golib/gdb/orm/schema.go

307 lines
6.5 KiB
Go
Raw Normal View History

2023-08-20 13:51:00 +08:00
//
// schema.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package orm
import (
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
)
func getConnectionFor(e Entity) *connection {
configurator := newEntityConfigurator()
e.ConfigureEntity(configurator)
if len(globalConnections) > 1 && (configurator.connection == "" || configurator.table == "") {
panic("need table and DB name when having more than 1 DB registered")
}
if len(globalConnections) == 1 {
for _, db := range globalConnections {
return db
}
}
if db, exists := globalConnections[fmt.Sprintf("%s", configurator.connection)]; exists {
return db
}
panic("no db found")
}
func getSchemaFor(e Entity) *schema {
configurator := newEntityConfigurator()
c := getConnectionFor(e)
e.ConfigureEntity(configurator)
s := c.getSchema(configurator.table)
if s == nil {
s = schemaOfHeavyReflectionStuff(e)
c.setSchema(e, s)
}
return s
}
type schema struct {
Connection string
Table string
fields []*field
relations map[string]interface{}
setPK func(o Entity, value interface{})
getPK func(o Entity) interface{}
columnConstraints []*FieldConfigurator
}
func (s *schema) getField(sf reflect.StructField) *field {
for _, f := range s.fields {
if sf.Name == f.Name {
return f
}
}
return nil
}
func (s *schema) getDialect() *Dialect {
return GetConnection(s.Connection).Dialect
}
func (s *schema) Columns(withPK bool) []string {
var cols []string
for _, field := range s.fields {
if field.Virtual {
continue
}
if !withPK && field.IsPK {
continue
}
if s.getDialect().AddTableNameInSelectColumns {
cols = append(cols, s.Table+"."+field.Name)
} else {
cols = append(cols, field.Name)
}
}
return cols
}
func (s *schema) pkName() string {
for _, field := range s.fields {
if field.IsPK {
return field.Name
}
}
return ""
}
func genericFieldsOf(obj Entity) []*field {
t := reflect.TypeOf(obj)
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() == reflect.Slice {
t = t.Elem()
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
}
var ec EntityConfigurator
obj.ConfigureEntity(&ec)
var fms []*field
for i := 0; i < t.NumField(); i++ {
ft := t.Field(i)
fm := fieldMetadata(ft, ec.columnConstraints)
fms = append(fms, fm...)
}
return fms
}
func valuesOfField(vf reflect.Value) []interface{} {
var values []interface{}
if vf.Type().Kind() == reflect.Struct || vf.Type().Kind() == reflect.Ptr {
t := vf.Type()
if vf.Type().Kind() == reflect.Ptr {
t = vf.Type().Elem()
}
if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
// go into
// it does not implement driver.Valuer interface
for i := 0; i < vf.NumField(); i++ {
vif := vf.Field(i)
values = append(values, valuesOfField(vif)...)
}
} else {
values = append(values, vf.Interface())
}
} else {
values = append(values, vf.Interface())
}
return values
}
func genericValuesOf(o Entity, withPK bool) []interface{} {
t := reflect.TypeOf(o)
v := reflect.ValueOf(o)
if t.Kind() == reflect.Ptr {
t = t.Elem()
v = v.Elem()
}
fields := getSchemaFor(o).fields
pkIdx := -1
for i, field := range fields {
if field.IsPK {
pkIdx = i
}
}
var values []interface{}
for i := 0; i < t.NumField(); i++ {
if !withPK && i == pkIdx {
continue
}
if fields[i].Virtual {
continue
}
vf := v.Field(i)
values = append(values, valuesOfField(vf)...)
}
return values
}
func genericSetPkValue(obj Entity, value interface{}) {
genericSet(obj, getSchemaFor(obj).pkName(), value)
}
func genericGetPKValue(obj Entity) interface{} {
t := reflect.TypeOf(obj)
val := reflect.ValueOf(obj)
if t.Kind() == reflect.Ptr {
val = val.Elem()
}
fields := getSchemaFor(obj).fields
for i, field := range fields {
if field.IsPK {
return val.Field(i).Interface()
}
}
return ""
}
func (s *schema) createdAt() *field {
for _, f := range s.fields {
if f.IsCreatedAt {
return f
}
}
return nil
}
func (s *schema) updatedAt() *field {
for _, f := range s.fields {
if f.IsUpdatedAt {
return f
}
}
return nil
}
func (s *schema) deletedAt() *field {
for _, f := range s.fields {
if f.IsDeletedAt {
return f
}
}
return nil
}
func pointersOf(v reflect.Value) map[string]interface{} {
m := map[string]interface{}{}
actualV := v
for actualV.Type().Kind() == reflect.Ptr {
actualV = actualV.Elem()
}
for i := 0; i < actualV.NumField(); i++ {
f := actualV.Field(i)
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
fm := pointersOf(f)
for k, p := range fm {
m[k] = p
}
} else {
fm := fieldMetadata(actualV.Type().Field(i), nil)[0]
m[fm.Name] = actualV.Field(i)
}
}
return m
}
func genericSet(obj Entity, name string, value interface{}) {
n2p := pointersOf(reflect.ValueOf(obj))
var val interface{}
for k, v := range n2p {
if k == name {
val = v
}
}
val.(reflect.Value).Set(reflect.ValueOf(value))
}
func schemaOfHeavyReflectionStuff(v Entity) *schema {
userEntityConfigurator := newEntityConfigurator()
v.ConfigureEntity(userEntityConfigurator)
for _, relation := range userEntityConfigurator.resolveRelations {
relation()
}
schema := &schema{}
if userEntityConfigurator.connection != "" {
schema.Connection = userEntityConfigurator.connection
}
if userEntityConfigurator.table != "" {
schema.Table = userEntityConfigurator.table
} else {
panic("you need to have table name for getting schema.")
}
schema.columnConstraints = userEntityConfigurator.columnConstraints
if schema.Connection == "" {
schema.Connection = "default"
}
if schema.fields == nil {
schema.fields = genericFieldsOf(v)
}
if schema.getPK == nil {
schema.getPK = genericGetPKValue
}
if schema.setPK == nil {
schema.setPK = genericSetPkValue
}
schema.relations = userEntityConfigurator.relations
return schema
}
func (s *schema) getTable() string {
return s.Table
}
func (s *schema) getSQLDB() *sql.DB {
return s.getConnection().DB
}
func (s *schema) getConnection() *connection {
if len(globalConnections) > 1 && (s.Connection == "" || s.Table == "") {
panic("need table and DB name when having more than 1 DB registered")
}
if len(globalConnections) == 1 {
for _, db := range globalConnections {
return db
}
}
if db, exists := globalConnections[fmt.Sprintf("%s", s.Connection)]; exists {
return db
}
panic("no db found")
}