diff --git a/gdb/orm/.gitignore b/gdb/orm/.gitignore new file mode 100644 index 0000000..e702335 --- /dev/null +++ b/gdb/orm/.gitignore @@ -0,0 +1,3 @@ +**/.idea/* +cover.out +**db \ No newline at end of file diff --git a/gdb/orm/binder.go b/gdb/orm/binder.go new file mode 100644 index 0000000..3a65598 --- /dev/null +++ b/gdb/orm/binder.go @@ -0,0 +1,157 @@ +// +// binder.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "unsafe" +) + +// makeNewPointersOf creates a map of [field name] -> pointer to fill it +// recursively. it will go down until reaches a driver.Valuer implementation, it will stop there. +func (b *binder) makeNewPointersOf(v reflect.Value) interface{} { + m := map[string]interface{}{} + actualV := v + for actualV.Type().Kind() == reflect.Ptr { + actualV = actualV.Elem() + } + if actualV.Type().Kind() == reflect.Struct { + for i := 0; i < actualV.NumField(); i++ { + f := actualV.Field(i) + if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { + f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr())) + fm := b.makeNewPointersOf(f).(map[string]interface{}) + for k, p := range fm { + m[k] = p + } + } else { + var fm *field + fm = b.s.getField(actualV.Type().Field(i)) + if fm == nil { + fm = fieldMetadata(actualV.Type().Field(i), b.s.columnConstraints)[0] + } + m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface() + } + } + } else { + return v.Addr().Interface() + } + + return m +} + +// ptrsFor first allocates for all struct fields recursively until reaches a driver.Value impl +// then it will put them in a map with their correct field name as key, then loops over cts +// and for each one gets appropriate one from the map and adds it to pointer list. +func (b *binder) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} { + ptrs := b.makeNewPointersOf(v) + var scanInto []interface{} + if reflect.TypeOf(ptrs).Kind() == reflect.Map { + nameToPtr := ptrs.(map[string]interface{}) + for _, ct := range cts { + if nameToPtr[ct.Name()] != nil { + scanInto = append(scanInto, nameToPtr[ct.Name()]) + } + } + } else { + scanInto = append(scanInto, ptrs) + } + + return scanInto +} + +type binder struct { + s *schema +} + +func newBinder(s *schema) *binder { + return &binder{s: s} +} + +// bind binds given rows to the given object at obj. obj should be a pointer +func (b *binder) bind(rows *sql.Rows, obj interface{}) error { + cts, err := rows.ColumnTypes() + if err != nil { + return err + } + + t := reflect.TypeOf(obj) + v := reflect.ValueOf(obj) + if t.Kind() != reflect.Ptr { + return fmt.Errorf("obj should be a ptr") + } + // since passed input is always a pointer one deref is necessary + t = t.Elem() + v = v.Elem() + if t.Kind() == reflect.Slice { + // getting slice elemnt type -> slice[t] + t = t.Elem() + for rows.Next() { + var rowValue reflect.Value + // Since reflect.SetupConnections returns a pointer to the type, we need to unwrap it to get actual + rowValue = reflect.New(t).Elem() + // till we reach a not pointer type continue newing the underlying type. + for rowValue.IsZero() && rowValue.Type().Kind() == reflect.Ptr { + rowValue = reflect.New(rowValue.Type().Elem()).Elem() + } + newCts := make([]*sql.ColumnType, len(cts)) + copy(newCts, cts) + ptrs := b.ptrsFor(rowValue, newCts) + err = rows.Scan(ptrs...) + if err != nil { + return err + } + for rowValue.Type() != t { + tmp := reflect.New(rowValue.Type()) + tmp.Elem().Set(rowValue) + rowValue = tmp + } + v = reflect.Append(v, rowValue) + } + } else { + for rows.Next() { + ptrs := b.ptrsFor(v, cts) + err = rows.Scan(ptrs...) + if err != nil { + return err + } + } + } + // v is either struct or slice + reflect.ValueOf(obj).Elem().Set(v) + return nil +} + +func bindToMap(rows *sql.Rows) ([]map[string]interface{}, error) { + cts, err := rows.ColumnTypes() + if err != nil { + return nil, err + } + var ms []map[string]interface{} + for rows.Next() { + var ptrs []interface{} + for _, ct := range cts { + ptrs = append(ptrs, reflect.New(ct.ScanType()).Interface()) + } + + err = rows.Scan(ptrs...) + if err != nil { + return nil, err + } + m := map[string]interface{}{} + for i, ptr := range ptrs { + m[cts[i].Name()] = reflect.ValueOf(ptr).Elem().Interface() + } + + ms = append(ms, m) + } + return ms, nil +} diff --git a/gdb/orm/binder_test.go b/gdb/orm/binder_test.go new file mode 100644 index 0000000..a5ac817 --- /dev/null +++ b/gdb/orm/binder_test.go @@ -0,0 +1,91 @@ +// +// binder_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + _ "github.com/lib/pq" + "github.com/stretchr/testify/assert" +) + +type User struct { + ID int64 + Name string + Timestamps +} + +func (u User) ConfigureEntity(e *EntityConfigurator) { + e.Table("users") +} + +type Address struct { + ID int + Path string +} + +func TestBind(t *testing.T) { + t.Run("single result", func(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + mock. + ExpectQuery("SELECT .* FROM users"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}). + AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{})) + rows, err := db.Query(`SELECT * FROM users`) + assert.NoError(t, err) + + u := &User{} + md := schemaOfHeavyReflectionStuff(u) + err = newBinder(md).bind(rows, u) + assert.NoError(t, err) + + assert.Equal(t, "amirreza", u.Name) + }) + + t.Run("multi result", func(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + mock. + ExpectQuery("SELECT .* FROM users"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "amirreza").AddRow(2, "milad")) + + rows, err := db.Query(`SELECT * FROM users`) + assert.NoError(t, err) + + md := schemaOfHeavyReflectionStuff(&User{}) + var users []*User + err = newBinder(md).bind(rows, &users) + assert.NoError(t, err) + + assert.Equal(t, "amirreza", users[0].Name) + assert.Equal(t, "milad", users[1].Name) + }) +} + +func TestBindMap(t *testing.T) { + db, mock, err := sqlmock.New() + assert.NoError(t, err) + mock. + ExpectQuery("SELECT .* FROM users"). + WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}). + AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{})) + rows, err := db.Query(`SELECT * FROM users`) + assert.NoError(t, err) + + ms, err := bindToMap(rows) + + assert.NoError(t, err) + assert.NotEmpty(t, ms) + + assert.Len(t, ms, 1) +} diff --git a/gdb/orm/configurators.go b/gdb/orm/configurators.go new file mode 100644 index 0000000..073a726 --- /dev/null +++ b/gdb/orm/configurators.go @@ -0,0 +1,192 @@ +// +// configurators.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + + "git.hexq.cn/tiglog/golib/helper" +) + +type EntityConfigurator struct { + connection string + table string + this Entity + relations map[string]interface{} + resolveRelations []func() + columnConstraints []*FieldConfigurator +} + +func newEntityConfigurator() *EntityConfigurator { + return &EntityConfigurator{} +} + +func (ec *EntityConfigurator) Table(name string) *EntityConfigurator { + ec.table = name + return ec +} + +func (ec *EntityConfigurator) Connection(name string) *EntityConfigurator { + ec.connection = name + return ec +} + +func (ec *EntityConfigurator) HasMany(property Entity, config HasManyConfig) *EntityConfigurator { + if ec.relations == nil { + ec.relations = map[string]interface{}{} + } + ec.resolveRelations = append(ec.resolveRelations, func() { + if config.PropertyForeignKey != "" && config.PropertyTable != "" { + ec.relations[config.PropertyTable] = config + return + } + configurator := newEntityConfigurator() + property.ConfigureEntity(configurator) + + if config.PropertyTable == "" { + config.PropertyTable = configurator.table + } + + if config.PropertyForeignKey == "" { + config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id" + } + + ec.relations[configurator.table] = config + + return + }) + return ec +} + +func (ec *EntityConfigurator) HasOne(property Entity, config HasOneConfig) *EntityConfigurator { + if ec.relations == nil { + ec.relations = map[string]interface{}{} + } + ec.resolveRelations = append(ec.resolveRelations, func() { + if config.PropertyForeignKey != "" && config.PropertyTable != "" { + ec.relations[config.PropertyTable] = config + return + } + + configurator := newEntityConfigurator() + property.ConfigureEntity(configurator) + + if config.PropertyTable == "" { + config.PropertyTable = configurator.table + } + if config.PropertyForeignKey == "" { + config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id" + } + + ec.relations[configurator.table] = config + return + }) + return ec +} + +func (ec *EntityConfigurator) BelongsTo(owner Entity, config BelongsToConfig) *EntityConfigurator { + if ec.relations == nil { + ec.relations = map[string]interface{}{} + } + ec.resolveRelations = append(ec.resolveRelations, func() { + if config.ForeignColumnName != "" && config.LocalForeignKey != "" && config.OwnerTable != "" { + ec.relations[config.OwnerTable] = config + return + } + ownerConfigurator := newEntityConfigurator() + owner.ConfigureEntity(ownerConfigurator) + if config.OwnerTable == "" { + config.OwnerTable = ownerConfigurator.table + } + if config.LocalForeignKey == "" { + config.LocalForeignKey = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id" + } + if config.ForeignColumnName == "" { + config.ForeignColumnName = "id" + } + ec.relations[ownerConfigurator.table] = config + }) + return ec +} + +func (ec *EntityConfigurator) BelongsToMany(owner Entity, config BelongsToManyConfig) *EntityConfigurator { + if ec.relations == nil { + ec.relations = map[string]interface{}{} + } + ec.resolveRelations = append(ec.resolveRelations, func() { + ownerConfigurator := newEntityConfigurator() + owner.ConfigureEntity(ownerConfigurator) + + if config.OwnerLookupColumn == "" { + var pkName string + for _, field := range genericFieldsOf(owner) { + if field.IsPK { + pkName = field.Name + } + } + config.OwnerLookupColumn = pkName + + } + if config.OwnerTable == "" { + config.OwnerTable = ownerConfigurator.table + } + if config.IntermediateTable == "" { + panic("cannot infer intermediate table yet") + } + if config.IntermediatePropertyID == "" { + config.IntermediatePropertyID = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id" + } + if config.IntermediateOwnerID == "" { + config.IntermediateOwnerID = helper.NewPluralizeClient().Singular(ec.table) + "_id" + } + + ec.relations[ownerConfigurator.table] = config + }) + return ec +} + +type FieldConfigurator struct { + fieldName string + nullable sql.NullBool + primaryKey bool + column string + isCreatedAt bool + isUpdatedAt bool + isDeletedAt bool +} + +func (ec *EntityConfigurator) Field(name string) *FieldConfigurator { + cc := &FieldConfigurator{fieldName: name} + ec.columnConstraints = append(ec.columnConstraints, cc) + return cc +} + +func (fc *FieldConfigurator) IsPrimaryKey() *FieldConfigurator { + fc.primaryKey = true + return fc +} + +func (fc *FieldConfigurator) IsCreatedAt() *FieldConfigurator { + fc.isCreatedAt = true + return fc +} + +func (fc *FieldConfigurator) IsUpdatedAt() *FieldConfigurator { + fc.isUpdatedAt = true + return fc +} + +func (fc *FieldConfigurator) IsDeletedAt() *FieldConfigurator { + fc.isDeletedAt = true + return fc +} + +func (fc *FieldConfigurator) ColumnName(name string) *FieldConfigurator { + fc.column = name + return fc +} diff --git a/gdb/orm/connection.go b/gdb/orm/connection.go new file mode 100644 index 0000000..dde2388 --- /dev/null +++ b/gdb/orm/connection.go @@ -0,0 +1,160 @@ +// +// connection.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "fmt" + + "git.hexq.cn/tiglog/golib/helper/table" +) + +type connection struct { + Name string + Dialect *Dialect + DB *sql.DB + Schemas map[string]*schema + DBSchema map[string][]columnSpec + DatabaseValidations bool +} + +func (c *connection) inferedTables() []string { + var tables []string + for t, s := range c.Schemas { + tables = append(tables, t) + for _, relC := range s.relations { + if belongsToManyConfig, is := relC.(BelongsToManyConfig); is { + tables = append(tables, belongsToManyConfig.IntermediateTable) + } + } + } + return tables +} + +func (c *connection) validateAllTablesArePresent() error { + for _, inferedTable := range c.inferedTables() { + if _, exists := c.DBSchema[inferedTable]; !exists { + return fmt.Errorf("orm infered %s but it's not found in your database, your database is out of sync", inferedTable) + } + } + return nil +} + +func (c *connection) validateTablesSchemas() error { + // check for entity tables: there should not be any struct field that does not have a coresponding column + for table, sc := range c.Schemas { + if columns, exists := c.DBSchema[table]; exists { + for _, f := range sc.fields { + found := false + for _, c := range columns { + if c.Name == f.Name { + found = true + } + } + if !found { + return fmt.Errorf("column %s not found while it was inferred", f.Name) + } + } + } else { + return fmt.Errorf("tables are out of sync, %s was inferred but not present in database", table) + } + } + + // check for relation tables: for HasMany,HasOne relations check if OWNER pk column is in PROPERTY, + // for BelongsToMany check intermediate table has 2 pk for two entities + + for table, sc := range c.Schemas { + for _, rel := range sc.relations { + switch rel.(type) { + case BelongsToConfig: + columns := c.DBSchema[table] + var found bool + for _, col := range columns { + if col.Name == rel.(BelongsToConfig).LocalForeignKey { + found = true + } + } + if !found { + return fmt.Errorf("cannot find local foreign key %s for relation", rel.(BelongsToConfig).LocalForeignKey) + } + case BelongsToManyConfig: + columns := c.DBSchema[rel.(BelongsToManyConfig).IntermediateTable] + var foundOwner bool + var foundProperty bool + + for _, col := range columns { + if col.Name == rel.(BelongsToManyConfig).IntermediateOwnerID { + foundOwner = true + } + if col.Name == rel.(BelongsToManyConfig).IntermediatePropertyID { + foundProperty = true + } + } + if !foundOwner || !foundProperty { + return fmt.Errorf("table schema for %s is not correct one of foreign keys is not present", rel.(BelongsToManyConfig).IntermediateTable) + } + } + } + } + + return nil +} + +func (c *connection) Schematic() { + fmt.Printf("SQL Dialect: %s\n", c.Dialect.DriverName) + for t, schema := range c.Schemas { + fmt.Printf("t: %s\n", t) + w := table.NewWriter() + w.AppendHeader(table.Row{"SQL Name", "Type", "Is Primary Key", "Is Virtual"}) + for _, field := range schema.fields { + w.AppendRow(table.Row{field.Name, field.Type, field.IsPK, field.Virtual}) + } + fmt.Println(w.Render()) + for _, rel := range schema.relations { + switch rel.(type) { + case HasOneConfig: + fmt.Printf("%s 1-1 %s => %+v\n", t, rel.(HasOneConfig).PropertyTable, rel) + case HasManyConfig: + fmt.Printf("%s 1-N %s => %+v\n", t, rel.(HasManyConfig).PropertyTable, rel) + + case BelongsToConfig: + fmt.Printf("%s N-1 %s => %+v\n", t, rel.(BelongsToConfig).OwnerTable, rel) + + case BelongsToManyConfig: + fmt.Printf("%s N-N %s => %+v\n", t, rel.(BelongsToManyConfig).IntermediateTable, rel) + } + } + fmt.Println("") + } +} + +func (c *connection) getSchema(t string) *schema { + return c.Schemas[t] +} + +func (c *connection) setSchema(e Entity, s *schema) { + var configurator EntityConfigurator + e.ConfigureEntity(&configurator) + c.Schemas[configurator.table] = s +} + +func GetConnection(name string) *connection { + return globalConnections[name] +} + +func (c *connection) exec(q string, args ...any) (sql.Result, error) { + return c.DB.Exec(q, args...) +} + +func (c *connection) query(q string, args ...any) (*sql.Rows, error) { + return c.DB.Query(q, args...) +} + +func (c *connection) queryRow(q string, args ...any) *sql.Row { + return c.DB.QueryRow(q, args...) +} diff --git a/gdb/orm/dialect.go b/gdb/orm/dialect.go new file mode 100644 index 0000000..67e6a67 --- /dev/null +++ b/gdb/orm/dialect.go @@ -0,0 +1,109 @@ +// +// dialect.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "fmt" +) + +type Dialect struct { + DriverName string + PlaceholderChar string + IncludeIndexInPlaceholder bool + AddTableNameInSelectColumns bool + PlaceHolderGenerator func(n int) []string + QueryListTables string + QueryTableSchema string +} + +func getListOfTables(query string) func(db *sql.DB) ([]string, error) { + return func(db *sql.DB) ([]string, error) { + rows, err := db.Query(query) + if err != nil { + return nil, err + } + var tables []string + for rows.Next() { + var table string + err = rows.Scan(&table) + if err != nil { + return nil, err + } + tables = append(tables, table) + } + return tables, nil + } +} + +type columnSpec struct { + //0|id|INTEGER|0||1 + Name string + Type string + Nullable bool + DefaultValue sql.NullString + IsPrimaryKey bool +} + +func getTableSchema(query string) func(db *sql.DB, query string) ([]columnSpec, error) { + return func(db *sql.DB, table string) ([]columnSpec, error) { + rows, err := db.Query(fmt.Sprintf(query, table)) + if err != nil { + return nil, err + } + var output []columnSpec + for rows.Next() { + var cs columnSpec + var nullable string + var pk int + err = rows.Scan(&cs.Name, &cs.Type, &nullable, &cs.DefaultValue, &pk) + if err != nil { + return nil, err + } + cs.Nullable = nullable == "notnull" + cs.IsPrimaryKey = pk == 1 + output = append(output, cs) + } + return output, nil + + } +} + +var Dialects = &struct { + MySQL *Dialect + PostgreSQL *Dialect + SQLite3 *Dialect +}{ + MySQL: &Dialect{ + DriverName: "mysql", + PlaceholderChar: "?", + IncludeIndexInPlaceholder: false, + AddTableNameInSelectColumns: true, + PlaceHolderGenerator: questionMarks, + QueryListTables: "SHOW TABLES", + QueryTableSchema: "DESCRIBE %s", + }, + PostgreSQL: &Dialect{ + DriverName: "postgres", + PlaceholderChar: "$", + IncludeIndexInPlaceholder: true, + AddTableNameInSelectColumns: true, + PlaceHolderGenerator: postgresPlaceholder, + QueryListTables: `\dt`, + QueryTableSchema: `\d %s`, + }, + SQLite3: &Dialect{ + DriverName: "sqlite3", + PlaceholderChar: "?", + IncludeIndexInPlaceholder: false, + AddTableNameInSelectColumns: false, + PlaceHolderGenerator: questionMarks, + QueryListTables: "SELECT name FROM sqlite_schema WHERE type='table'", + QueryTableSchema: `SELECT name,type,"notnull","dflt_value","pk" FROM PRAGMA_TABLE_INFO('%s')`, + }, +} diff --git a/gdb/orm/field.go b/gdb/orm/field.go new file mode 100644 index 0000000..aad677c --- /dev/null +++ b/gdb/orm/field.go @@ -0,0 +1,75 @@ +// +// field.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql/driver" + "reflect" + "strings" + + "git.hexq.cn/tiglog/golib/helper" +) + +type field struct { + Name string + IsPK bool + Virtual bool + IsCreatedAt bool + IsUpdatedAt bool + IsDeletedAt bool + Nullable bool + Default any + Type reflect.Type +} + +func getFieldConfiguratorFor(fieldConfigurators []*FieldConfigurator, name string) *FieldConfigurator { + for _, fc := range fieldConfigurators { + if fc.fieldName == name { + return fc + } + } + return &FieldConfigurator{} +} + +func fieldMetadata(ft reflect.StructField, fieldConfigurators []*FieldConfigurator) []*field { + var fms []*field + fc := getFieldConfiguratorFor(fieldConfigurators, ft.Name) + baseFm := &field{} + baseFm.Type = ft.Type + fms = append(fms, baseFm) + if fc.column != "" { + baseFm.Name = fc.column + } else { + baseFm.Name = helper.SnakeString(ft.Name) + } + if strings.ToLower(ft.Name) == "id" || fc.primaryKey { + baseFm.IsPK = true + } + if strings.ToLower(ft.Name) == "createdat" || fc.isCreatedAt { + baseFm.IsCreatedAt = true + } + if strings.ToLower(ft.Name) == "updatedat" || fc.isUpdatedAt { + baseFm.IsUpdatedAt = true + } + if strings.ToLower(ft.Name) == "deletedat" || fc.isDeletedAt { + baseFm.IsDeletedAt = true + } + if ft.Type.Kind() == reflect.Struct || ft.Type.Kind() == reflect.Ptr { + t := ft.Type + if ft.Type.Kind() == reflect.Ptr { + t = ft.Type.Elem() + } + if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { + for i := 0; i < t.NumField(); i++ { + fms = append(fms, fieldMetadata(t.Field(i), fieldConfigurators)...) + } + fms = fms[1:] + } + } + return fms +} diff --git a/gdb/orm/orm.go b/gdb/orm/orm.go new file mode 100644 index 0000000..d7bc6e1 --- /dev/null +++ b/gdb/orm/orm.go @@ -0,0 +1,654 @@ +// +// orm.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "fmt" + "reflect" + "time" + + // Drivers + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +var globalConnections = map[string]*connection{} + +// Schematic prints all information ORM inferred from your entities in startup, remember to pass +// your entities in Entities when you call SetupConnections if you want their data inferred +// otherwise Schematic does not print correct data since GoLobby ORM also +// incrementally cache your entities metadata and schema. +func Schematic() { + for conn, connObj := range globalConnections { + fmt.Printf("----------------%s---------------\n", conn) + connObj.Schematic() + fmt.Println("-----------------------------------") + } +} + +type ConnectionConfig struct { + // Name of your database connection, it's up to you to name them anything + // just remember that having a connection name is mandatory if + // you have multiple connections + Name string + // If you already have an active database connection configured pass it in this value and + // do not pass Driver and DSN fields. + DB *sql.DB + // Which dialect of sql to generate queries for, you don't need it most of the times when you are using + // traditional databases such as mysql, sqlite3, postgres. + Dialect *Dialect + // List of entities that you want to use for this connection, remember that you can ignore this field + // and GoLobby ORM will build our metadata cache incrementally but you will lose schematic + // information that we can provide you and also potentialy validations that we + // can do with the database + Entities []Entity + // Database validations, check if all tables exists and also table schemas contains all necessary columns. + // Check if all infered tables exist in your database + DatabaseValidations bool +} + +// SetupConnections declares a new connections for ORM. +func SetupConnections(configs ...ConnectionConfig) error { + + for _, c := range configs { + if err := setupConnection(c); err != nil { + return err + } + } + for _, conn := range globalConnections { + if !conn.DatabaseValidations { + continue + } + + tables, err := getListOfTables(conn.Dialect.QueryListTables)(conn.DB) + if err != nil { + return err + } + + for _, table := range tables { + if conn.DatabaseValidations { + spec, err := getTableSchema(conn.Dialect.QueryTableSchema)(conn.DB, table) + if err != nil { + return err + } + conn.DBSchema[table] = spec + } else { + conn.DBSchema[table] = nil + } + } + + // check tables existence + if conn.DatabaseValidations { + err := conn.validateAllTablesArePresent() + if err != nil { + return err + } + } + + if conn.DatabaseValidations { + err = conn.validateTablesSchemas() + if err != nil { + return err + } + } + + } + + return nil +} + +func setupConnection(config ConnectionConfig) error { + schemas := map[string]*schema{} + if config.Name == "" { + config.Name = "default" + } + + for _, entity := range config.Entities { + s := schemaOfHeavyReflectionStuff(entity) + var configurator EntityConfigurator + entity.ConfigureEntity(&configurator) + schemas[configurator.table] = s + } + + s := &connection{ + Name: config.Name, + DB: config.DB, + Dialect: config.Dialect, + Schemas: schemas, + DBSchema: make(map[string][]columnSpec), + DatabaseValidations: config.DatabaseValidations, + } + + globalConnections[fmt.Sprintf("%s", config.Name)] = s + + return nil +} + +// Entity defines the interface that each of your structs that +// you want to use as database entities should have, +// it's a simple one and its ConfigureEntity. +type Entity interface { + // ConfigureEntity should be defined for all of your database entities + // and it can define Table, DB and also relations of your Entity. + ConfigureEntity(e *EntityConfigurator) +} + +// InsertAll given entities into database based on their ConfigureEntity +// we can find table and also DB name. +func InsertAll(objs ...Entity) error { + if len(objs) == 0 { + return nil + } + s := getSchemaFor(objs[0]) + cols := s.Columns(false) + var values [][]interface{} + for _, obj := range objs { + createdAtF := s.createdAt() + if createdAtF != nil { + genericSet(obj, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) + } + updatedAtF := s.updatedAt() + if updatedAtF != nil { + genericSet(obj, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) + } + values = append(values, genericValuesOf(obj, false)) + } + + is := insertStmt{ + PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator, + Table: s.getTable(), + Columns: cols, + Values: values, + } + + q, args := is.ToSql() + + _, err := s.getConnection().exec(q, args...) + if err != nil { + return err + } + return nil +} + +// Insert given entity into database based on their ConfigureEntity +// we can find table and also DB name. +func Insert(o Entity) error { + s := getSchemaFor(o) + cols := s.Columns(false) + var values [][]interface{} + createdAtF := s.createdAt() + if createdAtF != nil { + genericSet(o, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) + } + updatedAtF := s.updatedAt() + if updatedAtF != nil { + genericSet(o, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) + } + values = append(values, genericValuesOf(o, false)) + + is := insertStmt{ + PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator, + Table: s.getTable(), + Columns: cols, + Values: values, + } + + if s.getDialect().DriverName == "postgres" { + is.Returning = s.pkName() + } + q, args := is.ToSql() + + res, err := s.getConnection().exec(q, args...) + if err != nil { + return err + } + id, err := res.LastInsertId() + if err != nil { + return err + } + + if s.pkName() != "" { + // intermediate tables usually have no single pk column. + s.setPK(o, id) + } + return nil +} + +func isZero(val interface{}) bool { + switch val.(type) { + case int64: + return val.(int64) == 0 + case int: + return val.(int) == 0 + case string: + return val.(string) == "" + default: + return reflect.ValueOf(val).Elem().IsZero() + } +} + +// Save saves given entity, if primary key is set +// we will make an update query and if +// primary key is zero value we will +// insert it. +func Save(obj Entity) error { + if isZero(getSchemaFor(obj).getPK(obj)) { + return Insert(obj) + } else { + return Update(obj) + } +} + +// Find finds the Entity you want based on generic type and primary key you passed. +func Find[T Entity](id interface{}) (T, error) { + var q string + out := new(T) + md := getSchemaFor(*out) + q, args, err := NewQueryBuilder[T](md). + SetDialect(md.getDialect()). + Table(md.Table). + Select(md.Columns(true)...). + Where(md.pkName(), id). + ToSql() + if err != nil { + return *out, err + } + err = bind[T](out, q, args) + + if err != nil { + return *out, err + } + + return *out, nil +} + +func toKeyValues(obj Entity, withPK bool) []any { + var tuples []any + vs := genericValuesOf(obj, withPK) + cols := getSchemaFor(obj).Columns(withPK) + for i, col := range cols { + tuples = append(tuples, col, vs[i]) + } + return tuples +} + +// Update given Entity in database. +func Update(obj Entity) error { + s := getSchemaFor(obj) + q, args, err := NewQueryBuilder[Entity](s). + SetDialect(s.getDialect()). + Set(toKeyValues(obj, false)...). + Where(s.pkName(), genericGetPKValue(obj)).Table(s.Table).ToSql() + + if err != nil { + return err + } + _, err = s.getConnection().exec(q, args...) + return err +} + +// Delete given Entity from database +func Delete(obj Entity) error { + s := getSchemaFor(obj) + genericSet(obj, "deleted_at", sql.NullTime{Time: time.Now(), Valid: true}) + query, args, err := NewQueryBuilder[Entity](s).SetDialect(s.getDialect()).Table(s.Table).Where(s.pkName(), genericGetPKValue(obj)).SetDelete().ToSql() + if err != nil { + return err + } + _, err = s.getConnection().exec(query, args...) + return err +} + +func bind[T Entity](output interface{}, q string, args []interface{}) error { + outputMD := getSchemaFor(*new(T)) + rows, err := outputMD.getConnection().query(q, args...) + if err != nil { + return err + } + return newBinder(outputMD).bind(rows, output) +} + +// HasManyConfig contains all information we need for querying HasMany relationships. +// We can infer both fields if you have them in standard way but you +// can specify them if you want custom ones. +type HasManyConfig struct { + // PropertyTable is table of the property of HasMany relationship, + // consider `Comment` in Post and Comment relationship, + // each Post HasMany Comment, so PropertyTable is + // `comments`. + PropertyTable string + // PropertyForeignKey is the foreign key field name in the property table, + // for example in Post HasMany Comment, if comment has `post_id` field, + // it's the PropertyForeignKey field. + PropertyForeignKey string +} + +// HasMany configures a QueryBuilder for a HasMany relationship +// this relationship will be defined for owner argument +// that has many of PROPERTY generic type for example +// HasMany[Comment](&Post{}) +// is for Post HasMany Comment relationship. +func HasMany[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] { + outSchema := getSchemaFor(*new(PROPERTY)) + + q := NewQueryBuilder[PROPERTY](outSchema) + // getting config from our cache + c, ok := getSchemaFor(owner).relations[outSchema.Table].(HasManyConfig) + if !ok { + q.err = fmt.Errorf("wrong config passed for HasMany") + } + + s := getSchemaFor(owner) + return q. + SetDialect(s.getDialect()). + Table(c.PropertyTable). + Select(outSchema.Columns(true)...). + Where(c.PropertyForeignKey, genericGetPKValue(owner)) +} + +// HasOneConfig contains all information we need for a HasOne relationship, +// it's similar to HasManyConfig. +type HasOneConfig struct { + // PropertyTable is table of the property of HasOne relationship, + // consider `HeaderPicture` in Post and HeaderPicture relationship, + // each Post HasOne HeaderPicture, so PropertyTable is + // `header_pictures`. + PropertyTable string + // PropertyForeignKey is the foreign key field name in the property table, + // forexample in Post HasOne HeaderPicture, if header_picture has `post_id` field, + // it's the PropertyForeignKey field. + PropertyForeignKey string +} + +// HasOne configures a QueryBuilder for a HasOne relationship +// this relationship will be defined for owner argument +// that has one of PROPERTY generic type for example +// HasOne[HeaderPicture](&Post{}) +// is for Post HasOne HeaderPicture relationship. +func HasOne[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] { + property := getSchemaFor(*new(PROPERTY)) + q := NewQueryBuilder[PROPERTY](property) + c, ok := getSchemaFor(owner).relations[property.Table].(HasOneConfig) + if !ok { + q.err = fmt.Errorf("wrong config passed for HasOne") + } + + // settings default config Values + return q. + SetDialect(property.getDialect()). + Table(c.PropertyTable). + Select(property.Columns(true)...). + Where(c.PropertyForeignKey, genericGetPKValue(owner)) +} + +// BelongsToConfig contains all information we need for a BelongsTo relationship +// BelongsTo is a relationship between a Comment and it's Post, +// A Comment BelongsTo Post. +type BelongsToConfig struct { + // OwnerTable is the table that contains owner of a BelongsTo + // relationship. + OwnerTable string + // LocalForeignKey is name of the field that links property + // to its owner in BelongsTo relation. for example when + // a Comment BelongsTo Post, LocalForeignKey is + // post_id of Comment. + LocalForeignKey string + // ForeignColumnName is name of the field that LocalForeignKey + // field value will point to it, for example when + // a Comment BelongsTo Post, ForeignColumnName is + // id of Post. + ForeignColumnName string +} + +// BelongsTo configures a QueryBuilder for a BelongsTo relationship between +// OWNER type parameter and property argument, so +// property BelongsTo OWNER. +func BelongsTo[OWNER Entity](property Entity) *QueryBuilder[OWNER] { + owner := getSchemaFor(*new(OWNER)) + q := NewQueryBuilder[OWNER](owner) + c, ok := getSchemaFor(property).relations[owner.Table].(BelongsToConfig) + if !ok { + q.err = fmt.Errorf("wrong config passed for BelongsTo") + } + + ownerIDidx := 0 + for idx, field := range owner.fields { + if field.Name == c.LocalForeignKey { + ownerIDidx = idx + } + } + + ownerID := genericValuesOf(property, true)[ownerIDidx] + + return q. + SetDialect(owner.getDialect()). + Table(c.OwnerTable).Select(owner.Columns(true)...). + Where(c.ForeignColumnName, ownerID) + +} + +// BelongsToManyConfig contains information that we +// need for creating many to many queries. +type BelongsToManyConfig struct { + // IntermediateTable is the name of the middle table + // in a BelongsToMany (Many to Many) relationship. + // for example when we have Post BelongsToMany + // Category, this table will be post_categories + // table, remember that this field cannot be + // inferred. + IntermediateTable string + // IntermediatePropertyID is the name of the field name + // of property foreign key in intermediate table, + // for example when we have Post BelongsToMany + // Category, in post_categories table, it would + // be post_id. + IntermediatePropertyID string + // IntermediateOwnerID is the name of the field name + // of property foreign key in intermediate table, + // for example when we have Post BelongsToMany + // Category, in post_categories table, it would + // be category_id. + IntermediateOwnerID string + // Table name of the owner in BelongsToMany relation, + // for example in Post BelongsToMany Category + // Owner table is name of Category table + // for example `categories`. + OwnerTable string + // OwnerLookupColumn is name of the field in the owner + // table that is used in query, for example in Post BelongsToMany Category + // Owner lookup field would be Category primary key which is id. + OwnerLookupColumn string +} + +// BelongsToMany configures a QueryBuilder for a BelongsToMany relationship +func BelongsToMany[OWNER Entity](property Entity) *QueryBuilder[OWNER] { + out := *new(OWNER) + outSchema := getSchemaFor(out) + q := NewQueryBuilder[OWNER](outSchema) + c, ok := getSchemaFor(property).relations[outSchema.Table].(BelongsToManyConfig) + if !ok { + q.err = fmt.Errorf("wrong config passed for HasMany") + } + return q. + Select(outSchema.Columns(true)...). + Table(outSchema.Table). + WhereIn(c.OwnerLookupColumn, Raw(fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`, + c.IntermediatePropertyID, + c.IntermediateTable, c.IntermediateOwnerID), genericGetPKValue(property))) +} + +// Add adds `items` to `to` using relations defined between items and to in ConfigureEntity method of `to`. +func Add(to Entity, items ...Entity) error { + if len(items) == 0 { + return nil + } + rels := getSchemaFor(to).relations + tname := getSchemaFor(items[0]).Table + c, ok := rels[tname] + if !ok { + return fmt.Errorf("no config found for given to and item...") + } + switch c.(type) { + case HasManyConfig: + return addProperty(to, items...) + case HasOneConfig: + return addProperty(to, items[0]) + case BelongsToManyConfig: + return addM2M(to, items...) + default: + return fmt.Errorf("cannot add for relation: %T", rels[getSchemaFor(items[0]).Table]) + } +} + +func addM2M(to Entity, items ...Entity) error { + //TODO: Optimize this + rels := getSchemaFor(to).relations + tname := getSchemaFor(items[0]).Table + c := rels[tname].(BelongsToManyConfig) + var values [][]interface{} + ownerPk := genericGetPKValue(to) + for _, item := range items { + pk := genericGetPKValue(item) + if isZero(pk) { + err := Insert(item) + if err != nil { + return err + } + pk = genericGetPKValue(item) + } + values = append(values, []interface{}{ownerPk, pk}) + } + i := insertStmt{ + PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator, + Table: c.IntermediateTable, + Columns: []string{c.IntermediateOwnerID, c.IntermediatePropertyID}, + Values: values, + } + + q, args := i.ToSql() + + _, err := getConnectionFor(items[0]).DB.Exec(q, args...) + if err != nil { + return err + } + + return err +} + +// addHasMany(Post, comments) +func addProperty(to Entity, items ...Entity) error { + var lastTable string + for _, obj := range items { + s := getSchemaFor(obj) + if lastTable == "" { + lastTable = s.Table + } else { + if lastTable != s.Table { + return fmt.Errorf("cannot batch insert for two different tables: %s and %s", s.Table, lastTable) + } + } + } + i := insertStmt{ + PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator, + Table: getSchemaFor(items[0]).getTable(), + } + ownerPKIdx := -1 + ownerPKName := getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey + for idx, col := range getSchemaFor(items[0]).Columns(false) { + if col == ownerPKName { + ownerPKIdx = idx + } + } + + ownerPK := genericGetPKValue(to) + if ownerPKIdx != -1 { + cols := getSchemaFor(items[0]).Columns(false) + i.Columns = append(i.Columns, cols...) + // Owner PK is present in the items struct + for _, item := range items { + vals := genericValuesOf(item, false) + if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey { + return fmt.Errorf("owner pk idx is not correct") + } + vals[ownerPKIdx] = ownerPK + i.Values = append(i.Values, vals) + } + } else { + ownerPKIdx = 0 + cols := getSchemaFor(items[0]).Columns(false) + cols = append(cols[:ownerPKIdx+1], cols[ownerPKIdx:]...) + cols[ownerPKIdx] = getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey + i.Columns = append(i.Columns, cols...) + for _, item := range items { + vals := genericValuesOf(item, false) + if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey { + return fmt.Errorf("owner pk idx is not correct") + } + vals = append(vals[:ownerPKIdx+1], vals[ownerPKIdx:]...) + vals[ownerPKIdx] = ownerPK + i.Values = append(i.Values, vals) + } + } + + q, args := i.ToSql() + + _, err := getConnectionFor(items[0]).DB.Exec(q, args...) + if err != nil { + return err + } + + return err + +} + +// Query creates a new QueryBuilder for given type parameter, sets dialect and table as well. +func Query[E Entity]() *QueryBuilder[E] { + s := getSchemaFor(*new(E)) + q := NewQueryBuilder[E](s) + q.SetDialect(s.getDialect()).Table(s.Table) + return q +} + +// ExecRaw executes given query string and arguments on given type parameter database connection. +func ExecRaw[E Entity](q string, args ...interface{}) (int64, int64, error) { + e := new(E) + + res, err := getSchemaFor(*e).getSQLDB().Exec(q, args...) + if err != nil { + return 0, 0, err + } + + id, err := res.LastInsertId() + if err != nil { + return 0, 0, err + } + + affected, err := res.RowsAffected() + if err != nil { + return 0, 0, err + } + + return id, affected, nil +} + +// QueryRaw queries given query string and arguments on given type parameter database connection. +func QueryRaw[OUTPUT Entity](q string, args ...interface{}) ([]OUTPUT, error) { + o := new(OUTPUT) + rows, err := getSchemaFor(*o).getSQLDB().Query(q, args...) + if err != nil { + return nil, err + } + var output []OUTPUT + err = newBinder(getSchemaFor(*o)).bind(rows, &output) + if err != nil { + return nil, err + } + return output, nil +} diff --git a/gdb/orm/orm_test.go b/gdb/orm/orm_test.go new file mode 100644 index 0000000..b0aa1cb --- /dev/null +++ b/gdb/orm/orm_test.go @@ -0,0 +1,588 @@ +// +// orm_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm_test + +import ( + "database/sql" + "testing" + + "git.hexq.cn/tiglog/golib/gdb/orm" + "github.com/stretchr/testify/assert" +) + +type AuthorEmail struct { + ID int64 + Email string +} + +func (a AuthorEmail) ConfigureEntity(e *orm.EntityConfigurator) { + e. + Table("emails"). + Connection("default"). + BelongsTo(&Post{}, orm.BelongsToConfig{}) +} + +type HeaderPicture struct { + ID int64 + PostID int64 + Link string +} + +func (h HeaderPicture) ConfigureEntity(e *orm.EntityConfigurator) { + e.Table("header_pictures").BelongsTo(&Post{}, orm.BelongsToConfig{}) +} + +type Post struct { + ID int64 + BodyText string + CreatedAt sql.NullTime + UpdatedAt sql.NullTime + DeletedAt sql.NullTime +} + +func (p Post) ConfigureEntity(e *orm.EntityConfigurator) { + e.Field("BodyText").ColumnName("body") + e.Field("ID").ColumnName("id") + e. + Table("posts"). + HasMany(Comment{}, orm.HasManyConfig{}). + HasOne(HeaderPicture{}, orm.HasOneConfig{}). + HasOne(AuthorEmail{}, orm.HasOneConfig{}). + BelongsToMany(Category{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"}) + +} + +func (p *Post) Categories() ([]Category, error) { + return orm.BelongsToMany[Category](p).All() +} + +func (p *Post) Comments() *orm.QueryBuilder[Comment] { + return orm.HasMany[Comment](p) +} + +type Comment struct { + ID int64 + PostID int64 + Body string +} + +func (c Comment) ConfigureEntity(e *orm.EntityConfigurator) { + e.Table("comments").BelongsTo(&Post{}, orm.BelongsToConfig{}) +} + +func (c *Comment) Post() (Post, error) { + return orm.BelongsTo[Post](c).Get() +} + +type Category struct { + ID int64 + Title string +} + +func (c Category) ConfigureEntity(e *orm.EntityConfigurator) { + e.Table("categories").BelongsToMany(Post{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"}) +} + +func (c Category) Posts() ([]Post, error) { + return orm.BelongsToMany[Post](c).All() +} + +// enough models let's test +// Entities is mandatory +// Errors should be carried + +func setup() error { + db, err := sql.Open("sqlite3", ":memory:") + if err != nil { + return err + } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) + if err != nil { + return err + } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) + if err != nil { + return err + } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) + if err != nil { + return err + } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`) + if err != nil { + return err + } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) + if err != nil { + return err + } + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) + if err != nil { + return err + } + return orm.SetupConnections(orm.ConnectionConfig{ + Name: "default", + DB: db, + Dialect: orm.Dialects.SQLite3, + Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}}, + DatabaseValidations: true, + }) +} + +func TestFind(t *testing.T) { + err := setup() + assert.NoError(t, err) + err = orm.InsertAll(&Post{ + BodyText: "my body for insert", + }) + + assert.NoError(t, err) + + post, err := orm.Find[Post](1) + assert.NoError(t, err) + assert.Equal(t, "my body for insert", post.BodyText) + assert.Equal(t, int64(1), post.ID) +} + +func TestInsert(t *testing.T) { + err := setup() + assert.NoError(t, err) + post := &Post{ + BodyText: "my body for insert", + } + err = orm.Insert(post) + assert.NoError(t, err) + assert.Equal(t, int64(1), post.ID) + var p Post + assert.NoError(t, + orm.GetConnection("default").DB.QueryRow(`SELECT id, body FROM posts where id = ?`, 1).Scan(&p.ID, &p.BodyText)) + + assert.Equal(t, "my body for insert", p.BodyText) +} +func TestInsertAll(t *testing.T) { + err := setup() + assert.NoError(t, err) + post1 := &Post{ + BodyText: "Body1", + } + post2 := &Post{ + BodyText: "Body2", + } + + post3 := &Post{ + BodyText: "Body3", + } + + err = orm.InsertAll(post1, post2, post3) + assert.NoError(t, err) + var counter int + assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts`).Scan(&counter)) + assert.Equal(t, 3, counter) + +} +func TestUpdateORM(t *testing.T) { + err := setup() + assert.NoError(t, err) + post := &Post{ + BodyText: "my body for insert", + } + err = orm.Insert(post) + assert.NoError(t, err) + assert.Equal(t, int64(1), post.ID) + + post.BodyText += " update text" + assert.NoError(t, orm.Update(post)) + + var body string + assert.NoError(t, + orm.GetConnection("default").DB.QueryRow(`SELECT body FROM posts where id = ?`, post.ID).Scan(&body)) + + assert.Equal(t, "my body for insert update text", body) +} + +func TestDeleteORM(t *testing.T) { + err := setup() + assert.NoError(t, err) + post := &Post{ + BodyText: "my body for insert", + } + err = orm.Insert(post) + assert.NoError(t, err) + assert.Equal(t, int64(1), post.ID) + + assert.NoError(t, orm.Delete(post)) + + var count int + assert.NoError(t, + orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts where id = ?`, post.ID).Scan(&count)) + + assert.Equal(t, 0, count) +} +func TestAdd_HasMany(t *testing.T) { + err := setup() + assert.NoError(t, err) + post := &Post{ + BodyText: "my body for insert", + } + err = orm.Insert(post) + assert.NoError(t, err) + assert.Equal(t, int64(1), post.ID) + + err = orm.Add(post, []orm.Entity{ + Comment{ + Body: "comment 1", + }, + Comment{ + Body: "comment 2", + }, + }...) + // orm.Query(qm.WhereBetween()) + assert.NoError(t, err) + var count int + assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM comments`).Scan(&count)) + assert.Equal(t, 2, count) + + comment, err := orm.Find[Comment](1) + assert.NoError(t, err) + + assert.Equal(t, int64(1), comment.PostID) + +} + +func TestAdd_ManyToMany(t *testing.T) { + err := setup() + assert.NoError(t, err) + post := &Post{ + BodyText: "my body for insert", + } + err = orm.Insert(post) + assert.NoError(t, err) + assert.Equal(t, int64(1), post.ID) + + err = orm.Add(post, []orm.Entity{ + &Category{ + Title: "cat 1", + }, + &Category{ + Title: "cat 2", + }, + }...) + assert.NoError(t, err) + var count int + assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(post_id) FROM post_categories`).Scan(&count)) + assert.Equal(t, 2, count) + assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM categories`).Scan(&count)) + assert.Equal(t, 2, count) + + categories, err := post.Categories() + assert.NoError(t, err) + + assert.Equal(t, 2, len(categories)) + assert.Equal(t, int64(1), categories[0].ID) + assert.Equal(t, int64(2), categories[1].ID) + +} + +func TestSave(t *testing.T) { + t.Run("save should insert", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + post := &Post{ + BodyText: "1", + } + assert.NoError(t, orm.Save(post)) + assert.Equal(t, int64(1), post.ID) + }) + + t.Run("save should update", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "1", + } + assert.NoError(t, orm.Save(post)) + assert.Equal(t, int64(1), post.ID) + + post.BodyText += "2" + assert.NoError(t, orm.Save(post)) + + myPost, err := orm.Find[Post](1) + assert.NoError(t, err) + + assert.EqualValues(t, post.BodyText, myPost.BodyText) + }) + +} + +func TestHasMany(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "first post", + } + assert.NoError(t, orm.Save(post)) + assert.Equal(t, int64(1), post.ID) + + assert.NoError(t, orm.Save(&Comment{ + PostID: post.ID, + Body: "comment 1", + })) + assert.NoError(t, orm.Save(&Comment{ + PostID: post.ID, + Body: "comment 2", + })) + + comments, err := orm.HasMany[Comment](post).All() + assert.NoError(t, err) + + assert.Len(t, comments, 2) + + assert.Equal(t, post.ID, comments[0].PostID) + assert.Equal(t, post.ID, comments[1].PostID) +} + +func TestBelongsTo(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "first post", + } + assert.NoError(t, orm.Save(post)) + assert.Equal(t, int64(1), post.ID) + + comment := &Comment{ + PostID: post.ID, + Body: "comment 1", + } + assert.NoError(t, orm.Save(comment)) + + post2, err := orm.BelongsTo[Post](comment).Get() + assert.NoError(t, err) + + assert.Equal(t, post.BodyText, post2.BodyText) +} + +func TestHasOne(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "first post", + } + assert.NoError(t, orm.Save(post)) + assert.Equal(t, int64(1), post.ID) + + headerPicture := &HeaderPicture{ + PostID: post.ID, + Link: "google", + } + assert.NoError(t, orm.Save(headerPicture)) + + c1, err := orm.HasOne[HeaderPicture](post).Get() + assert.NoError(t, err) + + assert.Equal(t, headerPicture.PostID, c1.PostID) +} + +func TestBelongsToMany(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "first Post", + } + + assert.NoError(t, orm.Save(post)) + assert.Equal(t, int64(1), post.ID) + + category := &Category{ + Title: "first category", + } + assert.NoError(t, orm.Save(category)) + assert.Equal(t, int64(1), category.ID) + + _, _, err = orm.ExecRaw[Category](`INSERT INTO post_categories (post_id, category_id) VALUES (?,?)`, post.ID, category.ID) + assert.NoError(t, err) + + categories, err := orm.BelongsToMany[Category](post).All() + assert.NoError(t, err) + + assert.Len(t, categories, 1) +} + +func TestSchematic(t *testing.T) { + err := setup() + assert.NoError(t, err) + + orm.Schematic() +} + +func TestAddProperty(t *testing.T) { + t.Run("having pk value", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "first post", + } + + assert.NoError(t, orm.Save(post)) + assert.EqualValues(t, 1, post.ID) + + err = orm.Add(post, &Comment{PostID: post.ID, Body: "firstComment"}) + assert.NoError(t, err) + + var comment Comment + assert.NoError(t, orm.GetConnection("default"). + DB. + QueryRow(`SELECT id, post_id, body FROM comments WHERE post_id=?`, post.ID). + Scan(&comment.ID, &comment.PostID, &comment.Body)) + + assert.EqualValues(t, post.ID, comment.PostID) + }) + t.Run("not having PK value", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + post := &Post{ + BodyText: "first post", + } + assert.NoError(t, orm.Save(post)) + assert.EqualValues(t, 1, post.ID) + + err = orm.Add(post, &AuthorEmail{Email: "myemail"}) + assert.NoError(t, err) + + emails, err := orm.QueryRaw[AuthorEmail](`SELECT id, email FROM emails WHERE post_id=?`, post.ID) + + assert.NoError(t, err) + assert.Equal(t, []AuthorEmail{{ID: 1, Email: "myemail"}}, emails) + }) +} + +func TestQuery(t *testing.T) { + t.Run("querying single row", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) + // post, err := orm.Query[Post]().Where("id", 1).First() + post, err := orm.Query[Post]().WherePK(1).First().Get() + assert.NoError(t, err) + assert.EqualValues(t, "body 1", post.BodyText) + assert.EqualValues(t, 1, post.ID) + + }) + t.Run("querying multiple rows", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) + assert.NoError(t, orm.Save(&Post{BodyText: "body 2"})) + assert.NoError(t, orm.Save(&Post{BodyText: "body 3"})) + posts, err := orm.Query[Post]().All() + assert.NoError(t, err) + assert.Len(t, posts, 3) + assert.Equal(t, "body 1", posts[0].BodyText) + }) + + t.Run("updating a row using query interface", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) + + affected, err := orm.Query[Post]().Where("id", 1).Set("body", "body jadid").Update() + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + + post, err := orm.Find[Post](1) + assert.NoError(t, err) + assert.Equal(t, "body jadid", post.BodyText) + }) + + t.Run("deleting a row using query interface", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) + + affected, err := orm.Query[Post]().WherePK(1).Delete() + assert.NoError(t, err) + assert.EqualValues(t, 1, affected) + count, err := orm.Query[Post]().WherePK(1).Count().Get() + assert.NoError(t, err) + assert.EqualValues(t, 0, count) + }) + + t.Run("count", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + count, err := orm.Query[Post]().WherePK(1).Count().Get() + assert.NoError(t, err) + assert.EqualValues(t, 0, count) + }) + + t.Run("latest", func(t *testing.T) { + err := setup() + assert.NoError(t, err) + + assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) + assert.NoError(t, orm.Save(&Post{BodyText: "body 2"})) + + post, err := orm.Query[Post]().Latest().Get() + assert.NoError(t, err) + + assert.EqualValues(t, "body 2", post.BodyText) + }) +} + +func TestSetup(t *testing.T) { + t.Run("tables are out of sync", func(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + // _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) + // _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) + + err = orm.SetupConnections(orm.ConnectionConfig{ + Name: "default", + DB: db, + Dialect: orm.Dialects.SQLite3, + Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}}, + DatabaseValidations: true, + }) + assert.Error(t, err) + + }) + t.Run("schemas are wrong", func(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, body text)`) // missing post_id + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) + _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) + + err = orm.SetupConnections(orm.ConnectionConfig{ + Name: "default", + DB: db, + Dialect: orm.Dialects.SQLite3, + Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}}, + DatabaseValidations: true, + }) + assert.Error(t, err) + + }) +} diff --git a/gdb/orm/query.go b/gdb/orm/query.go new file mode 100644 index 0000000..8c3f23d --- /dev/null +++ b/gdb/orm/query.go @@ -0,0 +1,811 @@ +// +// query.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "fmt" + "strings" +) + +const ( + queryTypeSELECT = iota + 1 + queryTypeUPDATE + queryTypeDelete +) + +// QueryBuilder is our query builder, almost all methods and functions in GoLobby ORM +// create or configure instance of QueryBuilder. +type QueryBuilder[OUTPUT any] struct { + typ int + schema *schema + // general parts + where *whereClause + table string + placeholderGenerator func(n int) []string + + // select parts + orderBy *orderByClause + groupBy *GroupBy + selected *selected + subQuery *struct { + q string + args []interface{} + placeholderGenerator func(n int) []string + } + joins []*Join + limit *Limit + offset *Offset + + // update parts + sets [][2]interface{} + + // execution parts + db *sql.DB + err error +} + +// Finisher APIs + +// execute is a finisher executes QueryBuilder query, remember to use this when you have an Update +// or Delete Query. +func (q *QueryBuilder[OUTPUT]) execute() (sql.Result, error) { + if q.err != nil { + return nil, q.err + } + if q.typ == queryTypeSELECT { + return nil, fmt.Errorf("query type is SELECT") + } + query, args, err := q.ToSql() + if err != nil { + return nil, err + } + return q.schema.getConnection().exec(query, args...) +} + +// Get limit results to 1, runs query generated by query builder, scans result into OUTPUT. +func (q *QueryBuilder[OUTPUT]) Get() (OUTPUT, error) { + if q.err != nil { + return *new(OUTPUT), q.err + } + queryString, args, err := q.ToSql() + if err != nil { + return *new(OUTPUT), err + } + rows, err := q.schema.getConnection().query(queryString, args...) + if err != nil { + return *new(OUTPUT), err + } + var output OUTPUT + err = newBinder(q.schema).bind(rows, &output) + if err != nil { + return *new(OUTPUT), err + } + return output, nil +} + +// All is a finisher, create the Select query based on QueryBuilder and scan results into +// slice of type parameter E. +func (q *QueryBuilder[OUTPUT]) All() ([]OUTPUT, error) { + if q.err != nil { + return nil, q.err + } + q.SetSelect() + queryString, args, err := q.ToSql() + if err != nil { + return nil, err + } + rows, err := q.schema.getConnection().query(queryString, args...) + if err != nil { + return nil, err + } + var output []OUTPUT + err = newBinder(q.schema).bind(rows, &output) + if err != nil { + return nil, err + } + return output, nil +} + +// Delete is a finisher, creates a delete query from query builder and executes it. +func (q *QueryBuilder[OUTPUT]) Delete() (rowsAffected int64, err error) { + if q.err != nil { + return 0, q.err + } + q.SetDelete() + res, err := q.execute() + if err != nil { + return 0, q.err + } + return res.RowsAffected() +} + +// Update is a finisher, creates an Update query from QueryBuilder and executes in into database, returns +func (q *QueryBuilder[OUTPUT]) Update() (rowsAffected int64, err error) { + if q.err != nil { + return 0, q.err + } + q.SetUpdate() + res, err := q.execute() + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func copyQueryBuilder[T1 any, T2 any](q *QueryBuilder[T1], q2 *QueryBuilder[T2]) { + q2.db = q.db + q2.err = q.err + q2.groupBy = q.groupBy + q2.joins = q.joins + q2.limit = q.limit + q2.offset = q.offset + q2.orderBy = q.orderBy + q2.placeholderGenerator = q.placeholderGenerator + q2.schema = q.schema + q2.selected = q.selected + q2.sets = q.sets + + q2.subQuery = q.subQuery + q2.table = q.table + q2.typ = q.typ + q2.where = q.where +} + +// Count creates and execute a select query from QueryBuilder and set it's field list of selection +// to COUNT(id). +func (q *QueryBuilder[OUTPUT]) Count() *QueryBuilder[int] { + q.selected = &selected{Columns: []string{"COUNT(id)"}} + q.SetSelect() + qCount := NewQueryBuilder[int](q.schema) + + copyQueryBuilder(q, qCount) + + return qCount +} + +// First returns first record of database using OrderBy primary key +// ascending order. +func (q *QueryBuilder[OUTPUT]) First() *QueryBuilder[OUTPUT] { + q.OrderBy(q.schema.pkName(), ASC).Limit(1) + return q +} + +// Latest is like Get but it also do a OrderBy(primary key, DESC) +func (q *QueryBuilder[OUTPUT]) Latest() *QueryBuilder[OUTPUT] { + q.OrderBy(q.schema.pkName(), DESC).Limit(1) + return q +} + +// WherePK adds a where clause to QueryBuilder and also gets primary key name +// from type parameter schema. +func (q *QueryBuilder[OUTPUT]) WherePK(value interface{}) *QueryBuilder[OUTPUT] { + return q.Where(q.schema.pkName(), value) +} + +func (d *QueryBuilder[OUTPUT]) toSqlDelete() (string, []interface{}, error) { + base := fmt.Sprintf("DELETE FROM %s", d.table) + var args []interface{} + if d.where != nil { + d.where.PlaceHolderGenerator = d.placeholderGenerator + where, whereArgs, err := d.where.ToSql() + if err != nil { + return "", nil, err + } + base += " WHERE " + where + args = append(args, whereArgs...) + } + return base, args, nil +} +func pop(phs *[]string) string { + top := (*phs)[len(*phs)-1] + *phs = (*phs)[:len(*phs)-1] + return top +} + +func (u *QueryBuilder[OUTPUT]) kvString() string { + phs := u.placeholderGenerator(len(u.sets)) + var sets []string + for _, pair := range u.sets { + sets = append(sets, fmt.Sprintf("%s=%s", pair[0], pop(&phs))) + } + return strings.Join(sets, ",") +} + +func (u *QueryBuilder[OUTPUT]) args() []interface{} { + var values []interface{} + for _, pair := range u.sets { + values = append(values, pair[1]) + } + return values +} + +func (u *QueryBuilder[OUTPUT]) toSqlUpdate() (string, []interface{}, error) { + if u.table == "" { + return "", nil, fmt.Errorf("table cannot be empty") + } + base := fmt.Sprintf("UPDATE %s SET %s", u.table, u.kvString()) + args := u.args() + if u.where != nil { + u.where.PlaceHolderGenerator = u.placeholderGenerator + where, whereArgs, err := u.where.ToSql() + if err != nil { + return "", nil, err + } + args = append(args, whereArgs...) + base += " WHERE " + where + } + return base, args, nil +} +func (s *QueryBuilder[OUTPUT]) toSqlSelect() (string, []interface{}, error) { + if s.err != nil { + return "", nil, s.err + } + base := "SELECT" + var args []interface{} + // select + if s.selected == nil { + s.selected = &selected{ + Columns: []string{"*"}, + } + } + base += " " + s.selected.String() + // from + if s.table == "" && s.subQuery == nil { + return "", nil, fmt.Errorf("Table name cannot be empty") + } else if s.table != "" && s.subQuery != nil { + return "", nil, fmt.Errorf("cannot have both Table and subquery") + } + if s.table != "" { + base += " " + "FROM " + s.table + } + if s.subQuery != nil { + s.subQuery.placeholderGenerator = s.placeholderGenerator + base += " " + "FROM (" + s.subQuery.q + " )" + args = append(args, s.subQuery.args...) + } + // Joins + if s.joins != nil { + for _, join := range s.joins { + base += " " + join.String() + } + } + // whereClause + if s.where != nil { + s.where.PlaceHolderGenerator = s.placeholderGenerator + where, whereArgs, err := s.where.ToSql() + if err != nil { + return "", nil, err + } + base += " WHERE " + where + args = append(args, whereArgs...) + } + + // orderByClause + if s.orderBy != nil { + base += " " + s.orderBy.String() + } + + // GroupBy + if s.groupBy != nil { + base += " " + s.groupBy.String() + } + + // Limit + if s.limit != nil { + base += " " + s.limit.String() + } + + // Offset + if s.offset != nil { + base += " " + s.offset.String() + } + + return base, args, nil +} + +// ToSql creates sql query from QueryBuilder based on internal fields it would decide what kind +// of query to build. +func (q *QueryBuilder[OUTPUT]) ToSql() (string, []interface{}, error) { + if q.err != nil { + return "", nil, q.err + } + if q.typ == queryTypeSELECT { + return q.toSqlSelect() + } else if q.typ == queryTypeDelete { + return q.toSqlDelete() + } else if q.typ == queryTypeUPDATE { + return q.toSqlUpdate() + } else { + return "", nil, fmt.Errorf("no sql type matched") + } +} + +type orderByOrder string + +const ( + ASC = "ASC" + DESC = "DESC" +) + +type orderByClause struct { + Columns [][2]string +} + +func (o orderByClause) String() string { + var tuples []string + for _, pair := range o.Columns { + tuples = append(tuples, fmt.Sprintf("%s %s", pair[0], pair[1])) + } + return fmt.Sprintf("ORDER BY %s", strings.Join(tuples, ",")) +} + +type GroupBy struct { + Columns []string +} + +func (g GroupBy) String() string { + return fmt.Sprintf("GROUP BY %s", strings.Join(g.Columns, ",")) +} + +type joinType string + +const ( + JoinTypeInner = "INNER" + JoinTypeLeft = "LEFT" + JoinTypeRight = "RIGHT" + JoinTypeFull = "FULL OUTER" + JoinTypeSelf = "SELF" +) + +type JoinOn struct { + Lhs string + Rhs string +} + +func (j JoinOn) String() string { + return fmt.Sprintf("%s = %s", j.Lhs, j.Rhs) +} + +type Join struct { + Type joinType + Table string + On JoinOn +} + +func (j Join) String() string { + return fmt.Sprintf("%s JOIN %s ON %s", j.Type, j.Table, j.On.String()) +} + +type Limit struct { + N int +} + +func (l Limit) String() string { + return fmt.Sprintf("LIMIT %d", l.N) +} + +type Offset struct { + N int +} + +func (o Offset) String() string { + return fmt.Sprintf("OFFSET %d", o.N) +} + +type selected struct { + Columns []string +} + +func (s selected) String() string { + return fmt.Sprintf("%s", strings.Join(s.Columns, ",")) +} + +// OrderBy adds an OrderBy section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) OrderBy(column string, how string) *QueryBuilder[OUTPUT] { + q.SetSelect() + if q.orderBy == nil { + q.orderBy = &orderByClause{} + } + q.orderBy.Columns = append(q.orderBy.Columns, [2]string{column, how}) + return q +} + +// LeftJoin adds a left join section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) LeftJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { + q.SetSelect() + q.joins = append(q.joins, &Join{ + Type: JoinTypeLeft, + Table: table, + On: JoinOn{ + Lhs: onLhs, + Rhs: onRhs, + }, + }) + return q +} + +// RightJoin adds a right join section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) RightJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { + q.SetSelect() + q.joins = append(q.joins, &Join{ + Type: JoinTypeRight, + Table: table, + On: JoinOn{ + Lhs: onLhs, + Rhs: onRhs, + }, + }) + return q +} + +// InnerJoin adds a inner join section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) InnerJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { + q.SetSelect() + q.joins = append(q.joins, &Join{ + Type: JoinTypeInner, + Table: table, + On: JoinOn{ + Lhs: onLhs, + Rhs: onRhs, + }, + }) + return q +} + +// Join adds a inner join section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) Join(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { + return q.InnerJoin(table, onLhs, onRhs) +} + +// FullOuterJoin adds a full outer join section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) FullOuterJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { + q.SetSelect() + q.joins = append(q.joins, &Join{ + Type: JoinTypeFull, + Table: table, + On: JoinOn{ + Lhs: onLhs, + Rhs: onRhs, + }, + }) + return q +} + +// Where Adds a where clause to query, if already have where clause append to it +// as AndWhere. +func (q *QueryBuilder[OUTPUT]) Where(parts ...interface{}) *QueryBuilder[OUTPUT] { + if q.where != nil { + return q.addWhere("AND", parts...) + } + if len(parts) == 1 { + if r, isRaw := parts[0].(*raw); isRaw { + q.where = &whereClause{raw: r.sql, args: r.args, PlaceHolderGenerator: q.placeholderGenerator} + return q + } else { + q.err = fmt.Errorf("when you have one argument passed to where, it should be *raw") + return q + } + + } else if len(parts) == 2 { + if strings.Index(parts[0].(string), " ") == -1 { + // Equal mode + q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]}, PlaceHolderGenerator: q.placeholderGenerator} + } + return q + } else if len(parts) == 3 { + // operator mode + q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]}, PlaceHolderGenerator: q.placeholderGenerator} + return q + } else if len(parts) > 3 && parts[1].(string) == "IN" { + q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2:]}, PlaceHolderGenerator: q.placeholderGenerator} + return q + } else { + q.err = fmt.Errorf("wrong number of arguments passed to Where") + return q + } +} + +type binaryOp string + +const ( + Eq = "=" + GT = ">" + LT = "<" + GE = ">=" + LE = "<=" + NE = "!=" + Between = "BETWEEN" + Like = "LIKE" + In = "IN" +) + +type cond struct { + PlaceHolderGenerator func(n int) []string + + Lhs string + Op binaryOp + Rhs interface{} +} + +func (b cond) ToSql() (string, []interface{}, error) { + var phs []string + if b.Op == In { + rhs, isInterfaceSlice := b.Rhs.([]interface{}) + if isInterfaceSlice { + phs = b.PlaceHolderGenerator(len(rhs)) + return fmt.Sprintf("%s IN (%s)", b.Lhs, strings.Join(phs, ",")), rhs, nil + } else if rawThing, isRaw := b.Rhs.(*raw); isRaw { + return fmt.Sprintf("%s IN (%s)", b.Lhs, rawThing.sql), rawThing.args, nil + } else { + return "", nil, fmt.Errorf("Right hand side of Cond when operator is IN should be either a interface{} slice or *raw") + } + + } else { + phs = b.PlaceHolderGenerator(1) + return fmt.Sprintf("%s %s %s", b.Lhs, b.Op, pop(&phs)), []interface{}{b.Rhs}, nil + } +} + +const ( + nextType_AND = "AND" + nextType_OR = "OR" +) + +type whereClause struct { + PlaceHolderGenerator func(n int) []string + nextTyp string + next *whereClause + cond + raw string + args []interface{} +} + +func (w whereClause) ToSql() (string, []interface{}, error) { + var base string + var args []interface{} + var err error + if w.raw != "" { + base = w.raw + args = w.args + } else { + w.cond.PlaceHolderGenerator = w.PlaceHolderGenerator + base, args, err = w.cond.ToSql() + if err != nil { + return "", nil, err + } + } + if w.next == nil { + return base, args, nil + } + if w.next != nil { + next, nextArgs, err := w.next.ToSql() + if err != nil { + return "", nil, err + } + base += " " + w.nextTyp + " " + next + args = append(args, nextArgs...) + return base, args, nil + } + + return base, args, nil +} + +//func (q *QueryBuilder[OUTPUT]) WhereKeyValue(m map) {} + +// WhereIn adds a where clause to QueryBuilder using In operator. +func (q *QueryBuilder[OUTPUT]) WhereIn(column string, values ...interface{}) *QueryBuilder[OUTPUT] { + return q.Where(append([]interface{}{column, In}, values...)...) +} + +// AndWhere appends a where clause to query builder as And where clause. +func (q *QueryBuilder[OUTPUT]) AndWhere(parts ...interface{}) *QueryBuilder[OUTPUT] { + return q.addWhere(nextType_AND, parts...) +} + +// OrWhere appends a where clause to query builder as Or where clause. +func (q *QueryBuilder[OUTPUT]) OrWhere(parts ...interface{}) *QueryBuilder[OUTPUT] { + return q.addWhere(nextType_OR, parts...) +} + +func (q *QueryBuilder[OUTPUT]) addWhere(typ string, parts ...interface{}) *QueryBuilder[OUTPUT] { + w := q.where + for { + if w == nil { + break + } else if w.next == nil { + w.next = &whereClause{PlaceHolderGenerator: q.placeholderGenerator} + w.nextTyp = typ + w = w.next + break + } else { + w = w.next + } + } + if w == nil { + w = &whereClause{PlaceHolderGenerator: q.placeholderGenerator} + } + if len(parts) == 1 { + w.raw = parts[0].(*raw).sql + w.args = parts[0].(*raw).args + return q + } else if len(parts) == 2 { + // Equal mode + w.cond = cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]} + return q + } else if len(parts) == 3 { + // operator mode + w.cond = cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]} + return q + } else { + panic("wrong number of arguments passed to Where") + } +} + +// Offset adds offset section to query builder. +func (q *QueryBuilder[OUTPUT]) Offset(n int) *QueryBuilder[OUTPUT] { + q.SetSelect() + q.offset = &Offset{N: n} + return q +} + +// Limit adds limit section to query builder. +func (q *QueryBuilder[OUTPUT]) Limit(n int) *QueryBuilder[OUTPUT] { + q.SetSelect() + q.limit = &Limit{N: n} + return q +} + +// Table sets table of QueryBuilder. +func (q *QueryBuilder[OUTPUT]) Table(t string) *QueryBuilder[OUTPUT] { + q.table = t + return q +} + +// SetSelect sets query type of QueryBuilder to Select. +func (q *QueryBuilder[OUTPUT]) SetSelect() *QueryBuilder[OUTPUT] { + q.typ = queryTypeSELECT + return q +} + +// GroupBy adds a group by section to QueryBuilder. +func (q *QueryBuilder[OUTPUT]) GroupBy(columns ...string) *QueryBuilder[OUTPUT] { + q.SetSelect() + if q.groupBy == nil { + q.groupBy = &GroupBy{} + } + q.groupBy.Columns = append(q.groupBy.Columns, columns...) + return q +} + +// Select adds columns to QueryBuilder select field list. +func (q *QueryBuilder[OUTPUT]) Select(columns ...string) *QueryBuilder[OUTPUT] { + q.SetSelect() + if q.selected == nil { + q.selected = &selected{} + } + q.selected.Columns = append(q.selected.Columns, columns...) + return q +} + +// FromQuery sets subquery of QueryBuilder to be given subquery so +// when doing select instead of from table we do from(subquery). +func (q *QueryBuilder[OUTPUT]) FromQuery(subQuery *QueryBuilder[OUTPUT]) *QueryBuilder[OUTPUT] { + q.SetSelect() + subQuery.SetSelect() + subQuery.placeholderGenerator = q.placeholderGenerator + subQueryString, args, err := subQuery.ToSql() + q.err = err + q.subQuery = &struct { + q string + args []interface{} + placeholderGenerator func(n int) []string + }{ + subQueryString, args, q.placeholderGenerator, + } + return q +} + +func (q *QueryBuilder[OUTPUT]) SetUpdate() *QueryBuilder[OUTPUT] { + q.typ = queryTypeUPDATE + return q +} + +func (q *QueryBuilder[OUTPUT]) Set(keyValues ...any) *QueryBuilder[OUTPUT] { + if len(keyValues)%2 != 0 { + q.err = fmt.Errorf("when using Set, passed argument count should be even: %w", q.err) + return q + } + q.SetUpdate() + for i := 0; i < len(keyValues); i++ { + if i != 0 && i%2 == 1 { + q.sets = append(q.sets, [2]any{keyValues[i-1], keyValues[i]}) + } + } + return q +} + +func (q *QueryBuilder[OUTPUT]) SetDialect(dialect *Dialect) *QueryBuilder[OUTPUT] { + q.placeholderGenerator = dialect.PlaceHolderGenerator + return q +} +func (q *QueryBuilder[OUTPUT]) SetDelete() *QueryBuilder[OUTPUT] { + q.typ = queryTypeDelete + return q +} + +type raw struct { + sql string + args []interface{} +} + +// Raw creates a Raw sql query chunk that you can add to several components of QueryBuilder like +// Wheres. +func Raw(sql string, args ...interface{}) *raw { + return &raw{sql: sql, args: args} +} + +func NewQueryBuilder[OUTPUT any](s *schema) *QueryBuilder[OUTPUT] { + return &QueryBuilder[OUTPUT]{schema: s} +} + +type insertStmt struct { + PlaceHolderGenerator func(n int) []string + Table string + Columns []string + Values [][]interface{} + Returning string +} + +func (i insertStmt) flatValues() []interface{} { + var values []interface{} + for _, row := range i.Values { + values = append(values, row...) + } + return values +} + +func (i insertStmt) getValuesStr() string { + phs := i.PlaceHolderGenerator(len(i.Values) * len(i.Values[0])) + + var output []string + for _, valueRow := range i.Values { + output = append(output, fmt.Sprintf("(%s)", strings.Join(phs[:len(valueRow)], ","))) + phs = phs[len(valueRow):] + } + return strings.Join(output, ",") +} + +func (i insertStmt) ToSql() (string, []interface{}) { + base := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", + i.Table, + strings.Join(i.Columns, ","), + i.getValuesStr(), + ) + if i.Returning != "" { + base += "RETURNING " + i.Returning + } + return base, i.flatValues() +} + +func postgresPlaceholder(n int) []string { + output := []string{} + for i := 1; i < n+1; i++ { + output = append(output, fmt.Sprintf("$%d", i)) + } + return output +} + +func questionMarks(n int) []string { + output := []string{} + for i := 0; i < n; i++ { + output = append(output, "?") + } + + return output +} diff --git a/gdb/orm/query_test.go b/gdb/orm/query_test.go new file mode 100644 index 0000000..6bc0a71 --- /dev/null +++ b/gdb/orm/query_test.go @@ -0,0 +1,243 @@ +// +// query_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +type Dummy struct{} + +func (d Dummy) ConfigureEntity(e *EntityConfigurator) { + // TODO implement me + panic("implement me") +} + +func TestSelect(t *testing.T) { + t.Run("only select * from Table", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil) + s.Table("users").SetSelect() + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, "SELECT * FROM users", str) + }) + t.Run("select with whereClause", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil) + + s.Table("users").SetDialect(Dialects.MySQL). + Where("age", 10). + AndWhere("age", "<", 10). + Where("name", "Amirreza"). + OrWhere("age", GT, 11). + SetSelect() + + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.EqualValues(t, []interface{}{10, 10, "Amirreza", 11}, args) + assert.Equal(t, "SELECT * FROM users WHERE age = ? AND age < ? AND name = ? OR age > ?", str) + }) + t.Run("select with order by", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").OrderBy("created_at", ASC).OrderBy("updated_at", DESC) + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, "SELECT * FROM users ORDER BY created_at ASC,updated_at DESC", str) + }) + + t.Run("select with group by", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").GroupBy("created_at", "updated_at") + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, "SELECT * FROM users GROUP BY created_at,updated_at", str) + }) + + t.Run("Select with limit", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").Limit(10) + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, "SELECT * FROM users LIMIT 10", str) + }) + + t.Run("Select with offset", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").Offset(10) + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, "SELECT * FROM users OFFSET 10", str) + }) + + t.Run("select with join", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").RightJoin("addresses", "users.id", "addresses.user_id") + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id`, str) + }) + + t.Run("select with multiple joins", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users"). + Select("id", "name"). + RightJoin("addresses", "users.id", "addresses.user_id"). + LeftJoin("user_credits", "users.id", "user_credits.user_id") + sql, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id LEFT JOIN user_credits ON users.id = user_credits.user_id`, sql) + }) + + t.Run("select with subquery", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).SetDialect(Dialects.MySQL) + s.FromQuery(NewQueryBuilder[Dummy](nil).Table("users").Where("age", "<", 10)) + sql, args, err := s.ToSql() + assert.NoError(t, err) + assert.EqualValues(t, []interface{}{10}, args) + assert.Equal(t, `SELECT * FROM (SELECT * FROM users WHERE age < ? )`, sql) + + }) + + t.Run("select with inner join", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").InnerJoin("addresses", "users.id", "addresses.user_id") + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str) + }) + + t.Run("select with join", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").Join("addresses", "users.id", "addresses.user_id") + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str) + }) + + t.Run("select with full outer join", func(t *testing.T) { + s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").FullOuterJoin("addresses", "users.id", "addresses.user_id") + str, args, err := s.ToSql() + assert.NoError(t, err) + assert.Empty(t, args) + assert.Equal(t, `SELECT id,name FROM users FULL OUTER JOIN addresses ON users.id = addresses.user_id`, str) + }) + t.Run("raw where", func(t *testing.T) { + sql, args, err := + NewQueryBuilder[Dummy](nil). + SetDialect(Dialects.MySQL). + Table("users"). + Where(Raw("id = ?", 1)). + AndWhere(Raw("age < ?", 10)). + SetSelect(). + ToSql() + + assert.NoError(t, err) + assert.EqualValues(t, []interface{}{1, 10}, args) + assert.Equal(t, `SELECT * FROM users WHERE id = ? AND age < ?`, sql) + }) + t.Run("no sql type matched", func(t *testing.T) { + sql, args, err := NewQueryBuilder[Dummy](nil).ToSql() + assert.Error(t, err) + assert.Empty(t, args) + assert.Empty(t, sql) + }) + + t.Run("raw where in", func(t *testing.T) { + sql, args, err := + NewQueryBuilder[Dummy](nil). + SetDialect(Dialects.MySQL). + Table("users"). + WhereIn("id", Raw("SELECT user_id FROM user_books WHERE book_id = ?", 10)). + SetSelect(). + ToSql() + + assert.NoError(t, err) + assert.EqualValues(t, []interface{}{10}, args) + assert.Equal(t, `SELECT * FROM users WHERE id IN (SELECT user_id FROM user_books WHERE book_id = ?)`, sql) + }) + t.Run("where in", func(t *testing.T) { + sql, args, err := + NewQueryBuilder[Dummy](nil). + SetDialect(Dialects.MySQL). + Table("users"). + WhereIn("id", 1, 2, 3, 4, 5, 6). + SetSelect(). + ToSql() + + assert.NoError(t, err) + assert.EqualValues(t, []interface{}{1, 2, 3, 4, 5, 6}, args) + assert.Equal(t, `SELECT * FROM users WHERE id IN (?,?,?,?,?,?)`, sql) + + }) +} +func TestUpdate(t *testing.T) { + t.Run("update no whereClause", func(t *testing.T) { + u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").SetDialect(Dialects.MySQL) + sql, args, err := u.ToSql() + assert.NoError(t, err) + assert.Equal(t, `UPDATE users SET name=?`, sql) + assert.Equal(t, []interface{}{"amirreza"}, args) + }) + t.Run("update with whereClause", func(t *testing.T) { + u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").Where("age", "<", 18).SetDialect(Dialects.MySQL) + sql, args, err := u.ToSql() + assert.NoError(t, err) + assert.Equal(t, `UPDATE users SET name=? WHERE age < ?`, sql) + assert.Equal(t, []interface{}{"amirreza", 18}, args) + + }) +} +func TestDelete(t *testing.T) { + t.Run("delete without whereClause", func(t *testing.T) { + d := NewQueryBuilder[Dummy](nil).Table("users").SetDelete() + sql, args, err := d.ToSql() + assert.NoError(t, err) + assert.Equal(t, `DELETE FROM users`, sql) + assert.Empty(t, args) + }) + t.Run("delete with whereClause", func(t *testing.T) { + d := NewQueryBuilder[Dummy](nil).Table("users").SetDialect(Dialects.MySQL).Where("created_at", ">", "2012-01-10").SetDelete() + sql, args, err := d.ToSql() + assert.NoError(t, err) + assert.Equal(t, `DELETE FROM users WHERE created_at > ?`, sql) + assert.EqualValues(t, []interface{}{"2012-01-10"}, args) + }) +} + +func TestInsert(t *testing.T) { + t.Run("insert into multiple rows", func(t *testing.T) { + i := insertStmt{} + i.Table = "users" + i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator + i.Columns = []string{"name", "age"} + i.Values = append(i.Values, []interface{}{"amirreza", 11}, []interface{}{"parsa", 10}) + s, args := i.ToSql() + assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?),(?,?)`, s) + assert.EqualValues(t, []interface{}{"amirreza", 11, "parsa", 10}, args) + }) + + t.Run("insert into single row", func(t *testing.T) { + i := insertStmt{} + i.Table = "users" + i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator + i.Columns = []string{"name", "age"} + i.Values = append(i.Values, []interface{}{"amirreza", 11}) + s, args := i.ToSql() + assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?)`, s) + assert.Equal(t, []interface{}{"amirreza", 11}, args) + }) +} + +func TestPostgresPlaceholder(t *testing.T) { + t.Run("for 5 it should have 5", func(t *testing.T) { + phs := postgresPlaceholder(5) + assert.EqualValues(t, []string{"$1", "$2", "$3", "$4", "$5"}, phs) + }) +} diff --git a/gdb/orm/schema.go b/gdb/orm/schema.go new file mode 100644 index 0000000..806b0b5 --- /dev/null +++ b/gdb/orm/schema.go @@ -0,0 +1,306 @@ +// +// schema.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "reflect" +) + +func getConnectionFor(e Entity) *connection { + configurator := newEntityConfigurator() + e.ConfigureEntity(configurator) + + if len(globalConnections) > 1 && (configurator.connection == "" || configurator.table == "") { + panic("need table and DB name when having more than 1 DB registered") + } + if len(globalConnections) == 1 { + for _, db := range globalConnections { + return db + } + } + if db, exists := globalConnections[fmt.Sprintf("%s", configurator.connection)]; exists { + return db + } + panic("no db found") +} + +func getSchemaFor(e Entity) *schema { + configurator := newEntityConfigurator() + c := getConnectionFor(e) + e.ConfigureEntity(configurator) + s := c.getSchema(configurator.table) + if s == nil { + s = schemaOfHeavyReflectionStuff(e) + c.setSchema(e, s) + } + return s +} + +type schema struct { + Connection string + Table string + fields []*field + relations map[string]interface{} + setPK func(o Entity, value interface{}) + getPK func(o Entity) interface{} + columnConstraints []*FieldConfigurator +} + +func (s *schema) getField(sf reflect.StructField) *field { + for _, f := range s.fields { + if sf.Name == f.Name { + return f + } + } + return nil +} + +func (s *schema) getDialect() *Dialect { + return GetConnection(s.Connection).Dialect +} +func (s *schema) Columns(withPK bool) []string { + var cols []string + for _, field := range s.fields { + if field.Virtual { + continue + } + if !withPK && field.IsPK { + continue + } + if s.getDialect().AddTableNameInSelectColumns { + cols = append(cols, s.Table+"."+field.Name) + } else { + cols = append(cols, field.Name) + } + } + return cols +} + +func (s *schema) pkName() string { + for _, field := range s.fields { + if field.IsPK { + return field.Name + } + } + return "" +} + +func genericFieldsOf(obj Entity) []*field { + t := reflect.TypeOf(obj) + for t.Kind() == reflect.Ptr { + t = t.Elem() + + } + if t.Kind() == reflect.Slice { + t = t.Elem() + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + } + var ec EntityConfigurator + obj.ConfigureEntity(&ec) + + var fms []*field + for i := 0; i < t.NumField(); i++ { + ft := t.Field(i) + fm := fieldMetadata(ft, ec.columnConstraints) + fms = append(fms, fm...) + } + return fms +} + +func valuesOfField(vf reflect.Value) []interface{} { + var values []interface{} + if vf.Type().Kind() == reflect.Struct || vf.Type().Kind() == reflect.Ptr { + t := vf.Type() + if vf.Type().Kind() == reflect.Ptr { + t = vf.Type().Elem() + } + if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { + // go into + // it does not implement driver.Valuer interface + for i := 0; i < vf.NumField(); i++ { + vif := vf.Field(i) + values = append(values, valuesOfField(vif)...) + } + } else { + values = append(values, vf.Interface()) + } + } else { + values = append(values, vf.Interface()) + } + return values +} +func genericValuesOf(o Entity, withPK bool) []interface{} { + t := reflect.TypeOf(o) + v := reflect.ValueOf(o) + if t.Kind() == reflect.Ptr { + t = t.Elem() + v = v.Elem() + } + fields := getSchemaFor(o).fields + pkIdx := -1 + for i, field := range fields { + if field.IsPK { + pkIdx = i + } + + } + + var values []interface{} + + for i := 0; i < t.NumField(); i++ { + if !withPK && i == pkIdx { + continue + } + if fields[i].Virtual { + continue + } + vf := v.Field(i) + values = append(values, valuesOfField(vf)...) + } + return values +} + +func genericSetPkValue(obj Entity, value interface{}) { + genericSet(obj, getSchemaFor(obj).pkName(), value) +} + +func genericGetPKValue(obj Entity) interface{} { + t := reflect.TypeOf(obj) + val := reflect.ValueOf(obj) + if t.Kind() == reflect.Ptr { + val = val.Elem() + } + + fields := getSchemaFor(obj).fields + for i, field := range fields { + if field.IsPK { + return val.Field(i).Interface() + } + } + return "" +} + +func (s *schema) createdAt() *field { + for _, f := range s.fields { + if f.IsCreatedAt { + return f + } + } + return nil +} +func (s *schema) updatedAt() *field { + for _, f := range s.fields { + if f.IsUpdatedAt { + return f + } + } + return nil +} + +func (s *schema) deletedAt() *field { + for _, f := range s.fields { + if f.IsDeletedAt { + return f + } + } + return nil +} +func pointersOf(v reflect.Value) map[string]interface{} { + m := map[string]interface{}{} + actualV := v + for actualV.Type().Kind() == reflect.Ptr { + actualV = actualV.Elem() + } + for i := 0; i < actualV.NumField(); i++ { + f := actualV.Field(i) + if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { + fm := pointersOf(f) + for k, p := range fm { + m[k] = p + } + } else { + fm := fieldMetadata(actualV.Type().Field(i), nil)[0] + m[fm.Name] = actualV.Field(i) + } + } + + return m +} +func genericSet(obj Entity, name string, value interface{}) { + n2p := pointersOf(reflect.ValueOf(obj)) + var val interface{} + for k, v := range n2p { + if k == name { + val = v + } + } + val.(reflect.Value).Set(reflect.ValueOf(value)) +} +func schemaOfHeavyReflectionStuff(v Entity) *schema { + userEntityConfigurator := newEntityConfigurator() + v.ConfigureEntity(userEntityConfigurator) + for _, relation := range userEntityConfigurator.resolveRelations { + relation() + } + schema := &schema{} + if userEntityConfigurator.connection != "" { + schema.Connection = userEntityConfigurator.connection + } + if userEntityConfigurator.table != "" { + schema.Table = userEntityConfigurator.table + } else { + panic("you need to have table name for getting schema.") + } + + schema.columnConstraints = userEntityConfigurator.columnConstraints + if schema.Connection == "" { + schema.Connection = "default" + } + if schema.fields == nil { + schema.fields = genericFieldsOf(v) + } + if schema.getPK == nil { + schema.getPK = genericGetPKValue + } + + if schema.setPK == nil { + schema.setPK = genericSetPkValue + } + + schema.relations = userEntityConfigurator.relations + + return schema +} + +func (s *schema) getTable() string { + return s.Table +} + +func (s *schema) getSQLDB() *sql.DB { + return s.getConnection().DB +} + +func (s *schema) getConnection() *connection { + if len(globalConnections) > 1 && (s.Connection == "" || s.Table == "") { + panic("need table and DB name when having more than 1 DB registered") + } + if len(globalConnections) == 1 { + for _, db := range globalConnections { + return db + } + } + if db, exists := globalConnections[fmt.Sprintf("%s", s.Connection)]; exists { + return db + } + panic("no db found") +} diff --git a/gdb/orm/schema_test.go b/gdb/orm/schema_test.go new file mode 100644 index 0000000..cdca8c4 --- /dev/null +++ b/gdb/orm/schema_test.go @@ -0,0 +1,76 @@ +// +// schema_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" + "testing" + + "github.com/stretchr/testify/assert" +) + +func setup(t *testing.T) { + db, err := sql.Open("sqlite3", ":memory:") + err = SetupConnections(ConnectionConfig{ + Name: "default", + DB: db, + Dialect: Dialects.SQLite3, + }) + // orm.Schematic() + _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) + _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) + _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) + _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`) + _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) + _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) + assert.NoError(t, err) +} + +type Object struct { + ID int64 + Name string + Timestamps +} + +func (o Object) ConfigureEntity(e *EntityConfigurator) { + e.Table("objects").Connection("default") +} + +func TestGenericFieldsOf(t *testing.T) { + t.Run("fields of with id and timestamps embedded", func(t *testing.T) { + fs := genericFieldsOf(&Object{}) + assert.Len(t, fs, 5) + assert.Equal(t, "id", fs[0].Name) + assert.True(t, fs[0].IsPK) + assert.Equal(t, "name", fs[1].Name) + assert.Equal(t, "created_at", fs[2].Name) + assert.Equal(t, "updated_at", fs[3].Name) + assert.Equal(t, "deleted_at", fs[4].Name) + }) +} + +func TestGenericValuesOf(t *testing.T) { + t.Run("values of", func(t *testing.T) { + + setup(t) + vs := genericValuesOf(Object{}, true) + assert.Len(t, vs, 5) + }) +} + +func TestEntityConfigurator(t *testing.T) { + t.Run("test has many with user provided values", func(t *testing.T) { + setup(t) + var ec EntityConfigurator + ec.Table("users").Connection("default").HasMany(Object{}, HasManyConfig{ + "objects", "user_id", + }) + + }) + +} diff --git a/gdb/orm/timestamps.go b/gdb/orm/timestamps.go new file mode 100644 index 0000000..6ad40b7 --- /dev/null +++ b/gdb/orm/timestamps.go @@ -0,0 +1,18 @@ +// +// timestamps.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package orm + +import ( + "database/sql" +) + +type Timestamps struct { + CreatedAt sql.NullTime + UpdatedAt sql.NullTime + DeletedAt sql.NullTime +}