From aa784c4187fc783f905d5bd1df9bb9e892478103 Mon Sep 17 00:00:00 2001 From: tiglog Date: Tue, 19 Sep 2023 10:45:38 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E7=A7=BB=E9=99=A4=20db=20=E5=BA=93?= =?UTF-8?q?=EF=BC=8C=E5=8F=AA=E4=BF=9D=E6=8C=81=E5=B7=B2=E6=9C=89=E4=BD=BF?= =?UTF-8?q?=E7=94=A8=E7=9A=84=20mgodb?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gdb/orm/.gitignore | 3 - gdb/orm/binder.go | 157 -- gdb/orm/binder_test.go | 92 - gdb/orm/configurators.go | 192 -- gdb/orm/connection.go | 160 -- gdb/orm/dialect.go | 109 -- gdb/orm/field.go | 75 - gdb/orm/orm.go | 699 ------- gdb/orm/orm_test.go | 588 ------ gdb/orm/query.go | 811 -------- gdb/orm/query_test.go | 243 --- gdb/orm/schema.go | 306 --- gdb/orm/schema_test.go | 77 - gdb/orm/timestamps.go | 21 - gdb/sqldb/column.go | 78 - gdb/sqldb/context_test.go | 82 - gdb/sqldb/db.go | 1049 ---------- gdb/sqldb/db_test.go | 187 -- gdb/sqldb/dialect.go | 108 -- gdb/sqldb/dialect_mysql.go | 172 -- gdb/sqldb/dialect_mysql_test.go | 195 -- gdb/sqldb/dialect_oracle.go | 142 -- gdb/sqldb/dialect_postgres.go | 152 -- gdb/sqldb/dialect_postgres_test.go | 161 -- gdb/sqldb/dialect_sqlite.go | 115 -- gdb/sqldb/doc.go | 13 - gdb/sqldb/errors.go | 34 - gdb/sqldb/hooks.go | 45 - gdb/sqldb/index.go | 51 - gdb/sqldb/lockerror.go | 59 - gdb/sqldb/logging.go | 45 - gdb/sqldb/nulltypes.go | 68 - gdb/sqldb/query_builder.go | 40 - gdb/sqldb/select.go | 361 ---- gdb/sqldb/sqldb.go | 675 ------- gdb/sqldb/sqldb_test.go | 2875 ---------------------------- gdb/sqldb/table.go | 258 --- gdb/sqldb/table_bindings.go | 308 --- gdb/sqldb/test_all.sh | 24 - gdb/sqldb/transaction.go | 242 --- gdb/sqldb/transaction_test.go | 340 ---- 41 files changed, 11412 deletions(-) delete mode 100644 gdb/orm/.gitignore delete mode 100644 gdb/orm/binder.go delete mode 100644 gdb/orm/binder_test.go delete mode 100644 gdb/orm/configurators.go delete mode 100644 gdb/orm/connection.go delete mode 100644 gdb/orm/dialect.go delete mode 100644 gdb/orm/field.go delete mode 100644 gdb/orm/orm.go delete mode 100644 gdb/orm/orm_test.go delete mode 100644 gdb/orm/query.go delete mode 100644 gdb/orm/query_test.go delete mode 100644 gdb/orm/schema.go delete mode 100644 gdb/orm/schema_test.go delete mode 100644 gdb/orm/timestamps.go delete mode 100644 gdb/sqldb/column.go delete mode 100644 gdb/sqldb/context_test.go delete mode 100644 gdb/sqldb/db.go delete mode 100644 gdb/sqldb/db_test.go delete mode 100644 gdb/sqldb/dialect.go delete mode 100644 gdb/sqldb/dialect_mysql.go delete mode 100644 gdb/sqldb/dialect_mysql_test.go delete mode 100644 gdb/sqldb/dialect_oracle.go delete mode 100644 gdb/sqldb/dialect_postgres.go delete mode 100644 gdb/sqldb/dialect_postgres_test.go delete mode 100644 gdb/sqldb/dialect_sqlite.go delete mode 100644 gdb/sqldb/doc.go delete mode 100644 gdb/sqldb/errors.go delete mode 100644 gdb/sqldb/hooks.go delete mode 100644 gdb/sqldb/index.go delete mode 100644 gdb/sqldb/lockerror.go delete mode 100644 gdb/sqldb/logging.go delete mode 100644 gdb/sqldb/nulltypes.go delete mode 100644 gdb/sqldb/query_builder.go delete mode 100644 gdb/sqldb/select.go delete mode 100644 gdb/sqldb/sqldb.go delete mode 100644 gdb/sqldb/sqldb_test.go delete mode 100644 gdb/sqldb/table.go delete mode 100644 gdb/sqldb/table_bindings.go delete mode 100755 gdb/sqldb/test_all.sh delete mode 100644 gdb/sqldb/transaction.go delete mode 100644 gdb/sqldb/transaction_test.go diff --git a/gdb/orm/.gitignore b/gdb/orm/.gitignore deleted file mode 100644 index e702335..0000000 --- a/gdb/orm/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -**/.idea/* -cover.out -**db \ No newline at end of file diff --git a/gdb/orm/binder.go b/gdb/orm/binder.go deleted file mode 100644 index 3a65598..0000000 --- a/gdb/orm/binder.go +++ /dev/null @@ -1,157 +0,0 @@ -// -// binder.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "reflect" - "unsafe" -) - -// makeNewPointersOf creates a map of [field name] -> pointer to fill it -// recursively. it will go down until reaches a driver.Valuer implementation, it will stop there. -func (b *binder) makeNewPointersOf(v reflect.Value) interface{} { - m := map[string]interface{}{} - actualV := v - for actualV.Type().Kind() == reflect.Ptr { - actualV = actualV.Elem() - } - if actualV.Type().Kind() == reflect.Struct { - for i := 0; i < actualV.NumField(); i++ { - f := actualV.Field(i) - if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { - f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr())) - fm := b.makeNewPointersOf(f).(map[string]interface{}) - for k, p := range fm { - m[k] = p - } - } else { - var fm *field - fm = b.s.getField(actualV.Type().Field(i)) - if fm == nil { - fm = fieldMetadata(actualV.Type().Field(i), b.s.columnConstraints)[0] - } - m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface() - } - } - } else { - return v.Addr().Interface() - } - - return m -} - -// ptrsFor first allocates for all struct fields recursively until reaches a driver.Value impl -// then it will put them in a map with their correct field name as key, then loops over cts -// and for each one gets appropriate one from the map and adds it to pointer list. -func (b *binder) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} { - ptrs := b.makeNewPointersOf(v) - var scanInto []interface{} - if reflect.TypeOf(ptrs).Kind() == reflect.Map { - nameToPtr := ptrs.(map[string]interface{}) - for _, ct := range cts { - if nameToPtr[ct.Name()] != nil { - scanInto = append(scanInto, nameToPtr[ct.Name()]) - } - } - } else { - scanInto = append(scanInto, ptrs) - } - - return scanInto -} - -type binder struct { - s *schema -} - -func newBinder(s *schema) *binder { - return &binder{s: s} -} - -// bind binds given rows to the given object at obj. obj should be a pointer -func (b *binder) bind(rows *sql.Rows, obj interface{}) error { - cts, err := rows.ColumnTypes() - if err != nil { - return err - } - - t := reflect.TypeOf(obj) - v := reflect.ValueOf(obj) - if t.Kind() != reflect.Ptr { - return fmt.Errorf("obj should be a ptr") - } - // since passed input is always a pointer one deref is necessary - t = t.Elem() - v = v.Elem() - if t.Kind() == reflect.Slice { - // getting slice elemnt type -> slice[t] - t = t.Elem() - for rows.Next() { - var rowValue reflect.Value - // Since reflect.SetupConnections returns a pointer to the type, we need to unwrap it to get actual - rowValue = reflect.New(t).Elem() - // till we reach a not pointer type continue newing the underlying type. - for rowValue.IsZero() && rowValue.Type().Kind() == reflect.Ptr { - rowValue = reflect.New(rowValue.Type().Elem()).Elem() - } - newCts := make([]*sql.ColumnType, len(cts)) - copy(newCts, cts) - ptrs := b.ptrsFor(rowValue, newCts) - err = rows.Scan(ptrs...) - if err != nil { - return err - } - for rowValue.Type() != t { - tmp := reflect.New(rowValue.Type()) - tmp.Elem().Set(rowValue) - rowValue = tmp - } - v = reflect.Append(v, rowValue) - } - } else { - for rows.Next() { - ptrs := b.ptrsFor(v, cts) - err = rows.Scan(ptrs...) - if err != nil { - return err - } - } - } - // v is either struct or slice - reflect.ValueOf(obj).Elem().Set(v) - return nil -} - -func bindToMap(rows *sql.Rows) ([]map[string]interface{}, error) { - cts, err := rows.ColumnTypes() - if err != nil { - return nil, err - } - var ms []map[string]interface{} - for rows.Next() { - var ptrs []interface{} - for _, ct := range cts { - ptrs = append(ptrs, reflect.New(ct.ScanType()).Interface()) - } - - err = rows.Scan(ptrs...) - if err != nil { - return nil, err - } - m := map[string]interface{}{} - for i, ptr := range ptrs { - m[cts[i].Name()] = reflect.ValueOf(ptr).Elem().Interface() - } - - ms = append(ms, m) - } - return ms, nil -} diff --git a/gdb/orm/binder_test.go b/gdb/orm/binder_test.go deleted file mode 100644 index 437032c..0000000 --- a/gdb/orm/binder_test.go +++ /dev/null @@ -1,92 +0,0 @@ -// -// binder_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "testing" - "time" - - "github.com/DATA-DOG/go-sqlmock" - - _ "github.com/lib/pq" - "github.com/stretchr/testify/assert" -) - -type User struct { - ID int64 - Name string - Timestamps - SoftDelete -} - -func (u User) ConfigureEntity(e *EntityConfigurator) { - e.Table("users") -} - -type Address struct { - ID int - Path string -} - -func TestBind(t *testing.T) { - t.Run("single result", func(t *testing.T) { - db, mock, err := sqlmock.New() - assert.NoError(t, err) - mock. - ExpectQuery("SELECT .* FROM users"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}). - AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{})) - rows, err := db.Query(`SELECT * FROM users`) - assert.NoError(t, err) - - u := &User{} - md := schemaOfHeavyReflectionStuff(u) - err = newBinder(md).bind(rows, u) - assert.NoError(t, err) - - assert.Equal(t, "amirreza", u.Name) - }) - - t.Run("multi result", func(t *testing.T) { - db, mock, err := sqlmock.New() - assert.NoError(t, err) - mock. - ExpectQuery("SELECT .* FROM users"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "amirreza").AddRow(2, "milad")) - - rows, err := db.Query(`SELECT * FROM users`) - assert.NoError(t, err) - - md := schemaOfHeavyReflectionStuff(&User{}) - var users []*User - err = newBinder(md).bind(rows, &users) - assert.NoError(t, err) - - assert.Equal(t, "amirreza", users[0].Name) - assert.Equal(t, "milad", users[1].Name) - }) -} - -func TestBindMap(t *testing.T) { - db, mock, err := sqlmock.New() - assert.NoError(t, err) - mock. - ExpectQuery("SELECT .* FROM users"). - WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}). - AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{})) - rows, err := db.Query(`SELECT * FROM users`) - assert.NoError(t, err) - - ms, err := bindToMap(rows) - - assert.NoError(t, err) - assert.NotEmpty(t, ms) - - assert.Len(t, ms, 1) -} diff --git a/gdb/orm/configurators.go b/gdb/orm/configurators.go deleted file mode 100644 index 073a726..0000000 --- a/gdb/orm/configurators.go +++ /dev/null @@ -1,192 +0,0 @@ -// -// configurators.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - - "git.hexq.cn/tiglog/golib/helper" -) - -type EntityConfigurator struct { - connection string - table string - this Entity - relations map[string]interface{} - resolveRelations []func() - columnConstraints []*FieldConfigurator -} - -func newEntityConfigurator() *EntityConfigurator { - return &EntityConfigurator{} -} - -func (ec *EntityConfigurator) Table(name string) *EntityConfigurator { - ec.table = name - return ec -} - -func (ec *EntityConfigurator) Connection(name string) *EntityConfigurator { - ec.connection = name - return ec -} - -func (ec *EntityConfigurator) HasMany(property Entity, config HasManyConfig) *EntityConfigurator { - if ec.relations == nil { - ec.relations = map[string]interface{}{} - } - ec.resolveRelations = append(ec.resolveRelations, func() { - if config.PropertyForeignKey != "" && config.PropertyTable != "" { - ec.relations[config.PropertyTable] = config - return - } - configurator := newEntityConfigurator() - property.ConfigureEntity(configurator) - - if config.PropertyTable == "" { - config.PropertyTable = configurator.table - } - - if config.PropertyForeignKey == "" { - config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id" - } - - ec.relations[configurator.table] = config - - return - }) - return ec -} - -func (ec *EntityConfigurator) HasOne(property Entity, config HasOneConfig) *EntityConfigurator { - if ec.relations == nil { - ec.relations = map[string]interface{}{} - } - ec.resolveRelations = append(ec.resolveRelations, func() { - if config.PropertyForeignKey != "" && config.PropertyTable != "" { - ec.relations[config.PropertyTable] = config - return - } - - configurator := newEntityConfigurator() - property.ConfigureEntity(configurator) - - if config.PropertyTable == "" { - config.PropertyTable = configurator.table - } - if config.PropertyForeignKey == "" { - config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id" - } - - ec.relations[configurator.table] = config - return - }) - return ec -} - -func (ec *EntityConfigurator) BelongsTo(owner Entity, config BelongsToConfig) *EntityConfigurator { - if ec.relations == nil { - ec.relations = map[string]interface{}{} - } - ec.resolveRelations = append(ec.resolveRelations, func() { - if config.ForeignColumnName != "" && config.LocalForeignKey != "" && config.OwnerTable != "" { - ec.relations[config.OwnerTable] = config - return - } - ownerConfigurator := newEntityConfigurator() - owner.ConfigureEntity(ownerConfigurator) - if config.OwnerTable == "" { - config.OwnerTable = ownerConfigurator.table - } - if config.LocalForeignKey == "" { - config.LocalForeignKey = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id" - } - if config.ForeignColumnName == "" { - config.ForeignColumnName = "id" - } - ec.relations[ownerConfigurator.table] = config - }) - return ec -} - -func (ec *EntityConfigurator) BelongsToMany(owner Entity, config BelongsToManyConfig) *EntityConfigurator { - if ec.relations == nil { - ec.relations = map[string]interface{}{} - } - ec.resolveRelations = append(ec.resolveRelations, func() { - ownerConfigurator := newEntityConfigurator() - owner.ConfigureEntity(ownerConfigurator) - - if config.OwnerLookupColumn == "" { - var pkName string - for _, field := range genericFieldsOf(owner) { - if field.IsPK { - pkName = field.Name - } - } - config.OwnerLookupColumn = pkName - - } - if config.OwnerTable == "" { - config.OwnerTable = ownerConfigurator.table - } - if config.IntermediateTable == "" { - panic("cannot infer intermediate table yet") - } - if config.IntermediatePropertyID == "" { - config.IntermediatePropertyID = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id" - } - if config.IntermediateOwnerID == "" { - config.IntermediateOwnerID = helper.NewPluralizeClient().Singular(ec.table) + "_id" - } - - ec.relations[ownerConfigurator.table] = config - }) - return ec -} - -type FieldConfigurator struct { - fieldName string - nullable sql.NullBool - primaryKey bool - column string - isCreatedAt bool - isUpdatedAt bool - isDeletedAt bool -} - -func (ec *EntityConfigurator) Field(name string) *FieldConfigurator { - cc := &FieldConfigurator{fieldName: name} - ec.columnConstraints = append(ec.columnConstraints, cc) - return cc -} - -func (fc *FieldConfigurator) IsPrimaryKey() *FieldConfigurator { - fc.primaryKey = true - return fc -} - -func (fc *FieldConfigurator) IsCreatedAt() *FieldConfigurator { - fc.isCreatedAt = true - return fc -} - -func (fc *FieldConfigurator) IsUpdatedAt() *FieldConfigurator { - fc.isUpdatedAt = true - return fc -} - -func (fc *FieldConfigurator) IsDeletedAt() *FieldConfigurator { - fc.isDeletedAt = true - return fc -} - -func (fc *FieldConfigurator) ColumnName(name string) *FieldConfigurator { - fc.column = name - return fc -} diff --git a/gdb/orm/connection.go b/gdb/orm/connection.go deleted file mode 100644 index dde2388..0000000 --- a/gdb/orm/connection.go +++ /dev/null @@ -1,160 +0,0 @@ -// -// connection.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "fmt" - - "git.hexq.cn/tiglog/golib/helper/table" -) - -type connection struct { - Name string - Dialect *Dialect - DB *sql.DB - Schemas map[string]*schema - DBSchema map[string][]columnSpec - DatabaseValidations bool -} - -func (c *connection) inferedTables() []string { - var tables []string - for t, s := range c.Schemas { - tables = append(tables, t) - for _, relC := range s.relations { - if belongsToManyConfig, is := relC.(BelongsToManyConfig); is { - tables = append(tables, belongsToManyConfig.IntermediateTable) - } - } - } - return tables -} - -func (c *connection) validateAllTablesArePresent() error { - for _, inferedTable := range c.inferedTables() { - if _, exists := c.DBSchema[inferedTable]; !exists { - return fmt.Errorf("orm infered %s but it's not found in your database, your database is out of sync", inferedTable) - } - } - return nil -} - -func (c *connection) validateTablesSchemas() error { - // check for entity tables: there should not be any struct field that does not have a coresponding column - for table, sc := range c.Schemas { - if columns, exists := c.DBSchema[table]; exists { - for _, f := range sc.fields { - found := false - for _, c := range columns { - if c.Name == f.Name { - found = true - } - } - if !found { - return fmt.Errorf("column %s not found while it was inferred", f.Name) - } - } - } else { - return fmt.Errorf("tables are out of sync, %s was inferred but not present in database", table) - } - } - - // check for relation tables: for HasMany,HasOne relations check if OWNER pk column is in PROPERTY, - // for BelongsToMany check intermediate table has 2 pk for two entities - - for table, sc := range c.Schemas { - for _, rel := range sc.relations { - switch rel.(type) { - case BelongsToConfig: - columns := c.DBSchema[table] - var found bool - for _, col := range columns { - if col.Name == rel.(BelongsToConfig).LocalForeignKey { - found = true - } - } - if !found { - return fmt.Errorf("cannot find local foreign key %s for relation", rel.(BelongsToConfig).LocalForeignKey) - } - case BelongsToManyConfig: - columns := c.DBSchema[rel.(BelongsToManyConfig).IntermediateTable] - var foundOwner bool - var foundProperty bool - - for _, col := range columns { - if col.Name == rel.(BelongsToManyConfig).IntermediateOwnerID { - foundOwner = true - } - if col.Name == rel.(BelongsToManyConfig).IntermediatePropertyID { - foundProperty = true - } - } - if !foundOwner || !foundProperty { - return fmt.Errorf("table schema for %s is not correct one of foreign keys is not present", rel.(BelongsToManyConfig).IntermediateTable) - } - } - } - } - - return nil -} - -func (c *connection) Schematic() { - fmt.Printf("SQL Dialect: %s\n", c.Dialect.DriverName) - for t, schema := range c.Schemas { - fmt.Printf("t: %s\n", t) - w := table.NewWriter() - w.AppendHeader(table.Row{"SQL Name", "Type", "Is Primary Key", "Is Virtual"}) - for _, field := range schema.fields { - w.AppendRow(table.Row{field.Name, field.Type, field.IsPK, field.Virtual}) - } - fmt.Println(w.Render()) - for _, rel := range schema.relations { - switch rel.(type) { - case HasOneConfig: - fmt.Printf("%s 1-1 %s => %+v\n", t, rel.(HasOneConfig).PropertyTable, rel) - case HasManyConfig: - fmt.Printf("%s 1-N %s => %+v\n", t, rel.(HasManyConfig).PropertyTable, rel) - - case BelongsToConfig: - fmt.Printf("%s N-1 %s => %+v\n", t, rel.(BelongsToConfig).OwnerTable, rel) - - case BelongsToManyConfig: - fmt.Printf("%s N-N %s => %+v\n", t, rel.(BelongsToManyConfig).IntermediateTable, rel) - } - } - fmt.Println("") - } -} - -func (c *connection) getSchema(t string) *schema { - return c.Schemas[t] -} - -func (c *connection) setSchema(e Entity, s *schema) { - var configurator EntityConfigurator - e.ConfigureEntity(&configurator) - c.Schemas[configurator.table] = s -} - -func GetConnection(name string) *connection { - return globalConnections[name] -} - -func (c *connection) exec(q string, args ...any) (sql.Result, error) { - return c.DB.Exec(q, args...) -} - -func (c *connection) query(q string, args ...any) (*sql.Rows, error) { - return c.DB.Query(q, args...) -} - -func (c *connection) queryRow(q string, args ...any) *sql.Row { - return c.DB.QueryRow(q, args...) -} diff --git a/gdb/orm/dialect.go b/gdb/orm/dialect.go deleted file mode 100644 index 67e6a67..0000000 --- a/gdb/orm/dialect.go +++ /dev/null @@ -1,109 +0,0 @@ -// -// dialect.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "fmt" -) - -type Dialect struct { - DriverName string - PlaceholderChar string - IncludeIndexInPlaceholder bool - AddTableNameInSelectColumns bool - PlaceHolderGenerator func(n int) []string - QueryListTables string - QueryTableSchema string -} - -func getListOfTables(query string) func(db *sql.DB) ([]string, error) { - return func(db *sql.DB) ([]string, error) { - rows, err := db.Query(query) - if err != nil { - return nil, err - } - var tables []string - for rows.Next() { - var table string - err = rows.Scan(&table) - if err != nil { - return nil, err - } - tables = append(tables, table) - } - return tables, nil - } -} - -type columnSpec struct { - //0|id|INTEGER|0||1 - Name string - Type string - Nullable bool - DefaultValue sql.NullString - IsPrimaryKey bool -} - -func getTableSchema(query string) func(db *sql.DB, query string) ([]columnSpec, error) { - return func(db *sql.DB, table string) ([]columnSpec, error) { - rows, err := db.Query(fmt.Sprintf(query, table)) - if err != nil { - return nil, err - } - var output []columnSpec - for rows.Next() { - var cs columnSpec - var nullable string - var pk int - err = rows.Scan(&cs.Name, &cs.Type, &nullable, &cs.DefaultValue, &pk) - if err != nil { - return nil, err - } - cs.Nullable = nullable == "notnull" - cs.IsPrimaryKey = pk == 1 - output = append(output, cs) - } - return output, nil - - } -} - -var Dialects = &struct { - MySQL *Dialect - PostgreSQL *Dialect - SQLite3 *Dialect -}{ - MySQL: &Dialect{ - DriverName: "mysql", - PlaceholderChar: "?", - IncludeIndexInPlaceholder: false, - AddTableNameInSelectColumns: true, - PlaceHolderGenerator: questionMarks, - QueryListTables: "SHOW TABLES", - QueryTableSchema: "DESCRIBE %s", - }, - PostgreSQL: &Dialect{ - DriverName: "postgres", - PlaceholderChar: "$", - IncludeIndexInPlaceholder: true, - AddTableNameInSelectColumns: true, - PlaceHolderGenerator: postgresPlaceholder, - QueryListTables: `\dt`, - QueryTableSchema: `\d %s`, - }, - SQLite3: &Dialect{ - DriverName: "sqlite3", - PlaceholderChar: "?", - IncludeIndexInPlaceholder: false, - AddTableNameInSelectColumns: false, - PlaceHolderGenerator: questionMarks, - QueryListTables: "SELECT name FROM sqlite_schema WHERE type='table'", - QueryTableSchema: `SELECT name,type,"notnull","dflt_value","pk" FROM PRAGMA_TABLE_INFO('%s')`, - }, -} diff --git a/gdb/orm/field.go b/gdb/orm/field.go deleted file mode 100644 index aad677c..0000000 --- a/gdb/orm/field.go +++ /dev/null @@ -1,75 +0,0 @@ -// -// field.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql/driver" - "reflect" - "strings" - - "git.hexq.cn/tiglog/golib/helper" -) - -type field struct { - Name string - IsPK bool - Virtual bool - IsCreatedAt bool - IsUpdatedAt bool - IsDeletedAt bool - Nullable bool - Default any - Type reflect.Type -} - -func getFieldConfiguratorFor(fieldConfigurators []*FieldConfigurator, name string) *FieldConfigurator { - for _, fc := range fieldConfigurators { - if fc.fieldName == name { - return fc - } - } - return &FieldConfigurator{} -} - -func fieldMetadata(ft reflect.StructField, fieldConfigurators []*FieldConfigurator) []*field { - var fms []*field - fc := getFieldConfiguratorFor(fieldConfigurators, ft.Name) - baseFm := &field{} - baseFm.Type = ft.Type - fms = append(fms, baseFm) - if fc.column != "" { - baseFm.Name = fc.column - } else { - baseFm.Name = helper.SnakeString(ft.Name) - } - if strings.ToLower(ft.Name) == "id" || fc.primaryKey { - baseFm.IsPK = true - } - if strings.ToLower(ft.Name) == "createdat" || fc.isCreatedAt { - baseFm.IsCreatedAt = true - } - if strings.ToLower(ft.Name) == "updatedat" || fc.isUpdatedAt { - baseFm.IsUpdatedAt = true - } - if strings.ToLower(ft.Name) == "deletedat" || fc.isDeletedAt { - baseFm.IsDeletedAt = true - } - if ft.Type.Kind() == reflect.Struct || ft.Type.Kind() == reflect.Ptr { - t := ft.Type - if ft.Type.Kind() == reflect.Ptr { - t = ft.Type.Elem() - } - if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { - for i := 0; i < t.NumField(); i++ { - fms = append(fms, fieldMetadata(t.Field(i), fieldConfigurators)...) - } - fms = fms[1:] - } - } - return fms -} diff --git a/gdb/orm/orm.go b/gdb/orm/orm.go deleted file mode 100644 index 9dee1bb..0000000 --- a/gdb/orm/orm.go +++ /dev/null @@ -1,699 +0,0 @@ -// -// orm.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "fmt" - "reflect" - "time" - - // Drivers - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" -) - -var globalConnections = map[string]*connection{} - -// Schematic prints all information ORM inferred from your entities in startup, remember to pass -// your entities in Entities when you call SetupConnections if you want their data inferred -// otherwise Schematic does not print correct data since GoLobby ORM also -// incrementally cache your entities metadata and schema. -func Schematic() { - for conn, connObj := range globalConnections { - fmt.Printf("----------------%s---------------\n", conn) - connObj.Schematic() - fmt.Println("-----------------------------------") - } -} - -type ConnectionConfig struct { - // Name of your database connection, it's up to you to name them anything - // just remember that having a connection name is mandatory if - // you have multiple connections - Name string - - // 不是必需,不指定 Dialect 时必需 - Driver string - // If you already have an active database connection configured pass it in this value and - // do not pass Driver and DSN fields. - DB *sql.DB - // Which dialect of sql to generate queries for, you don't need it most of the times when you are using - // traditional databases such as mysql, sqlite3, postgres. - Dialect *Dialect - // List of entities that you want to use for this connection, remember that you can ignore this field - // and GoLobby ORM will build our metadata cache incrementally but you will lose schematic - // information that we can provide you and also potentialy validations that we - // can do with the database - Entities []Entity - // Database validations, check if all tables exists and also table schemas contains all necessary columns. - // Check if all infered tables exist in your database - DatabaseValidations bool -} - -type DbOption struct { - Type string - Dsn string - MaxIdle int - MaxOpen int -} - -func NewDb(opt DbOption) (*sql.DB, error) { - db, err := sql.Open(opt.Type, opt.Dsn) - if err != nil { - return nil, err - } - if opt.Type == "sqlite3" { - if opt.Dsn == ":memory:" { - db.SetMaxOpenConns(1) - } - } else { - db.SetMaxIdleConns(opt.MaxIdle) - db.SetMaxOpenConns(opt.MaxOpen) - } - return db, nil -} - -func getDialectByDriver(driver string) (*Dialect, error) { - switch driver { - case "postgres": - return Dialects.PostgreSQL, nil - case "mysql": - return Dialects.MySQL, nil - case "sqlite3": - return Dialects.SQLite3, nil - } - return nil, fmt.Errorf("unsupported db driver %s", driver) -} - -// SetupConnections declares a new connections for ORM. -func SetupConnections(configs ...ConnectionConfig) error { - - for _, c := range configs { - if err := setupConnection(c); err != nil { - return err - } - } - for _, conn := range globalConnections { - if !conn.DatabaseValidations { - continue - } - - tables, err := getListOfTables(conn.Dialect.QueryListTables)(conn.DB) - if err != nil { - return err - } - - for _, table := range tables { - if conn.DatabaseValidations { - spec, err := getTableSchema(conn.Dialect.QueryTableSchema)(conn.DB, table) - if err != nil { - return err - } - conn.DBSchema[table] = spec - } else { - conn.DBSchema[table] = nil - } - } - - // check tables existence - if conn.DatabaseValidations { - err := conn.validateAllTablesArePresent() - if err != nil { - return err - } - } - - if conn.DatabaseValidations { - err = conn.validateTablesSchemas() - if err != nil { - return err - } - } - - } - - return nil -} - -func setupConnection(config ConnectionConfig) error { - schemas := map[string]*schema{} - if config.Name == "" { - config.Name = "default" - } - if config.Dialect == nil { - dialect, err := getDialectByDriver(config.Driver) - if err != nil { - return err - } - config.Dialect = dialect - } - - for _, entity := range config.Entities { - s := schemaOfHeavyReflectionStuff(entity) - var configurator EntityConfigurator - entity.ConfigureEntity(&configurator) - schemas[configurator.table] = s - } - - s := &connection{ - Name: config.Name, - DB: config.DB, - Dialect: config.Dialect, - Schemas: schemas, - DBSchema: make(map[string][]columnSpec), - DatabaseValidations: config.DatabaseValidations, - } - - globalConnections[fmt.Sprintf("%s", config.Name)] = s - - return nil -} - -// Entity defines the interface that each of your structs that -// you want to use as database entities should have, -// it's a simple one and its ConfigureEntity. -type Entity interface { - // ConfigureEntity should be defined for all of your database entities - // and it can define Table, DB and also relations of your Entity. - ConfigureEntity(e *EntityConfigurator) -} - -// InsertAll given entities into database based on their ConfigureEntity -// we can find table and also DB name. -func InsertAll(objs ...Entity) error { - if len(objs) == 0 { - return nil - } - s := getSchemaFor(objs[0]) - cols := s.Columns(false) - var values [][]interface{} - for _, obj := range objs { - createdAtF := s.createdAt() - if createdAtF != nil { - genericSet(obj, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) - } - updatedAtF := s.updatedAt() - if updatedAtF != nil { - genericSet(obj, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) - } - values = append(values, genericValuesOf(obj, false)) - } - - is := insertStmt{ - PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator, - Table: s.getTable(), - Columns: cols, - Values: values, - } - - q, args := is.ToSql() - - _, err := s.getConnection().exec(q, args...) - if err != nil { - return err - } - return nil -} - -// Insert given entity into database based on their ConfigureEntity -// we can find table and also DB name. -func Insert(o Entity) error { - s := getSchemaFor(o) - cols := s.Columns(false) - var values [][]interface{} - createdAtF := s.createdAt() - if createdAtF != nil { - genericSet(o, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) - } - updatedAtF := s.updatedAt() - if updatedAtF != nil { - genericSet(o, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true}) - } - values = append(values, genericValuesOf(o, false)) - - is := insertStmt{ - PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator, - Table: s.getTable(), - Columns: cols, - Values: values, - } - - if s.getDialect().DriverName == "postgres" { - is.Returning = s.pkName() - } - q, args := is.ToSql() - - res, err := s.getConnection().exec(q, args...) - if err != nil { - return err - } - id, err := res.LastInsertId() - if err != nil { - return err - } - - if s.pkName() != "" { - // intermediate tables usually have no single pk column. - s.setPK(o, id) - } - return nil -} - -func isZero(val interface{}) bool { - switch val.(type) { - case int64: - return val.(int64) == 0 - case int: - return val.(int) == 0 - case string: - return val.(string) == "" - default: - return reflect.ValueOf(val).Elem().IsZero() - } -} - -// Save saves given entity, if primary key is set -// we will make an update query and if -// primary key is zero value we will -// insert it. -func Save(obj Entity) error { - if isZero(getSchemaFor(obj).getPK(obj)) { - return Insert(obj) - } else { - return Update(obj) - } -} - -// Find finds the Entity you want based on generic type and primary key you passed. -func Find[T Entity](id interface{}) (T, error) { - var q string - out := new(T) - md := getSchemaFor(*out) - q, args, err := NewQueryBuilder[T](md). - SetDialect(md.getDialect()). - Table(md.Table). - Select(md.Columns(true)...). - Where(md.pkName(), id). - ToSql() - if err != nil { - return *out, err - } - err = bind[T](out, q, args) - - if err != nil { - return *out, err - } - - return *out, nil -} - -func toKeyValues(obj Entity, withPK bool) []any { - var tuples []any - vs := genericValuesOf(obj, withPK) - cols := getSchemaFor(obj).Columns(withPK) - for i, col := range cols { - tuples = append(tuples, col, vs[i]) - } - return tuples -} - -// Update given Entity in database. -func Update(obj Entity) error { - s := getSchemaFor(obj) - q, args, err := NewQueryBuilder[Entity](s). - SetDialect(s.getDialect()). - Set(toKeyValues(obj, false)...). - Where(s.pkName(), genericGetPKValue(obj)).Table(s.Table).ToSql() - - if err != nil { - return err - } - _, err = s.getConnection().exec(q, args...) - return err -} - -// Delete given Entity from database -func Delete(obj Entity) error { - s := getSchemaFor(obj) - genericSet(obj, "deleted_at", sql.NullTime{Time: time.Now(), Valid: true}) - query, args, err := NewQueryBuilder[Entity](s).SetDialect(s.getDialect()).Table(s.Table).Where(s.pkName(), genericGetPKValue(obj)).SetDelete().ToSql() - if err != nil { - return err - } - _, err = s.getConnection().exec(query, args...) - return err -} - -func bind[T Entity](output interface{}, q string, args []interface{}) error { - outputMD := getSchemaFor(*new(T)) - rows, err := outputMD.getConnection().query(q, args...) - if err != nil { - return err - } - return newBinder(outputMD).bind(rows, output) -} - -// HasManyConfig contains all information we need for querying HasMany relationships. -// We can infer both fields if you have them in standard way but you -// can specify them if you want custom ones. -type HasManyConfig struct { - // PropertyTable is table of the property of HasMany relationship, - // consider `Comment` in Post and Comment relationship, - // each Post HasMany Comment, so PropertyTable is - // `comments`. - PropertyTable string - // PropertyForeignKey is the foreign key field name in the property table, - // for example in Post HasMany Comment, if comment has `post_id` field, - // it's the PropertyForeignKey field. - PropertyForeignKey string -} - -// HasMany configures a QueryBuilder for a HasMany relationship -// this relationship will be defined for owner argument -// that has many of PROPERTY generic type for example -// HasMany[Comment](&Post{}) -// is for Post HasMany Comment relationship. -func HasMany[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] { - outSchema := getSchemaFor(*new(PROPERTY)) - - q := NewQueryBuilder[PROPERTY](outSchema) - // getting config from our cache - c, ok := getSchemaFor(owner).relations[outSchema.Table].(HasManyConfig) - if !ok { - q.err = fmt.Errorf("wrong config passed for HasMany") - } - - s := getSchemaFor(owner) - return q. - SetDialect(s.getDialect()). - Table(c.PropertyTable). - Select(outSchema.Columns(true)...). - Where(c.PropertyForeignKey, genericGetPKValue(owner)) -} - -// HasOneConfig contains all information we need for a HasOne relationship, -// it's similar to HasManyConfig. -type HasOneConfig struct { - // PropertyTable is table of the property of HasOne relationship, - // consider `HeaderPicture` in Post and HeaderPicture relationship, - // each Post HasOne HeaderPicture, so PropertyTable is - // `header_pictures`. - PropertyTable string - // PropertyForeignKey is the foreign key field name in the property table, - // forexample in Post HasOne HeaderPicture, if header_picture has `post_id` field, - // it's the PropertyForeignKey field. - PropertyForeignKey string -} - -// HasOne configures a QueryBuilder for a HasOne relationship -// this relationship will be defined for owner argument -// that has one of PROPERTY generic type for example -// HasOne[HeaderPicture](&Post{}) -// is for Post HasOne HeaderPicture relationship. -func HasOne[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] { - property := getSchemaFor(*new(PROPERTY)) - q := NewQueryBuilder[PROPERTY](property) - c, ok := getSchemaFor(owner).relations[property.Table].(HasOneConfig) - if !ok { - q.err = fmt.Errorf("wrong config passed for HasOne") - } - - // settings default config Values - return q. - SetDialect(property.getDialect()). - Table(c.PropertyTable). - Select(property.Columns(true)...). - Where(c.PropertyForeignKey, genericGetPKValue(owner)) -} - -// BelongsToConfig contains all information we need for a BelongsTo relationship -// BelongsTo is a relationship between a Comment and it's Post, -// A Comment BelongsTo Post. -type BelongsToConfig struct { - // OwnerTable is the table that contains owner of a BelongsTo - // relationship. - OwnerTable string - // LocalForeignKey is name of the field that links property - // to its owner in BelongsTo relation. for example when - // a Comment BelongsTo Post, LocalForeignKey is - // post_id of Comment. - LocalForeignKey string - // ForeignColumnName is name of the field that LocalForeignKey - // field value will point to it, for example when - // a Comment BelongsTo Post, ForeignColumnName is - // id of Post. - ForeignColumnName string -} - -// BelongsTo configures a QueryBuilder for a BelongsTo relationship between -// OWNER type parameter and property argument, so -// property BelongsTo OWNER. -func BelongsTo[OWNER Entity](property Entity) *QueryBuilder[OWNER] { - owner := getSchemaFor(*new(OWNER)) - q := NewQueryBuilder[OWNER](owner) - c, ok := getSchemaFor(property).relations[owner.Table].(BelongsToConfig) - if !ok { - q.err = fmt.Errorf("wrong config passed for BelongsTo") - } - - ownerIDidx := 0 - for idx, field := range owner.fields { - if field.Name == c.LocalForeignKey { - ownerIDidx = idx - } - } - - ownerID := genericValuesOf(property, true)[ownerIDidx] - - return q. - SetDialect(owner.getDialect()). - Table(c.OwnerTable).Select(owner.Columns(true)...). - Where(c.ForeignColumnName, ownerID) - -} - -// BelongsToManyConfig contains information that we -// need for creating many to many queries. -type BelongsToManyConfig struct { - // IntermediateTable is the name of the middle table - // in a BelongsToMany (Many to Many) relationship. - // for example when we have Post BelongsToMany - // Category, this table will be post_categories - // table, remember that this field cannot be - // inferred. - IntermediateTable string - // IntermediatePropertyID is the name of the field name - // of property foreign key in intermediate table, - // for example when we have Post BelongsToMany - // Category, in post_categories table, it would - // be post_id. - IntermediatePropertyID string - // IntermediateOwnerID is the name of the field name - // of property foreign key in intermediate table, - // for example when we have Post BelongsToMany - // Category, in post_categories table, it would - // be category_id. - IntermediateOwnerID string - // Table name of the owner in BelongsToMany relation, - // for example in Post BelongsToMany Category - // Owner table is name of Category table - // for example `categories`. - OwnerTable string - // OwnerLookupColumn is name of the field in the owner - // table that is used in query, for example in Post BelongsToMany Category - // Owner lookup field would be Category primary key which is id. - OwnerLookupColumn string -} - -// BelongsToMany configures a QueryBuilder for a BelongsToMany relationship -func BelongsToMany[OWNER Entity](property Entity) *QueryBuilder[OWNER] { - out := *new(OWNER) - outSchema := getSchemaFor(out) - q := NewQueryBuilder[OWNER](outSchema) - c, ok := getSchemaFor(property).relations[outSchema.Table].(BelongsToManyConfig) - if !ok { - q.err = fmt.Errorf("wrong config passed for HasMany") - } - return q. - Select(outSchema.Columns(true)...). - Table(outSchema.Table). - WhereIn(c.OwnerLookupColumn, Raw(fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`, - c.IntermediatePropertyID, - c.IntermediateTable, c.IntermediateOwnerID), genericGetPKValue(property))) -} - -// Add adds `items` to `to` using relations defined between items and to in ConfigureEntity method of `to`. -func Add(to Entity, items ...Entity) error { - if len(items) == 0 { - return nil - } - rels := getSchemaFor(to).relations - tname := getSchemaFor(items[0]).Table - c, ok := rels[tname] - if !ok { - return fmt.Errorf("no config found for given to and item...") - } - switch c.(type) { - case HasManyConfig: - return addProperty(to, items...) - case HasOneConfig: - return addProperty(to, items[0]) - case BelongsToManyConfig: - return addM2M(to, items...) - default: - return fmt.Errorf("cannot add for relation: %T", rels[getSchemaFor(items[0]).Table]) - } -} - -func addM2M(to Entity, items ...Entity) error { - //TODO: Optimize this - rels := getSchemaFor(to).relations - tname := getSchemaFor(items[0]).Table - c := rels[tname].(BelongsToManyConfig) - var values [][]interface{} - ownerPk := genericGetPKValue(to) - for _, item := range items { - pk := genericGetPKValue(item) - if isZero(pk) { - err := Insert(item) - if err != nil { - return err - } - pk = genericGetPKValue(item) - } - values = append(values, []interface{}{ownerPk, pk}) - } - i := insertStmt{ - PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator, - Table: c.IntermediateTable, - Columns: []string{c.IntermediateOwnerID, c.IntermediatePropertyID}, - Values: values, - } - - q, args := i.ToSql() - - _, err := getConnectionFor(items[0]).DB.Exec(q, args...) - if err != nil { - return err - } - - return err -} - -// addHasMany(Post, comments) -func addProperty(to Entity, items ...Entity) error { - var lastTable string - for _, obj := range items { - s := getSchemaFor(obj) - if lastTable == "" { - lastTable = s.Table - } else { - if lastTable != s.Table { - return fmt.Errorf("cannot batch insert for two different tables: %s and %s", s.Table, lastTable) - } - } - } - i := insertStmt{ - PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator, - Table: getSchemaFor(items[0]).getTable(), - } - ownerPKIdx := -1 - ownerPKName := getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey - for idx, col := range getSchemaFor(items[0]).Columns(false) { - if col == ownerPKName { - ownerPKIdx = idx - } - } - - ownerPK := genericGetPKValue(to) - if ownerPKIdx != -1 { - cols := getSchemaFor(items[0]).Columns(false) - i.Columns = append(i.Columns, cols...) - // Owner PK is present in the items struct - for _, item := range items { - vals := genericValuesOf(item, false) - if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey { - return fmt.Errorf("owner pk idx is not correct") - } - vals[ownerPKIdx] = ownerPK - i.Values = append(i.Values, vals) - } - } else { - ownerPKIdx = 0 - cols := getSchemaFor(items[0]).Columns(false) - cols = append(cols[:ownerPKIdx+1], cols[ownerPKIdx:]...) - cols[ownerPKIdx] = getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey - i.Columns = append(i.Columns, cols...) - for _, item := range items { - vals := genericValuesOf(item, false) - if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey { - return fmt.Errorf("owner pk idx is not correct") - } - vals = append(vals[:ownerPKIdx+1], vals[ownerPKIdx:]...) - vals[ownerPKIdx] = ownerPK - i.Values = append(i.Values, vals) - } - } - - q, args := i.ToSql() - - _, err := getConnectionFor(items[0]).DB.Exec(q, args...) - if err != nil { - return err - } - - return err - -} - -// Query creates a new QueryBuilder for given type parameter, sets dialect and table as well. -func Query[E Entity]() *QueryBuilder[E] { - s := getSchemaFor(*new(E)) - q := NewQueryBuilder[E](s) - q.SetDialect(s.getDialect()).Table(s.Table) - return q -} - -// ExecRaw executes given query string and arguments on given type parameter database connection. -func ExecRaw[E Entity](q string, args ...interface{}) (int64, int64, error) { - e := new(E) - - res, err := getSchemaFor(*e).getSQLDB().Exec(q, args...) - if err != nil { - return 0, 0, err - } - - id, err := res.LastInsertId() - if err != nil { - return 0, 0, err - } - - affected, err := res.RowsAffected() - if err != nil { - return 0, 0, err - } - - return id, affected, nil -} - -// QueryRaw queries given query string and arguments on given type parameter database connection. -func QueryRaw[OUTPUT Entity](q string, args ...interface{}) ([]OUTPUT, error) { - o := new(OUTPUT) - rows, err := getSchemaFor(*o).getSQLDB().Query(q, args...) - if err != nil { - return nil, err - } - var output []OUTPUT - err = newBinder(getSchemaFor(*o)).bind(rows, &output) - if err != nil { - return nil, err - } - return output, nil -} diff --git a/gdb/orm/orm_test.go b/gdb/orm/orm_test.go deleted file mode 100644 index b0aa1cb..0000000 --- a/gdb/orm/orm_test.go +++ /dev/null @@ -1,588 +0,0 @@ -// -// orm_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm_test - -import ( - "database/sql" - "testing" - - "git.hexq.cn/tiglog/golib/gdb/orm" - "github.com/stretchr/testify/assert" -) - -type AuthorEmail struct { - ID int64 - Email string -} - -func (a AuthorEmail) ConfigureEntity(e *orm.EntityConfigurator) { - e. - Table("emails"). - Connection("default"). - BelongsTo(&Post{}, orm.BelongsToConfig{}) -} - -type HeaderPicture struct { - ID int64 - PostID int64 - Link string -} - -func (h HeaderPicture) ConfigureEntity(e *orm.EntityConfigurator) { - e.Table("header_pictures").BelongsTo(&Post{}, orm.BelongsToConfig{}) -} - -type Post struct { - ID int64 - BodyText string - CreatedAt sql.NullTime - UpdatedAt sql.NullTime - DeletedAt sql.NullTime -} - -func (p Post) ConfigureEntity(e *orm.EntityConfigurator) { - e.Field("BodyText").ColumnName("body") - e.Field("ID").ColumnName("id") - e. - Table("posts"). - HasMany(Comment{}, orm.HasManyConfig{}). - HasOne(HeaderPicture{}, orm.HasOneConfig{}). - HasOne(AuthorEmail{}, orm.HasOneConfig{}). - BelongsToMany(Category{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"}) - -} - -func (p *Post) Categories() ([]Category, error) { - return orm.BelongsToMany[Category](p).All() -} - -func (p *Post) Comments() *orm.QueryBuilder[Comment] { - return orm.HasMany[Comment](p) -} - -type Comment struct { - ID int64 - PostID int64 - Body string -} - -func (c Comment) ConfigureEntity(e *orm.EntityConfigurator) { - e.Table("comments").BelongsTo(&Post{}, orm.BelongsToConfig{}) -} - -func (c *Comment) Post() (Post, error) { - return orm.BelongsTo[Post](c).Get() -} - -type Category struct { - ID int64 - Title string -} - -func (c Category) ConfigureEntity(e *orm.EntityConfigurator) { - e.Table("categories").BelongsToMany(Post{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"}) -} - -func (c Category) Posts() ([]Post, error) { - return orm.BelongsToMany[Post](c).All() -} - -// enough models let's test -// Entities is mandatory -// Errors should be carried - -func setup() error { - db, err := sql.Open("sqlite3", ":memory:") - if err != nil { - return err - } - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) - if err != nil { - return err - } - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) - if err != nil { - return err - } - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) - if err != nil { - return err - } - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`) - if err != nil { - return err - } - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) - if err != nil { - return err - } - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) - if err != nil { - return err - } - return orm.SetupConnections(orm.ConnectionConfig{ - Name: "default", - DB: db, - Dialect: orm.Dialects.SQLite3, - Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}}, - DatabaseValidations: true, - }) -} - -func TestFind(t *testing.T) { - err := setup() - assert.NoError(t, err) - err = orm.InsertAll(&Post{ - BodyText: "my body for insert", - }) - - assert.NoError(t, err) - - post, err := orm.Find[Post](1) - assert.NoError(t, err) - assert.Equal(t, "my body for insert", post.BodyText) - assert.Equal(t, int64(1), post.ID) -} - -func TestInsert(t *testing.T) { - err := setup() - assert.NoError(t, err) - post := &Post{ - BodyText: "my body for insert", - } - err = orm.Insert(post) - assert.NoError(t, err) - assert.Equal(t, int64(1), post.ID) - var p Post - assert.NoError(t, - orm.GetConnection("default").DB.QueryRow(`SELECT id, body FROM posts where id = ?`, 1).Scan(&p.ID, &p.BodyText)) - - assert.Equal(t, "my body for insert", p.BodyText) -} -func TestInsertAll(t *testing.T) { - err := setup() - assert.NoError(t, err) - post1 := &Post{ - BodyText: "Body1", - } - post2 := &Post{ - BodyText: "Body2", - } - - post3 := &Post{ - BodyText: "Body3", - } - - err = orm.InsertAll(post1, post2, post3) - assert.NoError(t, err) - var counter int - assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts`).Scan(&counter)) - assert.Equal(t, 3, counter) - -} -func TestUpdateORM(t *testing.T) { - err := setup() - assert.NoError(t, err) - post := &Post{ - BodyText: "my body for insert", - } - err = orm.Insert(post) - assert.NoError(t, err) - assert.Equal(t, int64(1), post.ID) - - post.BodyText += " update text" - assert.NoError(t, orm.Update(post)) - - var body string - assert.NoError(t, - orm.GetConnection("default").DB.QueryRow(`SELECT body FROM posts where id = ?`, post.ID).Scan(&body)) - - assert.Equal(t, "my body for insert update text", body) -} - -func TestDeleteORM(t *testing.T) { - err := setup() - assert.NoError(t, err) - post := &Post{ - BodyText: "my body for insert", - } - err = orm.Insert(post) - assert.NoError(t, err) - assert.Equal(t, int64(1), post.ID) - - assert.NoError(t, orm.Delete(post)) - - var count int - assert.NoError(t, - orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts where id = ?`, post.ID).Scan(&count)) - - assert.Equal(t, 0, count) -} -func TestAdd_HasMany(t *testing.T) { - err := setup() - assert.NoError(t, err) - post := &Post{ - BodyText: "my body for insert", - } - err = orm.Insert(post) - assert.NoError(t, err) - assert.Equal(t, int64(1), post.ID) - - err = orm.Add(post, []orm.Entity{ - Comment{ - Body: "comment 1", - }, - Comment{ - Body: "comment 2", - }, - }...) - // orm.Query(qm.WhereBetween()) - assert.NoError(t, err) - var count int - assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM comments`).Scan(&count)) - assert.Equal(t, 2, count) - - comment, err := orm.Find[Comment](1) - assert.NoError(t, err) - - assert.Equal(t, int64(1), comment.PostID) - -} - -func TestAdd_ManyToMany(t *testing.T) { - err := setup() - assert.NoError(t, err) - post := &Post{ - BodyText: "my body for insert", - } - err = orm.Insert(post) - assert.NoError(t, err) - assert.Equal(t, int64(1), post.ID) - - err = orm.Add(post, []orm.Entity{ - &Category{ - Title: "cat 1", - }, - &Category{ - Title: "cat 2", - }, - }...) - assert.NoError(t, err) - var count int - assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(post_id) FROM post_categories`).Scan(&count)) - assert.Equal(t, 2, count) - assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM categories`).Scan(&count)) - assert.Equal(t, 2, count) - - categories, err := post.Categories() - assert.NoError(t, err) - - assert.Equal(t, 2, len(categories)) - assert.Equal(t, int64(1), categories[0].ID) - assert.Equal(t, int64(2), categories[1].ID) - -} - -func TestSave(t *testing.T) { - t.Run("save should insert", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - post := &Post{ - BodyText: "1", - } - assert.NoError(t, orm.Save(post)) - assert.Equal(t, int64(1), post.ID) - }) - - t.Run("save should update", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "1", - } - assert.NoError(t, orm.Save(post)) - assert.Equal(t, int64(1), post.ID) - - post.BodyText += "2" - assert.NoError(t, orm.Save(post)) - - myPost, err := orm.Find[Post](1) - assert.NoError(t, err) - - assert.EqualValues(t, post.BodyText, myPost.BodyText) - }) - -} - -func TestHasMany(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "first post", - } - assert.NoError(t, orm.Save(post)) - assert.Equal(t, int64(1), post.ID) - - assert.NoError(t, orm.Save(&Comment{ - PostID: post.ID, - Body: "comment 1", - })) - assert.NoError(t, orm.Save(&Comment{ - PostID: post.ID, - Body: "comment 2", - })) - - comments, err := orm.HasMany[Comment](post).All() - assert.NoError(t, err) - - assert.Len(t, comments, 2) - - assert.Equal(t, post.ID, comments[0].PostID) - assert.Equal(t, post.ID, comments[1].PostID) -} - -func TestBelongsTo(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "first post", - } - assert.NoError(t, orm.Save(post)) - assert.Equal(t, int64(1), post.ID) - - comment := &Comment{ - PostID: post.ID, - Body: "comment 1", - } - assert.NoError(t, orm.Save(comment)) - - post2, err := orm.BelongsTo[Post](comment).Get() - assert.NoError(t, err) - - assert.Equal(t, post.BodyText, post2.BodyText) -} - -func TestHasOne(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "first post", - } - assert.NoError(t, orm.Save(post)) - assert.Equal(t, int64(1), post.ID) - - headerPicture := &HeaderPicture{ - PostID: post.ID, - Link: "google", - } - assert.NoError(t, orm.Save(headerPicture)) - - c1, err := orm.HasOne[HeaderPicture](post).Get() - assert.NoError(t, err) - - assert.Equal(t, headerPicture.PostID, c1.PostID) -} - -func TestBelongsToMany(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "first Post", - } - - assert.NoError(t, orm.Save(post)) - assert.Equal(t, int64(1), post.ID) - - category := &Category{ - Title: "first category", - } - assert.NoError(t, orm.Save(category)) - assert.Equal(t, int64(1), category.ID) - - _, _, err = orm.ExecRaw[Category](`INSERT INTO post_categories (post_id, category_id) VALUES (?,?)`, post.ID, category.ID) - assert.NoError(t, err) - - categories, err := orm.BelongsToMany[Category](post).All() - assert.NoError(t, err) - - assert.Len(t, categories, 1) -} - -func TestSchematic(t *testing.T) { - err := setup() - assert.NoError(t, err) - - orm.Schematic() -} - -func TestAddProperty(t *testing.T) { - t.Run("having pk value", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "first post", - } - - assert.NoError(t, orm.Save(post)) - assert.EqualValues(t, 1, post.ID) - - err = orm.Add(post, &Comment{PostID: post.ID, Body: "firstComment"}) - assert.NoError(t, err) - - var comment Comment - assert.NoError(t, orm.GetConnection("default"). - DB. - QueryRow(`SELECT id, post_id, body FROM comments WHERE post_id=?`, post.ID). - Scan(&comment.ID, &comment.PostID, &comment.Body)) - - assert.EqualValues(t, post.ID, comment.PostID) - }) - t.Run("not having PK value", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - post := &Post{ - BodyText: "first post", - } - assert.NoError(t, orm.Save(post)) - assert.EqualValues(t, 1, post.ID) - - err = orm.Add(post, &AuthorEmail{Email: "myemail"}) - assert.NoError(t, err) - - emails, err := orm.QueryRaw[AuthorEmail](`SELECT id, email FROM emails WHERE post_id=?`, post.ID) - - assert.NoError(t, err) - assert.Equal(t, []AuthorEmail{{ID: 1, Email: "myemail"}}, emails) - }) -} - -func TestQuery(t *testing.T) { - t.Run("querying single row", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) - // post, err := orm.Query[Post]().Where("id", 1).First() - post, err := orm.Query[Post]().WherePK(1).First().Get() - assert.NoError(t, err) - assert.EqualValues(t, "body 1", post.BodyText) - assert.EqualValues(t, 1, post.ID) - - }) - t.Run("querying multiple rows", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) - assert.NoError(t, orm.Save(&Post{BodyText: "body 2"})) - assert.NoError(t, orm.Save(&Post{BodyText: "body 3"})) - posts, err := orm.Query[Post]().All() - assert.NoError(t, err) - assert.Len(t, posts, 3) - assert.Equal(t, "body 1", posts[0].BodyText) - }) - - t.Run("updating a row using query interface", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) - - affected, err := orm.Query[Post]().Where("id", 1).Set("body", "body jadid").Update() - assert.NoError(t, err) - assert.EqualValues(t, 1, affected) - - post, err := orm.Find[Post](1) - assert.NoError(t, err) - assert.Equal(t, "body jadid", post.BodyText) - }) - - t.Run("deleting a row using query interface", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) - - affected, err := orm.Query[Post]().WherePK(1).Delete() - assert.NoError(t, err) - assert.EqualValues(t, 1, affected) - count, err := orm.Query[Post]().WherePK(1).Count().Get() - assert.NoError(t, err) - assert.EqualValues(t, 0, count) - }) - - t.Run("count", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - count, err := orm.Query[Post]().WherePK(1).Count().Get() - assert.NoError(t, err) - assert.EqualValues(t, 0, count) - }) - - t.Run("latest", func(t *testing.T) { - err := setup() - assert.NoError(t, err) - - assert.NoError(t, orm.Save(&Post{BodyText: "body 1"})) - assert.NoError(t, orm.Save(&Post{BodyText: "body 2"})) - - post, err := orm.Query[Post]().Latest().Get() - assert.NoError(t, err) - - assert.EqualValues(t, "body 2", post.BodyText) - }) -} - -func TestSetup(t *testing.T) { - t.Run("tables are out of sync", func(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") - // _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) - // _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) - - err = orm.SetupConnections(orm.ConnectionConfig{ - Name: "default", - DB: db, - Dialect: orm.Dialects.SQLite3, - Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}}, - DatabaseValidations: true, - }) - assert.Error(t, err) - - }) - t.Run("schemas are wrong", func(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, body text)`) // missing post_id - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) - _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) - - err = orm.SetupConnections(orm.ConnectionConfig{ - Name: "default", - DB: db, - Dialect: orm.Dialects.SQLite3, - Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}}, - DatabaseValidations: true, - }) - assert.Error(t, err) - - }) -} diff --git a/gdb/orm/query.go b/gdb/orm/query.go deleted file mode 100644 index 8c3f23d..0000000 --- a/gdb/orm/query.go +++ /dev/null @@ -1,811 +0,0 @@ -// -// query.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "fmt" - "strings" -) - -const ( - queryTypeSELECT = iota + 1 - queryTypeUPDATE - queryTypeDelete -) - -// QueryBuilder is our query builder, almost all methods and functions in GoLobby ORM -// create or configure instance of QueryBuilder. -type QueryBuilder[OUTPUT any] struct { - typ int - schema *schema - // general parts - where *whereClause - table string - placeholderGenerator func(n int) []string - - // select parts - orderBy *orderByClause - groupBy *GroupBy - selected *selected - subQuery *struct { - q string - args []interface{} - placeholderGenerator func(n int) []string - } - joins []*Join - limit *Limit - offset *Offset - - // update parts - sets [][2]interface{} - - // execution parts - db *sql.DB - err error -} - -// Finisher APIs - -// execute is a finisher executes QueryBuilder query, remember to use this when you have an Update -// or Delete Query. -func (q *QueryBuilder[OUTPUT]) execute() (sql.Result, error) { - if q.err != nil { - return nil, q.err - } - if q.typ == queryTypeSELECT { - return nil, fmt.Errorf("query type is SELECT") - } - query, args, err := q.ToSql() - if err != nil { - return nil, err - } - return q.schema.getConnection().exec(query, args...) -} - -// Get limit results to 1, runs query generated by query builder, scans result into OUTPUT. -func (q *QueryBuilder[OUTPUT]) Get() (OUTPUT, error) { - if q.err != nil { - return *new(OUTPUT), q.err - } - queryString, args, err := q.ToSql() - if err != nil { - return *new(OUTPUT), err - } - rows, err := q.schema.getConnection().query(queryString, args...) - if err != nil { - return *new(OUTPUT), err - } - var output OUTPUT - err = newBinder(q.schema).bind(rows, &output) - if err != nil { - return *new(OUTPUT), err - } - return output, nil -} - -// All is a finisher, create the Select query based on QueryBuilder and scan results into -// slice of type parameter E. -func (q *QueryBuilder[OUTPUT]) All() ([]OUTPUT, error) { - if q.err != nil { - return nil, q.err - } - q.SetSelect() - queryString, args, err := q.ToSql() - if err != nil { - return nil, err - } - rows, err := q.schema.getConnection().query(queryString, args...) - if err != nil { - return nil, err - } - var output []OUTPUT - err = newBinder(q.schema).bind(rows, &output) - if err != nil { - return nil, err - } - return output, nil -} - -// Delete is a finisher, creates a delete query from query builder and executes it. -func (q *QueryBuilder[OUTPUT]) Delete() (rowsAffected int64, err error) { - if q.err != nil { - return 0, q.err - } - q.SetDelete() - res, err := q.execute() - if err != nil { - return 0, q.err - } - return res.RowsAffected() -} - -// Update is a finisher, creates an Update query from QueryBuilder and executes in into database, returns -func (q *QueryBuilder[OUTPUT]) Update() (rowsAffected int64, err error) { - if q.err != nil { - return 0, q.err - } - q.SetUpdate() - res, err := q.execute() - if err != nil { - return 0, err - } - return res.RowsAffected() -} - -func copyQueryBuilder[T1 any, T2 any](q *QueryBuilder[T1], q2 *QueryBuilder[T2]) { - q2.db = q.db - q2.err = q.err - q2.groupBy = q.groupBy - q2.joins = q.joins - q2.limit = q.limit - q2.offset = q.offset - q2.orderBy = q.orderBy - q2.placeholderGenerator = q.placeholderGenerator - q2.schema = q.schema - q2.selected = q.selected - q2.sets = q.sets - - q2.subQuery = q.subQuery - q2.table = q.table - q2.typ = q.typ - q2.where = q.where -} - -// Count creates and execute a select query from QueryBuilder and set it's field list of selection -// to COUNT(id). -func (q *QueryBuilder[OUTPUT]) Count() *QueryBuilder[int] { - q.selected = &selected{Columns: []string{"COUNT(id)"}} - q.SetSelect() - qCount := NewQueryBuilder[int](q.schema) - - copyQueryBuilder(q, qCount) - - return qCount -} - -// First returns first record of database using OrderBy primary key -// ascending order. -func (q *QueryBuilder[OUTPUT]) First() *QueryBuilder[OUTPUT] { - q.OrderBy(q.schema.pkName(), ASC).Limit(1) - return q -} - -// Latest is like Get but it also do a OrderBy(primary key, DESC) -func (q *QueryBuilder[OUTPUT]) Latest() *QueryBuilder[OUTPUT] { - q.OrderBy(q.schema.pkName(), DESC).Limit(1) - return q -} - -// WherePK adds a where clause to QueryBuilder and also gets primary key name -// from type parameter schema. -func (q *QueryBuilder[OUTPUT]) WherePK(value interface{}) *QueryBuilder[OUTPUT] { - return q.Where(q.schema.pkName(), value) -} - -func (d *QueryBuilder[OUTPUT]) toSqlDelete() (string, []interface{}, error) { - base := fmt.Sprintf("DELETE FROM %s", d.table) - var args []interface{} - if d.where != nil { - d.where.PlaceHolderGenerator = d.placeholderGenerator - where, whereArgs, err := d.where.ToSql() - if err != nil { - return "", nil, err - } - base += " WHERE " + where - args = append(args, whereArgs...) - } - return base, args, nil -} -func pop(phs *[]string) string { - top := (*phs)[len(*phs)-1] - *phs = (*phs)[:len(*phs)-1] - return top -} - -func (u *QueryBuilder[OUTPUT]) kvString() string { - phs := u.placeholderGenerator(len(u.sets)) - var sets []string - for _, pair := range u.sets { - sets = append(sets, fmt.Sprintf("%s=%s", pair[0], pop(&phs))) - } - return strings.Join(sets, ",") -} - -func (u *QueryBuilder[OUTPUT]) args() []interface{} { - var values []interface{} - for _, pair := range u.sets { - values = append(values, pair[1]) - } - return values -} - -func (u *QueryBuilder[OUTPUT]) toSqlUpdate() (string, []interface{}, error) { - if u.table == "" { - return "", nil, fmt.Errorf("table cannot be empty") - } - base := fmt.Sprintf("UPDATE %s SET %s", u.table, u.kvString()) - args := u.args() - if u.where != nil { - u.where.PlaceHolderGenerator = u.placeholderGenerator - where, whereArgs, err := u.where.ToSql() - if err != nil { - return "", nil, err - } - args = append(args, whereArgs...) - base += " WHERE " + where - } - return base, args, nil -} -func (s *QueryBuilder[OUTPUT]) toSqlSelect() (string, []interface{}, error) { - if s.err != nil { - return "", nil, s.err - } - base := "SELECT" - var args []interface{} - // select - if s.selected == nil { - s.selected = &selected{ - Columns: []string{"*"}, - } - } - base += " " + s.selected.String() - // from - if s.table == "" && s.subQuery == nil { - return "", nil, fmt.Errorf("Table name cannot be empty") - } else if s.table != "" && s.subQuery != nil { - return "", nil, fmt.Errorf("cannot have both Table and subquery") - } - if s.table != "" { - base += " " + "FROM " + s.table - } - if s.subQuery != nil { - s.subQuery.placeholderGenerator = s.placeholderGenerator - base += " " + "FROM (" + s.subQuery.q + " )" - args = append(args, s.subQuery.args...) - } - // Joins - if s.joins != nil { - for _, join := range s.joins { - base += " " + join.String() - } - } - // whereClause - if s.where != nil { - s.where.PlaceHolderGenerator = s.placeholderGenerator - where, whereArgs, err := s.where.ToSql() - if err != nil { - return "", nil, err - } - base += " WHERE " + where - args = append(args, whereArgs...) - } - - // orderByClause - if s.orderBy != nil { - base += " " + s.orderBy.String() - } - - // GroupBy - if s.groupBy != nil { - base += " " + s.groupBy.String() - } - - // Limit - if s.limit != nil { - base += " " + s.limit.String() - } - - // Offset - if s.offset != nil { - base += " " + s.offset.String() - } - - return base, args, nil -} - -// ToSql creates sql query from QueryBuilder based on internal fields it would decide what kind -// of query to build. -func (q *QueryBuilder[OUTPUT]) ToSql() (string, []interface{}, error) { - if q.err != nil { - return "", nil, q.err - } - if q.typ == queryTypeSELECT { - return q.toSqlSelect() - } else if q.typ == queryTypeDelete { - return q.toSqlDelete() - } else if q.typ == queryTypeUPDATE { - return q.toSqlUpdate() - } else { - return "", nil, fmt.Errorf("no sql type matched") - } -} - -type orderByOrder string - -const ( - ASC = "ASC" - DESC = "DESC" -) - -type orderByClause struct { - Columns [][2]string -} - -func (o orderByClause) String() string { - var tuples []string - for _, pair := range o.Columns { - tuples = append(tuples, fmt.Sprintf("%s %s", pair[0], pair[1])) - } - return fmt.Sprintf("ORDER BY %s", strings.Join(tuples, ",")) -} - -type GroupBy struct { - Columns []string -} - -func (g GroupBy) String() string { - return fmt.Sprintf("GROUP BY %s", strings.Join(g.Columns, ",")) -} - -type joinType string - -const ( - JoinTypeInner = "INNER" - JoinTypeLeft = "LEFT" - JoinTypeRight = "RIGHT" - JoinTypeFull = "FULL OUTER" - JoinTypeSelf = "SELF" -) - -type JoinOn struct { - Lhs string - Rhs string -} - -func (j JoinOn) String() string { - return fmt.Sprintf("%s = %s", j.Lhs, j.Rhs) -} - -type Join struct { - Type joinType - Table string - On JoinOn -} - -func (j Join) String() string { - return fmt.Sprintf("%s JOIN %s ON %s", j.Type, j.Table, j.On.String()) -} - -type Limit struct { - N int -} - -func (l Limit) String() string { - return fmt.Sprintf("LIMIT %d", l.N) -} - -type Offset struct { - N int -} - -func (o Offset) String() string { - return fmt.Sprintf("OFFSET %d", o.N) -} - -type selected struct { - Columns []string -} - -func (s selected) String() string { - return fmt.Sprintf("%s", strings.Join(s.Columns, ",")) -} - -// OrderBy adds an OrderBy section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) OrderBy(column string, how string) *QueryBuilder[OUTPUT] { - q.SetSelect() - if q.orderBy == nil { - q.orderBy = &orderByClause{} - } - q.orderBy.Columns = append(q.orderBy.Columns, [2]string{column, how}) - return q -} - -// LeftJoin adds a left join section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) LeftJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { - q.SetSelect() - q.joins = append(q.joins, &Join{ - Type: JoinTypeLeft, - Table: table, - On: JoinOn{ - Lhs: onLhs, - Rhs: onRhs, - }, - }) - return q -} - -// RightJoin adds a right join section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) RightJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { - q.SetSelect() - q.joins = append(q.joins, &Join{ - Type: JoinTypeRight, - Table: table, - On: JoinOn{ - Lhs: onLhs, - Rhs: onRhs, - }, - }) - return q -} - -// InnerJoin adds a inner join section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) InnerJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { - q.SetSelect() - q.joins = append(q.joins, &Join{ - Type: JoinTypeInner, - Table: table, - On: JoinOn{ - Lhs: onLhs, - Rhs: onRhs, - }, - }) - return q -} - -// Join adds a inner join section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) Join(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { - return q.InnerJoin(table, onLhs, onRhs) -} - -// FullOuterJoin adds a full outer join section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) FullOuterJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] { - q.SetSelect() - q.joins = append(q.joins, &Join{ - Type: JoinTypeFull, - Table: table, - On: JoinOn{ - Lhs: onLhs, - Rhs: onRhs, - }, - }) - return q -} - -// Where Adds a where clause to query, if already have where clause append to it -// as AndWhere. -func (q *QueryBuilder[OUTPUT]) Where(parts ...interface{}) *QueryBuilder[OUTPUT] { - if q.where != nil { - return q.addWhere("AND", parts...) - } - if len(parts) == 1 { - if r, isRaw := parts[0].(*raw); isRaw { - q.where = &whereClause{raw: r.sql, args: r.args, PlaceHolderGenerator: q.placeholderGenerator} - return q - } else { - q.err = fmt.Errorf("when you have one argument passed to where, it should be *raw") - return q - } - - } else if len(parts) == 2 { - if strings.Index(parts[0].(string), " ") == -1 { - // Equal mode - q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]}, PlaceHolderGenerator: q.placeholderGenerator} - } - return q - } else if len(parts) == 3 { - // operator mode - q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]}, PlaceHolderGenerator: q.placeholderGenerator} - return q - } else if len(parts) > 3 && parts[1].(string) == "IN" { - q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2:]}, PlaceHolderGenerator: q.placeholderGenerator} - return q - } else { - q.err = fmt.Errorf("wrong number of arguments passed to Where") - return q - } -} - -type binaryOp string - -const ( - Eq = "=" - GT = ">" - LT = "<" - GE = ">=" - LE = "<=" - NE = "!=" - Between = "BETWEEN" - Like = "LIKE" - In = "IN" -) - -type cond struct { - PlaceHolderGenerator func(n int) []string - - Lhs string - Op binaryOp - Rhs interface{} -} - -func (b cond) ToSql() (string, []interface{}, error) { - var phs []string - if b.Op == In { - rhs, isInterfaceSlice := b.Rhs.([]interface{}) - if isInterfaceSlice { - phs = b.PlaceHolderGenerator(len(rhs)) - return fmt.Sprintf("%s IN (%s)", b.Lhs, strings.Join(phs, ",")), rhs, nil - } else if rawThing, isRaw := b.Rhs.(*raw); isRaw { - return fmt.Sprintf("%s IN (%s)", b.Lhs, rawThing.sql), rawThing.args, nil - } else { - return "", nil, fmt.Errorf("Right hand side of Cond when operator is IN should be either a interface{} slice or *raw") - } - - } else { - phs = b.PlaceHolderGenerator(1) - return fmt.Sprintf("%s %s %s", b.Lhs, b.Op, pop(&phs)), []interface{}{b.Rhs}, nil - } -} - -const ( - nextType_AND = "AND" - nextType_OR = "OR" -) - -type whereClause struct { - PlaceHolderGenerator func(n int) []string - nextTyp string - next *whereClause - cond - raw string - args []interface{} -} - -func (w whereClause) ToSql() (string, []interface{}, error) { - var base string - var args []interface{} - var err error - if w.raw != "" { - base = w.raw - args = w.args - } else { - w.cond.PlaceHolderGenerator = w.PlaceHolderGenerator - base, args, err = w.cond.ToSql() - if err != nil { - return "", nil, err - } - } - if w.next == nil { - return base, args, nil - } - if w.next != nil { - next, nextArgs, err := w.next.ToSql() - if err != nil { - return "", nil, err - } - base += " " + w.nextTyp + " " + next - args = append(args, nextArgs...) - return base, args, nil - } - - return base, args, nil -} - -//func (q *QueryBuilder[OUTPUT]) WhereKeyValue(m map) {} - -// WhereIn adds a where clause to QueryBuilder using In operator. -func (q *QueryBuilder[OUTPUT]) WhereIn(column string, values ...interface{}) *QueryBuilder[OUTPUT] { - return q.Where(append([]interface{}{column, In}, values...)...) -} - -// AndWhere appends a where clause to query builder as And where clause. -func (q *QueryBuilder[OUTPUT]) AndWhere(parts ...interface{}) *QueryBuilder[OUTPUT] { - return q.addWhere(nextType_AND, parts...) -} - -// OrWhere appends a where clause to query builder as Or where clause. -func (q *QueryBuilder[OUTPUT]) OrWhere(parts ...interface{}) *QueryBuilder[OUTPUT] { - return q.addWhere(nextType_OR, parts...) -} - -func (q *QueryBuilder[OUTPUT]) addWhere(typ string, parts ...interface{}) *QueryBuilder[OUTPUT] { - w := q.where - for { - if w == nil { - break - } else if w.next == nil { - w.next = &whereClause{PlaceHolderGenerator: q.placeholderGenerator} - w.nextTyp = typ - w = w.next - break - } else { - w = w.next - } - } - if w == nil { - w = &whereClause{PlaceHolderGenerator: q.placeholderGenerator} - } - if len(parts) == 1 { - w.raw = parts[0].(*raw).sql - w.args = parts[0].(*raw).args - return q - } else if len(parts) == 2 { - // Equal mode - w.cond = cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]} - return q - } else if len(parts) == 3 { - // operator mode - w.cond = cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]} - return q - } else { - panic("wrong number of arguments passed to Where") - } -} - -// Offset adds offset section to query builder. -func (q *QueryBuilder[OUTPUT]) Offset(n int) *QueryBuilder[OUTPUT] { - q.SetSelect() - q.offset = &Offset{N: n} - return q -} - -// Limit adds limit section to query builder. -func (q *QueryBuilder[OUTPUT]) Limit(n int) *QueryBuilder[OUTPUT] { - q.SetSelect() - q.limit = &Limit{N: n} - return q -} - -// Table sets table of QueryBuilder. -func (q *QueryBuilder[OUTPUT]) Table(t string) *QueryBuilder[OUTPUT] { - q.table = t - return q -} - -// SetSelect sets query type of QueryBuilder to Select. -func (q *QueryBuilder[OUTPUT]) SetSelect() *QueryBuilder[OUTPUT] { - q.typ = queryTypeSELECT - return q -} - -// GroupBy adds a group by section to QueryBuilder. -func (q *QueryBuilder[OUTPUT]) GroupBy(columns ...string) *QueryBuilder[OUTPUT] { - q.SetSelect() - if q.groupBy == nil { - q.groupBy = &GroupBy{} - } - q.groupBy.Columns = append(q.groupBy.Columns, columns...) - return q -} - -// Select adds columns to QueryBuilder select field list. -func (q *QueryBuilder[OUTPUT]) Select(columns ...string) *QueryBuilder[OUTPUT] { - q.SetSelect() - if q.selected == nil { - q.selected = &selected{} - } - q.selected.Columns = append(q.selected.Columns, columns...) - return q -} - -// FromQuery sets subquery of QueryBuilder to be given subquery so -// when doing select instead of from table we do from(subquery). -func (q *QueryBuilder[OUTPUT]) FromQuery(subQuery *QueryBuilder[OUTPUT]) *QueryBuilder[OUTPUT] { - q.SetSelect() - subQuery.SetSelect() - subQuery.placeholderGenerator = q.placeholderGenerator - subQueryString, args, err := subQuery.ToSql() - q.err = err - q.subQuery = &struct { - q string - args []interface{} - placeholderGenerator func(n int) []string - }{ - subQueryString, args, q.placeholderGenerator, - } - return q -} - -func (q *QueryBuilder[OUTPUT]) SetUpdate() *QueryBuilder[OUTPUT] { - q.typ = queryTypeUPDATE - return q -} - -func (q *QueryBuilder[OUTPUT]) Set(keyValues ...any) *QueryBuilder[OUTPUT] { - if len(keyValues)%2 != 0 { - q.err = fmt.Errorf("when using Set, passed argument count should be even: %w", q.err) - return q - } - q.SetUpdate() - for i := 0; i < len(keyValues); i++ { - if i != 0 && i%2 == 1 { - q.sets = append(q.sets, [2]any{keyValues[i-1], keyValues[i]}) - } - } - return q -} - -func (q *QueryBuilder[OUTPUT]) SetDialect(dialect *Dialect) *QueryBuilder[OUTPUT] { - q.placeholderGenerator = dialect.PlaceHolderGenerator - return q -} -func (q *QueryBuilder[OUTPUT]) SetDelete() *QueryBuilder[OUTPUT] { - q.typ = queryTypeDelete - return q -} - -type raw struct { - sql string - args []interface{} -} - -// Raw creates a Raw sql query chunk that you can add to several components of QueryBuilder like -// Wheres. -func Raw(sql string, args ...interface{}) *raw { - return &raw{sql: sql, args: args} -} - -func NewQueryBuilder[OUTPUT any](s *schema) *QueryBuilder[OUTPUT] { - return &QueryBuilder[OUTPUT]{schema: s} -} - -type insertStmt struct { - PlaceHolderGenerator func(n int) []string - Table string - Columns []string - Values [][]interface{} - Returning string -} - -func (i insertStmt) flatValues() []interface{} { - var values []interface{} - for _, row := range i.Values { - values = append(values, row...) - } - return values -} - -func (i insertStmt) getValuesStr() string { - phs := i.PlaceHolderGenerator(len(i.Values) * len(i.Values[0])) - - var output []string - for _, valueRow := range i.Values { - output = append(output, fmt.Sprintf("(%s)", strings.Join(phs[:len(valueRow)], ","))) - phs = phs[len(valueRow):] - } - return strings.Join(output, ",") -} - -func (i insertStmt) ToSql() (string, []interface{}) { - base := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s", - i.Table, - strings.Join(i.Columns, ","), - i.getValuesStr(), - ) - if i.Returning != "" { - base += "RETURNING " + i.Returning - } - return base, i.flatValues() -} - -func postgresPlaceholder(n int) []string { - output := []string{} - for i := 1; i < n+1; i++ { - output = append(output, fmt.Sprintf("$%d", i)) - } - return output -} - -func questionMarks(n int) []string { - output := []string{} - for i := 0; i < n; i++ { - output = append(output, "?") - } - - return output -} diff --git a/gdb/orm/query_test.go b/gdb/orm/query_test.go deleted file mode 100644 index 6bc0a71..0000000 --- a/gdb/orm/query_test.go +++ /dev/null @@ -1,243 +0,0 @@ -// -// query_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -type Dummy struct{} - -func (d Dummy) ConfigureEntity(e *EntityConfigurator) { - // TODO implement me - panic("implement me") -} - -func TestSelect(t *testing.T) { - t.Run("only select * from Table", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil) - s.Table("users").SetSelect() - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, "SELECT * FROM users", str) - }) - t.Run("select with whereClause", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil) - - s.Table("users").SetDialect(Dialects.MySQL). - Where("age", 10). - AndWhere("age", "<", 10). - Where("name", "Amirreza"). - OrWhere("age", GT, 11). - SetSelect() - - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.EqualValues(t, []interface{}{10, 10, "Amirreza", 11}, args) - assert.Equal(t, "SELECT * FROM users WHERE age = ? AND age < ? AND name = ? OR age > ?", str) - }) - t.Run("select with order by", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").OrderBy("created_at", ASC).OrderBy("updated_at", DESC) - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, "SELECT * FROM users ORDER BY created_at ASC,updated_at DESC", str) - }) - - t.Run("select with group by", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").GroupBy("created_at", "updated_at") - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, "SELECT * FROM users GROUP BY created_at,updated_at", str) - }) - - t.Run("Select with limit", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").Limit(10) - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, "SELECT * FROM users LIMIT 10", str) - }) - - t.Run("Select with offset", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").Offset(10) - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, "SELECT * FROM users OFFSET 10", str) - }) - - t.Run("select with join", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").RightJoin("addresses", "users.id", "addresses.user_id") - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id`, str) - }) - - t.Run("select with multiple joins", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users"). - Select("id", "name"). - RightJoin("addresses", "users.id", "addresses.user_id"). - LeftJoin("user_credits", "users.id", "user_credits.user_id") - sql, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id LEFT JOIN user_credits ON users.id = user_credits.user_id`, sql) - }) - - t.Run("select with subquery", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).SetDialect(Dialects.MySQL) - s.FromQuery(NewQueryBuilder[Dummy](nil).Table("users").Where("age", "<", 10)) - sql, args, err := s.ToSql() - assert.NoError(t, err) - assert.EqualValues(t, []interface{}{10}, args) - assert.Equal(t, `SELECT * FROM (SELECT * FROM users WHERE age < ? )`, sql) - - }) - - t.Run("select with inner join", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").InnerJoin("addresses", "users.id", "addresses.user_id") - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str) - }) - - t.Run("select with join", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").Join("addresses", "users.id", "addresses.user_id") - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str) - }) - - t.Run("select with full outer join", func(t *testing.T) { - s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").FullOuterJoin("addresses", "users.id", "addresses.user_id") - str, args, err := s.ToSql() - assert.NoError(t, err) - assert.Empty(t, args) - assert.Equal(t, `SELECT id,name FROM users FULL OUTER JOIN addresses ON users.id = addresses.user_id`, str) - }) - t.Run("raw where", func(t *testing.T) { - sql, args, err := - NewQueryBuilder[Dummy](nil). - SetDialect(Dialects.MySQL). - Table("users"). - Where(Raw("id = ?", 1)). - AndWhere(Raw("age < ?", 10)). - SetSelect(). - ToSql() - - assert.NoError(t, err) - assert.EqualValues(t, []interface{}{1, 10}, args) - assert.Equal(t, `SELECT * FROM users WHERE id = ? AND age < ?`, sql) - }) - t.Run("no sql type matched", func(t *testing.T) { - sql, args, err := NewQueryBuilder[Dummy](nil).ToSql() - assert.Error(t, err) - assert.Empty(t, args) - assert.Empty(t, sql) - }) - - t.Run("raw where in", func(t *testing.T) { - sql, args, err := - NewQueryBuilder[Dummy](nil). - SetDialect(Dialects.MySQL). - Table("users"). - WhereIn("id", Raw("SELECT user_id FROM user_books WHERE book_id = ?", 10)). - SetSelect(). - ToSql() - - assert.NoError(t, err) - assert.EqualValues(t, []interface{}{10}, args) - assert.Equal(t, `SELECT * FROM users WHERE id IN (SELECT user_id FROM user_books WHERE book_id = ?)`, sql) - }) - t.Run("where in", func(t *testing.T) { - sql, args, err := - NewQueryBuilder[Dummy](nil). - SetDialect(Dialects.MySQL). - Table("users"). - WhereIn("id", 1, 2, 3, 4, 5, 6). - SetSelect(). - ToSql() - - assert.NoError(t, err) - assert.EqualValues(t, []interface{}{1, 2, 3, 4, 5, 6}, args) - assert.Equal(t, `SELECT * FROM users WHERE id IN (?,?,?,?,?,?)`, sql) - - }) -} -func TestUpdate(t *testing.T) { - t.Run("update no whereClause", func(t *testing.T) { - u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").SetDialect(Dialects.MySQL) - sql, args, err := u.ToSql() - assert.NoError(t, err) - assert.Equal(t, `UPDATE users SET name=?`, sql) - assert.Equal(t, []interface{}{"amirreza"}, args) - }) - t.Run("update with whereClause", func(t *testing.T) { - u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").Where("age", "<", 18).SetDialect(Dialects.MySQL) - sql, args, err := u.ToSql() - assert.NoError(t, err) - assert.Equal(t, `UPDATE users SET name=? WHERE age < ?`, sql) - assert.Equal(t, []interface{}{"amirreza", 18}, args) - - }) -} -func TestDelete(t *testing.T) { - t.Run("delete without whereClause", func(t *testing.T) { - d := NewQueryBuilder[Dummy](nil).Table("users").SetDelete() - sql, args, err := d.ToSql() - assert.NoError(t, err) - assert.Equal(t, `DELETE FROM users`, sql) - assert.Empty(t, args) - }) - t.Run("delete with whereClause", func(t *testing.T) { - d := NewQueryBuilder[Dummy](nil).Table("users").SetDialect(Dialects.MySQL).Where("created_at", ">", "2012-01-10").SetDelete() - sql, args, err := d.ToSql() - assert.NoError(t, err) - assert.Equal(t, `DELETE FROM users WHERE created_at > ?`, sql) - assert.EqualValues(t, []interface{}{"2012-01-10"}, args) - }) -} - -func TestInsert(t *testing.T) { - t.Run("insert into multiple rows", func(t *testing.T) { - i := insertStmt{} - i.Table = "users" - i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator - i.Columns = []string{"name", "age"} - i.Values = append(i.Values, []interface{}{"amirreza", 11}, []interface{}{"parsa", 10}) - s, args := i.ToSql() - assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?),(?,?)`, s) - assert.EqualValues(t, []interface{}{"amirreza", 11, "parsa", 10}, args) - }) - - t.Run("insert into single row", func(t *testing.T) { - i := insertStmt{} - i.Table = "users" - i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator - i.Columns = []string{"name", "age"} - i.Values = append(i.Values, []interface{}{"amirreza", 11}) - s, args := i.ToSql() - assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?)`, s) - assert.Equal(t, []interface{}{"amirreza", 11}, args) - }) -} - -func TestPostgresPlaceholder(t *testing.T) { - t.Run("for 5 it should have 5", func(t *testing.T) { - phs := postgresPlaceholder(5) - assert.EqualValues(t, []string{"$1", "$2", "$3", "$4", "$5"}, phs) - }) -} diff --git a/gdb/orm/schema.go b/gdb/orm/schema.go deleted file mode 100644 index 806b0b5..0000000 --- a/gdb/orm/schema.go +++ /dev/null @@ -1,306 +0,0 @@ -// -// schema.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "database/sql/driver" - "fmt" - "reflect" -) - -func getConnectionFor(e Entity) *connection { - configurator := newEntityConfigurator() - e.ConfigureEntity(configurator) - - if len(globalConnections) > 1 && (configurator.connection == "" || configurator.table == "") { - panic("need table and DB name when having more than 1 DB registered") - } - if len(globalConnections) == 1 { - for _, db := range globalConnections { - return db - } - } - if db, exists := globalConnections[fmt.Sprintf("%s", configurator.connection)]; exists { - return db - } - panic("no db found") -} - -func getSchemaFor(e Entity) *schema { - configurator := newEntityConfigurator() - c := getConnectionFor(e) - e.ConfigureEntity(configurator) - s := c.getSchema(configurator.table) - if s == nil { - s = schemaOfHeavyReflectionStuff(e) - c.setSchema(e, s) - } - return s -} - -type schema struct { - Connection string - Table string - fields []*field - relations map[string]interface{} - setPK func(o Entity, value interface{}) - getPK func(o Entity) interface{} - columnConstraints []*FieldConfigurator -} - -func (s *schema) getField(sf reflect.StructField) *field { - for _, f := range s.fields { - if sf.Name == f.Name { - return f - } - } - return nil -} - -func (s *schema) getDialect() *Dialect { - return GetConnection(s.Connection).Dialect -} -func (s *schema) Columns(withPK bool) []string { - var cols []string - for _, field := range s.fields { - if field.Virtual { - continue - } - if !withPK && field.IsPK { - continue - } - if s.getDialect().AddTableNameInSelectColumns { - cols = append(cols, s.Table+"."+field.Name) - } else { - cols = append(cols, field.Name) - } - } - return cols -} - -func (s *schema) pkName() string { - for _, field := range s.fields { - if field.IsPK { - return field.Name - } - } - return "" -} - -func genericFieldsOf(obj Entity) []*field { - t := reflect.TypeOf(obj) - for t.Kind() == reflect.Ptr { - t = t.Elem() - - } - if t.Kind() == reflect.Slice { - t = t.Elem() - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - } - var ec EntityConfigurator - obj.ConfigureEntity(&ec) - - var fms []*field - for i := 0; i < t.NumField(); i++ { - ft := t.Field(i) - fm := fieldMetadata(ft, ec.columnConstraints) - fms = append(fms, fm...) - } - return fms -} - -func valuesOfField(vf reflect.Value) []interface{} { - var values []interface{} - if vf.Type().Kind() == reflect.Struct || vf.Type().Kind() == reflect.Ptr { - t := vf.Type() - if vf.Type().Kind() == reflect.Ptr { - t = vf.Type().Elem() - } - if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { - // go into - // it does not implement driver.Valuer interface - for i := 0; i < vf.NumField(); i++ { - vif := vf.Field(i) - values = append(values, valuesOfField(vif)...) - } - } else { - values = append(values, vf.Interface()) - } - } else { - values = append(values, vf.Interface()) - } - return values -} -func genericValuesOf(o Entity, withPK bool) []interface{} { - t := reflect.TypeOf(o) - v := reflect.ValueOf(o) - if t.Kind() == reflect.Ptr { - t = t.Elem() - v = v.Elem() - } - fields := getSchemaFor(o).fields - pkIdx := -1 - for i, field := range fields { - if field.IsPK { - pkIdx = i - } - - } - - var values []interface{} - - for i := 0; i < t.NumField(); i++ { - if !withPK && i == pkIdx { - continue - } - if fields[i].Virtual { - continue - } - vf := v.Field(i) - values = append(values, valuesOfField(vf)...) - } - return values -} - -func genericSetPkValue(obj Entity, value interface{}) { - genericSet(obj, getSchemaFor(obj).pkName(), value) -} - -func genericGetPKValue(obj Entity) interface{} { - t := reflect.TypeOf(obj) - val := reflect.ValueOf(obj) - if t.Kind() == reflect.Ptr { - val = val.Elem() - } - - fields := getSchemaFor(obj).fields - for i, field := range fields { - if field.IsPK { - return val.Field(i).Interface() - } - } - return "" -} - -func (s *schema) createdAt() *field { - for _, f := range s.fields { - if f.IsCreatedAt { - return f - } - } - return nil -} -func (s *schema) updatedAt() *field { - for _, f := range s.fields { - if f.IsUpdatedAt { - return f - } - } - return nil -} - -func (s *schema) deletedAt() *field { - for _, f := range s.fields { - if f.IsDeletedAt { - return f - } - } - return nil -} -func pointersOf(v reflect.Value) map[string]interface{} { - m := map[string]interface{}{} - actualV := v - for actualV.Type().Kind() == reflect.Ptr { - actualV = actualV.Elem() - } - for i := 0; i < actualV.NumField(); i++ { - f := actualV.Field(i) - if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) { - fm := pointersOf(f) - for k, p := range fm { - m[k] = p - } - } else { - fm := fieldMetadata(actualV.Type().Field(i), nil)[0] - m[fm.Name] = actualV.Field(i) - } - } - - return m -} -func genericSet(obj Entity, name string, value interface{}) { - n2p := pointersOf(reflect.ValueOf(obj)) - var val interface{} - for k, v := range n2p { - if k == name { - val = v - } - } - val.(reflect.Value).Set(reflect.ValueOf(value)) -} -func schemaOfHeavyReflectionStuff(v Entity) *schema { - userEntityConfigurator := newEntityConfigurator() - v.ConfigureEntity(userEntityConfigurator) - for _, relation := range userEntityConfigurator.resolveRelations { - relation() - } - schema := &schema{} - if userEntityConfigurator.connection != "" { - schema.Connection = userEntityConfigurator.connection - } - if userEntityConfigurator.table != "" { - schema.Table = userEntityConfigurator.table - } else { - panic("you need to have table name for getting schema.") - } - - schema.columnConstraints = userEntityConfigurator.columnConstraints - if schema.Connection == "" { - schema.Connection = "default" - } - if schema.fields == nil { - schema.fields = genericFieldsOf(v) - } - if schema.getPK == nil { - schema.getPK = genericGetPKValue - } - - if schema.setPK == nil { - schema.setPK = genericSetPkValue - } - - schema.relations = userEntityConfigurator.relations - - return schema -} - -func (s *schema) getTable() string { - return s.Table -} - -func (s *schema) getSQLDB() *sql.DB { - return s.getConnection().DB -} - -func (s *schema) getConnection() *connection { - if len(globalConnections) > 1 && (s.Connection == "" || s.Table == "") { - panic("need table and DB name when having more than 1 DB registered") - } - if len(globalConnections) == 1 { - for _, db := range globalConnections { - return db - } - } - if db, exists := globalConnections[fmt.Sprintf("%s", s.Connection)]; exists { - return db - } - panic("no db found") -} diff --git a/gdb/orm/schema_test.go b/gdb/orm/schema_test.go deleted file mode 100644 index 07bad37..0000000 --- a/gdb/orm/schema_test.go +++ /dev/null @@ -1,77 +0,0 @@ -// -// schema_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" - "testing" - - "github.com/stretchr/testify/assert" -) - -func setup(t *testing.T) { - db, err := sql.Open("sqlite3", ":memory:") - err = SetupConnections(ConnectionConfig{ - Name: "default", - DB: db, - Dialect: Dialects.SQLite3, - }) - // orm.Schematic() - _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`) - _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`) - _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`) - _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`) - _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`) - _, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`) - assert.NoError(t, err) -} - -type Object struct { - ID int64 - Name string - Timestamps - SoftDelete -} - -func (o Object) ConfigureEntity(e *EntityConfigurator) { - e.Table("objects").Connection("default") -} - -func TestGenericFieldsOf(t *testing.T) { - t.Run("fields of with id and timestamps embedded", func(t *testing.T) { - fs := genericFieldsOf(&Object{}) - assert.Len(t, fs, 5) - assert.Equal(t, "id", fs[0].Name) - assert.True(t, fs[0].IsPK) - assert.Equal(t, "name", fs[1].Name) - assert.Equal(t, "created_at", fs[2].Name) - assert.Equal(t, "updated_at", fs[3].Name) - assert.Equal(t, "deleted_at", fs[4].Name) - }) -} - -func TestGenericValuesOf(t *testing.T) { - t.Run("values of", func(t *testing.T) { - - setup(t) - vs := genericValuesOf(Object{}, true) - assert.Len(t, vs, 5) - }) -} - -func TestEntityConfigurator(t *testing.T) { - t.Run("test has many with user provided values", func(t *testing.T) { - setup(t) - var ec EntityConfigurator - ec.Table("users").Connection("default").HasMany(Object{}, HasManyConfig{ - "objects", "user_id", - }) - - }) - -} diff --git a/gdb/orm/timestamps.go b/gdb/orm/timestamps.go deleted file mode 100644 index 3cf2ee1..0000000 --- a/gdb/orm/timestamps.go +++ /dev/null @@ -1,21 +0,0 @@ -// -// timestamps.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package orm - -import ( - "database/sql" -) - -type Timestamps struct { - CreatedAt sql.NullTime - UpdatedAt sql.NullTime -} - -type SoftDelete struct { - DeletedAt sql.NullTime -} diff --git a/gdb/sqldb/column.go b/gdb/sqldb/column.go deleted file mode 100644 index 781c36c..0000000 --- a/gdb/sqldb/column.go +++ /dev/null @@ -1,78 +0,0 @@ -// -// column.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import "reflect" - -// ColumnMap represents a mapping between a Go struct field and a single -// column in a table. -// Unique and MaxSize only inform the -// CreateTables() function and are not used by Insert/Update/Delete/Get. -type ColumnMap struct { - // Column name in db table - ColumnName string - - // If true, this column is skipped in generated SQL statements - Transient bool - - // If true, " unique" is added to create table statements. - // Not used elsewhere - Unique bool - - // Query used for getting generated id after insert - GeneratedIdQuery string - - // Passed to Dialect.ToSqlType() to assist in informing the - // correct column type to map to in CreateTables() - MaxSize int - - DefaultValue string - - fieldName string - gotype reflect.Type - isPK bool - isAutoIncr bool - isNotNull bool -} - -// Rename allows you to specify the column name in the table -// -// Example: table.ColMap("Updated").Rename("date_updated") -func (c *ColumnMap) Rename(colname string) *ColumnMap { - c.ColumnName = colname - return c -} - -// SetTransient allows you to mark the column as transient. If true -// this column will be skipped when SQL statements are generated -func (c *ColumnMap) SetTransient(b bool) *ColumnMap { - c.Transient = b - return c -} - -// SetUnique adds "unique" to the create table statements for this -// column, if b is true. -func (c *ColumnMap) SetUnique(b bool) *ColumnMap { - c.Unique = b - return c -} - -// SetNotNull adds "not null" to the create table statements for this -// column, if nn is true. -func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap { - c.isNotNull = nn - return c -} - -// SetMaxSize specifies the max length of values of this column. This is -// passed to the dialect.ToSqlType() function, which can use the value -// to alter the generated type for "create table" statements -func (c *ColumnMap) SetMaxSize(size int) *ColumnMap { - c.MaxSize = size - return c -} diff --git a/gdb/sqldb/context_test.go b/gdb/sqldb/context_test.go deleted file mode 100644 index 96cd8d8..0000000 --- a/gdb/sqldb/context_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// -// context_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -//go:build integration -// +build integration - -package sqldb_test - -import ( - "context" - "testing" - "time" - - "github.com/stretchr/testify/assert" -) - -// Drivers that don't support cancellation. -var unsupportedDrivers map[string]bool = map[string]bool{ - "mymysql": true, -} - -type SleepDialect interface { - // string to sleep for d duration - SleepClause(d time.Duration) string -} - -func TestWithNotCanceledContext(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - withCtx := dbmap.WithContext(ctx) - - _, err := withCtx.Exec("SELECT 1") - assert.Nil(t, err) -} - -func TestWithCanceledContext(t *testing.T) { - dialect, driver := dialectAndDriver() - if unsupportedDrivers[driver] { - t.Skipf("Cancellation is not yet supported by all drivers. Not known to be supported in %s.", driver) - } - - sleepDialect, ok := dialect.(SleepDialect) - if !ok { - t.Skipf("Sleep is not supported in all dialects. Not known to be supported in %s.", driver) - } - - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) - defer cancel() - - withCtx := dbmap.WithContext(ctx) - - startTime := time.Now() - - _, err := withCtx.Exec("SELECT " + sleepDialect.SleepClause(1*time.Second)) - - if d := time.Since(startTime); d > 500*time.Millisecond { - t.Errorf("too long execution time: %s", d) - } - - switch driver { - case "postgres": - // pq doesn't return standard deadline exceeded error - if err.Error() != "pq: canceling statement due to user request" { - t.Errorf("expected context.DeadlineExceeded, got %v", err) - } - default: - if err != context.DeadlineExceeded { - t.Errorf("expected context.DeadlineExceeded, got %v", err) - } - } -} diff --git a/gdb/sqldb/db.go b/gdb/sqldb/db.go deleted file mode 100644 index 08da704..0000000 --- a/gdb/sqldb/db.go +++ /dev/null @@ -1,1049 +0,0 @@ -// -// db.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "bytes" - "context" - "database/sql" - "database/sql/driver" - "errors" - "fmt" - "log" - "reflect" - "strconv" - "strings" - "time" -) - -var dm *DbMap - -type DbOption struct { - Type string - Dsn string - MaxIdle int - MaxOpen int -} - -func InitDb(opt DbOption) error { - var ( - db *sql.DB - dialect Dialect - err error - ) - switch opt.Type { - case "postgres": - db, err = sql.Open(opt.Type, opt.Dsn) - if err != nil { - return err - } - db.SetMaxIdleConns(opt.MaxIdle) - db.SetMaxIdleConns(opt.MaxOpen) - dialect = PostgresDialect{LowercaseFields: true} - case "sqlite3": - db, err = sql.Open(opt.Type, opt.Dsn) - if err != nil { - return err - } - if opt.Dsn == ":memory:" { - db.SetMaxOpenConns(1) - } - dialect = SqliteDialect{} - case "mysql": - db, err = sql.Open(opt.Type, opt.Dsn) - if err != nil { - return err - } - db.SetMaxIdleConns(opt.MaxIdle) - db.SetMaxIdleConns(opt.MaxOpen) - dialect = MySQLDialect{Engine: "InnoDB", Encoding: "utf8mb4"} - default: - return errors.New("unrecognized database driver") - } - dm = &DbMap{ - Db: db, - Dialect: dialect, - } - return nil -} - -// 拿到数据库操作的句柄 -func GetDbMap() *DbMap { - if dm == nil { - InitDb(DbOption{ - Type: "sqlite3", - Dsn: ":memory:", - }) - } - return dm -} - -// DbMap is the root sqldb mapping object. Create one of these for each -// database schema you wish to map. Each DbMap contains a list of -// mapped tables. -// -// Example: -// -// dialect := sqldb.MySQLDialect{"InnoDB", "UTF8"} -// dbmap := &sqldb.DbMap{Db: db, Dialect: dialect} -type DbMap struct { - ctx context.Context - - // Db handle to use with this map - Db *sql.DB - - // Dialect implementation to use with this map - Dialect Dialect - - TypeConverter TypeConverter - - // ExpandSlices when enabled will convert slice arguments in mappers into flat - // values. It will modify the query, adding more placeholders, and the mapper, - // adding each item of the slice as a new unique entry in the mapper. For - // example, given the scenario bellow: - // - // dbmap.Select(&output, "SELECT 1 FROM example WHERE id IN (:IDs)", map[string]interface{}{ - // "IDs": []int64{1, 2, 3}, - // }) - // - // The executed query would be: - // - // SELECT 1 FROM example WHERE id IN (:IDs0,:IDs1,:IDs2) - // - // With the mapper: - // - // map[string]interface{}{ - // "IDs": []int64{1, 2, 3}, - // "IDs0": int64(1), - // "IDs1": int64(2), - // "IDs2": int64(3), - // } - // - // It is also flexible for custom slice types. The value just need to - // implement stringer or numberer interfaces. - // - // type CustomValue string - // - // const ( - // CustomValueHey CustomValue = "hey" - // CustomValueOh CustomValue = "oh" - // ) - // - // type CustomValues []CustomValue - // - // func (c CustomValues) ToStringSlice() []string { - // values := make([]string, len(c)) - // for i := range c { - // values[i] = string(c[i]) - // } - // return values - // } - // - // func query() { - // // ... - // result, err := dbmap.Select(&output, "SELECT 1 FROM example WHERE value IN (:Values)", map[string]interface{}{ - // "Values": CustomValues([]CustomValue{CustomValueHey}), - // }) - // // ... - // } - ExpandSliceArgs bool - - tables []*TableMap - tablesDynamic map[string]*TableMap // tables that use same go-struct and different db table names - logger SqldbLogger - logPrefix string -} - -func (m *DbMap) dynamicTableAdd(tableName string, tbl *TableMap) { - if m.tablesDynamic == nil { - m.tablesDynamic = make(map[string]*TableMap) - } - m.tablesDynamic[tableName] = tbl -} - -func (m *DbMap) dynamicTableFind(tableName string) (*TableMap, bool) { - if m.tablesDynamic == nil { - return nil, false - } - tbl, found := m.tablesDynamic[tableName] - return tbl, found -} - -func (m *DbMap) dynamicTableMap() map[string]*TableMap { - if m.tablesDynamic == nil { - m.tablesDynamic = make(map[string]*TableMap) - } - return m.tablesDynamic -} - -func (m *DbMap) WithContext(ctx context.Context) SqlExecutor { - copy := &DbMap{} - *copy = *m - copy.ctx = ctx - return copy -} - -func (m *DbMap) CreateIndex() error { - var err error - dialect := reflect.TypeOf(m.Dialect) - for _, table := range m.tables { - for _, index := range table.indexes { - err = m.createIndexImpl(dialect, table, index) - if err != nil { - break - } - } - } - - for _, table := range m.dynamicTableMap() { - for _, index := range table.indexes { - err = m.createIndexImpl(dialect, table, index) - if err != nil { - break - } - } - } - - return err -} - -func (m *DbMap) createIndexImpl(dialect reflect.Type, - table *TableMap, - index *IndexMap) error { - s := bytes.Buffer{} - s.WriteString("create") - if index.Unique { - s.WriteString(" unique") - } - s.WriteString(" index") - s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName)) - if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" { - s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) - } - s.WriteString(" (") - for x, col := range index.columns { - if x > 0 { - s.WriteString(", ") - } - s.WriteString(m.Dialect.QuoteField(col)) - } - s.WriteString(")") - - if dname := dialect.Name(); dname == "MySQLDialect" && index.IndexType != "" { - s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) - } - s.WriteString(";") - _, err := m.Exec(s.String()) - return err -} - -func (t *TableMap) DropIndex(name string) error { - - var err error - dialect := reflect.TypeOf(t.dbmap.Dialect) - for _, idx := range t.indexes { - if idx.IndexName == name { - s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName)) - - if dname := dialect.Name(); dname == "MySQLDialect" { - s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName)) - } - s.WriteString(";") - _, e := t.dbmap.Exec(s.String()) - if e != nil { - err = e - } - break - } - } - t.ResetSql() - return err -} - -// AddTable registers the given interface type with sqldb. The table name -// will be given the name of the TypeOf(i). You must call this function, -// or AddTableWithName, for any struct type you wish to persist with -// the given DbMap. -// -// This operation is idempotent. If i's type is already mapped, the -// existing *TableMap is returned -func (m *DbMap) AddTable(i interface{}) *TableMap { - return m.AddTableWithName(i, "") -} - -// AddTableWithName has the same behavior as AddTable, but sets -// table.TableName to name. -func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { - return m.AddTableWithNameAndSchema(i, "", name) -} - -// AddTableWithNameAndSchema has the same behavior as AddTable, but sets -// table.TableName to name. -func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap { - t := reflect.TypeOf(i) - if name == "" { - name = t.Name() - } - - // check if we have a table for this type already - // if so, update the name and return the existing pointer - for i := range m.tables { - table := m.tables[i] - if table.gotype == t { - table.TableName = name - return table - } - } - - tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} - var primaryKey []*ColumnMap - tmap.Columns, primaryKey = m.readStructColumns(t) - m.tables = append(m.tables, tmap) - if len(primaryKey) > 0 { - tmap.keys = append(tmap.keys, primaryKey...) - } - - return tmap -} - -// AddTableDynamic registers the given interface type with sqldb. -// The table name will be dynamically determined at runtime by -// using the GetTableName method on DynamicTable interface -func (m *DbMap) AddTableDynamic(inp DynamicTable, schema string) *TableMap { - - val := reflect.ValueOf(inp) - elm := val.Elem() - t := elm.Type() - name := inp.TableName() - if name == "" { - panic("Missing table name in DynamicTable instance") - } - - // Check if there is another dynamic table with the same name - if _, found := m.dynamicTableFind(name); found { - panic(fmt.Sprintf("A table with the same name %v already exists", name)) - } - - tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} - var primaryKey []*ColumnMap - tmap.Columns, primaryKey = m.readStructColumns(t) - if len(primaryKey) > 0 { - tmap.keys = append(tmap.keys, primaryKey...) - } - - m.dynamicTableAdd(name, tmap) - - return tmap -} - -func (m *DbMap) readStructColumns(t reflect.Type) (cols []*ColumnMap, primaryKey []*ColumnMap) { - primaryKey = make([]*ColumnMap, 0) - n := t.NumField() - for i := 0; i < n; i++ { - f := t.Field(i) - if f.Anonymous && f.Type.Kind() == reflect.Struct { - // Recursively add nested fields in embedded structs. - subcols, subpk := m.readStructColumns(f.Type) - // Don't append nested fields that have the same field - // name as an already-mapped field. - for _, subcol := range subcols { - shouldAppend := true - for _, col := range cols { - if !subcol.Transient && subcol.fieldName == col.fieldName { - shouldAppend = false - break - } - } - if shouldAppend { - cols = append(cols, subcol) - } - } - if subpk != nil { - primaryKey = append(primaryKey, subpk...) - } - } else { - // Tag = Name { ',' Option } - // Option = OptionKey [ ':' OptionValue ] - cArguments := strings.Split(f.Tag.Get("db"), ",") - columnName := cArguments[0] - var maxSize int - var defaultValue string - var isAuto bool - var isPK bool - var isNotNull bool - for _, argString := range cArguments[1:] { - argString = strings.TrimSpace(argString) - arg := strings.SplitN(argString, ":", 2) - - // check mandatory/unexpected option values - switch arg[0] { - case "size", "default": - // options requiring value - if len(arg) == 1 { - panic(fmt.Sprintf("missing option value for option %v on field %v", arg[0], f.Name)) - } - default: - // options where value is invalid (currently all other options) - if len(arg) == 2 { - panic(fmt.Sprintf("unexpected option value for option %v on field %v", arg[0], f.Name)) - } - } - - switch arg[0] { - case "size": - maxSize, _ = strconv.Atoi(arg[1]) - case "default": - defaultValue = arg[1] - case "primarykey": - isPK = true - case "autoincrement": - isAuto = true - case "notnull": - isNotNull = true - default: - panic(fmt.Sprintf("Unrecognized tag option for field %v: %v", f.Name, arg)) - } - } - if columnName == "" { - columnName = f.Name - } - - gotype := f.Type - valueType := gotype - if valueType.Kind() == reflect.Ptr { - valueType = valueType.Elem() - } - value := reflect.New(valueType).Interface() - if m.TypeConverter != nil { - // Make a new pointer to a value of type gotype and - // pass it to the TypeConverter's FromDb method to see - // if a different type should be used for the column - // type during table creation. - scanner, useHolder := m.TypeConverter.FromDb(value) - if useHolder { - value = scanner.Holder - gotype = reflect.TypeOf(value) - } - } - if typer, ok := value.(SqlTyper); ok { - gotype = reflect.TypeOf(typer.SqlType()) - } else if typer, ok := value.(legacySqlTyper); ok { - log.Printf("Deprecation Warning: update your SqlType methods to return a driver.Value") - gotype = reflect.TypeOf(typer.SqlType()) - } else if valuer, ok := value.(driver.Valuer); ok { - // Only check for driver.Valuer if SqlTyper wasn't - // found. - v, err := valuer.Value() - if err == nil && v != nil { - gotype = reflect.TypeOf(v) - } - } - cm := &ColumnMap{ - ColumnName: columnName, - DefaultValue: defaultValue, - Transient: columnName == "-", - fieldName: f.Name, - gotype: gotype, - isPK: isPK, - isAutoIncr: isAuto, - isNotNull: isNotNull, - MaxSize: maxSize, - } - if isPK { - primaryKey = append(primaryKey, cm) - } - // Check for nested fields of the same field name and - // override them. - shouldAppend := true - for index, col := range cols { - if !col.Transient && col.fieldName == cm.fieldName { - cols[index] = cm - shouldAppend = false - break - } - } - if shouldAppend { - cols = append(cols, cm) - } - } - - } - return -} - -// CreateTables iterates through TableMaps registered to this DbMap and -// executes "create table" statements against the database for each. -// -// This is particularly useful in unit tests where you want to create -// and destroy the schema automatically. -func (m *DbMap) CreateTables() error { - return m.createTables(false) -} - -// CreateTablesIfNotExists is similar to CreateTables, but starts -// each statement with "create table if not exists" so that existing -// tables do not raise errors -func (m *DbMap) CreateTablesIfNotExists() error { - return m.createTables(true) -} - -func (m *DbMap) createTables(ifNotExists bool) error { - var err error - for i := range m.tables { - table := m.tables[i] - sql := table.SqlForCreate(ifNotExists) - _, err = m.Exec(sql) - if err != nil { - return err - } - } - - for _, tbl := range m.dynamicTableMap() { - sql := tbl.SqlForCreate(ifNotExists) - _, err = m.Exec(sql) - if err != nil { - return err - } - } - - return err -} - -// DropTable drops an individual table. -// Returns an error when the table does not exist. -func (m *DbMap) DropTable(table interface{}) error { - t := reflect.TypeOf(table) - - tableName := "" - if dyn, ok := table.(DynamicTable); ok { - tableName = dyn.TableName() - } - - return m.dropTable(t, tableName, false) -} - -// DropTableIfExists drops an individual table when the table exists. -func (m *DbMap) DropTableIfExists(table interface{}) error { - t := reflect.TypeOf(table) - - tableName := "" - if dyn, ok := table.(DynamicTable); ok { - tableName = dyn.TableName() - } - - return m.dropTable(t, tableName, true) -} - -// DropTables iterates through TableMaps registered to this DbMap and -// executes "drop table" statements against the database for each. -func (m *DbMap) DropTables() error { - return m.dropTables(false) -} - -// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to -// avoid errors for tables that do not exist. -func (m *DbMap) DropTablesIfExists() error { - return m.dropTables(true) -} - -// Goes through all the registered tables, dropping them one by one. -// If an error is encountered, then it is returned and the rest of -// the tables are not dropped. -func (m *DbMap) dropTables(addIfExists bool) (err error) { - for _, table := range m.tables { - err = m.dropTableImpl(table, addIfExists) - if err != nil { - return err - } - } - - for _, table := range m.dynamicTableMap() { - err = m.dropTableImpl(table, addIfExists) - if err != nil { - return err - } - } - - return err -} - -// Implementation of dropping a single table. -func (m *DbMap) dropTable(t reflect.Type, name string, addIfExists bool) error { - table := tableOrNil(m, t, name) - if table == nil { - return fmt.Errorf("table %s was not registered", table.TableName) - } - - return m.dropTableImpl(table, addIfExists) -} - -func (m *DbMap) dropTableImpl(table *TableMap, ifExists bool) (err error) { - tableDrop := "drop table" - if ifExists { - tableDrop = m.Dialect.IfTableExists(tableDrop, table.SchemaName, table.TableName) - } - _, err = m.Exec(fmt.Sprintf("%s %s;", tableDrop, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) - return err -} - -// TruncateTables iterates through TableMaps registered to this DbMap and -// executes "truncate table" statements against the database for each, or in the case of -// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization -// (http://www.sqlite.org/lang_delete.html) -func (m *DbMap) TruncateTables() error { - var err error - for i := range m.tables { - table := m.tables[i] - _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) - if e != nil { - err = e - } - } - - for _, table := range m.dynamicTableMap() { - _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) - if e != nil { - err = e - } - } - - return err -} - -// Insert runs a SQL INSERT statement for each element in list. List -// items must be pointers. -// -// Any interface whose TableMap has an auto-increment primary key will -// have its last insert id bound to the PK field on the struct. -// -// The hook functions PreInsert() and/or PostInsert() will be executed -// before/after the INSERT statement if the interface defines them. -// -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Insert(list ...interface{}) error { - return insert(m, m, list...) -} - -// Update runs a SQL UPDATE statement for each element in list. List -// items must be pointers. -// -// The hook functions PreUpdate() and/or PostUpdate() will be executed -// before/after the UPDATE statement if the interface defines them. -// -// Returns the number of rows updated. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Update(list ...interface{}) (int64, error) { - return update(m, m, nil, list...) -} - -// UpdateColumns runs a SQL UPDATE statement for each element in list. List -// items must be pointers. -// -// Only the columns accepted by filter are included in the UPDATE. -// -// The hook functions PreUpdate() and/or PostUpdate() will be executed -// before/after the UPDATE statement if the interface defines them. -// -// Returns the number of rows updated. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) { - return update(m, m, filter, list...) -} - -// Delete runs a SQL DELETE statement for each element in list. List -// items must be pointers. -// -// The hook functions PreDelete() and/or PostDelete() will be executed -// before/after the DELETE statement if the interface defines them. -// -// Returns the number of rows deleted. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Delete(list ...interface{}) (int64, error) { - return delete(m, m, list...) -} - -// Get runs a SQL SELECT to fetch a single row from the table based on the -// primary key(s) -// -// i should be an empty value for the struct to load. keys should be -// the primary key value(s) for the row to load. If multiple keys -// exist on the table, the order should match the column order -// specified in SetKeys() when the table mapping was defined. -// -// The hook function PostGet() will be executed after the SELECT -// statement if the interface defines them. -// -// Returns a pointer to a struct that matches or nil if no row is found. -// -// Returns an error if SetKeys has not been called on the TableMap -// Panics if any interface in the list has not been registered with AddTable -func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) { - return get(m, m, i, keys...) -} - -// Select runs an arbitrary SQL query, binding the columns in the result -// to fields on the struct specified by i. args represent the bind -// parameters for the SQL statement. -// -// Column names on the SELECT statement should be aliased to the field names -// on the struct i. Returns an error if one or more columns in the result -// do not match. It is OK if fields on i are not part of the SQL -// statement. -// -// The hook function PostGet() will be executed after the SELECT -// statement if the interface defines them. -// -// Values are returned in one of two ways: -// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to -// matching rows of type i. -// 2. If i is a pointer to a slice, the results will be appended to that slice -// and nil returned. -// -// i does NOT need to be registered with AddTable() -func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return hookedselect(m, m, i, query, args...) -} - -// Exec runs an arbitrary SQL statement. args represent the bind parameters. -// This is equivalent to running: Exec() using database/sql -func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - if m.logger != nil { - now := time.Now() - defer m.trace(now, query, args...) - } - return maybeExpandNamedQueryAndExec(m, query, args...) -} - -// SelectInt is a convenience wrapper around the sqldb.SelectInt function -func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectInt(m, query, args...) -} - -// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function -func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectNullInt(m, query, args...) -} - -// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function -func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectFloat(m, query, args...) -} - -// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function -func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectNullFloat(m, query, args...) -} - -// SelectStr is a convenience wrapper around the sqldb.SelectStr function -func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectStr(m, query, args...) -} - -// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function -func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectNullStr(m, query, args...) -} - -// SelectOne is a convenience wrapper around the sqldb.SelectOne function -func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectOne(m, m, holder, query, args...) -} - -// Begin starts a sqldb Transaction -func (m *DbMap) Begin() (*Transaction, error) { - if m.logger != nil { - now := time.Now() - defer m.trace(now, "begin;") - } - tx, err := begin(m) - if err != nil { - return nil, err - } - return &Transaction{ - dbmap: m, - tx: tx, - closed: false, - }, nil -} - -// TableFor returns the *TableMap corresponding to the given Go Type -// If no table is mapped to that type an error is returned. -// If checkPK is true and the mapped table has no registered PKs, an error is returned. -func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) { - table := tableOrNil(m, t, "") - if table == nil { - return nil, fmt.Errorf("no table found for type: %v", t.Name()) - } - - if checkPK && len(table.keys) < 1 { - e := fmt.Sprintf("sqldb: no keys defined for table: %s", - table.TableName) - return nil, errors.New(e) - } - - return table, nil -} - -// DynamicTableFor returns the *TableMap for the dynamic table corresponding -// to the input tablename -// If no table is mapped to that tablename an error is returned. -// If checkPK is true and the mapped table has no registered PKs, an error is returned. -func (m *DbMap) DynamicTableFor(tableName string, checkPK bool) (*TableMap, error) { - table, found := m.dynamicTableFind(tableName) - if !found { - return nil, fmt.Errorf("sqldb: no table found for name: %v", tableName) - } - - if checkPK && len(table.keys) < 1 { - e := fmt.Sprintf("sqldb: no keys defined for table: %s", - table.TableName) - return nil, errors.New(e) - } - - return table, nil -} - -// Prepare creates a prepared statement for later queries or executions. -// Multiple queries or executions may be run concurrently from the returned statement. -// This is equivalent to running: Prepare() using database/sql -func (m *DbMap) Prepare(query string) (*sql.Stmt, error) { - if m.logger != nil { - now := time.Now() - defer m.trace(now, query, nil) - } - return prepare(m, query) -} - -func tableOrNil(m *DbMap, t reflect.Type, name string) *TableMap { - if name != "" { - // Search by table name (dynamic tables) - if table, found := m.dynamicTableFind(name); found { - return table - } - return nil - } - - for i := range m.tables { - table := m.tables[i] - if table.gotype == t { - return table - } - } - return nil -} - -func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) { - ptrv := reflect.ValueOf(ptr) - if ptrv.Kind() != reflect.Ptr { - e := fmt.Sprintf("sqldb: passed non-pointer: %v (kind=%v)", ptr, - ptrv.Kind()) - return nil, reflect.Value{}, errors.New(e) - } - elem := ptrv.Elem() - ifc := elem.Interface() - var t *TableMap - var err error - tableName := "" - if dyn, isDyn := ptr.(DynamicTable); isDyn { - tableName = dyn.TableName() - t, err = m.DynamicTableFor(tableName, checkPK) - } else { - etype := reflect.TypeOf(ifc) - t, err = m.TableFor(etype, checkPK) - } - - if err != nil { - return nil, reflect.Value{}, err - } - - return t, elem, nil -} - -func (m *DbMap) QueryRow(query string, args ...interface{}) *sql.Row { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - if m.logger != nil { - now := time.Now() - defer m.trace(now, query, args...) - } - return queryRow(m, query, args...) -} - -func (m *DbMap) Query(q string, args ...interface{}) (*sql.Rows, error) { - if m.ExpandSliceArgs { - expandSliceArgs(&q, args...) - } - - if m.logger != nil { - now := time.Now() - defer m.trace(now, q, args...) - } - return query(m, q, args...) -} - -func (m *DbMap) trace(started time.Time, query string, args ...interface{}) { - if m.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - if m.logger != nil { - var margs = argsString(args...) - m.logger.Printf("%s%s [%s] (%v)", m.logPrefix, query, margs, (time.Now().Sub(started))) - } -} - -type stringer interface { - ToStringSlice() []string -} - -type numberer interface { - ToInt64Slice() []int64 -} - -func expandSliceArgs(query *string, args ...interface{}) { - for _, arg := range args { - mapper, ok := arg.(map[string]interface{}) - if !ok { - continue - } - - for key, value := range mapper { - var replacements []string - - // add flexibility for any custom type to be convert to one of the - // acceptable formats. - if v, ok := value.(stringer); ok { - value = v.ToStringSlice() - } - if v, ok := value.(numberer); ok { - value = v.ToInt64Slice() - } - - switch v := value.(type) { - case []string: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint8: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint16: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint32: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []uint64: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int8: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int16: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int32: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []int64: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []float32: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - case []float64: - for id, replace := range v { - mapper[fmt.Sprintf("%s%d", key, id)] = replace - replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) - } - default: - continue - } - - if len(replacements) == 0 { - continue - } - - *query = strings.Replace(*query, fmt.Sprintf(":%s", key), strings.Join(replacements, ","), -1) - } - } -} diff --git a/gdb/sqldb/db_test.go b/gdb/sqldb/db_test.go deleted file mode 100644 index 41ce8b0..0000000 --- a/gdb/sqldb/db_test.go +++ /dev/null @@ -1,187 +0,0 @@ -// -// db_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -//go:build integration -// +build integration - -package sqldb_test - -import ( - "testing" -) - -type customType1 []string - -func (c customType1) ToStringSlice() []string { - return []string(c) -} - -type customType2 []int64 - -func (c customType2) ToInt64Slice() []int64 { - return []int64(c) -} - -func TestDbMap_Select_expandSliceArgs(t *testing.T) { - tests := []struct { - description string - query string - args []interface{} - wantLen int - }{ - { - description: "it should handle slice placeholders correctly", - query: ` -SELECT 1 FROM crazy_table -WHERE field1 = :Field1 -AND field2 IN (:FieldStringList) -AND field3 IN (:FieldUIntList) -AND field4 IN (:FieldUInt8List) -AND field5 IN (:FieldUInt16List) -AND field6 IN (:FieldUInt32List) -AND field7 IN (:FieldUInt64List) -AND field8 IN (:FieldIntList) -AND field9 IN (:FieldInt8List) -AND field10 IN (:FieldInt16List) -AND field11 IN (:FieldInt32List) -AND field12 IN (:FieldInt64List) -AND field13 IN (:FieldFloat32List) -AND field14 IN (:FieldFloat64List) -`, - args: []interface{}{ - map[string]interface{}{ - "Field1": 123, - "FieldStringList": []string{"h", "e", "y"}, - "FieldUIntList": []uint{1, 2, 3, 4}, - "FieldUInt8List": []uint8{1, 2, 3, 4}, - "FieldUInt16List": []uint16{1, 2, 3, 4}, - "FieldUInt32List": []uint32{1, 2, 3, 4}, - "FieldUInt64List": []uint64{1, 2, 3, 4}, - "FieldIntList": []int{1, 2, 3, 4}, - "FieldInt8List": []int8{1, 2, 3, 4}, - "FieldInt16List": []int16{1, 2, 3, 4}, - "FieldInt32List": []int32{1, 2, 3, 4}, - "FieldInt64List": []int64{1, 2, 3, 4}, - "FieldFloat32List": []float32{1, 2, 3, 4}, - "FieldFloat64List": []float64{1, 2, 3, 4}, - }, - }, - wantLen: 1, - }, - { - description: "it should handle slice placeholders correctly with custom types", - query: ` -SELECT 1 FROM crazy_table -WHERE field2 IN (:FieldStringList) -AND field12 IN (:FieldIntList) -`, - args: []interface{}{ - map[string]interface{}{ - "FieldStringList": customType1{"h", "e", "y"}, - "FieldIntList": customType2{1, 2, 3, 4}, - }, - }, - wantLen: 3, - }, - } - - type dataFormat struct { - Field1 int `db:"field1"` - Field2 string `db:"field2"` - Field3 uint `db:"field3"` - Field4 uint8 `db:"field4"` - Field5 uint16 `db:"field5"` - Field6 uint32 `db:"field6"` - Field7 uint64 `db:"field7"` - Field8 int `db:"field8"` - Field9 int8 `db:"field9"` - Field10 int16 `db:"field10"` - Field11 int32 `db:"field11"` - Field12 int64 `db:"field12"` - Field13 float32 `db:"field13"` - Field14 float64 `db:"field14"` - } - - dbmap := newDBMap(t) - dbmap.ExpandSliceArgs = true - dbmap.AddTableWithName(dataFormat{}, "crazy_table") - - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - err = dbmap.Insert( - &dataFormat{ - Field1: 123, - Field2: "h", - Field3: 1, - Field4: 1, - Field5: 1, - Field6: 1, - Field7: 1, - Field8: 1, - Field9: 1, - Field10: 1, - Field11: 1, - Field12: 1, - Field13: 1, - Field14: 1, - }, - &dataFormat{ - Field1: 124, - Field2: "e", - Field3: 2, - Field4: 2, - Field5: 2, - Field6: 2, - Field7: 2, - Field8: 2, - Field9: 2, - Field10: 2, - Field11: 2, - Field12: 2, - Field13: 2, - Field14: 2, - }, - &dataFormat{ - Field1: 125, - Field2: "y", - Field3: 3, - Field4: 3, - Field5: 3, - Field6: 3, - Field7: 3, - Field8: 3, - Field9: 3, - Field10: 3, - Field11: 3, - Field12: 3, - Field13: 3, - Field14: 3, - }, - ) - - if err != nil { - t.Fatal(err) - } - - for _, tt := range tests { - t.Run(tt.description, func(t *testing.T) { - var dummy []int - _, err := dbmap.Select(&dummy, tt.query, tt.args...) - if err != nil { - t.Fatal(err) - } - - if len(dummy) != tt.wantLen { - t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen) - } - }) - } -} diff --git a/gdb/sqldb/dialect.go b/gdb/sqldb/dialect.go deleted file mode 100644 index 88ecba3..0000000 --- a/gdb/sqldb/dialect.go +++ /dev/null @@ -1,108 +0,0 @@ -// -// dialect.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "reflect" -) - -// The Dialect interface encapsulates behaviors that differ across -// SQL databases. At present the Dialect is only used by CreateTables() -// but this could change in the future -type Dialect interface { - // adds a suffix to any query, usually ";" - QuerySuffix() string - - // ToSqlType returns the SQL column type to use when creating a - // table of the given Go Type. maxsize can be used to switch based on - // size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB, - // or LONGBLOB depending on the maxsize - ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string - - // string to append to primary key column definitions - AutoIncrStr() string - - // string to bind autoincrement columns to. Empty string will - // remove reference to those columns in the INSERT statement. - AutoIncrBindValue() string - - AutoIncrInsertSuffix(col *ColumnMap) string - - // string to append to "create table" statement for vendor specific - // table attributes - CreateTableSuffix() string - - // string to append to "create index" statement - CreateIndexSuffix() string - - // string to append to "drop index" statement - DropIndexSuffix() string - - // string to truncate tables - TruncateClause() string - - // bind variable string to use when forming SQL statements - // in many dbs it is "?", but Postgres appears to use $1 - // - // i is a zero based index of the bind variable in this statement - // - BindVar(i int) string - - // Handles quoting of a field name to ensure that it doesn't raise any - // SQL parsing exceptions by using a reserved word as a field name. - QuoteField(field string) string - - // Handles building up of a schema.database string that is compatible with - // the given dialect - // - // schema - The schema that lives in - // table - The table name - QuotedTableForQuery(schema string, table string) string - - // Existence clause for table creation / deletion - IfSchemaNotExists(command, schema string) string - IfTableExists(command, schema, table string) string - IfTableNotExists(command, schema, table string) string -} - -// IntegerAutoIncrInserter is implemented by dialects that can perform -// inserts with automatically incremented integer primary keys. If -// the dialect can handle automatic assignment of more than just -// integers, see TargetedAutoIncrInserter. -type IntegerAutoIncrInserter interface { - InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) -} - -// TargetedAutoIncrInserter is implemented by dialects that can -// perform automatic assignment of any primary key type (i.e. strings -// for uuids, integers for serials, etc). -type TargetedAutoIncrInserter interface { - // InsertAutoIncrToTarget runs an insert operation and assigns the - // automatically generated primary key directly to the passed in - // target. The target should be a pointer to the primary key - // field of the value being inserted. - InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error -} - -// TargetQueryInserter is implemented by dialects that can perform -// assignment of integer primary key type by executing a query -// like "select sequence.currval from dual". -type TargetQueryInserter interface { - // TargetQueryInserter runs an insert operation and assigns the - // automatically generated primary key retrived by the query - // extracted from the GeneratedIdQuery field of the id column. - InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error -} - -func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { - res, err := exec.Exec(insertSql, params...) - if err != nil { - return 0, err - } - return res.LastInsertId() -} diff --git a/gdb/sqldb/dialect_mysql.go b/gdb/sqldb/dialect_mysql.go deleted file mode 100644 index c60cfbe..0000000 --- a/gdb/sqldb/dialect_mysql.go +++ /dev/null @@ -1,172 +0,0 @@ -// -// dialect_mysql.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -// Implementation of Dialect for MySQL databases. -type MySQLDialect struct { - - // Engine is the storage engine to use "InnoDB" vs "MyISAM" for example - Engine string - - // Encoding is the character encoding to use for created tables - Encoding string -} - -func (d MySQLDialect) QuerySuffix() string { return ";" } - -func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { - switch val.Kind() { - case reflect.Ptr: - return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) - case reflect.Bool: - return "boolean" - case reflect.Int8: - return "tinyint" - case reflect.Uint8: - return "tinyint unsigned" - case reflect.Int16: - return "smallint" - case reflect.Uint16: - return "smallint unsigned" - case reflect.Int, reflect.Int32: - return "int" - case reflect.Uint, reflect.Uint32: - return "int unsigned" - case reflect.Int64: - return "bigint" - case reflect.Uint64: - return "bigint unsigned" - case reflect.Float64, reflect.Float32: - return "double" - case reflect.Slice: - if val.Elem().Kind() == reflect.Uint8 { - return "mediumblob" - } - } - - switch val.Name() { - case "NullInt64": - return "bigint" - case "NullFloat64": - return "double" - case "NullBool": - return "tinyint" - case "Time": - return "datetime" - } - - if maxsize < 1 { - maxsize = 255 - } - - /* == About varchar(N) == - * N is number of characters. - * A varchar column can store up to 65535 bytes. - * Remember that 1 character is 3 bytes in utf-8 charset. - * Also remember that each row can store up to 65535 bytes, - * and you have some overheads, so it's not possible for a - * varchar column to have 65535/3 characters really. - * So it would be better to use 'text' type in stead of - * large varchar type. - */ - if maxsize < 256 { - return fmt.Sprintf("varchar(%d)", maxsize) - } else { - return "text" - } -} - -// Returns auto_increment -func (d MySQLDialect) AutoIncrStr() string { - return "auto_increment" -} - -func (d MySQLDialect) AutoIncrBindValue() string { - return "null" -} - -func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { - return "" -} - -// Returns engine=%s charset=%s based on values stored on struct -func (d MySQLDialect) CreateTableSuffix() string { - if d.Engine == "" || d.Encoding == "" { - msg := "sqldb - undefined" - - if d.Engine == "" { - msg += " MySQLDialect.Engine" - } - if d.Engine == "" && d.Encoding == "" { - msg += "," - } - if d.Encoding == "" { - msg += " MySQLDialect.Encoding" - } - msg += ". Check that your MySQLDialect was correctly initialized when declared." - panic(msg) - } - - return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding) -} - -func (d MySQLDialect) CreateIndexSuffix() string { - return "using" -} - -func (d MySQLDialect) DropIndexSuffix() string { - return "on" -} - -func (d MySQLDialect) TruncateClause() string { - return "truncate" -} - -func (d MySQLDialect) SleepClause(s time.Duration) string { - return fmt.Sprintf("sleep(%f)", s.Seconds()) -} - -// Returns "?" -func (d MySQLDialect) BindVar(i int) string { - return "?" -} - -func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { - return standardInsertAutoIncr(exec, insertSql, params...) -} - -func (d MySQLDialect) QuoteField(f string) string { - return "`" + f + "`" -} - -func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { - if strings.TrimSpace(schema) == "" { - return d.QuoteField(table) - } - - return schema + "." + d.QuoteField(table) -} - -func (d MySQLDialect) IfSchemaNotExists(command, schema string) string { - return fmt.Sprintf("%s if not exists", command) -} - -func (d MySQLDialect) IfTableExists(command, schema, table string) string { - return fmt.Sprintf("%s if exists", command) -} - -func (d MySQLDialect) IfTableNotExists(command, schema, table string) string { - return fmt.Sprintf("%s if not exists", command) -} diff --git a/gdb/sqldb/dialect_mysql_test.go b/gdb/sqldb/dialect_mysql_test.go deleted file mode 100644 index e60bc9e..0000000 --- a/gdb/sqldb/dialect_mysql_test.go +++ /dev/null @@ -1,195 +0,0 @@ -// -// dialect_mysql_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -//go:build !integration -// +build !integration - -package sqldb_test - -import ( - "database/sql" - "errors" - "fmt" - "reflect" - "testing" - "time" - - "git.hexq.cn/tiglog/golib/gdb/sqldb" - "github.com/poy/onpar" - "github.com/poy/onpar/expect" - "github.com/poy/onpar/matchers" -) - -func TestMySQLDialect(t *testing.T) { - // o := onpar.New(t) - // defer o.Run() - - type testContext struct { - t *testing.T - dialect sqldb.MySQLDialect - } - - o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext { - return testContext{ - t: t, - dialect: sqldb.MySQLDialect{Engine: "foo", Encoding: "bar"}, - } - }) - defer o.Run() - - o.Group("ToSqlType", func() { - tests := []struct { - name string - value interface{} - maxSize int - autoIncr bool - expected string - }{ - {"bool", true, 0, false, "boolean"}, - {"int8", int8(1), 0, false, "tinyint"}, - {"uint8", uint8(1), 0, false, "tinyint unsigned"}, - {"int16", int16(1), 0, false, "smallint"}, - {"uint16", uint16(1), 0, false, "smallint unsigned"}, - {"int32", int32(1), 0, false, "int"}, - {"int (treated as int32)", int(1), 0, false, "int"}, - {"uint32", uint32(1), 0, false, "int unsigned"}, - {"uint (treated as uint32)", uint(1), 0, false, "int unsigned"}, - {"int64", int64(1), 0, false, "bigint"}, - {"uint64", uint64(1), 0, false, "bigint unsigned"}, - {"float32", float32(1), 0, false, "double"}, - {"float64", float64(1), 0, false, "double"}, - {"[]uint8", []uint8{1}, 0, false, "mediumblob"}, - {"NullInt64", sql.NullInt64{}, 0, false, "bigint"}, - {"NullFloat64", sql.NullFloat64{}, 0, false, "double"}, - {"NullBool", sql.NullBool{}, 0, false, "tinyint"}, - {"Time", time.Time{}, 0, false, "datetime"}, - {"default-size string", "", 0, false, "varchar(255)"}, - {"sized string", "", 50, false, "varchar(50)"}, - {"large string", "", 1024, false, "text"}, - } - for _, t := range tests { - o.Spec(t.name, func(tt testContext) { - typ := reflect.TypeOf(t.value) - sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr) - expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected)) - }) - } - }) - - o.Spec("AutoIncrStr", func(tt testContext) { - expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("auto_increment")) - }) - - o.Spec("AutoIncrBindValue", func(tt testContext) { - expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("null")) - }) - - o.Spec("AutoIncrInsertSuffix", func(tt testContext) { - expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(nil)).To(matchers.Equal("")) - }) - - o.Group("CreateTableSuffix", func() { - o.Group("with an empty engine", func() { - o1 := onpar.BeforeEach(o, func(tt testContext) testContext { - tt.dialect.Encoding = "" - return tt - }) - o1.Spec("panics", func(tt testContext) { - expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic()) - }) - }) - - o.Group("with an empty encoding", func() { - o2 := onpar.BeforeEach(o, func(tt testContext) testContext { - tt.dialect.Encoding = "" - return tt - }) - o2.Spec("panics", func(tt testContext) { - expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic()) - }) - }) - - o.Spec("with an engine and an encoding", func(tt testContext) { - expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(" engine=foo charset=bar")) - }) - }) - - o.Spec("CreateIndexSuffix", func(tt testContext) { - expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using")) - }) - - o.Spec("DropIndexSuffix", func(tt testContext) { - expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("on")) - }) - - o.Spec("TruncateClause", func(tt testContext) { - expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate")) - }) - - o.Spec("SleepClause", func(tt testContext) { - expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("sleep(1.000000)")) - expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("sleep(0.100000)")) - }) - - o.Spec("BindVar", func(tt testContext) { - expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("?")) - }) - - o.Spec("QuoteField", func(tt testContext) { - expect.Expect(t, tt.dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) - }) - - o.Group("QuotedTableForQuery", func() { - o.Spec("using the default schema", func(tt testContext) { - expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal("`foo`")) - }) - - o.Spec("with a supplied schema", func(tt testContext) { - expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`")) - }) - }) - - o.Spec("IfSchemaNotExists", func(tt testContext) { - expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists")) - }) - - o.Spec("IfTableExists", func(tt testContext) { - expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists")) - }) - - o.Spec("IfTableNotExists", func(tt testContext) { - expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists")) - }) -} - -type panicMatcher struct { -} - -func Panic() panicMatcher { - return panicMatcher{} -} - -func (m panicMatcher) Match(actual interface{}) (resultValue interface{}, err error) { - switch f := actual.(type) { - case func(): - panicked := false - func() { - defer func() { - if r := recover(); r != nil { - panicked = true - } - }() - f() - }() - if panicked { - return f, nil - } - return f, errors.New("function did not panic") - default: - return f, fmt.Errorf("%T is not func()", f) - } -} diff --git a/gdb/sqldb/dialect_oracle.go b/gdb/sqldb/dialect_oracle.go deleted file mode 100644 index 127c857..0000000 --- a/gdb/sqldb/dialect_oracle.go +++ /dev/null @@ -1,142 +0,0 @@ -// -// dialect_oracle.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "fmt" - "reflect" - "strings" -) - -// Implementation of Dialect for Oracle databases. -type OracleDialect struct{} - -func (d OracleDialect) QuerySuffix() string { return "" } - -func (d OracleDialect) CreateIndexSuffix() string { return "" } - -func (d OracleDialect) DropIndexSuffix() string { return "" } - -func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { - switch val.Kind() { - case reflect.Ptr: - return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) - case reflect.Bool: - return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - if isAutoIncr { - return "serial" - } - return "integer" - case reflect.Int64, reflect.Uint64: - if isAutoIncr { - return "bigserial" - } - return "bigint" - case reflect.Float64: - return "double precision" - case reflect.Float32: - return "real" - case reflect.Slice: - if val.Elem().Kind() == reflect.Uint8 { - return "bytea" - } - } - - switch val.Name() { - case "NullInt64": - return "bigint" - case "NullFloat64": - return "double precision" - case "NullBool": - return "boolean" - case "NullTime", "Time": - return "timestamp with time zone" - } - - if maxsize > 0 { - return fmt.Sprintf("varchar(%d)", maxsize) - } else { - return "text" - } - -} - -// Returns empty string -func (d OracleDialect) AutoIncrStr() string { - return "" -} - -func (d OracleDialect) AutoIncrBindValue() string { - return "NULL" -} - -func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string { - return "" -} - -// Returns suffix -func (d OracleDialect) CreateTableSuffix() string { - return "" -} - -func (d OracleDialect) TruncateClause() string { - return "truncate" -} - -// Returns "$(i+1)" -func (d OracleDialect) BindVar(i int) string { - return fmt.Sprintf(":%d", i+1) -} - -// After executing the insert uses the ColMap IdQuery to get the generated id -func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error { - _, err := exec.Exec(insertSql, params...) - if err != nil { - return err - } - id, err := exec.SelectInt(idSql) - if err != nil { - return err - } - switch target.(type) { - case *int64: - *(target.(*int64)) = id - case *int32: - *(target.(*int32)) = int32(id) - case int: - *(target.(*int)) = int(id) - default: - return fmt.Errorf("Id field can be int, int32 or int64") - } - return nil -} - -func (d OracleDialect) QuoteField(f string) string { - return `"` + strings.ToUpper(f) + `"` -} - -func (d OracleDialect) QuotedTableForQuery(schema string, table string) string { - if strings.TrimSpace(schema) == "" { - return d.QuoteField(table) - } - - return schema + "." + d.QuoteField(table) -} - -func (d OracleDialect) IfSchemaNotExists(command, schema string) string { - return fmt.Sprintf("%s if not exists", command) -} - -func (d OracleDialect) IfTableExists(command, schema, table string) string { - return fmt.Sprintf("%s if exists", command) -} - -func (d OracleDialect) IfTableNotExists(command, schema, table string) string { - return fmt.Sprintf("%s if not exists", command) -} diff --git a/gdb/sqldb/dialect_postgres.go b/gdb/sqldb/dialect_postgres.go deleted file mode 100644 index 2e17200..0000000 --- a/gdb/sqldb/dialect_postgres.go +++ /dev/null @@ -1,152 +0,0 @@ -// -// dialect_postgres.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "fmt" - "reflect" - "strings" - "time" -) - -type PostgresDialect struct { - suffix string - LowercaseFields bool -} - -func (d PostgresDialect) QuerySuffix() string { return ";" } - -func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { - switch val.Kind() { - case reflect.Ptr: - return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) - case reflect.Bool: - return "boolean" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: - if isAutoIncr { - return "serial" - } - return "integer" - case reflect.Int64, reflect.Uint64: - if isAutoIncr { - return "bigserial" - } - return "bigint" - case reflect.Float64: - return "double precision" - case reflect.Float32: - return "real" - case reflect.Slice: - if val.Elem().Kind() == reflect.Uint8 { - return "bytea" - } - } - - switch val.Name() { - case "NullInt64": - return "bigint" - case "NullFloat64": - return "double precision" - case "NullBool": - return "boolean" - case "Time", "NullTime": - return "timestamp with time zone" - } - - if maxsize > 0 { - return fmt.Sprintf("varchar(%d)", maxsize) - } else { - return "text" - } - -} - -// Returns empty string -func (d PostgresDialect) AutoIncrStr() string { - return "" -} - -func (d PostgresDialect) AutoIncrBindValue() string { - return "default" -} - -func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string { - return " returning " + d.QuoteField(col.ColumnName) -} - -// Returns suffix -func (d PostgresDialect) CreateTableSuffix() string { - return d.suffix -} - -func (d PostgresDialect) CreateIndexSuffix() string { - return "using" -} - -func (d PostgresDialect) DropIndexSuffix() string { - return "" -} - -func (d PostgresDialect) TruncateClause() string { - return "truncate" -} - -func (d PostgresDialect) SleepClause(s time.Duration) string { - return fmt.Sprintf("pg_sleep(%f)", s.Seconds()) -} - -// Returns "$(i+1)" -func (d PostgresDialect) BindVar(i int) string { - return fmt.Sprintf("$%d", i+1) -} - -func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { - rows, err := exec.Query(insertSql, params...) - if err != nil { - return err - } - defer rows.Close() - - if !rows.Next() { - return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err()) - } - if err := rows.Scan(target); err != nil { - return err - } - if rows.Next() { - return fmt.Errorf("more than two serial value returned for insert: %s", insertSql) - } - return rows.Err() -} - -func (d PostgresDialect) QuoteField(f string) string { - if d.LowercaseFields { - return `"` + strings.ToLower(f) + `"` - } - return `"` + f + `"` -} - -func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { - if strings.TrimSpace(schema) == "" { - return d.QuoteField(table) - } - - return schema + "." + d.QuoteField(table) -} - -func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { - return fmt.Sprintf("%s if not exists", command) -} - -func (d PostgresDialect) IfTableExists(command, schema, table string) string { - return fmt.Sprintf("%s if exists", command) -} - -func (d PostgresDialect) IfTableNotExists(command, schema, table string) string { - return fmt.Sprintf("%s if not exists", command) -} diff --git a/gdb/sqldb/dialect_postgres_test.go b/gdb/sqldb/dialect_postgres_test.go deleted file mode 100644 index 45ed541..0000000 --- a/gdb/sqldb/dialect_postgres_test.go +++ /dev/null @@ -1,161 +0,0 @@ -// -// dialect_postgres_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -//go:build !integration -// +build !integration - -package sqldb_test - -import ( - "database/sql" - "reflect" - "testing" - "time" - - "git.hexq.cn/tiglog/golib/gdb/sqldb" - "github.com/poy/onpar" - "github.com/poy/onpar/expect" - "github.com/poy/onpar/matchers" -) - -func TestPostgresDialect(t *testing.T) { - - type testContext struct { - t *testing.T - dialect sqldb.PostgresDialect - } - - o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext { - return testContext{ - t: t, - dialect: sqldb.PostgresDialect{ - LowercaseFields: false, - }, - } - }) - defer o.Run() - - o.Group("ToSqlType", func() { - tests := []struct { - name string - value interface{} - maxSize int - autoIncr bool - expected string - }{ - {"bool", true, 0, false, "boolean"}, - {"int8", int8(1), 0, false, "integer"}, - {"uint8", uint8(1), 0, false, "integer"}, - {"int16", int16(1), 0, false, "integer"}, - {"uint16", uint16(1), 0, false, "integer"}, - {"int32", int32(1), 0, false, "integer"}, - {"int (treated as int32)", int(1), 0, false, "integer"}, - {"uint32", uint32(1), 0, false, "integer"}, - {"uint (treated as uint32)", uint(1), 0, false, "integer"}, - {"int64", int64(1), 0, false, "bigint"}, - {"uint64", uint64(1), 0, false, "bigint"}, - {"float32", float32(1), 0, false, "real"}, - {"float64", float64(1), 0, false, "double precision"}, - {"[]uint8", []uint8{1}, 0, false, "bytea"}, - {"NullInt64", sql.NullInt64{}, 0, false, "bigint"}, - {"NullFloat64", sql.NullFloat64{}, 0, false, "double precision"}, - {"NullBool", sql.NullBool{}, 0, false, "boolean"}, - {"Time", time.Time{}, 0, false, "timestamp with time zone"}, - {"default-size string", "", 0, false, "text"}, - {"sized string", "", 50, false, "varchar(50)"}, - {"large string", "", 1024, false, "varchar(1024)"}, - } - for _, t := range tests { - o.Spec(t.name, func(tt testContext) { - typ := reflect.TypeOf(t.value) - sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr) - expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected)) - }) - } - }) - - o.Spec("AutoIncrStr", func(tt testContext) { - expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("")) - }) - - o.Spec("AutoIncrBindValue", func(tt testContext) { - expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("default")) - }) - - o.Spec("AutoIncrInsertSuffix", func(tt testContext) { - cm := sqldb.ColumnMap{ - ColumnName: "foo", - } - expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(&cm)).To(matchers.Equal(` returning "foo"`)) - }) - - o.Spec("CreateTableSuffix", func(tt testContext) { - expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal("")) - }) - - o.Spec("CreateIndexSuffix", func(tt testContext) { - expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using")) - }) - - o.Spec("DropIndexSuffix", func(tt testContext) { - expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("")) - }) - - o.Spec("TruncateClause", func(tt testContext) { - expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate")) - }) - - o.Spec("SleepClause", func(tt testContext) { - expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("pg_sleep(1.000000)")) - expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("pg_sleep(0.100000)")) - }) - - o.Spec("BindVar", func(tt testContext) { - expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("$1")) - expect.Expect(t, tt.dialect.BindVar(4)).To(matchers.Equal("$5")) - }) - - o.Group("QuoteField", func() { - o.Spec("By default, case is preserved", func(tt testContext) { - expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) - expect.Expect(t, tt.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`)) - }) - - o.Group("With LowercaseFields set to true", func() { - o1 := onpar.BeforeEach(o, func(tt testContext) testContext { - tt.dialect.LowercaseFields = true - return tt - }) - - o1.Spec("fields are lowercased", func(tt testContext) { - expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`)) - }) - }) - }) - - o.Group("QuotedTableForQuery", func() { - o.Spec("using the default schema", func(tt testContext) { - expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal(`"foo"`)) - }) - - o.Spec("with a supplied schema", func(tt testContext) { - expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`)) - }) - }) - - o.Spec("IfSchemaNotExists", func(tt testContext) { - expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists")) - }) - - o.Spec("IfTableExists", func(tt testContext) { - expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists")) - }) - - o.Spec("IfTableNotExists", func(tt testContext) { - expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists")) - }) -} diff --git a/gdb/sqldb/dialect_sqlite.go b/gdb/sqldb/dialect_sqlite.go deleted file mode 100644 index 72f6a72..0000000 --- a/gdb/sqldb/dialect_sqlite.go +++ /dev/null @@ -1,115 +0,0 @@ -// -// dialect_sqlite.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "fmt" - "reflect" -) - -type SqliteDialect struct { - suffix string -} - -func (d SqliteDialect) QuerySuffix() string { return ";" } - -func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { - switch val.Kind() { - case reflect.Ptr: - return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) - case reflect.Bool: - return "integer" - case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: - return "integer" - case reflect.Float64, reflect.Float32: - return "real" - case reflect.Slice: - if val.Elem().Kind() == reflect.Uint8 { - return "blob" - } - } - - switch val.Name() { - case "NullInt64": - return "integer" - case "NullFloat64": - return "real" - case "NullBool": - return "integer" - case "Time": - return "datetime" - } - - if maxsize < 1 { - maxsize = 255 - } - return fmt.Sprintf("varchar(%d)", maxsize) -} - -// Returns autoincrement -func (d SqliteDialect) AutoIncrStr() string { - return "autoincrement" -} - -func (d SqliteDialect) AutoIncrBindValue() string { - return "null" -} - -func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string { - return "" -} - -// Returns suffix -func (d SqliteDialect) CreateTableSuffix() string { - return d.suffix -} - -func (d SqliteDialect) CreateIndexSuffix() string { - return "" -} - -func (d SqliteDialect) DropIndexSuffix() string { - return "" -} - -// With sqlite, there technically isn't a TRUNCATE statement, -// but a DELETE FROM uses a truncate optimization: -// http://www.sqlite.org/lang_delete.html -func (d SqliteDialect) TruncateClause() string { - return "delete from" -} - -// Returns "?" -func (d SqliteDialect) BindVar(i int) string { - return "?" -} - -func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { - return standardInsertAutoIncr(exec, insertSql, params...) -} - -func (d SqliteDialect) QuoteField(f string) string { - return `"` + f + `"` -} - -// sqlite does not have schemas like PostgreSQL does, so just escape it like normal -func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string { - return d.QuoteField(table) -} - -func (d SqliteDialect) IfSchemaNotExists(command, schema string) string { - return fmt.Sprintf("%s if not exists", command) -} - -func (d SqliteDialect) IfTableExists(command, schema, table string) string { - return fmt.Sprintf("%s if exists", command) -} - -func (d SqliteDialect) IfTableNotExists(command, schema, table string) string { - return fmt.Sprintf("%s if not exists", command) -} diff --git a/gdb/sqldb/doc.go b/gdb/sqldb/doc.go deleted file mode 100644 index 3b84008..0000000 --- a/gdb/sqldb/doc.go +++ /dev/null @@ -1,13 +0,0 @@ -// -// doc.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -// Package sqldb provides a simple way to marshal Go structs to and from -// SQL databases. It uses the database/sql package, and should work with any -// compliant database/sql driver. -// -// Source code and project home: -package sqldb diff --git a/gdb/sqldb/errors.go b/gdb/sqldb/errors.go deleted file mode 100644 index eafba19..0000000 --- a/gdb/sqldb/errors.go +++ /dev/null @@ -1,34 +0,0 @@ -// -// errors.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "fmt" -) - -// A non-fatal error, when a select query returns columns that do not exist -// as fields in the struct it is being mapped to -// TODO: discuss wether this needs an error. encoding/json silently ignores missing fields -type NoFieldInTypeError struct { - TypeName string - MissingColNames []string -} - -func (err *NoFieldInTypeError) Error() string { - return fmt.Sprintf("sqldb: no fields %+v in type %s", err.MissingColNames, err.TypeName) -} - -// returns true if the error is non-fatal (ie, we shouldn't immediately return) -func NonFatalError(err error) bool { - switch err.(type) { - case *NoFieldInTypeError: - return true - default: - return false - } -} diff --git a/gdb/sqldb/hooks.go b/gdb/sqldb/hooks.go deleted file mode 100644 index 07c4918..0000000 --- a/gdb/sqldb/hooks.go +++ /dev/null @@ -1,45 +0,0 @@ -// -// hooks.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -//++ TODO v2-phase3: HasPostGet => PostGetter, HasPostDelete => PostDeleter, etc. - -// HasPostGet provides PostGet() which will be executed after the GET statement. -type HasPostGet interface { - PostGet(SqlExecutor) error -} - -// HasPostDelete provides PostDelete() which will be executed after the DELETE statement -type HasPostDelete interface { - PostDelete(SqlExecutor) error -} - -// HasPostUpdate provides PostUpdate() which will be executed after the UPDATE statement -type HasPostUpdate interface { - PostUpdate(SqlExecutor) error -} - -// HasPostInsert provides PostInsert() which will be executed after the INSERT statement -type HasPostInsert interface { - PostInsert(SqlExecutor) error -} - -// HasPreDelete provides PreDelete() which will be executed before the DELETE statement. -type HasPreDelete interface { - PreDelete(SqlExecutor) error -} - -// HasPreUpdate provides PreUpdate() which will be executed before UPDATE statement. -type HasPreUpdate interface { - PreUpdate(SqlExecutor) error -} - -// HasPreInsert provides PreInsert() which will be executed before INSERT statement. -type HasPreInsert interface { - PreInsert(SqlExecutor) error -} diff --git a/gdb/sqldb/index.go b/gdb/sqldb/index.go deleted file mode 100644 index c61d6fd..0000000 --- a/gdb/sqldb/index.go +++ /dev/null @@ -1,51 +0,0 @@ -// -// index.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -// IndexMap represents a mapping between a Go struct field and a single -// index in a table. -// Unique and MaxSize only inform the -// CreateTables() function and are not used by Insert/Update/Delete/Get. -type IndexMap struct { - // Index name in db table - IndexName string - - // If true, " unique" is added to create index statements. - // Not used elsewhere - Unique bool - - // Index type supported by Dialect - // Postgres: B-tree, Hash, GiST and GIN. - // Mysql: Btree, Hash. - // Sqlite: nil. - IndexType string - - // Columns name for single and multiple indexes - columns []string -} - -// Rename allows you to specify the index name in the table -// -// Example: table.IndMap("customer_test_idx").Rename("customer_idx") -func (idx *IndexMap) Rename(indname string) *IndexMap { - idx.IndexName = indname - return idx -} - -// SetUnique adds "unique" to the create index statements for this -// index, if b is true. -func (idx *IndexMap) SetUnique(b bool) *IndexMap { - idx.Unique = b - return idx -} - -// SetIndexType specifies the index type supported by chousen SQL Dialect -func (idx *IndexMap) SetIndexType(indtype string) *IndexMap { - idx.IndexType = indtype - return idx -} diff --git a/gdb/sqldb/lockerror.go b/gdb/sqldb/lockerror.go deleted file mode 100644 index 2351b3b..0000000 --- a/gdb/sqldb/lockerror.go +++ /dev/null @@ -1,59 +0,0 @@ -// -// lockerror.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "fmt" - "reflect" -) - -// OptimisticLockError is returned by Update() or Delete() if the -// struct being modified has a Version field and the value is not equal to -// the current value in the database -type OptimisticLockError struct { - // Table name where the lock error occurred - TableName string - - // Primary key values of the row being updated/deleted - Keys []interface{} - - // true if a row was found with those keys, indicating the - // LocalVersion is stale. false if no value was found with those - // keys, suggesting the row has been deleted since loaded, or - // was never inserted to begin with - RowExists bool - - // Version value on the struct passed to Update/Delete. This value is - // out of sync with the database. - LocalVersion int64 -} - -// Error returns a description of the cause of the lock error -func (e OptimisticLockError) Error() string { - if e.RowExists { - return fmt.Sprintf("sqldb: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion) - } - - return fmt.Sprintf("sqldb: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys) -} - -func lockError(m *DbMap, exec SqlExecutor, tableName string, - existingVer int64, elem reflect.Value, - keys ...interface{}) (int64, error) { - - existing, err := get(m, exec, elem.Interface(), keys...) - if err != nil { - return -1, err - } - - ole := OptimisticLockError{tableName, keys, true, existingVer} - if existing == nil { - ole.RowExists = false - } - return -1, ole -} diff --git a/gdb/sqldb/logging.go b/gdb/sqldb/logging.go deleted file mode 100644 index b7e56b1..0000000 --- a/gdb/sqldb/logging.go +++ /dev/null @@ -1,45 +0,0 @@ -// -// logging.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import "fmt" - -// SqldbLogger is a deprecated alias of Logger. -type SqldbLogger = Logger - -// Logger is the type that sqldb uses to log SQL statements. -// See DbMap.TraceOn. -type Logger interface { - Printf(format string, v ...interface{}) -} - -// TraceOn turns on SQL statement logging for this DbMap. After this is -// called, all SQL statements will be sent to the logger. If prefix is -// a non-empty string, it will be written to the front of all logged -// strings, which can aid in filtering log lines. -// -// Use TraceOn if you want to spy on the SQL statements that sqldb -// generates. -// -// Note that the base log.Logger type satisfies Logger, but adapters can -// easily be written for other logging packages (e.g., the golang-sanctioned -// glog framework). -func (m *DbMap) TraceOn(prefix string, logger Logger) { - m.logger = logger - if prefix == "" { - m.logPrefix = prefix - } else { - m.logPrefix = fmt.Sprintf("%s ", prefix) - } -} - -// TraceOff turns off tracing. It is idempotent. -func (m *DbMap) TraceOff() { - m.logger = nil - m.logPrefix = "" -} diff --git a/gdb/sqldb/nulltypes.go b/gdb/sqldb/nulltypes.go deleted file mode 100644 index c5c2158..0000000 --- a/gdb/sqldb/nulltypes.go +++ /dev/null @@ -1,68 +0,0 @@ -// -// nulltypes.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "database/sql/driver" - "log" - "time" -) - -// A nullable Time value -type NullTime struct { - Time time.Time - Valid bool // Valid is true if Time is not NULL -} - -// Scan implements the Scanner interface. -func (nt *NullTime) Scan(value interface{}) error { - log.Printf("Time scan value is: %#v", value) - switch t := value.(type) { - case time.Time: - nt.Time, nt.Valid = t, true - case []byte: - v := strToTime(string(t)) - if v != nil { - nt.Valid = true - nt.Time = *v - } - case string: - v := strToTime(t) - if v != nil { - nt.Valid = true - nt.Time = *v - } - } - return nil -} - -func strToTime(v string) *time.Time { - for _, dtfmt := range []string{ - "2006-01-02 15:04:05.999999999", - "2006-01-02T15:04:05.999999999", - "2006-01-02 15:04:05", - "2006-01-02T15:04:05", - "2006-01-02 15:04", - "2006-01-02T15:04", - "2006-01-02", - "2006-01-02 15:04:05-07:00", - } { - if t, err := time.Parse(dtfmt, v); err == nil { - return &t - } - } - return nil -} - -// Value implements the driver Valuer interface. -func (nt NullTime) Value() (driver.Value, error) { - if !nt.Valid { - return nil, nil - } - return nt.Time, nil -} diff --git a/gdb/sqldb/query_builder.go b/gdb/sqldb/query_builder.go deleted file mode 100644 index fe6f3cd..0000000 --- a/gdb/sqldb/query_builder.go +++ /dev/null @@ -1,40 +0,0 @@ -// -// query_builder.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import "reflect" - -// 该功能用于改善记录查询,避免直接写表名 -// TODO 实现 query builder - -type join_item struct { - way string - table string - on string -} - -type query_builder struct { - table string - fields string - conds []string - orderBy string - offset int - limit int - joins []join_item -} - -func FromEntity(ent any) *query_builder { - tabM, err := dm.TableFor(reflect.TypeOf(ent), false) - if err != nil { - return nil - } - - return &query_builder{ - table: tabM.TableName, - } -} diff --git a/gdb/sqldb/select.go b/gdb/sqldb/select.go deleted file mode 100644 index 8b9a4a9..0000000 --- a/gdb/sqldb/select.go +++ /dev/null @@ -1,361 +0,0 @@ -// -// select.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "database/sql" - "fmt" - "reflect" -) - -// SelectInt executes the given query, which should be a SELECT statement for a single -// integer column, and returns the value of the first row returned. If no rows are -// found, zero is returned. -func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) { - var h int64 - err := selectVal(e, &h, query, args...) - if err != nil && err != sql.ErrNoRows { - return 0, err - } - return h, nil -} - -// SelectNullInt executes the given query, which should be a SELECT statement for a single -// integer column, and returns the value of the first row returned. If no rows are -// found, the empty sql.NullInt64 value is returned. -func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) { - var h sql.NullInt64 - err := selectVal(e, &h, query, args...) - if err != nil && err != sql.ErrNoRows { - return h, err - } - return h, nil -} - -// SelectFloat executes the given query, which should be a SELECT statement for a single -// float column, and returns the value of the first row returned. If no rows are -// found, zero is returned. -func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) { - var h float64 - err := selectVal(e, &h, query, args...) - if err != nil && err != sql.ErrNoRows { - return 0, err - } - return h, nil -} - -// SelectNullFloat executes the given query, which should be a SELECT statement for a single -// float column, and returns the value of the first row returned. If no rows are -// found, the empty sql.NullInt64 value is returned. -func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) { - var h sql.NullFloat64 - err := selectVal(e, &h, query, args...) - if err != nil && err != sql.ErrNoRows { - return h, err - } - return h, nil -} - -// SelectStr executes the given query, which should be a SELECT statement for a single -// char/varchar column, and returns the value of the first row returned. If no rows are -// found, an empty string is returned. -func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) { - var h string - err := selectVal(e, &h, query, args...) - if err != nil && err != sql.ErrNoRows { - return "", err - } - return h, nil -} - -// SelectNullStr executes the given query, which should be a SELECT -// statement for a single char/varchar column, and returns the value -// of the first row returned. If no rows are found, the empty -// sql.NullString is returned. -func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) { - var h sql.NullString - err := selectVal(e, &h, query, args...) - if err != nil && err != sql.ErrNoRows { - return h, err - } - return h, nil -} - -// SelectOne executes the given query (which should be a SELECT statement) -// and binds the result to holder, which must be a pointer. -// -// # If no row is found, an error (sql.ErrNoRows specifically) will be returned -// -// If more than one row is found, an error will be returned. -func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error { - t := reflect.TypeOf(holder) - if t.Kind() == reflect.Ptr { - t = t.Elem() - } else { - return fmt.Errorf("sqldb: SelectOne holder must be a pointer, but got: %t", holder) - } - - // Handle pointer to pointer - isptr := false - if t.Kind() == reflect.Ptr { - isptr = true - t = t.Elem() - } - - if t.Kind() == reflect.Struct { - var nonFatalErr error - - list, err := hookedselect(m, e, holder, query, args...) - if err != nil { - if !NonFatalError(err) { // FIXME: double negative, rename NonFatalError to FatalError - return err - } - nonFatalErr = err - } - - dest := reflect.ValueOf(holder) - if isptr { - dest = dest.Elem() - } - - if list != nil && len(list) > 0 { // FIXME: invert if/else - // check for multiple rows - if len(list) > 1 { - return fmt.Errorf("sqldb: multiple rows returned for: %s - %v", query, args) - } - - // Initialize if nil - if dest.IsNil() { - dest.Set(reflect.New(t)) - } - - // only one row found - src := reflect.ValueOf(list[0]) - dest.Elem().Set(src.Elem()) - } else { - // No rows found, return a proper error. - return sql.ErrNoRows - } - - return nonFatalErr - } - - return selectVal(e, holder, query, args...) -} - -func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error { - if len(args) == 1 { - switch m := e.(type) { - case *DbMap: - query, args = maybeExpandNamedQuery(m, query, args) - case *Transaction: - query, args = maybeExpandNamedQuery(m.dbmap, query, args) - } - } - rows, err := e.Query(query, args...) - if err != nil { - return err - } - defer rows.Close() - - if !rows.Next() { - if err := rows.Err(); err != nil { - return err - } - return sql.ErrNoRows - } - - return rows.Scan(holder) -} - -func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string, - args ...interface{}) ([]interface{}, error) { - - var nonFatalErr error - - list, err := rawselect(m, exec, i, query, args...) - if err != nil { - if !NonFatalError(err) { - return nil, err - } - nonFatalErr = err - } - - // Determine where the results are: written to i, or returned in list - if t, _ := toSliceType(i); t == nil { - for _, v := range list { - if v, ok := v.(HasPostGet); ok { - err := v.PostGet(exec) - if err != nil { - return nil, err - } - } - } - } else { - resultsValue := reflect.Indirect(reflect.ValueOf(i)) - for i := 0; i < resultsValue.Len(); i++ { - if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok { - err := v.PostGet(exec) - if err != nil { - return nil, err - } - } - } - } - return list, nonFatalErr -} - -func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string, - args ...interface{}) ([]interface{}, error) { - var ( - appendToSlice = false // Write results to i directly? - intoStruct = true // Selecting into a struct? - pointerElements = true // Are the slice elements pointers (vs values)? - ) - - var nonFatalErr error - - tableName := "" - var dynObj DynamicTable - isDynamic := false - if dynObj, isDynamic = i.(DynamicTable); isDynamic { - tableName = dynObj.TableName() - } - - // get type for i, verifying it's a supported destination - t, err := toType(i) - if err != nil { - var err2 error - if t, err2 = toSliceType(i); t == nil { - if err2 != nil { - return nil, err2 - } - return nil, err - } - pointerElements = t.Kind() == reflect.Ptr - if pointerElements { - t = t.Elem() - } - appendToSlice = true - intoStruct = t.Kind() == reflect.Struct - } - - // If the caller supplied a single struct/map argument, assume a "named - // parameter" query. Extract the named arguments from the struct/map, create - // the flat arg slice, and rewrite the query to use the dialect's placeholder. - if len(args) == 1 { - query, args = maybeExpandNamedQuery(m, query, args) - } - - // Run the query - rows, err := exec.Query(query, args...) - if err != nil { - return nil, err - } - defer rows.Close() - - // Fetch the column names as returned from db - cols, err := rows.Columns() - if err != nil { - return nil, err - } - - if !intoStruct && len(cols) > 1 { - return nil, fmt.Errorf("sqldb: select into non-struct slice requires 1 column, got %d", len(cols)) - } - - var colToFieldIndex [][]int - if intoStruct { - colToFieldIndex, err = columnToFieldIndex(m, t, tableName, cols) - if err != nil { - if !NonFatalError(err) { - return nil, err - } - nonFatalErr = err - } - } - - conv := m.TypeConverter - - // Add results to one of these two slices. - var ( - list = make([]interface{}, 0) - sliceValue = reflect.Indirect(reflect.ValueOf(i)) - ) - - for { - if !rows.Next() { - // if error occured return rawselect - if rows.Err() != nil { - return nil, rows.Err() - } - // time to exit from outer "for" loop - break - } - v := reflect.New(t) - - if isDynamic { - v.Interface().(DynamicTable).SetTableName(tableName) - } - - dest := make([]interface{}, len(cols)) - - custScan := make([]CustomScanner, 0) - - for x := range cols { - f := v.Elem() - if intoStruct { - index := colToFieldIndex[x] - if index == nil { - // this field is not present in the struct, so create a dummy - // value for rows.Scan to scan into - var dummy dummyField - dest[x] = &dummy - continue - } - f = f.FieldByIndex(index) - } - target := f.Addr().Interface() - if conv != nil { - scanner, ok := conv.FromDb(target) - if ok { - target = scanner.Holder - custScan = append(custScan, scanner) - } - } - dest[x] = target - } - - err = rows.Scan(dest...) - if err != nil { - return nil, err - } - - for _, c := range custScan { - err = c.Bind() - if err != nil { - return nil, err - } - } - - if appendToSlice { - if !pointerElements { - v = v.Elem() - } - sliceValue.Set(reflect.Append(sliceValue, v)) - } else { - list = append(list, v.Interface()) - } - } - - if appendToSlice && sliceValue.IsNil() { - sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0)) - } - - return list, nonFatalErr -} diff --git a/gdb/sqldb/sqldb.go b/gdb/sqldb/sqldb.go deleted file mode 100644 index bae8e17..0000000 --- a/gdb/sqldb/sqldb.go +++ /dev/null @@ -1,675 +0,0 @@ -// -// sqldb.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "context" - "database/sql" - "database/sql/driver" - "fmt" - "reflect" - "regexp" - "strings" - "time" -) - -// OracleString (empty string is null) -// TODO: move to dialect/oracle?, rename to String? -type OracleString struct { - sql.NullString -} - -// Scan implements the Scanner interface. -func (os *OracleString) Scan(value interface{}) error { - if value == nil { - os.String, os.Valid = "", false - return nil - } - os.Valid = true - return os.NullString.Scan(value) -} - -// Value implements the driver Valuer interface. -func (os OracleString) Value() (driver.Value, error) { - if !os.Valid || os.String == "" { - return nil, nil - } - return os.String, nil -} - -// SqlTyper is a type that returns its database type. Most of the -// time, the type can just use "database/sql/driver".Valuer; but when -// it returns nil for its empty value, it needs to implement SqlTyper -// to have its column type detected properly during table creation. -type SqlTyper interface { - SqlType() driver.Value -} - -// legacySqlTyper prevents breaking clients who depended on the previous -// SqlTyper interface -type legacySqlTyper interface { - SqlType() driver.Valuer -} - -// for fields that exists in DB table, but not exists in struct -type dummyField struct{} - -// Scan implements the Scanner interface. -func (nt *dummyField) Scan(value interface{}) error { - return nil -} - -var zeroVal reflect.Value -var versFieldConst = "[sqldb_ver_field]" - -// The TypeConverter interface provides a way to map a value of one -// type to another type when persisting to, or loading from, a database. -// -// Example use cases: Implement type converter to convert bool types to "y"/"n" strings, -// or serialize a struct member as a JSON blob. -type TypeConverter interface { - // ToDb converts val to another type. Called before INSERT/UPDATE operations - ToDb(val interface{}) (interface{}, error) - - // FromDb returns a CustomScanner appropriate for this type. This will be used - // to hold values returned from SELECT queries. - // - // In particular the CustomScanner returned should implement a Binder - // function appropriate for the Go type you wish to convert the db value to - // - // If bool==false, then no custom scanner will be used for this field. - FromDb(target interface{}) (CustomScanner, bool) -} - -// SqlExecutor exposes sqldb operations that can be run from Pre/Post -// hooks. This hides whether the current operation that triggered the -// hook is in a transaction. -// -// See the DbMap function docs for each of the functions below for more -// information. -type SqlExecutor interface { - WithContext(ctx context.Context) SqlExecutor - Get(i interface{}, keys ...interface{}) (interface{}, error) - Insert(list ...interface{}) error - Update(list ...interface{}) (int64, error) - Delete(list ...interface{}) (int64, error) - Exec(query string, args ...interface{}) (sql.Result, error) - Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) - SelectInt(query string, args ...interface{}) (int64, error) - SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) - SelectFloat(query string, args ...interface{}) (float64, error) - SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) - SelectStr(query string, args ...interface{}) (string, error) - SelectNullStr(query string, args ...interface{}) (sql.NullString, error) - SelectOne(holder interface{}, query string, args ...interface{}) error - Query(query string, args ...interface{}) (*sql.Rows, error) - QueryRow(query string, args ...interface{}) *sql.Row -} - -// DynamicTable allows the users of sqldb to dynamically -// use different database table names during runtime -// while sharing the same golang struct for in-memory data -type DynamicTable interface { - TableName() string - SetTableName(string) -} - -// Compile-time check that DbMap and Transaction implement the SqlExecutor -// interface. -var _, _ SqlExecutor = &DbMap{}, &Transaction{} - -func argValue(a interface{}) interface{} { - v, ok := a.(driver.Valuer) - if !ok { - return a - } - vV := reflect.ValueOf(v) - if vV.Kind() == reflect.Ptr && vV.IsNil() { - return nil - } - ret, err := v.Value() - if err != nil { - return a - } - return ret -} - -func argsString(args ...interface{}) string { - var margs string - for i, a := range args { - v := argValue(a) - switch v.(type) { - case string: - v = fmt.Sprintf("%q", v) - default: - v = fmt.Sprintf("%v", v) - } - margs += fmt.Sprintf("%d:%s", i+1, v) - if i+1 < len(args) { - margs += " " - } - } - return margs -} - -// Calls the Exec function on the executor, but attempts to expand any eligible named -// query arguments first. -func maybeExpandNamedQueryAndExec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) { - dbMap := extractDbMap(e) - - if len(args) == 1 { - query, args = maybeExpandNamedQuery(dbMap, query, args) - } - - return exec(e, query, args...) -} - -func extractDbMap(e SqlExecutor) *DbMap { - switch m := e.(type) { - case *DbMap: - return m - case *Transaction: - return m.dbmap - } - return nil -} - -// executor exposes the sql.DB and sql.Tx functions so that it can be used -// on internal functions that need to be agnostic to the underlying object. -type executor interface { - Exec(query string, args ...interface{}) (sql.Result, error) - Prepare(query string) (*sql.Stmt, error) - QueryRow(query string, args ...interface{}) *sql.Row - Query(query string, args ...interface{}) (*sql.Rows, error) - ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) - PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) - QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row - QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) -} - -func extractExecutorAndContext(e SqlExecutor) (executor, context.Context) { - switch m := e.(type) { - case *DbMap: - return m.Db, m.ctx - case *Transaction: - return m.tx, m.ctx - } - return nil, nil -} - -// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used -// as input to a named query. If so, it rewrites the query to use -// dialect-dependent bindvars and instantiates the corresponding slice of -// parameters by extracting data from the map / struct. -// If not, returns the input values unchanged. -func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) { - var ( - arg = args[0] - argval = reflect.ValueOf(arg) - ) - if argval.Kind() == reflect.Ptr { - argval = argval.Elem() - } - - if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String { - return expandNamedQuery(m, query, func(key string) reflect.Value { - return argval.MapIndex(reflect.ValueOf(key)) - }) - } - if argval.Kind() != reflect.Struct { - return query, args - } - if _, ok := arg.(time.Time); ok { - // time.Time is driver.Value - return query, args - } - if _, ok := arg.(driver.Valuer); ok { - // driver.Valuer will be converted to driver.Value. - return query, args - } - - return expandNamedQuery(m, query, argval.FieldByName) -} - -var keyRegexp = regexp.MustCompile(`:[[:word:]]+`) - -// expandNamedQuery accepts a query with placeholders of the form ":key", and a -// single arg of Kind Struct or Map[string]. It returns the query with the -// dialect's placeholders, and a slice of args ready for positional insertion -// into the query. -func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) { - var ( - n int - args []interface{} - ) - return keyRegexp.ReplaceAllStringFunc(query, func(key string) string { - val := keyGetter(key[1:]) - if !val.IsValid() { - return key - } - args = append(args, val.Interface()) - newVar := m.Dialect.BindVar(n) - n++ - return newVar - }), args -} - -func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) { - colToFieldIndex := make([][]int, len(cols)) - - // check if type t is a mapped table - if so we'll - // check the table for column aliasing below - tableMapped := false - table := tableOrNil(m, t, name) - if table != nil { - tableMapped = true - } - - // Loop over column names and find field in i to bind to - // based on column name. all returned columns must match - // a field in the i struct - missingColNames := []string{} - for x := range cols { - colName := strings.ToLower(cols[x]) - field, found := t.FieldByNameFunc(func(fieldName string) bool { - field, _ := t.FieldByName(fieldName) - cArguments := strings.Split(field.Tag.Get("db"), ",") - fieldName = cArguments[0] - - if fieldName == "-" { - return false - } else if fieldName == "" { - fieldName = field.Name - } - if tableMapped { - colMap := colMapOrNil(table, fieldName) - if colMap != nil { - fieldName = colMap.ColumnName - } - } - return colName == strings.ToLower(fieldName) - }) - if found { - colToFieldIndex[x] = field.Index - } - if colToFieldIndex[x] == nil { - missingColNames = append(missingColNames, colName) - } - } - if len(missingColNames) > 0 { - return colToFieldIndex, &NoFieldInTypeError{ - TypeName: t.Name(), - MissingColNames: missingColNames, - } - } - return colToFieldIndex, nil -} - -func fieldByName(val reflect.Value, fieldName string) *reflect.Value { - // try to find field by exact match - f := val.FieldByName(fieldName) - - if f != zeroVal { - return &f - } - - // try to find by case insensitive match - only the Postgres driver - // seems to require this - in the case where columns are aliased in the sql - fieldNameL := strings.ToLower(fieldName) - fieldCount := val.NumField() - t := val.Type() - for i := 0; i < fieldCount; i++ { - sf := t.Field(i) - if strings.ToLower(sf.Name) == fieldNameL { - f := val.Field(i) - return &f - } - } - - return nil -} - -// toSliceType returns the element type of the given object, if the object is a -// "*[]*Element" or "*[]Element". If not, returns nil. -// err is returned if the user was trying to pass a pointer-to-slice but failed. -func toSliceType(i interface{}) (reflect.Type, error) { - t := reflect.TypeOf(i) - if t.Kind() != reflect.Ptr { - // If it's a slice, return a more helpful error message - if t.Kind() == reflect.Slice { - return nil, fmt.Errorf("sqldb: cannot SELECT into a non-pointer slice: %v", t) - } - return nil, nil - } - if t = t.Elem(); t.Kind() != reflect.Slice { - return nil, nil - } - return t.Elem(), nil -} - -func toType(i interface{}) (reflect.Type, error) { - t := reflect.TypeOf(i) - - // If a Pointer to a type, follow - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - if t.Kind() != reflect.Struct { - return nil, fmt.Errorf("sqldb: cannot SELECT into this type: %v", reflect.TypeOf(i)) - } - return t, nil -} - -type foundTable struct { - table *TableMap - dynName *string -} - -func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) { - if dyn, isDynamic := i.(DynamicTable); isDynamic { - tableName := dyn.TableName() - table, err := m.DynamicTableFor(tableName, true) - if err != nil { - return nil, err - } - return &foundTable{ - table: table, - dynName: &tableName, - }, nil - } - table, err := m.TableFor(t, true) - if err != nil { - return nil, err - } - return &foundTable{table: table}, nil -} - -func get(m *DbMap, exec SqlExecutor, i interface{}, - keys ...interface{}) (interface{}, error) { - - t, err := toType(i) - if err != nil { - return nil, err - } - - foundTable, err := tableFor(m, t, i) - if err != nil { - return nil, err - } - table := foundTable.table - - plan := table.bindGet() - - v := reflect.New(t) - if foundTable.dynName != nil { - retDyn := v.Interface().(DynamicTable) - retDyn.SetTableName(*foundTable.dynName) - } - - dest := make([]interface{}, len(plan.argFields)) - - conv := m.TypeConverter - custScan := make([]CustomScanner, 0) - - for x, fieldName := range plan.argFields { - f := v.Elem().FieldByName(fieldName) - target := f.Addr().Interface() - if conv != nil { - scanner, ok := conv.FromDb(target) - if ok { - target = scanner.Holder - custScan = append(custScan, scanner) - } - } - dest[x] = target - } - - row := exec.QueryRow(plan.query, keys...) - err = row.Scan(dest...) - if err != nil { - if err == sql.ErrNoRows { - err = nil - } - return nil, err - } - - for _, c := range custScan { - err = c.Bind() - if err != nil { - return nil, err - } - } - - if v, ok := v.Interface().(HasPostGet); ok { - err := v.PostGet(exec) - if err != nil { - return nil, err - } - } - - return v.Interface(), nil -} - -func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { - count := int64(0) - for _, ptr := range list { - table, elem, err := m.tableForPointer(ptr, true) - if err != nil { - return -1, err - } - - eval := elem.Addr().Interface() - if v, ok := eval.(HasPreDelete); ok { - err = v.PreDelete(exec) - if err != nil { - return -1, err - } - } - - bi, err := table.bindDelete(elem) - if err != nil { - return -1, err - } - - res, err := exec.Exec(bi.query, bi.args...) - if err != nil { - return -1, err - } - rows, err := res.RowsAffected() - if err != nil { - return -1, err - } - - if rows == 0 && bi.existingVersion > 0 { - return lockError(m, exec, table.TableName, - bi.existingVersion, elem, bi.keys...) - } - - count += rows - - if v, ok := eval.(HasPostDelete); ok { - err := v.PostDelete(exec) - if err != nil { - return -1, err - } - } - } - - return count, nil -} - -func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) { - count := int64(0) - for _, ptr := range list { - table, elem, err := m.tableForPointer(ptr, true) - if err != nil { - return -1, err - } - - eval := elem.Addr().Interface() - if v, ok := eval.(HasPreUpdate); ok { - err = v.PreUpdate(exec) - if err != nil { - return -1, err - } - } - - bi, err := table.bindUpdate(elem, colFilter) - if err != nil { - return -1, err - } - - res, err := exec.Exec(bi.query, bi.args...) - if err != nil { - return -1, err - } - - rows, err := res.RowsAffected() - if err != nil { - return -1, err - } - - if rows == 0 && bi.existingVersion > 0 { - return lockError(m, exec, table.TableName, - bi.existingVersion, elem, bi.keys...) - } - - if bi.versField != "" { - elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1) - } - - count += rows - - if v, ok := eval.(HasPostUpdate); ok { - err = v.PostUpdate(exec) - if err != nil { - return -1, err - } - } - } - return count, nil -} - -func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { - for _, ptr := range list { - table, elem, err := m.tableForPointer(ptr, false) - if err != nil { - return err - } - - eval := elem.Addr().Interface() - if v, ok := eval.(HasPreInsert); ok { - err := v.PreInsert(exec) - if err != nil { - return err - } - } - - bi, err := table.bindInsert(elem) - if err != nil { - return err - } - - if bi.autoIncrIdx > -1 { - f := elem.FieldByName(bi.autoIncrFieldName) - switch inserter := m.Dialect.(type) { - case IntegerAutoIncrInserter: - id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...) - if err != nil { - return err - } - k := f.Kind() - if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { - f.SetInt(id) - } else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { - f.SetUint(uint64(id)) - } else { - return fmt.Errorf("sqldb: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) - } - case TargetedAutoIncrInserter: - err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...) - if err != nil { - return err - } - case TargetQueryInserter: - var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery - if idQuery == "" { - return fmt.Errorf("sqldb: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName) - } - err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...) - if err != nil { - return err - } - default: - return fmt.Errorf("sqldb: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface") - } - } else { - _, err := exec.Exec(bi.query, bi.args...) - if err != nil { - return err - } - } - - if v, ok := eval.(HasPostInsert); ok { - err := v.PostInsert(exec) - if err != nil { - return err - } - } - } - return nil -} - -func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) { - executor, ctx := extractExecutorAndContext(e) - - if ctx != nil { - return executor.ExecContext(ctx, query, args...) - } - - return executor.Exec(query, args...) -} - -func prepare(e SqlExecutor, query string) (*sql.Stmt, error) { - executor, ctx := extractExecutorAndContext(e) - - if ctx != nil { - return executor.PrepareContext(ctx, query) - } - - return executor.Prepare(query) -} - -func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row { - executor, ctx := extractExecutorAndContext(e) - - if ctx != nil { - return executor.QueryRowContext(ctx, query, args...) - } - - return executor.QueryRow(query, args...) -} - -func query(e SqlExecutor, query string, args ...interface{}) (*sql.Rows, error) { - executor, ctx := extractExecutorAndContext(e) - - if ctx != nil { - return executor.QueryContext(ctx, query, args...) - } - - return executor.Query(query, args...) -} - -func begin(m *DbMap) (*sql.Tx, error) { - if m.ctx != nil { - return m.Db.BeginTx(m.ctx, nil) - } - - return m.Db.Begin() -} diff --git a/gdb/sqldb/sqldb_test.go b/gdb/sqldb/sqldb_test.go deleted file mode 100644 index 4d390e1..0000000 --- a/gdb/sqldb/sqldb_test.go +++ /dev/null @@ -1,2875 +0,0 @@ -// -// sqldb_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -//go:build integration -// +build integration - -package sqldb_test - -import ( - "bytes" - "database/sql" - "database/sql/driver" - "encoding/json" - "errors" - "flag" - "fmt" - "log" - "math/rand" - "os" - "reflect" - "strconv" - "strings" - "testing" - "time" - - "git.hexq.cn/tiglog/golib/gdb/sqldb" - _ "github.com/go-sql-driver/mysql" - _ "github.com/lib/pq" - _ "github.com/mattn/go-sqlite3" -) - -var ( - // verify interface compliance - _ = []sqldb.Dialect{ - sqldb.SqliteDialect{}, - sqldb.PostgresDialect{}, - sqldb.MySQLDialect{}, - sqldb.OracleDialect{}, - } - - debug bool -) - -func TestMain(m *testing.M) { - flag.BoolVar(&debug, "trace", true, "Turn on or off database tracing (DbMap.TraceOn)") - flag.Parse() - os.Exit(m.Run()) -} - -type testable interface { - GetId() int64 - Rand() -} - -type Invoice struct { - Id int64 - Created int64 - Updated int64 - Memo string - PersonId int64 - IsPaid bool -} - -type InvoiceWithValuer struct { - Id int64 - Created int64 - Updated int64 - Memo string - Person PersonValuerScanner `db:"personid"` - IsPaid bool -} - -func (me *Invoice) GetId() int64 { return me.Id } -func (me *Invoice) Rand() { - me.Memo = fmt.Sprintf("random %d", rand.Int63()) - me.Created = rand.Int63() - me.Updated = rand.Int63() -} - -type InvoiceTag struct { - Id int64 `db:"myid, primarykey, autoincrement"` - Created int64 `db:"myCreated"` - Updated int64 `db:"date_updated"` - Memo string - PersonId int64 `db:"person_id"` - IsPaid bool `db:"is_Paid"` -} - -func (me *InvoiceTag) GetId() int64 { return me.Id } -func (me *InvoiceTag) Rand() { - me.Memo = fmt.Sprintf("random %d", rand.Int63()) - me.Created = rand.Int63() - me.Updated = rand.Int63() -} - -// See: https://github.com/go-sqldb/sqldb/issues/175 -type AliasTransientField struct { - Id int64 `db:"id"` - Bar int64 `db:"-"` - BarStr string `db:"bar"` -} - -func (me *AliasTransientField) GetId() int64 { return me.Id } -func (me *AliasTransientField) Rand() { - me.BarStr = fmt.Sprintf("random %d", rand.Int63()) -} - -type OverriddenInvoice struct { - Invoice - Id string -} - -type Person struct { - Id int64 - Created int64 - Updated int64 - FName string - LName string - Version int64 -} - -// PersonValuerScanner is used as a field in test types to ensure that we -// make use of "database/sql/driver".Valuer for choosing column types when -// creating tables and that we don't get in the way of the underlying -// database libraries when they make use of either Valuer or -// "database/sql".Scanner. -type PersonValuerScanner struct { - Person -} - -// Value implements "database/sql/driver".Valuer. It will be automatically -// run by the "database/sql" package when inserting/updating data. -func (p PersonValuerScanner) Value() (driver.Value, error) { - return p.Id, nil -} - -// Scan implements "database/sql".Scanner. It will be automatically run -// by the "database/sql" package when reading column data into a field -// of type PersonValuerScanner. -func (p *PersonValuerScanner) Scan(value interface{}) (err error) { - switch src := value.(type) { - case []byte: - // TODO: this case is here for mysql only. For some reason, - // one (both?) of the mysql libraries opt to pass us a []byte - // instead of an int64 for the bigint column. We should add - // table tests around valuers/scanners and try to solve these - // types of odd discrepencies to make it easier for users of - // sqldb to migrate to other database engines. - p.Id, err = strconv.ParseInt(string(src), 10, 64) - case int64: - // Most libraries pass in the type we'd expect. - p.Id = src - default: - typ := reflect.TypeOf(value) - return fmt.Errorf("Expected person value to be convertible to int64, got %v (type %s)", value, typ) - } - return -} - -type FNameOnly struct { - FName string -} - -type InvoicePersonView struct { - InvoiceId int64 - PersonId int64 - Memo string - FName string - LegacyVersion int64 -} - -type TableWithNull struct { - Id int64 - Str sql.NullString - Int64 sql.NullInt64 - Float64 sql.NullFloat64 - Bool sql.NullBool - Bytes []byte -} - -type WithIgnoredColumn struct { - internal int64 `db:"-"` - Id int64 - Created int64 -} - -type IdCreated struct { - Id int64 - Created int64 -} - -type IdCreatedExternal struct { - IdCreated - External int64 -} - -type WithStringPk struct { - Id string - Name string -} - -type CustomStringType string - -type TypeConversionExample struct { - Id int64 - PersonJSON Person - Name CustomStringType -} - -type PersonUInt32 struct { - Id uint32 - Name string -} - -type PersonUInt64 struct { - Id uint64 - Name string -} - -type PersonUInt16 struct { - Id uint16 - Name string -} - -type WithEmbeddedStruct struct { - Id int64 - Names -} - -type WithEmbeddedStructConflictingEmbeddedMemberNames struct { - Id int64 - Names - NamesConflict -} - -type WithEmbeddedStructSameMemberName struct { - Id int64 - SameName -} - -type WithEmbeddedStructBeforeAutoincrField struct { - Names - Id int64 -} - -type WithEmbeddedAutoincr struct { - WithEmbeddedStruct - MiddleName string -} - -type Names struct { - FirstName string - LastName string -} - -type NamesConflict struct { - FirstName string - Surname string -} - -type SameName struct { - SameName string -} - -type UniqueColumns struct { - FirstName string - LastName string - City string - ZipCode int64 -} - -type SingleColumnTable struct { - SomeId string -} - -type CustomDate struct { - time.Time -} - -type WithCustomDate struct { - Id int64 - Added CustomDate -} - -type WithNullTime struct { - Id int64 - Time sqldb.NullTime -} - -type testTypeConverter struct{} - -func (me testTypeConverter) ToDb(val interface{}) (interface{}, error) { - - switch t := val.(type) { - case Person: - b, err := json.Marshal(t) - if err != nil { - return "", err - } - return string(b), nil - case CustomStringType: - return string(t), nil - case CustomDate: - return t.Time, nil - } - - return val, nil -} - -func (me testTypeConverter) FromDb(target interface{}) (sqldb.CustomScanner, bool) { - switch target.(type) { - case *Person: - binder := func(holder, target interface{}) error { - s, ok := holder.(*string) - if !ok { - return errors.New("FromDb: Unable to convert Person to *string") - } - b := []byte(*s) - return json.Unmarshal(b, target) - } - return sqldb.CustomScanner{new(string), target, binder}, true - case *CustomStringType: - binder := func(holder, target interface{}) error { - s, ok := holder.(*string) - if !ok { - return errors.New("FromDb: Unable to convert CustomStringType to *string") - } - st, ok := target.(*CustomStringType) - if !ok { - return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomStringType: ", reflect.TypeOf(target))) - } - *st = CustomStringType(*s) - return nil - } - return sqldb.CustomScanner{new(string), target, binder}, true - case *CustomDate: - binder := func(holder, target interface{}) error { - t, ok := holder.(*time.Time) - if !ok { - return errors.New("FromDb: Unable to convert CustomDate to *time.Time") - } - dateTarget, ok := target.(*CustomDate) - if !ok { - return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomDate: ", reflect.TypeOf(target))) - } - dateTarget.Time = *t - return nil - } - return sqldb.CustomScanner{new(time.Time), target, binder}, true - } - - return sqldb.CustomScanner{}, false -} - -func (p *Person) PreInsert(s sqldb.SqlExecutor) error { - p.Created = time.Now().UnixNano() - p.Updated = p.Created - if p.FName == "badname" { - return fmt.Errorf("Invalid name: %s", p.FName) - } - return nil -} - -func (p *Person) PostInsert(s sqldb.SqlExecutor) error { - p.LName = "postinsert" - return nil -} - -func (p *Person) PreUpdate(s sqldb.SqlExecutor) error { - p.FName = "preupdate" - return nil -} - -func (p *Person) PostUpdate(s sqldb.SqlExecutor) error { - p.LName = "postupdate" - return nil -} - -func (p *Person) PreDelete(s sqldb.SqlExecutor) error { - p.FName = "predelete" - return nil -} - -func (p *Person) PostDelete(s sqldb.SqlExecutor) error { - p.LName = "postdelete" - return nil -} - -func (p *Person) PostGet(s sqldb.SqlExecutor) error { - p.LName = "postget" - return nil -} - -type PersistentUser struct { - Key int32 - Id string - PassedTraining bool -} - -type TenantDynamic struct { - Id int64 `db:"id"` - Name string - Address string - curTable string `db:"-"` -} - -func (curObj *TenantDynamic) TableName() string { - return curObj.curTable -} -func (curObj *TenantDynamic) SetTableName(tblName string) { - curObj.curTable = tblName -} - -var dynTableInst1 = TenantDynamic{curTable: "t_1_tenant_dynamic"} -var dynTableInst2 = TenantDynamic{curTable: "t_2_tenant_dynamic"} - -func dynamicTablesTest(t *testing.T, dbmap *sqldb.DbMap) { - - dynamicTablesTestTableMap(t, dbmap, &dynTableInst1) - dynamicTablesTestTableMap(t, dbmap, &dynTableInst2) - - // TEST - dbmap.Insert using dynTableInst1 - dynTableInst1.Name = "Test Name 1" - dynTableInst1.Address = "Test Address 1" - err := dbmap.Insert(&dynTableInst1) - if err != nil { - t.Errorf("Errow while saving dynTableInst1. Details: %v", err) - } - - // TEST - dbmap.Insert using dynTableInst2 - dynTableInst2.Name = "Test Name 2" - dynTableInst2.Address = "Test Address 2" - err = dbmap.Insert(&dynTableInst2) - if err != nil { - t.Errorf("Errow while saving dynTableInst2. Details: %v", err) - } - - dynamicTablesTestSelect(t, dbmap, &dynTableInst1) - dynamicTablesTestSelect(t, dbmap, &dynTableInst2) - dynamicTablesTestSelectOne(t, dbmap, &dynTableInst1) - dynamicTablesTestSelectOne(t, dbmap, &dynTableInst2) - dynamicTablesTestGetUpdateGet(t, dbmap, &dynTableInst1) - dynamicTablesTestGetUpdateGet(t, dbmap, &dynTableInst2) - dynamicTablesTestDelete(t, dbmap, &dynTableInst1) - dynamicTablesTestDelete(t, dbmap, &dynTableInst2) - -} - -func dynamicTablesTestTableMap(t *testing.T, - dbmap *sqldb.DbMap, - inpInst *TenantDynamic) { - - tableName := inpInst.TableName() - - tblMap, err := dbmap.DynamicTableFor(tableName, true) - if err != nil { - t.Errorf("Error while searching for tablemap for tableName: %v, Error:%v", tableName, err) - } - if tblMap == nil { - t.Errorf("Unable to find tablemap for tableName:%v", tableName) - } -} - -func dynamicTablesTestSelect(t *testing.T, - dbmap *sqldb.DbMap, - inpInst *TenantDynamic) { - - // TEST - dbmap.Select using inpInst - - // read the data back from dynInst to see if the - // table mapping is correct - var dbTenantInst1 = TenantDynamic{curTable: inpInst.curTable} - selectSQL1 := "select * from " + inpInst.curTable - dbObjs, err := dbmap.Select(&dbTenantInst1, selectSQL1) - if err != nil { - t.Errorf("Errow in dbmap.Select. SQL: %v, Details: %v", selectSQL1, err) - } - if dbObjs == nil { - t.Fatalf("Nil return from dbmap.Select") - } - rwCnt := len(dbObjs) - if rwCnt != 1 { - t.Errorf("Unexpected row count for tenantInst:%v", rwCnt) - } - - dbInst := dbObjs[0].(*TenantDynamic) - - inpTableName := inpInst.TableName() - resTableName := dbInst.TableName() - if inpTableName != resTableName { - t.Errorf("Mismatched table names %v != %v ", - inpTableName, resTableName) - } - - if inpInst.Id != dbInst.Id { - t.Errorf("Mismatched Id values %v != %v ", - inpInst.Id, dbInst.Id) - } - - if inpInst.Name != dbInst.Name { - t.Errorf("Mismatched Name values %v != %v ", - inpInst.Name, dbInst.Name) - } - - if inpInst.Address != dbInst.Address { - t.Errorf("Mismatched Address values %v != %v ", - inpInst.Address, dbInst.Address) - } -} - -func dynamicTablesTestGetUpdateGet(t *testing.T, - dbmap *sqldb.DbMap, - inpInst *TenantDynamic) { - - // TEST - dbmap.Get, dbmap.Update, dbmap.Get sequence - - // read and update one of the instances to make sure - // that the common sqldb APIs are working well with dynamic table - var inpIface2 = TenantDynamic{curTable: inpInst.curTable} - dbObj, err := dbmap.Get(&inpIface2, inpInst.Id) - if err != nil { - t.Errorf("Errow in dbmap.Get. id: %v, Details: %v", inpInst.Id, err) - } - if dbObj == nil { - t.Errorf("Nil return from dbmap.Get") - } - - dbInst := dbObj.(*TenantDynamic) - - { - inpTableName := inpInst.TableName() - resTableName := dbInst.TableName() - if inpTableName != resTableName { - t.Errorf("Mismatched table names %v != %v ", - inpTableName, resTableName) - } - - if inpInst.Id != dbInst.Id { - t.Errorf("Mismatched Id values %v != %v ", - inpInst.Id, dbInst.Id) - } - - if inpInst.Name != dbInst.Name { - t.Errorf("Mismatched Name values %v != %v ", - inpInst.Name, dbInst.Name) - } - - if inpInst.Address != dbInst.Address { - t.Errorf("Mismatched Address values %v != %v ", - inpInst.Address, dbInst.Address) - } - } - - { - updatedName := "Testing Updated Name2" - dbInst.Name = updatedName - cnt, err := dbmap.Update(dbInst) - if err != nil { - t.Errorf("Error from dbmap.Update: %v", err.Error()) - } - if cnt != 1 { - t.Errorf("Update count must be 1, got %v", cnt) - } - - // Read the object again to make sure that the - // data was updated in db - dbObj2, err := dbmap.Get(&inpIface2, inpInst.Id) - if err != nil { - t.Errorf("Errow in dbmap.Get. id: %v, Details: %v", inpInst.Id, err) - } - if dbObj2 == nil { - t.Errorf("Nil return from dbmap.Get") - } - - dbInst2 := dbObj2.(*TenantDynamic) - - inpTableName := inpInst.TableName() - resTableName := dbInst2.TableName() - if inpTableName != resTableName { - t.Errorf("Mismatched table names %v != %v ", - inpTableName, resTableName) - } - - if inpInst.Id != dbInst2.Id { - t.Errorf("Mismatched Id values %v != %v ", - inpInst.Id, dbInst2.Id) - } - - if updatedName != dbInst2.Name { - t.Errorf("Mismatched Name values %v != %v ", - updatedName, dbInst2.Name) - } - - if inpInst.Address != dbInst.Address { - t.Errorf("Mismatched Address values %v != %v ", - inpInst.Address, dbInst.Address) - } - - } -} - -func dynamicTablesTestSelectOne(t *testing.T, - dbmap *sqldb.DbMap, - inpInst *TenantDynamic) { - - // TEST - dbmap.SelectOne - - // read the data back from inpInst to see if the - // table mapping is correct - var dbTenantInst1 = TenantDynamic{curTable: inpInst.curTable} - selectSQL1 := "select * from " + dbTenantInst1.curTable + " where id = :idKey" - params := map[string]interface{}{"idKey": inpInst.Id} - err := dbmap.SelectOne(&dbTenantInst1, selectSQL1, params) - if err != nil { - t.Errorf("Errow in dbmap.SelectOne. SQL: %v, Details: %v", selectSQL1, err) - } - - inpTableName := inpInst.curTable - resTableName := dbTenantInst1.TableName() - if inpTableName != resTableName { - t.Errorf("Mismatched table names %v != %v ", - inpTableName, resTableName) - } - - if inpInst.Id != dbTenantInst1.Id { - t.Errorf("Mismatched Id values %v != %v ", - inpInst.Id, dbTenantInst1.Id) - } - - if inpInst.Name != dbTenantInst1.Name { - t.Errorf("Mismatched Name values %v != %v ", - inpInst.Name, dbTenantInst1.Name) - } - - if inpInst.Address != dbTenantInst1.Address { - t.Errorf("Mismatched Address values %v != %v ", - inpInst.Address, dbTenantInst1.Address) - } -} - -func dynamicTablesTestDelete(t *testing.T, - dbmap *sqldb.DbMap, - inpInst *TenantDynamic) { - - // TEST - dbmap.Delete - cnt, err := dbmap.Delete(inpInst) - if err != nil { - t.Errorf("Errow in dbmap.Delete. Details: %v", err) - } - if cnt != 1 { - t.Errorf("Expected delete count for %v : 1, found count:%v", - inpInst.TableName(), cnt) - } - - // Try reading again to make sure instance is gone from db - getInst := TenantDynamic{curTable: inpInst.TableName()} - dbInst, err := dbmap.Get(&getInst, inpInst.Id) - if err != nil { - t.Errorf("Error while trying to read deleted %v object using id: %v", - inpInst.TableName(), inpInst.Id) - } - - if dbInst != nil { - t.Errorf("Found deleted %v instance using id: %v", - inpInst.TableName(), inpInst.Id) - } - - if getInst.Name != "" { - t.Errorf("Found data from deleted %v instance using id: %v", - inpInst.TableName(), inpInst.Id) - } - -} - -func TestCreateTablesIfNotExists(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - err := dbmap.CreateTablesIfNotExists() - if err != nil { - t.Error(err) - } -} - -func TestTruncateTables(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - err := dbmap.CreateTablesIfNotExists() - if err != nil { - t.Error(err) - } - - // Insert some data - p1 := &Person{0, 0, 0, "Bob", "Smith", 0} - dbmap.Insert(p1) - inv := &Invoice{0, 0, 1, "my invoice", 0, true} - dbmap.Insert(inv) - - err = dbmap.TruncateTables() - if err != nil { - t.Error(err) - } - - // Make sure all rows are deleted - rows, _ := dbmap.Select(Person{}, "SELECT * FROM person_test") - if len(rows) != 0 { - t.Errorf("Expected 0 person rows, got %d", len(rows)) - } - rows, _ = dbmap.Select(Invoice{}, "SELECT * FROM invoice_test") - if len(rows) != 0 { - t.Errorf("Expected 0 invoice rows, got %d", len(rows)) - } -} - -func TestCustomDateType(t *testing.T) { - dbmap := newDBMap(t) - dbmap.TypeConverter = testTypeConverter{} - dbmap.AddTable(WithCustomDate{}).SetKeys(true, "Id") - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - test1 := &WithCustomDate{Added: CustomDate{Time: time.Now().Truncate(time.Second)}} - err = dbmap.Insert(test1) - if err != nil { - t.Errorf("Could not insert struct with custom date field: %s", err) - t.FailNow() - } - // Unfortunately, the mysql driver doesn't handle time.Time - // values properly during Get(). I can't find a way to work - // around that problem - every other type that I've tried is just - // silently converted. time.Time is the only type that causes - // the issue that this test checks for. As such, if the driver is - // mysql, we'll just skip the rest of this test. - if _, driver := dialectAndDriver(); driver == "mysql" { - t.Skip("TestCustomDateType can't run Get() with the mysql driver; skipping the rest of this test...") - } - result, err := dbmap.Get(new(WithCustomDate), test1.Id) - if err != nil { - t.Errorf("Could not get struct with custom date field: %s", err) - t.FailNow() - } - test2 := result.(*WithCustomDate) - if test2.Added.UTC() != test1.Added.UTC() { - t.Errorf("Custom dates do not match: %v != %v", test2.Added.UTC(), test1.Added.UTC()) - } -} - -func TestUIntPrimaryKey(t *testing.T) { - dbmap := newDBMap(t) - dbmap.AddTable(PersonUInt64{}).SetKeys(true, "Id") - dbmap.AddTable(PersonUInt32{}).SetKeys(true, "Id") - dbmap.AddTable(PersonUInt16{}).SetKeys(true, "Id") - err := dbmap.CreateTablesIfNotExists() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - p1 := &PersonUInt64{0, "name1"} - p2 := &PersonUInt32{0, "name2"} - p3 := &PersonUInt16{0, "name3"} - err = dbmap.Insert(p1, p2, p3) - if err != nil { - t.Error(err) - } - if p1.Id != 1 { - t.Errorf("%d != 1", p1.Id) - } - if p2.Id != 1 { - t.Errorf("%d != 1", p2.Id) - } - if p3.Id != 1 { - t.Errorf("%d != 1", p3.Id) - } -} - -func TestSetUniqueTogether(t *testing.T) { - dbmap := newDBMap(t) - dbmap.AddTable(UniqueColumns{}).SetUniqueTogether("FirstName", "LastName").SetUniqueTogether("City", "ZipCode") - err := dbmap.CreateTablesIfNotExists() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - n1 := &UniqueColumns{"Steve", "Jobs", "Cupertino", 95014} - err = dbmap.Insert(n1) - if err != nil { - t.Error(err) - } - - // Should fail because of the first constraint - n2 := &UniqueColumns{"Steve", "Jobs", "Sunnyvale", 94085} - err = dbmap.Insert(n2) - if err == nil { - t.Error(err) - } - // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - errLower := strings.ToLower(err.Error()) - if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { - t.Error(err) - } - - // Should also fail because of the second unique-together - n3 := &UniqueColumns{"Steve", "Wozniak", "Cupertino", 95014} - err = dbmap.Insert(n3) - if err == nil { - t.Error(err) - } - // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL - errLower = strings.ToLower(err.Error()) - if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { - t.Error(err) - } - - // This one should finally succeed - n4 := &UniqueColumns{"Steve", "Wozniak", "Sunnyvale", 94085} - err = dbmap.Insert(n4) - if err != nil { - t.Error(err) - } -} - -func TestSetUniqueTogetherIdempotent(t *testing.T) { - dbmap := newDBMap(t) - table := dbmap.AddTable(UniqueColumns{}).SetUniqueTogether("FirstName", "LastName") - table.SetUniqueTogether("FirstName", "LastName") - err := dbmap.CreateTablesIfNotExists() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - n1 := &UniqueColumns{"Steve", "Jobs", "Cupertino", 95014} - err = dbmap.Insert(n1) - if err != nil { - t.Error(err) - } - - // Should still fail because of the constraint - n2 := &UniqueColumns{"Steve", "Jobs", "Sunnyvale", 94085} - err = dbmap.Insert(n2) - if err == nil { - t.Error(err) - } - - // Should have only created one unique constraint - actualCount := strings.Count(table.SqlForCreate(false), "unique") - if actualCount != 1 { - t.Errorf("expected one unique index, found %d: %s", actualCount, table.SqlForCreate(false)) - } -} - -func TestPersistentUser(t *testing.T) { - dbmap := newDBMap(t) - dbmap.Exec("drop table if exists PersistentUser") - table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") - table.ColMap("Key").Rename("mykey") - err := dbmap.CreateTablesIfNotExists() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - pu := &PersistentUser{43, "33r", false} - err = dbmap.Insert(pu) - if err != nil { - panic(err) - } - - // prove we can pass a pointer into Get - pu2, err := dbmap.Get(pu, pu.Key) - if err != nil { - panic(err) - } - if !reflect.DeepEqual(pu, pu2) { - t.Errorf("%v!=%v", pu, pu2) - } - - arr, err := dbmap.Select(pu, "select * from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if !reflect.DeepEqual(pu, arr[0]) { - t.Errorf("%v!=%v", pu, arr[0]) - } - - // prove we can get the results back in a slice - var puArr []*PersistentUser - _, err = dbmap.Select(&puArr, "select * from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if len(puArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu, puArr[0]) { - t.Errorf("%v!=%v", pu, puArr[0]) - } - - // prove we can get the results back in a non-pointer slice - var puValues []PersistentUser - _, err = dbmap.Select(&puValues, "select * from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if len(puValues) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(*pu, puValues[0]) { - t.Errorf("%v!=%v", *pu, puValues[0]) - } - - // prove we can get the results back in a string slice - var idArr []*string - _, err = dbmap.Select(&idArr, "select "+columnName(dbmap, PersistentUser{}, "Id")+" from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if len(idArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu.Id, *idArr[0]) { - t.Errorf("%v!=%v", pu.Id, *idArr[0]) - } - - // prove we can get the results back in an int slice - var keyArr []*int32 - _, err = dbmap.Select(&keyArr, "select mykey from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if len(keyArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu.Key, *keyArr[0]) { - t.Errorf("%v!=%v", pu.Key, *keyArr[0]) - } - - // prove we can get the results back in a bool slice - var passedArr []*bool - _, err = dbmap.Select(&passedArr, "select "+columnName(dbmap, PersistentUser{}, "PassedTraining")+" from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if len(passedArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu.PassedTraining, *passedArr[0]) { - t.Errorf("%v!=%v", pu.PassedTraining, *passedArr[0]) - } - - // prove we can get the results back in a non-pointer slice - var stringArr []string - _, err = dbmap.Select(&stringArr, "select "+columnName(dbmap, PersistentUser{}, "Id")+" from "+tableName(dbmap, PersistentUser{})) - if err != nil { - panic(err) - } - if len(stringArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu.Id, stringArr[0]) { - t.Errorf("%v!=%v", pu.Id, stringArr[0]) - } -} - -func TestNamedQueryMap(t *testing.T) { - dbmap := newDBMap(t) - dbmap.Exec("drop table if exists PersistentUser") - table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") - table.ColMap("Key").Rename("mykey") - err := dbmap.CreateTablesIfNotExists() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - pu := &PersistentUser{43, "33r", false} - pu2 := &PersistentUser{500, "abc", false} - err = dbmap.Insert(pu, pu2) - if err != nil { - panic(err) - } - - // Test simple case - var puArr []*PersistentUser - _, err = dbmap.Select(&puArr, "select * from "+tableName(dbmap, PersistentUser{})+" where mykey = :Key", map[string]interface{}{ - "Key": 43, - }) - if err != nil { - t.Errorf("Failed to select: %s", err) - t.FailNow() - } - if len(puArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu, puArr[0]) { - t.Errorf("%v!=%v", pu, puArr[0]) - } - - // Test more specific map value type is ok - puArr = nil - _, err = dbmap.Select(&puArr, "select * from "+tableName(dbmap, PersistentUser{})+" where mykey = :Key", map[string]int{ - "Key": 43, - }) - if err != nil { - t.Errorf("Failed to select: %s", err) - t.FailNow() - } - if len(puArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - - // Test multiple parameters set. - puArr = nil - _, err = dbmap.Select(&puArr, ` -select * from `+tableName(dbmap, PersistentUser{})+` - where mykey = :Key - and `+columnName(dbmap, PersistentUser{}, "PassedTraining")+` = :PassedTraining - and `+columnName(dbmap, PersistentUser{}, "Id")+` = :Id`, map[string]interface{}{ - "Key": 43, - "PassedTraining": false, - "Id": "33r", - }) - if err != nil { - t.Errorf("Failed to select: %s", err) - t.FailNow() - } - if len(puArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - - // Test colon within a non-key string - // Test having extra, unused properties in the map. - puArr = nil - _, err = dbmap.Select(&puArr, ` -select * from `+tableName(dbmap, PersistentUser{})+` - where mykey = :Key - and `+columnName(dbmap, PersistentUser{}, "Id")+` != 'abc:def'`, map[string]interface{}{ - "Key": 43, - "PassedTraining": false, - }) - if err != nil { - t.Errorf("Failed to select: %s", err) - t.FailNow() - } - if len(puArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - - // Test to delete with Exec and named params. - result, err := dbmap.Exec("delete from "+tableName(dbmap, PersistentUser{})+" where mykey = :Key", map[string]interface{}{ - "Key": 43, - }) - count, err := result.RowsAffected() - if err != nil { - t.Errorf("Failed to exec: %s", err) - t.FailNow() - } - if count != 1 { - t.Errorf("Expected 1 persistentuser to be deleted, but %d deleted", count) - } -} - -func TestNamedQueryStruct(t *testing.T) { - dbmap := newDBMap(t) - dbmap.Exec("drop table if exists PersistentUser") - table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") - table.ColMap("Key").Rename("mykey") - err := dbmap.CreateTablesIfNotExists() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - pu := &PersistentUser{43, "33r", false} - pu2 := &PersistentUser{500, "abc", false} - err = dbmap.Insert(pu, pu2) - if err != nil { - panic(err) - } - - // Test select self - var puArr []*PersistentUser - _, err = dbmap.Select(&puArr, ` -select * from `+tableName(dbmap, PersistentUser{})+` - where mykey = :Key - and `+columnName(dbmap, PersistentUser{}, "PassedTraining")+` = :PassedTraining - and `+columnName(dbmap, PersistentUser{}, "Id")+` = :Id`, pu) - if err != nil { - t.Errorf("Failed to select: %s", err) - t.FailNow() - } - if len(puArr) != 1 { - t.Errorf("Expected one persistentuser, found none") - } - if !reflect.DeepEqual(pu, puArr[0]) { - t.Errorf("%v!=%v", pu, puArr[0]) - } - - // Test delete self. - result, err := dbmap.Exec(` -delete from `+tableName(dbmap, PersistentUser{})+` - where mykey = :Key - and `+columnName(dbmap, PersistentUser{}, "PassedTraining")+` = :PassedTraining - and `+columnName(dbmap, PersistentUser{}, "Id")+` = :Id`, pu) - count, err := result.RowsAffected() - if err != nil { - t.Errorf("Failed to exec: %s", err) - t.FailNow() - } - if count != 1 { - t.Errorf("Expected 1 persistentuser to be deleted, but %d deleted", count) - } -} - -// Ensure that the slices containing SQL results are non-nil when the result set is empty. -func TestReturnsNonNilSlice(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - noResultsSQL := "select * from invoice_test where " + columnName(dbmap, Invoice{}, "Id") + "=99999" - var r1 []*Invoice - rawSelect(dbmap, &r1, noResultsSQL) - if r1 == nil { - t.Errorf("r1==nil") - } - - r2 := rawSelect(dbmap, Invoice{}, noResultsSQL) - if r2 == nil { - t.Errorf("r2==nil") - } -} - -func TestOverrideVersionCol(t *testing.T) { - dbmap := newDBMap(t) - t1 := dbmap.AddTable(InvoicePersonView{}).SetKeys(false, "InvoiceId", "PersonId") - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - c1 := t1.SetVersionCol("LegacyVersion") - if c1.ColumnName != "LegacyVersion" { - t.Errorf("Wrong col returned: %v", c1) - } - - ipv := &InvoicePersonView{1, 2, "memo", "fname", 0} - _update(dbmap, ipv) - if ipv.LegacyVersion != 1 { - t.Errorf("LegacyVersion not updated: %d", ipv.LegacyVersion) - } -} - -func TestOptimisticLocking(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p1 := &Person{0, 0, 0, "Bob", "Smith", 0} - dbmap.Insert(p1) // Version is now 1 - if p1.Version != 1 { - t.Errorf("Insert didn't incr Version: %d != %d", 1, p1.Version) - return - } - if p1.Id == 0 { - t.Errorf("Insert didn't return a generated PK") - return - } - - obj, err := dbmap.Get(Person{}, p1.Id) - if err != nil { - panic(err) - } - p2 := obj.(*Person) - p2.LName = "Edwards" - dbmap.Update(p2) // Version is now 2 - if p2.Version != 2 { - t.Errorf("Update didn't incr Version: %d != %d", 2, p2.Version) - } - - p1.LName = "Howard" - count, err := dbmap.Update(p1) - if _, ok := err.(sqldb.OptimisticLockError); !ok { - t.Errorf("update - Expected sqldb.OptimisticLockError, got: %v", err) - } - if count != -1 { - t.Errorf("update - Expected -1 count, got: %d", count) - } - - count, err = dbmap.Delete(p1) - if _, ok := err.(sqldb.OptimisticLockError); !ok { - t.Errorf("delete - Expected sqldb.OptimisticLockError, got: %v", err) - } - if count != -1 { - t.Errorf("delete - Expected -1 count, got: %d", count) - } -} - -// what happens if a legacy table has a null value? -func TestDoubleAddTable(t *testing.T) { - dbmap := newDBMap(t) - t1 := dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id") - t2 := dbmap.AddTable(TableWithNull{}) - if t1 != t2 { - t.Errorf("%v != %v", t1, t2) - } -} - -// what happens if a legacy table has a null value? -func TestNullValues(t *testing.T) { - dbmap := initDBMapNulls(t) - defer dropAndClose(dbmap) - - // insert a row directly - rawExec(dbmap, "insert into "+tableName(dbmap, TableWithNull{})+" values (10, null, "+ - "null, null, null, null)") - - // try to load it - expected := &TableWithNull{Id: 10} - obj := _get(dbmap, TableWithNull{}, 10) - t1 := obj.(*TableWithNull) - if !reflect.DeepEqual(expected, t1) { - t.Errorf("%v != %v", expected, t1) - } - - // update it - t1.Str = sql.NullString{"hi", true} - expected.Str = t1.Str - t1.Int64 = sql.NullInt64{999, true} - expected.Int64 = t1.Int64 - t1.Float64 = sql.NullFloat64{53.33, true} - expected.Float64 = t1.Float64 - t1.Bool = sql.NullBool{true, true} - expected.Bool = t1.Bool - t1.Bytes = []byte{1, 30, 31, 33} - expected.Bytes = t1.Bytes - _update(dbmap, t1) - - obj = _get(dbmap, TableWithNull{}, 10) - t1 = obj.(*TableWithNull) - if t1.Str.String != "hi" { - t.Errorf("%s != hi", t1.Str.String) - } - if !reflect.DeepEqual(expected, t1) { - t.Errorf("%v != %v", expected, t1) - } -} - -func TestScannerValuer(t *testing.T) { - dbmap := newDBMap(t) - dbmap.AddTableWithName(PersonValuerScanner{}, "person_test").SetKeys(true, "Id") - dbmap.AddTableWithName(InvoiceWithValuer{}, "invoice_test").SetKeys(true, "Id") - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - pv := PersonValuerScanner{} - pv.FName = "foo" - pv.LName = "bar" - err = dbmap.Insert(&pv) - if err != nil { - t.Errorf("Could not insert PersonValuerScanner using Person table: %v", err) - t.FailNow() - } - - inv := InvoiceWithValuer{} - inv.Memo = "foo" - inv.Person = pv - err = dbmap.Insert(&inv) - if err != nil { - t.Errorf("Could not insert InvoiceWithValuer using Invoice table: %v", err) - t.FailNow() - } - - res, err := dbmap.Get(InvoiceWithValuer{}, inv.Id) - if err != nil { - t.Errorf("Could not get InvoiceWithValuer: %v", err) - t.FailNow() - } - dbInv := res.(*InvoiceWithValuer) - - if dbInv.Person.Id != pv.Id { - t.Errorf("InvoiceWithValuer got wrong person ID: %d (expected) != %d (actual)", pv.Id, dbInv.Person.Id) - } -} - -func TestColumnProps(t *testing.T) { - dbmap := newDBMap(t) - t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") - t1.ColMap("Created").Rename("date_created") - t1.ColMap("Updated").SetTransient(true) - t1.ColMap("Memo").SetMaxSize(10) - t1.ColMap("PersonId").SetUnique(true) - - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - // test transient - inv := &Invoice{0, 0, 1, "my invoice", 0, true} - _insert(dbmap, inv) - obj := _get(dbmap, Invoice{}, inv.Id) - inv = obj.(*Invoice) - if inv.Updated != 0 { - t.Errorf("Saved transient column 'Updated'") - } - - // test max size - inv.Memo = "this memo is too long" - err = dbmap.Insert(inv) - if err == nil { - t.Errorf("max size exceeded, but Insert did not fail.") - } - - // test unique - same person id - inv = &Invoice{0, 0, 1, "my invoice2", 0, false} - err = dbmap.Insert(inv) - if err == nil { - t.Errorf("same PersonId inserted, but Insert did not fail.") - } -} - -func TestRawSelect(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p1 := &Person{0, 0, 0, "bob", "smith", 0} - _insert(dbmap, p1) - - inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, true} - _insert(dbmap, inv1) - - expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0} - - query := "select i." + columnName(dbmap, Invoice{}, "Id") + " InvoiceId, p." + columnName(dbmap, Person{}, "Id") + " PersonId, i." + columnName(dbmap, Invoice{}, "Memo") + ", p." + columnName(dbmap, Person{}, "FName") + " " + - "from invoice_test i, person_test p " + - "where i." + columnName(dbmap, Invoice{}, "PersonId") + " = p." + columnName(dbmap, Person{}, "Id") - list := rawSelect(dbmap, InvoicePersonView{}, query) - if len(list) != 1 { - t.Errorf("len(list) != 1: %d", len(list)) - } else if !reflect.DeepEqual(expected, list[0]) { - t.Errorf("%v != %v", expected, list[0]) - } -} - -func TestHooks(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p1 := &Person{0, 0, 0, "bob", "smith", 0} - _insert(dbmap, p1) - if p1.Created == 0 || p1.Updated == 0 { - t.Errorf("p1.PreInsert() didn't run: %v", p1) - } else if p1.LName != "postinsert" { - t.Errorf("p1.PostInsert() didn't run: %v", p1) - } - - obj := _get(dbmap, Person{}, p1.Id) - p1 = obj.(*Person) - if p1.LName != "postget" { - t.Errorf("p1.PostGet() didn't run: %v", p1) - } - - _update(dbmap, p1) - if p1.FName != "preupdate" { - t.Errorf("p1.PreUpdate() didn't run: %v", p1) - } else if p1.LName != "postupdate" { - t.Errorf("p1.PostUpdate() didn't run: %v", p1) - } - - var persons []*Person - bindVar := dbmap.Dialect.BindVar(0) - rawSelect(dbmap, &persons, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+" = "+bindVar, p1.Id) - if persons[0].LName != "postget" { - t.Errorf("p1.PostGet() didn't run after select: %v", p1) - } - - _del(dbmap, p1) - if p1.FName != "predelete" { - t.Errorf("p1.PreDelete() didn't run: %v", p1) - } else if p1.LName != "postdelete" { - t.Errorf("p1.PostDelete() didn't run: %v", p1) - } - - // Test error case - p2 := &Person{0, 0, 0, "badname", "", 0} - err := dbmap.Insert(p2) - if err == nil { - t.Errorf("p2.PreInsert() didn't return an error") - } -} - -func TestTransaction(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - inv1 := &Invoice{0, 100, 200, "t1", 0, true} - inv2 := &Invoice{0, 100, 200, "t2", 0, false} - - trans, err := dbmap.Begin() - if err != nil { - panic(err) - } - trans.Insert(inv1, inv2) - err = trans.Commit() - if err != nil { - panic(err) - } - - obj, err := dbmap.Get(Invoice{}, inv1.Id) - if err != nil { - panic(err) - } - if !reflect.DeepEqual(inv1, obj) { - t.Errorf("%v != %v", inv1, obj) - } - obj, err = dbmap.Get(Invoice{}, inv2.Id) - if err != nil { - panic(err) - } - if !reflect.DeepEqual(inv2, obj) { - t.Errorf("%v != %v", inv2, obj) - } -} - -func TestTransactionExecNamed(t *testing.T) { - if os.Getenv("SQLDB_TEST_DIALECT") == "postgres" { - return - } - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - trans, err := dbmap.Begin() - if err != nil { - panic(err) - } - defer trans.Rollback() - // exec should support named params - args := map[string]interface{}{ - "created": 100, - "updated": 200, - "memo": "unpaid", - "personID": 0, - "isPaid": false, - } - - result, err := trans.Exec(`INSERT INTO invoice_test (Created, Updated, Memo, PersonId, IsPaid) Values(:created, :updated, :memo, :personID, :isPaid)`, args) - if err != nil { - panic(err) - } - id, err := result.LastInsertId() - if err != nil { - panic(err) - } - var checkMemo = func(want string) { - args := map[string]interface{}{ - "id": id, - } - memo, err := trans.SelectStr("select memo from invoice_test where id = :id", args) - if err != nil { - panic(err) - } - if memo != want { - t.Errorf("%q != %q", want, memo) - } - } - checkMemo("unpaid") - - // exec should still work with ? params - result, err = trans.Exec(`INSERT INTO invoice_test (Created, Updated, Memo, PersonId, IsPaid) Values(?, ?, ?, ?, ?)`, 10, 15, "paid", 0, true) - if err != nil { - panic(err) - } - id, err = result.LastInsertId() - if err != nil { - panic(err) - } - checkMemo("paid") - err = trans.Commit() - if err != nil { - panic(err) - } -} - -func TestTransactionExecNamedPostgres(t *testing.T) { - if os.Getenv("SQLDB_TEST_DIALECT") != "postgres" { - return - } - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - trans, err := dbmap.Begin() - if err != nil { - panic(err) - } - // exec should support named params - args := map[string]interface{}{ - "created": 100, - "updated": 200, - "memo": "zzTest", - "personID": 0, - "isPaid": false, - } - _, err = trans.Exec(`INSERT INTO invoice_test ("Created", "Updated", "Memo", "PersonId", "IsPaid") Values(:created, :updated, :memo, :personID, :isPaid)`, args) - if err != nil { - panic(err) - } - var checkMemo = func(want string) { - args := map[string]interface{}{ - "memo": want, - } - memo, err := trans.SelectStr(`select "Memo" from invoice_test where "Memo" = :memo`, args) - if err != nil { - panic(err) - } - if memo != want { - t.Errorf("%q != %q", want, memo) - } - } - checkMemo("zzTest") - - // exec should still work with ? params - _, err = trans.Exec(`INSERT INTO invoice_test ("Created", "Updated", "Memo", "PersonId", "IsPaid") Values($1, $2, $3, $4, $5)`, 10, 15, "yyTest", 0, true) - - if err != nil { - panic(err) - } - checkMemo("yyTest") - err = trans.Commit() - if err != nil { - panic(err) - } -} - -func TestSavepoint(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - inv1 := &Invoice{0, 100, 200, "unpaid", 0, false} - - trans, err := dbmap.Begin() - if err != nil { - panic(err) - } - trans.Insert(inv1) - - var checkMemo = func(want string) { - memo, err := trans.SelectStr("select " + columnName(dbmap, Invoice{}, "Memo") + " from invoice_test") - if err != nil { - panic(err) - } - if memo != want { - t.Errorf("%q != %q", want, memo) - } - } - checkMemo("unpaid") - - err = trans.Savepoint("foo") - if err != nil { - panic(err) - } - checkMemo("unpaid") - - inv1.Memo = "paid" - _, err = trans.Update(inv1) - if err != nil { - panic(err) - } - checkMemo("paid") - - err = trans.RollbackToSavepoint("foo") - if err != nil { - panic(err) - } - checkMemo("unpaid") - - err = trans.Rollback() - if err != nil { - panic(err) - } -} - -func TestMultiple(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - inv1 := &Invoice{0, 100, 200, "a", 0, false} - inv2 := &Invoice{0, 100, 200, "b", 0, true} - _insert(dbmap, inv1, inv2) - - inv1.Memo = "c" - inv2.Memo = "d" - _update(dbmap, inv1, inv2) - - count := _del(dbmap, inv1, inv2) - if count != 2 { - t.Errorf("%d != 2", count) - } -} - -func TestCrud(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - inv := &Invoice{0, 100, 200, "first order", 0, true} - testCrudInternal(t, dbmap, inv) - - invtag := &InvoiceTag{0, 300, 400, "some order", 33, false} - testCrudInternal(t, dbmap, invtag) - - foo := &AliasTransientField{BarStr: "some bar"} - testCrudInternal(t, dbmap, foo) - - dynamicTablesTest(t, dbmap) -} - -func testCrudInternal(t *testing.T, dbmap *sqldb.DbMap, val testable) { - table, err := dbmap.TableFor(reflect.TypeOf(val).Elem(), false) - if err != nil { - t.Errorf("couldn't call TableFor: val=%v err=%v", val, err) - } - - _, err = dbmap.Exec("delete from " + table.TableName) - if err != nil { - t.Errorf("couldn't delete rows from: val=%v err=%v", val, err) - } - - // INSERT row - _insert(dbmap, val) - if val.GetId() == 0 { - t.Errorf("val.GetId() was not set on INSERT") - return - } - - // SELECT row - val2 := _get(dbmap, val, val.GetId()) - if !reflect.DeepEqual(val, val2) { - t.Errorf("%v != %v", val, val2) - } - - // UPDATE row and SELECT - val.Rand() - count := _update(dbmap, val) - if count != 1 { - t.Errorf("update 1 != %d", count) - } - val2 = _get(dbmap, val, val.GetId()) - if !reflect.DeepEqual(val, val2) { - t.Errorf("%v != %v", val, val2) - } - - // Select * - rows, err := dbmap.Select(val, "select * from "+dbmap.Dialect.QuoteField(table.TableName)) - if err != nil { - t.Errorf("couldn't select * from %s err=%v", dbmap.Dialect.QuoteField(table.TableName), err) - } else if len(rows) != 1 { - t.Errorf("unexpected row count in %s: %d", dbmap.Dialect.QuoteField(table.TableName), len(rows)) - } else if !reflect.DeepEqual(val, rows[0]) { - t.Errorf("select * result: %v != %v", val, rows[0]) - } - - // DELETE row - deleted := _del(dbmap, val) - if deleted != 1 { - t.Errorf("Did not delete row with Id: %d", val.GetId()) - return - } - - // VERIFY deleted - val2 = _get(dbmap, val, val.GetId()) - if val2 != nil { - t.Errorf("Found invoice with id: %d after Delete()", val.GetId()) - } -} - -func TestWithIgnoredColumn(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - ic := &WithIgnoredColumn{-1, 0, 1} - _insert(dbmap, ic) - expected := &WithIgnoredColumn{0, 1, 1} - ic2 := _get(dbmap, WithIgnoredColumn{}, ic.Id).(*WithIgnoredColumn) - - if !reflect.DeepEqual(expected, ic2) { - t.Errorf("%v != %v", expected, ic2) - } - if _del(dbmap, ic) != 1 { - t.Errorf("Did not delete row with Id: %d", ic.Id) - return - } - if _get(dbmap, WithIgnoredColumn{}, ic.Id) != nil { - t.Errorf("Found id: %d after Delete()", ic.Id) - } -} - -func TestColumnFilter(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - inv1 := &Invoice{0, 100, 200, "a", 0, false} - _insert(dbmap, inv1) - - inv1.Memo = "c" - inv1.IsPaid = true - _updateColumns(dbmap, func(col *sqldb.ColumnMap) bool { - return col.ColumnName == "Memo" - }, inv1) - - inv2 := &Invoice{} - inv2 = _get(dbmap, inv2, inv1.Id).(*Invoice) - if inv2.Memo != "c" { - t.Errorf("Expected column to be updated (%#v)", inv2) - } - if inv2.IsPaid { - t.Error("IsPaid shouldn't have been updated") - } -} - -func TestTypeConversionExample(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p := Person{FName: "Bob", LName: "Smith"} - tc := &TypeConversionExample{-1, p, CustomStringType("hi")} - _insert(dbmap, tc) - - expected := &TypeConversionExample{1, p, CustomStringType("hi")} - tc2 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) - if !reflect.DeepEqual(expected, tc2) { - t.Errorf("tc2 %v != %v", expected, tc2) - } - - tc2.Name = CustomStringType("hi2") - tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"} - _update(dbmap, tc2) - - expected = &TypeConversionExample{1, tc2.PersonJSON, CustomStringType("hi2")} - tc3 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) - if !reflect.DeepEqual(expected, tc3) { - t.Errorf("tc3 %v != %v", expected, tc3) - } - - if _del(dbmap, tc) != 1 { - t.Errorf("Did not delete row with Id: %d", tc.Id) - } - -} - -func TestWithEmbeddedStruct(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - es := &WithEmbeddedStruct{-1, Names{FirstName: "Alice", LastName: "Smith"}} - _insert(dbmap, es) - expected := &WithEmbeddedStruct{1, Names{FirstName: "Alice", LastName: "Smith"}} - es2 := _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct) - if !reflect.DeepEqual(expected, es2) { - t.Errorf("%v != %v", expected, es2) - } - - es2.FirstName = "Bob" - expected.FirstName = "Bob" - _update(dbmap, es2) - es2 = _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct) - if !reflect.DeepEqual(expected, es2) { - t.Errorf("%v != %v", expected, es2) - } - - ess := rawSelect(dbmap, WithEmbeddedStruct{}, "select * from embedded_struct_test") - if !reflect.DeepEqual(es2, ess[0]) { - t.Errorf("%v != %v", es2, ess[0]) - } -} - -/* -func TestWithEmbeddedStructConflictingEmbeddedMemberNames(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - es := &WithEmbeddedStructConflictingEmbeddedMemberNames{-1, Names{FirstName: "Alice", LastName: "Smith"}, NamesConflict{FirstName: "Andrew", Surname: "Wiggin"}} - _insert(dbmap, es) - expected := &WithEmbeddedStructConflictingEmbeddedMemberNames{-1, Names{FirstName: "Alice", LastName: "Smith"}, NamesConflict{FirstName: "Andrew", Surname: "Wiggin"}} - es2 := _get(dbmap, WithEmbeddedStructConflictingEmbeddedMemberNames{}, es.Id).(*WithEmbeddedStructConflictingEmbeddedMemberNames) - if !reflect.DeepEqual(expected, es2) { - t.Errorf("%v != %v", expected, es2) - } - - es2.Names.FirstName = "Bob" - expected.Names.FirstName = "Bob" - _update(dbmap, es2) - es2 = _get(dbmap, WithEmbeddedStructConflictingEmbeddedMemberNames{}, es.Id).(*WithEmbeddedStructConflictingEmbeddedMemberNames) - if !reflect.DeepEqual(expected, es2) { - t.Errorf("%v != %v", expected, es2) - } - - ess := rawSelect(dbmap, WithEmbeddedStructConflictingEmbeddedMemberNames{}, "select * from embedded_struct_conflict_name_test") - if !reflect.DeepEqual(es2, ess[0]) { - t.Errorf("%v != %v", es2, ess[0]) - } -} - -func TestWithEmbeddedStructSameMemberName(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - es := &WithEmbeddedStructSameMemberName{-1, SameName{SameName: "Alice"}} - _insert(dbmap, es) - expected := &WithEmbeddedStructSameMemberName{-1, SameName{SameName: "Alice"}} - es2 := _get(dbmap, WithEmbeddedStructSameMemberName{}, es.Id).(*WithEmbeddedStructSameMemberName) - if !reflect.DeepEqual(expected, es2) { - t.Errorf("%v != %v", expected, es2) - } - - es2.SameName = SameName{"Bob"} - expected.SameName = SameName{"Bob"} - _update(dbmap, es2) - es2 = _get(dbmap, WithEmbeddedStructSameMemberName{}, es.Id).(*WithEmbeddedStructSameMemberName) - if !reflect.DeepEqual(expected, es2) { - t.Errorf("%v != %v", expected, es2) - } - - ess := rawSelect(dbmap, WithEmbeddedStructSameMemberName{}, "select * from embedded_struct_same_member_name_test") - if !reflect.DeepEqual(es2, ess[0]) { - t.Errorf("%v != %v", es2, ess[0]) - } -} -//*/ - -func TestWithEmbeddedStructBeforeAutoincr(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - esba := &WithEmbeddedStructBeforeAutoincrField{Names: Names{FirstName: "Alice", LastName: "Smith"}} - _insert(dbmap, esba) - var expectedAutoincrId int64 = 1 - if esba.Id != expectedAutoincrId { - t.Errorf("%d != %d", expectedAutoincrId, esba.Id) - } -} - -func TestWithEmbeddedAutoincr(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - esa := &WithEmbeddedAutoincr{ - WithEmbeddedStruct: WithEmbeddedStruct{Names: Names{FirstName: "Alice", LastName: "Smith"}}, - MiddleName: "Rose", - } - _insert(dbmap, esa) - var expectedAutoincrId int64 = 1 - if esa.Id != expectedAutoincrId { - t.Errorf("%d != %d", expectedAutoincrId, esa.Id) - } -} - -func TestSelectVal(t *testing.T) { - dbmap := initDBMapNulls(t) - defer dropAndClose(dbmap) - - bindVar := dbmap.Dialect.BindVar(0) - - t1 := TableWithNull{Str: sql.NullString{"abc", true}, - Int64: sql.NullInt64{78, true}, - Float64: sql.NullFloat64{32.2, true}, - Bool: sql.NullBool{true, true}, - Bytes: []byte("hi")} - _insert(dbmap, &t1) - - // SelectInt - i64 := selectInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") - if i64 != 78 { - t.Errorf("int64 %d != 78", i64) - } - i64 = selectInt(dbmap, "select count(*) from "+tableName(dbmap, TableWithNull{})) - if i64 != 1 { - t.Errorf("int64 count %d != 1", i64) - } - i64 = selectInt(dbmap, "select count(*) from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"="+bindVar, "asdfasdf") - if i64 != 0 { - t.Errorf("int64 no rows %d != 0", i64) - } - - // SelectNullInt - n := selectNullInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='notfound'") - if !reflect.DeepEqual(n, sql.NullInt64{0, false}) { - t.Errorf("nullint %v != 0,false", n) - } - - n = selectNullInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") - if !reflect.DeepEqual(n, sql.NullInt64{78, true}) { - t.Errorf("nullint %v != 78, true", n) - } - - // SelectFloat - f64 := selectFloat(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Float64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") - if f64 != 32.2 { - t.Errorf("float64 %f != 32.2", f64) - } - f64 = selectFloat(dbmap, "select min("+columnName(dbmap, TableWithNull{}, "Float64")+") from "+tableName(dbmap, TableWithNull{})) - if f64 != 32.2 { - t.Errorf("float64 min %f != 32.2", f64) - } - f64 = selectFloat(dbmap, "select count(*) from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"="+bindVar, "asdfasdf") - if f64 != 0 { - t.Errorf("float64 no rows %f != 0", f64) - } - - // SelectNullFloat - nf := selectNullFloat(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Float64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='notfound'") - if !reflect.DeepEqual(nf, sql.NullFloat64{0, false}) { - t.Errorf("nullfloat %v != 0,false", nf) - } - - nf = selectNullFloat(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Float64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") - if !reflect.DeepEqual(nf, sql.NullFloat64{32.2, true}) { - t.Errorf("nullfloat %v != 32.2, true", nf) - } - - // SelectStr - s := selectStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Int64")+"="+bindVar, 78) - if s != "abc" { - t.Errorf("s %s != abc", s) - } - s = selectStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='asdfasdf'") - if s != "" { - t.Errorf("s no rows %s != ''", s) - } - - // SelectNullStr - ns := selectNullStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Int64")+"="+bindVar, 78) - if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) { - t.Errorf("nullstr %v != abc,true", ns) - } - ns = selectNullStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='asdfasdf'") - if !reflect.DeepEqual(ns, sql.NullString{"", false}) { - t.Errorf("nullstr no rows %v != '',false", ns) - } - - // SelectInt/Str with named parameters - i64 = selectInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"=:abc", map[string]string{"abc": "abc"}) - if i64 != 78 { - t.Errorf("int64 %d != 78", i64) - } - ns = selectNullStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Int64")+"=:num", map[string]int{"num": 78}) - if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) { - t.Errorf("nullstr %v != abc,true", ns) - } -} - -func TestVersionMultipleRows(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - persons := []*Person{ - &Person{0, 0, 0, "Bob", "Smith", 0}, - &Person{0, 0, 0, "Jane", "Smith", 0}, - &Person{0, 0, 0, "Mike", "Smith", 0}, - } - - _insert(dbmap, persons[0], persons[1], persons[2]) - - for x, p := range persons { - if p.Version != 1 { - t.Errorf("person[%d].Version != 1: %d", x, p.Version) - } - } -} - -func TestWithStringPk(t *testing.T) { - dbmap := newDBMap(t) - dbmap.AddTableWithName(WithStringPk{}, "string_pk_test").SetKeys(true, "Id") - _, err := dbmap.Exec("create table string_pk_test (Id varchar(255), Name varchar(255));") - if err != nil { - t.Errorf("couldn't create string_pk_test: %v", err) - } - defer dropAndClose(dbmap) - - row := &WithStringPk{"1", "foo"} - err = dbmap.Insert(row) - if err == nil { - t.Errorf("Expected error when inserting into table w/non Int PK and autoincr set true") - } -} - -// TestSqlExecutorInterfaceSelects ensures that all sqldb.DbMap methods starting with Select... -// are also exposed in the sqldb.SqlExecutor interface. Select... functions can always -// run on Pre/Post hooks. -func TestSqlExecutorInterfaceSelects(t *testing.T) { - dbMapType := reflect.TypeOf(&sqldb.DbMap{}) - sqlExecutorType := reflect.TypeOf((*sqldb.SqlExecutor)(nil)).Elem() - numDbMapMethods := dbMapType.NumMethod() - for i := 0; i < numDbMapMethods; i += 1 { - dbMapMethod := dbMapType.Method(i) - if !strings.HasPrefix(dbMapMethod.Name, "Select") { - continue - } - if _, found := sqlExecutorType.MethodByName(dbMapMethod.Name); !found { - t.Errorf("Method %s is defined on sqldb.DbMap but not implemented in sqldb.SqlExecutor", - dbMapMethod.Name) - } - } -} - -func TestNullTime(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - // if time is null - ent := &WithNullTime{ - Id: 0, - Time: sqldb.NullTime{ - Valid: false, - }} - err := dbmap.Insert(ent) - if err != nil { - t.Errorf("failed insert on %s", err.Error()) - } - err = dbmap.SelectOne(ent, `select * from nulltime_test where `+columnName(dbmap, WithNullTime{}, "Id")+`=:Id`, map[string]interface{}{ - "Id": ent.Id, - }) - if err != nil { - t.Errorf("failed select on %s", err.Error()) - } - if ent.Time.Valid { - t.Error("sqldb.NullTime returns valid but expected null.") - } - - // if time is not null - ts, err := time.Parse(time.RFC3339, "2001-01-02T15:04:05-07:00") - if err != nil { - t.Errorf("failed to parse time %s: %s", time.Stamp, err.Error()) - } - ent = &WithNullTime{ - Id: 1, - Time: sqldb.NullTime{ - Valid: true, - Time: ts, - }} - err = dbmap.Insert(ent) - if err != nil { - t.Errorf("failed insert on %s", err.Error()) - } - err = dbmap.SelectOne(ent, `select * from nulltime_test where `+columnName(dbmap, WithNullTime{}, "Id")+`=:Id`, map[string]interface{}{ - "Id": ent.Id, - }) - if err != nil { - t.Errorf("failed select on %s", err.Error()) - } - if !ent.Time.Valid { - t.Error("sqldb.NullTime returns invalid but expected valid.") - } - if ent.Time.Time.UTC() != ts.UTC() { - t.Errorf("expect %v but got %v.", ts, ent.Time.Time) - } - - return -} - -type WithTime struct { - Id int64 - Time time.Time -} - -type Times struct { - One time.Time - Two time.Time -} - -type EmbeddedTime struct { - Id string - Times -} - -func parseTimeOrPanic(format, date string) time.Time { - t1, err := time.Parse(format, date) - if err != nil { - panic(err) - } - return t1 -} - -func TestWithTime(t *testing.T) { - if _, driver := dialectAndDriver(); driver == "mysql" { - t.Skip("mysql drivers don't support time.Time, skipping...") - } - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - t1 := parseTimeOrPanic("2006-01-02 15:04:05 -0700 MST", - "2013-08-09 21:30:43 +0800 CST") - w1 := WithTime{1, t1} - _insert(dbmap, &w1) - - obj := _get(dbmap, WithTime{}, w1.Id) - w2 := obj.(*WithTime) - if w1.Time.UnixNano() != w2.Time.UnixNano() { - t.Errorf("%v != %v", w1, w2) - } -} - -func TestEmbeddedTime(t *testing.T) { - if _, driver := dialectAndDriver(); driver == "mysql" { - t.Skip("mysql drivers don't support time.Time, skipping...") - } - dbmap := newDBMap(t) - dbmap.AddTable(EmbeddedTime{}).SetKeys(false, "Id") - defer dropAndClose(dbmap) - err := dbmap.CreateTables() - if err != nil { - t.Fatal(err) - } - - time1 := parseTimeOrPanic("2006-01-02 15:04:05", "2013-08-09 21:30:43") - - t1 := &EmbeddedTime{Id: "abc", Times: Times{One: time1, Two: time1.Add(10 * time.Second)}} - _insert(dbmap, t1) - - x := _get(dbmap, EmbeddedTime{}, t1.Id) - t2, _ := x.(*EmbeddedTime) - if t1.One.UnixNano() != t2.One.UnixNano() || t1.Two.UnixNano() != t2.Two.UnixNano() { - t.Errorf("%v != %v", t1, t2) - } -} - -func TestWithTimeSelect(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - halfhourago := time.Now().UTC().Add(-30 * time.Minute) - - w1 := WithTime{1, halfhourago.Add(time.Minute * -1)} - w2 := WithTime{2, halfhourago.Add(time.Second)} - _insert(dbmap, &w1, &w2) - - var caseIds []int64 - _, err := dbmap.Select(&caseIds, "SELECT "+columnName(dbmap, WithTime{}, "Id")+" FROM time_test WHERE "+columnName(dbmap, WithTime{}, "Time")+" < "+dbmap.Dialect.BindVar(0), halfhourago) - - if err != nil { - t.Error(err) - } - if len(caseIds) != 1 { - t.Errorf("%d != 1", len(caseIds)) - } - if caseIds[0] != w1.Id { - t.Errorf("%d != %d", caseIds[0], w1.Id) - } -} - -func TestInvoicePersonView(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - // Create some rows - p1 := &Person{0, 0, 0, "bob", "smith", 0} - dbmap.Insert(p1) - - // notice how we can wire up p1.Id to the invoice easily - inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, false} - dbmap.Insert(inv1) - - // Run your query - query := "select i." + columnName(dbmap, Invoice{}, "Id") + " InvoiceId, p." + columnName(dbmap, Person{}, "Id") + " PersonId, i." + columnName(dbmap, Invoice{}, "Memo") + ", p." + columnName(dbmap, Person{}, "FName") + " " + - "from invoice_test i, person_test p " + - "where i." + columnName(dbmap, Invoice{}, "PersonId") + " = p." + columnName(dbmap, Person{}, "Id") - - // pass a slice of pointers to Select() - // this avoids the need to type assert after the query is run - var list []*InvoicePersonView - _, err := dbmap.Select(&list, query) - if err != nil { - panic(err) - } - - // this should test true - expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0} - if !reflect.DeepEqual(list[0], expected) { - t.Errorf("%v != %v", list[0], expected) - } -} - -func TestQuoteTableNames(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - quotedTableName := dbmap.Dialect.QuoteField("person_test") - - // Use a buffer to hold the log to check generated queries - logBuffer := &bytes.Buffer{} - dbmap.TraceOn("", log.New(logBuffer, "sqldbtest:", log.Lmicroseconds)) - - // Create some rows - p1 := &Person{0, 0, 0, "bob", "smith", 0} - errorTemplate := "Expected quoted table name %v in query but didn't find it" - - // Check if Insert quotes the table name - id := dbmap.Insert(p1) - if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) { - t.Errorf(errorTemplate, quotedTableName) - } - logBuffer.Reset() - - // Check if Get quotes the table name - dbmap.Get(Person{}, id) - if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) { - t.Errorf(errorTemplate, quotedTableName) - } - logBuffer.Reset() -} - -func TestSelectTooManyCols(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p1 := &Person{0, 0, 0, "bob", "smith", 0} - p2 := &Person{0, 0, 0, "jane", "doe", 0} - _insert(dbmap, p1) - _insert(dbmap, p2) - - obj := _get(dbmap, Person{}, p1.Id) - p1 = obj.(*Person) - obj = _get(dbmap, Person{}, p2.Id) - p2 = obj.(*Person) - - params := map[string]interface{}{ - "Id": p1.Id, - } - - var p3 FNameOnly - err := dbmap.SelectOne(&p3, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) - if err != nil { - if !sqldb.NonFatalError(err) { - t.Error(err) - } - } else { - t.Errorf("Non-fatal error expected") - } - - if p1.FName != p3.FName { - t.Errorf("%v != %v", p1.FName, p3.FName) - } - - var pSlice []FNameOnly - _, err = dbmap.Select(&pSlice, "select * from person_test order by "+columnName(dbmap, Person{}, "FName")+" asc") - if err != nil { - if !sqldb.NonFatalError(err) { - t.Error(err) - } - } else { - t.Errorf("Non-fatal error expected") - } - - if p1.FName != pSlice[0].FName { - t.Errorf("%v != %v", p1.FName, pSlice[0].FName) - } - if p2.FName != pSlice[1].FName { - t.Errorf("%v != %v", p2.FName, pSlice[1].FName) - } -} - -func TestSelectSingleVal(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p1 := &Person{0, 0, 0, "bob", "smith", 0} - _insert(dbmap, p1) - - obj := _get(dbmap, Person{}, p1.Id) - p1 = obj.(*Person) - - params := map[string]interface{}{ - "Id": p1.Id, - } - - var p2 Person - err := dbmap.SelectOne(&p2, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(p1, &p2) { - t.Errorf("%v != %v", p1, &p2) - } - - // verify SelectOne allows non-struct holders - var s string - err = dbmap.SelectOne(&s, "select "+columnName(dbmap, Person{}, "FName")+" from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) - if err != nil { - t.Error(err) - } - if s != "bob" { - t.Error("Expected bob but got: " + s) - } - - // verify SelectOne requires pointer receiver - err = dbmap.SelectOne(s, "select "+columnName(dbmap, Person{}, "FName")+" from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) - if err == nil { - t.Error("SelectOne should have returned error for non-pointer holder") - } - - // verify SelectOne works with uninitialized pointers - var p3 *Person - err = dbmap.SelectOne(&p3, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) - if err != nil { - t.Error(err) - } - - if !reflect.DeepEqual(p1, p3) { - t.Errorf("%v != %v", p1, p3) - } - - // verify that the receiver is still nil if nothing was found - var p4 *Person - dbmap.SelectOne(&p3, "select * from person_test where 2<1 AND "+columnName(dbmap, Person{}, "Id")+"=:Id", params) - if p4 != nil { - t.Error("SelectOne should not have changed a nil receiver when no rows were found") - } - - // verify that the error is set to sql.ErrNoRows if not found - err = dbmap.SelectOne(&p2, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", map[string]interface{}{ - "Id": -2222, - }) - if err == nil || err != sql.ErrNoRows { - t.Error("SelectOne should have returned an sql.ErrNoRows") - } - - _insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0}) - err = dbmap.SelectOne(&p2, "select * from person_test where "+columnName(dbmap, Person{}, "FName")+"='bob'") - if err == nil { - t.Error("Expected error when two rows found") - } - - // tests for #150 - var tInt int64 - var tStr string - var tBool bool - var tFloat float64 - primVals := []interface{}{tInt, tStr, tBool, tFloat} - for _, prim := range primVals { - err = dbmap.SelectOne(&prim, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=-123") - if err == nil || err != sql.ErrNoRows { - t.Error("primVals: SelectOne should have returned sql.ErrNoRows") - } - } -} - -func TestSelectAlias(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - p1 := &IdCreatedExternal{IdCreated: IdCreated{Id: 1, Created: 3}, External: 2} - - // Insert using embedded IdCreated, which reflects the structure of the table - _insert(dbmap, &p1.IdCreated) - - // Select into IdCreatedExternal type, which includes some fields not present - // in id_created_test - var p2 IdCreatedExternal - err := dbmap.SelectOne(&p2, "select * from id_created_test where "+columnName(dbmap, IdCreatedExternal{}, "Id")+"=1") - if err != nil { - t.Error(err) - } - if p2.Id != 1 || p2.Created != 3 || p2.External != 0 { - t.Error("Expected ignored field defaults to not set") - } - - // Prove that we can supply an aliased value in the select, and that it will - // automatically map to IdCreatedExternal.External - err = dbmap.SelectOne(&p2, "SELECT *, 1 AS external FROM id_created_test") - if err != nil { - t.Error(err) - } - if p2.External != 1 { - t.Error("Expected select as can map to exported field.") - } - - var rows *sql.Rows - var cols []string - rows, err = dbmap.Db.Query("SELECT * FROM id_created_test") - cols, err = rows.Columns() - if err != nil || len(cols) != 2 { - t.Error("Expected ignored column not created") - } -} - -func TestMysqlPanicIfDialectNotInitialized(t *testing.T) { - _, driver := dialectAndDriver() - // this test only applies to MySQL - if os.Getenv("SQLDB_TEST_DIALECT") != "mysql" { - return - } - - // The expected behaviour is to catch a panic. - // Here is the deferred function which will check if a panic has indeed occurred : - defer func() { - r := recover() - if r == nil { - t.Error("db.CreateTables() should panic if db is initialized with an incorrect sqldb.MySQLDialect") - } - }() - - // invalid MySQLDialect : does not contain Engine or Encoding specification - dialect := sqldb.MySQLDialect{} - db := &sqldb.DbMap{Db: connect(driver), Dialect: dialect} - db.AddTableWithName(Invoice{}, "invoice") - // the following call should panic : - db.CreateTables() -} - -func TestSingleColumnKeyDbReturnsZeroRowsUpdatedOnPKChange(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - dbmap.AddTableWithName(SingleColumnTable{}, "single_column_table").SetKeys(false, "SomeId") - err := dbmap.DropTablesIfExists() - if err != nil { - t.Error("Drop tables failed") - } - err = dbmap.CreateTablesIfNotExists() - if err != nil { - t.Error("Create tables failed") - } - err = dbmap.TruncateTables() - if err != nil { - t.Error("Truncate tables failed") - } - - sct := SingleColumnTable{ - SomeId: "A Unique Id String", - } - - count, err := dbmap.Update(&sct) - if err != nil { - t.Error(err) - } - if count != 0 { - t.Errorf("Expected 0 updated rows, got %d", count) - } - -} - -func TestPrepare(t *testing.T) { - dbmap := initDBMap(t) - defer dropAndClose(dbmap) - - inv1 := &Invoice{0, 100, 200, "prepare-foo", 0, false} - inv2 := &Invoice{0, 100, 200, "prepare-bar", 0, false} - _insert(dbmap, inv1, inv2) - - bindVar0 := dbmap.Dialect.BindVar(0) - bindVar1 := dbmap.Dialect.BindVar(1) - stmt, err := dbmap.Prepare(fmt.Sprintf("UPDATE invoice_test SET "+columnName(dbmap, Invoice{}, "Memo")+"=%s WHERE "+columnName(dbmap, Invoice{}, "Id")+"=%s", bindVar0, bindVar1)) - if err != nil { - t.Error(err) - } - defer stmt.Close() - _, err = stmt.Exec("prepare-baz", inv1.Id) - if err != nil { - t.Error(err) - } - err = dbmap.SelectOne(inv1, "SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "Memo")+"='prepare-baz'") - if err != nil { - t.Error(err) - } - - trans, err := dbmap.Begin() - if err != nil { - t.Error(err) - } - transStmt, err := trans.Prepare(fmt.Sprintf("UPDATE invoice_test SET "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s WHERE "+columnName(dbmap, Invoice{}, "Id")+"=%s", bindVar0, bindVar1)) - if err != nil { - t.Error(err) - } - defer transStmt.Close() - _, err = transStmt.Exec(true, inv2.Id) - if err != nil { - t.Error(err) - } - err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s", bindVar0), true) - if err == nil || err != sql.ErrNoRows { - t.Error("SelectOne should have returned an sql.ErrNoRows") - } - err = trans.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s", bindVar0), true) - if err != nil { - t.Error(err) - } - err = trans.Commit() - if err != nil { - t.Error(err) - } - err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s", bindVar0), true) - if err != nil { - t.Error(err) - } -} - -type UUID4 string - -func (u UUID4) Value() (driver.Value, error) { - if u == "" { - return nil, nil - } - - return string(u), nil -} - -type NilPointer struct { - ID string - UserID *UUID4 -} - -func TestCallOfValueMethodOnNilPointer(t *testing.T) { - dbmap := newDBMap(t) - dbmap.AddTable(NilPointer{}).SetKeys(false, "ID") - defer dropAndClose(dbmap) - err := dbmap.CreateTables() - if err != nil { - t.Fatal(err) - } - - nilPointer := &NilPointer{ID: "abc", UserID: nil} - _insert(dbmap, nilPointer) -} - -func BenchmarkNativeCrud(b *testing.B) { - b.StopTimer() - dbmap := initDBMapBench(b) - defer dropAndClose(dbmap) - columnId := columnName(dbmap, Invoice{}, "Id") - columnCreated := columnName(dbmap, Invoice{}, "Created") - columnUpdated := columnName(dbmap, Invoice{}, "Updated") - columnMemo := columnName(dbmap, Invoice{}, "Memo") - columnPersonId := columnName(dbmap, Invoice{}, "PersonId") - b.StartTimer() - - var insert, sel, update, delete string - if os.Getenv("SQLDB_TEST_DIALECT") != "postgres" { - insert = "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values (?, ?, ?, ?)" - sel = "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=?" - update = "update invoice_test set " + columnCreated + "=?, " + columnUpdated + "=?, " + columnMemo + "=?, " + columnPersonId + "=? where " + columnId + "=?" - delete = "delete from invoice_test where " + columnId + "=?" - } else { - insert = "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values ($1, $2, $3, $4)" - sel = "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=$1" - update = "update invoice_test set " + columnCreated + "=$1, " + columnUpdated + "=$2, " + columnMemo + "=$3, " + columnPersonId + "=$4 where " + columnId + "=$5" - delete = "delete from invoice_test where " + columnId + "=$1" - } - - inv := &Invoice{0, 100, 200, "my memo", 0, false} - - for i := 0; i < b.N; i++ { - res, err := dbmap.Db.Exec(insert, inv.Created, inv.Updated, - inv.Memo, inv.PersonId) - if err != nil { - panic(err) - } - - newid, err := res.LastInsertId() - if err != nil { - panic(err) - } - inv.Id = newid - - row := dbmap.Db.QueryRow(sel, inv.Id) - err = row.Scan(&inv.Id, &inv.Created, &inv.Updated, &inv.Memo, - &inv.PersonId) - if err != nil { - panic(err) - } - - inv.Created = 1000 - inv.Updated = 2000 - inv.Memo = "my memo 2" - inv.PersonId = 3000 - - _, err = dbmap.Db.Exec(update, inv.Created, inv.Updated, inv.Memo, - inv.PersonId, inv.Id) - if err != nil { - panic(err) - } - - _, err = dbmap.Db.Exec(delete, inv.Id) - if err != nil { - panic(err) - } - } - -} - -func BenchmarkSqldbCrud(b *testing.B) { - b.StopTimer() - dbmap := initDBMapBench(b) - defer dropAndClose(dbmap) - b.StartTimer() - - inv := &Invoice{0, 100, 200, "my memo", 0, true} - for i := 0; i < b.N; i++ { - err := dbmap.Insert(inv) - if err != nil { - panic(err) - } - - obj, err := dbmap.Get(Invoice{}, inv.Id) - if err != nil { - panic(err) - } - - inv2, ok := obj.(*Invoice) - if !ok { - panic(fmt.Sprintf("expected *Invoice, got: %v", obj)) - } - - inv2.Created = 1000 - inv2.Updated = 2000 - inv2.Memo = "my memo 2" - inv2.PersonId = 3000 - _, err = dbmap.Update(inv2) - if err != nil { - panic(err) - } - - _, err = dbmap.Delete(inv2) - if err != nil { - panic(err) - } - - } -} - -func initDBMapBench(b *testing.B) *sqldb.DbMap { - dbmap := newDBMap(b) - dbmap.Db.Exec("drop table if exists invoice_test") - dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - return dbmap -} - -func initDBMap(t *testing.T) *sqldb.DbMap { - dbmap := newDBMap(t) - dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") - dbmap.AddTableWithName(InvoiceTag{}, "invoice_tag_test") //key is set via primarykey attribute - dbmap.AddTableWithName(AliasTransientField{}, "alias_trans_field_test").SetKeys(true, "id") - dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id") - dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id").SetVersionCol("Version") - dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id") - dbmap.AddTableWithName(IdCreated{}, "id_created_test").SetKeys(true, "Id") - dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id") - dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id") - //dbmap.AddTableWithName(WithEmbeddedStructConflictingEmbeddedMemberNames{}, "embedded_struct_conflict_name_test").SetKeys(true, "Id") - //dbmap.AddTableWithName(WithEmbeddedStructSameMemberName{}, "embedded_struct_same_member_name_test").SetKeys(true, "Id") - dbmap.AddTableWithName(WithEmbeddedStructBeforeAutoincrField{}, "embedded_struct_before_autoincr_test").SetKeys(true, "Id") - dbmap.AddTableDynamic(&dynTableInst1, "").SetKeys(true, "Id").AddIndex("TenantInst1Index", "Btree", []string{"Name"}).SetUnique(true) - dbmap.AddTableDynamic(&dynTableInst2, "").SetKeys(true, "Id").AddIndex("TenantInst2Index", "Btree", []string{"Name"}).SetUnique(true) - dbmap.AddTableWithName(WithEmbeddedAutoincr{}, "embedded_autoincr_test").SetKeys(true, "Id") - dbmap.AddTableWithName(WithTime{}, "time_test").SetKeys(true, "Id") - dbmap.AddTableWithName(WithNullTime{}, "nulltime_test").SetKeys(false, "Id") - dbmap.TypeConverter = testTypeConverter{} - err := dbmap.DropTablesIfExists() - if err != nil { - panic(err) - } - err = dbmap.CreateTables() - if err != nil { - panic(err) - } - - err = dbmap.CreateIndex() - if err != nil { - panic(err) - } - - // See #146 and TestSelectAlias - this type is mapped to the same - // table as IdCreated, but includes an extra field that isn't in the table - dbmap.AddTableWithName(IdCreatedExternal{}, "id_created_test").SetKeys(true, "Id") - - return dbmap -} - -func initDBMapNulls(t *testing.T) *sqldb.DbMap { - dbmap := newDBMap(t) - dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id") - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - return dbmap -} - -type Logger interface { - Logf(format string, args ...any) -} - -type TestLogger struct { - l Logger -} - -func (l TestLogger) Printf(format string, args ...any) { - l.l.Logf(format, args...) -} - -func newDBMap(l Logger) *sqldb.DbMap { - dialect, driver := dialectAndDriver() - dbmap := &sqldb.DbMap{Db: connect(driver), Dialect: dialect} - if debug { - dbmap.TraceOn("", TestLogger{l: l}) - } - return dbmap -} - -func dropAndClose(dbmap *sqldb.DbMap) { - dbmap.DropTablesIfExists() - dbmap.Db.Close() -} - -func connect(driver string) *sql.DB { - dsn := os.Getenv("SQLDB_TEST_DSN") - if dsn == "" { - panic("SQLDB_TEST_DSN env variable is not set. Please see README.md") - } - - db, err := sql.Open(driver, dsn) - if err != nil { - panic("Error connecting to db: " + err.Error()) - } - return db -} - -func dialectAndDriver() (sqldb.Dialect, string) { - switch os.Getenv("SQLDB_TEST_DIALECT") { - case "mysql", "gomysql": - // NOTE: the 'mysql' driver used to use github.com/ziutek/mymysql, but that project - // seems mostly unmaintained recently. We've dropped it from tests, at least for - // now. - return sqldb.MySQLDialect{"InnoDB", "UTF8"}, "mysql" - case "postgres": - return sqldb.PostgresDialect{}, "postgres" - case "sqlite": - return sqldb.SqliteDialect{}, "sqlite3" - } - panic("SQLDB_TEST_DIALECT env variable is not set or is invalid. Please see README.md") -} - -func _insert(dbmap *sqldb.DbMap, list ...interface{}) { - err := dbmap.Insert(list...) - if err != nil { - panic(err) - } -} - -func _update(dbmap *sqldb.DbMap, list ...interface{}) int64 { - count, err := dbmap.Update(list...) - if err != nil { - panic(err) - } - return count -} - -func _updateColumns(dbmap *sqldb.DbMap, filter sqldb.ColumnFilter, list ...interface{}) int64 { - count, err := dbmap.UpdateColumns(filter, list...) - if err != nil { - panic(err) - } - return count -} - -func _del(dbmap *sqldb.DbMap, list ...interface{}) int64 { - count, err := dbmap.Delete(list...) - if err != nil { - panic(err) - } - - return count -} - -func _get(dbmap *sqldb.DbMap, i interface{}, keys ...interface{}) interface{} { - obj, err := dbmap.Get(i, keys...) - if err != nil { - panic(err) - } - - return obj -} - -func selectInt(dbmap *sqldb.DbMap, query string, args ...interface{}) int64 { - i64, err := sqldb.SelectInt(dbmap, query, args...) - if err != nil { - panic(err) - } - - return i64 -} - -func selectNullInt(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.NullInt64 { - i64, err := sqldb.SelectNullInt(dbmap, query, args...) - if err != nil { - panic(err) - } - - return i64 -} - -func selectFloat(dbmap *sqldb.DbMap, query string, args ...interface{}) float64 { - f64, err := sqldb.SelectFloat(dbmap, query, args...) - if err != nil { - panic(err) - } - - return f64 -} - -func selectNullFloat(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.NullFloat64 { - f64, err := sqldb.SelectNullFloat(dbmap, query, args...) - if err != nil { - panic(err) - } - - return f64 -} - -func selectStr(dbmap *sqldb.DbMap, query string, args ...interface{}) string { - s, err := sqldb.SelectStr(dbmap, query, args...) - if err != nil { - panic(err) - } - - return s -} - -func selectNullStr(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.NullString { - s, err := sqldb.SelectNullStr(dbmap, query, args...) - if err != nil { - panic(err) - } - - return s -} - -func rawExec(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.Result { - res, err := dbmap.Exec(query, args...) - if err != nil { - panic(err) - } - return res -} - -func rawSelect(dbmap *sqldb.DbMap, i interface{}, query string, args ...interface{}) []interface{} { - list, err := dbmap.Select(i, query, args...) - if err != nil { - panic(err) - } - return list -} - -func tableName(dbmap *sqldb.DbMap, i interface{}) string { - t := reflect.TypeOf(i) - if table, err := dbmap.TableFor(t, false); table != nil && err == nil { - return dbmap.Dialect.QuoteField(table.TableName) - } - return t.Name() -} - -func columnName(dbmap *sqldb.DbMap, i interface{}, fieldName string) string { - t := reflect.TypeOf(i) - if table, err := dbmap.TableFor(t, false); table != nil && err == nil { - return dbmap.Dialect.QuoteField(table.ColMap(fieldName).ColumnName) - } - return fieldName -} diff --git a/gdb/sqldb/table.go b/gdb/sqldb/table.go deleted file mode 100644 index c628aa9..0000000 --- a/gdb/sqldb/table.go +++ /dev/null @@ -1,258 +0,0 @@ -// -// table.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "bytes" - "fmt" - "reflect" - "strings" -) - -// TableMap represents a mapping between a Go struct and a database table -// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these -type TableMap struct { - // Name of database table. - TableName string - SchemaName string - gotype reflect.Type - Columns []*ColumnMap - keys []*ColumnMap - indexes []*IndexMap - uniqueTogether [][]string - version *ColumnMap - insertPlan bindPlan - updatePlan bindPlan - deletePlan bindPlan - getPlan bindPlan - dbmap *DbMap -} - -// ResetSql removes cached insert/update/select/delete SQL strings -// associated with this TableMap. Call this if you've modified -// any column names or the table name itself. -func (t *TableMap) ResetSql() { - t.insertPlan = bindPlan{} - t.updatePlan = bindPlan{} - t.deletePlan = bindPlan{} - t.getPlan = bindPlan{} -} - -// SetKeys lets you specify the fields on a struct that map to primary -// key columns on the table. If isAutoIncr is set, result.LastInsertId() -// will be used after INSERT to bind the generated id to the Go struct. -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -// -// Panics if isAutoIncr is true, and fieldNames length != 1 -func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap { - if isAutoIncr && len(fieldNames) != 1 { - panic(fmt.Sprintf( - "sqldb: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)", - len(fieldNames))) - } - t.keys = make([]*ColumnMap, 0) - for _, name := range fieldNames { - colmap := t.ColMap(name) - colmap.isPK = true - colmap.isAutoIncr = isAutoIncr - t.keys = append(t.keys, colmap) - } - t.ResetSql() - - return t -} - -// SetUniqueTogether lets you specify uniqueness constraints across multiple -// columns on the table. Each call adds an additional constraint for the -// specified columns. -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -// -// Panics if fieldNames length < 2. -func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap { - if len(fieldNames) < 2 { - panic(fmt.Sprintf( - "sqldb: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint.")) - } - - columns := make([]string, 0, len(fieldNames)) - for _, name := range fieldNames { - columns = append(columns, name) - } - - for _, existingColumns := range t.uniqueTogether { - if equal(existingColumns, columns) { - return t - } - } - t.uniqueTogether = append(t.uniqueTogether, columns) - t.ResetSql() - - return t -} - -// ColMap returns the ColumnMap pointer matching the given struct field -// name. It panics if the struct does not contain a field matching this -// name. -func (t *TableMap) ColMap(field string) *ColumnMap { - col := colMapOrNil(t, field) - if col == nil { - e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s", - t.TableName, t.gotype.Name(), field) - - panic(e) - } - return col -} - -func colMapOrNil(t *TableMap, field string) *ColumnMap { - for _, col := range t.Columns { - if col.fieldName == field || col.ColumnName == field { - return col - } - } - return nil -} - -// IdxMap returns the IndexMap pointer matching the given index name. -func (t *TableMap) IdxMap(field string) *IndexMap { - for _, idx := range t.indexes { - if idx.IndexName == field { - return idx - } - } - return nil -} - -// AddIndex registers the index with sqldb for specified table with given parameters. -// This operation is idempotent. If index is already mapped, the -// existing *IndexMap is returned -// Function will panic if one of the given for index columns does not exists -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -func (t *TableMap) AddIndex(name string, idxtype string, columns []string) *IndexMap { - // check if we have a index with this name already - for _, idx := range t.indexes { - if idx.IndexName == name { - return idx - } - } - for _, icol := range columns { - if res := t.ColMap(icol); res == nil { - e := fmt.Sprintf("No ColumnName in table %s to create index on", t.TableName) - panic(e) - } - } - - idx := &IndexMap{IndexName: name, Unique: false, IndexType: idxtype, columns: columns} - t.indexes = append(t.indexes, idx) - t.ResetSql() - return idx -} - -// SetVersionCol sets the column to use as the Version field. By default -// the "Version" field is used. Returns the column found, or panics -// if the struct does not contain a field matching this name. -// -// Automatically calls ResetSql() to ensure SQL statements are regenerated. -func (t *TableMap) SetVersionCol(field string) *ColumnMap { - c := t.ColMap(field) - t.version = c - t.ResetSql() - return c -} - -// SqlForCreateTable gets a sequence of SQL commands that will create -// the specified table and any associated schema -func (t *TableMap) SqlForCreate(ifNotExists bool) string { - s := bytes.Buffer{} - dialect := t.dbmap.Dialect - - if strings.TrimSpace(t.SchemaName) != "" { - schemaCreate := "create schema" - if ifNotExists { - s.WriteString(dialect.IfSchemaNotExists(schemaCreate, t.SchemaName)) - } else { - s.WriteString(schemaCreate) - } - s.WriteString(fmt.Sprintf(" %s;", t.SchemaName)) - } - - tableCreate := "create table" - if ifNotExists { - s.WriteString(dialect.IfTableNotExists(tableCreate, t.SchemaName, t.TableName)) - } else { - s.WriteString(tableCreate) - } - s.WriteString(fmt.Sprintf(" %s (", dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - - x := 0 - for _, col := range t.Columns { - if !col.Transient { - if x > 0 { - s.WriteString(", ") - } - stype := dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr) - s.WriteString(fmt.Sprintf("%s %s", dialect.QuoteField(col.ColumnName), stype)) - - if col.isPK || col.isNotNull { - s.WriteString(" not null") - } - if col.isPK && len(t.keys) == 1 { - s.WriteString(" primary key") - } - if col.Unique { - s.WriteString(" unique") - } - if col.isAutoIncr { - s.WriteString(fmt.Sprintf(" %s", dialect.AutoIncrStr())) - } - - x++ - } - } - if len(t.keys) > 1 { - s.WriteString(", primary key (") - for x := range t.keys { - if x > 0 { - s.WriteString(", ") - } - s.WriteString(dialect.QuoteField(t.keys[x].ColumnName)) - } - s.WriteString(")") - } - if len(t.uniqueTogether) > 0 { - for _, columns := range t.uniqueTogether { - s.WriteString(", unique (") - for i, column := range columns { - if i > 0 { - s.WriteString(", ") - } - s.WriteString(dialect.QuoteField(column)) - } - s.WriteString(")") - } - } - s.WriteString(") ") - s.WriteString(dialect.CreateTableSuffix()) - s.WriteString(dialect.QuerySuffix()) - return s.String() -} - -func equal(a, b []string) bool { - if len(a) != len(b) { - return false - } - for i := range a { - if a[i] != b[i] { - return false - } - } - return true -} diff --git a/gdb/sqldb/table_bindings.go b/gdb/sqldb/table_bindings.go deleted file mode 100644 index 13a4ca8..0000000 --- a/gdb/sqldb/table_bindings.go +++ /dev/null @@ -1,308 +0,0 @@ -// -// table_bindings.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "bytes" - "fmt" - "reflect" - "sync" -) - -// CustomScanner binds a database column value to a Go type -type CustomScanner struct { - // After a row is scanned, Holder will contain the value from the database column. - // Initialize the CustomScanner with the concrete Go type you wish the database - // driver to scan the raw column into. - Holder interface{} - // Target typically holds a pointer to the target struct field to bind the Holder - // value to. - Target interface{} - // Binder is a custom function that converts the holder value to the target type - // and sets target accordingly. This function should return error if a problem - // occurs converting the holder to the target. - Binder func(holder interface{}, target interface{}) error -} - -// Used to filter columns when selectively updating -type ColumnFilter func(*ColumnMap) bool - -func acceptAllFilter(col *ColumnMap) bool { - return true -} - -// Bind is called automatically by sqldb after Scan() -func (me CustomScanner) Bind() error { - return me.Binder(me.Holder, me.Target) -} - -type bindPlan struct { - query string - argFields []string - keyFields []string - versField string - autoIncrIdx int - autoIncrFieldName string - once sync.Once -} - -func (plan *bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) { - bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField} - if plan.versField != "" { - bi.existingVersion = elem.FieldByName(plan.versField).Int() - } - - var err error - - for i := 0; i < len(plan.argFields); i++ { - k := plan.argFields[i] - if k == versFieldConst { - newVer := bi.existingVersion + 1 - bi.args = append(bi.args, newVer) - if bi.existingVersion == 0 { - elem.FieldByName(plan.versField).SetInt(int64(newVer)) - } - } else { - val := elem.FieldByName(k).Interface() - if conv != nil { - val, err = conv.ToDb(val) - if err != nil { - return bindInstance{}, err - } - } - bi.args = append(bi.args, val) - } - } - - for i := 0; i < len(plan.keyFields); i++ { - k := plan.keyFields[i] - val := elem.FieldByName(k).Interface() - if conv != nil { - val, err = conv.ToDb(val) - if err != nil { - return bindInstance{}, err - } - } - bi.keys = append(bi.keys, val) - } - - return bi, nil -} - -type bindInstance struct { - query string - args []interface{} - keys []interface{} - existingVersion int64 - versField string - autoIncrIdx int - autoIncrFieldName string -} - -func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { - plan := &t.insertPlan - plan.once.Do(func() { - plan.autoIncrIdx = -1 - - s := bytes.Buffer{} - s2 := bytes.Buffer{} - s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - - x := 0 - first := true - for y := range t.Columns { - col := t.Columns[y] - if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") { - if !col.Transient { - if !first { - s.WriteString(",") - s2.WriteString(",") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - - if col.isAutoIncr { - s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) - plan.autoIncrIdx = y - plan.autoIncrFieldName = col.fieldName - } else { - if col.DefaultValue == "" { - s2.WriteString(t.dbmap.Dialect.BindVar(x)) - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) - } else { - plan.argFields = append(plan.argFields, col.fieldName) - } - x++ - } else { - s2.WriteString(col.DefaultValue) - } - } - first = false - } - } else { - plan.autoIncrIdx = y - plan.autoIncrFieldName = col.fieldName - } - } - s.WriteString(") values (") - s.WriteString(s2.String()) - s.WriteString(")") - if plan.autoIncrIdx > -1 { - s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx])) - } - s.WriteString(t.dbmap.Dialect.QuerySuffix()) - - plan.query = s.String() - }) - - return plan.createBindInstance(elem, t.dbmap.TypeConverter) -} - -func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { - if colFilter == nil { - colFilter = acceptAllFilter - } - - plan := &t.updatePlan - plan.once.Do(func() { - s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - x := 0 - - for y := range t.Columns { - col := t.Columns[y] - if !col.isAutoIncr && !col.Transient && colFilter(col) { - if x > 0 { - s.WriteString(", ") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) - - if col == t.version { - plan.versField = col.fieldName - plan.argFields = append(plan.argFields, versFieldConst) - } else { - plan.argFields = append(plan.argFields, col.fieldName) - } - x++ - } - } - - s.WriteString(" where ") - for y := range t.keys { - col := t.keys[y] - if y > 0 { - s.WriteString(" and ") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) - - plan.argFields = append(plan.argFields, col.fieldName) - plan.keyFields = append(plan.keyFields, col.fieldName) - x++ - } - if plan.versField != "" { - s.WriteString(" and ") - s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) - plan.argFields = append(plan.argFields, plan.versField) - } - s.WriteString(t.dbmap.Dialect.QuerySuffix()) - - plan.query = s.String() - }) - - return plan.createBindInstance(elem, t.dbmap.TypeConverter) -} - -func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { - plan := &t.deletePlan - plan.once.Do(func() { - s := bytes.Buffer{} - s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) - - for y := range t.Columns { - col := t.Columns[y] - if !col.Transient { - if col == t.version { - plan.versField = col.fieldName - } - } - } - - s.WriteString(" where ") - for x := range t.keys { - k := t.keys[x] - if x > 0 { - s.WriteString(" and ") - } - s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) - - plan.keyFields = append(plan.keyFields, k.fieldName) - plan.argFields = append(plan.argFields, k.fieldName) - } - if plan.versField != "" { - s.WriteString(" and ") - s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields))) - - plan.argFields = append(plan.argFields, plan.versField) - } - s.WriteString(t.dbmap.Dialect.QuerySuffix()) - - plan.query = s.String() - }) - - return plan.createBindInstance(elem, t.dbmap.TypeConverter) -} - -func (t *TableMap) bindGet() *bindPlan { - plan := &t.getPlan - plan.once.Do(func() { - s := bytes.Buffer{} - s.WriteString("select ") - - x := 0 - for _, col := range t.Columns { - if !col.Transient { - if x > 0 { - s.WriteString(",") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - plan.argFields = append(plan.argFields, col.fieldName) - x++ - } - } - s.WriteString(" from ") - s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)) - s.WriteString(" where ") - for x := range t.keys { - col := t.keys[x] - if x > 0 { - s.WriteString(" and ") - } - s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) - s.WriteString("=") - s.WriteString(t.dbmap.Dialect.BindVar(x)) - - plan.keyFields = append(plan.keyFields, col.fieldName) - } - s.WriteString(t.dbmap.Dialect.QuerySuffix()) - - plan.query = s.String() - }) - - return plan -} diff --git a/gdb/sqldb/test_all.sh b/gdb/sqldb/test_all.sh deleted file mode 100755 index 0d4a549..0000000 --- a/gdb/sqldb/test_all.sh +++ /dev/null @@ -1,24 +0,0 @@ -#!/bin/bash -ex - -# on macs, you may need to: -# export GOBUILDFLAG=-ldflags -linkmode=external - -echo "Running unit tests" -go test -race - -echo "Testing against postgres" -export SQLDB_TEST_DSN="host=127.0.0.1 user=testuser password=123 dbname=testdb sslmode=disable" -export SQLDB_TEST_DIALECT=postgres -go test -tags integration $GOBUILDFLAG $@ . - -echo "Testing against sqlite" -export SQLDB_TEST_DSN=/tmp/testdb.bin -export SQLDB_TEST_DIALECT=sqlite -go test -tags integration $GOBUILDFLAG $@ . -rm -f /tmp/testdb.bin - -echo "Testing against mysql" -# export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" -export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb" -export SQLDB_TEST_DIALECT=mysql -go test -tags integration $GOBUILDFLAG $@ . diff --git a/gdb/sqldb/transaction.go b/gdb/sqldb/transaction.go deleted file mode 100644 index b6d0564..0000000 --- a/gdb/sqldb/transaction.go +++ /dev/null @@ -1,242 +0,0 @@ -// -// transaction.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "context" - "database/sql" - "time" -) - -// Transaction represents a database transaction. -// Insert/Update/Delete/Get/Exec operations will be run in the context -// of that transaction. Transactions should be terminated with -// a call to Commit() or Rollback() -type Transaction struct { - ctx context.Context - dbmap *DbMap - tx *sql.Tx - closed bool -} - -func (t *Transaction) WithContext(ctx context.Context) SqlExecutor { - copy := &Transaction{} - *copy = *t - copy.ctx = ctx - return copy -} - -// Insert has the same behavior as DbMap.Insert(), but runs in a transaction. -func (t *Transaction) Insert(list ...interface{}) error { - return insert(t.dbmap, t, list...) -} - -// Update had the same behavior as DbMap.Update(), but runs in a transaction. -func (t *Transaction) Update(list ...interface{}) (int64, error) { - return update(t.dbmap, t, nil, list...) -} - -// UpdateColumns had the same behavior as DbMap.UpdateColumns(), but runs in a transaction. -func (t *Transaction) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) { - return update(t.dbmap, t, filter, list...) -} - -// Delete has the same behavior as DbMap.Delete(), but runs in a transaction. -func (t *Transaction) Delete(list ...interface{}) (int64, error) { - return delete(t.dbmap, t, list...) -} - -// Get has the same behavior as DbMap.Get(), but runs in a transaction. -func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) { - return get(t.dbmap, t, i, keys...) -} - -// Select has the same behavior as DbMap.Select(), but runs in a transaction. -func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return hookedselect(t.dbmap, t, i, query, args...) -} - -// Exec has the same behavior as DbMap.Exec(), but runs in a transaction. -func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, query, args...) - } - return maybeExpandNamedQueryAndExec(t, query, args...) -} - -// SelectInt is a convenience wrapper around the sqldb.SelectInt function. -func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectInt(t, query, args...) -} - -// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function. -func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectNullInt(t, query, args...) -} - -// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function. -func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectFloat(t, query, args...) -} - -// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function. -func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectNullFloat(t, query, args...) -} - -// SelectStr is a convenience wrapper around the sqldb.SelectStr function. -func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectStr(t, query, args...) -} - -// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function. -func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectNullStr(t, query, args...) -} - -// SelectOne is a convenience wrapper around the sqldb.SelectOne function. -func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - return SelectOne(t.dbmap, t, holder, query, args...) -} - -// Commit commits the underlying database transaction. -func (t *Transaction) Commit() error { - if !t.closed { - t.closed = true - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, "commit;") - } - return t.tx.Commit() - } - - return sql.ErrTxDone -} - -// Rollback rolls back the underlying database transaction. -func (t *Transaction) Rollback() error { - if !t.closed { - t.closed = true - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, "rollback;") - } - return t.tx.Rollback() - } - - return sql.ErrTxDone -} - -// Savepoint creates a savepoint with the given name. The name is interpolated -// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is -// derived from user input. -func (t *Transaction) Savepoint(name string) error { - query := "savepoint " + t.dbmap.Dialect.QuoteField(name) - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, query, nil) - } - _, err := exec(t, query) - return err -} - -// RollbackToSavepoint rolls back to the savepoint with the given name. The -// name is interpolated directly into the SQL SAVEPOINT statement, so you must -// sanitize it if it is derived from user input. -func (t *Transaction) RollbackToSavepoint(savepoint string) error { - query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint) - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, query, nil) - } - _, err := exec(t, query) - return err -} - -// ReleaseSavepint releases the savepoint with the given name. The name is -// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize -// it if it is derived from user input. -func (t *Transaction) ReleaseSavepoint(savepoint string) error { - query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint) - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, query, nil) - } - _, err := exec(t, query) - return err -} - -// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction. -func (t *Transaction) Prepare(query string) (*sql.Stmt, error) { - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, query, nil) - } - return prepare(t, query) -} - -func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&query, args...) - } - - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, query, args...) - } - return queryRow(t, query, args...) -} - -func (t *Transaction) Query(q string, args ...interface{}) (*sql.Rows, error) { - if t.dbmap.ExpandSliceArgs { - expandSliceArgs(&q, args...) - } - - if t.dbmap.logger != nil { - now := time.Now() - defer t.dbmap.trace(now, q, args...) - } - return query(t, q, args...) -} diff --git a/gdb/sqldb/transaction_test.go b/gdb/sqldb/transaction_test.go deleted file mode 100644 index 0686a8d..0000000 --- a/gdb/sqldb/transaction_test.go +++ /dev/null @@ -1,340 +0,0 @@ -// -// transaction_test.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -//go:build integration -// +build integration - -package sqldb_test - -import "testing" - -func TestTransaction_Select_expandSliceArgs(t *testing.T) { - tests := []struct { - description string - query string - args []interface{} - wantLen int - }{ - { - description: "it should handle slice placeholders correctly", - query: ` -SELECT 1 FROM crazy_table -WHERE field1 = :Field1 -AND field2 IN (:FieldStringList) -AND field3 IN (:FieldUIntList) -AND field4 IN (:FieldUInt8List) -AND field5 IN (:FieldUInt16List) -AND field6 IN (:FieldUInt32List) -AND field7 IN (:FieldUInt64List) -AND field8 IN (:FieldIntList) -AND field9 IN (:FieldInt8List) -AND field10 IN (:FieldInt16List) -AND field11 IN (:FieldInt32List) -AND field12 IN (:FieldInt64List) -AND field13 IN (:FieldFloat32List) -AND field14 IN (:FieldFloat64List) -`, - args: []interface{}{ - map[string]interface{}{ - "Field1": 123, - "FieldStringList": []string{"h", "e", "y"}, - "FieldUIntList": []uint{1, 2, 3, 4}, - "FieldUInt8List": []uint8{1, 2, 3, 4}, - "FieldUInt16List": []uint16{1, 2, 3, 4}, - "FieldUInt32List": []uint32{1, 2, 3, 4}, - "FieldUInt64List": []uint64{1, 2, 3, 4}, - "FieldIntList": []int{1, 2, 3, 4}, - "FieldInt8List": []int8{1, 2, 3, 4}, - "FieldInt16List": []int16{1, 2, 3, 4}, - "FieldInt32List": []int32{1, 2, 3, 4}, - "FieldInt64List": []int64{1, 2, 3, 4}, - "FieldFloat32List": []float32{1, 2, 3, 4}, - "FieldFloat64List": []float64{1, 2, 3, 4}, - }, - }, - wantLen: 1, - }, - { - description: "it should handle slice placeholders correctly with custom types", - query: ` -SELECT 1 FROM crazy_table -WHERE field2 IN (:FieldStringList) -AND field12 IN (:FieldIntList) -`, - args: []interface{}{ - map[string]interface{}{ - "FieldStringList": customType1{"h", "e", "y"}, - "FieldIntList": customType2{1, 2, 3, 4}, - }, - }, - wantLen: 3, - }, - } - - type dataFormat struct { - Field1 int `db:"field1"` - Field2 string `db:"field2"` - Field3 uint `db:"field3"` - Field4 uint8 `db:"field4"` - Field5 uint16 `db:"field5"` - Field6 uint32 `db:"field6"` - Field7 uint64 `db:"field7"` - Field8 int `db:"field8"` - Field9 int8 `db:"field9"` - Field10 int16 `db:"field10"` - Field11 int32 `db:"field11"` - Field12 int64 `db:"field12"` - Field13 float32 `db:"field13"` - Field14 float64 `db:"field14"` - } - - dbmap := newDBMap(t) - dbmap.ExpandSliceArgs = true - dbmap.AddTableWithName(dataFormat{}, "crazy_table") - - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - err = dbmap.Insert( - &dataFormat{ - Field1: 123, - Field2: "h", - Field3: 1, - Field4: 1, - Field5: 1, - Field6: 1, - Field7: 1, - Field8: 1, - Field9: 1, - Field10: 1, - Field11: 1, - Field12: 1, - Field13: 1, - Field14: 1, - }, - &dataFormat{ - Field1: 124, - Field2: "e", - Field3: 2, - Field4: 2, - Field5: 2, - Field6: 2, - Field7: 2, - Field8: 2, - Field9: 2, - Field10: 2, - Field11: 2, - Field12: 2, - Field13: 2, - Field14: 2, - }, - &dataFormat{ - Field1: 125, - Field2: "y", - Field3: 3, - Field4: 3, - Field5: 3, - Field6: 3, - Field7: 3, - Field8: 3, - Field9: 3, - Field10: 3, - Field11: 3, - Field12: 3, - Field13: 3, - Field14: 3, - }, - ) - - if err != nil { - t.Fatal(err) - } - - for _, tt := range tests { - t.Run(tt.description, func(t *testing.T) { - tx, err := dbmap.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - var dummy []int - _, err = tx.Select(&dummy, tt.query, tt.args...) - if err != nil { - t.Fatal(err) - } - - if len(dummy) != tt.wantLen { - t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen) - } - }) - } -} - -func TestTransaction_Exec_expandSliceArgs(t *testing.T) { - tests := []struct { - description string - query string - args []interface{} - wantLen int - }{ - { - description: "it should handle slice placeholders correctly", - query: ` -DELETE FROM crazy_table -WHERE field1 = :Field1 -AND field2 IN (:FieldStringList) -AND field3 IN (:FieldUIntList) -AND field4 IN (:FieldUInt8List) -AND field5 IN (:FieldUInt16List) -AND field6 IN (:FieldUInt32List) -AND field7 IN (:FieldUInt64List) -AND field8 IN (:FieldIntList) -AND field9 IN (:FieldInt8List) -AND field10 IN (:FieldInt16List) -AND field11 IN (:FieldInt32List) -AND field12 IN (:FieldInt64List) -AND field13 IN (:FieldFloat32List) -AND field14 IN (:FieldFloat64List) -`, - args: []interface{}{ - map[string]interface{}{ - "Field1": 123, - "FieldStringList": []string{"h", "e", "y"}, - "FieldUIntList": []uint{1, 2, 3, 4}, - "FieldUInt8List": []uint8{1, 2, 3, 4}, - "FieldUInt16List": []uint16{1, 2, 3, 4}, - "FieldUInt32List": []uint32{1, 2, 3, 4}, - "FieldUInt64List": []uint64{1, 2, 3, 4}, - "FieldIntList": []int{1, 2, 3, 4}, - "FieldInt8List": []int8{1, 2, 3, 4}, - "FieldInt16List": []int16{1, 2, 3, 4}, - "FieldInt32List": []int32{1, 2, 3, 4}, - "FieldInt64List": []int64{1, 2, 3, 4}, - "FieldFloat32List": []float32{1, 2, 3, 4}, - "FieldFloat64List": []float64{1, 2, 3, 4}, - }, - }, - wantLen: 1, - }, - { - description: "it should handle slice placeholders correctly with custom types", - query: ` -DELETE FROM crazy_table -WHERE field2 IN (:FieldStringList) -AND field12 IN (:FieldIntList) -`, - args: []interface{}{ - map[string]interface{}{ - "FieldStringList": customType1{"h", "e", "y"}, - "FieldIntList": customType2{1, 2, 3, 4}, - }, - }, - wantLen: 3, - }, - } - - type dataFormat struct { - Field1 int `db:"field1"` - Field2 string `db:"field2"` - Field3 uint `db:"field3"` - Field4 uint8 `db:"field4"` - Field5 uint16 `db:"field5"` - Field6 uint32 `db:"field6"` - Field7 uint64 `db:"field7"` - Field8 int `db:"field8"` - Field9 int8 `db:"field9"` - Field10 int16 `db:"field10"` - Field11 int32 `db:"field11"` - Field12 int64 `db:"field12"` - Field13 float32 `db:"field13"` - Field14 float64 `db:"field14"` - } - - dbmap := newDBMap(t) - dbmap.ExpandSliceArgs = true - dbmap.AddTableWithName(dataFormat{}, "crazy_table") - - err := dbmap.CreateTables() - if err != nil { - panic(err) - } - defer dropAndClose(dbmap) - - err = dbmap.Insert( - &dataFormat{ - Field1: 123, - Field2: "h", - Field3: 1, - Field4: 1, - Field5: 1, - Field6: 1, - Field7: 1, - Field8: 1, - Field9: 1, - Field10: 1, - Field11: 1, - Field12: 1, - Field13: 1, - Field14: 1, - }, - &dataFormat{ - Field1: 124, - Field2: "e", - Field3: 2, - Field4: 2, - Field5: 2, - Field6: 2, - Field7: 2, - Field8: 2, - Field9: 2, - Field10: 2, - Field11: 2, - Field12: 2, - Field13: 2, - Field14: 2, - }, - &dataFormat{ - Field1: 125, - Field2: "y", - Field3: 3, - Field4: 3, - Field5: 3, - Field6: 3, - Field7: 3, - Field8: 3, - Field9: 3, - Field10: 3, - Field11: 3, - Field12: 3, - Field13: 3, - Field14: 3, - }, - ) - - if err != nil { - t.Fatal(err) - } - - for _, tt := range tests { - t.Run(tt.description, func(t *testing.T) { - tx, err := dbmap.Begin() - if err != nil { - t.Fatal(err) - } - defer tx.Rollback() - - _, err = tx.Exec(tt.query, tt.args...) - if err != nil { - t.Fatal(err) - } - }) - } -}