From 02b5c78a45a2257ec2b5ab4f2ebb75d76e1b1d45 Mon Sep 17 00:00:00 2001 From: tiglog Date: Thu, 17 Aug 2023 17:16:00 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E5=9F=BA=E4=BA=8E=20gorp=20?= =?UTF-8?q?=E9=87=8D=E6=9E=84?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gdb/sqldb/base_test.go | 180 -- gdb/sqldb/column.go | 78 + gdb/sqldb/context_test.go | 82 + gdb/sqldb/db.go | 1050 +++++++++- gdb/sqldb/db_func.go | 220 --- gdb/sqldb/db_func_opt.go | 75 - gdb/sqldb/db_func_test.go | 114 -- gdb/sqldb/db_model.go | 20 - gdb/sqldb/db_query.go | 322 ---- gdb/sqldb/db_query_test.go | 109 -- gdb/sqldb/db_test.go | 187 ++ gdb/sqldb/dialect.go | 108 ++ gdb/sqldb/dialect_mysql.go | 172 ++ gdb/sqldb/dialect_mysql_test.go | 195 ++ gdb/sqldb/dialect_oracle.go | 142 ++ gdb/sqldb/dialect_postgres.go | 152 ++ gdb/sqldb/dialect_postgres_test.go | 161 ++ gdb/sqldb/dialect_sqlite.go | 115 ++ gdb/sqldb/doc.go | 13 + gdb/sqldb/errors.go | 34 + gdb/sqldb/hooks.go | 45 + gdb/sqldb/index.go | 51 + gdb/sqldb/lockerror.go | 59 + gdb/sqldb/logging.go | 45 + gdb/sqldb/nulltypes.go | 68 + gdb/sqldb/select.go | 361 ++++ gdb/sqldb/sqldb.go | 675 +++++++ gdb/sqldb/sqldb_test.go | 2875 ++++++++++++++++++++++++++++ gdb/sqldb/table.go | 258 +++ gdb/sqldb/table_bindings.go | 308 +++ gdb/sqldb/test_all.sh | 24 + gdb/sqldb/transaction.go | 242 +++ gdb/sqldb/transaction_test.go | 340 ++++ go.mod | 8 +- go.sum | 7 +- 35 files changed, 7814 insertions(+), 1081 deletions(-) delete mode 100644 gdb/sqldb/base_test.go create mode 100644 gdb/sqldb/column.go create mode 100644 gdb/sqldb/context_test.go delete mode 100644 gdb/sqldb/db_func.go delete mode 100644 gdb/sqldb/db_func_opt.go delete mode 100644 gdb/sqldb/db_func_test.go delete mode 100644 gdb/sqldb/db_model.go delete mode 100644 gdb/sqldb/db_query.go delete mode 100644 gdb/sqldb/db_query_test.go create mode 100644 gdb/sqldb/db_test.go create mode 100644 gdb/sqldb/dialect.go create mode 100644 gdb/sqldb/dialect_mysql.go create mode 100644 gdb/sqldb/dialect_mysql_test.go create mode 100644 gdb/sqldb/dialect_oracle.go create mode 100644 gdb/sqldb/dialect_postgres.go create mode 100644 gdb/sqldb/dialect_postgres_test.go create mode 100644 gdb/sqldb/dialect_sqlite.go create mode 100644 gdb/sqldb/doc.go create mode 100644 gdb/sqldb/errors.go create mode 100644 gdb/sqldb/hooks.go create mode 100644 gdb/sqldb/index.go create mode 100644 gdb/sqldb/lockerror.go create mode 100644 gdb/sqldb/logging.go create mode 100644 gdb/sqldb/nulltypes.go create mode 100644 gdb/sqldb/select.go create mode 100644 gdb/sqldb/sqldb.go create mode 100644 gdb/sqldb/sqldb_test.go create mode 100644 gdb/sqldb/table.go create mode 100644 gdb/sqldb/table_bindings.go create mode 100755 gdb/sqldb/test_all.sh create mode 100644 gdb/sqldb/transaction.go create mode 100644 gdb/sqldb/transaction_test.go diff --git a/gdb/sqldb/base_test.go b/gdb/sqldb/base_test.go deleted file mode 100644 index 7b7bfeb..0000000 --- a/gdb/sqldb/base_test.go +++ /dev/null @@ -1,180 +0,0 @@ -// -// base_test.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb_test - -import ( - "database/sql" - "fmt" - "os" - "strings" - "testing" - - _ "github.com/lib/pq" - "git.hexq.cn/tiglog/golib/gdb/sqldb" - // _ "github.com/go-sql-driver/mysql" -) - -type Schema struct { - create string - drop string -} - -var defaultSchema = Schema{ - create: ` -CREATE TABLE person ( - id serial, - first_name text, - last_name text, - email text, - added_at int default 0, - PRIMARY KEY (id) -); - -CREATE TABLE place ( - country text, - city text NULL, - telcode integer -); - -CREATE TABLE capplace ( - country text, - city text NULL, - telcode integer -); - -CREATE TABLE nullperson ( - first_name text NULL, - last_name text NULL, - email text NULL -); - -CREATE TABLE employees ( - name text, - id integer, - boss_id integer -); - -`, - drop: ` -drop table person; -drop table place; -drop table capplace; -drop table nullperson; -drop table employees; -`, -} - -type Person struct { - Id int64 `db:"id"` - FirstName string `db:"first_name"` - LastName string `db:"last_name"` - Email string `db:"email"` - AddedAt int64 `db:"added_at"` -} - -type Person2 struct { - FirstName sql.NullString `db:"first_name"` - LastName sql.NullString `db:"last_name"` - Email sql.NullString -} - -type Place struct { - Country string - City sql.NullString - TelCode int -} - -type PlacePtr struct { - Country string - City *string - TelCode int -} - -type PersonPlace struct { - Person - Place -} - -type PersonPlacePtr struct { - *Person - *Place -} - -type EmbedConflict struct { - FirstName string `db:"first_name"` - Person -} - -type SliceMember struct { - Country string - City sql.NullString - TelCode int - People []Person `db:"-"` - Addresses []Place `db:"-"` -} - -func loadDefaultFixture(db *sqldb.Engine, t *testing.T) { - tx := db.MustBegin() - - s1 := "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)" - tx.MustExec(db.Rebind(s1), "Jason", "Moiron", "jmoiron@jmoiron.net") - - s1 = "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)" - tx.MustExec(db.Rebind(s1), "John", "Doe", "johndoeDNE@gmail.net") - - s1 = "INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)" - tx.MustExec(db.Rebind(s1), "United States", "New York", "1") - - s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)" - tx.MustExec(db.Rebind(s1), "Hong Kong", "852") - - s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)" - tx.MustExec(db.Rebind(s1), "Singapore", "65") - - s1 = "INSERT INTO capplace (country, telcode) VALUES (?, ?)" - tx.MustExec(db.Rebind(s1), "Sarf Efrica", "27") - - s1 = "INSERT INTO employees (name, id) VALUES (?, ?)" - tx.MustExec(db.Rebind(s1), "Peter", "4444") - - s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)" - tx.MustExec(db.Rebind(s1), "Joe", "1", "4444") - - s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)" - tx.MustExec(db.Rebind(s1), "Martin", "2", "4444") - tx.Commit() -} - -func MultiExec(e *sqldb.Engine, query string) { - stmts := strings.Split(query, ";\n") - if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { - stmts = stmts[:len(stmts)-1] - } - for _, s := range stmts { - _, err := e.Exec(s) - if err != nil { - fmt.Println(err, s) - } - } -} - -func RunDbTest(t *testing.T, test func(db *sqldb.Engine, t *testing.T)) { - // 先初始化数据库 - url := os.Getenv("DB_URL") - var db = sqldb.New(url) - - // 再注册清空数据库 - defer func() { - MultiExec(db, defaultSchema.drop) - }() - // 再加入一些数据 - MultiExec(db, defaultSchema.create) - loadDefaultFixture(db, t) - // 最后测试 - test(db, t) -} diff --git a/gdb/sqldb/column.go b/gdb/sqldb/column.go new file mode 100644 index 0000000..781c36c --- /dev/null +++ b/gdb/sqldb/column.go @@ -0,0 +1,78 @@ +// +// column.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import "reflect" + +// ColumnMap represents a mapping between a Go struct field and a single +// column in a table. +// Unique and MaxSize only inform the +// CreateTables() function and are not used by Insert/Update/Delete/Get. +type ColumnMap struct { + // Column name in db table + ColumnName string + + // If true, this column is skipped in generated SQL statements + Transient bool + + // If true, " unique" is added to create table statements. + // Not used elsewhere + Unique bool + + // Query used for getting generated id after insert + GeneratedIdQuery string + + // Passed to Dialect.ToSqlType() to assist in informing the + // correct column type to map to in CreateTables() + MaxSize int + + DefaultValue string + + fieldName string + gotype reflect.Type + isPK bool + isAutoIncr bool + isNotNull bool +} + +// Rename allows you to specify the column name in the table +// +// Example: table.ColMap("Updated").Rename("date_updated") +func (c *ColumnMap) Rename(colname string) *ColumnMap { + c.ColumnName = colname + return c +} + +// SetTransient allows you to mark the column as transient. If true +// this column will be skipped when SQL statements are generated +func (c *ColumnMap) SetTransient(b bool) *ColumnMap { + c.Transient = b + return c +} + +// SetUnique adds "unique" to the create table statements for this +// column, if b is true. +func (c *ColumnMap) SetUnique(b bool) *ColumnMap { + c.Unique = b + return c +} + +// SetNotNull adds "not null" to the create table statements for this +// column, if nn is true. +func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap { + c.isNotNull = nn + return c +} + +// SetMaxSize specifies the max length of values of this column. This is +// passed to the dialect.ToSqlType() function, which can use the value +// to alter the generated type for "create table" statements +func (c *ColumnMap) SetMaxSize(size int) *ColumnMap { + c.MaxSize = size + return c +} diff --git a/gdb/sqldb/context_test.go b/gdb/sqldb/context_test.go new file mode 100644 index 0000000..96cd8d8 --- /dev/null +++ b/gdb/sqldb/context_test.go @@ -0,0 +1,82 @@ +// +// context_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +//go:build integration +// +build integration + +package sqldb_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +// Drivers that don't support cancellation. +var unsupportedDrivers map[string]bool = map[string]bool{ + "mymysql": true, +} + +type SleepDialect interface { + // string to sleep for d duration + SleepClause(d time.Duration) string +} + +func TestWithNotCanceledContext(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + withCtx := dbmap.WithContext(ctx) + + _, err := withCtx.Exec("SELECT 1") + assert.Nil(t, err) +} + +func TestWithCanceledContext(t *testing.T) { + dialect, driver := dialectAndDriver() + if unsupportedDrivers[driver] { + t.Skipf("Cancellation is not yet supported by all drivers. Not known to be supported in %s.", driver) + } + + sleepDialect, ok := dialect.(SleepDialect) + if !ok { + t.Skipf("Sleep is not supported in all dialects. Not known to be supported in %s.", driver) + } + + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond) + defer cancel() + + withCtx := dbmap.WithContext(ctx) + + startTime := time.Now() + + _, err := withCtx.Exec("SELECT " + sleepDialect.SleepClause(1*time.Second)) + + if d := time.Since(startTime); d > 500*time.Millisecond { + t.Errorf("too long execution time: %s", d) + } + + switch driver { + case "postgres": + // pq doesn't return standard deadline exceeded error + if err.Error() != "pq: canceling statement due to user request" { + t.Errorf("expected context.DeadlineExceeded, got %v", err) + } + default: + if err != context.DeadlineExceeded { + t.Errorf("expected context.DeadlineExceeded, got %v", err) + } + } +} diff --git a/gdb/sqldb/db.go b/gdb/sqldb/db.go index cec2764..c36378d 100644 --- a/gdb/sqldb/db.go +++ b/gdb/sqldb/db.go @@ -1,6 +1,6 @@ // // db.go -// Copyright (C) 2022 tiglog +// Copyright (C) 2023 tiglog // // Distributed under terms of the MIT license. // @@ -8,51 +8,1031 @@ package sqldb import ( + "bytes" + "context" "database/sql" + "database/sql/driver" "errors" + "fmt" + "log" + "reflect" + "strconv" "strings" - - "github.com/jmoiron/sqlx" + "time" ) -var Db *Engine - -type Engine struct { - *sqlx.DB -} - -var ErrNoRows = sql.ErrNoRows +var DM *DbMap type DbOption struct { - Url string - MaxOpenConns int - MaxIdleConns int + Type string + Dsn string + MaxIdleConnections int + MaxOpenConnections int } -// mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] -// pgsql://host=X.X.X.X port=54321 user=postgres password=admin123 dbname=postgres sslmode=disable" -func NewWithOption(opt *DbOption) *Engine { - urls := strings.Split(opt.Url, "://") - if len(urls) != 2 { - panic(errors.New("wrong database url:" + opt.Url)) +func InitDb(opt DbOption) error { + var ( + db *sql.DB + dialect Dialect + err error + ) + switch opt.Type { + case "postgres": + db, err = sql.Open(opt.Type, opt.Dsn) + if err != nil { + return err + } + db.SetMaxIdleConns(opt.MaxIdleConnections) + db.SetMaxIdleConns(opt.MaxOpenConnections) + dialect = PostgresDialect{LowercaseFields: true} + case "sqlite3": + db, err = sql.Open(opt.Type, opt.Dsn) + if err != nil { + return err + } + if opt.Dsn == ":memory:" { + db.SetMaxOpenConns(1) + } + dialect = SqliteDialect{} + case "mysql": + db, err = sql.Open(opt.Type, opt.Dsn) + if err != nil { + return err + } + db.SetMaxIdleConns(opt.MaxIdleConnections) + db.SetMaxIdleConns(opt.MaxOpenConnections) + dialect = MySQLDialect{Engine: "InnoDB", Encoding: "utf8mb4"} + default: + return errors.New("unrecognized database driver") } - dbx, err := sqlx.Open(urls[0], urls[1]) - if err != nil { - panic(err) + DM = &DbMap{ + Db: db, + Dialect: dialect, } - dbx.SetMaxIdleConns(opt.MaxIdleConns) - dbx.SetMaxOpenConns(opt.MaxOpenConns) - err = dbx.Ping() - if err != nil { - panic(err) - } - Db = &Engine{ - dbx, - } - return Db + return nil } -func New(url string) *Engine { - opt := &DbOption{Url: url, MaxOpenConns: 256, MaxIdleConns: 2} - return NewWithOption(opt) +// DbMap is the root sqldb mapping object. Create one of these for each +// database schema you wish to map. Each DbMap contains a list of +// mapped tables. +// +// Example: +// +// dialect := sqldb.MySQLDialect{"InnoDB", "UTF8"} +// dbmap := &sqldb.DbMap{Db: db, Dialect: dialect} +type DbMap struct { + ctx context.Context + + // Db handle to use with this map + Db *sql.DB + + // Dialect implementation to use with this map + Dialect Dialect + + TypeConverter TypeConverter + + // ExpandSlices when enabled will convert slice arguments in mappers into flat + // values. It will modify the query, adding more placeholders, and the mapper, + // adding each item of the slice as a new unique entry in the mapper. For + // example, given the scenario bellow: + // + // dbmap.Select(&output, "SELECT 1 FROM example WHERE id IN (:IDs)", map[string]interface{}{ + // "IDs": []int64{1, 2, 3}, + // }) + // + // The executed query would be: + // + // SELECT 1 FROM example WHERE id IN (:IDs0,:IDs1,:IDs2) + // + // With the mapper: + // + // map[string]interface{}{ + // "IDs": []int64{1, 2, 3}, + // "IDs0": int64(1), + // "IDs1": int64(2), + // "IDs2": int64(3), + // } + // + // It is also flexible for custom slice types. The value just need to + // implement stringer or numberer interfaces. + // + // type CustomValue string + // + // const ( + // CustomValueHey CustomValue = "hey" + // CustomValueOh CustomValue = "oh" + // ) + // + // type CustomValues []CustomValue + // + // func (c CustomValues) ToStringSlice() []string { + // values := make([]string, len(c)) + // for i := range c { + // values[i] = string(c[i]) + // } + // return values + // } + // + // func query() { + // // ... + // result, err := dbmap.Select(&output, "SELECT 1 FROM example WHERE value IN (:Values)", map[string]interface{}{ + // "Values": CustomValues([]CustomValue{CustomValueHey}), + // }) + // // ... + // } + ExpandSliceArgs bool + + tables []*TableMap + tablesDynamic map[string]*TableMap // tables that use same go-struct and different db table names + logger SqldbLogger + logPrefix string +} + +func (m *DbMap) dynamicTableAdd(tableName string, tbl *TableMap) { + if m.tablesDynamic == nil { + m.tablesDynamic = make(map[string]*TableMap) + } + m.tablesDynamic[tableName] = tbl +} + +func (m *DbMap) dynamicTableFind(tableName string) (*TableMap, bool) { + if m.tablesDynamic == nil { + return nil, false + } + tbl, found := m.tablesDynamic[tableName] + return tbl, found +} + +func (m *DbMap) dynamicTableMap() map[string]*TableMap { + if m.tablesDynamic == nil { + m.tablesDynamic = make(map[string]*TableMap) + } + return m.tablesDynamic +} + +func (m *DbMap) WithContext(ctx context.Context) SqlExecutor { + copy := &DbMap{} + *copy = *m + copy.ctx = ctx + return copy +} + +func (m *DbMap) CreateIndex() error { + var err error + dialect := reflect.TypeOf(m.Dialect) + for _, table := range m.tables { + for _, index := range table.indexes { + err = m.createIndexImpl(dialect, table, index) + if err != nil { + break + } + } + } + + for _, table := range m.dynamicTableMap() { + for _, index := range table.indexes { + err = m.createIndexImpl(dialect, table, index) + if err != nil { + break + } + } + } + + return err +} + +func (m *DbMap) createIndexImpl(dialect reflect.Type, + table *TableMap, + index *IndexMap) error { + s := bytes.Buffer{} + s.WriteString("create") + if index.Unique { + s.WriteString(" unique") + } + s.WriteString(" index") + s.WriteString(fmt.Sprintf(" %s on %s", index.IndexName, table.TableName)) + if dname := dialect.Name(); dname == "PostgresDialect" && index.IndexType != "" { + s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) + } + s.WriteString(" (") + for x, col := range index.columns { + if x > 0 { + s.WriteString(", ") + } + s.WriteString(m.Dialect.QuoteField(col)) + } + s.WriteString(")") + + if dname := dialect.Name(); dname == "MySQLDialect" && index.IndexType != "" { + s.WriteString(fmt.Sprintf(" %s %s", m.Dialect.CreateIndexSuffix(), index.IndexType)) + } + s.WriteString(";") + _, err := m.Exec(s.String()) + return err +} + +func (t *TableMap) DropIndex(name string) error { + + var err error + dialect := reflect.TypeOf(t.dbmap.Dialect) + for _, idx := range t.indexes { + if idx.IndexName == name { + s := bytes.Buffer{} + s.WriteString(fmt.Sprintf("DROP INDEX %s", idx.IndexName)) + + if dname := dialect.Name(); dname == "MySQLDialect" { + s.WriteString(fmt.Sprintf(" %s %s", t.dbmap.Dialect.DropIndexSuffix(), t.TableName)) + } + s.WriteString(";") + _, e := t.dbmap.Exec(s.String()) + if e != nil { + err = e + } + break + } + } + t.ResetSql() + return err +} + +// AddTable registers the given interface type with sqldb. The table name +// will be given the name of the TypeOf(i). You must call this function, +// or AddTableWithName, for any struct type you wish to persist with +// the given DbMap. +// +// This operation is idempotent. If i's type is already mapped, the +// existing *TableMap is returned +func (m *DbMap) AddTable(i interface{}) *TableMap { + return m.AddTableWithName(i, "") +} + +// AddTableWithName has the same behavior as AddTable, but sets +// table.TableName to name. +func (m *DbMap) AddTableWithName(i interface{}, name string) *TableMap { + return m.AddTableWithNameAndSchema(i, "", name) +} + +// AddTableWithNameAndSchema has the same behavior as AddTable, but sets +// table.TableName to name. +func (m *DbMap) AddTableWithNameAndSchema(i interface{}, schema string, name string) *TableMap { + t := reflect.TypeOf(i) + if name == "" { + name = t.Name() + } + + // check if we have a table for this type already + // if so, update the name and return the existing pointer + for i := range m.tables { + table := m.tables[i] + if table.gotype == t { + table.TableName = name + return table + } + } + + tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} + var primaryKey []*ColumnMap + tmap.Columns, primaryKey = m.readStructColumns(t) + m.tables = append(m.tables, tmap) + if len(primaryKey) > 0 { + tmap.keys = append(tmap.keys, primaryKey...) + } + + return tmap +} + +// AddTableDynamic registers the given interface type with sqldb. +// The table name will be dynamically determined at runtime by +// using the GetTableName method on DynamicTable interface +func (m *DbMap) AddTableDynamic(inp DynamicTable, schema string) *TableMap { + + val := reflect.ValueOf(inp) + elm := val.Elem() + t := elm.Type() + name := inp.TableName() + if name == "" { + panic("Missing table name in DynamicTable instance") + } + + // Check if there is another dynamic table with the same name + if _, found := m.dynamicTableFind(name); found { + panic(fmt.Sprintf("A table with the same name %v already exists", name)) + } + + tmap := &TableMap{gotype: t, TableName: name, SchemaName: schema, dbmap: m} + var primaryKey []*ColumnMap + tmap.Columns, primaryKey = m.readStructColumns(t) + if len(primaryKey) > 0 { + tmap.keys = append(tmap.keys, primaryKey...) + } + + m.dynamicTableAdd(name, tmap) + + return tmap +} + +func (m *DbMap) readStructColumns(t reflect.Type) (cols []*ColumnMap, primaryKey []*ColumnMap) { + primaryKey = make([]*ColumnMap, 0) + n := t.NumField() + for i := 0; i < n; i++ { + f := t.Field(i) + if f.Anonymous && f.Type.Kind() == reflect.Struct { + // Recursively add nested fields in embedded structs. + subcols, subpk := m.readStructColumns(f.Type) + // Don't append nested fields that have the same field + // name as an already-mapped field. + for _, subcol := range subcols { + shouldAppend := true + for _, col := range cols { + if !subcol.Transient && subcol.fieldName == col.fieldName { + shouldAppend = false + break + } + } + if shouldAppend { + cols = append(cols, subcol) + } + } + if subpk != nil { + primaryKey = append(primaryKey, subpk...) + } + } else { + // Tag = Name { ',' Option } + // Option = OptionKey [ ':' OptionValue ] + cArguments := strings.Split(f.Tag.Get("db"), ",") + columnName := cArguments[0] + var maxSize int + var defaultValue string + var isAuto bool + var isPK bool + var isNotNull bool + for _, argString := range cArguments[1:] { + argString = strings.TrimSpace(argString) + arg := strings.SplitN(argString, ":", 2) + + // check mandatory/unexpected option values + switch arg[0] { + case "size", "default": + // options requiring value + if len(arg) == 1 { + panic(fmt.Sprintf("missing option value for option %v on field %v", arg[0], f.Name)) + } + default: + // options where value is invalid (currently all other options) + if len(arg) == 2 { + panic(fmt.Sprintf("unexpected option value for option %v on field %v", arg[0], f.Name)) + } + } + + switch arg[0] { + case "size": + maxSize, _ = strconv.Atoi(arg[1]) + case "default": + defaultValue = arg[1] + case "primarykey": + isPK = true + case "autoincrement": + isAuto = true + case "notnull": + isNotNull = true + default: + panic(fmt.Sprintf("Unrecognized tag option for field %v: %v", f.Name, arg)) + } + } + if columnName == "" { + columnName = f.Name + } + + gotype := f.Type + valueType := gotype + if valueType.Kind() == reflect.Ptr { + valueType = valueType.Elem() + } + value := reflect.New(valueType).Interface() + if m.TypeConverter != nil { + // Make a new pointer to a value of type gotype and + // pass it to the TypeConverter's FromDb method to see + // if a different type should be used for the column + // type during table creation. + scanner, useHolder := m.TypeConverter.FromDb(value) + if useHolder { + value = scanner.Holder + gotype = reflect.TypeOf(value) + } + } + if typer, ok := value.(SqlTyper); ok { + gotype = reflect.TypeOf(typer.SqlType()) + } else if typer, ok := value.(legacySqlTyper); ok { + log.Printf("Deprecation Warning: update your SqlType methods to return a driver.Value") + gotype = reflect.TypeOf(typer.SqlType()) + } else if valuer, ok := value.(driver.Valuer); ok { + // Only check for driver.Valuer if SqlTyper wasn't + // found. + v, err := valuer.Value() + if err == nil && v != nil { + gotype = reflect.TypeOf(v) + } + } + cm := &ColumnMap{ + ColumnName: columnName, + DefaultValue: defaultValue, + Transient: columnName == "-", + fieldName: f.Name, + gotype: gotype, + isPK: isPK, + isAutoIncr: isAuto, + isNotNull: isNotNull, + MaxSize: maxSize, + } + if isPK { + primaryKey = append(primaryKey, cm) + } + // Check for nested fields of the same field name and + // override them. + shouldAppend := true + for index, col := range cols { + if !col.Transient && col.fieldName == cm.fieldName { + cols[index] = cm + shouldAppend = false + break + } + } + if shouldAppend { + cols = append(cols, cm) + } + } + + } + return +} + +// CreateTables iterates through TableMaps registered to this DbMap and +// executes "create table" statements against the database for each. +// +// This is particularly useful in unit tests where you want to create +// and destroy the schema automatically. +func (m *DbMap) CreateTables() error { + return m.createTables(false) +} + +// CreateTablesIfNotExists is similar to CreateTables, but starts +// each statement with "create table if not exists" so that existing +// tables do not raise errors +func (m *DbMap) CreateTablesIfNotExists() error { + return m.createTables(true) +} + +func (m *DbMap) createTables(ifNotExists bool) error { + var err error + for i := range m.tables { + table := m.tables[i] + sql := table.SqlForCreate(ifNotExists) + _, err = m.Exec(sql) + if err != nil { + return err + } + } + + for _, tbl := range m.dynamicTableMap() { + sql := tbl.SqlForCreate(ifNotExists) + _, err = m.Exec(sql) + if err != nil { + return err + } + } + + return err +} + +// DropTable drops an individual table. +// Returns an error when the table does not exist. +func (m *DbMap) DropTable(table interface{}) error { + t := reflect.TypeOf(table) + + tableName := "" + if dyn, ok := table.(DynamicTable); ok { + tableName = dyn.TableName() + } + + return m.dropTable(t, tableName, false) +} + +// DropTableIfExists drops an individual table when the table exists. +func (m *DbMap) DropTableIfExists(table interface{}) error { + t := reflect.TypeOf(table) + + tableName := "" + if dyn, ok := table.(DynamicTable); ok { + tableName = dyn.TableName() + } + + return m.dropTable(t, tableName, true) +} + +// DropTables iterates through TableMaps registered to this DbMap and +// executes "drop table" statements against the database for each. +func (m *DbMap) DropTables() error { + return m.dropTables(false) +} + +// DropTablesIfExists is the same as DropTables, but uses the "if exists" clause to +// avoid errors for tables that do not exist. +func (m *DbMap) DropTablesIfExists() error { + return m.dropTables(true) +} + +// Goes through all the registered tables, dropping them one by one. +// If an error is encountered, then it is returned and the rest of +// the tables are not dropped. +func (m *DbMap) dropTables(addIfExists bool) (err error) { + for _, table := range m.tables { + err = m.dropTableImpl(table, addIfExists) + if err != nil { + return err + } + } + + for _, table := range m.dynamicTableMap() { + err = m.dropTableImpl(table, addIfExists) + if err != nil { + return err + } + } + + return err +} + +// Implementation of dropping a single table. +func (m *DbMap) dropTable(t reflect.Type, name string, addIfExists bool) error { + table := tableOrNil(m, t, name) + if table == nil { + return fmt.Errorf("table %s was not registered", table.TableName) + } + + return m.dropTableImpl(table, addIfExists) +} + +func (m *DbMap) dropTableImpl(table *TableMap, ifExists bool) (err error) { + tableDrop := "drop table" + if ifExists { + tableDrop = m.Dialect.IfTableExists(tableDrop, table.SchemaName, table.TableName) + } + _, err = m.Exec(fmt.Sprintf("%s %s;", tableDrop, m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + return err +} + +// TruncateTables iterates through TableMaps registered to this DbMap and +// executes "truncate table" statements against the database for each, or in the case of +// sqlite, a "delete from" with no "where" clause, which uses the truncate optimization +// (http://www.sqlite.org/lang_delete.html) +func (m *DbMap) TruncateTables() error { + var err error + for i := range m.tables { + table := m.tables[i] + _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + if e != nil { + err = e + } + } + + for _, table := range m.dynamicTableMap() { + _, e := m.Exec(fmt.Sprintf("%s %s;", m.Dialect.TruncateClause(), m.Dialect.QuotedTableForQuery(table.SchemaName, table.TableName))) + if e != nil { + err = e + } + } + + return err +} + +// Insert runs a SQL INSERT statement for each element in list. List +// items must be pointers. +// +// Any interface whose TableMap has an auto-increment primary key will +// have its last insert id bound to the PK field on the struct. +// +// The hook functions PreInsert() and/or PostInsert() will be executed +// before/after the INSERT statement if the interface defines them. +// +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Insert(list ...interface{}) error { + return insert(m, m, list...) +} + +// Update runs a SQL UPDATE statement for each element in list. List +// items must be pointers. +// +// The hook functions PreUpdate() and/or PostUpdate() will be executed +// before/after the UPDATE statement if the interface defines them. +// +// Returns the number of rows updated. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Update(list ...interface{}) (int64, error) { + return update(m, m, nil, list...) +} + +// UpdateColumns runs a SQL UPDATE statement for each element in list. List +// items must be pointers. +// +// Only the columns accepted by filter are included in the UPDATE. +// +// The hook functions PreUpdate() and/or PostUpdate() will be executed +// before/after the UPDATE statement if the interface defines them. +// +// Returns the number of rows updated. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) { + return update(m, m, filter, list...) +} + +// Delete runs a SQL DELETE statement for each element in list. List +// items must be pointers. +// +// The hook functions PreDelete() and/or PostDelete() will be executed +// before/after the DELETE statement if the interface defines them. +// +// Returns the number of rows deleted. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Delete(list ...interface{}) (int64, error) { + return delete(m, m, list...) +} + +// Get runs a SQL SELECT to fetch a single row from the table based on the +// primary key(s) +// +// i should be an empty value for the struct to load. keys should be +// the primary key value(s) for the row to load. If multiple keys +// exist on the table, the order should match the column order +// specified in SetKeys() when the table mapping was defined. +// +// The hook function PostGet() will be executed after the SELECT +// statement if the interface defines them. +// +// Returns a pointer to a struct that matches or nil if no row is found. +// +// Returns an error if SetKeys has not been called on the TableMap +// Panics if any interface in the list has not been registered with AddTable +func (m *DbMap) Get(i interface{}, keys ...interface{}) (interface{}, error) { + return get(m, m, i, keys...) +} + +// Select runs an arbitrary SQL query, binding the columns in the result +// to fields on the struct specified by i. args represent the bind +// parameters for the SQL statement. +// +// Column names on the SELECT statement should be aliased to the field names +// on the struct i. Returns an error if one or more columns in the result +// do not match. It is OK if fields on i are not part of the SQL +// statement. +// +// The hook function PostGet() will be executed after the SELECT +// statement if the interface defines them. +// +// Values are returned in one of two ways: +// 1. If i is a struct or a pointer to a struct, returns a slice of pointers to +// matching rows of type i. +// 2. If i is a pointer to a slice, the results will be appended to that slice +// and nil returned. +// +// i does NOT need to be registered with AddTable() +func (m *DbMap) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return hookedselect(m, m, i, query, args...) +} + +// Exec runs an arbitrary SQL statement. args represent the bind parameters. +// This is equivalent to running: Exec() using database/sql +func (m *DbMap) Exec(query string, args ...interface{}) (sql.Result, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, args...) + } + return maybeExpandNamedQueryAndExec(m, query, args...) +} + +// SelectInt is a convenience wrapper around the sqldb.SelectInt function +func (m *DbMap) SelectInt(query string, args ...interface{}) (int64, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectInt(m, query, args...) +} + +// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function +func (m *DbMap) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectNullInt(m, query, args...) +} + +// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function +func (m *DbMap) SelectFloat(query string, args ...interface{}) (float64, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectFloat(m, query, args...) +} + +// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function +func (m *DbMap) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectNullFloat(m, query, args...) +} + +// SelectStr is a convenience wrapper around the sqldb.SelectStr function +func (m *DbMap) SelectStr(query string, args ...interface{}) (string, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectStr(m, query, args...) +} + +// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function +func (m *DbMap) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectNullStr(m, query, args...) +} + +// SelectOne is a convenience wrapper around the sqldb.SelectOne function +func (m *DbMap) SelectOne(holder interface{}, query string, args ...interface{}) error { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectOne(m, m, holder, query, args...) +} + +// Begin starts a sqldb Transaction +func (m *DbMap) Begin() (*Transaction, error) { + if m.logger != nil { + now := time.Now() + defer m.trace(now, "begin;") + } + tx, err := begin(m) + if err != nil { + return nil, err + } + return &Transaction{ + dbmap: m, + tx: tx, + closed: false, + }, nil +} + +// TableFor returns the *TableMap corresponding to the given Go Type +// If no table is mapped to that type an error is returned. +// If checkPK is true and the mapped table has no registered PKs, an error is returned. +func (m *DbMap) TableFor(t reflect.Type, checkPK bool) (*TableMap, error) { + table := tableOrNil(m, t, "") + if table == nil { + return nil, fmt.Errorf("no table found for type: %v", t.Name()) + } + + if checkPK && len(table.keys) < 1 { + e := fmt.Sprintf("sqldb: no keys defined for table: %s", + table.TableName) + return nil, errors.New(e) + } + + return table, nil +} + +// DynamicTableFor returns the *TableMap for the dynamic table corresponding +// to the input tablename +// If no table is mapped to that tablename an error is returned. +// If checkPK is true and the mapped table has no registered PKs, an error is returned. +func (m *DbMap) DynamicTableFor(tableName string, checkPK bool) (*TableMap, error) { + table, found := m.dynamicTableFind(tableName) + if !found { + return nil, fmt.Errorf("sqldb: no table found for name: %v", tableName) + } + + if checkPK && len(table.keys) < 1 { + e := fmt.Sprintf("sqldb: no keys defined for table: %s", + table.TableName) + return nil, errors.New(e) + } + + return table, nil +} + +// Prepare creates a prepared statement for later queries or executions. +// Multiple queries or executions may be run concurrently from the returned statement. +// This is equivalent to running: Prepare() using database/sql +func (m *DbMap) Prepare(query string) (*sql.Stmt, error) { + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, nil) + } + return prepare(m, query) +} + +func tableOrNil(m *DbMap, t reflect.Type, name string) *TableMap { + if name != "" { + // Search by table name (dynamic tables) + if table, found := m.dynamicTableFind(name); found { + return table + } + return nil + } + + for i := range m.tables { + table := m.tables[i] + if table.gotype == t { + return table + } + } + return nil +} + +func (m *DbMap) tableForPointer(ptr interface{}, checkPK bool) (*TableMap, reflect.Value, error) { + ptrv := reflect.ValueOf(ptr) + if ptrv.Kind() != reflect.Ptr { + e := fmt.Sprintf("sqldb: passed non-pointer: %v (kind=%v)", ptr, + ptrv.Kind()) + return nil, reflect.Value{}, errors.New(e) + } + elem := ptrv.Elem() + ifc := elem.Interface() + var t *TableMap + var err error + tableName := "" + if dyn, isDyn := ptr.(DynamicTable); isDyn { + tableName = dyn.TableName() + t, err = m.DynamicTableFor(tableName, checkPK) + } else { + etype := reflect.TypeOf(ifc) + t, err = m.TableFor(etype, checkPK) + } + + if err != nil { + return nil, reflect.Value{}, err + } + + return t, elem, nil +} + +func (m *DbMap) QueryRow(query string, args ...interface{}) *sql.Row { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + if m.logger != nil { + now := time.Now() + defer m.trace(now, query, args...) + } + return queryRow(m, query, args...) +} + +func (m *DbMap) Query(q string, args ...interface{}) (*sql.Rows, error) { + if m.ExpandSliceArgs { + expandSliceArgs(&q, args...) + } + + if m.logger != nil { + now := time.Now() + defer m.trace(now, q, args...) + } + return query(m, q, args...) +} + +func (m *DbMap) trace(started time.Time, query string, args ...interface{}) { + if m.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + if m.logger != nil { + var margs = argsString(args...) + m.logger.Printf("%s%s [%s] (%v)", m.logPrefix, query, margs, (time.Now().Sub(started))) + } +} + +type stringer interface { + ToStringSlice() []string +} + +type numberer interface { + ToInt64Slice() []int64 +} + +func expandSliceArgs(query *string, args ...interface{}) { + for _, arg := range args { + mapper, ok := arg.(map[string]interface{}) + if !ok { + continue + } + + for key, value := range mapper { + var replacements []string + + // add flexibility for any custom type to be convert to one of the + // acceptable formats. + if v, ok := value.(stringer); ok { + value = v.ToStringSlice() + } + if v, ok := value.(numberer); ok { + value = v.ToInt64Slice() + } + + switch v := value.(type) { + case []string: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []uint: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []uint8: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []uint16: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []uint32: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []uint64: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []int: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []int8: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []int16: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []int32: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []int64: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []float32: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + case []float64: + for id, replace := range v { + mapper[fmt.Sprintf("%s%d", key, id)] = replace + replacements = append(replacements, fmt.Sprintf(":%s%d", key, id)) + } + default: + continue + } + + if len(replacements) == 0 { + continue + } + + *query = strings.Replace(*query, fmt.Sprintf(":%s", key), strings.Join(replacements, ","), -1) + } + } } diff --git a/gdb/sqldb/db_func.go b/gdb/sqldb/db_func.go deleted file mode 100644 index c77153f..0000000 --- a/gdb/sqldb/db_func.go +++ /dev/null @@ -1,220 +0,0 @@ -// -// db_func.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "errors" - "fmt" - "strconv" - "strings" - - "github.com/jmoiron/sqlx" -) - -func (e *Engine) Begin() (*sqlx.Tx, error) { - return e.Beginx() -} - -// 插入一条记录 -func (e *Engine) NamedInsertRecord(opt *QueryOption, arg interface{}) (int64, error) { // {{{ - if len(opt.fields) == 0 { - return 0, errors.New("empty fields") - } - var tmp = make([]string, 0) - for _, field := range opt.fields { - tmp = append(tmp, fmt.Sprintf(":%s", field)) - } - fields_str := strings.Join(opt.fields, ",") - fields_pl := strings.Join(tmp, ",") - sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", opt.table, fields_str, fields_pl) - if e.DriverName() == "postgres" { - sql += " returning id" - } - // sql = e.Rebind(sql) - stmt, err := e.PrepareNamed(sql) - if err != nil { - return 0, err - } - var id int64 - err = stmt.Get(&id, arg) - if err != nil { - return 0, err - } - return id, err -} // }}} - -// 插入一条记录 -func (e *Engine) InsertRecord(opt *QueryOption) (int64, error) { // {{{ - if len(opt.fields) == 0 { - return 0, errors.New("empty fields") - } - fields_str := strings.Join(opt.fields, ",") - fields_pl := strings.TrimRight(strings.Repeat("?,", len(opt.fields)), ",") - sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", opt.table, fields_str, fields_pl) - if e.DriverName() == "postgres" { - sql += " returning id" - } - sql = e.Rebind(sql) - result, err := e.Exec(sql, opt.args...) - if err != nil { - return 0, err - } - return result.LastInsertId() -} // }}} - -// 查询一条记录 -// dest 目标对象 -// table 查询表 -// query 查询条件 -// args bindvars -func (e *Engine) GetRecord(dest interface{}, opt *QueryOption) error { // {{{ - if opt.query == "" { - return errors.New("empty query") - } - opt.query = "WHERE " + opt.query - sql := fmt.Sprintf("SELECT * FROM %s %s limit 1", opt.table, opt.query) - sql = e.Rebind(sql) - err := e.Get(dest, sql, opt.args...) - if err != nil { - return err - } - return nil -} // }}} - -// 查询多条记录 -// dest 目标变量 -// opt 查询对象 -// args bindvars -func (e *Engine) GetRecords(dest interface{}, opt *QueryOption) error { // {{{ - var tmp = []string{} - if opt.query != "" { - tmp = append(tmp, "where", opt.query) - } - if opt.sort != "" { - tmp = append(tmp, "order by", opt.sort) - } - if opt.offset > 0 { - tmp = append(tmp, "offset", strconv.Itoa(opt.offset)) - } - if opt.limit > 0 { - tmp = append(tmp, "limit", strconv.Itoa(opt.limit)) - } - sql := fmt.Sprintf("select * from %s %s", opt.table, strings.Join(tmp, " ")) - sql = e.Rebind(sql) - return e.Select(dest, sql, opt.args...) -} // }}} - -// 更新一条记录 -// table 待处理的表 -// set 需要设置的语句, eg: age=:age -// query 查询语句,不能为空,确保误更新所有记录 -// arg 值 -func (e *Engine) NamedUpdateRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{ - if opt.set == "" || opt.query == "" { - return 0, errors.New("empty set or query") - } - sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query) - result, err := e.NamedExec(sql, arg) - if err != nil { - return 0, err - } - rows, err := result.RowsAffected() - if err != nil { - return 0, err - } - return rows, nil -} // }}} - -func (e *Engine) UpdateRecords(opt *QueryOption) (int64, error) { // {{{ - if opt.set == "" || opt.query == "" { - return 0, errors.New("empty set or query") - } - sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query) - sql = e.Rebind(sql) - result, err := e.Exec(sql, opt.args...) - if err != nil { - return 0, err - } - rows, err := result.RowsAffected() - if err != nil { - return 0, err - } - return rows, nil -} // }}} - -// 删除若干条记录 -// opt 的 query 不能为空 -// arg bindvars -func (e *Engine) NamedDeleteRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{ - if opt.query == "" { - return 0, errors.New("emtpy query") - } - sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query) - result, err := e.NamedExec(sql, arg) - if err != nil { - return 0, err - } - rows, err := result.RowsAffected() - if err != nil { - return 0, err - } - return rows, nil -} // }}} - -func (e *Engine) DeleteRecords(opt *QueryOption) (int64, error) { - if opt.query == "" { - return 0, errors.New("emtpy query") - } - sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query) - sql = e.Rebind(sql) - result, err := e.Exec(sql, opt.args...) - if err != nil { - return 0, err - } - rows, err := result.RowsAffected() - if err != nil { - return 0, err - } - return rows, nil -} - -func (e *Engine) CountRecords(opt *QueryOption) (int, error) { - sql := fmt.Sprintf("select count(*) from %s where %s", opt.table, opt.query) - sql = e.Rebind(sql) - var num int - err := e.Get(&num, sql, opt.args...) - if err != nil { - return 0, err - } - return num, nil -} - -// var levels = []int{4, 6, 7} -// query, args, err := sqlx.In("SELECT * FROM users WHERE level IN (?);", levels) -// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend -// query = db.Rebind(query) -// rows, err := db.Query(query, args...) -func (e *Engine) In(query string, args ...interface{}) (string, []interface{}, error) { - return sqlx.In(query, args...) -} - -func IsNoRows(err error) bool { - return err == ErrNoRows -} - -// 把 fields 转换为 field1=:field1, field2=:field2, ..., fieldN=:fieldN -func GetSetString(fields []string) string { - items := []string{} - for _, field := range fields { - if field == "id" { - continue - } - items = append(items, fmt.Sprintf("%s=:%s", field, field)) - } - return strings.Join(items, ",") -} diff --git a/gdb/sqldb/db_func_opt.go b/gdb/sqldb/db_func_opt.go deleted file mode 100644 index 1d596c5..0000000 --- a/gdb/sqldb/db_func_opt.go +++ /dev/null @@ -1,75 +0,0 @@ -// -// db_func_opt.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -type QueryOption struct { - table string - query string - set string - fields []string - sort string - offset int - limit int - args []any - joins []string -} - -func NewQueryOption(table string) *QueryOption { - return &QueryOption{ - table: table, - fields: []string{"*"}, - offset: 0, - limit: 0, - args: make([]any, 0), - joins: make([]string, 0), - } -} -func (opt *QueryOption) Query(query string) *QueryOption { - opt.query = query - return opt -} -func (opt *QueryOption) Fields(args []string) *QueryOption { - opt.fields = args - return opt -} -func (opt *QueryOption) Select(cols ...string) *QueryOption { - opt.fields = cols - return opt -} -func (opt *QueryOption) Offset(offset int) *QueryOption { - opt.offset = offset - return opt -} -func (opt *QueryOption) Limit(limit int) *QueryOption { - opt.limit = limit - return opt -} -func (opt *QueryOption) Sort(sort string) *QueryOption { - opt.sort = sort - return opt -} -func (opt *QueryOption) Set(set string) *QueryOption { - opt.set = set - return opt -} -func (opt *QueryOption) Args(args ...any) *QueryOption { - opt.args = args - return opt -} -func (opt *QueryOption) Join(table string, cond string) *QueryOption { - opt.joins = append(opt.joins, "join "+table+" on "+cond) - return opt -} -func (opt *QueryOption) LeftJoin(table string, cond string) *QueryOption { - opt.joins = append(opt.joins, "left join "+table+" on "+cond) - return opt -} -func (opt *QueryOption) RightJoin(table string, cond string) *QueryOption { - opt.joins = append(opt.joins, "right join "+table+" on "+cond) - return opt -} diff --git a/gdb/sqldb/db_func_test.go b/gdb/sqldb/db_func_test.go deleted file mode 100644 index 6eccf5b..0000000 --- a/gdb/sqldb/db_func_test.go +++ /dev/null @@ -1,114 +0,0 @@ -// -// db_func_test.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb_test - -import ( - "testing" - "time" - - "git.hexq.cn/tiglog/golib/gdb/sqldb" - "git.hexq.cn/tiglog/golib/gtest" -) - -// 经过测试,发现数据库里面使用 time 类型容易出现 timezone 不一致的情况 -// 在存入数据库时,可能会导致时区丢失 -// 因此,为了更好的兼容性,使用 int 时间戳会更合适 -func dbFuncTest(db *sqldb.Engine, t *testing.T) { - var err error - fields := []string{"first_name", "last_name", "email"} - p := &Person{ - FirstName: "三", - LastName: "张", - Email: "zs@foo.com", - } - // InsertRecord 的用法 - opt := sqldb.NewQueryOption("person").Fields(fields) - rows, err := db.NamedInsertRecord(opt, p) - gtest.Nil(t, err) - gtest.True(t, rows > 0) - // fmt.Println(rows) - - // GetRecord 的用法 - var p3 Person - opt = sqldb.NewQueryOption("person").Query("email=?").Args("zs@foo.com") - err = db.GetRecord(&p3, opt) - // fmt.Println(p3) - gtest.Equal(t, "张", p3.LastName) - gtest.Equal(t, "三", p3.FirstName) - gtest.Equal(t, int64(0), p3.AddedAt) - gtest.Nil(t, err) - - p2 := &Person{ - FirstName: "四", - LastName: "李", - Email: "ls@foo.com", - AddedAt: time.Now().Unix(), - } - fields2 := append(fields, "added_at") - opt = sqldb.NewQueryOption("person").Fields(fields2) - _, err = db.NamedInsertRecord(opt, p2) - gtest.Nil(t, err) - - var p4 Person - opt = sqldb.NewQueryOption("person") - err = db.GetRecord(&p4, opt) - gtest.NotNil(t, err) - gtest.Equal(t, "", p4.FirstName) - - opt = sqldb.NewQueryOption("person").Query("first_name=?").Args("四") - err = db.GetRecord(&p4, opt) - gtest.Nil(t, err) - gtest.Equal(t, time.Now().Unix(), p4.AddedAt) - gtest.Equal(t, "ls@foo.com", p4.Email) - - // GetRecords - var ps []Person - opt = sqldb.NewQueryOption("person").Query("id > ?").Args(0) - err = db.GetRecords(&ps, opt) - gtest.Nil(t, err) - gtest.Greater(t, int64(1), ps) - - var ps2 []Person - opt = sqldb.NewQueryOption("person").Query("id=?").Args(1) - err = db.GetRecords(&ps2, opt) - gtest.Equal(t, 1, len(ps2)) - if len(ps2) > 1 { - gtest.Equal(t, int64(1), ps2[0].Id) - } - - // DeleteRecords - opt = sqldb.NewQueryOption("person").Query("id=?").Args(2) - n, err := db.DeleteRecords(opt) - gtest.Nil(t, err) - gtest.Greater(t, int64(0), n) - - // UpdateRecords - opt = sqldb.NewQueryOption("person").Set("first_name=?").Query("email=?").Args("哈哈", "zs@foo.com") - n, err = db.UpdateRecords(opt) - gtest.Nil(t, err) - gtest.Greater(t, int64(0), n) - - // NamedUpdateRecords - var p5 = ps[0] - p5.FirstName = "中华人民共和国" - opt = sqldb.NewQueryOption("person").Set("first_name=:first_name").Query("email=:email") - n, err = db.NamedUpdateRecords(opt, p5) - gtest.Nil(t, err) - gtest.Greater(t, int64(0), n) - - var p6 Person - opt = sqldb.NewQueryOption("person").Query("first_name=?").Args(p5.FirstName) - err = db.GetRecord(&p6, opt) - gtest.Nil(t, err) - gtest.Greater(t, int64(0), p6.Id) - gtest.Equal(t, p6.FirstName, p5.FirstName) -} - -func TestFunc(t *testing.T) { - RunDbTest(t, dbFuncTest) -} diff --git a/gdb/sqldb/db_model.go b/gdb/sqldb/db_model.go deleted file mode 100644 index 7a68fd3..0000000 --- a/gdb/sqldb/db_model.go +++ /dev/null @@ -1,20 +0,0 @@ -// -// db_model.go -// Copyright (C) 2023 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -// TODO 暂时不好实现,以后再说 - -type Model struct { - db *Engine -} - -func NewModel() *Model { - return &Model{ - db: Db, - } -} diff --git a/gdb/sqldb/db_query.go b/gdb/sqldb/db_query.go deleted file mode 100644 index 5acda3c..0000000 --- a/gdb/sqldb/db_query.go +++ /dev/null @@ -1,322 +0,0 @@ -// -// db_query.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb - -import ( - "errors" - "fmt" - "strconv" - "strings" - - "github.com/jmoiron/sqlx" -) - -type Query struct { - db *Engine - table string - fields []string - wheres []string // 不能太复杂 - joins []string - orderBy string - groupBy string - offset int - limit int -} - -func NewQueryBuild(table string, db *Engine) *Query { - return &Query{ - db: db, - table: table, - fields: []string{}, - wheres: []string{}, - joins: []string{}, - offset: 0, - limit: 0, - } -} - -func (q *Query) Table(table string) *Query { - q.table = table - return q -} - -// 设置 select fields -func (q *Query) Select(fields ...string) *Query { - q.fields = fields - return q -} - -// 增加一个 select field -func (q *Query) AddFields(fields ...string) *Query { - q.fields = append(q.fields, fields...) - return q -} - -func (q *Query) Where(query string) *Query { - q.wheres = []string{query} - return q -} -func (q *Query) AndWhere(query string) *Query { - q.wheres = append(q.wheres, "and "+query) - return q -} - -func (q *Query) OrWhere(query string) *Query { - q.wheres = append(q.wheres, "or "+query) - return q -} - -func (q *Query) Join(table string, on string) *Query { - var join = "join " + table - if on != "" { - join = join + " on " + on - } - q.joins = append(q.joins, join) - return q -} - -func (q *Query) LeftJoin(table string, on string) *Query { - var join = "left join " + table - if on != "" { - join = join + " on " + on - } - q.joins = append(q.joins, join) - return q -} - -func (q *Query) RightJoin(table string, on string) *Query { - var join = "right join " + table - if on != "" { - join = join + " on " + on - } - q.joins = append(q.joins, join) - return q -} - -func (q *Query) InnerJoin(table string, on string) *Query { - var join = "inner join " + table - if on != "" { - join = join + " on " + on - } - q.joins = append(q.joins, join) - return q -} - -func (q *Query) OrderBy(order string) *Query { - q.orderBy = order - return q -} -func (q *Query) GroupBy(group string) *Query { - q.groupBy = group - return q -} - -func (q *Query) Offset(offset int) *Query { - q.offset = offset - return q -} - -func (q *Query) Limit(limit int) *Query { - q.limit = limit - return q -} - -// returningId postgres 数据库返回 LastInsertId 处理 -// TODO returningId 暂时不处理 -func (q *Query) getInsertSql(named, returningId bool) string { - fields_str := strings.Join(q.fields, ",") - var pl string - if named { - var tmp []string - for _, field := range q.fields { - tmp = append(tmp, ":"+field) - } - pl = strings.Join(tmp, ",") - } else { - pl = strings.Repeat("?,", len(q.fields)) - pl = strings.TrimRight(pl, ",") - } - - sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl) - sql = q.db.Rebind(sql) - // fmt.Println(sql) - return sql -} - -// return RowsAffected, error -func (q *Query) Insert(args ...interface{}) (int64, error) { - if len(q.fields) == 0 { - return 0, errors.New("empty fields") - } - sql := q.getInsertSql(false, false) - result, err := q.db.Exec(sql, args...) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -// return RowsAffected, error -func (q *Query) NamedInsert(arg interface{}) (int64, error) { - if len(q.fields) == 0 { - return 0, errors.New("empty fields") - } - sql := q.getInsertSql(true, false) - result, err := q.db.NamedExec(sql, arg) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -func (q *Query) getQuerySql() string { - var ( - fields_str string = "*" - join_str string - where_str string - offlim string - ) - if len(q.fields) > 0 { - fields_str = strings.Join(q.fields, ",") - } - - if len(q.joins) > 0 { - join_str = strings.Join(q.joins, " ") - } - if len(q.wheres) > 0 { - where_str = "where " + strings.Join(q.wheres, " ") - } - - if q.offset > 0 { - offlim = " offset " + strconv.Itoa(q.offset) - } - if q.limit > 0 { - offlim = " limit " + strconv.Itoa(q.limit) - } - // select fields from table t join where groupby orderby offset limit - sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim) - return sql -} - -func (q *Query) One(dest interface{}, args ...interface{}) error { - q.Limit(1) - sql := q.getQuerySql() - sql = q.db.Rebind(sql) - return q.db.Get(dest, sql, args...) -} - -func (q *Query) NamedOne(dest interface{}, arg interface{}) error { - q.Limit(1) - sql := q.getQuerySql() - rows, err := q.db.NamedQuery(sql, arg) - if err != nil { - return err - } - if rows.Next() { - return rows.Scan(dest) - } - return errors.New("nr") // no record -} - -func (q *Query) All(dest interface{}, args ...interface{}) error { - sql := q.getQuerySql() - sql = q.db.Rebind(sql) - return q.db.Select(dest, sql, args...) -} - -// 为了省内存,直接返回迭代器 -func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) { - sql := q.getQuerySql() - return q.db.NamedQuery(sql, arg) -} - -// set age=? / age=:age -func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) { - var where_str string - if len(q.wheres) > 0 { - where_str = strings.Join(q.wheres, " ") - } - if set == "" || where_str == "" { - return 0, errors.New("empty set or where") - } - - // update table t where - sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str) - sql = q.db.Rebind(sql) - result, err := q.db.NamedExec(sql, arg) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -// 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数 -func (q *Query) Update(set string, args ...interface{}) (int64, error) { - var where_str string - if len(q.wheres) > 0 { - where_str = strings.Join(q.wheres, " ") - } - if set == "" || where_str == "" { - return 0, errors.New("empty set or where") - } - - // update table t where - sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str) - sql = q.db.Rebind(sql) - result, err := q.db.Exec(sql, args...) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -// 普通的删除 -func (q *Query) Delete(args ...interface{}) (int64, error) { - var where_str string - if len(q.wheres) == 0 { - return 0, errors.New("missing where clause") - } - where_str = strings.Join(q.wheres, " ") - - sql := fmt.Sprintf("delete from %s where %s", q.table, where_str) - sql = q.db.Rebind(sql) - result, err := q.db.Exec(sql, args...) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -func (q *Query) NamedDelete(arg interface{}) (int64, error) { - if len(q.wheres) == 0 { - return 0, errors.New("missing where clause") - } - var where_str string - where_str = strings.Join(q.wheres, " ") - - sql := fmt.Sprintf("delete from %s where %s", q.table, where_str) - sql = q.db.Rebind(sql) - - result, err := q.db.NamedExec(sql, arg) - if err != nil { - return 0, err - } - return result.RowsAffected() -} - -func (q *Query) Count(args ...interface{}) (int64, error) { - var where_str string - if len(q.wheres) > 0 { - where_str = " where " + strings.Join(q.wheres, " ") - } - sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str) - sql = q.db.Rebind(sql) - var num int64 - err := q.db.Get(&num, sql, args...) - return num, err -} diff --git a/gdb/sqldb/db_query_test.go b/gdb/sqldb/db_query_test.go deleted file mode 100644 index 7790595..0000000 --- a/gdb/sqldb/db_query_test.go +++ /dev/null @@ -1,109 +0,0 @@ -// -// db_query_test.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package sqldb_test - -import ( - "testing" - "time" - - "git.hexq.cn/tiglog/golib/gtest" - - "git.hexq.cn/tiglog/golib/gdb/sqldb" -) - -func dbQueryTest(db *sqldb.Engine, t *testing.T) { - query := sqldb.NewQueryBuild("person", db) - // query one - var p1 Person - query.Where("id=?") - err := query.One(&p1, 1) - gtest.Nil(t, err) - gtest.Equal(t, int64(1), p1.Id) - - // query all - var ps1 []Person - query = sqldb.NewQueryBuild("person", db) - query.Where("id > ?") - err = query.All(&ps1, 1) - gtest.Nil(t, err) - gtest.True(t, len(ps1) > 0) - // fmt.Println(ps1) - if len(ps1) > 0 { - var val int64 = 2 - gtest.Equal(t, val, ps1[0].Id) - } - - // insert - query = sqldb.NewQueryBuild("person", db) - query.AddFields("first_name", "last_name", "email") - id, err := query.Insert("三", "张", "zs@bar.com") - gtest.Nil(t, err) - gtest.Greater(t, int64(0), id) - // fmt.Println(id) - - // named insert - query = sqldb.NewQueryBuild("person", db) - query.AddFields("first_name", "last_name", "email") - row, err := query.NamedInsert(&Person{ - FirstName: "四", - LastName: "李", - Email: "ls@bar.com", - AddedAt: time.Now().Unix(), - }) - gtest.Nil(t, err) - gtest.Equal(t, int64(1), row) - - // update - query = sqldb.NewQueryBuild("person", db) - query.Where("email=?") - n, err := query.Update("first_name=?", "哈哈", "ls@bar.com") - gtest.Nil(t, err) - gtest.Equal(t, int64(1), n) - - // named update map - query = sqldb.NewQueryBuild("person", db) - query.Where("email=:email") - n, err = query.NamedUpdate("first_name=:first_name", map[string]interface{}{ - "email": "ls@bar.com", - "first_name": "中华人民共和国", - }) - gtest.Nil(t, err) - gtest.Equal(t, int64(1), n) - - // named update struct - query = sqldb.NewQueryBuild("person", db) - query.Where("email=:email") - var p = &Person{ - Email: "ls@bar.com", - LastName: "中华人民共和国,救民于水火", - } - n, err = query.NamedUpdate("last_name=:last_name", p) - gtest.Nil(t, err) - gtest.Equal(t, int64(1), n) - - // count - query = sqldb.NewQueryBuild("person", db) - n, err = query.Count() - gtest.Nil(t, err) - // fmt.Println(n) - gtest.Greater(t, int64(0), n) - - // delete - query = sqldb.NewQueryBuild("person", db) - n, err = query.Delete() - gtest.NotNil(t, err) - gtest.Equal(t, int64(0), n) - - n, err = query.Where("id=?").Delete(2) - gtest.Nil(t, err) - gtest.Equal(t, int64(1), n) -} - -func TestQuery(t *testing.T) { - RunDbTest(t, dbQueryTest) -} diff --git a/gdb/sqldb/db_test.go b/gdb/sqldb/db_test.go new file mode 100644 index 0000000..41ce8b0 --- /dev/null +++ b/gdb/sqldb/db_test.go @@ -0,0 +1,187 @@ +// +// db_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +//go:build integration +// +build integration + +package sqldb_test + +import ( + "testing" +) + +type customType1 []string + +func (c customType1) ToStringSlice() []string { + return []string(c) +} + +type customType2 []int64 + +func (c customType2) ToInt64Slice() []int64 { + return []int64(c) +} + +func TestDbMap_Select_expandSliceArgs(t *testing.T) { + tests := []struct { + description string + query string + args []interface{} + wantLen int + }{ + { + description: "it should handle slice placeholders correctly", + query: ` +SELECT 1 FROM crazy_table +WHERE field1 = :Field1 +AND field2 IN (:FieldStringList) +AND field3 IN (:FieldUIntList) +AND field4 IN (:FieldUInt8List) +AND field5 IN (:FieldUInt16List) +AND field6 IN (:FieldUInt32List) +AND field7 IN (:FieldUInt64List) +AND field8 IN (:FieldIntList) +AND field9 IN (:FieldInt8List) +AND field10 IN (:FieldInt16List) +AND field11 IN (:FieldInt32List) +AND field12 IN (:FieldInt64List) +AND field13 IN (:FieldFloat32List) +AND field14 IN (:FieldFloat64List) +`, + args: []interface{}{ + map[string]interface{}{ + "Field1": 123, + "FieldStringList": []string{"h", "e", "y"}, + "FieldUIntList": []uint{1, 2, 3, 4}, + "FieldUInt8List": []uint8{1, 2, 3, 4}, + "FieldUInt16List": []uint16{1, 2, 3, 4}, + "FieldUInt32List": []uint32{1, 2, 3, 4}, + "FieldUInt64List": []uint64{1, 2, 3, 4}, + "FieldIntList": []int{1, 2, 3, 4}, + "FieldInt8List": []int8{1, 2, 3, 4}, + "FieldInt16List": []int16{1, 2, 3, 4}, + "FieldInt32List": []int32{1, 2, 3, 4}, + "FieldInt64List": []int64{1, 2, 3, 4}, + "FieldFloat32List": []float32{1, 2, 3, 4}, + "FieldFloat64List": []float64{1, 2, 3, 4}, + }, + }, + wantLen: 1, + }, + { + description: "it should handle slice placeholders correctly with custom types", + query: ` +SELECT 1 FROM crazy_table +WHERE field2 IN (:FieldStringList) +AND field12 IN (:FieldIntList) +`, + args: []interface{}{ + map[string]interface{}{ + "FieldStringList": customType1{"h", "e", "y"}, + "FieldIntList": customType2{1, 2, 3, 4}, + }, + }, + wantLen: 3, + }, + } + + type dataFormat struct { + Field1 int `db:"field1"` + Field2 string `db:"field2"` + Field3 uint `db:"field3"` + Field4 uint8 `db:"field4"` + Field5 uint16 `db:"field5"` + Field6 uint32 `db:"field6"` + Field7 uint64 `db:"field7"` + Field8 int `db:"field8"` + Field9 int8 `db:"field9"` + Field10 int16 `db:"field10"` + Field11 int32 `db:"field11"` + Field12 int64 `db:"field12"` + Field13 float32 `db:"field13"` + Field14 float64 `db:"field14"` + } + + dbmap := newDBMap(t) + dbmap.ExpandSliceArgs = true + dbmap.AddTableWithName(dataFormat{}, "crazy_table") + + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + err = dbmap.Insert( + &dataFormat{ + Field1: 123, + Field2: "h", + Field3: 1, + Field4: 1, + Field5: 1, + Field6: 1, + Field7: 1, + Field8: 1, + Field9: 1, + Field10: 1, + Field11: 1, + Field12: 1, + Field13: 1, + Field14: 1, + }, + &dataFormat{ + Field1: 124, + Field2: "e", + Field3: 2, + Field4: 2, + Field5: 2, + Field6: 2, + Field7: 2, + Field8: 2, + Field9: 2, + Field10: 2, + Field11: 2, + Field12: 2, + Field13: 2, + Field14: 2, + }, + &dataFormat{ + Field1: 125, + Field2: "y", + Field3: 3, + Field4: 3, + Field5: 3, + Field6: 3, + Field7: 3, + Field8: 3, + Field9: 3, + Field10: 3, + Field11: 3, + Field12: 3, + Field13: 3, + Field14: 3, + }, + ) + + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + var dummy []int + _, err := dbmap.Select(&dummy, tt.query, tt.args...) + if err != nil { + t.Fatal(err) + } + + if len(dummy) != tt.wantLen { + t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen) + } + }) + } +} diff --git a/gdb/sqldb/dialect.go b/gdb/sqldb/dialect.go new file mode 100644 index 0000000..88ecba3 --- /dev/null +++ b/gdb/sqldb/dialect.go @@ -0,0 +1,108 @@ +// +// dialect.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "reflect" +) + +// The Dialect interface encapsulates behaviors that differ across +// SQL databases. At present the Dialect is only used by CreateTables() +// but this could change in the future +type Dialect interface { + // adds a suffix to any query, usually ";" + QuerySuffix() string + + // ToSqlType returns the SQL column type to use when creating a + // table of the given Go Type. maxsize can be used to switch based on + // size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB, + // or LONGBLOB depending on the maxsize + ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string + + // string to append to primary key column definitions + AutoIncrStr() string + + // string to bind autoincrement columns to. Empty string will + // remove reference to those columns in the INSERT statement. + AutoIncrBindValue() string + + AutoIncrInsertSuffix(col *ColumnMap) string + + // string to append to "create table" statement for vendor specific + // table attributes + CreateTableSuffix() string + + // string to append to "create index" statement + CreateIndexSuffix() string + + // string to append to "drop index" statement + DropIndexSuffix() string + + // string to truncate tables + TruncateClause() string + + // bind variable string to use when forming SQL statements + // in many dbs it is "?", but Postgres appears to use $1 + // + // i is a zero based index of the bind variable in this statement + // + BindVar(i int) string + + // Handles quoting of a field name to ensure that it doesn't raise any + // SQL parsing exceptions by using a reserved word as a field name. + QuoteField(field string) string + + // Handles building up of a schema.database string that is compatible with + // the given dialect + // + // schema - The schema that lives in + // table - The table name + QuotedTableForQuery(schema string, table string) string + + // Existence clause for table creation / deletion + IfSchemaNotExists(command, schema string) string + IfTableExists(command, schema, table string) string + IfTableNotExists(command, schema, table string) string +} + +// IntegerAutoIncrInserter is implemented by dialects that can perform +// inserts with automatically incremented integer primary keys. If +// the dialect can handle automatic assignment of more than just +// integers, see TargetedAutoIncrInserter. +type IntegerAutoIncrInserter interface { + InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) +} + +// TargetedAutoIncrInserter is implemented by dialects that can +// perform automatic assignment of any primary key type (i.e. strings +// for uuids, integers for serials, etc). +type TargetedAutoIncrInserter interface { + // InsertAutoIncrToTarget runs an insert operation and assigns the + // automatically generated primary key directly to the passed in + // target. The target should be a pointer to the primary key + // field of the value being inserted. + InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error +} + +// TargetQueryInserter is implemented by dialects that can perform +// assignment of integer primary key type by executing a query +// like "select sequence.currval from dual". +type TargetQueryInserter interface { + // TargetQueryInserter runs an insert operation and assigns the + // automatically generated primary key retrived by the query + // extracted from the GeneratedIdQuery field of the id column. + InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error +} + +func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + res, err := exec.Exec(insertSql, params...) + if err != nil { + return 0, err + } + return res.LastInsertId() +} diff --git a/gdb/sqldb/dialect_mysql.go b/gdb/sqldb/dialect_mysql.go new file mode 100644 index 0000000..c60cfbe --- /dev/null +++ b/gdb/sqldb/dialect_mysql.go @@ -0,0 +1,172 @@ +// +// dialect_mysql.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +// Implementation of Dialect for MySQL databases. +type MySQLDialect struct { + + // Engine is the storage engine to use "InnoDB" vs "MyISAM" for example + Engine string + + // Encoding is the character encoding to use for created tables + Encoding string +} + +func (d MySQLDialect) QuerySuffix() string { return ";" } + +func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int8: + return "tinyint" + case reflect.Uint8: + return "tinyint unsigned" + case reflect.Int16: + return "smallint" + case reflect.Uint16: + return "smallint unsigned" + case reflect.Int, reflect.Int32: + return "int" + case reflect.Uint, reflect.Uint32: + return "int unsigned" + case reflect.Int64: + return "bigint" + case reflect.Uint64: + return "bigint unsigned" + case reflect.Float64, reflect.Float32: + return "double" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "mediumblob" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double" + case "NullBool": + return "tinyint" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + + /* == About varchar(N) == + * N is number of characters. + * A varchar column can store up to 65535 bytes. + * Remember that 1 character is 3 bytes in utf-8 charset. + * Also remember that each row can store up to 65535 bytes, + * and you have some overheads, so it's not possible for a + * varchar column to have 65535/3 characters really. + * So it would be better to use 'text' type in stead of + * large varchar type. + */ + if maxsize < 256 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } +} + +// Returns auto_increment +func (d MySQLDialect) AutoIncrStr() string { + return "auto_increment" +} + +func (d MySQLDialect) AutoIncrBindValue() string { + return "null" +} + +func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns engine=%s charset=%s based on values stored on struct +func (d MySQLDialect) CreateTableSuffix() string { + if d.Engine == "" || d.Encoding == "" { + msg := "sqldb - undefined" + + if d.Engine == "" { + msg += " MySQLDialect.Engine" + } + if d.Engine == "" && d.Encoding == "" { + msg += "," + } + if d.Encoding == "" { + msg += " MySQLDialect.Encoding" + } + msg += ". Check that your MySQLDialect was correctly initialized when declared." + panic(msg) + } + + return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding) +} + +func (d MySQLDialect) CreateIndexSuffix() string { + return "using" +} + +func (d MySQLDialect) DropIndexSuffix() string { + return "on" +} + +func (d MySQLDialect) TruncateClause() string { + return "truncate" +} + +func (d MySQLDialect) SleepClause(s time.Duration) string { + return fmt.Sprintf("sleep(%f)", s.Seconds()) +} + +// Returns "?" +func (d MySQLDialect) BindVar(i int) string { + return "?" +} + +func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d MySQLDialect) QuoteField(f string) string { + return "`" + f + "`" +} + +func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d MySQLDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d MySQLDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d MySQLDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} diff --git a/gdb/sqldb/dialect_mysql_test.go b/gdb/sqldb/dialect_mysql_test.go new file mode 100644 index 0000000..e60bc9e --- /dev/null +++ b/gdb/sqldb/dialect_mysql_test.go @@ -0,0 +1,195 @@ +// +// dialect_mysql_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +//go:build !integration +// +build !integration + +package sqldb_test + +import ( + "database/sql" + "errors" + "fmt" + "reflect" + "testing" + "time" + + "git.hexq.cn/tiglog/golib/gdb/sqldb" + "github.com/poy/onpar" + "github.com/poy/onpar/expect" + "github.com/poy/onpar/matchers" +) + +func TestMySQLDialect(t *testing.T) { + // o := onpar.New(t) + // defer o.Run() + + type testContext struct { + t *testing.T + dialect sqldb.MySQLDialect + } + + o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext { + return testContext{ + t: t, + dialect: sqldb.MySQLDialect{Engine: "foo", Encoding: "bar"}, + } + }) + defer o.Run() + + o.Group("ToSqlType", func() { + tests := []struct { + name string + value interface{} + maxSize int + autoIncr bool + expected string + }{ + {"bool", true, 0, false, "boolean"}, + {"int8", int8(1), 0, false, "tinyint"}, + {"uint8", uint8(1), 0, false, "tinyint unsigned"}, + {"int16", int16(1), 0, false, "smallint"}, + {"uint16", uint16(1), 0, false, "smallint unsigned"}, + {"int32", int32(1), 0, false, "int"}, + {"int (treated as int32)", int(1), 0, false, "int"}, + {"uint32", uint32(1), 0, false, "int unsigned"}, + {"uint (treated as uint32)", uint(1), 0, false, "int unsigned"}, + {"int64", int64(1), 0, false, "bigint"}, + {"uint64", uint64(1), 0, false, "bigint unsigned"}, + {"float32", float32(1), 0, false, "double"}, + {"float64", float64(1), 0, false, "double"}, + {"[]uint8", []uint8{1}, 0, false, "mediumblob"}, + {"NullInt64", sql.NullInt64{}, 0, false, "bigint"}, + {"NullFloat64", sql.NullFloat64{}, 0, false, "double"}, + {"NullBool", sql.NullBool{}, 0, false, "tinyint"}, + {"Time", time.Time{}, 0, false, "datetime"}, + {"default-size string", "", 0, false, "varchar(255)"}, + {"sized string", "", 50, false, "varchar(50)"}, + {"large string", "", 1024, false, "text"}, + } + for _, t := range tests { + o.Spec(t.name, func(tt testContext) { + typ := reflect.TypeOf(t.value) + sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr) + expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected)) + }) + } + }) + + o.Spec("AutoIncrStr", func(tt testContext) { + expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("auto_increment")) + }) + + o.Spec("AutoIncrBindValue", func(tt testContext) { + expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("null")) + }) + + o.Spec("AutoIncrInsertSuffix", func(tt testContext) { + expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(nil)).To(matchers.Equal("")) + }) + + o.Group("CreateTableSuffix", func() { + o.Group("with an empty engine", func() { + o1 := onpar.BeforeEach(o, func(tt testContext) testContext { + tt.dialect.Encoding = "" + return tt + }) + o1.Spec("panics", func(tt testContext) { + expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic()) + }) + }) + + o.Group("with an empty encoding", func() { + o2 := onpar.BeforeEach(o, func(tt testContext) testContext { + tt.dialect.Encoding = "" + return tt + }) + o2.Spec("panics", func(tt testContext) { + expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic()) + }) + }) + + o.Spec("with an engine and an encoding", func(tt testContext) { + expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(" engine=foo charset=bar")) + }) + }) + + o.Spec("CreateIndexSuffix", func(tt testContext) { + expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using")) + }) + + o.Spec("DropIndexSuffix", func(tt testContext) { + expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("on")) + }) + + o.Spec("TruncateClause", func(tt testContext) { + expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate")) + }) + + o.Spec("SleepClause", func(tt testContext) { + expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("sleep(1.000000)")) + expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("sleep(0.100000)")) + }) + + o.Spec("BindVar", func(tt testContext) { + expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("?")) + }) + + o.Spec("QuoteField", func(tt testContext) { + expect.Expect(t, tt.dialect.QuoteField("foo")).To(matchers.Equal("`foo`")) + }) + + o.Group("QuotedTableForQuery", func() { + o.Spec("using the default schema", func(tt testContext) { + expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal("`foo`")) + }) + + o.Spec("with a supplied schema", func(tt testContext) { + expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`")) + }) + }) + + o.Spec("IfSchemaNotExists", func(tt testContext) { + expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists")) + }) + + o.Spec("IfTableExists", func(tt testContext) { + expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists")) + }) + + o.Spec("IfTableNotExists", func(tt testContext) { + expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists")) + }) +} + +type panicMatcher struct { +} + +func Panic() panicMatcher { + return panicMatcher{} +} + +func (m panicMatcher) Match(actual interface{}) (resultValue interface{}, err error) { + switch f := actual.(type) { + case func(): + panicked := false + func() { + defer func() { + if r := recover(); r != nil { + panicked = true + } + }() + f() + }() + if panicked { + return f, nil + } + return f, errors.New("function did not panic") + default: + return f, fmt.Errorf("%T is not func()", f) + } +} diff --git a/gdb/sqldb/dialect_oracle.go b/gdb/sqldb/dialect_oracle.go new file mode 100644 index 0000000..127c857 --- /dev/null +++ b/gdb/sqldb/dialect_oracle.go @@ -0,0 +1,142 @@ +// +// dialect_oracle.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "fmt" + "reflect" + "strings" +) + +// Implementation of Dialect for Oracle databases. +type OracleDialect struct{} + +func (d OracleDialect) QuerySuffix() string { return "" } + +func (d OracleDialect) CreateIndexSuffix() string { return "" } + +func (d OracleDialect) DropIndexSuffix() string { return "" } + +func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "NullTime", "Time": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d OracleDialect) AutoIncrStr() string { + return "" +} + +func (d OracleDialect) AutoIncrBindValue() string { + return "NULL" +} + +func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns suffix +func (d OracleDialect) CreateTableSuffix() string { + return "" +} + +func (d OracleDialect) TruncateClause() string { + return "truncate" +} + +// Returns "$(i+1)" +func (d OracleDialect) BindVar(i int) string { + return fmt.Sprintf(":%d", i+1) +} + +// After executing the insert uses the ColMap IdQuery to get the generated id +func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error { + _, err := exec.Exec(insertSql, params...) + if err != nil { + return err + } + id, err := exec.SelectInt(idSql) + if err != nil { + return err + } + switch target.(type) { + case *int64: + *(target.(*int64)) = id + case *int32: + *(target.(*int32)) = int32(id) + case int: + *(target.(*int)) = int(id) + default: + return fmt.Errorf("Id field can be int, int32 or int64") + } + return nil +} + +func (d OracleDialect) QuoteField(f string) string { + return `"` + strings.ToUpper(f) + `"` +} + +func (d OracleDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d OracleDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d OracleDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d OracleDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} diff --git a/gdb/sqldb/dialect_postgres.go b/gdb/sqldb/dialect_postgres.go new file mode 100644 index 0000000..2e17200 --- /dev/null +++ b/gdb/sqldb/dialect_postgres.go @@ -0,0 +1,152 @@ +// +// dialect_postgres.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "fmt" + "reflect" + "strings" + "time" +) + +type PostgresDialect struct { + suffix string + LowercaseFields bool +} + +func (d PostgresDialect) QuerySuffix() string { return ";" } + +func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "boolean" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: + if isAutoIncr { + return "serial" + } + return "integer" + case reflect.Int64, reflect.Uint64: + if isAutoIncr { + return "bigserial" + } + return "bigint" + case reflect.Float64: + return "double precision" + case reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "bytea" + } + } + + switch val.Name() { + case "NullInt64": + return "bigint" + case "NullFloat64": + return "double precision" + case "NullBool": + return "boolean" + case "Time", "NullTime": + return "timestamp with time zone" + } + + if maxsize > 0 { + return fmt.Sprintf("varchar(%d)", maxsize) + } else { + return "text" + } + +} + +// Returns empty string +func (d PostgresDialect) AutoIncrStr() string { + return "" +} + +func (d PostgresDialect) AutoIncrBindValue() string { + return "default" +} + +func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return " returning " + d.QuoteField(col.ColumnName) +} + +// Returns suffix +func (d PostgresDialect) CreateTableSuffix() string { + return d.suffix +} + +func (d PostgresDialect) CreateIndexSuffix() string { + return "using" +} + +func (d PostgresDialect) DropIndexSuffix() string { + return "" +} + +func (d PostgresDialect) TruncateClause() string { + return "truncate" +} + +func (d PostgresDialect) SleepClause(s time.Duration) string { + return fmt.Sprintf("pg_sleep(%f)", s.Seconds()) +} + +// Returns "$(i+1)" +func (d PostgresDialect) BindVar(i int) string { + return fmt.Sprintf("$%d", i+1) +} + +func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error { + rows, err := exec.Query(insertSql, params...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err()) + } + if err := rows.Scan(target); err != nil { + return err + } + if rows.Next() { + return fmt.Errorf("more than two serial value returned for insert: %s", insertSql) + } + return rows.Err() +} + +func (d PostgresDialect) QuoteField(f string) string { + if d.LowercaseFields { + return `"` + strings.ToLower(f) + `"` + } + return `"` + f + `"` +} + +func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string { + if strings.TrimSpace(schema) == "" { + return d.QuoteField(table) + } + + return schema + "." + d.QuoteField(table) +} + +func (d PostgresDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d PostgresDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d PostgresDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} diff --git a/gdb/sqldb/dialect_postgres_test.go b/gdb/sqldb/dialect_postgres_test.go new file mode 100644 index 0000000..45ed541 --- /dev/null +++ b/gdb/sqldb/dialect_postgres_test.go @@ -0,0 +1,161 @@ +// +// dialect_postgres_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +//go:build !integration +// +build !integration + +package sqldb_test + +import ( + "database/sql" + "reflect" + "testing" + "time" + + "git.hexq.cn/tiglog/golib/gdb/sqldb" + "github.com/poy/onpar" + "github.com/poy/onpar/expect" + "github.com/poy/onpar/matchers" +) + +func TestPostgresDialect(t *testing.T) { + + type testContext struct { + t *testing.T + dialect sqldb.PostgresDialect + } + + o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext { + return testContext{ + t: t, + dialect: sqldb.PostgresDialect{ + LowercaseFields: false, + }, + } + }) + defer o.Run() + + o.Group("ToSqlType", func() { + tests := []struct { + name string + value interface{} + maxSize int + autoIncr bool + expected string + }{ + {"bool", true, 0, false, "boolean"}, + {"int8", int8(1), 0, false, "integer"}, + {"uint8", uint8(1), 0, false, "integer"}, + {"int16", int16(1), 0, false, "integer"}, + {"uint16", uint16(1), 0, false, "integer"}, + {"int32", int32(1), 0, false, "integer"}, + {"int (treated as int32)", int(1), 0, false, "integer"}, + {"uint32", uint32(1), 0, false, "integer"}, + {"uint (treated as uint32)", uint(1), 0, false, "integer"}, + {"int64", int64(1), 0, false, "bigint"}, + {"uint64", uint64(1), 0, false, "bigint"}, + {"float32", float32(1), 0, false, "real"}, + {"float64", float64(1), 0, false, "double precision"}, + {"[]uint8", []uint8{1}, 0, false, "bytea"}, + {"NullInt64", sql.NullInt64{}, 0, false, "bigint"}, + {"NullFloat64", sql.NullFloat64{}, 0, false, "double precision"}, + {"NullBool", sql.NullBool{}, 0, false, "boolean"}, + {"Time", time.Time{}, 0, false, "timestamp with time zone"}, + {"default-size string", "", 0, false, "text"}, + {"sized string", "", 50, false, "varchar(50)"}, + {"large string", "", 1024, false, "varchar(1024)"}, + } + for _, t := range tests { + o.Spec(t.name, func(tt testContext) { + typ := reflect.TypeOf(t.value) + sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr) + expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected)) + }) + } + }) + + o.Spec("AutoIncrStr", func(tt testContext) { + expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("")) + }) + + o.Spec("AutoIncrBindValue", func(tt testContext) { + expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("default")) + }) + + o.Spec("AutoIncrInsertSuffix", func(tt testContext) { + cm := sqldb.ColumnMap{ + ColumnName: "foo", + } + expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(&cm)).To(matchers.Equal(` returning "foo"`)) + }) + + o.Spec("CreateTableSuffix", func(tt testContext) { + expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal("")) + }) + + o.Spec("CreateIndexSuffix", func(tt testContext) { + expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using")) + }) + + o.Spec("DropIndexSuffix", func(tt testContext) { + expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("")) + }) + + o.Spec("TruncateClause", func(tt testContext) { + expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate")) + }) + + o.Spec("SleepClause", func(tt testContext) { + expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("pg_sleep(1.000000)")) + expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("pg_sleep(0.100000)")) + }) + + o.Spec("BindVar", func(tt testContext) { + expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("$1")) + expect.Expect(t, tt.dialect.BindVar(4)).To(matchers.Equal("$5")) + }) + + o.Group("QuoteField", func() { + o.Spec("By default, case is preserved", func(tt testContext) { + expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`)) + expect.Expect(t, tt.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`)) + }) + + o.Group("With LowercaseFields set to true", func() { + o1 := onpar.BeforeEach(o, func(tt testContext) testContext { + tt.dialect.LowercaseFields = true + return tt + }) + + o1.Spec("fields are lowercased", func(tt testContext) { + expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`)) + }) + }) + }) + + o.Group("QuotedTableForQuery", func() { + o.Spec("using the default schema", func(tt testContext) { + expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal(`"foo"`)) + }) + + o.Spec("with a supplied schema", func(tt testContext) { + expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`)) + }) + }) + + o.Spec("IfSchemaNotExists", func(tt testContext) { + expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists")) + }) + + o.Spec("IfTableExists", func(tt testContext) { + expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists")) + }) + + o.Spec("IfTableNotExists", func(tt testContext) { + expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists")) + }) +} diff --git a/gdb/sqldb/dialect_sqlite.go b/gdb/sqldb/dialect_sqlite.go new file mode 100644 index 0000000..72f6a72 --- /dev/null +++ b/gdb/sqldb/dialect_sqlite.go @@ -0,0 +1,115 @@ +// +// dialect_sqlite.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "fmt" + "reflect" +) + +type SqliteDialect struct { + suffix string +} + +func (d SqliteDialect) QuerySuffix() string { return ";" } + +func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string { + switch val.Kind() { + case reflect.Ptr: + return d.ToSqlType(val.Elem(), maxsize, isAutoIncr) + case reflect.Bool: + return "integer" + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return "integer" + case reflect.Float64, reflect.Float32: + return "real" + case reflect.Slice: + if val.Elem().Kind() == reflect.Uint8 { + return "blob" + } + } + + switch val.Name() { + case "NullInt64": + return "integer" + case "NullFloat64": + return "real" + case "NullBool": + return "integer" + case "Time": + return "datetime" + } + + if maxsize < 1 { + maxsize = 255 + } + return fmt.Sprintf("varchar(%d)", maxsize) +} + +// Returns autoincrement +func (d SqliteDialect) AutoIncrStr() string { + return "autoincrement" +} + +func (d SqliteDialect) AutoIncrBindValue() string { + return "null" +} + +func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string { + return "" +} + +// Returns suffix +func (d SqliteDialect) CreateTableSuffix() string { + return d.suffix +} + +func (d SqliteDialect) CreateIndexSuffix() string { + return "" +} + +func (d SqliteDialect) DropIndexSuffix() string { + return "" +} + +// With sqlite, there technically isn't a TRUNCATE statement, +// but a DELETE FROM uses a truncate optimization: +// http://www.sqlite.org/lang_delete.html +func (d SqliteDialect) TruncateClause() string { + return "delete from" +} + +// Returns "?" +func (d SqliteDialect) BindVar(i int) string { + return "?" +} + +func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) { + return standardInsertAutoIncr(exec, insertSql, params...) +} + +func (d SqliteDialect) QuoteField(f string) string { + return `"` + f + `"` +} + +// sqlite does not have schemas like PostgreSQL does, so just escape it like normal +func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string { + return d.QuoteField(table) +} + +func (d SqliteDialect) IfSchemaNotExists(command, schema string) string { + return fmt.Sprintf("%s if not exists", command) +} + +func (d SqliteDialect) IfTableExists(command, schema, table string) string { + return fmt.Sprintf("%s if exists", command) +} + +func (d SqliteDialect) IfTableNotExists(command, schema, table string) string { + return fmt.Sprintf("%s if not exists", command) +} diff --git a/gdb/sqldb/doc.go b/gdb/sqldb/doc.go new file mode 100644 index 0000000..3b84008 --- /dev/null +++ b/gdb/sqldb/doc.go @@ -0,0 +1,13 @@ +// +// doc.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +// Package sqldb provides a simple way to marshal Go structs to and from +// SQL databases. It uses the database/sql package, and should work with any +// compliant database/sql driver. +// +// Source code and project home: +package sqldb diff --git a/gdb/sqldb/errors.go b/gdb/sqldb/errors.go new file mode 100644 index 0000000..eafba19 --- /dev/null +++ b/gdb/sqldb/errors.go @@ -0,0 +1,34 @@ +// +// errors.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "fmt" +) + +// A non-fatal error, when a select query returns columns that do not exist +// as fields in the struct it is being mapped to +// TODO: discuss wether this needs an error. encoding/json silently ignores missing fields +type NoFieldInTypeError struct { + TypeName string + MissingColNames []string +} + +func (err *NoFieldInTypeError) Error() string { + return fmt.Sprintf("sqldb: no fields %+v in type %s", err.MissingColNames, err.TypeName) +} + +// returns true if the error is non-fatal (ie, we shouldn't immediately return) +func NonFatalError(err error) bool { + switch err.(type) { + case *NoFieldInTypeError: + return true + default: + return false + } +} diff --git a/gdb/sqldb/hooks.go b/gdb/sqldb/hooks.go new file mode 100644 index 0000000..07c4918 --- /dev/null +++ b/gdb/sqldb/hooks.go @@ -0,0 +1,45 @@ +// +// hooks.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +//++ TODO v2-phase3: HasPostGet => PostGetter, HasPostDelete => PostDeleter, etc. + +// HasPostGet provides PostGet() which will be executed after the GET statement. +type HasPostGet interface { + PostGet(SqlExecutor) error +} + +// HasPostDelete provides PostDelete() which will be executed after the DELETE statement +type HasPostDelete interface { + PostDelete(SqlExecutor) error +} + +// HasPostUpdate provides PostUpdate() which will be executed after the UPDATE statement +type HasPostUpdate interface { + PostUpdate(SqlExecutor) error +} + +// HasPostInsert provides PostInsert() which will be executed after the INSERT statement +type HasPostInsert interface { + PostInsert(SqlExecutor) error +} + +// HasPreDelete provides PreDelete() which will be executed before the DELETE statement. +type HasPreDelete interface { + PreDelete(SqlExecutor) error +} + +// HasPreUpdate provides PreUpdate() which will be executed before UPDATE statement. +type HasPreUpdate interface { + PreUpdate(SqlExecutor) error +} + +// HasPreInsert provides PreInsert() which will be executed before INSERT statement. +type HasPreInsert interface { + PreInsert(SqlExecutor) error +} diff --git a/gdb/sqldb/index.go b/gdb/sqldb/index.go new file mode 100644 index 0000000..c61d6fd --- /dev/null +++ b/gdb/sqldb/index.go @@ -0,0 +1,51 @@ +// +// index.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +// IndexMap represents a mapping between a Go struct field and a single +// index in a table. +// Unique and MaxSize only inform the +// CreateTables() function and are not used by Insert/Update/Delete/Get. +type IndexMap struct { + // Index name in db table + IndexName string + + // If true, " unique" is added to create index statements. + // Not used elsewhere + Unique bool + + // Index type supported by Dialect + // Postgres: B-tree, Hash, GiST and GIN. + // Mysql: Btree, Hash. + // Sqlite: nil. + IndexType string + + // Columns name for single and multiple indexes + columns []string +} + +// Rename allows you to specify the index name in the table +// +// Example: table.IndMap("customer_test_idx").Rename("customer_idx") +func (idx *IndexMap) Rename(indname string) *IndexMap { + idx.IndexName = indname + return idx +} + +// SetUnique adds "unique" to the create index statements for this +// index, if b is true. +func (idx *IndexMap) SetUnique(b bool) *IndexMap { + idx.Unique = b + return idx +} + +// SetIndexType specifies the index type supported by chousen SQL Dialect +func (idx *IndexMap) SetIndexType(indtype string) *IndexMap { + idx.IndexType = indtype + return idx +} diff --git a/gdb/sqldb/lockerror.go b/gdb/sqldb/lockerror.go new file mode 100644 index 0000000..2351b3b --- /dev/null +++ b/gdb/sqldb/lockerror.go @@ -0,0 +1,59 @@ +// +// lockerror.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "fmt" + "reflect" +) + +// OptimisticLockError is returned by Update() or Delete() if the +// struct being modified has a Version field and the value is not equal to +// the current value in the database +type OptimisticLockError struct { + // Table name where the lock error occurred + TableName string + + // Primary key values of the row being updated/deleted + Keys []interface{} + + // true if a row was found with those keys, indicating the + // LocalVersion is stale. false if no value was found with those + // keys, suggesting the row has been deleted since loaded, or + // was never inserted to begin with + RowExists bool + + // Version value on the struct passed to Update/Delete. This value is + // out of sync with the database. + LocalVersion int64 +} + +// Error returns a description of the cause of the lock error +func (e OptimisticLockError) Error() string { + if e.RowExists { + return fmt.Sprintf("sqldb: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion) + } + + return fmt.Sprintf("sqldb: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys) +} + +func lockError(m *DbMap, exec SqlExecutor, tableName string, + existingVer int64, elem reflect.Value, + keys ...interface{}) (int64, error) { + + existing, err := get(m, exec, elem.Interface(), keys...) + if err != nil { + return -1, err + } + + ole := OptimisticLockError{tableName, keys, true, existingVer} + if existing == nil { + ole.RowExists = false + } + return -1, ole +} diff --git a/gdb/sqldb/logging.go b/gdb/sqldb/logging.go new file mode 100644 index 0000000..b7e56b1 --- /dev/null +++ b/gdb/sqldb/logging.go @@ -0,0 +1,45 @@ +// +// logging.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import "fmt" + +// SqldbLogger is a deprecated alias of Logger. +type SqldbLogger = Logger + +// Logger is the type that sqldb uses to log SQL statements. +// See DbMap.TraceOn. +type Logger interface { + Printf(format string, v ...interface{}) +} + +// TraceOn turns on SQL statement logging for this DbMap. After this is +// called, all SQL statements will be sent to the logger. If prefix is +// a non-empty string, it will be written to the front of all logged +// strings, which can aid in filtering log lines. +// +// Use TraceOn if you want to spy on the SQL statements that sqldb +// generates. +// +// Note that the base log.Logger type satisfies Logger, but adapters can +// easily be written for other logging packages (e.g., the golang-sanctioned +// glog framework). +func (m *DbMap) TraceOn(prefix string, logger Logger) { + m.logger = logger + if prefix == "" { + m.logPrefix = prefix + } else { + m.logPrefix = fmt.Sprintf("%s ", prefix) + } +} + +// TraceOff turns off tracing. It is idempotent. +func (m *DbMap) TraceOff() { + m.logger = nil + m.logPrefix = "" +} diff --git a/gdb/sqldb/nulltypes.go b/gdb/sqldb/nulltypes.go new file mode 100644 index 0000000..c5c2158 --- /dev/null +++ b/gdb/sqldb/nulltypes.go @@ -0,0 +1,68 @@ +// +// nulltypes.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "database/sql/driver" + "log" + "time" +) + +// A nullable Time value +type NullTime struct { + Time time.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (nt *NullTime) Scan(value interface{}) error { + log.Printf("Time scan value is: %#v", value) + switch t := value.(type) { + case time.Time: + nt.Time, nt.Valid = t, true + case []byte: + v := strToTime(string(t)) + if v != nil { + nt.Valid = true + nt.Time = *v + } + case string: + v := strToTime(t) + if v != nil { + nt.Valid = true + nt.Time = *v + } + } + return nil +} + +func strToTime(v string) *time.Time { + for _, dtfmt := range []string{ + "2006-01-02 15:04:05.999999999", + "2006-01-02T15:04:05.999999999", + "2006-01-02 15:04:05", + "2006-01-02T15:04:05", + "2006-01-02 15:04", + "2006-01-02T15:04", + "2006-01-02", + "2006-01-02 15:04:05-07:00", + } { + if t, err := time.Parse(dtfmt, v); err == nil { + return &t + } + } + return nil +} + +// Value implements the driver Valuer interface. +func (nt NullTime) Value() (driver.Value, error) { + if !nt.Valid { + return nil, nil + } + return nt.Time, nil +} diff --git a/gdb/sqldb/select.go b/gdb/sqldb/select.go new file mode 100644 index 0000000..8b9a4a9 --- /dev/null +++ b/gdb/sqldb/select.go @@ -0,0 +1,361 @@ +// +// select.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "database/sql" + "fmt" + "reflect" +) + +// SelectInt executes the given query, which should be a SELECT statement for a single +// integer column, and returns the value of the first row returned. If no rows are +// found, zero is returned. +func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) { + var h int64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return 0, err + } + return h, nil +} + +// SelectNullInt executes the given query, which should be a SELECT statement for a single +// integer column, and returns the value of the first row returned. If no rows are +// found, the empty sql.NullInt64 value is returned. +func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) { + var h sql.NullInt64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return h, err + } + return h, nil +} + +// SelectFloat executes the given query, which should be a SELECT statement for a single +// float column, and returns the value of the first row returned. If no rows are +// found, zero is returned. +func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) { + var h float64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return 0, err + } + return h, nil +} + +// SelectNullFloat executes the given query, which should be a SELECT statement for a single +// float column, and returns the value of the first row returned. If no rows are +// found, the empty sql.NullInt64 value is returned. +func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) { + var h sql.NullFloat64 + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return h, err + } + return h, nil +} + +// SelectStr executes the given query, which should be a SELECT statement for a single +// char/varchar column, and returns the value of the first row returned. If no rows are +// found, an empty string is returned. +func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) { + var h string + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return "", err + } + return h, nil +} + +// SelectNullStr executes the given query, which should be a SELECT +// statement for a single char/varchar column, and returns the value +// of the first row returned. If no rows are found, the empty +// sql.NullString is returned. +func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) { + var h sql.NullString + err := selectVal(e, &h, query, args...) + if err != nil && err != sql.ErrNoRows { + return h, err + } + return h, nil +} + +// SelectOne executes the given query (which should be a SELECT statement) +// and binds the result to holder, which must be a pointer. +// +// # If no row is found, an error (sql.ErrNoRows specifically) will be returned +// +// If more than one row is found, an error will be returned. +func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error { + t := reflect.TypeOf(holder) + if t.Kind() == reflect.Ptr { + t = t.Elem() + } else { + return fmt.Errorf("sqldb: SelectOne holder must be a pointer, but got: %t", holder) + } + + // Handle pointer to pointer + isptr := false + if t.Kind() == reflect.Ptr { + isptr = true + t = t.Elem() + } + + if t.Kind() == reflect.Struct { + var nonFatalErr error + + list, err := hookedselect(m, e, holder, query, args...) + if err != nil { + if !NonFatalError(err) { // FIXME: double negative, rename NonFatalError to FatalError + return err + } + nonFatalErr = err + } + + dest := reflect.ValueOf(holder) + if isptr { + dest = dest.Elem() + } + + if list != nil && len(list) > 0 { // FIXME: invert if/else + // check for multiple rows + if len(list) > 1 { + return fmt.Errorf("sqldb: multiple rows returned for: %s - %v", query, args) + } + + // Initialize if nil + if dest.IsNil() { + dest.Set(reflect.New(t)) + } + + // only one row found + src := reflect.ValueOf(list[0]) + dest.Elem().Set(src.Elem()) + } else { + // No rows found, return a proper error. + return sql.ErrNoRows + } + + return nonFatalErr + } + + return selectVal(e, holder, query, args...) +} + +func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error { + if len(args) == 1 { + switch m := e.(type) { + case *DbMap: + query, args = maybeExpandNamedQuery(m, query, args) + case *Transaction: + query, args = maybeExpandNamedQuery(m.dbmap, query, args) + } + } + rows, err := e.Query(query, args...) + if err != nil { + return err + } + defer rows.Close() + + if !rows.Next() { + if err := rows.Err(); err != nil { + return err + } + return sql.ErrNoRows + } + + return rows.Scan(holder) +} + +func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string, + args ...interface{}) ([]interface{}, error) { + + var nonFatalErr error + + list, err := rawselect(m, exec, i, query, args...) + if err != nil { + if !NonFatalError(err) { + return nil, err + } + nonFatalErr = err + } + + // Determine where the results are: written to i, or returned in list + if t, _ := toSliceType(i); t == nil { + for _, v := range list { + if v, ok := v.(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } + } + } + } else { + resultsValue := reflect.Indirect(reflect.ValueOf(i)) + for i := 0; i < resultsValue.Len(); i++ { + if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } + } + } + } + return list, nonFatalErr +} + +func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string, + args ...interface{}) ([]interface{}, error) { + var ( + appendToSlice = false // Write results to i directly? + intoStruct = true // Selecting into a struct? + pointerElements = true // Are the slice elements pointers (vs values)? + ) + + var nonFatalErr error + + tableName := "" + var dynObj DynamicTable + isDynamic := false + if dynObj, isDynamic = i.(DynamicTable); isDynamic { + tableName = dynObj.TableName() + } + + // get type for i, verifying it's a supported destination + t, err := toType(i) + if err != nil { + var err2 error + if t, err2 = toSliceType(i); t == nil { + if err2 != nil { + return nil, err2 + } + return nil, err + } + pointerElements = t.Kind() == reflect.Ptr + if pointerElements { + t = t.Elem() + } + appendToSlice = true + intoStruct = t.Kind() == reflect.Struct + } + + // If the caller supplied a single struct/map argument, assume a "named + // parameter" query. Extract the named arguments from the struct/map, create + // the flat arg slice, and rewrite the query to use the dialect's placeholder. + if len(args) == 1 { + query, args = maybeExpandNamedQuery(m, query, args) + } + + // Run the query + rows, err := exec.Query(query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + // Fetch the column names as returned from db + cols, err := rows.Columns() + if err != nil { + return nil, err + } + + if !intoStruct && len(cols) > 1 { + return nil, fmt.Errorf("sqldb: select into non-struct slice requires 1 column, got %d", len(cols)) + } + + var colToFieldIndex [][]int + if intoStruct { + colToFieldIndex, err = columnToFieldIndex(m, t, tableName, cols) + if err != nil { + if !NonFatalError(err) { + return nil, err + } + nonFatalErr = err + } + } + + conv := m.TypeConverter + + // Add results to one of these two slices. + var ( + list = make([]interface{}, 0) + sliceValue = reflect.Indirect(reflect.ValueOf(i)) + ) + + for { + if !rows.Next() { + // if error occured return rawselect + if rows.Err() != nil { + return nil, rows.Err() + } + // time to exit from outer "for" loop + break + } + v := reflect.New(t) + + if isDynamic { + v.Interface().(DynamicTable).SetTableName(tableName) + } + + dest := make([]interface{}, len(cols)) + + custScan := make([]CustomScanner, 0) + + for x := range cols { + f := v.Elem() + if intoStruct { + index := colToFieldIndex[x] + if index == nil { + // this field is not present in the struct, so create a dummy + // value for rows.Scan to scan into + var dummy dummyField + dest[x] = &dummy + continue + } + f = f.FieldByIndex(index) + } + target := f.Addr().Interface() + if conv != nil { + scanner, ok := conv.FromDb(target) + if ok { + target = scanner.Holder + custScan = append(custScan, scanner) + } + } + dest[x] = target + } + + err = rows.Scan(dest...) + if err != nil { + return nil, err + } + + for _, c := range custScan { + err = c.Bind() + if err != nil { + return nil, err + } + } + + if appendToSlice { + if !pointerElements { + v = v.Elem() + } + sliceValue.Set(reflect.Append(sliceValue, v)) + } else { + list = append(list, v.Interface()) + } + } + + if appendToSlice && sliceValue.IsNil() { + sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0)) + } + + return list, nonFatalErr +} diff --git a/gdb/sqldb/sqldb.go b/gdb/sqldb/sqldb.go new file mode 100644 index 0000000..bae8e17 --- /dev/null +++ b/gdb/sqldb/sqldb.go @@ -0,0 +1,675 @@ +// +// sqldb.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "context" + "database/sql" + "database/sql/driver" + "fmt" + "reflect" + "regexp" + "strings" + "time" +) + +// OracleString (empty string is null) +// TODO: move to dialect/oracle?, rename to String? +type OracleString struct { + sql.NullString +} + +// Scan implements the Scanner interface. +func (os *OracleString) Scan(value interface{}) error { + if value == nil { + os.String, os.Valid = "", false + return nil + } + os.Valid = true + return os.NullString.Scan(value) +} + +// Value implements the driver Valuer interface. +func (os OracleString) Value() (driver.Value, error) { + if !os.Valid || os.String == "" { + return nil, nil + } + return os.String, nil +} + +// SqlTyper is a type that returns its database type. Most of the +// time, the type can just use "database/sql/driver".Valuer; but when +// it returns nil for its empty value, it needs to implement SqlTyper +// to have its column type detected properly during table creation. +type SqlTyper interface { + SqlType() driver.Value +} + +// legacySqlTyper prevents breaking clients who depended on the previous +// SqlTyper interface +type legacySqlTyper interface { + SqlType() driver.Valuer +} + +// for fields that exists in DB table, but not exists in struct +type dummyField struct{} + +// Scan implements the Scanner interface. +func (nt *dummyField) Scan(value interface{}) error { + return nil +} + +var zeroVal reflect.Value +var versFieldConst = "[sqldb_ver_field]" + +// The TypeConverter interface provides a way to map a value of one +// type to another type when persisting to, or loading from, a database. +// +// Example use cases: Implement type converter to convert bool types to "y"/"n" strings, +// or serialize a struct member as a JSON blob. +type TypeConverter interface { + // ToDb converts val to another type. Called before INSERT/UPDATE operations + ToDb(val interface{}) (interface{}, error) + + // FromDb returns a CustomScanner appropriate for this type. This will be used + // to hold values returned from SELECT queries. + // + // In particular the CustomScanner returned should implement a Binder + // function appropriate for the Go type you wish to convert the db value to + // + // If bool==false, then no custom scanner will be used for this field. + FromDb(target interface{}) (CustomScanner, bool) +} + +// SqlExecutor exposes sqldb operations that can be run from Pre/Post +// hooks. This hides whether the current operation that triggered the +// hook is in a transaction. +// +// See the DbMap function docs for each of the functions below for more +// information. +type SqlExecutor interface { + WithContext(ctx context.Context) SqlExecutor + Get(i interface{}, keys ...interface{}) (interface{}, error) + Insert(list ...interface{}) error + Update(list ...interface{}) (int64, error) + Delete(list ...interface{}) (int64, error) + Exec(query string, args ...interface{}) (sql.Result, error) + Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) + SelectInt(query string, args ...interface{}) (int64, error) + SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) + SelectFloat(query string, args ...interface{}) (float64, error) + SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) + SelectStr(query string, args ...interface{}) (string, error) + SelectNullStr(query string, args ...interface{}) (sql.NullString, error) + SelectOne(holder interface{}, query string, args ...interface{}) error + Query(query string, args ...interface{}) (*sql.Rows, error) + QueryRow(query string, args ...interface{}) *sql.Row +} + +// DynamicTable allows the users of sqldb to dynamically +// use different database table names during runtime +// while sharing the same golang struct for in-memory data +type DynamicTable interface { + TableName() string + SetTableName(string) +} + +// Compile-time check that DbMap and Transaction implement the SqlExecutor +// interface. +var _, _ SqlExecutor = &DbMap{}, &Transaction{} + +func argValue(a interface{}) interface{} { + v, ok := a.(driver.Valuer) + if !ok { + return a + } + vV := reflect.ValueOf(v) + if vV.Kind() == reflect.Ptr && vV.IsNil() { + return nil + } + ret, err := v.Value() + if err != nil { + return a + } + return ret +} + +func argsString(args ...interface{}) string { + var margs string + for i, a := range args { + v := argValue(a) + switch v.(type) { + case string: + v = fmt.Sprintf("%q", v) + default: + v = fmt.Sprintf("%v", v) + } + margs += fmt.Sprintf("%d:%s", i+1, v) + if i+1 < len(args) { + margs += " " + } + } + return margs +} + +// Calls the Exec function on the executor, but attempts to expand any eligible named +// query arguments first. +func maybeExpandNamedQueryAndExec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) { + dbMap := extractDbMap(e) + + if len(args) == 1 { + query, args = maybeExpandNamedQuery(dbMap, query, args) + } + + return exec(e, query, args...) +} + +func extractDbMap(e SqlExecutor) *DbMap { + switch m := e.(type) { + case *DbMap: + return m + case *Transaction: + return m.dbmap + } + return nil +} + +// executor exposes the sql.DB and sql.Tx functions so that it can be used +// on internal functions that need to be agnostic to the underlying object. +type executor interface { + Exec(query string, args ...interface{}) (sql.Result, error) + Prepare(query string) (*sql.Stmt, error) + QueryRow(query string, args ...interface{}) *sql.Row + Query(query string, args ...interface{}) (*sql.Rows, error) + ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) + QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row + QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) +} + +func extractExecutorAndContext(e SqlExecutor) (executor, context.Context) { + switch m := e.(type) { + case *DbMap: + return m.Db, m.ctx + case *Transaction: + return m.tx, m.ctx + } + return nil, nil +} + +// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used +// as input to a named query. If so, it rewrites the query to use +// dialect-dependent bindvars and instantiates the corresponding slice of +// parameters by extracting data from the map / struct. +// If not, returns the input values unchanged. +func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) { + var ( + arg = args[0] + argval = reflect.ValueOf(arg) + ) + if argval.Kind() == reflect.Ptr { + argval = argval.Elem() + } + + if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String { + return expandNamedQuery(m, query, func(key string) reflect.Value { + return argval.MapIndex(reflect.ValueOf(key)) + }) + } + if argval.Kind() != reflect.Struct { + return query, args + } + if _, ok := arg.(time.Time); ok { + // time.Time is driver.Value + return query, args + } + if _, ok := arg.(driver.Valuer); ok { + // driver.Valuer will be converted to driver.Value. + return query, args + } + + return expandNamedQuery(m, query, argval.FieldByName) +} + +var keyRegexp = regexp.MustCompile(`:[[:word:]]+`) + +// expandNamedQuery accepts a query with placeholders of the form ":key", and a +// single arg of Kind Struct or Map[string]. It returns the query with the +// dialect's placeholders, and a slice of args ready for positional insertion +// into the query. +func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) { + var ( + n int + args []interface{} + ) + return keyRegexp.ReplaceAllStringFunc(query, func(key string) string { + val := keyGetter(key[1:]) + if !val.IsValid() { + return key + } + args = append(args, val.Interface()) + newVar := m.Dialect.BindVar(n) + n++ + return newVar + }), args +} + +func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) { + colToFieldIndex := make([][]int, len(cols)) + + // check if type t is a mapped table - if so we'll + // check the table for column aliasing below + tableMapped := false + table := tableOrNil(m, t, name) + if table != nil { + tableMapped = true + } + + // Loop over column names and find field in i to bind to + // based on column name. all returned columns must match + // a field in the i struct + missingColNames := []string{} + for x := range cols { + colName := strings.ToLower(cols[x]) + field, found := t.FieldByNameFunc(func(fieldName string) bool { + field, _ := t.FieldByName(fieldName) + cArguments := strings.Split(field.Tag.Get("db"), ",") + fieldName = cArguments[0] + + if fieldName == "-" { + return false + } else if fieldName == "" { + fieldName = field.Name + } + if tableMapped { + colMap := colMapOrNil(table, fieldName) + if colMap != nil { + fieldName = colMap.ColumnName + } + } + return colName == strings.ToLower(fieldName) + }) + if found { + colToFieldIndex[x] = field.Index + } + if colToFieldIndex[x] == nil { + missingColNames = append(missingColNames, colName) + } + } + if len(missingColNames) > 0 { + return colToFieldIndex, &NoFieldInTypeError{ + TypeName: t.Name(), + MissingColNames: missingColNames, + } + } + return colToFieldIndex, nil +} + +func fieldByName(val reflect.Value, fieldName string) *reflect.Value { + // try to find field by exact match + f := val.FieldByName(fieldName) + + if f != zeroVal { + return &f + } + + // try to find by case insensitive match - only the Postgres driver + // seems to require this - in the case where columns are aliased in the sql + fieldNameL := strings.ToLower(fieldName) + fieldCount := val.NumField() + t := val.Type() + for i := 0; i < fieldCount; i++ { + sf := t.Field(i) + if strings.ToLower(sf.Name) == fieldNameL { + f := val.Field(i) + return &f + } + } + + return nil +} + +// toSliceType returns the element type of the given object, if the object is a +// "*[]*Element" or "*[]Element". If not, returns nil. +// err is returned if the user was trying to pass a pointer-to-slice but failed. +func toSliceType(i interface{}) (reflect.Type, error) { + t := reflect.TypeOf(i) + if t.Kind() != reflect.Ptr { + // If it's a slice, return a more helpful error message + if t.Kind() == reflect.Slice { + return nil, fmt.Errorf("sqldb: cannot SELECT into a non-pointer slice: %v", t) + } + return nil, nil + } + if t = t.Elem(); t.Kind() != reflect.Slice { + return nil, nil + } + return t.Elem(), nil +} + +func toType(i interface{}) (reflect.Type, error) { + t := reflect.TypeOf(i) + + // If a Pointer to a type, follow + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + if t.Kind() != reflect.Struct { + return nil, fmt.Errorf("sqldb: cannot SELECT into this type: %v", reflect.TypeOf(i)) + } + return t, nil +} + +type foundTable struct { + table *TableMap + dynName *string +} + +func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) { + if dyn, isDynamic := i.(DynamicTable); isDynamic { + tableName := dyn.TableName() + table, err := m.DynamicTableFor(tableName, true) + if err != nil { + return nil, err + } + return &foundTable{ + table: table, + dynName: &tableName, + }, nil + } + table, err := m.TableFor(t, true) + if err != nil { + return nil, err + } + return &foundTable{table: table}, nil +} + +func get(m *DbMap, exec SqlExecutor, i interface{}, + keys ...interface{}) (interface{}, error) { + + t, err := toType(i) + if err != nil { + return nil, err + } + + foundTable, err := tableFor(m, t, i) + if err != nil { + return nil, err + } + table := foundTable.table + + plan := table.bindGet() + + v := reflect.New(t) + if foundTable.dynName != nil { + retDyn := v.Interface().(DynamicTable) + retDyn.SetTableName(*foundTable.dynName) + } + + dest := make([]interface{}, len(plan.argFields)) + + conv := m.TypeConverter + custScan := make([]CustomScanner, 0) + + for x, fieldName := range plan.argFields { + f := v.Elem().FieldByName(fieldName) + target := f.Addr().Interface() + if conv != nil { + scanner, ok := conv.FromDb(target) + if ok { + target = scanner.Holder + custScan = append(custScan, scanner) + } + } + dest[x] = target + } + + row := exec.QueryRow(plan.query, keys...) + err = row.Scan(dest...) + if err != nil { + if err == sql.ErrNoRows { + err = nil + } + return nil, err + } + + for _, c := range custScan { + err = c.Bind() + if err != nil { + return nil, err + } + } + + if v, ok := v.Interface().(HasPostGet); ok { + err := v.PostGet(exec) + if err != nil { + return nil, err + } + } + + return v.Interface(), nil +} + +func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) { + count := int64(0) + for _, ptr := range list { + table, elem, err := m.tableForPointer(ptr, true) + if err != nil { + return -1, err + } + + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreDelete); ok { + err = v.PreDelete(exec) + if err != nil { + return -1, err + } + } + + bi, err := table.bindDelete(elem) + if err != nil { + return -1, err + } + + res, err := exec.Exec(bi.query, bi.args...) + if err != nil { + return -1, err + } + rows, err := res.RowsAffected() + if err != nil { + return -1, err + } + + if rows == 0 && bi.existingVersion > 0 { + return lockError(m, exec, table.TableName, + bi.existingVersion, elem, bi.keys...) + } + + count += rows + + if v, ok := eval.(HasPostDelete); ok { + err := v.PostDelete(exec) + if err != nil { + return -1, err + } + } + } + + return count, nil +} + +func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) { + count := int64(0) + for _, ptr := range list { + table, elem, err := m.tableForPointer(ptr, true) + if err != nil { + return -1, err + } + + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreUpdate); ok { + err = v.PreUpdate(exec) + if err != nil { + return -1, err + } + } + + bi, err := table.bindUpdate(elem, colFilter) + if err != nil { + return -1, err + } + + res, err := exec.Exec(bi.query, bi.args...) + if err != nil { + return -1, err + } + + rows, err := res.RowsAffected() + if err != nil { + return -1, err + } + + if rows == 0 && bi.existingVersion > 0 { + return lockError(m, exec, table.TableName, + bi.existingVersion, elem, bi.keys...) + } + + if bi.versField != "" { + elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1) + } + + count += rows + + if v, ok := eval.(HasPostUpdate); ok { + err = v.PostUpdate(exec) + if err != nil { + return -1, err + } + } + } + return count, nil +} + +func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error { + for _, ptr := range list { + table, elem, err := m.tableForPointer(ptr, false) + if err != nil { + return err + } + + eval := elem.Addr().Interface() + if v, ok := eval.(HasPreInsert); ok { + err := v.PreInsert(exec) + if err != nil { + return err + } + } + + bi, err := table.bindInsert(elem) + if err != nil { + return err + } + + if bi.autoIncrIdx > -1 { + f := elem.FieldByName(bi.autoIncrFieldName) + switch inserter := m.Dialect.(type) { + case IntegerAutoIncrInserter: + id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...) + if err != nil { + return err + } + k := f.Kind() + if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) { + f.SetInt(id) + } else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) { + f.SetUint(uint64(id)) + } else { + return fmt.Errorf("sqldb: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName) + } + case TargetedAutoIncrInserter: + err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...) + if err != nil { + return err + } + case TargetQueryInserter: + var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery + if idQuery == "" { + return fmt.Errorf("sqldb: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName) + } + err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...) + if err != nil { + return err + } + default: + return fmt.Errorf("sqldb: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface") + } + } else { + _, err := exec.Exec(bi.query, bi.args...) + if err != nil { + return err + } + } + + if v, ok := eval.(HasPostInsert); ok { + err := v.PostInsert(exec) + if err != nil { + return err + } + } + } + return nil +} + +func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) { + executor, ctx := extractExecutorAndContext(e) + + if ctx != nil { + return executor.ExecContext(ctx, query, args...) + } + + return executor.Exec(query, args...) +} + +func prepare(e SqlExecutor, query string) (*sql.Stmt, error) { + executor, ctx := extractExecutorAndContext(e) + + if ctx != nil { + return executor.PrepareContext(ctx, query) + } + + return executor.Prepare(query) +} + +func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row { + executor, ctx := extractExecutorAndContext(e) + + if ctx != nil { + return executor.QueryRowContext(ctx, query, args...) + } + + return executor.QueryRow(query, args...) +} + +func query(e SqlExecutor, query string, args ...interface{}) (*sql.Rows, error) { + executor, ctx := extractExecutorAndContext(e) + + if ctx != nil { + return executor.QueryContext(ctx, query, args...) + } + + return executor.Query(query, args...) +} + +func begin(m *DbMap) (*sql.Tx, error) { + if m.ctx != nil { + return m.Db.BeginTx(m.ctx, nil) + } + + return m.Db.Begin() +} diff --git a/gdb/sqldb/sqldb_test.go b/gdb/sqldb/sqldb_test.go new file mode 100644 index 0000000..4d390e1 --- /dev/null +++ b/gdb/sqldb/sqldb_test.go @@ -0,0 +1,2875 @@ +// +// sqldb_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +//go:build integration +// +build integration + +package sqldb_test + +import ( + "bytes" + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "flag" + "fmt" + "log" + "math/rand" + "os" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "git.hexq.cn/tiglog/golib/gdb/sqldb" + _ "github.com/go-sql-driver/mysql" + _ "github.com/lib/pq" + _ "github.com/mattn/go-sqlite3" +) + +var ( + // verify interface compliance + _ = []sqldb.Dialect{ + sqldb.SqliteDialect{}, + sqldb.PostgresDialect{}, + sqldb.MySQLDialect{}, + sqldb.OracleDialect{}, + } + + debug bool +) + +func TestMain(m *testing.M) { + flag.BoolVar(&debug, "trace", true, "Turn on or off database tracing (DbMap.TraceOn)") + flag.Parse() + os.Exit(m.Run()) +} + +type testable interface { + GetId() int64 + Rand() +} + +type Invoice struct { + Id int64 + Created int64 + Updated int64 + Memo string + PersonId int64 + IsPaid bool +} + +type InvoiceWithValuer struct { + Id int64 + Created int64 + Updated int64 + Memo string + Person PersonValuerScanner `db:"personid"` + IsPaid bool +} + +func (me *Invoice) GetId() int64 { return me.Id } +func (me *Invoice) Rand() { + me.Memo = fmt.Sprintf("random %d", rand.Int63()) + me.Created = rand.Int63() + me.Updated = rand.Int63() +} + +type InvoiceTag struct { + Id int64 `db:"myid, primarykey, autoincrement"` + Created int64 `db:"myCreated"` + Updated int64 `db:"date_updated"` + Memo string + PersonId int64 `db:"person_id"` + IsPaid bool `db:"is_Paid"` +} + +func (me *InvoiceTag) GetId() int64 { return me.Id } +func (me *InvoiceTag) Rand() { + me.Memo = fmt.Sprintf("random %d", rand.Int63()) + me.Created = rand.Int63() + me.Updated = rand.Int63() +} + +// See: https://github.com/go-sqldb/sqldb/issues/175 +type AliasTransientField struct { + Id int64 `db:"id"` + Bar int64 `db:"-"` + BarStr string `db:"bar"` +} + +func (me *AliasTransientField) GetId() int64 { return me.Id } +func (me *AliasTransientField) Rand() { + me.BarStr = fmt.Sprintf("random %d", rand.Int63()) +} + +type OverriddenInvoice struct { + Invoice + Id string +} + +type Person struct { + Id int64 + Created int64 + Updated int64 + FName string + LName string + Version int64 +} + +// PersonValuerScanner is used as a field in test types to ensure that we +// make use of "database/sql/driver".Valuer for choosing column types when +// creating tables and that we don't get in the way of the underlying +// database libraries when they make use of either Valuer or +// "database/sql".Scanner. +type PersonValuerScanner struct { + Person +} + +// Value implements "database/sql/driver".Valuer. It will be automatically +// run by the "database/sql" package when inserting/updating data. +func (p PersonValuerScanner) Value() (driver.Value, error) { + return p.Id, nil +} + +// Scan implements "database/sql".Scanner. It will be automatically run +// by the "database/sql" package when reading column data into a field +// of type PersonValuerScanner. +func (p *PersonValuerScanner) Scan(value interface{}) (err error) { + switch src := value.(type) { + case []byte: + // TODO: this case is here for mysql only. For some reason, + // one (both?) of the mysql libraries opt to pass us a []byte + // instead of an int64 for the bigint column. We should add + // table tests around valuers/scanners and try to solve these + // types of odd discrepencies to make it easier for users of + // sqldb to migrate to other database engines. + p.Id, err = strconv.ParseInt(string(src), 10, 64) + case int64: + // Most libraries pass in the type we'd expect. + p.Id = src + default: + typ := reflect.TypeOf(value) + return fmt.Errorf("Expected person value to be convertible to int64, got %v (type %s)", value, typ) + } + return +} + +type FNameOnly struct { + FName string +} + +type InvoicePersonView struct { + InvoiceId int64 + PersonId int64 + Memo string + FName string + LegacyVersion int64 +} + +type TableWithNull struct { + Id int64 + Str sql.NullString + Int64 sql.NullInt64 + Float64 sql.NullFloat64 + Bool sql.NullBool + Bytes []byte +} + +type WithIgnoredColumn struct { + internal int64 `db:"-"` + Id int64 + Created int64 +} + +type IdCreated struct { + Id int64 + Created int64 +} + +type IdCreatedExternal struct { + IdCreated + External int64 +} + +type WithStringPk struct { + Id string + Name string +} + +type CustomStringType string + +type TypeConversionExample struct { + Id int64 + PersonJSON Person + Name CustomStringType +} + +type PersonUInt32 struct { + Id uint32 + Name string +} + +type PersonUInt64 struct { + Id uint64 + Name string +} + +type PersonUInt16 struct { + Id uint16 + Name string +} + +type WithEmbeddedStruct struct { + Id int64 + Names +} + +type WithEmbeddedStructConflictingEmbeddedMemberNames struct { + Id int64 + Names + NamesConflict +} + +type WithEmbeddedStructSameMemberName struct { + Id int64 + SameName +} + +type WithEmbeddedStructBeforeAutoincrField struct { + Names + Id int64 +} + +type WithEmbeddedAutoincr struct { + WithEmbeddedStruct + MiddleName string +} + +type Names struct { + FirstName string + LastName string +} + +type NamesConflict struct { + FirstName string + Surname string +} + +type SameName struct { + SameName string +} + +type UniqueColumns struct { + FirstName string + LastName string + City string + ZipCode int64 +} + +type SingleColumnTable struct { + SomeId string +} + +type CustomDate struct { + time.Time +} + +type WithCustomDate struct { + Id int64 + Added CustomDate +} + +type WithNullTime struct { + Id int64 + Time sqldb.NullTime +} + +type testTypeConverter struct{} + +func (me testTypeConverter) ToDb(val interface{}) (interface{}, error) { + + switch t := val.(type) { + case Person: + b, err := json.Marshal(t) + if err != nil { + return "", err + } + return string(b), nil + case CustomStringType: + return string(t), nil + case CustomDate: + return t.Time, nil + } + + return val, nil +} + +func (me testTypeConverter) FromDb(target interface{}) (sqldb.CustomScanner, bool) { + switch target.(type) { + case *Person: + binder := func(holder, target interface{}) error { + s, ok := holder.(*string) + if !ok { + return errors.New("FromDb: Unable to convert Person to *string") + } + b := []byte(*s) + return json.Unmarshal(b, target) + } + return sqldb.CustomScanner{new(string), target, binder}, true + case *CustomStringType: + binder := func(holder, target interface{}) error { + s, ok := holder.(*string) + if !ok { + return errors.New("FromDb: Unable to convert CustomStringType to *string") + } + st, ok := target.(*CustomStringType) + if !ok { + return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomStringType: ", reflect.TypeOf(target))) + } + *st = CustomStringType(*s) + return nil + } + return sqldb.CustomScanner{new(string), target, binder}, true + case *CustomDate: + binder := func(holder, target interface{}) error { + t, ok := holder.(*time.Time) + if !ok { + return errors.New("FromDb: Unable to convert CustomDate to *time.Time") + } + dateTarget, ok := target.(*CustomDate) + if !ok { + return errors.New(fmt.Sprint("FromDb: Unable to convert target to *CustomDate: ", reflect.TypeOf(target))) + } + dateTarget.Time = *t + return nil + } + return sqldb.CustomScanner{new(time.Time), target, binder}, true + } + + return sqldb.CustomScanner{}, false +} + +func (p *Person) PreInsert(s sqldb.SqlExecutor) error { + p.Created = time.Now().UnixNano() + p.Updated = p.Created + if p.FName == "badname" { + return fmt.Errorf("Invalid name: %s", p.FName) + } + return nil +} + +func (p *Person) PostInsert(s sqldb.SqlExecutor) error { + p.LName = "postinsert" + return nil +} + +func (p *Person) PreUpdate(s sqldb.SqlExecutor) error { + p.FName = "preupdate" + return nil +} + +func (p *Person) PostUpdate(s sqldb.SqlExecutor) error { + p.LName = "postupdate" + return nil +} + +func (p *Person) PreDelete(s sqldb.SqlExecutor) error { + p.FName = "predelete" + return nil +} + +func (p *Person) PostDelete(s sqldb.SqlExecutor) error { + p.LName = "postdelete" + return nil +} + +func (p *Person) PostGet(s sqldb.SqlExecutor) error { + p.LName = "postget" + return nil +} + +type PersistentUser struct { + Key int32 + Id string + PassedTraining bool +} + +type TenantDynamic struct { + Id int64 `db:"id"` + Name string + Address string + curTable string `db:"-"` +} + +func (curObj *TenantDynamic) TableName() string { + return curObj.curTable +} +func (curObj *TenantDynamic) SetTableName(tblName string) { + curObj.curTable = tblName +} + +var dynTableInst1 = TenantDynamic{curTable: "t_1_tenant_dynamic"} +var dynTableInst2 = TenantDynamic{curTable: "t_2_tenant_dynamic"} + +func dynamicTablesTest(t *testing.T, dbmap *sqldb.DbMap) { + + dynamicTablesTestTableMap(t, dbmap, &dynTableInst1) + dynamicTablesTestTableMap(t, dbmap, &dynTableInst2) + + // TEST - dbmap.Insert using dynTableInst1 + dynTableInst1.Name = "Test Name 1" + dynTableInst1.Address = "Test Address 1" + err := dbmap.Insert(&dynTableInst1) + if err != nil { + t.Errorf("Errow while saving dynTableInst1. Details: %v", err) + } + + // TEST - dbmap.Insert using dynTableInst2 + dynTableInst2.Name = "Test Name 2" + dynTableInst2.Address = "Test Address 2" + err = dbmap.Insert(&dynTableInst2) + if err != nil { + t.Errorf("Errow while saving dynTableInst2. Details: %v", err) + } + + dynamicTablesTestSelect(t, dbmap, &dynTableInst1) + dynamicTablesTestSelect(t, dbmap, &dynTableInst2) + dynamicTablesTestSelectOne(t, dbmap, &dynTableInst1) + dynamicTablesTestSelectOne(t, dbmap, &dynTableInst2) + dynamicTablesTestGetUpdateGet(t, dbmap, &dynTableInst1) + dynamicTablesTestGetUpdateGet(t, dbmap, &dynTableInst2) + dynamicTablesTestDelete(t, dbmap, &dynTableInst1) + dynamicTablesTestDelete(t, dbmap, &dynTableInst2) + +} + +func dynamicTablesTestTableMap(t *testing.T, + dbmap *sqldb.DbMap, + inpInst *TenantDynamic) { + + tableName := inpInst.TableName() + + tblMap, err := dbmap.DynamicTableFor(tableName, true) + if err != nil { + t.Errorf("Error while searching for tablemap for tableName: %v, Error:%v", tableName, err) + } + if tblMap == nil { + t.Errorf("Unable to find tablemap for tableName:%v", tableName) + } +} + +func dynamicTablesTestSelect(t *testing.T, + dbmap *sqldb.DbMap, + inpInst *TenantDynamic) { + + // TEST - dbmap.Select using inpInst + + // read the data back from dynInst to see if the + // table mapping is correct + var dbTenantInst1 = TenantDynamic{curTable: inpInst.curTable} + selectSQL1 := "select * from " + inpInst.curTable + dbObjs, err := dbmap.Select(&dbTenantInst1, selectSQL1) + if err != nil { + t.Errorf("Errow in dbmap.Select. SQL: %v, Details: %v", selectSQL1, err) + } + if dbObjs == nil { + t.Fatalf("Nil return from dbmap.Select") + } + rwCnt := len(dbObjs) + if rwCnt != 1 { + t.Errorf("Unexpected row count for tenantInst:%v", rwCnt) + } + + dbInst := dbObjs[0].(*TenantDynamic) + + inpTableName := inpInst.TableName() + resTableName := dbInst.TableName() + if inpTableName != resTableName { + t.Errorf("Mismatched table names %v != %v ", + inpTableName, resTableName) + } + + if inpInst.Id != dbInst.Id { + t.Errorf("Mismatched Id values %v != %v ", + inpInst.Id, dbInst.Id) + } + + if inpInst.Name != dbInst.Name { + t.Errorf("Mismatched Name values %v != %v ", + inpInst.Name, dbInst.Name) + } + + if inpInst.Address != dbInst.Address { + t.Errorf("Mismatched Address values %v != %v ", + inpInst.Address, dbInst.Address) + } +} + +func dynamicTablesTestGetUpdateGet(t *testing.T, + dbmap *sqldb.DbMap, + inpInst *TenantDynamic) { + + // TEST - dbmap.Get, dbmap.Update, dbmap.Get sequence + + // read and update one of the instances to make sure + // that the common sqldb APIs are working well with dynamic table + var inpIface2 = TenantDynamic{curTable: inpInst.curTable} + dbObj, err := dbmap.Get(&inpIface2, inpInst.Id) + if err != nil { + t.Errorf("Errow in dbmap.Get. id: %v, Details: %v", inpInst.Id, err) + } + if dbObj == nil { + t.Errorf("Nil return from dbmap.Get") + } + + dbInst := dbObj.(*TenantDynamic) + + { + inpTableName := inpInst.TableName() + resTableName := dbInst.TableName() + if inpTableName != resTableName { + t.Errorf("Mismatched table names %v != %v ", + inpTableName, resTableName) + } + + if inpInst.Id != dbInst.Id { + t.Errorf("Mismatched Id values %v != %v ", + inpInst.Id, dbInst.Id) + } + + if inpInst.Name != dbInst.Name { + t.Errorf("Mismatched Name values %v != %v ", + inpInst.Name, dbInst.Name) + } + + if inpInst.Address != dbInst.Address { + t.Errorf("Mismatched Address values %v != %v ", + inpInst.Address, dbInst.Address) + } + } + + { + updatedName := "Testing Updated Name2" + dbInst.Name = updatedName + cnt, err := dbmap.Update(dbInst) + if err != nil { + t.Errorf("Error from dbmap.Update: %v", err.Error()) + } + if cnt != 1 { + t.Errorf("Update count must be 1, got %v", cnt) + } + + // Read the object again to make sure that the + // data was updated in db + dbObj2, err := dbmap.Get(&inpIface2, inpInst.Id) + if err != nil { + t.Errorf("Errow in dbmap.Get. id: %v, Details: %v", inpInst.Id, err) + } + if dbObj2 == nil { + t.Errorf("Nil return from dbmap.Get") + } + + dbInst2 := dbObj2.(*TenantDynamic) + + inpTableName := inpInst.TableName() + resTableName := dbInst2.TableName() + if inpTableName != resTableName { + t.Errorf("Mismatched table names %v != %v ", + inpTableName, resTableName) + } + + if inpInst.Id != dbInst2.Id { + t.Errorf("Mismatched Id values %v != %v ", + inpInst.Id, dbInst2.Id) + } + + if updatedName != dbInst2.Name { + t.Errorf("Mismatched Name values %v != %v ", + updatedName, dbInst2.Name) + } + + if inpInst.Address != dbInst.Address { + t.Errorf("Mismatched Address values %v != %v ", + inpInst.Address, dbInst.Address) + } + + } +} + +func dynamicTablesTestSelectOne(t *testing.T, + dbmap *sqldb.DbMap, + inpInst *TenantDynamic) { + + // TEST - dbmap.SelectOne + + // read the data back from inpInst to see if the + // table mapping is correct + var dbTenantInst1 = TenantDynamic{curTable: inpInst.curTable} + selectSQL1 := "select * from " + dbTenantInst1.curTable + " where id = :idKey" + params := map[string]interface{}{"idKey": inpInst.Id} + err := dbmap.SelectOne(&dbTenantInst1, selectSQL1, params) + if err != nil { + t.Errorf("Errow in dbmap.SelectOne. SQL: %v, Details: %v", selectSQL1, err) + } + + inpTableName := inpInst.curTable + resTableName := dbTenantInst1.TableName() + if inpTableName != resTableName { + t.Errorf("Mismatched table names %v != %v ", + inpTableName, resTableName) + } + + if inpInst.Id != dbTenantInst1.Id { + t.Errorf("Mismatched Id values %v != %v ", + inpInst.Id, dbTenantInst1.Id) + } + + if inpInst.Name != dbTenantInst1.Name { + t.Errorf("Mismatched Name values %v != %v ", + inpInst.Name, dbTenantInst1.Name) + } + + if inpInst.Address != dbTenantInst1.Address { + t.Errorf("Mismatched Address values %v != %v ", + inpInst.Address, dbTenantInst1.Address) + } +} + +func dynamicTablesTestDelete(t *testing.T, + dbmap *sqldb.DbMap, + inpInst *TenantDynamic) { + + // TEST - dbmap.Delete + cnt, err := dbmap.Delete(inpInst) + if err != nil { + t.Errorf("Errow in dbmap.Delete. Details: %v", err) + } + if cnt != 1 { + t.Errorf("Expected delete count for %v : 1, found count:%v", + inpInst.TableName(), cnt) + } + + // Try reading again to make sure instance is gone from db + getInst := TenantDynamic{curTable: inpInst.TableName()} + dbInst, err := dbmap.Get(&getInst, inpInst.Id) + if err != nil { + t.Errorf("Error while trying to read deleted %v object using id: %v", + inpInst.TableName(), inpInst.Id) + } + + if dbInst != nil { + t.Errorf("Found deleted %v instance using id: %v", + inpInst.TableName(), inpInst.Id) + } + + if getInst.Name != "" { + t.Errorf("Found data from deleted %v instance using id: %v", + inpInst.TableName(), inpInst.Id) + } + +} + +func TestCreateTablesIfNotExists(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + err := dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error(err) + } +} + +func TestTruncateTables(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + err := dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error(err) + } + + // Insert some data + p1 := &Person{0, 0, 0, "Bob", "Smith", 0} + dbmap.Insert(p1) + inv := &Invoice{0, 0, 1, "my invoice", 0, true} + dbmap.Insert(inv) + + err = dbmap.TruncateTables() + if err != nil { + t.Error(err) + } + + // Make sure all rows are deleted + rows, _ := dbmap.Select(Person{}, "SELECT * FROM person_test") + if len(rows) != 0 { + t.Errorf("Expected 0 person rows, got %d", len(rows)) + } + rows, _ = dbmap.Select(Invoice{}, "SELECT * FROM invoice_test") + if len(rows) != 0 { + t.Errorf("Expected 0 invoice rows, got %d", len(rows)) + } +} + +func TestCustomDateType(t *testing.T) { + dbmap := newDBMap(t) + dbmap.TypeConverter = testTypeConverter{} + dbmap.AddTable(WithCustomDate{}).SetKeys(true, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + test1 := &WithCustomDate{Added: CustomDate{Time: time.Now().Truncate(time.Second)}} + err = dbmap.Insert(test1) + if err != nil { + t.Errorf("Could not insert struct with custom date field: %s", err) + t.FailNow() + } + // Unfortunately, the mysql driver doesn't handle time.Time + // values properly during Get(). I can't find a way to work + // around that problem - every other type that I've tried is just + // silently converted. time.Time is the only type that causes + // the issue that this test checks for. As such, if the driver is + // mysql, we'll just skip the rest of this test. + if _, driver := dialectAndDriver(); driver == "mysql" { + t.Skip("TestCustomDateType can't run Get() with the mysql driver; skipping the rest of this test...") + } + result, err := dbmap.Get(new(WithCustomDate), test1.Id) + if err != nil { + t.Errorf("Could not get struct with custom date field: %s", err) + t.FailNow() + } + test2 := result.(*WithCustomDate) + if test2.Added.UTC() != test1.Added.UTC() { + t.Errorf("Custom dates do not match: %v != %v", test2.Added.UTC(), test1.Added.UTC()) + } +} + +func TestUIntPrimaryKey(t *testing.T) { + dbmap := newDBMap(t) + dbmap.AddTable(PersonUInt64{}).SetKeys(true, "Id") + dbmap.AddTable(PersonUInt32{}).SetKeys(true, "Id") + dbmap.AddTable(PersonUInt16{}).SetKeys(true, "Id") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + p1 := &PersonUInt64{0, "name1"} + p2 := &PersonUInt32{0, "name2"} + p3 := &PersonUInt16{0, "name3"} + err = dbmap.Insert(p1, p2, p3) + if err != nil { + t.Error(err) + } + if p1.Id != 1 { + t.Errorf("%d != 1", p1.Id) + } + if p2.Id != 1 { + t.Errorf("%d != 1", p2.Id) + } + if p3.Id != 1 { + t.Errorf("%d != 1", p3.Id) + } +} + +func TestSetUniqueTogether(t *testing.T) { + dbmap := newDBMap(t) + dbmap.AddTable(UniqueColumns{}).SetUniqueTogether("FirstName", "LastName").SetUniqueTogether("City", "ZipCode") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + n1 := &UniqueColumns{"Steve", "Jobs", "Cupertino", 95014} + err = dbmap.Insert(n1) + if err != nil { + t.Error(err) + } + + // Should fail because of the first constraint + n2 := &UniqueColumns{"Steve", "Jobs", "Sunnyvale", 94085} + err = dbmap.Insert(n2) + if err == nil { + t.Error(err) + } + // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL + errLower := strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { + t.Error(err) + } + + // Should also fail because of the second unique-together + n3 := &UniqueColumns{"Steve", "Wozniak", "Cupertino", 95014} + err = dbmap.Insert(n3) + if err == nil { + t.Error(err) + } + // "unique" for Postgres/SQLite, "Duplicate entry" for MySQL + errLower = strings.ToLower(err.Error()) + if !strings.Contains(errLower, "unique") && !strings.Contains(errLower, "duplicate entry") { + t.Error(err) + } + + // This one should finally succeed + n4 := &UniqueColumns{"Steve", "Wozniak", "Sunnyvale", 94085} + err = dbmap.Insert(n4) + if err != nil { + t.Error(err) + } +} + +func TestSetUniqueTogetherIdempotent(t *testing.T) { + dbmap := newDBMap(t) + table := dbmap.AddTable(UniqueColumns{}).SetUniqueTogether("FirstName", "LastName") + table.SetUniqueTogether("FirstName", "LastName") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + n1 := &UniqueColumns{"Steve", "Jobs", "Cupertino", 95014} + err = dbmap.Insert(n1) + if err != nil { + t.Error(err) + } + + // Should still fail because of the constraint + n2 := &UniqueColumns{"Steve", "Jobs", "Sunnyvale", 94085} + err = dbmap.Insert(n2) + if err == nil { + t.Error(err) + } + + // Should have only created one unique constraint + actualCount := strings.Count(table.SqlForCreate(false), "unique") + if actualCount != 1 { + t.Errorf("expected one unique index, found %d: %s", actualCount, table.SqlForCreate(false)) + } +} + +func TestPersistentUser(t *testing.T) { + dbmap := newDBMap(t) + dbmap.Exec("drop table if exists PersistentUser") + table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") + table.ColMap("Key").Rename("mykey") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + pu := &PersistentUser{43, "33r", false} + err = dbmap.Insert(pu) + if err != nil { + panic(err) + } + + // prove we can pass a pointer into Get + pu2, err := dbmap.Get(pu, pu.Key) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(pu, pu2) { + t.Errorf("%v!=%v", pu, pu2) + } + + arr, err := dbmap.Select(pu, "select * from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(pu, arr[0]) { + t.Errorf("%v!=%v", pu, arr[0]) + } + + // prove we can get the results back in a slice + var puArr []*PersistentUser + _, err = dbmap.Select(&puArr, "select * from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu, puArr[0]) { + t.Errorf("%v!=%v", pu, puArr[0]) + } + + // prove we can get the results back in a non-pointer slice + var puValues []PersistentUser + _, err = dbmap.Select(&puValues, "select * from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if len(puValues) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(*pu, puValues[0]) { + t.Errorf("%v!=%v", *pu, puValues[0]) + } + + // prove we can get the results back in a string slice + var idArr []*string + _, err = dbmap.Select(&idArr, "select "+columnName(dbmap, PersistentUser{}, "Id")+" from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if len(idArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.Id, *idArr[0]) { + t.Errorf("%v!=%v", pu.Id, *idArr[0]) + } + + // prove we can get the results back in an int slice + var keyArr []*int32 + _, err = dbmap.Select(&keyArr, "select mykey from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if len(keyArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.Key, *keyArr[0]) { + t.Errorf("%v!=%v", pu.Key, *keyArr[0]) + } + + // prove we can get the results back in a bool slice + var passedArr []*bool + _, err = dbmap.Select(&passedArr, "select "+columnName(dbmap, PersistentUser{}, "PassedTraining")+" from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if len(passedArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.PassedTraining, *passedArr[0]) { + t.Errorf("%v!=%v", pu.PassedTraining, *passedArr[0]) + } + + // prove we can get the results back in a non-pointer slice + var stringArr []string + _, err = dbmap.Select(&stringArr, "select "+columnName(dbmap, PersistentUser{}, "Id")+" from "+tableName(dbmap, PersistentUser{})) + if err != nil { + panic(err) + } + if len(stringArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu.Id, stringArr[0]) { + t.Errorf("%v!=%v", pu.Id, stringArr[0]) + } +} + +func TestNamedQueryMap(t *testing.T) { + dbmap := newDBMap(t) + dbmap.Exec("drop table if exists PersistentUser") + table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") + table.ColMap("Key").Rename("mykey") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + pu := &PersistentUser{43, "33r", false} + pu2 := &PersistentUser{500, "abc", false} + err = dbmap.Insert(pu, pu2) + if err != nil { + panic(err) + } + + // Test simple case + var puArr []*PersistentUser + _, err = dbmap.Select(&puArr, "select * from "+tableName(dbmap, PersistentUser{})+" where mykey = :Key", map[string]interface{}{ + "Key": 43, + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu, puArr[0]) { + t.Errorf("%v!=%v", pu, puArr[0]) + } + + // Test more specific map value type is ok + puArr = nil + _, err = dbmap.Select(&puArr, "select * from "+tableName(dbmap, PersistentUser{})+" where mykey = :Key", map[string]int{ + "Key": 43, + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + + // Test multiple parameters set. + puArr = nil + _, err = dbmap.Select(&puArr, ` +select * from `+tableName(dbmap, PersistentUser{})+` + where mykey = :Key + and `+columnName(dbmap, PersistentUser{}, "PassedTraining")+` = :PassedTraining + and `+columnName(dbmap, PersistentUser{}, "Id")+` = :Id`, map[string]interface{}{ + "Key": 43, + "PassedTraining": false, + "Id": "33r", + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + + // Test colon within a non-key string + // Test having extra, unused properties in the map. + puArr = nil + _, err = dbmap.Select(&puArr, ` +select * from `+tableName(dbmap, PersistentUser{})+` + where mykey = :Key + and `+columnName(dbmap, PersistentUser{}, "Id")+` != 'abc:def'`, map[string]interface{}{ + "Key": 43, + "PassedTraining": false, + }) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + + // Test to delete with Exec and named params. + result, err := dbmap.Exec("delete from "+tableName(dbmap, PersistentUser{})+" where mykey = :Key", map[string]interface{}{ + "Key": 43, + }) + count, err := result.RowsAffected() + if err != nil { + t.Errorf("Failed to exec: %s", err) + t.FailNow() + } + if count != 1 { + t.Errorf("Expected 1 persistentuser to be deleted, but %d deleted", count) + } +} + +func TestNamedQueryStruct(t *testing.T) { + dbmap := newDBMap(t) + dbmap.Exec("drop table if exists PersistentUser") + table := dbmap.AddTable(PersistentUser{}).SetKeys(false, "Key") + table.ColMap("Key").Rename("mykey") + err := dbmap.CreateTablesIfNotExists() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + pu := &PersistentUser{43, "33r", false} + pu2 := &PersistentUser{500, "abc", false} + err = dbmap.Insert(pu, pu2) + if err != nil { + panic(err) + } + + // Test select self + var puArr []*PersistentUser + _, err = dbmap.Select(&puArr, ` +select * from `+tableName(dbmap, PersistentUser{})+` + where mykey = :Key + and `+columnName(dbmap, PersistentUser{}, "PassedTraining")+` = :PassedTraining + and `+columnName(dbmap, PersistentUser{}, "Id")+` = :Id`, pu) + if err != nil { + t.Errorf("Failed to select: %s", err) + t.FailNow() + } + if len(puArr) != 1 { + t.Errorf("Expected one persistentuser, found none") + } + if !reflect.DeepEqual(pu, puArr[0]) { + t.Errorf("%v!=%v", pu, puArr[0]) + } + + // Test delete self. + result, err := dbmap.Exec(` +delete from `+tableName(dbmap, PersistentUser{})+` + where mykey = :Key + and `+columnName(dbmap, PersistentUser{}, "PassedTraining")+` = :PassedTraining + and `+columnName(dbmap, PersistentUser{}, "Id")+` = :Id`, pu) + count, err := result.RowsAffected() + if err != nil { + t.Errorf("Failed to exec: %s", err) + t.FailNow() + } + if count != 1 { + t.Errorf("Expected 1 persistentuser to be deleted, but %d deleted", count) + } +} + +// Ensure that the slices containing SQL results are non-nil when the result set is empty. +func TestReturnsNonNilSlice(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + noResultsSQL := "select * from invoice_test where " + columnName(dbmap, Invoice{}, "Id") + "=99999" + var r1 []*Invoice + rawSelect(dbmap, &r1, noResultsSQL) + if r1 == nil { + t.Errorf("r1==nil") + } + + r2 := rawSelect(dbmap, Invoice{}, noResultsSQL) + if r2 == nil { + t.Errorf("r2==nil") + } +} + +func TestOverrideVersionCol(t *testing.T) { + dbmap := newDBMap(t) + t1 := dbmap.AddTable(InvoicePersonView{}).SetKeys(false, "InvoiceId", "PersonId") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + c1 := t1.SetVersionCol("LegacyVersion") + if c1.ColumnName != "LegacyVersion" { + t.Errorf("Wrong col returned: %v", c1) + } + + ipv := &InvoicePersonView{1, 2, "memo", "fname", 0} + _update(dbmap, ipv) + if ipv.LegacyVersion != 1 { + t.Errorf("LegacyVersion not updated: %d", ipv.LegacyVersion) + } +} + +func TestOptimisticLocking(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "Bob", "Smith", 0} + dbmap.Insert(p1) // Version is now 1 + if p1.Version != 1 { + t.Errorf("Insert didn't incr Version: %d != %d", 1, p1.Version) + return + } + if p1.Id == 0 { + t.Errorf("Insert didn't return a generated PK") + return + } + + obj, err := dbmap.Get(Person{}, p1.Id) + if err != nil { + panic(err) + } + p2 := obj.(*Person) + p2.LName = "Edwards" + dbmap.Update(p2) // Version is now 2 + if p2.Version != 2 { + t.Errorf("Update didn't incr Version: %d != %d", 2, p2.Version) + } + + p1.LName = "Howard" + count, err := dbmap.Update(p1) + if _, ok := err.(sqldb.OptimisticLockError); !ok { + t.Errorf("update - Expected sqldb.OptimisticLockError, got: %v", err) + } + if count != -1 { + t.Errorf("update - Expected -1 count, got: %d", count) + } + + count, err = dbmap.Delete(p1) + if _, ok := err.(sqldb.OptimisticLockError); !ok { + t.Errorf("delete - Expected sqldb.OptimisticLockError, got: %v", err) + } + if count != -1 { + t.Errorf("delete - Expected -1 count, got: %d", count) + } +} + +// what happens if a legacy table has a null value? +func TestDoubleAddTable(t *testing.T) { + dbmap := newDBMap(t) + t1 := dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id") + t2 := dbmap.AddTable(TableWithNull{}) + if t1 != t2 { + t.Errorf("%v != %v", t1, t2) + } +} + +// what happens if a legacy table has a null value? +func TestNullValues(t *testing.T) { + dbmap := initDBMapNulls(t) + defer dropAndClose(dbmap) + + // insert a row directly + rawExec(dbmap, "insert into "+tableName(dbmap, TableWithNull{})+" values (10, null, "+ + "null, null, null, null)") + + // try to load it + expected := &TableWithNull{Id: 10} + obj := _get(dbmap, TableWithNull{}, 10) + t1 := obj.(*TableWithNull) + if !reflect.DeepEqual(expected, t1) { + t.Errorf("%v != %v", expected, t1) + } + + // update it + t1.Str = sql.NullString{"hi", true} + expected.Str = t1.Str + t1.Int64 = sql.NullInt64{999, true} + expected.Int64 = t1.Int64 + t1.Float64 = sql.NullFloat64{53.33, true} + expected.Float64 = t1.Float64 + t1.Bool = sql.NullBool{true, true} + expected.Bool = t1.Bool + t1.Bytes = []byte{1, 30, 31, 33} + expected.Bytes = t1.Bytes + _update(dbmap, t1) + + obj = _get(dbmap, TableWithNull{}, 10) + t1 = obj.(*TableWithNull) + if t1.Str.String != "hi" { + t.Errorf("%s != hi", t1.Str.String) + } + if !reflect.DeepEqual(expected, t1) { + t.Errorf("%v != %v", expected, t1) + } +} + +func TestScannerValuer(t *testing.T) { + dbmap := newDBMap(t) + dbmap.AddTableWithName(PersonValuerScanner{}, "person_test").SetKeys(true, "Id") + dbmap.AddTableWithName(InvoiceWithValuer{}, "invoice_test").SetKeys(true, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + pv := PersonValuerScanner{} + pv.FName = "foo" + pv.LName = "bar" + err = dbmap.Insert(&pv) + if err != nil { + t.Errorf("Could not insert PersonValuerScanner using Person table: %v", err) + t.FailNow() + } + + inv := InvoiceWithValuer{} + inv.Memo = "foo" + inv.Person = pv + err = dbmap.Insert(&inv) + if err != nil { + t.Errorf("Could not insert InvoiceWithValuer using Invoice table: %v", err) + t.FailNow() + } + + res, err := dbmap.Get(InvoiceWithValuer{}, inv.Id) + if err != nil { + t.Errorf("Could not get InvoiceWithValuer: %v", err) + t.FailNow() + } + dbInv := res.(*InvoiceWithValuer) + + if dbInv.Person.Id != pv.Id { + t.Errorf("InvoiceWithValuer got wrong person ID: %d (expected) != %d (actual)", pv.Id, dbInv.Person.Id) + } +} + +func TestColumnProps(t *testing.T) { + dbmap := newDBMap(t) + t1 := dbmap.AddTable(Invoice{}).SetKeys(true, "Id") + t1.ColMap("Created").Rename("date_created") + t1.ColMap("Updated").SetTransient(true) + t1.ColMap("Memo").SetMaxSize(10) + t1.ColMap("PersonId").SetUnique(true) + + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + // test transient + inv := &Invoice{0, 0, 1, "my invoice", 0, true} + _insert(dbmap, inv) + obj := _get(dbmap, Invoice{}, inv.Id) + inv = obj.(*Invoice) + if inv.Updated != 0 { + t.Errorf("Saved transient column 'Updated'") + } + + // test max size + inv.Memo = "this memo is too long" + err = dbmap.Insert(inv) + if err == nil { + t.Errorf("max size exceeded, but Insert did not fail.") + } + + // test unique - same person id + inv = &Invoice{0, 0, 1, "my invoice2", 0, false} + err = dbmap.Insert(inv) + if err == nil { + t.Errorf("same PersonId inserted, but Insert did not fail.") + } +} + +func TestRawSelect(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + _insert(dbmap, p1) + + inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, true} + _insert(dbmap, inv1) + + expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0} + + query := "select i." + columnName(dbmap, Invoice{}, "Id") + " InvoiceId, p." + columnName(dbmap, Person{}, "Id") + " PersonId, i." + columnName(dbmap, Invoice{}, "Memo") + ", p." + columnName(dbmap, Person{}, "FName") + " " + + "from invoice_test i, person_test p " + + "where i." + columnName(dbmap, Invoice{}, "PersonId") + " = p." + columnName(dbmap, Person{}, "Id") + list := rawSelect(dbmap, InvoicePersonView{}, query) + if len(list) != 1 { + t.Errorf("len(list) != 1: %d", len(list)) + } else if !reflect.DeepEqual(expected, list[0]) { + t.Errorf("%v != %v", expected, list[0]) + } +} + +func TestHooks(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + _insert(dbmap, p1) + if p1.Created == 0 || p1.Updated == 0 { + t.Errorf("p1.PreInsert() didn't run: %v", p1) + } else if p1.LName != "postinsert" { + t.Errorf("p1.PostInsert() didn't run: %v", p1) + } + + obj := _get(dbmap, Person{}, p1.Id) + p1 = obj.(*Person) + if p1.LName != "postget" { + t.Errorf("p1.PostGet() didn't run: %v", p1) + } + + _update(dbmap, p1) + if p1.FName != "preupdate" { + t.Errorf("p1.PreUpdate() didn't run: %v", p1) + } else if p1.LName != "postupdate" { + t.Errorf("p1.PostUpdate() didn't run: %v", p1) + } + + var persons []*Person + bindVar := dbmap.Dialect.BindVar(0) + rawSelect(dbmap, &persons, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+" = "+bindVar, p1.Id) + if persons[0].LName != "postget" { + t.Errorf("p1.PostGet() didn't run after select: %v", p1) + } + + _del(dbmap, p1) + if p1.FName != "predelete" { + t.Errorf("p1.PreDelete() didn't run: %v", p1) + } else if p1.LName != "postdelete" { + t.Errorf("p1.PostDelete() didn't run: %v", p1) + } + + // Test error case + p2 := &Person{0, 0, 0, "badname", "", 0} + err := dbmap.Insert(p2) + if err == nil { + t.Errorf("p2.PreInsert() didn't return an error") + } +} + +func TestTransaction(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "t1", 0, true} + inv2 := &Invoice{0, 100, 200, "t2", 0, false} + + trans, err := dbmap.Begin() + if err != nil { + panic(err) + } + trans.Insert(inv1, inv2) + err = trans.Commit() + if err != nil { + panic(err) + } + + obj, err := dbmap.Get(Invoice{}, inv1.Id) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(inv1, obj) { + t.Errorf("%v != %v", inv1, obj) + } + obj, err = dbmap.Get(Invoice{}, inv2.Id) + if err != nil { + panic(err) + } + if !reflect.DeepEqual(inv2, obj) { + t.Errorf("%v != %v", inv2, obj) + } +} + +func TestTransactionExecNamed(t *testing.T) { + if os.Getenv("SQLDB_TEST_DIALECT") == "postgres" { + return + } + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + trans, err := dbmap.Begin() + if err != nil { + panic(err) + } + defer trans.Rollback() + // exec should support named params + args := map[string]interface{}{ + "created": 100, + "updated": 200, + "memo": "unpaid", + "personID": 0, + "isPaid": false, + } + + result, err := trans.Exec(`INSERT INTO invoice_test (Created, Updated, Memo, PersonId, IsPaid) Values(:created, :updated, :memo, :personID, :isPaid)`, args) + if err != nil { + panic(err) + } + id, err := result.LastInsertId() + if err != nil { + panic(err) + } + var checkMemo = func(want string) { + args := map[string]interface{}{ + "id": id, + } + memo, err := trans.SelectStr("select memo from invoice_test where id = :id", args) + if err != nil { + panic(err) + } + if memo != want { + t.Errorf("%q != %q", want, memo) + } + } + checkMemo("unpaid") + + // exec should still work with ? params + result, err = trans.Exec(`INSERT INTO invoice_test (Created, Updated, Memo, PersonId, IsPaid) Values(?, ?, ?, ?, ?)`, 10, 15, "paid", 0, true) + if err != nil { + panic(err) + } + id, err = result.LastInsertId() + if err != nil { + panic(err) + } + checkMemo("paid") + err = trans.Commit() + if err != nil { + panic(err) + } +} + +func TestTransactionExecNamedPostgres(t *testing.T) { + if os.Getenv("SQLDB_TEST_DIALECT") != "postgres" { + return + } + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + trans, err := dbmap.Begin() + if err != nil { + panic(err) + } + // exec should support named params + args := map[string]interface{}{ + "created": 100, + "updated": 200, + "memo": "zzTest", + "personID": 0, + "isPaid": false, + } + _, err = trans.Exec(`INSERT INTO invoice_test ("Created", "Updated", "Memo", "PersonId", "IsPaid") Values(:created, :updated, :memo, :personID, :isPaid)`, args) + if err != nil { + panic(err) + } + var checkMemo = func(want string) { + args := map[string]interface{}{ + "memo": want, + } + memo, err := trans.SelectStr(`select "Memo" from invoice_test where "Memo" = :memo`, args) + if err != nil { + panic(err) + } + if memo != want { + t.Errorf("%q != %q", want, memo) + } + } + checkMemo("zzTest") + + // exec should still work with ? params + _, err = trans.Exec(`INSERT INTO invoice_test ("Created", "Updated", "Memo", "PersonId", "IsPaid") Values($1, $2, $3, $4, $5)`, 10, 15, "yyTest", 0, true) + + if err != nil { + panic(err) + } + checkMemo("yyTest") + err = trans.Commit() + if err != nil { + panic(err) + } +} + +func TestSavepoint(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "unpaid", 0, false} + + trans, err := dbmap.Begin() + if err != nil { + panic(err) + } + trans.Insert(inv1) + + var checkMemo = func(want string) { + memo, err := trans.SelectStr("select " + columnName(dbmap, Invoice{}, "Memo") + " from invoice_test") + if err != nil { + panic(err) + } + if memo != want { + t.Errorf("%q != %q", want, memo) + } + } + checkMemo("unpaid") + + err = trans.Savepoint("foo") + if err != nil { + panic(err) + } + checkMemo("unpaid") + + inv1.Memo = "paid" + _, err = trans.Update(inv1) + if err != nil { + panic(err) + } + checkMemo("paid") + + err = trans.RollbackToSavepoint("foo") + if err != nil { + panic(err) + } + checkMemo("unpaid") + + err = trans.Rollback() + if err != nil { + panic(err) + } +} + +func TestMultiple(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "a", 0, false} + inv2 := &Invoice{0, 100, 200, "b", 0, true} + _insert(dbmap, inv1, inv2) + + inv1.Memo = "c" + inv2.Memo = "d" + _update(dbmap, inv1, inv2) + + count := _del(dbmap, inv1, inv2) + if count != 2 { + t.Errorf("%d != 2", count) + } +} + +func TestCrud(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + inv := &Invoice{0, 100, 200, "first order", 0, true} + testCrudInternal(t, dbmap, inv) + + invtag := &InvoiceTag{0, 300, 400, "some order", 33, false} + testCrudInternal(t, dbmap, invtag) + + foo := &AliasTransientField{BarStr: "some bar"} + testCrudInternal(t, dbmap, foo) + + dynamicTablesTest(t, dbmap) +} + +func testCrudInternal(t *testing.T, dbmap *sqldb.DbMap, val testable) { + table, err := dbmap.TableFor(reflect.TypeOf(val).Elem(), false) + if err != nil { + t.Errorf("couldn't call TableFor: val=%v err=%v", val, err) + } + + _, err = dbmap.Exec("delete from " + table.TableName) + if err != nil { + t.Errorf("couldn't delete rows from: val=%v err=%v", val, err) + } + + // INSERT row + _insert(dbmap, val) + if val.GetId() == 0 { + t.Errorf("val.GetId() was not set on INSERT") + return + } + + // SELECT row + val2 := _get(dbmap, val, val.GetId()) + if !reflect.DeepEqual(val, val2) { + t.Errorf("%v != %v", val, val2) + } + + // UPDATE row and SELECT + val.Rand() + count := _update(dbmap, val) + if count != 1 { + t.Errorf("update 1 != %d", count) + } + val2 = _get(dbmap, val, val.GetId()) + if !reflect.DeepEqual(val, val2) { + t.Errorf("%v != %v", val, val2) + } + + // Select * + rows, err := dbmap.Select(val, "select * from "+dbmap.Dialect.QuoteField(table.TableName)) + if err != nil { + t.Errorf("couldn't select * from %s err=%v", dbmap.Dialect.QuoteField(table.TableName), err) + } else if len(rows) != 1 { + t.Errorf("unexpected row count in %s: %d", dbmap.Dialect.QuoteField(table.TableName), len(rows)) + } else if !reflect.DeepEqual(val, rows[0]) { + t.Errorf("select * result: %v != %v", val, rows[0]) + } + + // DELETE row + deleted := _del(dbmap, val) + if deleted != 1 { + t.Errorf("Did not delete row with Id: %d", val.GetId()) + return + } + + // VERIFY deleted + val2 = _get(dbmap, val, val.GetId()) + if val2 != nil { + t.Errorf("Found invoice with id: %d after Delete()", val.GetId()) + } +} + +func TestWithIgnoredColumn(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + ic := &WithIgnoredColumn{-1, 0, 1} + _insert(dbmap, ic) + expected := &WithIgnoredColumn{0, 1, 1} + ic2 := _get(dbmap, WithIgnoredColumn{}, ic.Id).(*WithIgnoredColumn) + + if !reflect.DeepEqual(expected, ic2) { + t.Errorf("%v != %v", expected, ic2) + } + if _del(dbmap, ic) != 1 { + t.Errorf("Did not delete row with Id: %d", ic.Id) + return + } + if _get(dbmap, WithIgnoredColumn{}, ic.Id) != nil { + t.Errorf("Found id: %d after Delete()", ic.Id) + } +} + +func TestColumnFilter(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "a", 0, false} + _insert(dbmap, inv1) + + inv1.Memo = "c" + inv1.IsPaid = true + _updateColumns(dbmap, func(col *sqldb.ColumnMap) bool { + return col.ColumnName == "Memo" + }, inv1) + + inv2 := &Invoice{} + inv2 = _get(dbmap, inv2, inv1.Id).(*Invoice) + if inv2.Memo != "c" { + t.Errorf("Expected column to be updated (%#v)", inv2) + } + if inv2.IsPaid { + t.Error("IsPaid shouldn't have been updated") + } +} + +func TestTypeConversionExample(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p := Person{FName: "Bob", LName: "Smith"} + tc := &TypeConversionExample{-1, p, CustomStringType("hi")} + _insert(dbmap, tc) + + expected := &TypeConversionExample{1, p, CustomStringType("hi")} + tc2 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) + if !reflect.DeepEqual(expected, tc2) { + t.Errorf("tc2 %v != %v", expected, tc2) + } + + tc2.Name = CustomStringType("hi2") + tc2.PersonJSON = Person{FName: "Jane", LName: "Doe"} + _update(dbmap, tc2) + + expected = &TypeConversionExample{1, tc2.PersonJSON, CustomStringType("hi2")} + tc3 := _get(dbmap, TypeConversionExample{}, tc.Id).(*TypeConversionExample) + if !reflect.DeepEqual(expected, tc3) { + t.Errorf("tc3 %v != %v", expected, tc3) + } + + if _del(dbmap, tc) != 1 { + t.Errorf("Did not delete row with Id: %d", tc.Id) + } + +} + +func TestWithEmbeddedStruct(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + es := &WithEmbeddedStruct{-1, Names{FirstName: "Alice", LastName: "Smith"}} + _insert(dbmap, es) + expected := &WithEmbeddedStruct{1, Names{FirstName: "Alice", LastName: "Smith"}} + es2 := _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + es2.FirstName = "Bob" + expected.FirstName = "Bob" + _update(dbmap, es2) + es2 = _get(dbmap, WithEmbeddedStruct{}, es.Id).(*WithEmbeddedStruct) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + ess := rawSelect(dbmap, WithEmbeddedStruct{}, "select * from embedded_struct_test") + if !reflect.DeepEqual(es2, ess[0]) { + t.Errorf("%v != %v", es2, ess[0]) + } +} + +/* +func TestWithEmbeddedStructConflictingEmbeddedMemberNames(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + es := &WithEmbeddedStructConflictingEmbeddedMemberNames{-1, Names{FirstName: "Alice", LastName: "Smith"}, NamesConflict{FirstName: "Andrew", Surname: "Wiggin"}} + _insert(dbmap, es) + expected := &WithEmbeddedStructConflictingEmbeddedMemberNames{-1, Names{FirstName: "Alice", LastName: "Smith"}, NamesConflict{FirstName: "Andrew", Surname: "Wiggin"}} + es2 := _get(dbmap, WithEmbeddedStructConflictingEmbeddedMemberNames{}, es.Id).(*WithEmbeddedStructConflictingEmbeddedMemberNames) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + es2.Names.FirstName = "Bob" + expected.Names.FirstName = "Bob" + _update(dbmap, es2) + es2 = _get(dbmap, WithEmbeddedStructConflictingEmbeddedMemberNames{}, es.Id).(*WithEmbeddedStructConflictingEmbeddedMemberNames) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + ess := rawSelect(dbmap, WithEmbeddedStructConflictingEmbeddedMemberNames{}, "select * from embedded_struct_conflict_name_test") + if !reflect.DeepEqual(es2, ess[0]) { + t.Errorf("%v != %v", es2, ess[0]) + } +} + +func TestWithEmbeddedStructSameMemberName(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + es := &WithEmbeddedStructSameMemberName{-1, SameName{SameName: "Alice"}} + _insert(dbmap, es) + expected := &WithEmbeddedStructSameMemberName{-1, SameName{SameName: "Alice"}} + es2 := _get(dbmap, WithEmbeddedStructSameMemberName{}, es.Id).(*WithEmbeddedStructSameMemberName) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + es2.SameName = SameName{"Bob"} + expected.SameName = SameName{"Bob"} + _update(dbmap, es2) + es2 = _get(dbmap, WithEmbeddedStructSameMemberName{}, es.Id).(*WithEmbeddedStructSameMemberName) + if !reflect.DeepEqual(expected, es2) { + t.Errorf("%v != %v", expected, es2) + } + + ess := rawSelect(dbmap, WithEmbeddedStructSameMemberName{}, "select * from embedded_struct_same_member_name_test") + if !reflect.DeepEqual(es2, ess[0]) { + t.Errorf("%v != %v", es2, ess[0]) + } +} +//*/ + +func TestWithEmbeddedStructBeforeAutoincr(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + esba := &WithEmbeddedStructBeforeAutoincrField{Names: Names{FirstName: "Alice", LastName: "Smith"}} + _insert(dbmap, esba) + var expectedAutoincrId int64 = 1 + if esba.Id != expectedAutoincrId { + t.Errorf("%d != %d", expectedAutoincrId, esba.Id) + } +} + +func TestWithEmbeddedAutoincr(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + esa := &WithEmbeddedAutoincr{ + WithEmbeddedStruct: WithEmbeddedStruct{Names: Names{FirstName: "Alice", LastName: "Smith"}}, + MiddleName: "Rose", + } + _insert(dbmap, esa) + var expectedAutoincrId int64 = 1 + if esa.Id != expectedAutoincrId { + t.Errorf("%d != %d", expectedAutoincrId, esa.Id) + } +} + +func TestSelectVal(t *testing.T) { + dbmap := initDBMapNulls(t) + defer dropAndClose(dbmap) + + bindVar := dbmap.Dialect.BindVar(0) + + t1 := TableWithNull{Str: sql.NullString{"abc", true}, + Int64: sql.NullInt64{78, true}, + Float64: sql.NullFloat64{32.2, true}, + Bool: sql.NullBool{true, true}, + Bytes: []byte("hi")} + _insert(dbmap, &t1) + + // SelectInt + i64 := selectInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") + if i64 != 78 { + t.Errorf("int64 %d != 78", i64) + } + i64 = selectInt(dbmap, "select count(*) from "+tableName(dbmap, TableWithNull{})) + if i64 != 1 { + t.Errorf("int64 count %d != 1", i64) + } + i64 = selectInt(dbmap, "select count(*) from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"="+bindVar, "asdfasdf") + if i64 != 0 { + t.Errorf("int64 no rows %d != 0", i64) + } + + // SelectNullInt + n := selectNullInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='notfound'") + if !reflect.DeepEqual(n, sql.NullInt64{0, false}) { + t.Errorf("nullint %v != 0,false", n) + } + + n = selectNullInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") + if !reflect.DeepEqual(n, sql.NullInt64{78, true}) { + t.Errorf("nullint %v != 78, true", n) + } + + // SelectFloat + f64 := selectFloat(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Float64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") + if f64 != 32.2 { + t.Errorf("float64 %f != 32.2", f64) + } + f64 = selectFloat(dbmap, "select min("+columnName(dbmap, TableWithNull{}, "Float64")+") from "+tableName(dbmap, TableWithNull{})) + if f64 != 32.2 { + t.Errorf("float64 min %f != 32.2", f64) + } + f64 = selectFloat(dbmap, "select count(*) from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"="+bindVar, "asdfasdf") + if f64 != 0 { + t.Errorf("float64 no rows %f != 0", f64) + } + + // SelectNullFloat + nf := selectNullFloat(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Float64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='notfound'") + if !reflect.DeepEqual(nf, sql.NullFloat64{0, false}) { + t.Errorf("nullfloat %v != 0,false", nf) + } + + nf = selectNullFloat(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Float64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='abc'") + if !reflect.DeepEqual(nf, sql.NullFloat64{32.2, true}) { + t.Errorf("nullfloat %v != 32.2, true", nf) + } + + // SelectStr + s := selectStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Int64")+"="+bindVar, 78) + if s != "abc" { + t.Errorf("s %s != abc", s) + } + s = selectStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='asdfasdf'") + if s != "" { + t.Errorf("s no rows %s != ''", s) + } + + // SelectNullStr + ns := selectNullStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Int64")+"="+bindVar, 78) + if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) { + t.Errorf("nullstr %v != abc,true", ns) + } + ns = selectNullStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"='asdfasdf'") + if !reflect.DeepEqual(ns, sql.NullString{"", false}) { + t.Errorf("nullstr no rows %v != '',false", ns) + } + + // SelectInt/Str with named parameters + i64 = selectInt(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Int64")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Str")+"=:abc", map[string]string{"abc": "abc"}) + if i64 != 78 { + t.Errorf("int64 %d != 78", i64) + } + ns = selectNullStr(dbmap, "select "+columnName(dbmap, TableWithNull{}, "Str")+" from "+tableName(dbmap, TableWithNull{})+" where "+columnName(dbmap, TableWithNull{}, "Int64")+"=:num", map[string]int{"num": 78}) + if !reflect.DeepEqual(ns, sql.NullString{"abc", true}) { + t.Errorf("nullstr %v != abc,true", ns) + } +} + +func TestVersionMultipleRows(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + persons := []*Person{ + &Person{0, 0, 0, "Bob", "Smith", 0}, + &Person{0, 0, 0, "Jane", "Smith", 0}, + &Person{0, 0, 0, "Mike", "Smith", 0}, + } + + _insert(dbmap, persons[0], persons[1], persons[2]) + + for x, p := range persons { + if p.Version != 1 { + t.Errorf("person[%d].Version != 1: %d", x, p.Version) + } + } +} + +func TestWithStringPk(t *testing.T) { + dbmap := newDBMap(t) + dbmap.AddTableWithName(WithStringPk{}, "string_pk_test").SetKeys(true, "Id") + _, err := dbmap.Exec("create table string_pk_test (Id varchar(255), Name varchar(255));") + if err != nil { + t.Errorf("couldn't create string_pk_test: %v", err) + } + defer dropAndClose(dbmap) + + row := &WithStringPk{"1", "foo"} + err = dbmap.Insert(row) + if err == nil { + t.Errorf("Expected error when inserting into table w/non Int PK and autoincr set true") + } +} + +// TestSqlExecutorInterfaceSelects ensures that all sqldb.DbMap methods starting with Select... +// are also exposed in the sqldb.SqlExecutor interface. Select... functions can always +// run on Pre/Post hooks. +func TestSqlExecutorInterfaceSelects(t *testing.T) { + dbMapType := reflect.TypeOf(&sqldb.DbMap{}) + sqlExecutorType := reflect.TypeOf((*sqldb.SqlExecutor)(nil)).Elem() + numDbMapMethods := dbMapType.NumMethod() + for i := 0; i < numDbMapMethods; i += 1 { + dbMapMethod := dbMapType.Method(i) + if !strings.HasPrefix(dbMapMethod.Name, "Select") { + continue + } + if _, found := sqlExecutorType.MethodByName(dbMapMethod.Name); !found { + t.Errorf("Method %s is defined on sqldb.DbMap but not implemented in sqldb.SqlExecutor", + dbMapMethod.Name) + } + } +} + +func TestNullTime(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + // if time is null + ent := &WithNullTime{ + Id: 0, + Time: sqldb.NullTime{ + Valid: false, + }} + err := dbmap.Insert(ent) + if err != nil { + t.Errorf("failed insert on %s", err.Error()) + } + err = dbmap.SelectOne(ent, `select * from nulltime_test where `+columnName(dbmap, WithNullTime{}, "Id")+`=:Id`, map[string]interface{}{ + "Id": ent.Id, + }) + if err != nil { + t.Errorf("failed select on %s", err.Error()) + } + if ent.Time.Valid { + t.Error("sqldb.NullTime returns valid but expected null.") + } + + // if time is not null + ts, err := time.Parse(time.RFC3339, "2001-01-02T15:04:05-07:00") + if err != nil { + t.Errorf("failed to parse time %s: %s", time.Stamp, err.Error()) + } + ent = &WithNullTime{ + Id: 1, + Time: sqldb.NullTime{ + Valid: true, + Time: ts, + }} + err = dbmap.Insert(ent) + if err != nil { + t.Errorf("failed insert on %s", err.Error()) + } + err = dbmap.SelectOne(ent, `select * from nulltime_test where `+columnName(dbmap, WithNullTime{}, "Id")+`=:Id`, map[string]interface{}{ + "Id": ent.Id, + }) + if err != nil { + t.Errorf("failed select on %s", err.Error()) + } + if !ent.Time.Valid { + t.Error("sqldb.NullTime returns invalid but expected valid.") + } + if ent.Time.Time.UTC() != ts.UTC() { + t.Errorf("expect %v but got %v.", ts, ent.Time.Time) + } + + return +} + +type WithTime struct { + Id int64 + Time time.Time +} + +type Times struct { + One time.Time + Two time.Time +} + +type EmbeddedTime struct { + Id string + Times +} + +func parseTimeOrPanic(format, date string) time.Time { + t1, err := time.Parse(format, date) + if err != nil { + panic(err) + } + return t1 +} + +func TestWithTime(t *testing.T) { + if _, driver := dialectAndDriver(); driver == "mysql" { + t.Skip("mysql drivers don't support time.Time, skipping...") + } + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + t1 := parseTimeOrPanic("2006-01-02 15:04:05 -0700 MST", + "2013-08-09 21:30:43 +0800 CST") + w1 := WithTime{1, t1} + _insert(dbmap, &w1) + + obj := _get(dbmap, WithTime{}, w1.Id) + w2 := obj.(*WithTime) + if w1.Time.UnixNano() != w2.Time.UnixNano() { + t.Errorf("%v != %v", w1, w2) + } +} + +func TestEmbeddedTime(t *testing.T) { + if _, driver := dialectAndDriver(); driver == "mysql" { + t.Skip("mysql drivers don't support time.Time, skipping...") + } + dbmap := newDBMap(t) + dbmap.AddTable(EmbeddedTime{}).SetKeys(false, "Id") + defer dropAndClose(dbmap) + err := dbmap.CreateTables() + if err != nil { + t.Fatal(err) + } + + time1 := parseTimeOrPanic("2006-01-02 15:04:05", "2013-08-09 21:30:43") + + t1 := &EmbeddedTime{Id: "abc", Times: Times{One: time1, Two: time1.Add(10 * time.Second)}} + _insert(dbmap, t1) + + x := _get(dbmap, EmbeddedTime{}, t1.Id) + t2, _ := x.(*EmbeddedTime) + if t1.One.UnixNano() != t2.One.UnixNano() || t1.Two.UnixNano() != t2.Two.UnixNano() { + t.Errorf("%v != %v", t1, t2) + } +} + +func TestWithTimeSelect(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + halfhourago := time.Now().UTC().Add(-30 * time.Minute) + + w1 := WithTime{1, halfhourago.Add(time.Minute * -1)} + w2 := WithTime{2, halfhourago.Add(time.Second)} + _insert(dbmap, &w1, &w2) + + var caseIds []int64 + _, err := dbmap.Select(&caseIds, "SELECT "+columnName(dbmap, WithTime{}, "Id")+" FROM time_test WHERE "+columnName(dbmap, WithTime{}, "Time")+" < "+dbmap.Dialect.BindVar(0), halfhourago) + + if err != nil { + t.Error(err) + } + if len(caseIds) != 1 { + t.Errorf("%d != 1", len(caseIds)) + } + if caseIds[0] != w1.Id { + t.Errorf("%d != %d", caseIds[0], w1.Id) + } +} + +func TestInvoicePersonView(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + // Create some rows + p1 := &Person{0, 0, 0, "bob", "smith", 0} + dbmap.Insert(p1) + + // notice how we can wire up p1.Id to the invoice easily + inv1 := &Invoice{0, 0, 0, "xmas order", p1.Id, false} + dbmap.Insert(inv1) + + // Run your query + query := "select i." + columnName(dbmap, Invoice{}, "Id") + " InvoiceId, p." + columnName(dbmap, Person{}, "Id") + " PersonId, i." + columnName(dbmap, Invoice{}, "Memo") + ", p." + columnName(dbmap, Person{}, "FName") + " " + + "from invoice_test i, person_test p " + + "where i." + columnName(dbmap, Invoice{}, "PersonId") + " = p." + columnName(dbmap, Person{}, "Id") + + // pass a slice of pointers to Select() + // this avoids the need to type assert after the query is run + var list []*InvoicePersonView + _, err := dbmap.Select(&list, query) + if err != nil { + panic(err) + } + + // this should test true + expected := &InvoicePersonView{inv1.Id, p1.Id, inv1.Memo, p1.FName, 0} + if !reflect.DeepEqual(list[0], expected) { + t.Errorf("%v != %v", list[0], expected) + } +} + +func TestQuoteTableNames(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + quotedTableName := dbmap.Dialect.QuoteField("person_test") + + // Use a buffer to hold the log to check generated queries + logBuffer := &bytes.Buffer{} + dbmap.TraceOn("", log.New(logBuffer, "sqldbtest:", log.Lmicroseconds)) + + // Create some rows + p1 := &Person{0, 0, 0, "bob", "smith", 0} + errorTemplate := "Expected quoted table name %v in query but didn't find it" + + // Check if Insert quotes the table name + id := dbmap.Insert(p1) + if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) { + t.Errorf(errorTemplate, quotedTableName) + } + logBuffer.Reset() + + // Check if Get quotes the table name + dbmap.Get(Person{}, id) + if !bytes.Contains(logBuffer.Bytes(), []byte(quotedTableName)) { + t.Errorf(errorTemplate, quotedTableName) + } + logBuffer.Reset() +} + +func TestSelectTooManyCols(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + p2 := &Person{0, 0, 0, "jane", "doe", 0} + _insert(dbmap, p1) + _insert(dbmap, p2) + + obj := _get(dbmap, Person{}, p1.Id) + p1 = obj.(*Person) + obj = _get(dbmap, Person{}, p2.Id) + p2 = obj.(*Person) + + params := map[string]interface{}{ + "Id": p1.Id, + } + + var p3 FNameOnly + err := dbmap.SelectOne(&p3, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) + if err != nil { + if !sqldb.NonFatalError(err) { + t.Error(err) + } + } else { + t.Errorf("Non-fatal error expected") + } + + if p1.FName != p3.FName { + t.Errorf("%v != %v", p1.FName, p3.FName) + } + + var pSlice []FNameOnly + _, err = dbmap.Select(&pSlice, "select * from person_test order by "+columnName(dbmap, Person{}, "FName")+" asc") + if err != nil { + if !sqldb.NonFatalError(err) { + t.Error(err) + } + } else { + t.Errorf("Non-fatal error expected") + } + + if p1.FName != pSlice[0].FName { + t.Errorf("%v != %v", p1.FName, pSlice[0].FName) + } + if p2.FName != pSlice[1].FName { + t.Errorf("%v != %v", p2.FName, pSlice[1].FName) + } +} + +func TestSelectSingleVal(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p1 := &Person{0, 0, 0, "bob", "smith", 0} + _insert(dbmap, p1) + + obj := _get(dbmap, Person{}, p1.Id) + p1 = obj.(*Person) + + params := map[string]interface{}{ + "Id": p1.Id, + } + + var p2 Person + err := dbmap.SelectOne(&p2, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(p1, &p2) { + t.Errorf("%v != %v", p1, &p2) + } + + // verify SelectOne allows non-struct holders + var s string + err = dbmap.SelectOne(&s, "select "+columnName(dbmap, Person{}, "FName")+" from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) + if err != nil { + t.Error(err) + } + if s != "bob" { + t.Error("Expected bob but got: " + s) + } + + // verify SelectOne requires pointer receiver + err = dbmap.SelectOne(s, "select "+columnName(dbmap, Person{}, "FName")+" from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) + if err == nil { + t.Error("SelectOne should have returned error for non-pointer holder") + } + + // verify SelectOne works with uninitialized pointers + var p3 *Person + err = dbmap.SelectOne(&p3, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", params) + if err != nil { + t.Error(err) + } + + if !reflect.DeepEqual(p1, p3) { + t.Errorf("%v != %v", p1, p3) + } + + // verify that the receiver is still nil if nothing was found + var p4 *Person + dbmap.SelectOne(&p3, "select * from person_test where 2<1 AND "+columnName(dbmap, Person{}, "Id")+"=:Id", params) + if p4 != nil { + t.Error("SelectOne should not have changed a nil receiver when no rows were found") + } + + // verify that the error is set to sql.ErrNoRows if not found + err = dbmap.SelectOne(&p2, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=:Id", map[string]interface{}{ + "Id": -2222, + }) + if err == nil || err != sql.ErrNoRows { + t.Error("SelectOne should have returned an sql.ErrNoRows") + } + + _insert(dbmap, &Person{0, 0, 0, "bob", "smith", 0}) + err = dbmap.SelectOne(&p2, "select * from person_test where "+columnName(dbmap, Person{}, "FName")+"='bob'") + if err == nil { + t.Error("Expected error when two rows found") + } + + // tests for #150 + var tInt int64 + var tStr string + var tBool bool + var tFloat float64 + primVals := []interface{}{tInt, tStr, tBool, tFloat} + for _, prim := range primVals { + err = dbmap.SelectOne(&prim, "select * from person_test where "+columnName(dbmap, Person{}, "Id")+"=-123") + if err == nil || err != sql.ErrNoRows { + t.Error("primVals: SelectOne should have returned sql.ErrNoRows") + } + } +} + +func TestSelectAlias(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + p1 := &IdCreatedExternal{IdCreated: IdCreated{Id: 1, Created: 3}, External: 2} + + // Insert using embedded IdCreated, which reflects the structure of the table + _insert(dbmap, &p1.IdCreated) + + // Select into IdCreatedExternal type, which includes some fields not present + // in id_created_test + var p2 IdCreatedExternal + err := dbmap.SelectOne(&p2, "select * from id_created_test where "+columnName(dbmap, IdCreatedExternal{}, "Id")+"=1") + if err != nil { + t.Error(err) + } + if p2.Id != 1 || p2.Created != 3 || p2.External != 0 { + t.Error("Expected ignored field defaults to not set") + } + + // Prove that we can supply an aliased value in the select, and that it will + // automatically map to IdCreatedExternal.External + err = dbmap.SelectOne(&p2, "SELECT *, 1 AS external FROM id_created_test") + if err != nil { + t.Error(err) + } + if p2.External != 1 { + t.Error("Expected select as can map to exported field.") + } + + var rows *sql.Rows + var cols []string + rows, err = dbmap.Db.Query("SELECT * FROM id_created_test") + cols, err = rows.Columns() + if err != nil || len(cols) != 2 { + t.Error("Expected ignored column not created") + } +} + +func TestMysqlPanicIfDialectNotInitialized(t *testing.T) { + _, driver := dialectAndDriver() + // this test only applies to MySQL + if os.Getenv("SQLDB_TEST_DIALECT") != "mysql" { + return + } + + // The expected behaviour is to catch a panic. + // Here is the deferred function which will check if a panic has indeed occurred : + defer func() { + r := recover() + if r == nil { + t.Error("db.CreateTables() should panic if db is initialized with an incorrect sqldb.MySQLDialect") + } + }() + + // invalid MySQLDialect : does not contain Engine or Encoding specification + dialect := sqldb.MySQLDialect{} + db := &sqldb.DbMap{Db: connect(driver), Dialect: dialect} + db.AddTableWithName(Invoice{}, "invoice") + // the following call should panic : + db.CreateTables() +} + +func TestSingleColumnKeyDbReturnsZeroRowsUpdatedOnPKChange(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + dbmap.AddTableWithName(SingleColumnTable{}, "single_column_table").SetKeys(false, "SomeId") + err := dbmap.DropTablesIfExists() + if err != nil { + t.Error("Drop tables failed") + } + err = dbmap.CreateTablesIfNotExists() + if err != nil { + t.Error("Create tables failed") + } + err = dbmap.TruncateTables() + if err != nil { + t.Error("Truncate tables failed") + } + + sct := SingleColumnTable{ + SomeId: "A Unique Id String", + } + + count, err := dbmap.Update(&sct) + if err != nil { + t.Error(err) + } + if count != 0 { + t.Errorf("Expected 0 updated rows, got %d", count) + } + +} + +func TestPrepare(t *testing.T) { + dbmap := initDBMap(t) + defer dropAndClose(dbmap) + + inv1 := &Invoice{0, 100, 200, "prepare-foo", 0, false} + inv2 := &Invoice{0, 100, 200, "prepare-bar", 0, false} + _insert(dbmap, inv1, inv2) + + bindVar0 := dbmap.Dialect.BindVar(0) + bindVar1 := dbmap.Dialect.BindVar(1) + stmt, err := dbmap.Prepare(fmt.Sprintf("UPDATE invoice_test SET "+columnName(dbmap, Invoice{}, "Memo")+"=%s WHERE "+columnName(dbmap, Invoice{}, "Id")+"=%s", bindVar0, bindVar1)) + if err != nil { + t.Error(err) + } + defer stmt.Close() + _, err = stmt.Exec("prepare-baz", inv1.Id) + if err != nil { + t.Error(err) + } + err = dbmap.SelectOne(inv1, "SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "Memo")+"='prepare-baz'") + if err != nil { + t.Error(err) + } + + trans, err := dbmap.Begin() + if err != nil { + t.Error(err) + } + transStmt, err := trans.Prepare(fmt.Sprintf("UPDATE invoice_test SET "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s WHERE "+columnName(dbmap, Invoice{}, "Id")+"=%s", bindVar0, bindVar1)) + if err != nil { + t.Error(err) + } + defer transStmt.Close() + _, err = transStmt.Exec(true, inv2.Id) + if err != nil { + t.Error(err) + } + err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s", bindVar0), true) + if err == nil || err != sql.ErrNoRows { + t.Error("SelectOne should have returned an sql.ErrNoRows") + } + err = trans.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s", bindVar0), true) + if err != nil { + t.Error(err) + } + err = trans.Commit() + if err != nil { + t.Error(err) + } + err = dbmap.SelectOne(inv2, fmt.Sprintf("SELECT * from invoice_test WHERE "+columnName(dbmap, Invoice{}, "IsPaid")+"=%s", bindVar0), true) + if err != nil { + t.Error(err) + } +} + +type UUID4 string + +func (u UUID4) Value() (driver.Value, error) { + if u == "" { + return nil, nil + } + + return string(u), nil +} + +type NilPointer struct { + ID string + UserID *UUID4 +} + +func TestCallOfValueMethodOnNilPointer(t *testing.T) { + dbmap := newDBMap(t) + dbmap.AddTable(NilPointer{}).SetKeys(false, "ID") + defer dropAndClose(dbmap) + err := dbmap.CreateTables() + if err != nil { + t.Fatal(err) + } + + nilPointer := &NilPointer{ID: "abc", UserID: nil} + _insert(dbmap, nilPointer) +} + +func BenchmarkNativeCrud(b *testing.B) { + b.StopTimer() + dbmap := initDBMapBench(b) + defer dropAndClose(dbmap) + columnId := columnName(dbmap, Invoice{}, "Id") + columnCreated := columnName(dbmap, Invoice{}, "Created") + columnUpdated := columnName(dbmap, Invoice{}, "Updated") + columnMemo := columnName(dbmap, Invoice{}, "Memo") + columnPersonId := columnName(dbmap, Invoice{}, "PersonId") + b.StartTimer() + + var insert, sel, update, delete string + if os.Getenv("SQLDB_TEST_DIALECT") != "postgres" { + insert = "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values (?, ?, ?, ?)" + sel = "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=?" + update = "update invoice_test set " + columnCreated + "=?, " + columnUpdated + "=?, " + columnMemo + "=?, " + columnPersonId + "=? where " + columnId + "=?" + delete = "delete from invoice_test where " + columnId + "=?" + } else { + insert = "insert into invoice_test (" + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + ") values ($1, $2, $3, $4)" + sel = "select " + columnId + ", " + columnCreated + ", " + columnUpdated + ", " + columnMemo + ", " + columnPersonId + " from invoice_test where " + columnId + "=$1" + update = "update invoice_test set " + columnCreated + "=$1, " + columnUpdated + "=$2, " + columnMemo + "=$3, " + columnPersonId + "=$4 where " + columnId + "=$5" + delete = "delete from invoice_test where " + columnId + "=$1" + } + + inv := &Invoice{0, 100, 200, "my memo", 0, false} + + for i := 0; i < b.N; i++ { + res, err := dbmap.Db.Exec(insert, inv.Created, inv.Updated, + inv.Memo, inv.PersonId) + if err != nil { + panic(err) + } + + newid, err := res.LastInsertId() + if err != nil { + panic(err) + } + inv.Id = newid + + row := dbmap.Db.QueryRow(sel, inv.Id) + err = row.Scan(&inv.Id, &inv.Created, &inv.Updated, &inv.Memo, + &inv.PersonId) + if err != nil { + panic(err) + } + + inv.Created = 1000 + inv.Updated = 2000 + inv.Memo = "my memo 2" + inv.PersonId = 3000 + + _, err = dbmap.Db.Exec(update, inv.Created, inv.Updated, inv.Memo, + inv.PersonId, inv.Id) + if err != nil { + panic(err) + } + + _, err = dbmap.Db.Exec(delete, inv.Id) + if err != nil { + panic(err) + } + } + +} + +func BenchmarkSqldbCrud(b *testing.B) { + b.StopTimer() + dbmap := initDBMapBench(b) + defer dropAndClose(dbmap) + b.StartTimer() + + inv := &Invoice{0, 100, 200, "my memo", 0, true} + for i := 0; i < b.N; i++ { + err := dbmap.Insert(inv) + if err != nil { + panic(err) + } + + obj, err := dbmap.Get(Invoice{}, inv.Id) + if err != nil { + panic(err) + } + + inv2, ok := obj.(*Invoice) + if !ok { + panic(fmt.Sprintf("expected *Invoice, got: %v", obj)) + } + + inv2.Created = 1000 + inv2.Updated = 2000 + inv2.Memo = "my memo 2" + inv2.PersonId = 3000 + _, err = dbmap.Update(inv2) + if err != nil { + panic(err) + } + + _, err = dbmap.Delete(inv2) + if err != nil { + panic(err) + } + + } +} + +func initDBMapBench(b *testing.B) *sqldb.DbMap { + dbmap := newDBMap(b) + dbmap.Db.Exec("drop table if exists invoice_test") + dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + return dbmap +} + +func initDBMap(t *testing.T) *sqldb.DbMap { + dbmap := newDBMap(t) + dbmap.AddTableWithName(Invoice{}, "invoice_test").SetKeys(true, "Id") + dbmap.AddTableWithName(InvoiceTag{}, "invoice_tag_test") //key is set via primarykey attribute + dbmap.AddTableWithName(AliasTransientField{}, "alias_trans_field_test").SetKeys(true, "id") + dbmap.AddTableWithName(OverriddenInvoice{}, "invoice_override_test").SetKeys(false, "Id") + dbmap.AddTableWithName(Person{}, "person_test").SetKeys(true, "Id").SetVersionCol("Version") + dbmap.AddTableWithName(WithIgnoredColumn{}, "ignored_column_test").SetKeys(true, "Id") + dbmap.AddTableWithName(IdCreated{}, "id_created_test").SetKeys(true, "Id") + dbmap.AddTableWithName(TypeConversionExample{}, "type_conv_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithEmbeddedStruct{}, "embedded_struct_test").SetKeys(true, "Id") + //dbmap.AddTableWithName(WithEmbeddedStructConflictingEmbeddedMemberNames{}, "embedded_struct_conflict_name_test").SetKeys(true, "Id") + //dbmap.AddTableWithName(WithEmbeddedStructSameMemberName{}, "embedded_struct_same_member_name_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithEmbeddedStructBeforeAutoincrField{}, "embedded_struct_before_autoincr_test").SetKeys(true, "Id") + dbmap.AddTableDynamic(&dynTableInst1, "").SetKeys(true, "Id").AddIndex("TenantInst1Index", "Btree", []string{"Name"}).SetUnique(true) + dbmap.AddTableDynamic(&dynTableInst2, "").SetKeys(true, "Id").AddIndex("TenantInst2Index", "Btree", []string{"Name"}).SetUnique(true) + dbmap.AddTableWithName(WithEmbeddedAutoincr{}, "embedded_autoincr_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithTime{}, "time_test").SetKeys(true, "Id") + dbmap.AddTableWithName(WithNullTime{}, "nulltime_test").SetKeys(false, "Id") + dbmap.TypeConverter = testTypeConverter{} + err := dbmap.DropTablesIfExists() + if err != nil { + panic(err) + } + err = dbmap.CreateTables() + if err != nil { + panic(err) + } + + err = dbmap.CreateIndex() + if err != nil { + panic(err) + } + + // See #146 and TestSelectAlias - this type is mapped to the same + // table as IdCreated, but includes an extra field that isn't in the table + dbmap.AddTableWithName(IdCreatedExternal{}, "id_created_test").SetKeys(true, "Id") + + return dbmap +} + +func initDBMapNulls(t *testing.T) *sqldb.DbMap { + dbmap := newDBMap(t) + dbmap.AddTable(TableWithNull{}).SetKeys(false, "Id") + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + return dbmap +} + +type Logger interface { + Logf(format string, args ...any) +} + +type TestLogger struct { + l Logger +} + +func (l TestLogger) Printf(format string, args ...any) { + l.l.Logf(format, args...) +} + +func newDBMap(l Logger) *sqldb.DbMap { + dialect, driver := dialectAndDriver() + dbmap := &sqldb.DbMap{Db: connect(driver), Dialect: dialect} + if debug { + dbmap.TraceOn("", TestLogger{l: l}) + } + return dbmap +} + +func dropAndClose(dbmap *sqldb.DbMap) { + dbmap.DropTablesIfExists() + dbmap.Db.Close() +} + +func connect(driver string) *sql.DB { + dsn := os.Getenv("SQLDB_TEST_DSN") + if dsn == "" { + panic("SQLDB_TEST_DSN env variable is not set. Please see README.md") + } + + db, err := sql.Open(driver, dsn) + if err != nil { + panic("Error connecting to db: " + err.Error()) + } + return db +} + +func dialectAndDriver() (sqldb.Dialect, string) { + switch os.Getenv("SQLDB_TEST_DIALECT") { + case "mysql", "gomysql": + // NOTE: the 'mysql' driver used to use github.com/ziutek/mymysql, but that project + // seems mostly unmaintained recently. We've dropped it from tests, at least for + // now. + return sqldb.MySQLDialect{"InnoDB", "UTF8"}, "mysql" + case "postgres": + return sqldb.PostgresDialect{}, "postgres" + case "sqlite": + return sqldb.SqliteDialect{}, "sqlite3" + } + panic("SQLDB_TEST_DIALECT env variable is not set or is invalid. Please see README.md") +} + +func _insert(dbmap *sqldb.DbMap, list ...interface{}) { + err := dbmap.Insert(list...) + if err != nil { + panic(err) + } +} + +func _update(dbmap *sqldb.DbMap, list ...interface{}) int64 { + count, err := dbmap.Update(list...) + if err != nil { + panic(err) + } + return count +} + +func _updateColumns(dbmap *sqldb.DbMap, filter sqldb.ColumnFilter, list ...interface{}) int64 { + count, err := dbmap.UpdateColumns(filter, list...) + if err != nil { + panic(err) + } + return count +} + +func _del(dbmap *sqldb.DbMap, list ...interface{}) int64 { + count, err := dbmap.Delete(list...) + if err != nil { + panic(err) + } + + return count +} + +func _get(dbmap *sqldb.DbMap, i interface{}, keys ...interface{}) interface{} { + obj, err := dbmap.Get(i, keys...) + if err != nil { + panic(err) + } + + return obj +} + +func selectInt(dbmap *sqldb.DbMap, query string, args ...interface{}) int64 { + i64, err := sqldb.SelectInt(dbmap, query, args...) + if err != nil { + panic(err) + } + + return i64 +} + +func selectNullInt(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.NullInt64 { + i64, err := sqldb.SelectNullInt(dbmap, query, args...) + if err != nil { + panic(err) + } + + return i64 +} + +func selectFloat(dbmap *sqldb.DbMap, query string, args ...interface{}) float64 { + f64, err := sqldb.SelectFloat(dbmap, query, args...) + if err != nil { + panic(err) + } + + return f64 +} + +func selectNullFloat(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.NullFloat64 { + f64, err := sqldb.SelectNullFloat(dbmap, query, args...) + if err != nil { + panic(err) + } + + return f64 +} + +func selectStr(dbmap *sqldb.DbMap, query string, args ...interface{}) string { + s, err := sqldb.SelectStr(dbmap, query, args...) + if err != nil { + panic(err) + } + + return s +} + +func selectNullStr(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.NullString { + s, err := sqldb.SelectNullStr(dbmap, query, args...) + if err != nil { + panic(err) + } + + return s +} + +func rawExec(dbmap *sqldb.DbMap, query string, args ...interface{}) sql.Result { + res, err := dbmap.Exec(query, args...) + if err != nil { + panic(err) + } + return res +} + +func rawSelect(dbmap *sqldb.DbMap, i interface{}, query string, args ...interface{}) []interface{} { + list, err := dbmap.Select(i, query, args...) + if err != nil { + panic(err) + } + return list +} + +func tableName(dbmap *sqldb.DbMap, i interface{}) string { + t := reflect.TypeOf(i) + if table, err := dbmap.TableFor(t, false); table != nil && err == nil { + return dbmap.Dialect.QuoteField(table.TableName) + } + return t.Name() +} + +func columnName(dbmap *sqldb.DbMap, i interface{}, fieldName string) string { + t := reflect.TypeOf(i) + if table, err := dbmap.TableFor(t, false); table != nil && err == nil { + return dbmap.Dialect.QuoteField(table.ColMap(fieldName).ColumnName) + } + return fieldName +} diff --git a/gdb/sqldb/table.go b/gdb/sqldb/table.go new file mode 100644 index 0000000..c628aa9 --- /dev/null +++ b/gdb/sqldb/table.go @@ -0,0 +1,258 @@ +// +// table.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "bytes" + "fmt" + "reflect" + "strings" +) + +// TableMap represents a mapping between a Go struct and a database table +// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these +type TableMap struct { + // Name of database table. + TableName string + SchemaName string + gotype reflect.Type + Columns []*ColumnMap + keys []*ColumnMap + indexes []*IndexMap + uniqueTogether [][]string + version *ColumnMap + insertPlan bindPlan + updatePlan bindPlan + deletePlan bindPlan + getPlan bindPlan + dbmap *DbMap +} + +// ResetSql removes cached insert/update/select/delete SQL strings +// associated with this TableMap. Call this if you've modified +// any column names or the table name itself. +func (t *TableMap) ResetSql() { + t.insertPlan = bindPlan{} + t.updatePlan = bindPlan{} + t.deletePlan = bindPlan{} + t.getPlan = bindPlan{} +} + +// SetKeys lets you specify the fields on a struct that map to primary +// key columns on the table. If isAutoIncr is set, result.LastInsertId() +// will be used after INSERT to bind the generated id to the Go struct. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +// +// Panics if isAutoIncr is true, and fieldNames length != 1 +func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap { + if isAutoIncr && len(fieldNames) != 1 { + panic(fmt.Sprintf( + "sqldb: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)", + len(fieldNames))) + } + t.keys = make([]*ColumnMap, 0) + for _, name := range fieldNames { + colmap := t.ColMap(name) + colmap.isPK = true + colmap.isAutoIncr = isAutoIncr + t.keys = append(t.keys, colmap) + } + t.ResetSql() + + return t +} + +// SetUniqueTogether lets you specify uniqueness constraints across multiple +// columns on the table. Each call adds an additional constraint for the +// specified columns. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +// +// Panics if fieldNames length < 2. +func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap { + if len(fieldNames) < 2 { + panic(fmt.Sprintf( + "sqldb: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint.")) + } + + columns := make([]string, 0, len(fieldNames)) + for _, name := range fieldNames { + columns = append(columns, name) + } + + for _, existingColumns := range t.uniqueTogether { + if equal(existingColumns, columns) { + return t + } + } + t.uniqueTogether = append(t.uniqueTogether, columns) + t.ResetSql() + + return t +} + +// ColMap returns the ColumnMap pointer matching the given struct field +// name. It panics if the struct does not contain a field matching this +// name. +func (t *TableMap) ColMap(field string) *ColumnMap { + col := colMapOrNil(t, field) + if col == nil { + e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s", + t.TableName, t.gotype.Name(), field) + + panic(e) + } + return col +} + +func colMapOrNil(t *TableMap, field string) *ColumnMap { + for _, col := range t.Columns { + if col.fieldName == field || col.ColumnName == field { + return col + } + } + return nil +} + +// IdxMap returns the IndexMap pointer matching the given index name. +func (t *TableMap) IdxMap(field string) *IndexMap { + for _, idx := range t.indexes { + if idx.IndexName == field { + return idx + } + } + return nil +} + +// AddIndex registers the index with sqldb for specified table with given parameters. +// This operation is idempotent. If index is already mapped, the +// existing *IndexMap is returned +// Function will panic if one of the given for index columns does not exists +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +func (t *TableMap) AddIndex(name string, idxtype string, columns []string) *IndexMap { + // check if we have a index with this name already + for _, idx := range t.indexes { + if idx.IndexName == name { + return idx + } + } + for _, icol := range columns { + if res := t.ColMap(icol); res == nil { + e := fmt.Sprintf("No ColumnName in table %s to create index on", t.TableName) + panic(e) + } + } + + idx := &IndexMap{IndexName: name, Unique: false, IndexType: idxtype, columns: columns} + t.indexes = append(t.indexes, idx) + t.ResetSql() + return idx +} + +// SetVersionCol sets the column to use as the Version field. By default +// the "Version" field is used. Returns the column found, or panics +// if the struct does not contain a field matching this name. +// +// Automatically calls ResetSql() to ensure SQL statements are regenerated. +func (t *TableMap) SetVersionCol(field string) *ColumnMap { + c := t.ColMap(field) + t.version = c + t.ResetSql() + return c +} + +// SqlForCreateTable gets a sequence of SQL commands that will create +// the specified table and any associated schema +func (t *TableMap) SqlForCreate(ifNotExists bool) string { + s := bytes.Buffer{} + dialect := t.dbmap.Dialect + + if strings.TrimSpace(t.SchemaName) != "" { + schemaCreate := "create schema" + if ifNotExists { + s.WriteString(dialect.IfSchemaNotExists(schemaCreate, t.SchemaName)) + } else { + s.WriteString(schemaCreate) + } + s.WriteString(fmt.Sprintf(" %s;", t.SchemaName)) + } + + tableCreate := "create table" + if ifNotExists { + s.WriteString(dialect.IfTableNotExists(tableCreate, t.SchemaName, t.TableName)) + } else { + s.WriteString(tableCreate) + } + s.WriteString(fmt.Sprintf(" %s (", dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + + x := 0 + for _, col := range t.Columns { + if !col.Transient { + if x > 0 { + s.WriteString(", ") + } + stype := dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr) + s.WriteString(fmt.Sprintf("%s %s", dialect.QuoteField(col.ColumnName), stype)) + + if col.isPK || col.isNotNull { + s.WriteString(" not null") + } + if col.isPK && len(t.keys) == 1 { + s.WriteString(" primary key") + } + if col.Unique { + s.WriteString(" unique") + } + if col.isAutoIncr { + s.WriteString(fmt.Sprintf(" %s", dialect.AutoIncrStr())) + } + + x++ + } + } + if len(t.keys) > 1 { + s.WriteString(", primary key (") + for x := range t.keys { + if x > 0 { + s.WriteString(", ") + } + s.WriteString(dialect.QuoteField(t.keys[x].ColumnName)) + } + s.WriteString(")") + } + if len(t.uniqueTogether) > 0 { + for _, columns := range t.uniqueTogether { + s.WriteString(", unique (") + for i, column := range columns { + if i > 0 { + s.WriteString(", ") + } + s.WriteString(dialect.QuoteField(column)) + } + s.WriteString(")") + } + } + s.WriteString(") ") + s.WriteString(dialect.CreateTableSuffix()) + s.WriteString(dialect.QuerySuffix()) + return s.String() +} + +func equal(a, b []string) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} diff --git a/gdb/sqldb/table_bindings.go b/gdb/sqldb/table_bindings.go new file mode 100644 index 0000000..13a4ca8 --- /dev/null +++ b/gdb/sqldb/table_bindings.go @@ -0,0 +1,308 @@ +// +// table_bindings.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "bytes" + "fmt" + "reflect" + "sync" +) + +// CustomScanner binds a database column value to a Go type +type CustomScanner struct { + // After a row is scanned, Holder will contain the value from the database column. + // Initialize the CustomScanner with the concrete Go type you wish the database + // driver to scan the raw column into. + Holder interface{} + // Target typically holds a pointer to the target struct field to bind the Holder + // value to. + Target interface{} + // Binder is a custom function that converts the holder value to the target type + // and sets target accordingly. This function should return error if a problem + // occurs converting the holder to the target. + Binder func(holder interface{}, target interface{}) error +} + +// Used to filter columns when selectively updating +type ColumnFilter func(*ColumnMap) bool + +func acceptAllFilter(col *ColumnMap) bool { + return true +} + +// Bind is called automatically by sqldb after Scan() +func (me CustomScanner) Bind() error { + return me.Binder(me.Holder, me.Target) +} + +type bindPlan struct { + query string + argFields []string + keyFields []string + versField string + autoIncrIdx int + autoIncrFieldName string + once sync.Once +} + +func (plan *bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) { + bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField} + if plan.versField != "" { + bi.existingVersion = elem.FieldByName(plan.versField).Int() + } + + var err error + + for i := 0; i < len(plan.argFields); i++ { + k := plan.argFields[i] + if k == versFieldConst { + newVer := bi.existingVersion + 1 + bi.args = append(bi.args, newVer) + if bi.existingVersion == 0 { + elem.FieldByName(plan.versField).SetInt(int64(newVer)) + } + } else { + val := elem.FieldByName(k).Interface() + if conv != nil { + val, err = conv.ToDb(val) + if err != nil { + return bindInstance{}, err + } + } + bi.args = append(bi.args, val) + } + } + + for i := 0; i < len(plan.keyFields); i++ { + k := plan.keyFields[i] + val := elem.FieldByName(k).Interface() + if conv != nil { + val, err = conv.ToDb(val) + if err != nil { + return bindInstance{}, err + } + } + bi.keys = append(bi.keys, val) + } + + return bi, nil +} + +type bindInstance struct { + query string + args []interface{} + keys []interface{} + existingVersion int64 + versField string + autoIncrIdx int + autoIncrFieldName string +} + +func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) { + plan := &t.insertPlan + plan.once.Do(func() { + plan.autoIncrIdx = -1 + + s := bytes.Buffer{} + s2 := bytes.Buffer{} + s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + + x := 0 + first := true + for y := range t.Columns { + col := t.Columns[y] + if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") { + if !col.Transient { + if !first { + s.WriteString(",") + s2.WriteString(",") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + + if col.isAutoIncr { + s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue()) + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName + } else { + if col.DefaultValue == "" { + s2.WriteString(t.dbmap.Dialect.BindVar(x)) + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + x++ + } else { + s2.WriteString(col.DefaultValue) + } + } + first = false + } + } else { + plan.autoIncrIdx = y + plan.autoIncrFieldName = col.fieldName + } + } + s.WriteString(") values (") + s.WriteString(s2.String()) + s.WriteString(")") + if plan.autoIncrIdx > -1 { + s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx])) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + }) + + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} + +func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) { + if colFilter == nil { + colFilter = acceptAllFilter + } + + plan := &t.updatePlan + plan.once.Do(func() { + s := bytes.Buffer{} + s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + x := 0 + + for y := range t.Columns { + col := t.Columns[y] + if !col.isAutoIncr && !col.Transient && colFilter(col) { + if x > 0 { + s.WriteString(", ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + if col == t.version { + plan.versField = col.fieldName + plan.argFields = append(plan.argFields, versFieldConst) + } else { + plan.argFields = append(plan.argFields, col.fieldName) + } + x++ + } + } + + s.WriteString(" where ") + for y := range t.keys { + col := t.keys[y] + if y > 0 { + s.WriteString(" and ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + plan.argFields = append(plan.argFields, col.fieldName) + plan.keyFields = append(plan.keyFields, col.fieldName) + x++ + } + if plan.versField != "" { + s.WriteString(" and ") + s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + plan.argFields = append(plan.argFields, plan.versField) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + }) + + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} + +func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) { + plan := &t.deletePlan + plan.once.Do(func() { + s := bytes.Buffer{} + s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))) + + for y := range t.Columns { + col := t.Columns[y] + if !col.Transient { + if col == t.version { + plan.versField = col.fieldName + } + } + } + + s.WriteString(" where ") + for x := range t.keys { + k := t.keys[x] + if x > 0 { + s.WriteString(" and ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + plan.keyFields = append(plan.keyFields, k.fieldName) + plan.argFields = append(plan.argFields, k.fieldName) + } + if plan.versField != "" { + s.WriteString(" and ") + s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields))) + + plan.argFields = append(plan.argFields, plan.versField) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + }) + + return plan.createBindInstance(elem, t.dbmap.TypeConverter) +} + +func (t *TableMap) bindGet() *bindPlan { + plan := &t.getPlan + plan.once.Do(func() { + s := bytes.Buffer{} + s.WriteString("select ") + + x := 0 + for _, col := range t.Columns { + if !col.Transient { + if x > 0 { + s.WriteString(",") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + plan.argFields = append(plan.argFields, col.fieldName) + x++ + } + } + s.WriteString(" from ") + s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)) + s.WriteString(" where ") + for x := range t.keys { + col := t.keys[x] + if x > 0 { + s.WriteString(" and ") + } + s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName)) + s.WriteString("=") + s.WriteString(t.dbmap.Dialect.BindVar(x)) + + plan.keyFields = append(plan.keyFields, col.fieldName) + } + s.WriteString(t.dbmap.Dialect.QuerySuffix()) + + plan.query = s.String() + }) + + return plan +} diff --git a/gdb/sqldb/test_all.sh b/gdb/sqldb/test_all.sh new file mode 100755 index 0000000..0d4a549 --- /dev/null +++ b/gdb/sqldb/test_all.sh @@ -0,0 +1,24 @@ +#!/bin/bash -ex + +# on macs, you may need to: +# export GOBUILDFLAG=-ldflags -linkmode=external + +echo "Running unit tests" +go test -race + +echo "Testing against postgres" +export SQLDB_TEST_DSN="host=127.0.0.1 user=testuser password=123 dbname=testdb sslmode=disable" +export SQLDB_TEST_DIALECT=postgres +go test -tags integration $GOBUILDFLAG $@ . + +echo "Testing against sqlite" +export SQLDB_TEST_DSN=/tmp/testdb.bin +export SQLDB_TEST_DIALECT=sqlite +go test -tags integration $GOBUILDFLAG $@ . +rm -f /tmp/testdb.bin + +echo "Testing against mysql" +# export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local" +export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb" +export SQLDB_TEST_DIALECT=mysql +go test -tags integration $GOBUILDFLAG $@ . diff --git a/gdb/sqldb/transaction.go b/gdb/sqldb/transaction.go new file mode 100644 index 0000000..b6d0564 --- /dev/null +++ b/gdb/sqldb/transaction.go @@ -0,0 +1,242 @@ +// +// transaction.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "context" + "database/sql" + "time" +) + +// Transaction represents a database transaction. +// Insert/Update/Delete/Get/Exec operations will be run in the context +// of that transaction. Transactions should be terminated with +// a call to Commit() or Rollback() +type Transaction struct { + ctx context.Context + dbmap *DbMap + tx *sql.Tx + closed bool +} + +func (t *Transaction) WithContext(ctx context.Context) SqlExecutor { + copy := &Transaction{} + *copy = *t + copy.ctx = ctx + return copy +} + +// Insert has the same behavior as DbMap.Insert(), but runs in a transaction. +func (t *Transaction) Insert(list ...interface{}) error { + return insert(t.dbmap, t, list...) +} + +// Update had the same behavior as DbMap.Update(), but runs in a transaction. +func (t *Transaction) Update(list ...interface{}) (int64, error) { + return update(t.dbmap, t, nil, list...) +} + +// UpdateColumns had the same behavior as DbMap.UpdateColumns(), but runs in a transaction. +func (t *Transaction) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) { + return update(t.dbmap, t, filter, list...) +} + +// Delete has the same behavior as DbMap.Delete(), but runs in a transaction. +func (t *Transaction) Delete(list ...interface{}) (int64, error) { + return delete(t.dbmap, t, list...) +} + +// Get has the same behavior as DbMap.Get(), but runs in a transaction. +func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) { + return get(t.dbmap, t, i, keys...) +} + +// Select has the same behavior as DbMap.Select(), but runs in a transaction. +func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return hookedselect(t.dbmap, t, i, query, args...) +} + +// Exec has the same behavior as DbMap.Exec(), but runs in a transaction. +func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, args...) + } + return maybeExpandNamedQueryAndExec(t, query, args...) +} + +// SelectInt is a convenience wrapper around the sqldb.SelectInt function. +func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectInt(t, query, args...) +} + +// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function. +func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectNullInt(t, query, args...) +} + +// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function. +func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectFloat(t, query, args...) +} + +// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function. +func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectNullFloat(t, query, args...) +} + +// SelectStr is a convenience wrapper around the sqldb.SelectStr function. +func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectStr(t, query, args...) +} + +// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function. +func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectNullStr(t, query, args...) +} + +// SelectOne is a convenience wrapper around the sqldb.SelectOne function. +func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + return SelectOne(t.dbmap, t, holder, query, args...) +} + +// Commit commits the underlying database transaction. +func (t *Transaction) Commit() error { + if !t.closed { + t.closed = true + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, "commit;") + } + return t.tx.Commit() + } + + return sql.ErrTxDone +} + +// Rollback rolls back the underlying database transaction. +func (t *Transaction) Rollback() error { + if !t.closed { + t.closed = true + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, "rollback;") + } + return t.tx.Rollback() + } + + return sql.ErrTxDone +} + +// Savepoint creates a savepoint with the given name. The name is interpolated +// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is +// derived from user input. +func (t *Transaction) Savepoint(name string) error { + query := "savepoint " + t.dbmap.Dialect.QuoteField(name) + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + _, err := exec(t, query) + return err +} + +// RollbackToSavepoint rolls back to the savepoint with the given name. The +// name is interpolated directly into the SQL SAVEPOINT statement, so you must +// sanitize it if it is derived from user input. +func (t *Transaction) RollbackToSavepoint(savepoint string) error { + query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint) + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + _, err := exec(t, query) + return err +} + +// ReleaseSavepint releases the savepoint with the given name. The name is +// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize +// it if it is derived from user input. +func (t *Transaction) ReleaseSavepoint(savepoint string) error { + query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint) + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + _, err := exec(t, query) + return err +} + +// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction. +func (t *Transaction) Prepare(query string) (*sql.Stmt, error) { + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, nil) + } + return prepare(t, query) +} + +func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&query, args...) + } + + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, query, args...) + } + return queryRow(t, query, args...) +} + +func (t *Transaction) Query(q string, args ...interface{}) (*sql.Rows, error) { + if t.dbmap.ExpandSliceArgs { + expandSliceArgs(&q, args...) + } + + if t.dbmap.logger != nil { + now := time.Now() + defer t.dbmap.trace(now, q, args...) + } + return query(t, q, args...) +} diff --git a/gdb/sqldb/transaction_test.go b/gdb/sqldb/transaction_test.go new file mode 100644 index 0000000..0686a8d --- /dev/null +++ b/gdb/sqldb/transaction_test.go @@ -0,0 +1,340 @@ +// +// transaction_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +//go:build integration +// +build integration + +package sqldb_test + +import "testing" + +func TestTransaction_Select_expandSliceArgs(t *testing.T) { + tests := []struct { + description string + query string + args []interface{} + wantLen int + }{ + { + description: "it should handle slice placeholders correctly", + query: ` +SELECT 1 FROM crazy_table +WHERE field1 = :Field1 +AND field2 IN (:FieldStringList) +AND field3 IN (:FieldUIntList) +AND field4 IN (:FieldUInt8List) +AND field5 IN (:FieldUInt16List) +AND field6 IN (:FieldUInt32List) +AND field7 IN (:FieldUInt64List) +AND field8 IN (:FieldIntList) +AND field9 IN (:FieldInt8List) +AND field10 IN (:FieldInt16List) +AND field11 IN (:FieldInt32List) +AND field12 IN (:FieldInt64List) +AND field13 IN (:FieldFloat32List) +AND field14 IN (:FieldFloat64List) +`, + args: []interface{}{ + map[string]interface{}{ + "Field1": 123, + "FieldStringList": []string{"h", "e", "y"}, + "FieldUIntList": []uint{1, 2, 3, 4}, + "FieldUInt8List": []uint8{1, 2, 3, 4}, + "FieldUInt16List": []uint16{1, 2, 3, 4}, + "FieldUInt32List": []uint32{1, 2, 3, 4}, + "FieldUInt64List": []uint64{1, 2, 3, 4}, + "FieldIntList": []int{1, 2, 3, 4}, + "FieldInt8List": []int8{1, 2, 3, 4}, + "FieldInt16List": []int16{1, 2, 3, 4}, + "FieldInt32List": []int32{1, 2, 3, 4}, + "FieldInt64List": []int64{1, 2, 3, 4}, + "FieldFloat32List": []float32{1, 2, 3, 4}, + "FieldFloat64List": []float64{1, 2, 3, 4}, + }, + }, + wantLen: 1, + }, + { + description: "it should handle slice placeholders correctly with custom types", + query: ` +SELECT 1 FROM crazy_table +WHERE field2 IN (:FieldStringList) +AND field12 IN (:FieldIntList) +`, + args: []interface{}{ + map[string]interface{}{ + "FieldStringList": customType1{"h", "e", "y"}, + "FieldIntList": customType2{1, 2, 3, 4}, + }, + }, + wantLen: 3, + }, + } + + type dataFormat struct { + Field1 int `db:"field1"` + Field2 string `db:"field2"` + Field3 uint `db:"field3"` + Field4 uint8 `db:"field4"` + Field5 uint16 `db:"field5"` + Field6 uint32 `db:"field6"` + Field7 uint64 `db:"field7"` + Field8 int `db:"field8"` + Field9 int8 `db:"field9"` + Field10 int16 `db:"field10"` + Field11 int32 `db:"field11"` + Field12 int64 `db:"field12"` + Field13 float32 `db:"field13"` + Field14 float64 `db:"field14"` + } + + dbmap := newDBMap(t) + dbmap.ExpandSliceArgs = true + dbmap.AddTableWithName(dataFormat{}, "crazy_table") + + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + err = dbmap.Insert( + &dataFormat{ + Field1: 123, + Field2: "h", + Field3: 1, + Field4: 1, + Field5: 1, + Field6: 1, + Field7: 1, + Field8: 1, + Field9: 1, + Field10: 1, + Field11: 1, + Field12: 1, + Field13: 1, + Field14: 1, + }, + &dataFormat{ + Field1: 124, + Field2: "e", + Field3: 2, + Field4: 2, + Field5: 2, + Field6: 2, + Field7: 2, + Field8: 2, + Field9: 2, + Field10: 2, + Field11: 2, + Field12: 2, + Field13: 2, + Field14: 2, + }, + &dataFormat{ + Field1: 125, + Field2: "y", + Field3: 3, + Field4: 3, + Field5: 3, + Field6: 3, + Field7: 3, + Field8: 3, + Field9: 3, + Field10: 3, + Field11: 3, + Field12: 3, + Field13: 3, + Field14: 3, + }, + ) + + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + tx, err := dbmap.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + var dummy []int + _, err = tx.Select(&dummy, tt.query, tt.args...) + if err != nil { + t.Fatal(err) + } + + if len(dummy) != tt.wantLen { + t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen) + } + }) + } +} + +func TestTransaction_Exec_expandSliceArgs(t *testing.T) { + tests := []struct { + description string + query string + args []interface{} + wantLen int + }{ + { + description: "it should handle slice placeholders correctly", + query: ` +DELETE FROM crazy_table +WHERE field1 = :Field1 +AND field2 IN (:FieldStringList) +AND field3 IN (:FieldUIntList) +AND field4 IN (:FieldUInt8List) +AND field5 IN (:FieldUInt16List) +AND field6 IN (:FieldUInt32List) +AND field7 IN (:FieldUInt64List) +AND field8 IN (:FieldIntList) +AND field9 IN (:FieldInt8List) +AND field10 IN (:FieldInt16List) +AND field11 IN (:FieldInt32List) +AND field12 IN (:FieldInt64List) +AND field13 IN (:FieldFloat32List) +AND field14 IN (:FieldFloat64List) +`, + args: []interface{}{ + map[string]interface{}{ + "Field1": 123, + "FieldStringList": []string{"h", "e", "y"}, + "FieldUIntList": []uint{1, 2, 3, 4}, + "FieldUInt8List": []uint8{1, 2, 3, 4}, + "FieldUInt16List": []uint16{1, 2, 3, 4}, + "FieldUInt32List": []uint32{1, 2, 3, 4}, + "FieldUInt64List": []uint64{1, 2, 3, 4}, + "FieldIntList": []int{1, 2, 3, 4}, + "FieldInt8List": []int8{1, 2, 3, 4}, + "FieldInt16List": []int16{1, 2, 3, 4}, + "FieldInt32List": []int32{1, 2, 3, 4}, + "FieldInt64List": []int64{1, 2, 3, 4}, + "FieldFloat32List": []float32{1, 2, 3, 4}, + "FieldFloat64List": []float64{1, 2, 3, 4}, + }, + }, + wantLen: 1, + }, + { + description: "it should handle slice placeholders correctly with custom types", + query: ` +DELETE FROM crazy_table +WHERE field2 IN (:FieldStringList) +AND field12 IN (:FieldIntList) +`, + args: []interface{}{ + map[string]interface{}{ + "FieldStringList": customType1{"h", "e", "y"}, + "FieldIntList": customType2{1, 2, 3, 4}, + }, + }, + wantLen: 3, + }, + } + + type dataFormat struct { + Field1 int `db:"field1"` + Field2 string `db:"field2"` + Field3 uint `db:"field3"` + Field4 uint8 `db:"field4"` + Field5 uint16 `db:"field5"` + Field6 uint32 `db:"field6"` + Field7 uint64 `db:"field7"` + Field8 int `db:"field8"` + Field9 int8 `db:"field9"` + Field10 int16 `db:"field10"` + Field11 int32 `db:"field11"` + Field12 int64 `db:"field12"` + Field13 float32 `db:"field13"` + Field14 float64 `db:"field14"` + } + + dbmap := newDBMap(t) + dbmap.ExpandSliceArgs = true + dbmap.AddTableWithName(dataFormat{}, "crazy_table") + + err := dbmap.CreateTables() + if err != nil { + panic(err) + } + defer dropAndClose(dbmap) + + err = dbmap.Insert( + &dataFormat{ + Field1: 123, + Field2: "h", + Field3: 1, + Field4: 1, + Field5: 1, + Field6: 1, + Field7: 1, + Field8: 1, + Field9: 1, + Field10: 1, + Field11: 1, + Field12: 1, + Field13: 1, + Field14: 1, + }, + &dataFormat{ + Field1: 124, + Field2: "e", + Field3: 2, + Field4: 2, + Field5: 2, + Field6: 2, + Field7: 2, + Field8: 2, + Field9: 2, + Field10: 2, + Field11: 2, + Field12: 2, + Field13: 2, + Field14: 2, + }, + &dataFormat{ + Field1: 125, + Field2: "y", + Field3: 3, + Field4: 3, + Field5: 3, + Field6: 3, + Field7: 3, + Field8: 3, + Field9: 3, + Field10: 3, + Field11: 3, + Field12: 3, + Field13: 3, + Field14: 3, + }, + ) + + if err != nil { + t.Fatal(err) + } + + for _, tt := range tests { + t.Run(tt.description, func(t *testing.T) { + tx, err := dbmap.Begin() + if err != nil { + t.Fatal(err) + } + defer tx.Rollback() + + _, err = tx.Exec(tt.query, tt.args...) + if err != nil { + t.Fatal(err) + } + }) + } +} diff --git a/go.mod b/go.mod index bd3fc5b..46ffe0f 100644 --- a/go.mod +++ b/go.mod @@ -8,17 +8,20 @@ require ( github.com/go-redis/redis/v8 v8.11.5 github.com/go-sql-driver/mysql v1.7.1 github.com/hibiken/asynq v0.24.1 - github.com/jmoiron/sqlx v1.3.5 github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible github.com/lib/pq v1.10.9 github.com/mattn/go-runewidth v0.0.14 + github.com/mattn/go-sqlite3 v1.14.6 github.com/pkg/errors v0.9.1 + github.com/poy/onpar v0.3.2 github.com/rs/xid v1.5.0 + github.com/stretchr/testify v1.8.3 go.mongodb.org/mongo-driver v1.11.7 go.uber.org/zap v1.25.0 golang.org/x/crypto v0.10.0 gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/yaml.v2 v2.4.0 + gopkg.in/yaml.v3 v3.0.1 ) require ( @@ -26,6 +29,7 @@ require ( github.com/bytedance/sonic v1.9.1 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gin-contrib/sse v0.1.0 // indirect @@ -47,6 +51,7 @@ require ( github.com/modern-go/reflect2 v1.0.2 // indirect github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/redis/go-redis/v9 v9.0.3 // indirect github.com/rivo/uniseg v0.2.0 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect @@ -65,5 +70,4 @@ require ( golang.org/x/text v0.10.0 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect google.golang.org/protobuf v1.30.0 // indirect - gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 80a9d1a..14da58e 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,4 @@ +git.sr.ht/~nelsam/hel v0.4.3 h1:9W0zz8zv8CZhFsp8r9Wq6c8gFemBdtMurjZU/JKfvfM= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= @@ -36,7 +37,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= -github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= @@ -57,8 +57,6 @@ github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw= github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts= -github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= -github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= @@ -81,7 +79,6 @@ github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible h1:Y6sqxHMyB1D2YSzWkL github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible/go.mod h1:ZQnN8lSECaebrkQytbHj4xNgtg8CR7RYXnPok8e0EHA= github.com/lestrrat-go/strftime v1.0.6 h1:CFGsDEt1pOpFNU+TJB0nhz9jl+K0hZSLE205AhTIGQQ= github.com/lestrrat-go/strftime v1.0.6/go.mod h1:f7jQKgV5nnJpYgdEasS+/y7EsTb8ykN2z68n3TtcTaw= -github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= @@ -106,6 +103,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/poy/onpar v0.3.2 h1:yo8ZRqU3C4RlvkXPWUWfonQiTodAgpKQZ1g8VTNU9xU= +github.com/poy/onpar v0.3.2/go.mod h1:6XDWG8DJ1HsFX6/Btn0pHl3Jz5d1SEEGNZ5N1gtYo+I= github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=