307 lines
6.5 KiB
Go
307 lines
6.5 KiB
Go
|
//
|
||
|
// schema.go
|
||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||
|
//
|
||
|
// Distributed under terms of the MIT license.
|
||
|
//
|
||
|
|
||
|
package orm
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"database/sql/driver"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
)
|
||
|
|
||
|
func getConnectionFor(e Entity) *connection {
|
||
|
configurator := newEntityConfigurator()
|
||
|
e.ConfigureEntity(configurator)
|
||
|
|
||
|
if len(globalConnections) > 1 && (configurator.connection == "" || configurator.table == "") {
|
||
|
panic("need table and DB name when having more than 1 DB registered")
|
||
|
}
|
||
|
if len(globalConnections) == 1 {
|
||
|
for _, db := range globalConnections {
|
||
|
return db
|
||
|
}
|
||
|
}
|
||
|
if db, exists := globalConnections[fmt.Sprintf("%s", configurator.connection)]; exists {
|
||
|
return db
|
||
|
}
|
||
|
panic("no db found")
|
||
|
}
|
||
|
|
||
|
func getSchemaFor(e Entity) *schema {
|
||
|
configurator := newEntityConfigurator()
|
||
|
c := getConnectionFor(e)
|
||
|
e.ConfigureEntity(configurator)
|
||
|
s := c.getSchema(configurator.table)
|
||
|
if s == nil {
|
||
|
s = schemaOfHeavyReflectionStuff(e)
|
||
|
c.setSchema(e, s)
|
||
|
}
|
||
|
return s
|
||
|
}
|
||
|
|
||
|
type schema struct {
|
||
|
Connection string
|
||
|
Table string
|
||
|
fields []*field
|
||
|
relations map[string]interface{}
|
||
|
setPK func(o Entity, value interface{})
|
||
|
getPK func(o Entity) interface{}
|
||
|
columnConstraints []*FieldConfigurator
|
||
|
}
|
||
|
|
||
|
func (s *schema) getField(sf reflect.StructField) *field {
|
||
|
for _, f := range s.fields {
|
||
|
if sf.Name == f.Name {
|
||
|
return f
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *schema) getDialect() *Dialect {
|
||
|
return GetConnection(s.Connection).Dialect
|
||
|
}
|
||
|
func (s *schema) Columns(withPK bool) []string {
|
||
|
var cols []string
|
||
|
for _, field := range s.fields {
|
||
|
if field.Virtual {
|
||
|
continue
|
||
|
}
|
||
|
if !withPK && field.IsPK {
|
||
|
continue
|
||
|
}
|
||
|
if s.getDialect().AddTableNameInSelectColumns {
|
||
|
cols = append(cols, s.Table+"."+field.Name)
|
||
|
} else {
|
||
|
cols = append(cols, field.Name)
|
||
|
}
|
||
|
}
|
||
|
return cols
|
||
|
}
|
||
|
|
||
|
func (s *schema) pkName() string {
|
||
|
for _, field := range s.fields {
|
||
|
if field.IsPK {
|
||
|
return field.Name
|
||
|
}
|
||
|
}
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
func genericFieldsOf(obj Entity) []*field {
|
||
|
t := reflect.TypeOf(obj)
|
||
|
for t.Kind() == reflect.Ptr {
|
||
|
t = t.Elem()
|
||
|
|
||
|
}
|
||
|
if t.Kind() == reflect.Slice {
|
||
|
t = t.Elem()
|
||
|
for t.Kind() == reflect.Ptr {
|
||
|
t = t.Elem()
|
||
|
}
|
||
|
}
|
||
|
var ec EntityConfigurator
|
||
|
obj.ConfigureEntity(&ec)
|
||
|
|
||
|
var fms []*field
|
||
|
for i := 0; i < t.NumField(); i++ {
|
||
|
ft := t.Field(i)
|
||
|
fm := fieldMetadata(ft, ec.columnConstraints)
|
||
|
fms = append(fms, fm...)
|
||
|
}
|
||
|
return fms
|
||
|
}
|
||
|
|
||
|
func valuesOfField(vf reflect.Value) []interface{} {
|
||
|
var values []interface{}
|
||
|
if vf.Type().Kind() == reflect.Struct || vf.Type().Kind() == reflect.Ptr {
|
||
|
t := vf.Type()
|
||
|
if vf.Type().Kind() == reflect.Ptr {
|
||
|
t = vf.Type().Elem()
|
||
|
}
|
||
|
if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||
|
// go into
|
||
|
// it does not implement driver.Valuer interface
|
||
|
for i := 0; i < vf.NumField(); i++ {
|
||
|
vif := vf.Field(i)
|
||
|
values = append(values, valuesOfField(vif)...)
|
||
|
}
|
||
|
} else {
|
||
|
values = append(values, vf.Interface())
|
||
|
}
|
||
|
} else {
|
||
|
values = append(values, vf.Interface())
|
||
|
}
|
||
|
return values
|
||
|
}
|
||
|
func genericValuesOf(o Entity, withPK bool) []interface{} {
|
||
|
t := reflect.TypeOf(o)
|
||
|
v := reflect.ValueOf(o)
|
||
|
if t.Kind() == reflect.Ptr {
|
||
|
t = t.Elem()
|
||
|
v = v.Elem()
|
||
|
}
|
||
|
fields := getSchemaFor(o).fields
|
||
|
pkIdx := -1
|
||
|
for i, field := range fields {
|
||
|
if field.IsPK {
|
||
|
pkIdx = i
|
||
|
}
|
||
|
|
||
|
}
|
||
|
|
||
|
var values []interface{}
|
||
|
|
||
|
for i := 0; i < t.NumField(); i++ {
|
||
|
if !withPK && i == pkIdx {
|
||
|
continue
|
||
|
}
|
||
|
if fields[i].Virtual {
|
||
|
continue
|
||
|
}
|
||
|
vf := v.Field(i)
|
||
|
values = append(values, valuesOfField(vf)...)
|
||
|
}
|
||
|
return values
|
||
|
}
|
||
|
|
||
|
func genericSetPkValue(obj Entity, value interface{}) {
|
||
|
genericSet(obj, getSchemaFor(obj).pkName(), value)
|
||
|
}
|
||
|
|
||
|
func genericGetPKValue(obj Entity) interface{} {
|
||
|
t := reflect.TypeOf(obj)
|
||
|
val := reflect.ValueOf(obj)
|
||
|
if t.Kind() == reflect.Ptr {
|
||
|
val = val.Elem()
|
||
|
}
|
||
|
|
||
|
fields := getSchemaFor(obj).fields
|
||
|
for i, field := range fields {
|
||
|
if field.IsPK {
|
||
|
return val.Field(i).Interface()
|
||
|
}
|
||
|
}
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
func (s *schema) createdAt() *field {
|
||
|
for _, f := range s.fields {
|
||
|
if f.IsCreatedAt {
|
||
|
return f
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
func (s *schema) updatedAt() *field {
|
||
|
for _, f := range s.fields {
|
||
|
if f.IsUpdatedAt {
|
||
|
return f
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *schema) deletedAt() *field {
|
||
|
for _, f := range s.fields {
|
||
|
if f.IsDeletedAt {
|
||
|
return f
|
||
|
}
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
func pointersOf(v reflect.Value) map[string]interface{} {
|
||
|
m := map[string]interface{}{}
|
||
|
actualV := v
|
||
|
for actualV.Type().Kind() == reflect.Ptr {
|
||
|
actualV = actualV.Elem()
|
||
|
}
|
||
|
for i := 0; i < actualV.NumField(); i++ {
|
||
|
f := actualV.Field(i)
|
||
|
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||
|
fm := pointersOf(f)
|
||
|
for k, p := range fm {
|
||
|
m[k] = p
|
||
|
}
|
||
|
} else {
|
||
|
fm := fieldMetadata(actualV.Type().Field(i), nil)[0]
|
||
|
m[fm.Name] = actualV.Field(i)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return m
|
||
|
}
|
||
|
func genericSet(obj Entity, name string, value interface{}) {
|
||
|
n2p := pointersOf(reflect.ValueOf(obj))
|
||
|
var val interface{}
|
||
|
for k, v := range n2p {
|
||
|
if k == name {
|
||
|
val = v
|
||
|
}
|
||
|
}
|
||
|
val.(reflect.Value).Set(reflect.ValueOf(value))
|
||
|
}
|
||
|
func schemaOfHeavyReflectionStuff(v Entity) *schema {
|
||
|
userEntityConfigurator := newEntityConfigurator()
|
||
|
v.ConfigureEntity(userEntityConfigurator)
|
||
|
for _, relation := range userEntityConfigurator.resolveRelations {
|
||
|
relation()
|
||
|
}
|
||
|
schema := &schema{}
|
||
|
if userEntityConfigurator.connection != "" {
|
||
|
schema.Connection = userEntityConfigurator.connection
|
||
|
}
|
||
|
if userEntityConfigurator.table != "" {
|
||
|
schema.Table = userEntityConfigurator.table
|
||
|
} else {
|
||
|
panic("you need to have table name for getting schema.")
|
||
|
}
|
||
|
|
||
|
schema.columnConstraints = userEntityConfigurator.columnConstraints
|
||
|
if schema.Connection == "" {
|
||
|
schema.Connection = "default"
|
||
|
}
|
||
|
if schema.fields == nil {
|
||
|
schema.fields = genericFieldsOf(v)
|
||
|
}
|
||
|
if schema.getPK == nil {
|
||
|
schema.getPK = genericGetPKValue
|
||
|
}
|
||
|
|
||
|
if schema.setPK == nil {
|
||
|
schema.setPK = genericSetPkValue
|
||
|
}
|
||
|
|
||
|
schema.relations = userEntityConfigurator.relations
|
||
|
|
||
|
return schema
|
||
|
}
|
||
|
|
||
|
func (s *schema) getTable() string {
|
||
|
return s.Table
|
||
|
}
|
||
|
|
||
|
func (s *schema) getSQLDB() *sql.DB {
|
||
|
return s.getConnection().DB
|
||
|
}
|
||
|
|
||
|
func (s *schema) getConnection() *connection {
|
||
|
if len(globalConnections) > 1 && (s.Connection == "" || s.Table == "") {
|
||
|
panic("need table and DB name when having more than 1 DB registered")
|
||
|
}
|
||
|
if len(globalConnections) == 1 {
|
||
|
for _, db := range globalConnections {
|
||
|
return db
|
||
|
}
|
||
|
}
|
||
|
if db, exists := globalConnections[fmt.Sprintf("%s", s.Connection)]; exists {
|
||
|
return db
|
||
|
}
|
||
|
panic("no db found")
|
||
|
}
|