refactor: 移除 db 库,只保持已有使用的 mgodb
This commit is contained in:
parent
1ceb2331d8
commit
aa784c4187
3
gdb/orm/.gitignore
vendored
3
gdb/orm/.gitignore
vendored
@ -1,3 +0,0 @@
|
|||||||
**/.idea/*
|
|
||||||
cover.out
|
|
||||||
**db
|
|
@ -1,157 +0,0 @@
|
|||||||
//
|
|
||||||
// binder.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"unsafe"
|
|
||||||
)
|
|
||||||
|
|
||||||
// makeNewPointersOf creates a map of [field name] -> pointer to fill it
|
|
||||||
// recursively. it will go down until reaches a driver.Valuer implementation, it will stop there.
|
|
||||||
func (b *binder) makeNewPointersOf(v reflect.Value) interface{} {
|
|
||||||
m := map[string]interface{}{}
|
|
||||||
actualV := v
|
|
||||||
for actualV.Type().Kind() == reflect.Ptr {
|
|
||||||
actualV = actualV.Elem()
|
|
||||||
}
|
|
||||||
if actualV.Type().Kind() == reflect.Struct {
|
|
||||||
for i := 0; i < actualV.NumField(); i++ {
|
|
||||||
f := actualV.Field(i)
|
|
||||||
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
|
||||||
f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr()))
|
|
||||||
fm := b.makeNewPointersOf(f).(map[string]interface{})
|
|
||||||
for k, p := range fm {
|
|
||||||
m[k] = p
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
var fm *field
|
|
||||||
fm = b.s.getField(actualV.Type().Field(i))
|
|
||||||
if fm == nil {
|
|
||||||
fm = fieldMetadata(actualV.Type().Field(i), b.s.columnConstraints)[0]
|
|
||||||
}
|
|
||||||
m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return v.Addr().Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
|
|
||||||
// ptrsFor first allocates for all struct fields recursively until reaches a driver.Value impl
|
|
||||||
// then it will put them in a map with their correct field name as key, then loops over cts
|
|
||||||
// and for each one gets appropriate one from the map and adds it to pointer list.
|
|
||||||
func (b *binder) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} {
|
|
||||||
ptrs := b.makeNewPointersOf(v)
|
|
||||||
var scanInto []interface{}
|
|
||||||
if reflect.TypeOf(ptrs).Kind() == reflect.Map {
|
|
||||||
nameToPtr := ptrs.(map[string]interface{})
|
|
||||||
for _, ct := range cts {
|
|
||||||
if nameToPtr[ct.Name()] != nil {
|
|
||||||
scanInto = append(scanInto, nameToPtr[ct.Name()])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
scanInto = append(scanInto, ptrs)
|
|
||||||
}
|
|
||||||
|
|
||||||
return scanInto
|
|
||||||
}
|
|
||||||
|
|
||||||
type binder struct {
|
|
||||||
s *schema
|
|
||||||
}
|
|
||||||
|
|
||||||
func newBinder(s *schema) *binder {
|
|
||||||
return &binder{s: s}
|
|
||||||
}
|
|
||||||
|
|
||||||
// bind binds given rows to the given object at obj. obj should be a pointer
|
|
||||||
func (b *binder) bind(rows *sql.Rows, obj interface{}) error {
|
|
||||||
cts, err := rows.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
t := reflect.TypeOf(obj)
|
|
||||||
v := reflect.ValueOf(obj)
|
|
||||||
if t.Kind() != reflect.Ptr {
|
|
||||||
return fmt.Errorf("obj should be a ptr")
|
|
||||||
}
|
|
||||||
// since passed input is always a pointer one deref is necessary
|
|
||||||
t = t.Elem()
|
|
||||||
v = v.Elem()
|
|
||||||
if t.Kind() == reflect.Slice {
|
|
||||||
// getting slice elemnt type -> slice[t]
|
|
||||||
t = t.Elem()
|
|
||||||
for rows.Next() {
|
|
||||||
var rowValue reflect.Value
|
|
||||||
// Since reflect.SetupConnections returns a pointer to the type, we need to unwrap it to get actual
|
|
||||||
rowValue = reflect.New(t).Elem()
|
|
||||||
// till we reach a not pointer type continue newing the underlying type.
|
|
||||||
for rowValue.IsZero() && rowValue.Type().Kind() == reflect.Ptr {
|
|
||||||
rowValue = reflect.New(rowValue.Type().Elem()).Elem()
|
|
||||||
}
|
|
||||||
newCts := make([]*sql.ColumnType, len(cts))
|
|
||||||
copy(newCts, cts)
|
|
||||||
ptrs := b.ptrsFor(rowValue, newCts)
|
|
||||||
err = rows.Scan(ptrs...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
for rowValue.Type() != t {
|
|
||||||
tmp := reflect.New(rowValue.Type())
|
|
||||||
tmp.Elem().Set(rowValue)
|
|
||||||
rowValue = tmp
|
|
||||||
}
|
|
||||||
v = reflect.Append(v, rowValue)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
for rows.Next() {
|
|
||||||
ptrs := b.ptrsFor(v, cts)
|
|
||||||
err = rows.Scan(ptrs...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// v is either struct or slice
|
|
||||||
reflect.ValueOf(obj).Elem().Set(v)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func bindToMap(rows *sql.Rows) ([]map[string]interface{}, error) {
|
|
||||||
cts, err := rows.ColumnTypes()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var ms []map[string]interface{}
|
|
||||||
for rows.Next() {
|
|
||||||
var ptrs []interface{}
|
|
||||||
for _, ct := range cts {
|
|
||||||
ptrs = append(ptrs, reflect.New(ct.ScanType()).Interface())
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Scan(ptrs...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
m := map[string]interface{}{}
|
|
||||||
for i, ptr := range ptrs {
|
|
||||||
m[cts[i].Name()] = reflect.ValueOf(ptr).Elem().Interface()
|
|
||||||
}
|
|
||||||
|
|
||||||
ms = append(ms, m)
|
|
||||||
}
|
|
||||||
return ms, nil
|
|
||||||
}
|
|
@ -1,92 +0,0 @@
|
|||||||
//
|
|
||||||
// binder_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/DATA-DOG/go-sqlmock"
|
|
||||||
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
type User struct {
|
|
||||||
ID int64
|
|
||||||
Name string
|
|
||||||
Timestamps
|
|
||||||
SoftDelete
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u User) ConfigureEntity(e *EntityConfigurator) {
|
|
||||||
e.Table("users")
|
|
||||||
}
|
|
||||||
|
|
||||||
type Address struct {
|
|
||||||
ID int
|
|
||||||
Path string
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBind(t *testing.T) {
|
|
||||||
t.Run("single result", func(t *testing.T) {
|
|
||||||
db, mock, err := sqlmock.New()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
mock.
|
|
||||||
ExpectQuery("SELECT .* FROM users").
|
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}).
|
|
||||||
AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{}))
|
|
||||||
rows, err := db.Query(`SELECT * FROM users`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
u := &User{}
|
|
||||||
md := schemaOfHeavyReflectionStuff(u)
|
|
||||||
err = newBinder(md).bind(rows, u)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "amirreza", u.Name)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("multi result", func(t *testing.T) {
|
|
||||||
db, mock, err := sqlmock.New()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
mock.
|
|
||||||
ExpectQuery("SELECT .* FROM users").
|
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "amirreza").AddRow(2, "milad"))
|
|
||||||
|
|
||||||
rows, err := db.Query(`SELECT * FROM users`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
md := schemaOfHeavyReflectionStuff(&User{})
|
|
||||||
var users []*User
|
|
||||||
err = newBinder(md).bind(rows, &users)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, "amirreza", users[0].Name)
|
|
||||||
assert.Equal(t, "milad", users[1].Name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBindMap(t *testing.T) {
|
|
||||||
db, mock, err := sqlmock.New()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
mock.
|
|
||||||
ExpectQuery("SELECT .* FROM users").
|
|
||||||
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}).
|
|
||||||
AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{}))
|
|
||||||
rows, err := db.Query(`SELECT * FROM users`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
ms, err := bindToMap(rows)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotEmpty(t, ms)
|
|
||||||
|
|
||||||
assert.Len(t, ms, 1)
|
|
||||||
}
|
|
@ -1,192 +0,0 @@
|
|||||||
//
|
|
||||||
// configurators.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
|
|
||||||
"git.hexq.cn/tiglog/golib/helper"
|
|
||||||
)
|
|
||||||
|
|
||||||
type EntityConfigurator struct {
|
|
||||||
connection string
|
|
||||||
table string
|
|
||||||
this Entity
|
|
||||||
relations map[string]interface{}
|
|
||||||
resolveRelations []func()
|
|
||||||
columnConstraints []*FieldConfigurator
|
|
||||||
}
|
|
||||||
|
|
||||||
func newEntityConfigurator() *EntityConfigurator {
|
|
||||||
return &EntityConfigurator{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) Table(name string) *EntityConfigurator {
|
|
||||||
ec.table = name
|
|
||||||
return ec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) Connection(name string) *EntityConfigurator {
|
|
||||||
ec.connection = name
|
|
||||||
return ec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) HasMany(property Entity, config HasManyConfig) *EntityConfigurator {
|
|
||||||
if ec.relations == nil {
|
|
||||||
ec.relations = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
ec.resolveRelations = append(ec.resolveRelations, func() {
|
|
||||||
if config.PropertyForeignKey != "" && config.PropertyTable != "" {
|
|
||||||
ec.relations[config.PropertyTable] = config
|
|
||||||
return
|
|
||||||
}
|
|
||||||
configurator := newEntityConfigurator()
|
|
||||||
property.ConfigureEntity(configurator)
|
|
||||||
|
|
||||||
if config.PropertyTable == "" {
|
|
||||||
config.PropertyTable = configurator.table
|
|
||||||
}
|
|
||||||
|
|
||||||
if config.PropertyForeignKey == "" {
|
|
||||||
config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id"
|
|
||||||
}
|
|
||||||
|
|
||||||
ec.relations[configurator.table] = config
|
|
||||||
|
|
||||||
return
|
|
||||||
})
|
|
||||||
return ec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) HasOne(property Entity, config HasOneConfig) *EntityConfigurator {
|
|
||||||
if ec.relations == nil {
|
|
||||||
ec.relations = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
ec.resolveRelations = append(ec.resolveRelations, func() {
|
|
||||||
if config.PropertyForeignKey != "" && config.PropertyTable != "" {
|
|
||||||
ec.relations[config.PropertyTable] = config
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
configurator := newEntityConfigurator()
|
|
||||||
property.ConfigureEntity(configurator)
|
|
||||||
|
|
||||||
if config.PropertyTable == "" {
|
|
||||||
config.PropertyTable = configurator.table
|
|
||||||
}
|
|
||||||
if config.PropertyForeignKey == "" {
|
|
||||||
config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id"
|
|
||||||
}
|
|
||||||
|
|
||||||
ec.relations[configurator.table] = config
|
|
||||||
return
|
|
||||||
})
|
|
||||||
return ec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) BelongsTo(owner Entity, config BelongsToConfig) *EntityConfigurator {
|
|
||||||
if ec.relations == nil {
|
|
||||||
ec.relations = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
ec.resolveRelations = append(ec.resolveRelations, func() {
|
|
||||||
if config.ForeignColumnName != "" && config.LocalForeignKey != "" && config.OwnerTable != "" {
|
|
||||||
ec.relations[config.OwnerTable] = config
|
|
||||||
return
|
|
||||||
}
|
|
||||||
ownerConfigurator := newEntityConfigurator()
|
|
||||||
owner.ConfigureEntity(ownerConfigurator)
|
|
||||||
if config.OwnerTable == "" {
|
|
||||||
config.OwnerTable = ownerConfigurator.table
|
|
||||||
}
|
|
||||||
if config.LocalForeignKey == "" {
|
|
||||||
config.LocalForeignKey = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id"
|
|
||||||
}
|
|
||||||
if config.ForeignColumnName == "" {
|
|
||||||
config.ForeignColumnName = "id"
|
|
||||||
}
|
|
||||||
ec.relations[ownerConfigurator.table] = config
|
|
||||||
})
|
|
||||||
return ec
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) BelongsToMany(owner Entity, config BelongsToManyConfig) *EntityConfigurator {
|
|
||||||
if ec.relations == nil {
|
|
||||||
ec.relations = map[string]interface{}{}
|
|
||||||
}
|
|
||||||
ec.resolveRelations = append(ec.resolveRelations, func() {
|
|
||||||
ownerConfigurator := newEntityConfigurator()
|
|
||||||
owner.ConfigureEntity(ownerConfigurator)
|
|
||||||
|
|
||||||
if config.OwnerLookupColumn == "" {
|
|
||||||
var pkName string
|
|
||||||
for _, field := range genericFieldsOf(owner) {
|
|
||||||
if field.IsPK {
|
|
||||||
pkName = field.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
config.OwnerLookupColumn = pkName
|
|
||||||
|
|
||||||
}
|
|
||||||
if config.OwnerTable == "" {
|
|
||||||
config.OwnerTable = ownerConfigurator.table
|
|
||||||
}
|
|
||||||
if config.IntermediateTable == "" {
|
|
||||||
panic("cannot infer intermediate table yet")
|
|
||||||
}
|
|
||||||
if config.IntermediatePropertyID == "" {
|
|
||||||
config.IntermediatePropertyID = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id"
|
|
||||||
}
|
|
||||||
if config.IntermediateOwnerID == "" {
|
|
||||||
config.IntermediateOwnerID = helper.NewPluralizeClient().Singular(ec.table) + "_id"
|
|
||||||
}
|
|
||||||
|
|
||||||
ec.relations[ownerConfigurator.table] = config
|
|
||||||
})
|
|
||||||
return ec
|
|
||||||
}
|
|
||||||
|
|
||||||
type FieldConfigurator struct {
|
|
||||||
fieldName string
|
|
||||||
nullable sql.NullBool
|
|
||||||
primaryKey bool
|
|
||||||
column string
|
|
||||||
isCreatedAt bool
|
|
||||||
isUpdatedAt bool
|
|
||||||
isDeletedAt bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ec *EntityConfigurator) Field(name string) *FieldConfigurator {
|
|
||||||
cc := &FieldConfigurator{fieldName: name}
|
|
||||||
ec.columnConstraints = append(ec.columnConstraints, cc)
|
|
||||||
return cc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *FieldConfigurator) IsPrimaryKey() *FieldConfigurator {
|
|
||||||
fc.primaryKey = true
|
|
||||||
return fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *FieldConfigurator) IsCreatedAt() *FieldConfigurator {
|
|
||||||
fc.isCreatedAt = true
|
|
||||||
return fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *FieldConfigurator) IsUpdatedAt() *FieldConfigurator {
|
|
||||||
fc.isUpdatedAt = true
|
|
||||||
return fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *FieldConfigurator) IsDeletedAt() *FieldConfigurator {
|
|
||||||
fc.isDeletedAt = true
|
|
||||||
return fc
|
|
||||||
}
|
|
||||||
|
|
||||||
func (fc *FieldConfigurator) ColumnName(name string) *FieldConfigurator {
|
|
||||||
fc.column = name
|
|
||||||
return fc
|
|
||||||
}
|
|
@ -1,160 +0,0 @@
|
|||||||
//
|
|
||||||
// connection.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
|
|
||||||
"git.hexq.cn/tiglog/golib/helper/table"
|
|
||||||
)
|
|
||||||
|
|
||||||
type connection struct {
|
|
||||||
Name string
|
|
||||||
Dialect *Dialect
|
|
||||||
DB *sql.DB
|
|
||||||
Schemas map[string]*schema
|
|
||||||
DBSchema map[string][]columnSpec
|
|
||||||
DatabaseValidations bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) inferedTables() []string {
|
|
||||||
var tables []string
|
|
||||||
for t, s := range c.Schemas {
|
|
||||||
tables = append(tables, t)
|
|
||||||
for _, relC := range s.relations {
|
|
||||||
if belongsToManyConfig, is := relC.(BelongsToManyConfig); is {
|
|
||||||
tables = append(tables, belongsToManyConfig.IntermediateTable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return tables
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) validateAllTablesArePresent() error {
|
|
||||||
for _, inferedTable := range c.inferedTables() {
|
|
||||||
if _, exists := c.DBSchema[inferedTable]; !exists {
|
|
||||||
return fmt.Errorf("orm infered %s but it's not found in your database, your database is out of sync", inferedTable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) validateTablesSchemas() error {
|
|
||||||
// check for entity tables: there should not be any struct field that does not have a coresponding column
|
|
||||||
for table, sc := range c.Schemas {
|
|
||||||
if columns, exists := c.DBSchema[table]; exists {
|
|
||||||
for _, f := range sc.fields {
|
|
||||||
found := false
|
|
||||||
for _, c := range columns {
|
|
||||||
if c.Name == f.Name {
|
|
||||||
found = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("column %s not found while it was inferred", f.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("tables are out of sync, %s was inferred but not present in database", table)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check for relation tables: for HasMany,HasOne relations check if OWNER pk column is in PROPERTY,
|
|
||||||
// for BelongsToMany check intermediate table has 2 pk for two entities
|
|
||||||
|
|
||||||
for table, sc := range c.Schemas {
|
|
||||||
for _, rel := range sc.relations {
|
|
||||||
switch rel.(type) {
|
|
||||||
case BelongsToConfig:
|
|
||||||
columns := c.DBSchema[table]
|
|
||||||
var found bool
|
|
||||||
for _, col := range columns {
|
|
||||||
if col.Name == rel.(BelongsToConfig).LocalForeignKey {
|
|
||||||
found = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !found {
|
|
||||||
return fmt.Errorf("cannot find local foreign key %s for relation", rel.(BelongsToConfig).LocalForeignKey)
|
|
||||||
}
|
|
||||||
case BelongsToManyConfig:
|
|
||||||
columns := c.DBSchema[rel.(BelongsToManyConfig).IntermediateTable]
|
|
||||||
var foundOwner bool
|
|
||||||
var foundProperty bool
|
|
||||||
|
|
||||||
for _, col := range columns {
|
|
||||||
if col.Name == rel.(BelongsToManyConfig).IntermediateOwnerID {
|
|
||||||
foundOwner = true
|
|
||||||
}
|
|
||||||
if col.Name == rel.(BelongsToManyConfig).IntermediatePropertyID {
|
|
||||||
foundProperty = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundOwner || !foundProperty {
|
|
||||||
return fmt.Errorf("table schema for %s is not correct one of foreign keys is not present", rel.(BelongsToManyConfig).IntermediateTable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) Schematic() {
|
|
||||||
fmt.Printf("SQL Dialect: %s\n", c.Dialect.DriverName)
|
|
||||||
for t, schema := range c.Schemas {
|
|
||||||
fmt.Printf("t: %s\n", t)
|
|
||||||
w := table.NewWriter()
|
|
||||||
w.AppendHeader(table.Row{"SQL Name", "Type", "Is Primary Key", "Is Virtual"})
|
|
||||||
for _, field := range schema.fields {
|
|
||||||
w.AppendRow(table.Row{field.Name, field.Type, field.IsPK, field.Virtual})
|
|
||||||
}
|
|
||||||
fmt.Println(w.Render())
|
|
||||||
for _, rel := range schema.relations {
|
|
||||||
switch rel.(type) {
|
|
||||||
case HasOneConfig:
|
|
||||||
fmt.Printf("%s 1-1 %s => %+v\n", t, rel.(HasOneConfig).PropertyTable, rel)
|
|
||||||
case HasManyConfig:
|
|
||||||
fmt.Printf("%s 1-N %s => %+v\n", t, rel.(HasManyConfig).PropertyTable, rel)
|
|
||||||
|
|
||||||
case BelongsToConfig:
|
|
||||||
fmt.Printf("%s N-1 %s => %+v\n", t, rel.(BelongsToConfig).OwnerTable, rel)
|
|
||||||
|
|
||||||
case BelongsToManyConfig:
|
|
||||||
fmt.Printf("%s N-N %s => %+v\n", t, rel.(BelongsToManyConfig).IntermediateTable, rel)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
fmt.Println("")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) getSchema(t string) *schema {
|
|
||||||
return c.Schemas[t]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) setSchema(e Entity, s *schema) {
|
|
||||||
var configurator EntityConfigurator
|
|
||||||
e.ConfigureEntity(&configurator)
|
|
||||||
c.Schemas[configurator.table] = s
|
|
||||||
}
|
|
||||||
|
|
||||||
func GetConnection(name string) *connection {
|
|
||||||
return globalConnections[name]
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) exec(q string, args ...any) (sql.Result, error) {
|
|
||||||
return c.DB.Exec(q, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) query(q string, args ...any) (*sql.Rows, error) {
|
|
||||||
return c.DB.Query(q, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *connection) queryRow(q string, args ...any) *sql.Row {
|
|
||||||
return c.DB.QueryRow(q, args...)
|
|
||||||
}
|
|
@ -1,109 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Dialect struct {
|
|
||||||
DriverName string
|
|
||||||
PlaceholderChar string
|
|
||||||
IncludeIndexInPlaceholder bool
|
|
||||||
AddTableNameInSelectColumns bool
|
|
||||||
PlaceHolderGenerator func(n int) []string
|
|
||||||
QueryListTables string
|
|
||||||
QueryTableSchema string
|
|
||||||
}
|
|
||||||
|
|
||||||
func getListOfTables(query string) func(db *sql.DB) ([]string, error) {
|
|
||||||
return func(db *sql.DB) ([]string, error) {
|
|
||||||
rows, err := db.Query(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var tables []string
|
|
||||||
for rows.Next() {
|
|
||||||
var table string
|
|
||||||
err = rows.Scan(&table)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
tables = append(tables, table)
|
|
||||||
}
|
|
||||||
return tables, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type columnSpec struct {
|
|
||||||
//0|id|INTEGER|0||1
|
|
||||||
Name string
|
|
||||||
Type string
|
|
||||||
Nullable bool
|
|
||||||
DefaultValue sql.NullString
|
|
||||||
IsPrimaryKey bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTableSchema(query string) func(db *sql.DB, query string) ([]columnSpec, error) {
|
|
||||||
return func(db *sql.DB, table string) ([]columnSpec, error) {
|
|
||||||
rows, err := db.Query(fmt.Sprintf(query, table))
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var output []columnSpec
|
|
||||||
for rows.Next() {
|
|
||||||
var cs columnSpec
|
|
||||||
var nullable string
|
|
||||||
var pk int
|
|
||||||
err = rows.Scan(&cs.Name, &cs.Type, &nullable, &cs.DefaultValue, &pk)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
cs.Nullable = nullable == "notnull"
|
|
||||||
cs.IsPrimaryKey = pk == 1
|
|
||||||
output = append(output, cs)
|
|
||||||
}
|
|
||||||
return output, nil
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var Dialects = &struct {
|
|
||||||
MySQL *Dialect
|
|
||||||
PostgreSQL *Dialect
|
|
||||||
SQLite3 *Dialect
|
|
||||||
}{
|
|
||||||
MySQL: &Dialect{
|
|
||||||
DriverName: "mysql",
|
|
||||||
PlaceholderChar: "?",
|
|
||||||
IncludeIndexInPlaceholder: false,
|
|
||||||
AddTableNameInSelectColumns: true,
|
|
||||||
PlaceHolderGenerator: questionMarks,
|
|
||||||
QueryListTables: "SHOW TABLES",
|
|
||||||
QueryTableSchema: "DESCRIBE %s",
|
|
||||||
},
|
|
||||||
PostgreSQL: &Dialect{
|
|
||||||
DriverName: "postgres",
|
|
||||||
PlaceholderChar: "$",
|
|
||||||
IncludeIndexInPlaceholder: true,
|
|
||||||
AddTableNameInSelectColumns: true,
|
|
||||||
PlaceHolderGenerator: postgresPlaceholder,
|
|
||||||
QueryListTables: `\dt`,
|
|
||||||
QueryTableSchema: `\d %s`,
|
|
||||||
},
|
|
||||||
SQLite3: &Dialect{
|
|
||||||
DriverName: "sqlite3",
|
|
||||||
PlaceholderChar: "?",
|
|
||||||
IncludeIndexInPlaceholder: false,
|
|
||||||
AddTableNameInSelectColumns: false,
|
|
||||||
PlaceHolderGenerator: questionMarks,
|
|
||||||
QueryListTables: "SELECT name FROM sqlite_schema WHERE type='table'",
|
|
||||||
QueryTableSchema: `SELECT name,type,"notnull","dflt_value","pk" FROM PRAGMA_TABLE_INFO('%s')`,
|
|
||||||
},
|
|
||||||
}
|
|
@ -1,75 +0,0 @@
|
|||||||
//
|
|
||||||
// field.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"git.hexq.cn/tiglog/golib/helper"
|
|
||||||
)
|
|
||||||
|
|
||||||
type field struct {
|
|
||||||
Name string
|
|
||||||
IsPK bool
|
|
||||||
Virtual bool
|
|
||||||
IsCreatedAt bool
|
|
||||||
IsUpdatedAt bool
|
|
||||||
IsDeletedAt bool
|
|
||||||
Nullable bool
|
|
||||||
Default any
|
|
||||||
Type reflect.Type
|
|
||||||
}
|
|
||||||
|
|
||||||
func getFieldConfiguratorFor(fieldConfigurators []*FieldConfigurator, name string) *FieldConfigurator {
|
|
||||||
for _, fc := range fieldConfigurators {
|
|
||||||
if fc.fieldName == name {
|
|
||||||
return fc
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return &FieldConfigurator{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func fieldMetadata(ft reflect.StructField, fieldConfigurators []*FieldConfigurator) []*field {
|
|
||||||
var fms []*field
|
|
||||||
fc := getFieldConfiguratorFor(fieldConfigurators, ft.Name)
|
|
||||||
baseFm := &field{}
|
|
||||||
baseFm.Type = ft.Type
|
|
||||||
fms = append(fms, baseFm)
|
|
||||||
if fc.column != "" {
|
|
||||||
baseFm.Name = fc.column
|
|
||||||
} else {
|
|
||||||
baseFm.Name = helper.SnakeString(ft.Name)
|
|
||||||
}
|
|
||||||
if strings.ToLower(ft.Name) == "id" || fc.primaryKey {
|
|
||||||
baseFm.IsPK = true
|
|
||||||
}
|
|
||||||
if strings.ToLower(ft.Name) == "createdat" || fc.isCreatedAt {
|
|
||||||
baseFm.IsCreatedAt = true
|
|
||||||
}
|
|
||||||
if strings.ToLower(ft.Name) == "updatedat" || fc.isUpdatedAt {
|
|
||||||
baseFm.IsUpdatedAt = true
|
|
||||||
}
|
|
||||||
if strings.ToLower(ft.Name) == "deletedat" || fc.isDeletedAt {
|
|
||||||
baseFm.IsDeletedAt = true
|
|
||||||
}
|
|
||||||
if ft.Type.Kind() == reflect.Struct || ft.Type.Kind() == reflect.Ptr {
|
|
||||||
t := ft.Type
|
|
||||||
if ft.Type.Kind() == reflect.Ptr {
|
|
||||||
t = ft.Type.Elem()
|
|
||||||
}
|
|
||||||
if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
fms = append(fms, fieldMetadata(t.Field(i), fieldConfigurators)...)
|
|
||||||
}
|
|
||||||
fms = fms[1:]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fms
|
|
||||||
}
|
|
699
gdb/orm/orm.go
699
gdb/orm/orm.go
@ -1,699 +0,0 @@
|
|||||||
//
|
|
||||||
// orm.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
// Drivers
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
|
||||||
_ "github.com/lib/pq"
|
|
||||||
_ "github.com/mattn/go-sqlite3"
|
|
||||||
)
|
|
||||||
|
|
||||||
var globalConnections = map[string]*connection{}
|
|
||||||
|
|
||||||
// Schematic prints all information ORM inferred from your entities in startup, remember to pass
|
|
||||||
// your entities in Entities when you call SetupConnections if you want their data inferred
|
|
||||||
// otherwise Schematic does not print correct data since GoLobby ORM also
|
|
||||||
// incrementally cache your entities metadata and schema.
|
|
||||||
func Schematic() {
|
|
||||||
for conn, connObj := range globalConnections {
|
|
||||||
fmt.Printf("----------------%s---------------\n", conn)
|
|
||||||
connObj.Schematic()
|
|
||||||
fmt.Println("-----------------------------------")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type ConnectionConfig struct {
|
|
||||||
// Name of your database connection, it's up to you to name them anything
|
|
||||||
// just remember that having a connection name is mandatory if
|
|
||||||
// you have multiple connections
|
|
||||||
Name string
|
|
||||||
|
|
||||||
// 不是必需,不指定 Dialect 时必需
|
|
||||||
Driver string
|
|
||||||
// If you already have an active database connection configured pass it in this value and
|
|
||||||
// do not pass Driver and DSN fields.
|
|
||||||
DB *sql.DB
|
|
||||||
// Which dialect of sql to generate queries for, you don't need it most of the times when you are using
|
|
||||||
// traditional databases such as mysql, sqlite3, postgres.
|
|
||||||
Dialect *Dialect
|
|
||||||
// List of entities that you want to use for this connection, remember that you can ignore this field
|
|
||||||
// and GoLobby ORM will build our metadata cache incrementally but you will lose schematic
|
|
||||||
// information that we can provide you and also potentialy validations that we
|
|
||||||
// can do with the database
|
|
||||||
Entities []Entity
|
|
||||||
// Database validations, check if all tables exists and also table schemas contains all necessary columns.
|
|
||||||
// Check if all infered tables exist in your database
|
|
||||||
DatabaseValidations bool
|
|
||||||
}
|
|
||||||
|
|
||||||
type DbOption struct {
|
|
||||||
Type string
|
|
||||||
Dsn string
|
|
||||||
MaxIdle int
|
|
||||||
MaxOpen int
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewDb(opt DbOption) (*sql.DB, error) {
|
|
||||||
db, err := sql.Open(opt.Type, opt.Dsn)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
if opt.Type == "sqlite3" {
|
|
||||||
if opt.Dsn == ":memory:" {
|
|
||||||
db.SetMaxOpenConns(1)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
db.SetMaxIdleConns(opt.MaxIdle)
|
|
||||||
db.SetMaxOpenConns(opt.MaxOpen)
|
|
||||||
}
|
|
||||||
return db, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getDialectByDriver(driver string) (*Dialect, error) {
|
|
||||||
switch driver {
|
|
||||||
case "postgres":
|
|
||||||
return Dialects.PostgreSQL, nil
|
|
||||||
case "mysql":
|
|
||||||
return Dialects.MySQL, nil
|
|
||||||
case "sqlite3":
|
|
||||||
return Dialects.SQLite3, nil
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("unsupported db driver %s", driver)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetupConnections declares a new connections for ORM.
|
|
||||||
func SetupConnections(configs ...ConnectionConfig) error {
|
|
||||||
|
|
||||||
for _, c := range configs {
|
|
||||||
if err := setupConnection(c); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, conn := range globalConnections {
|
|
||||||
if !conn.DatabaseValidations {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
tables, err := getListOfTables(conn.Dialect.QueryListTables)(conn.DB)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, table := range tables {
|
|
||||||
if conn.DatabaseValidations {
|
|
||||||
spec, err := getTableSchema(conn.Dialect.QueryTableSchema)(conn.DB, table)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
conn.DBSchema[table] = spec
|
|
||||||
} else {
|
|
||||||
conn.DBSchema[table] = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// check tables existence
|
|
||||||
if conn.DatabaseValidations {
|
|
||||||
err := conn.validateAllTablesArePresent()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if conn.DatabaseValidations {
|
|
||||||
err = conn.validateTablesSchemas()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func setupConnection(config ConnectionConfig) error {
|
|
||||||
schemas := map[string]*schema{}
|
|
||||||
if config.Name == "" {
|
|
||||||
config.Name = "default"
|
|
||||||
}
|
|
||||||
if config.Dialect == nil {
|
|
||||||
dialect, err := getDialectByDriver(config.Driver)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
config.Dialect = dialect
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, entity := range config.Entities {
|
|
||||||
s := schemaOfHeavyReflectionStuff(entity)
|
|
||||||
var configurator EntityConfigurator
|
|
||||||
entity.ConfigureEntity(&configurator)
|
|
||||||
schemas[configurator.table] = s
|
|
||||||
}
|
|
||||||
|
|
||||||
s := &connection{
|
|
||||||
Name: config.Name,
|
|
||||||
DB: config.DB,
|
|
||||||
Dialect: config.Dialect,
|
|
||||||
Schemas: schemas,
|
|
||||||
DBSchema: make(map[string][]columnSpec),
|
|
||||||
DatabaseValidations: config.DatabaseValidations,
|
|
||||||
}
|
|
||||||
|
|
||||||
globalConnections[fmt.Sprintf("%s", config.Name)] = s
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Entity defines the interface that each of your structs that
|
|
||||||
// you want to use as database entities should have,
|
|
||||||
// it's a simple one and its ConfigureEntity.
|
|
||||||
type Entity interface {
|
|
||||||
// ConfigureEntity should be defined for all of your database entities
|
|
||||||
// and it can define Table, DB and also relations of your Entity.
|
|
||||||
ConfigureEntity(e *EntityConfigurator)
|
|
||||||
}
|
|
||||||
|
|
||||||
// InsertAll given entities into database based on their ConfigureEntity
|
|
||||||
// we can find table and also DB name.
|
|
||||||
func InsertAll(objs ...Entity) error {
|
|
||||||
if len(objs) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
s := getSchemaFor(objs[0])
|
|
||||||
cols := s.Columns(false)
|
|
||||||
var values [][]interface{}
|
|
||||||
for _, obj := range objs {
|
|
||||||
createdAtF := s.createdAt()
|
|
||||||
if createdAtF != nil {
|
|
||||||
genericSet(obj, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
|
||||||
}
|
|
||||||
updatedAtF := s.updatedAt()
|
|
||||||
if updatedAtF != nil {
|
|
||||||
genericSet(obj, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
|
||||||
}
|
|
||||||
values = append(values, genericValuesOf(obj, false))
|
|
||||||
}
|
|
||||||
|
|
||||||
is := insertStmt{
|
|
||||||
PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator,
|
|
||||||
Table: s.getTable(),
|
|
||||||
Columns: cols,
|
|
||||||
Values: values,
|
|
||||||
}
|
|
||||||
|
|
||||||
q, args := is.ToSql()
|
|
||||||
|
|
||||||
_, err := s.getConnection().exec(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert given entity into database based on their ConfigureEntity
|
|
||||||
// we can find table and also DB name.
|
|
||||||
func Insert(o Entity) error {
|
|
||||||
s := getSchemaFor(o)
|
|
||||||
cols := s.Columns(false)
|
|
||||||
var values [][]interface{}
|
|
||||||
createdAtF := s.createdAt()
|
|
||||||
if createdAtF != nil {
|
|
||||||
genericSet(o, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
|
||||||
}
|
|
||||||
updatedAtF := s.updatedAt()
|
|
||||||
if updatedAtF != nil {
|
|
||||||
genericSet(o, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
|
||||||
}
|
|
||||||
values = append(values, genericValuesOf(o, false))
|
|
||||||
|
|
||||||
is := insertStmt{
|
|
||||||
PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator,
|
|
||||||
Table: s.getTable(),
|
|
||||||
Columns: cols,
|
|
||||||
Values: values,
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.getDialect().DriverName == "postgres" {
|
|
||||||
is.Returning = s.pkName()
|
|
||||||
}
|
|
||||||
q, args := is.ToSql()
|
|
||||||
|
|
||||||
res, err := s.getConnection().exec(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
id, err := res.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if s.pkName() != "" {
|
|
||||||
// intermediate tables usually have no single pk column.
|
|
||||||
s.setPK(o, id)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func isZero(val interface{}) bool {
|
|
||||||
switch val.(type) {
|
|
||||||
case int64:
|
|
||||||
return val.(int64) == 0
|
|
||||||
case int:
|
|
||||||
return val.(int) == 0
|
|
||||||
case string:
|
|
||||||
return val.(string) == ""
|
|
||||||
default:
|
|
||||||
return reflect.ValueOf(val).Elem().IsZero()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Save saves given entity, if primary key is set
|
|
||||||
// we will make an update query and if
|
|
||||||
// primary key is zero value we will
|
|
||||||
// insert it.
|
|
||||||
func Save(obj Entity) error {
|
|
||||||
if isZero(getSchemaFor(obj).getPK(obj)) {
|
|
||||||
return Insert(obj)
|
|
||||||
} else {
|
|
||||||
return Update(obj)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Find finds the Entity you want based on generic type and primary key you passed.
|
|
||||||
func Find[T Entity](id interface{}) (T, error) {
|
|
||||||
var q string
|
|
||||||
out := new(T)
|
|
||||||
md := getSchemaFor(*out)
|
|
||||||
q, args, err := NewQueryBuilder[T](md).
|
|
||||||
SetDialect(md.getDialect()).
|
|
||||||
Table(md.Table).
|
|
||||||
Select(md.Columns(true)...).
|
|
||||||
Where(md.pkName(), id).
|
|
||||||
ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return *out, err
|
|
||||||
}
|
|
||||||
err = bind[T](out, q, args)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return *out, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return *out, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toKeyValues(obj Entity, withPK bool) []any {
|
|
||||||
var tuples []any
|
|
||||||
vs := genericValuesOf(obj, withPK)
|
|
||||||
cols := getSchemaFor(obj).Columns(withPK)
|
|
||||||
for i, col := range cols {
|
|
||||||
tuples = append(tuples, col, vs[i])
|
|
||||||
}
|
|
||||||
return tuples
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update given Entity in database.
|
|
||||||
func Update(obj Entity) error {
|
|
||||||
s := getSchemaFor(obj)
|
|
||||||
q, args, err := NewQueryBuilder[Entity](s).
|
|
||||||
SetDialect(s.getDialect()).
|
|
||||||
Set(toKeyValues(obj, false)...).
|
|
||||||
Where(s.pkName(), genericGetPKValue(obj)).Table(s.Table).ToSql()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = s.getConnection().exec(q, args...)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete given Entity from database
|
|
||||||
func Delete(obj Entity) error {
|
|
||||||
s := getSchemaFor(obj)
|
|
||||||
genericSet(obj, "deleted_at", sql.NullTime{Time: time.Now(), Valid: true})
|
|
||||||
query, args, err := NewQueryBuilder[Entity](s).SetDialect(s.getDialect()).Table(s.Table).Where(s.pkName(), genericGetPKValue(obj)).SetDelete().ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = s.getConnection().exec(query, args...)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func bind[T Entity](output interface{}, q string, args []interface{}) error {
|
|
||||||
outputMD := getSchemaFor(*new(T))
|
|
||||||
rows, err := outputMD.getConnection().query(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return newBinder(outputMD).bind(rows, output)
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasManyConfig contains all information we need for querying HasMany relationships.
|
|
||||||
// We can infer both fields if you have them in standard way but you
|
|
||||||
// can specify them if you want custom ones.
|
|
||||||
type HasManyConfig struct {
|
|
||||||
// PropertyTable is table of the property of HasMany relationship,
|
|
||||||
// consider `Comment` in Post and Comment relationship,
|
|
||||||
// each Post HasMany Comment, so PropertyTable is
|
|
||||||
// `comments`.
|
|
||||||
PropertyTable string
|
|
||||||
// PropertyForeignKey is the foreign key field name in the property table,
|
|
||||||
// for example in Post HasMany Comment, if comment has `post_id` field,
|
|
||||||
// it's the PropertyForeignKey field.
|
|
||||||
PropertyForeignKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasMany configures a QueryBuilder for a HasMany relationship
|
|
||||||
// this relationship will be defined for owner argument
|
|
||||||
// that has many of PROPERTY generic type for example
|
|
||||||
// HasMany[Comment](&Post{})
|
|
||||||
// is for Post HasMany Comment relationship.
|
|
||||||
func HasMany[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] {
|
|
||||||
outSchema := getSchemaFor(*new(PROPERTY))
|
|
||||||
|
|
||||||
q := NewQueryBuilder[PROPERTY](outSchema)
|
|
||||||
// getting config from our cache
|
|
||||||
c, ok := getSchemaFor(owner).relations[outSchema.Table].(HasManyConfig)
|
|
||||||
if !ok {
|
|
||||||
q.err = fmt.Errorf("wrong config passed for HasMany")
|
|
||||||
}
|
|
||||||
|
|
||||||
s := getSchemaFor(owner)
|
|
||||||
return q.
|
|
||||||
SetDialect(s.getDialect()).
|
|
||||||
Table(c.PropertyTable).
|
|
||||||
Select(outSchema.Columns(true)...).
|
|
||||||
Where(c.PropertyForeignKey, genericGetPKValue(owner))
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasOneConfig contains all information we need for a HasOne relationship,
|
|
||||||
// it's similar to HasManyConfig.
|
|
||||||
type HasOneConfig struct {
|
|
||||||
// PropertyTable is table of the property of HasOne relationship,
|
|
||||||
// consider `HeaderPicture` in Post and HeaderPicture relationship,
|
|
||||||
// each Post HasOne HeaderPicture, so PropertyTable is
|
|
||||||
// `header_pictures`.
|
|
||||||
PropertyTable string
|
|
||||||
// PropertyForeignKey is the foreign key field name in the property table,
|
|
||||||
// forexample in Post HasOne HeaderPicture, if header_picture has `post_id` field,
|
|
||||||
// it's the PropertyForeignKey field.
|
|
||||||
PropertyForeignKey string
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasOne configures a QueryBuilder for a HasOne relationship
|
|
||||||
// this relationship will be defined for owner argument
|
|
||||||
// that has one of PROPERTY generic type for example
|
|
||||||
// HasOne[HeaderPicture](&Post{})
|
|
||||||
// is for Post HasOne HeaderPicture relationship.
|
|
||||||
func HasOne[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] {
|
|
||||||
property := getSchemaFor(*new(PROPERTY))
|
|
||||||
q := NewQueryBuilder[PROPERTY](property)
|
|
||||||
c, ok := getSchemaFor(owner).relations[property.Table].(HasOneConfig)
|
|
||||||
if !ok {
|
|
||||||
q.err = fmt.Errorf("wrong config passed for HasOne")
|
|
||||||
}
|
|
||||||
|
|
||||||
// settings default config Values
|
|
||||||
return q.
|
|
||||||
SetDialect(property.getDialect()).
|
|
||||||
Table(c.PropertyTable).
|
|
||||||
Select(property.Columns(true)...).
|
|
||||||
Where(c.PropertyForeignKey, genericGetPKValue(owner))
|
|
||||||
}
|
|
||||||
|
|
||||||
// BelongsToConfig contains all information we need for a BelongsTo relationship
|
|
||||||
// BelongsTo is a relationship between a Comment and it's Post,
|
|
||||||
// A Comment BelongsTo Post.
|
|
||||||
type BelongsToConfig struct {
|
|
||||||
// OwnerTable is the table that contains owner of a BelongsTo
|
|
||||||
// relationship.
|
|
||||||
OwnerTable string
|
|
||||||
// LocalForeignKey is name of the field that links property
|
|
||||||
// to its owner in BelongsTo relation. for example when
|
|
||||||
// a Comment BelongsTo Post, LocalForeignKey is
|
|
||||||
// post_id of Comment.
|
|
||||||
LocalForeignKey string
|
|
||||||
// ForeignColumnName is name of the field that LocalForeignKey
|
|
||||||
// field value will point to it, for example when
|
|
||||||
// a Comment BelongsTo Post, ForeignColumnName is
|
|
||||||
// id of Post.
|
|
||||||
ForeignColumnName string
|
|
||||||
}
|
|
||||||
|
|
||||||
// BelongsTo configures a QueryBuilder for a BelongsTo relationship between
|
|
||||||
// OWNER type parameter and property argument, so
|
|
||||||
// property BelongsTo OWNER.
|
|
||||||
func BelongsTo[OWNER Entity](property Entity) *QueryBuilder[OWNER] {
|
|
||||||
owner := getSchemaFor(*new(OWNER))
|
|
||||||
q := NewQueryBuilder[OWNER](owner)
|
|
||||||
c, ok := getSchemaFor(property).relations[owner.Table].(BelongsToConfig)
|
|
||||||
if !ok {
|
|
||||||
q.err = fmt.Errorf("wrong config passed for BelongsTo")
|
|
||||||
}
|
|
||||||
|
|
||||||
ownerIDidx := 0
|
|
||||||
for idx, field := range owner.fields {
|
|
||||||
if field.Name == c.LocalForeignKey {
|
|
||||||
ownerIDidx = idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ownerID := genericValuesOf(property, true)[ownerIDidx]
|
|
||||||
|
|
||||||
return q.
|
|
||||||
SetDialect(owner.getDialect()).
|
|
||||||
Table(c.OwnerTable).Select(owner.Columns(true)...).
|
|
||||||
Where(c.ForeignColumnName, ownerID)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// BelongsToManyConfig contains information that we
|
|
||||||
// need for creating many to many queries.
|
|
||||||
type BelongsToManyConfig struct {
|
|
||||||
// IntermediateTable is the name of the middle table
|
|
||||||
// in a BelongsToMany (Many to Many) relationship.
|
|
||||||
// for example when we have Post BelongsToMany
|
|
||||||
// Category, this table will be post_categories
|
|
||||||
// table, remember that this field cannot be
|
|
||||||
// inferred.
|
|
||||||
IntermediateTable string
|
|
||||||
// IntermediatePropertyID is the name of the field name
|
|
||||||
// of property foreign key in intermediate table,
|
|
||||||
// for example when we have Post BelongsToMany
|
|
||||||
// Category, in post_categories table, it would
|
|
||||||
// be post_id.
|
|
||||||
IntermediatePropertyID string
|
|
||||||
// IntermediateOwnerID is the name of the field name
|
|
||||||
// of property foreign key in intermediate table,
|
|
||||||
// for example when we have Post BelongsToMany
|
|
||||||
// Category, in post_categories table, it would
|
|
||||||
// be category_id.
|
|
||||||
IntermediateOwnerID string
|
|
||||||
// Table name of the owner in BelongsToMany relation,
|
|
||||||
// for example in Post BelongsToMany Category
|
|
||||||
// Owner table is name of Category table
|
|
||||||
// for example `categories`.
|
|
||||||
OwnerTable string
|
|
||||||
// OwnerLookupColumn is name of the field in the owner
|
|
||||||
// table that is used in query, for example in Post BelongsToMany Category
|
|
||||||
// Owner lookup field would be Category primary key which is id.
|
|
||||||
OwnerLookupColumn string
|
|
||||||
}
|
|
||||||
|
|
||||||
// BelongsToMany configures a QueryBuilder for a BelongsToMany relationship
|
|
||||||
func BelongsToMany[OWNER Entity](property Entity) *QueryBuilder[OWNER] {
|
|
||||||
out := *new(OWNER)
|
|
||||||
outSchema := getSchemaFor(out)
|
|
||||||
q := NewQueryBuilder[OWNER](outSchema)
|
|
||||||
c, ok := getSchemaFor(property).relations[outSchema.Table].(BelongsToManyConfig)
|
|
||||||
if !ok {
|
|
||||||
q.err = fmt.Errorf("wrong config passed for HasMany")
|
|
||||||
}
|
|
||||||
return q.
|
|
||||||
Select(outSchema.Columns(true)...).
|
|
||||||
Table(outSchema.Table).
|
|
||||||
WhereIn(c.OwnerLookupColumn, Raw(fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
|
|
||||||
c.IntermediatePropertyID,
|
|
||||||
c.IntermediateTable, c.IntermediateOwnerID), genericGetPKValue(property)))
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add adds `items` to `to` using relations defined between items and to in ConfigureEntity method of `to`.
|
|
||||||
func Add(to Entity, items ...Entity) error {
|
|
||||||
if len(items) == 0 {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
rels := getSchemaFor(to).relations
|
|
||||||
tname := getSchemaFor(items[0]).Table
|
|
||||||
c, ok := rels[tname]
|
|
||||||
if !ok {
|
|
||||||
return fmt.Errorf("no config found for given to and item...")
|
|
||||||
}
|
|
||||||
switch c.(type) {
|
|
||||||
case HasManyConfig:
|
|
||||||
return addProperty(to, items...)
|
|
||||||
case HasOneConfig:
|
|
||||||
return addProperty(to, items[0])
|
|
||||||
case BelongsToManyConfig:
|
|
||||||
return addM2M(to, items...)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("cannot add for relation: %T", rels[getSchemaFor(items[0]).Table])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func addM2M(to Entity, items ...Entity) error {
|
|
||||||
//TODO: Optimize this
|
|
||||||
rels := getSchemaFor(to).relations
|
|
||||||
tname := getSchemaFor(items[0]).Table
|
|
||||||
c := rels[tname].(BelongsToManyConfig)
|
|
||||||
var values [][]interface{}
|
|
||||||
ownerPk := genericGetPKValue(to)
|
|
||||||
for _, item := range items {
|
|
||||||
pk := genericGetPKValue(item)
|
|
||||||
if isZero(pk) {
|
|
||||||
err := Insert(item)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
pk = genericGetPKValue(item)
|
|
||||||
}
|
|
||||||
values = append(values, []interface{}{ownerPk, pk})
|
|
||||||
}
|
|
||||||
i := insertStmt{
|
|
||||||
PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator,
|
|
||||||
Table: c.IntermediateTable,
|
|
||||||
Columns: []string{c.IntermediateOwnerID, c.IntermediatePropertyID},
|
|
||||||
Values: values,
|
|
||||||
}
|
|
||||||
|
|
||||||
q, args := i.ToSql()
|
|
||||||
|
|
||||||
_, err := getConnectionFor(items[0]).DB.Exec(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// addHasMany(Post, comments)
|
|
||||||
func addProperty(to Entity, items ...Entity) error {
|
|
||||||
var lastTable string
|
|
||||||
for _, obj := range items {
|
|
||||||
s := getSchemaFor(obj)
|
|
||||||
if lastTable == "" {
|
|
||||||
lastTable = s.Table
|
|
||||||
} else {
|
|
||||||
if lastTable != s.Table {
|
|
||||||
return fmt.Errorf("cannot batch insert for two different tables: %s and %s", s.Table, lastTable)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
i := insertStmt{
|
|
||||||
PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator,
|
|
||||||
Table: getSchemaFor(items[0]).getTable(),
|
|
||||||
}
|
|
||||||
ownerPKIdx := -1
|
|
||||||
ownerPKName := getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey
|
|
||||||
for idx, col := range getSchemaFor(items[0]).Columns(false) {
|
|
||||||
if col == ownerPKName {
|
|
||||||
ownerPKIdx = idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ownerPK := genericGetPKValue(to)
|
|
||||||
if ownerPKIdx != -1 {
|
|
||||||
cols := getSchemaFor(items[0]).Columns(false)
|
|
||||||
i.Columns = append(i.Columns, cols...)
|
|
||||||
// Owner PK is present in the items struct
|
|
||||||
for _, item := range items {
|
|
||||||
vals := genericValuesOf(item, false)
|
|
||||||
if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey {
|
|
||||||
return fmt.Errorf("owner pk idx is not correct")
|
|
||||||
}
|
|
||||||
vals[ownerPKIdx] = ownerPK
|
|
||||||
i.Values = append(i.Values, vals)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
ownerPKIdx = 0
|
|
||||||
cols := getSchemaFor(items[0]).Columns(false)
|
|
||||||
cols = append(cols[:ownerPKIdx+1], cols[ownerPKIdx:]...)
|
|
||||||
cols[ownerPKIdx] = getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey
|
|
||||||
i.Columns = append(i.Columns, cols...)
|
|
||||||
for _, item := range items {
|
|
||||||
vals := genericValuesOf(item, false)
|
|
||||||
if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey {
|
|
||||||
return fmt.Errorf("owner pk idx is not correct")
|
|
||||||
}
|
|
||||||
vals = append(vals[:ownerPKIdx+1], vals[ownerPKIdx:]...)
|
|
||||||
vals[ownerPKIdx] = ownerPK
|
|
||||||
i.Values = append(i.Values, vals)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
q, args := i.ToSql()
|
|
||||||
|
|
||||||
_, err := getConnectionFor(items[0]).DB.Exec(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Query creates a new QueryBuilder for given type parameter, sets dialect and table as well.
|
|
||||||
func Query[E Entity]() *QueryBuilder[E] {
|
|
||||||
s := getSchemaFor(*new(E))
|
|
||||||
q := NewQueryBuilder[E](s)
|
|
||||||
q.SetDialect(s.getDialect()).Table(s.Table)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// ExecRaw executes given query string and arguments on given type parameter database connection.
|
|
||||||
func ExecRaw[E Entity](q string, args ...interface{}) (int64, int64, error) {
|
|
||||||
e := new(E)
|
|
||||||
|
|
||||||
res, err := getSchemaFor(*e).getSQLDB().Exec(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
id, err := res.LastInsertId()
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
affected, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return 0, 0, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return id, affected, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// QueryRaw queries given query string and arguments on given type parameter database connection.
|
|
||||||
func QueryRaw[OUTPUT Entity](q string, args ...interface{}) ([]OUTPUT, error) {
|
|
||||||
o := new(OUTPUT)
|
|
||||||
rows, err := getSchemaFor(*o).getSQLDB().Query(q, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var output []OUTPUT
|
|
||||||
err = newBinder(getSchemaFor(*o)).bind(rows, &output)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return output, nil
|
|
||||||
}
|
|
@ -1,588 +0,0 @@
|
|||||||
//
|
|
||||||
// orm_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"git.hexq.cn/tiglog/golib/gdb/orm"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
type AuthorEmail struct {
|
|
||||||
ID int64
|
|
||||||
Email string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (a AuthorEmail) ConfigureEntity(e *orm.EntityConfigurator) {
|
|
||||||
e.
|
|
||||||
Table("emails").
|
|
||||||
Connection("default").
|
|
||||||
BelongsTo(&Post{}, orm.BelongsToConfig{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type HeaderPicture struct {
|
|
||||||
ID int64
|
|
||||||
PostID int64
|
|
||||||
Link string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h HeaderPicture) ConfigureEntity(e *orm.EntityConfigurator) {
|
|
||||||
e.Table("header_pictures").BelongsTo(&Post{}, orm.BelongsToConfig{})
|
|
||||||
}
|
|
||||||
|
|
||||||
type Post struct {
|
|
||||||
ID int64
|
|
||||||
BodyText string
|
|
||||||
CreatedAt sql.NullTime
|
|
||||||
UpdatedAt sql.NullTime
|
|
||||||
DeletedAt sql.NullTime
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p Post) ConfigureEntity(e *orm.EntityConfigurator) {
|
|
||||||
e.Field("BodyText").ColumnName("body")
|
|
||||||
e.Field("ID").ColumnName("id")
|
|
||||||
e.
|
|
||||||
Table("posts").
|
|
||||||
HasMany(Comment{}, orm.HasManyConfig{}).
|
|
||||||
HasOne(HeaderPicture{}, orm.HasOneConfig{}).
|
|
||||||
HasOne(AuthorEmail{}, orm.HasOneConfig{}).
|
|
||||||
BelongsToMany(Category{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Post) Categories() ([]Category, error) {
|
|
||||||
return orm.BelongsToMany[Category](p).All()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (p *Post) Comments() *orm.QueryBuilder[Comment] {
|
|
||||||
return orm.HasMany[Comment](p)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Comment struct {
|
|
||||||
ID int64
|
|
||||||
PostID int64
|
|
||||||
Body string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Comment) ConfigureEntity(e *orm.EntityConfigurator) {
|
|
||||||
e.Table("comments").BelongsTo(&Post{}, orm.BelongsToConfig{})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Comment) Post() (Post, error) {
|
|
||||||
return orm.BelongsTo[Post](c).Get()
|
|
||||||
}
|
|
||||||
|
|
||||||
type Category struct {
|
|
||||||
ID int64
|
|
||||||
Title string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Category) ConfigureEntity(e *orm.EntityConfigurator) {
|
|
||||||
e.Table("categories").BelongsToMany(Post{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c Category) Posts() ([]Post, error) {
|
|
||||||
return orm.BelongsToMany[Post](c).All()
|
|
||||||
}
|
|
||||||
|
|
||||||
// enough models let's test
|
|
||||||
// Entities is mandatory
|
|
||||||
// Errors should be carried
|
|
||||||
|
|
||||||
func setup() error {
|
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return orm.SetupConnections(orm.ConnectionConfig{
|
|
||||||
Name: "default",
|
|
||||||
DB: db,
|
|
||||||
Dialect: orm.Dialects.SQLite3,
|
|
||||||
Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}},
|
|
||||||
DatabaseValidations: true,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestFind(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
err = orm.InsertAll(&Post{
|
|
||||||
BodyText: "my body for insert",
|
|
||||||
})
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post, err := orm.Find[Post](1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "my body for insert", post.BodyText)
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInsert(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "my body for insert",
|
|
||||||
}
|
|
||||||
err = orm.Insert(post)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
var p Post
|
|
||||||
assert.NoError(t,
|
|
||||||
orm.GetConnection("default").DB.QueryRow(`SELECT id, body FROM posts where id = ?`, 1).Scan(&p.ID, &p.BodyText))
|
|
||||||
|
|
||||||
assert.Equal(t, "my body for insert", p.BodyText)
|
|
||||||
}
|
|
||||||
func TestInsertAll(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post1 := &Post{
|
|
||||||
BodyText: "Body1",
|
|
||||||
}
|
|
||||||
post2 := &Post{
|
|
||||||
BodyText: "Body2",
|
|
||||||
}
|
|
||||||
|
|
||||||
post3 := &Post{
|
|
||||||
BodyText: "Body3",
|
|
||||||
}
|
|
||||||
|
|
||||||
err = orm.InsertAll(post1, post2, post3)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
var counter int
|
|
||||||
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts`).Scan(&counter))
|
|
||||||
assert.Equal(t, 3, counter)
|
|
||||||
|
|
||||||
}
|
|
||||||
func TestUpdateORM(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "my body for insert",
|
|
||||||
}
|
|
||||||
err = orm.Insert(post)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
post.BodyText += " update text"
|
|
||||||
assert.NoError(t, orm.Update(post))
|
|
||||||
|
|
||||||
var body string
|
|
||||||
assert.NoError(t,
|
|
||||||
orm.GetConnection("default").DB.QueryRow(`SELECT body FROM posts where id = ?`, post.ID).Scan(&body))
|
|
||||||
|
|
||||||
assert.Equal(t, "my body for insert update text", body)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDeleteORM(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "my body for insert",
|
|
||||||
}
|
|
||||||
err = orm.Insert(post)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Delete(post))
|
|
||||||
|
|
||||||
var count int
|
|
||||||
assert.NoError(t,
|
|
||||||
orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts where id = ?`, post.ID).Scan(&count))
|
|
||||||
|
|
||||||
assert.Equal(t, 0, count)
|
|
||||||
}
|
|
||||||
func TestAdd_HasMany(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "my body for insert",
|
|
||||||
}
|
|
||||||
err = orm.Insert(post)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
err = orm.Add(post, []orm.Entity{
|
|
||||||
Comment{
|
|
||||||
Body: "comment 1",
|
|
||||||
},
|
|
||||||
Comment{
|
|
||||||
Body: "comment 2",
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
// orm.Query(qm.WhereBetween())
|
|
||||||
assert.NoError(t, err)
|
|
||||||
var count int
|
|
||||||
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM comments`).Scan(&count))
|
|
||||||
assert.Equal(t, 2, count)
|
|
||||||
|
|
||||||
comment, err := orm.Find[Comment](1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, int64(1), comment.PostID)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAdd_ManyToMany(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "my body for insert",
|
|
||||||
}
|
|
||||||
err = orm.Insert(post)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
err = orm.Add(post, []orm.Entity{
|
|
||||||
&Category{
|
|
||||||
Title: "cat 1",
|
|
||||||
},
|
|
||||||
&Category{
|
|
||||||
Title: "cat 2",
|
|
||||||
},
|
|
||||||
}...)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
var count int
|
|
||||||
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(post_id) FROM post_categories`).Scan(&count))
|
|
||||||
assert.Equal(t, 2, count)
|
|
||||||
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM categories`).Scan(&count))
|
|
||||||
assert.Equal(t, 2, count)
|
|
||||||
|
|
||||||
categories, err := post.Categories()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, 2, len(categories))
|
|
||||||
assert.Equal(t, int64(1), categories[0].ID)
|
|
||||||
assert.Equal(t, int64(2), categories[1].ID)
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSave(t *testing.T) {
|
|
||||||
t.Run("save should insert", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "1",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("save should update", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "1",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
post.BodyText += "2"
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
|
|
||||||
myPost, err := orm.Find[Post](1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.EqualValues(t, post.BodyText, myPost.BodyText)
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasMany(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "first post",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(&Comment{
|
|
||||||
PostID: post.ID,
|
|
||||||
Body: "comment 1",
|
|
||||||
}))
|
|
||||||
assert.NoError(t, orm.Save(&Comment{
|
|
||||||
PostID: post.ID,
|
|
||||||
Body: "comment 2",
|
|
||||||
}))
|
|
||||||
|
|
||||||
comments, err := orm.HasMany[Comment](post).All()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Len(t, comments, 2)
|
|
||||||
|
|
||||||
assert.Equal(t, post.ID, comments[0].PostID)
|
|
||||||
assert.Equal(t, post.ID, comments[1].PostID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBelongsTo(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "first post",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
comment := &Comment{
|
|
||||||
PostID: post.ID,
|
|
||||||
Body: "comment 1",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(comment))
|
|
||||||
|
|
||||||
post2, err := orm.BelongsTo[Post](comment).Get()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, post.BodyText, post2.BodyText)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHasOne(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "first post",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
headerPicture := &HeaderPicture{
|
|
||||||
PostID: post.ID,
|
|
||||||
Link: "google",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(headerPicture))
|
|
||||||
|
|
||||||
c1, err := orm.HasOne[HeaderPicture](post).Get()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Equal(t, headerPicture.PostID, c1.PostID)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestBelongsToMany(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "first Post",
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.Equal(t, int64(1), post.ID)
|
|
||||||
|
|
||||||
category := &Category{
|
|
||||||
Title: "first category",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(category))
|
|
||||||
assert.Equal(t, int64(1), category.ID)
|
|
||||||
|
|
||||||
_, _, err = orm.ExecRaw[Category](`INSERT INTO post_categories (post_id, category_id) VALUES (?,?)`, post.ID, category.ID)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
categories, err := orm.BelongsToMany[Category](post).All()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.Len(t, categories, 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSchematic(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
orm.Schematic()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestAddProperty(t *testing.T) {
|
|
||||||
t.Run("having pk value", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "first post",
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.EqualValues(t, 1, post.ID)
|
|
||||||
|
|
||||||
err = orm.Add(post, &Comment{PostID: post.ID, Body: "firstComment"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
var comment Comment
|
|
||||||
assert.NoError(t, orm.GetConnection("default").
|
|
||||||
DB.
|
|
||||||
QueryRow(`SELECT id, post_id, body FROM comments WHERE post_id=?`, post.ID).
|
|
||||||
Scan(&comment.ID, &comment.PostID, &comment.Body))
|
|
||||||
|
|
||||||
assert.EqualValues(t, post.ID, comment.PostID)
|
|
||||||
})
|
|
||||||
t.Run("not having PK value", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
post := &Post{
|
|
||||||
BodyText: "first post",
|
|
||||||
}
|
|
||||||
assert.NoError(t, orm.Save(post))
|
|
||||||
assert.EqualValues(t, 1, post.ID)
|
|
||||||
|
|
||||||
err = orm.Add(post, &AuthorEmail{Email: "myemail"})
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
emails, err := orm.QueryRaw[AuthorEmail](`SELECT id, email FROM emails WHERE post_id=?`, post.ID)
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, []AuthorEmail{{ID: 1, Email: "myemail"}}, emails)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestQuery(t *testing.T) {
|
|
||||||
t.Run("querying single row", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
|
||||||
// post, err := orm.Query[Post]().Where("id", 1).First()
|
|
||||||
post, err := orm.Query[Post]().WherePK(1).First().Get()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, "body 1", post.BodyText)
|
|
||||||
assert.EqualValues(t, 1, post.ID)
|
|
||||||
|
|
||||||
})
|
|
||||||
t.Run("querying multiple rows", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 2"}))
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 3"}))
|
|
||||||
posts, err := orm.Query[Post]().All()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Len(t, posts, 3)
|
|
||||||
assert.Equal(t, "body 1", posts[0].BodyText)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("updating a row using query interface", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
|
||||||
|
|
||||||
affected, err := orm.Query[Post]().Where("id", 1).Set("body", "body jadid").Update()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, 1, affected)
|
|
||||||
|
|
||||||
post, err := orm.Find[Post](1)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, "body jadid", post.BodyText)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("deleting a row using query interface", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
|
||||||
|
|
||||||
affected, err := orm.Query[Post]().WherePK(1).Delete()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, 1, affected)
|
|
||||||
count, err := orm.Query[Post]().WherePK(1).Count().Get()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, 0, count)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("count", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
count, err := orm.Query[Post]().WherePK(1).Count().Get()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, 0, count)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("latest", func(t *testing.T) {
|
|
||||||
err := setup()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
|
||||||
assert.NoError(t, orm.Save(&Post{BodyText: "body 2"}))
|
|
||||||
|
|
||||||
post, err := orm.Query[Post]().Latest().Get()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
|
|
||||||
assert.EqualValues(t, "body 2", post.BodyText)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSetup(t *testing.T) {
|
|
||||||
t.Run("tables are out of sync", func(t *testing.T) {
|
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
|
||||||
// _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
|
||||||
// _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
|
||||||
|
|
||||||
err = orm.SetupConnections(orm.ConnectionConfig{
|
|
||||||
Name: "default",
|
|
||||||
DB: db,
|
|
||||||
Dialect: orm.Dialects.SQLite3,
|
|
||||||
Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}},
|
|
||||||
DatabaseValidations: true,
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
})
|
|
||||||
t.Run("schemas are wrong", func(t *testing.T) {
|
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, body text)`) // missing post_id
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
|
||||||
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
|
||||||
|
|
||||||
err = orm.SetupConnections(orm.ConnectionConfig{
|
|
||||||
Name: "default",
|
|
||||||
DB: db,
|
|
||||||
Dialect: orm.Dialects.SQLite3,
|
|
||||||
Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}},
|
|
||||||
DatabaseValidations: true,
|
|
||||||
})
|
|
||||||
assert.Error(t, err)
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
811
gdb/orm/query.go
811
gdb/orm/query.go
@ -1,811 +0,0 @@
|
|||||||
//
|
|
||||||
// query.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
queryTypeSELECT = iota + 1
|
|
||||||
queryTypeUPDATE
|
|
||||||
queryTypeDelete
|
|
||||||
)
|
|
||||||
|
|
||||||
// QueryBuilder is our query builder, almost all methods and functions in GoLobby ORM
|
|
||||||
// create or configure instance of QueryBuilder.
|
|
||||||
type QueryBuilder[OUTPUT any] struct {
|
|
||||||
typ int
|
|
||||||
schema *schema
|
|
||||||
// general parts
|
|
||||||
where *whereClause
|
|
||||||
table string
|
|
||||||
placeholderGenerator func(n int) []string
|
|
||||||
|
|
||||||
// select parts
|
|
||||||
orderBy *orderByClause
|
|
||||||
groupBy *GroupBy
|
|
||||||
selected *selected
|
|
||||||
subQuery *struct {
|
|
||||||
q string
|
|
||||||
args []interface{}
|
|
||||||
placeholderGenerator func(n int) []string
|
|
||||||
}
|
|
||||||
joins []*Join
|
|
||||||
limit *Limit
|
|
||||||
offset *Offset
|
|
||||||
|
|
||||||
// update parts
|
|
||||||
sets [][2]interface{}
|
|
||||||
|
|
||||||
// execution parts
|
|
||||||
db *sql.DB
|
|
||||||
err error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Finisher APIs
|
|
||||||
|
|
||||||
// execute is a finisher executes QueryBuilder query, remember to use this when you have an Update
|
|
||||||
// or Delete Query.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) execute() (sql.Result, error) {
|
|
||||||
if q.err != nil {
|
|
||||||
return nil, q.err
|
|
||||||
}
|
|
||||||
if q.typ == queryTypeSELECT {
|
|
||||||
return nil, fmt.Errorf("query type is SELECT")
|
|
||||||
}
|
|
||||||
query, args, err := q.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return q.schema.getConnection().exec(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get limit results to 1, runs query generated by query builder, scans result into OUTPUT.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Get() (OUTPUT, error) {
|
|
||||||
if q.err != nil {
|
|
||||||
return *new(OUTPUT), q.err
|
|
||||||
}
|
|
||||||
queryString, args, err := q.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return *new(OUTPUT), err
|
|
||||||
}
|
|
||||||
rows, err := q.schema.getConnection().query(queryString, args...)
|
|
||||||
if err != nil {
|
|
||||||
return *new(OUTPUT), err
|
|
||||||
}
|
|
||||||
var output OUTPUT
|
|
||||||
err = newBinder(q.schema).bind(rows, &output)
|
|
||||||
if err != nil {
|
|
||||||
return *new(OUTPUT), err
|
|
||||||
}
|
|
||||||
return output, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// All is a finisher, create the Select query based on QueryBuilder and scan results into
|
|
||||||
// slice of type parameter E.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) All() ([]OUTPUT, error) {
|
|
||||||
if q.err != nil {
|
|
||||||
return nil, q.err
|
|
||||||
}
|
|
||||||
q.SetSelect()
|
|
||||||
queryString, args, err := q.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
rows, err := q.schema.getConnection().query(queryString, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
var output []OUTPUT
|
|
||||||
err = newBinder(q.schema).bind(rows, &output)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return output, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete is a finisher, creates a delete query from query builder and executes it.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Delete() (rowsAffected int64, err error) {
|
|
||||||
if q.err != nil {
|
|
||||||
return 0, q.err
|
|
||||||
}
|
|
||||||
q.SetDelete()
|
|
||||||
res, err := q.execute()
|
|
||||||
if err != nil {
|
|
||||||
return 0, q.err
|
|
||||||
}
|
|
||||||
return res.RowsAffected()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update is a finisher, creates an Update query from QueryBuilder and executes in into database, returns
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Update() (rowsAffected int64, err error) {
|
|
||||||
if q.err != nil {
|
|
||||||
return 0, q.err
|
|
||||||
}
|
|
||||||
q.SetUpdate()
|
|
||||||
res, err := q.execute()
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return res.RowsAffected()
|
|
||||||
}
|
|
||||||
|
|
||||||
func copyQueryBuilder[T1 any, T2 any](q *QueryBuilder[T1], q2 *QueryBuilder[T2]) {
|
|
||||||
q2.db = q.db
|
|
||||||
q2.err = q.err
|
|
||||||
q2.groupBy = q.groupBy
|
|
||||||
q2.joins = q.joins
|
|
||||||
q2.limit = q.limit
|
|
||||||
q2.offset = q.offset
|
|
||||||
q2.orderBy = q.orderBy
|
|
||||||
q2.placeholderGenerator = q.placeholderGenerator
|
|
||||||
q2.schema = q.schema
|
|
||||||
q2.selected = q.selected
|
|
||||||
q2.sets = q.sets
|
|
||||||
|
|
||||||
q2.subQuery = q.subQuery
|
|
||||||
q2.table = q.table
|
|
||||||
q2.typ = q.typ
|
|
||||||
q2.where = q.where
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count creates and execute a select query from QueryBuilder and set it's field list of selection
|
|
||||||
// to COUNT(id).
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Count() *QueryBuilder[int] {
|
|
||||||
q.selected = &selected{Columns: []string{"COUNT(id)"}}
|
|
||||||
q.SetSelect()
|
|
||||||
qCount := NewQueryBuilder[int](q.schema)
|
|
||||||
|
|
||||||
copyQueryBuilder(q, qCount)
|
|
||||||
|
|
||||||
return qCount
|
|
||||||
}
|
|
||||||
|
|
||||||
// First returns first record of database using OrderBy primary key
|
|
||||||
// ascending order.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) First() *QueryBuilder[OUTPUT] {
|
|
||||||
q.OrderBy(q.schema.pkName(), ASC).Limit(1)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Latest is like Get but it also do a OrderBy(primary key, DESC)
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Latest() *QueryBuilder[OUTPUT] {
|
|
||||||
q.OrderBy(q.schema.pkName(), DESC).Limit(1)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// WherePK adds a where clause to QueryBuilder and also gets primary key name
|
|
||||||
// from type parameter schema.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) WherePK(value interface{}) *QueryBuilder[OUTPUT] {
|
|
||||||
return q.Where(q.schema.pkName(), value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *QueryBuilder[OUTPUT]) toSqlDelete() (string, []interface{}, error) {
|
|
||||||
base := fmt.Sprintf("DELETE FROM %s", d.table)
|
|
||||||
var args []interface{}
|
|
||||||
if d.where != nil {
|
|
||||||
d.where.PlaceHolderGenerator = d.placeholderGenerator
|
|
||||||
where, whereArgs, err := d.where.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
base += " WHERE " + where
|
|
||||||
args = append(args, whereArgs...)
|
|
||||||
}
|
|
||||||
return base, args, nil
|
|
||||||
}
|
|
||||||
func pop(phs *[]string) string {
|
|
||||||
top := (*phs)[len(*phs)-1]
|
|
||||||
*phs = (*phs)[:len(*phs)-1]
|
|
||||||
return top
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *QueryBuilder[OUTPUT]) kvString() string {
|
|
||||||
phs := u.placeholderGenerator(len(u.sets))
|
|
||||||
var sets []string
|
|
||||||
for _, pair := range u.sets {
|
|
||||||
sets = append(sets, fmt.Sprintf("%s=%s", pair[0], pop(&phs)))
|
|
||||||
}
|
|
||||||
return strings.Join(sets, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *QueryBuilder[OUTPUT]) args() []interface{} {
|
|
||||||
var values []interface{}
|
|
||||||
for _, pair := range u.sets {
|
|
||||||
values = append(values, pair[1])
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
func (u *QueryBuilder[OUTPUT]) toSqlUpdate() (string, []interface{}, error) {
|
|
||||||
if u.table == "" {
|
|
||||||
return "", nil, fmt.Errorf("table cannot be empty")
|
|
||||||
}
|
|
||||||
base := fmt.Sprintf("UPDATE %s SET %s", u.table, u.kvString())
|
|
||||||
args := u.args()
|
|
||||||
if u.where != nil {
|
|
||||||
u.where.PlaceHolderGenerator = u.placeholderGenerator
|
|
||||||
where, whereArgs, err := u.where.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
args = append(args, whereArgs...)
|
|
||||||
base += " WHERE " + where
|
|
||||||
}
|
|
||||||
return base, args, nil
|
|
||||||
}
|
|
||||||
func (s *QueryBuilder[OUTPUT]) toSqlSelect() (string, []interface{}, error) {
|
|
||||||
if s.err != nil {
|
|
||||||
return "", nil, s.err
|
|
||||||
}
|
|
||||||
base := "SELECT"
|
|
||||||
var args []interface{}
|
|
||||||
// select
|
|
||||||
if s.selected == nil {
|
|
||||||
s.selected = &selected{
|
|
||||||
Columns: []string{"*"},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
base += " " + s.selected.String()
|
|
||||||
// from
|
|
||||||
if s.table == "" && s.subQuery == nil {
|
|
||||||
return "", nil, fmt.Errorf("Table name cannot be empty")
|
|
||||||
} else if s.table != "" && s.subQuery != nil {
|
|
||||||
return "", nil, fmt.Errorf("cannot have both Table and subquery")
|
|
||||||
}
|
|
||||||
if s.table != "" {
|
|
||||||
base += " " + "FROM " + s.table
|
|
||||||
}
|
|
||||||
if s.subQuery != nil {
|
|
||||||
s.subQuery.placeholderGenerator = s.placeholderGenerator
|
|
||||||
base += " " + "FROM (" + s.subQuery.q + " )"
|
|
||||||
args = append(args, s.subQuery.args...)
|
|
||||||
}
|
|
||||||
// Joins
|
|
||||||
if s.joins != nil {
|
|
||||||
for _, join := range s.joins {
|
|
||||||
base += " " + join.String()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// whereClause
|
|
||||||
if s.where != nil {
|
|
||||||
s.where.PlaceHolderGenerator = s.placeholderGenerator
|
|
||||||
where, whereArgs, err := s.where.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
base += " WHERE " + where
|
|
||||||
args = append(args, whereArgs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// orderByClause
|
|
||||||
if s.orderBy != nil {
|
|
||||||
base += " " + s.orderBy.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupBy
|
|
||||||
if s.groupBy != nil {
|
|
||||||
base += " " + s.groupBy.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit
|
|
||||||
if s.limit != nil {
|
|
||||||
base += " " + s.limit.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset
|
|
||||||
if s.offset != nil {
|
|
||||||
base += " " + s.offset.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return base, args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// ToSql creates sql query from QueryBuilder based on internal fields it would decide what kind
|
|
||||||
// of query to build.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) ToSql() (string, []interface{}, error) {
|
|
||||||
if q.err != nil {
|
|
||||||
return "", nil, q.err
|
|
||||||
}
|
|
||||||
if q.typ == queryTypeSELECT {
|
|
||||||
return q.toSqlSelect()
|
|
||||||
} else if q.typ == queryTypeDelete {
|
|
||||||
return q.toSqlDelete()
|
|
||||||
} else if q.typ == queryTypeUPDATE {
|
|
||||||
return q.toSqlUpdate()
|
|
||||||
} else {
|
|
||||||
return "", nil, fmt.Errorf("no sql type matched")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type orderByOrder string
|
|
||||||
|
|
||||||
const (
|
|
||||||
ASC = "ASC"
|
|
||||||
DESC = "DESC"
|
|
||||||
)
|
|
||||||
|
|
||||||
type orderByClause struct {
|
|
||||||
Columns [][2]string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o orderByClause) String() string {
|
|
||||||
var tuples []string
|
|
||||||
for _, pair := range o.Columns {
|
|
||||||
tuples = append(tuples, fmt.Sprintf("%s %s", pair[0], pair[1]))
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("ORDER BY %s", strings.Join(tuples, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
type GroupBy struct {
|
|
||||||
Columns []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g GroupBy) String() string {
|
|
||||||
return fmt.Sprintf("GROUP BY %s", strings.Join(g.Columns, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
type joinType string
|
|
||||||
|
|
||||||
const (
|
|
||||||
JoinTypeInner = "INNER"
|
|
||||||
JoinTypeLeft = "LEFT"
|
|
||||||
JoinTypeRight = "RIGHT"
|
|
||||||
JoinTypeFull = "FULL OUTER"
|
|
||||||
JoinTypeSelf = "SELF"
|
|
||||||
)
|
|
||||||
|
|
||||||
type JoinOn struct {
|
|
||||||
Lhs string
|
|
||||||
Rhs string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j JoinOn) String() string {
|
|
||||||
return fmt.Sprintf("%s = %s", j.Lhs, j.Rhs)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Join struct {
|
|
||||||
Type joinType
|
|
||||||
Table string
|
|
||||||
On JoinOn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (j Join) String() string {
|
|
||||||
return fmt.Sprintf("%s JOIN %s ON %s", j.Type, j.Table, j.On.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
type Limit struct {
|
|
||||||
N int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l Limit) String() string {
|
|
||||||
return fmt.Sprintf("LIMIT %d", l.N)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Offset struct {
|
|
||||||
N int
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Offset) String() string {
|
|
||||||
return fmt.Sprintf("OFFSET %d", o.N)
|
|
||||||
}
|
|
||||||
|
|
||||||
type selected struct {
|
|
||||||
Columns []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s selected) String() string {
|
|
||||||
return fmt.Sprintf("%s", strings.Join(s.Columns, ","))
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrderBy adds an OrderBy section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) OrderBy(column string, how string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
if q.orderBy == nil {
|
|
||||||
q.orderBy = &orderByClause{}
|
|
||||||
}
|
|
||||||
q.orderBy.Columns = append(q.orderBy.Columns, [2]string{column, how})
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// LeftJoin adds a left join section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) LeftJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
q.joins = append(q.joins, &Join{
|
|
||||||
Type: JoinTypeLeft,
|
|
||||||
Table: table,
|
|
||||||
On: JoinOn{
|
|
||||||
Lhs: onLhs,
|
|
||||||
Rhs: onRhs,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// RightJoin adds a right join section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) RightJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
q.joins = append(q.joins, &Join{
|
|
||||||
Type: JoinTypeRight,
|
|
||||||
Table: table,
|
|
||||||
On: JoinOn{
|
|
||||||
Lhs: onLhs,
|
|
||||||
Rhs: onRhs,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// InnerJoin adds a inner join section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) InnerJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
q.joins = append(q.joins, &Join{
|
|
||||||
Type: JoinTypeInner,
|
|
||||||
Table: table,
|
|
||||||
On: JoinOn{
|
|
||||||
Lhs: onLhs,
|
|
||||||
Rhs: onRhs,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Join adds a inner join section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Join(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
|
||||||
return q.InnerJoin(table, onLhs, onRhs)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FullOuterJoin adds a full outer join section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) FullOuterJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
q.joins = append(q.joins, &Join{
|
|
||||||
Type: JoinTypeFull,
|
|
||||||
Table: table,
|
|
||||||
On: JoinOn{
|
|
||||||
Lhs: onLhs,
|
|
||||||
Rhs: onRhs,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Where Adds a where clause to query, if already have where clause append to it
|
|
||||||
// as AndWhere.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Where(parts ...interface{}) *QueryBuilder[OUTPUT] {
|
|
||||||
if q.where != nil {
|
|
||||||
return q.addWhere("AND", parts...)
|
|
||||||
}
|
|
||||||
if len(parts) == 1 {
|
|
||||||
if r, isRaw := parts[0].(*raw); isRaw {
|
|
||||||
q.where = &whereClause{raw: r.sql, args: r.args, PlaceHolderGenerator: q.placeholderGenerator}
|
|
||||||
return q
|
|
||||||
} else {
|
|
||||||
q.err = fmt.Errorf("when you have one argument passed to where, it should be *raw")
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
} else if len(parts) == 2 {
|
|
||||||
if strings.Index(parts[0].(string), " ") == -1 {
|
|
||||||
// Equal mode
|
|
||||||
q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]}, PlaceHolderGenerator: q.placeholderGenerator}
|
|
||||||
}
|
|
||||||
return q
|
|
||||||
} else if len(parts) == 3 {
|
|
||||||
// operator mode
|
|
||||||
q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]}, PlaceHolderGenerator: q.placeholderGenerator}
|
|
||||||
return q
|
|
||||||
} else if len(parts) > 3 && parts[1].(string) == "IN" {
|
|
||||||
q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2:]}, PlaceHolderGenerator: q.placeholderGenerator}
|
|
||||||
return q
|
|
||||||
} else {
|
|
||||||
q.err = fmt.Errorf("wrong number of arguments passed to Where")
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type binaryOp string
|
|
||||||
|
|
||||||
const (
|
|
||||||
Eq = "="
|
|
||||||
GT = ">"
|
|
||||||
LT = "<"
|
|
||||||
GE = ">="
|
|
||||||
LE = "<="
|
|
||||||
NE = "!="
|
|
||||||
Between = "BETWEEN"
|
|
||||||
Like = "LIKE"
|
|
||||||
In = "IN"
|
|
||||||
)
|
|
||||||
|
|
||||||
type cond struct {
|
|
||||||
PlaceHolderGenerator func(n int) []string
|
|
||||||
|
|
||||||
Lhs string
|
|
||||||
Op binaryOp
|
|
||||||
Rhs interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b cond) ToSql() (string, []interface{}, error) {
|
|
||||||
var phs []string
|
|
||||||
if b.Op == In {
|
|
||||||
rhs, isInterfaceSlice := b.Rhs.([]interface{})
|
|
||||||
if isInterfaceSlice {
|
|
||||||
phs = b.PlaceHolderGenerator(len(rhs))
|
|
||||||
return fmt.Sprintf("%s IN (%s)", b.Lhs, strings.Join(phs, ",")), rhs, nil
|
|
||||||
} else if rawThing, isRaw := b.Rhs.(*raw); isRaw {
|
|
||||||
return fmt.Sprintf("%s IN (%s)", b.Lhs, rawThing.sql), rawThing.args, nil
|
|
||||||
} else {
|
|
||||||
return "", nil, fmt.Errorf("Right hand side of Cond when operator is IN should be either a interface{} slice or *raw")
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
|
||||||
phs = b.PlaceHolderGenerator(1)
|
|
||||||
return fmt.Sprintf("%s %s %s", b.Lhs, b.Op, pop(&phs)), []interface{}{b.Rhs}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const (
|
|
||||||
nextType_AND = "AND"
|
|
||||||
nextType_OR = "OR"
|
|
||||||
)
|
|
||||||
|
|
||||||
type whereClause struct {
|
|
||||||
PlaceHolderGenerator func(n int) []string
|
|
||||||
nextTyp string
|
|
||||||
next *whereClause
|
|
||||||
cond
|
|
||||||
raw string
|
|
||||||
args []interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w whereClause) ToSql() (string, []interface{}, error) {
|
|
||||||
var base string
|
|
||||||
var args []interface{}
|
|
||||||
var err error
|
|
||||||
if w.raw != "" {
|
|
||||||
base = w.raw
|
|
||||||
args = w.args
|
|
||||||
} else {
|
|
||||||
w.cond.PlaceHolderGenerator = w.PlaceHolderGenerator
|
|
||||||
base, args, err = w.cond.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if w.next == nil {
|
|
||||||
return base, args, nil
|
|
||||||
}
|
|
||||||
if w.next != nil {
|
|
||||||
next, nextArgs, err := w.next.ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, err
|
|
||||||
}
|
|
||||||
base += " " + w.nextTyp + " " + next
|
|
||||||
args = append(args, nextArgs...)
|
|
||||||
return base, args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return base, args, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
//func (q *QueryBuilder[OUTPUT]) WhereKeyValue(m map) {}
|
|
||||||
|
|
||||||
// WhereIn adds a where clause to QueryBuilder using In operator.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) WhereIn(column string, values ...interface{}) *QueryBuilder[OUTPUT] {
|
|
||||||
return q.Where(append([]interface{}{column, In}, values...)...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AndWhere appends a where clause to query builder as And where clause.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) AndWhere(parts ...interface{}) *QueryBuilder[OUTPUT] {
|
|
||||||
return q.addWhere(nextType_AND, parts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// OrWhere appends a where clause to query builder as Or where clause.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) OrWhere(parts ...interface{}) *QueryBuilder[OUTPUT] {
|
|
||||||
return q.addWhere(nextType_OR, parts...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *QueryBuilder[OUTPUT]) addWhere(typ string, parts ...interface{}) *QueryBuilder[OUTPUT] {
|
|
||||||
w := q.where
|
|
||||||
for {
|
|
||||||
if w == nil {
|
|
||||||
break
|
|
||||||
} else if w.next == nil {
|
|
||||||
w.next = &whereClause{PlaceHolderGenerator: q.placeholderGenerator}
|
|
||||||
w.nextTyp = typ
|
|
||||||
w = w.next
|
|
||||||
break
|
|
||||||
} else {
|
|
||||||
w = w.next
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if w == nil {
|
|
||||||
w = &whereClause{PlaceHolderGenerator: q.placeholderGenerator}
|
|
||||||
}
|
|
||||||
if len(parts) == 1 {
|
|
||||||
w.raw = parts[0].(*raw).sql
|
|
||||||
w.args = parts[0].(*raw).args
|
|
||||||
return q
|
|
||||||
} else if len(parts) == 2 {
|
|
||||||
// Equal mode
|
|
||||||
w.cond = cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]}
|
|
||||||
return q
|
|
||||||
} else if len(parts) == 3 {
|
|
||||||
// operator mode
|
|
||||||
w.cond = cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]}
|
|
||||||
return q
|
|
||||||
} else {
|
|
||||||
panic("wrong number of arguments passed to Where")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Offset adds offset section to query builder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Offset(n int) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
q.offset = &Offset{N: n}
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Limit adds limit section to query builder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Limit(n int) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
q.limit = &Limit{N: n}
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Table sets table of QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Table(t string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.table = t
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetSelect sets query type of QueryBuilder to Select.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) SetSelect() *QueryBuilder[OUTPUT] {
|
|
||||||
q.typ = queryTypeSELECT
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// GroupBy adds a group by section to QueryBuilder.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) GroupBy(columns ...string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
if q.groupBy == nil {
|
|
||||||
q.groupBy = &GroupBy{}
|
|
||||||
}
|
|
||||||
q.groupBy.Columns = append(q.groupBy.Columns, columns...)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select adds columns to QueryBuilder select field list.
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Select(columns ...string) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
if q.selected == nil {
|
|
||||||
q.selected = &selected{}
|
|
||||||
}
|
|
||||||
q.selected.Columns = append(q.selected.Columns, columns...)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
// FromQuery sets subquery of QueryBuilder to be given subquery so
|
|
||||||
// when doing select instead of from table we do from(subquery).
|
|
||||||
func (q *QueryBuilder[OUTPUT]) FromQuery(subQuery *QueryBuilder[OUTPUT]) *QueryBuilder[OUTPUT] {
|
|
||||||
q.SetSelect()
|
|
||||||
subQuery.SetSelect()
|
|
||||||
subQuery.placeholderGenerator = q.placeholderGenerator
|
|
||||||
subQueryString, args, err := subQuery.ToSql()
|
|
||||||
q.err = err
|
|
||||||
q.subQuery = &struct {
|
|
||||||
q string
|
|
||||||
args []interface{}
|
|
||||||
placeholderGenerator func(n int) []string
|
|
||||||
}{
|
|
||||||
subQueryString, args, q.placeholderGenerator,
|
|
||||||
}
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *QueryBuilder[OUTPUT]) SetUpdate() *QueryBuilder[OUTPUT] {
|
|
||||||
q.typ = queryTypeUPDATE
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *QueryBuilder[OUTPUT]) Set(keyValues ...any) *QueryBuilder[OUTPUT] {
|
|
||||||
if len(keyValues)%2 != 0 {
|
|
||||||
q.err = fmt.Errorf("when using Set, passed argument count should be even: %w", q.err)
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
q.SetUpdate()
|
|
||||||
for i := 0; i < len(keyValues); i++ {
|
|
||||||
if i != 0 && i%2 == 1 {
|
|
||||||
q.sets = append(q.sets, [2]any{keyValues[i-1], keyValues[i]})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *QueryBuilder[OUTPUT]) SetDialect(dialect *Dialect) *QueryBuilder[OUTPUT] {
|
|
||||||
q.placeholderGenerator = dialect.PlaceHolderGenerator
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
func (q *QueryBuilder[OUTPUT]) SetDelete() *QueryBuilder[OUTPUT] {
|
|
||||||
q.typ = queryTypeDelete
|
|
||||||
return q
|
|
||||||
}
|
|
||||||
|
|
||||||
type raw struct {
|
|
||||||
sql string
|
|
||||||
args []interface{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Raw creates a Raw sql query chunk that you can add to several components of QueryBuilder like
|
|
||||||
// Wheres.
|
|
||||||
func Raw(sql string, args ...interface{}) *raw {
|
|
||||||
return &raw{sql: sql, args: args}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewQueryBuilder[OUTPUT any](s *schema) *QueryBuilder[OUTPUT] {
|
|
||||||
return &QueryBuilder[OUTPUT]{schema: s}
|
|
||||||
}
|
|
||||||
|
|
||||||
type insertStmt struct {
|
|
||||||
PlaceHolderGenerator func(n int) []string
|
|
||||||
Table string
|
|
||||||
Columns []string
|
|
||||||
Values [][]interface{}
|
|
||||||
Returning string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i insertStmt) flatValues() []interface{} {
|
|
||||||
var values []interface{}
|
|
||||||
for _, row := range i.Values {
|
|
||||||
values = append(values, row...)
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i insertStmt) getValuesStr() string {
|
|
||||||
phs := i.PlaceHolderGenerator(len(i.Values) * len(i.Values[0]))
|
|
||||||
|
|
||||||
var output []string
|
|
||||||
for _, valueRow := range i.Values {
|
|
||||||
output = append(output, fmt.Sprintf("(%s)", strings.Join(phs[:len(valueRow)], ",")))
|
|
||||||
phs = phs[len(valueRow):]
|
|
||||||
}
|
|
||||||
return strings.Join(output, ",")
|
|
||||||
}
|
|
||||||
|
|
||||||
func (i insertStmt) ToSql() (string, []interface{}) {
|
|
||||||
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
|
|
||||||
i.Table,
|
|
||||||
strings.Join(i.Columns, ","),
|
|
||||||
i.getValuesStr(),
|
|
||||||
)
|
|
||||||
if i.Returning != "" {
|
|
||||||
base += "RETURNING " + i.Returning
|
|
||||||
}
|
|
||||||
return base, i.flatValues()
|
|
||||||
}
|
|
||||||
|
|
||||||
func postgresPlaceholder(n int) []string {
|
|
||||||
output := []string{}
|
|
||||||
for i := 1; i < n+1; i++ {
|
|
||||||
output = append(output, fmt.Sprintf("$%d", i))
|
|
||||||
}
|
|
||||||
return output
|
|
||||||
}
|
|
||||||
|
|
||||||
func questionMarks(n int) []string {
|
|
||||||
output := []string{}
|
|
||||||
for i := 0; i < n; i++ {
|
|
||||||
output = append(output, "?")
|
|
||||||
}
|
|
||||||
|
|
||||||
return output
|
|
||||||
}
|
|
@ -1,243 +0,0 @@
|
|||||||
//
|
|
||||||
// query_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Dummy struct{}
|
|
||||||
|
|
||||||
func (d Dummy) ConfigureEntity(e *EntityConfigurator) {
|
|
||||||
// TODO implement me
|
|
||||||
panic("implement me")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSelect(t *testing.T) {
|
|
||||||
t.Run("only select * from Table", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil)
|
|
||||||
s.Table("users").SetSelect()
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, "SELECT * FROM users", str)
|
|
||||||
})
|
|
||||||
t.Run("select with whereClause", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil)
|
|
||||||
|
|
||||||
s.Table("users").SetDialect(Dialects.MySQL).
|
|
||||||
Where("age", 10).
|
|
||||||
AndWhere("age", "<", 10).
|
|
||||||
Where("name", "Amirreza").
|
|
||||||
OrWhere("age", GT, 11).
|
|
||||||
SetSelect()
|
|
||||||
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, []interface{}{10, 10, "Amirreza", 11}, args)
|
|
||||||
assert.Equal(t, "SELECT * FROM users WHERE age = ? AND age < ? AND name = ? OR age > ?", str)
|
|
||||||
})
|
|
||||||
t.Run("select with order by", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").OrderBy("created_at", ASC).OrderBy("updated_at", DESC)
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, "SELECT * FROM users ORDER BY created_at ASC,updated_at DESC", str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with group by", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").GroupBy("created_at", "updated_at")
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, "SELECT * FROM users GROUP BY created_at,updated_at", str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Select with limit", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").Limit(10)
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, "SELECT * FROM users LIMIT 10", str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("Select with offset", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").Offset(10)
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, "SELECT * FROM users OFFSET 10", str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with join", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").RightJoin("addresses", "users.id", "addresses.user_id")
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id`, str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with multiple joins", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").
|
|
||||||
Select("id", "name").
|
|
||||||
RightJoin("addresses", "users.id", "addresses.user_id").
|
|
||||||
LeftJoin("user_credits", "users.id", "user_credits.user_id")
|
|
||||||
sql, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id LEFT JOIN user_credits ON users.id = user_credits.user_id`, sql)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with subquery", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).SetDialect(Dialects.MySQL)
|
|
||||||
s.FromQuery(NewQueryBuilder[Dummy](nil).Table("users").Where("age", "<", 10))
|
|
||||||
sql, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, []interface{}{10}, args)
|
|
||||||
assert.Equal(t, `SELECT * FROM (SELECT * FROM users WHERE age < ? )`, sql)
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with inner join", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").InnerJoin("addresses", "users.id", "addresses.user_id")
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with join", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").Join("addresses", "users.id", "addresses.user_id")
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("select with full outer join", func(t *testing.T) {
|
|
||||||
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").FullOuterJoin("addresses", "users.id", "addresses.user_id")
|
|
||||||
str, args, err := s.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Equal(t, `SELECT id,name FROM users FULL OUTER JOIN addresses ON users.id = addresses.user_id`, str)
|
|
||||||
})
|
|
||||||
t.Run("raw where", func(t *testing.T) {
|
|
||||||
sql, args, err :=
|
|
||||||
NewQueryBuilder[Dummy](nil).
|
|
||||||
SetDialect(Dialects.MySQL).
|
|
||||||
Table("users").
|
|
||||||
Where(Raw("id = ?", 1)).
|
|
||||||
AndWhere(Raw("age < ?", 10)).
|
|
||||||
SetSelect().
|
|
||||||
ToSql()
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, []interface{}{1, 10}, args)
|
|
||||||
assert.Equal(t, `SELECT * FROM users WHERE id = ? AND age < ?`, sql)
|
|
||||||
})
|
|
||||||
t.Run("no sql type matched", func(t *testing.T) {
|
|
||||||
sql, args, err := NewQueryBuilder[Dummy](nil).ToSql()
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
assert.Empty(t, sql)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("raw where in", func(t *testing.T) {
|
|
||||||
sql, args, err :=
|
|
||||||
NewQueryBuilder[Dummy](nil).
|
|
||||||
SetDialect(Dialects.MySQL).
|
|
||||||
Table("users").
|
|
||||||
WhereIn("id", Raw("SELECT user_id FROM user_books WHERE book_id = ?", 10)).
|
|
||||||
SetSelect().
|
|
||||||
ToSql()
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, []interface{}{10}, args)
|
|
||||||
assert.Equal(t, `SELECT * FROM users WHERE id IN (SELECT user_id FROM user_books WHERE book_id = ?)`, sql)
|
|
||||||
})
|
|
||||||
t.Run("where in", func(t *testing.T) {
|
|
||||||
sql, args, err :=
|
|
||||||
NewQueryBuilder[Dummy](nil).
|
|
||||||
SetDialect(Dialects.MySQL).
|
|
||||||
Table("users").
|
|
||||||
WhereIn("id", 1, 2, 3, 4, 5, 6).
|
|
||||||
SetSelect().
|
|
||||||
ToSql()
|
|
||||||
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.EqualValues(t, []interface{}{1, 2, 3, 4, 5, 6}, args)
|
|
||||||
assert.Equal(t, `SELECT * FROM users WHERE id IN (?,?,?,?,?,?)`, sql)
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func TestUpdate(t *testing.T) {
|
|
||||||
t.Run("update no whereClause", func(t *testing.T) {
|
|
||||||
u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").SetDialect(Dialects.MySQL)
|
|
||||||
sql, args, err := u.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, `UPDATE users SET name=?`, sql)
|
|
||||||
assert.Equal(t, []interface{}{"amirreza"}, args)
|
|
||||||
})
|
|
||||||
t.Run("update with whereClause", func(t *testing.T) {
|
|
||||||
u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").Where("age", "<", 18).SetDialect(Dialects.MySQL)
|
|
||||||
sql, args, err := u.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, `UPDATE users SET name=? WHERE age < ?`, sql)
|
|
||||||
assert.Equal(t, []interface{}{"amirreza", 18}, args)
|
|
||||||
|
|
||||||
})
|
|
||||||
}
|
|
||||||
func TestDelete(t *testing.T) {
|
|
||||||
t.Run("delete without whereClause", func(t *testing.T) {
|
|
||||||
d := NewQueryBuilder[Dummy](nil).Table("users").SetDelete()
|
|
||||||
sql, args, err := d.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, `DELETE FROM users`, sql)
|
|
||||||
assert.Empty(t, args)
|
|
||||||
})
|
|
||||||
t.Run("delete with whereClause", func(t *testing.T) {
|
|
||||||
d := NewQueryBuilder[Dummy](nil).Table("users").SetDialect(Dialects.MySQL).Where("created_at", ">", "2012-01-10").SetDelete()
|
|
||||||
sql, args, err := d.ToSql()
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.Equal(t, `DELETE FROM users WHERE created_at > ?`, sql)
|
|
||||||
assert.EqualValues(t, []interface{}{"2012-01-10"}, args)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestInsert(t *testing.T) {
|
|
||||||
t.Run("insert into multiple rows", func(t *testing.T) {
|
|
||||||
i := insertStmt{}
|
|
||||||
i.Table = "users"
|
|
||||||
i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator
|
|
||||||
i.Columns = []string{"name", "age"}
|
|
||||||
i.Values = append(i.Values, []interface{}{"amirreza", 11}, []interface{}{"parsa", 10})
|
|
||||||
s, args := i.ToSql()
|
|
||||||
assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?),(?,?)`, s)
|
|
||||||
assert.EqualValues(t, []interface{}{"amirreza", 11, "parsa", 10}, args)
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("insert into single row", func(t *testing.T) {
|
|
||||||
i := insertStmt{}
|
|
||||||
i.Table = "users"
|
|
||||||
i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator
|
|
||||||
i.Columns = []string{"name", "age"}
|
|
||||||
i.Values = append(i.Values, []interface{}{"amirreza", 11})
|
|
||||||
s, args := i.ToSql()
|
|
||||||
assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?)`, s)
|
|
||||||
assert.Equal(t, []interface{}{"amirreza", 11}, args)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPostgresPlaceholder(t *testing.T) {
|
|
||||||
t.Run("for 5 it should have 5", func(t *testing.T) {
|
|
||||||
phs := postgresPlaceholder(5)
|
|
||||||
assert.EqualValues(t, []string{"$1", "$2", "$3", "$4", "$5"}, phs)
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,306 +0,0 @@
|
|||||||
//
|
|
||||||
// schema.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
func getConnectionFor(e Entity) *connection {
|
|
||||||
configurator := newEntityConfigurator()
|
|
||||||
e.ConfigureEntity(configurator)
|
|
||||||
|
|
||||||
if len(globalConnections) > 1 && (configurator.connection == "" || configurator.table == "") {
|
|
||||||
panic("need table and DB name when having more than 1 DB registered")
|
|
||||||
}
|
|
||||||
if len(globalConnections) == 1 {
|
|
||||||
for _, db := range globalConnections {
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if db, exists := globalConnections[fmt.Sprintf("%s", configurator.connection)]; exists {
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
panic("no db found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSchemaFor(e Entity) *schema {
|
|
||||||
configurator := newEntityConfigurator()
|
|
||||||
c := getConnectionFor(e)
|
|
||||||
e.ConfigureEntity(configurator)
|
|
||||||
s := c.getSchema(configurator.table)
|
|
||||||
if s == nil {
|
|
||||||
s = schemaOfHeavyReflectionStuff(e)
|
|
||||||
c.setSchema(e, s)
|
|
||||||
}
|
|
||||||
return s
|
|
||||||
}
|
|
||||||
|
|
||||||
type schema struct {
|
|
||||||
Connection string
|
|
||||||
Table string
|
|
||||||
fields []*field
|
|
||||||
relations map[string]interface{}
|
|
||||||
setPK func(o Entity, value interface{})
|
|
||||||
getPK func(o Entity) interface{}
|
|
||||||
columnConstraints []*FieldConfigurator
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) getField(sf reflect.StructField) *field {
|
|
||||||
for _, f := range s.fields {
|
|
||||||
if sf.Name == f.Name {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) getDialect() *Dialect {
|
|
||||||
return GetConnection(s.Connection).Dialect
|
|
||||||
}
|
|
||||||
func (s *schema) Columns(withPK bool) []string {
|
|
||||||
var cols []string
|
|
||||||
for _, field := range s.fields {
|
|
||||||
if field.Virtual {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !withPK && field.IsPK {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if s.getDialect().AddTableNameInSelectColumns {
|
|
||||||
cols = append(cols, s.Table+"."+field.Name)
|
|
||||||
} else {
|
|
||||||
cols = append(cols, field.Name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return cols
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) pkName() string {
|
|
||||||
for _, field := range s.fields {
|
|
||||||
if field.IsPK {
|
|
||||||
return field.Name
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericFieldsOf(obj Entity) []*field {
|
|
||||||
t := reflect.TypeOf(obj)
|
|
||||||
for t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
|
|
||||||
}
|
|
||||||
if t.Kind() == reflect.Slice {
|
|
||||||
t = t.Elem()
|
|
||||||
for t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
var ec EntityConfigurator
|
|
||||||
obj.ConfigureEntity(&ec)
|
|
||||||
|
|
||||||
var fms []*field
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
ft := t.Field(i)
|
|
||||||
fm := fieldMetadata(ft, ec.columnConstraints)
|
|
||||||
fms = append(fms, fm...)
|
|
||||||
}
|
|
||||||
return fms
|
|
||||||
}
|
|
||||||
|
|
||||||
func valuesOfField(vf reflect.Value) []interface{} {
|
|
||||||
var values []interface{}
|
|
||||||
if vf.Type().Kind() == reflect.Struct || vf.Type().Kind() == reflect.Ptr {
|
|
||||||
t := vf.Type()
|
|
||||||
if vf.Type().Kind() == reflect.Ptr {
|
|
||||||
t = vf.Type().Elem()
|
|
||||||
}
|
|
||||||
if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
|
||||||
// go into
|
|
||||||
// it does not implement driver.Valuer interface
|
|
||||||
for i := 0; i < vf.NumField(); i++ {
|
|
||||||
vif := vf.Field(i)
|
|
||||||
values = append(values, valuesOfField(vif)...)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
values = append(values, vf.Interface())
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
values = append(values, vf.Interface())
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
func genericValuesOf(o Entity, withPK bool) []interface{} {
|
|
||||||
t := reflect.TypeOf(o)
|
|
||||||
v := reflect.ValueOf(o)
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
fields := getSchemaFor(o).fields
|
|
||||||
pkIdx := -1
|
|
||||||
for i, field := range fields {
|
|
||||||
if field.IsPK {
|
|
||||||
pkIdx = i
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
var values []interface{}
|
|
||||||
|
|
||||||
for i := 0; i < t.NumField(); i++ {
|
|
||||||
if !withPK && i == pkIdx {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if fields[i].Virtual {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
vf := v.Field(i)
|
|
||||||
values = append(values, valuesOfField(vf)...)
|
|
||||||
}
|
|
||||||
return values
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericSetPkValue(obj Entity, value interface{}) {
|
|
||||||
genericSet(obj, getSchemaFor(obj).pkName(), value)
|
|
||||||
}
|
|
||||||
|
|
||||||
func genericGetPKValue(obj Entity) interface{} {
|
|
||||||
t := reflect.TypeOf(obj)
|
|
||||||
val := reflect.ValueOf(obj)
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
val = val.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
fields := getSchemaFor(obj).fields
|
|
||||||
for i, field := range fields {
|
|
||||||
if field.IsPK {
|
|
||||||
return val.Field(i).Interface()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) createdAt() *field {
|
|
||||||
for _, f := range s.fields {
|
|
||||||
if f.IsCreatedAt {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func (s *schema) updatedAt() *field {
|
|
||||||
for _, f := range s.fields {
|
|
||||||
if f.IsUpdatedAt {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) deletedAt() *field {
|
|
||||||
for _, f := range s.fields {
|
|
||||||
if f.IsDeletedAt {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
func pointersOf(v reflect.Value) map[string]interface{} {
|
|
||||||
m := map[string]interface{}{}
|
|
||||||
actualV := v
|
|
||||||
for actualV.Type().Kind() == reflect.Ptr {
|
|
||||||
actualV = actualV.Elem()
|
|
||||||
}
|
|
||||||
for i := 0; i < actualV.NumField(); i++ {
|
|
||||||
f := actualV.Field(i)
|
|
||||||
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
|
||||||
fm := pointersOf(f)
|
|
||||||
for k, p := range fm {
|
|
||||||
m[k] = p
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
fm := fieldMetadata(actualV.Type().Field(i), nil)[0]
|
|
||||||
m[fm.Name] = actualV.Field(i)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return m
|
|
||||||
}
|
|
||||||
func genericSet(obj Entity, name string, value interface{}) {
|
|
||||||
n2p := pointersOf(reflect.ValueOf(obj))
|
|
||||||
var val interface{}
|
|
||||||
for k, v := range n2p {
|
|
||||||
if k == name {
|
|
||||||
val = v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
val.(reflect.Value).Set(reflect.ValueOf(value))
|
|
||||||
}
|
|
||||||
func schemaOfHeavyReflectionStuff(v Entity) *schema {
|
|
||||||
userEntityConfigurator := newEntityConfigurator()
|
|
||||||
v.ConfigureEntity(userEntityConfigurator)
|
|
||||||
for _, relation := range userEntityConfigurator.resolveRelations {
|
|
||||||
relation()
|
|
||||||
}
|
|
||||||
schema := &schema{}
|
|
||||||
if userEntityConfigurator.connection != "" {
|
|
||||||
schema.Connection = userEntityConfigurator.connection
|
|
||||||
}
|
|
||||||
if userEntityConfigurator.table != "" {
|
|
||||||
schema.Table = userEntityConfigurator.table
|
|
||||||
} else {
|
|
||||||
panic("you need to have table name for getting schema.")
|
|
||||||
}
|
|
||||||
|
|
||||||
schema.columnConstraints = userEntityConfigurator.columnConstraints
|
|
||||||
if schema.Connection == "" {
|
|
||||||
schema.Connection = "default"
|
|
||||||
}
|
|
||||||
if schema.fields == nil {
|
|
||||||
schema.fields = genericFieldsOf(v)
|
|
||||||
}
|
|
||||||
if schema.getPK == nil {
|
|
||||||
schema.getPK = genericGetPKValue
|
|
||||||
}
|
|
||||||
|
|
||||||
if schema.setPK == nil {
|
|
||||||
schema.setPK = genericSetPkValue
|
|
||||||
}
|
|
||||||
|
|
||||||
schema.relations = userEntityConfigurator.relations
|
|
||||||
|
|
||||||
return schema
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) getTable() string {
|
|
||||||
return s.Table
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) getSQLDB() *sql.DB {
|
|
||||||
return s.getConnection().DB
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *schema) getConnection() *connection {
|
|
||||||
if len(globalConnections) > 1 && (s.Connection == "" || s.Table == "") {
|
|
||||||
panic("need table and DB name when having more than 1 DB registered")
|
|
||||||
}
|
|
||||||
if len(globalConnections) == 1 {
|
|
||||||
for _, db := range globalConnections {
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if db, exists := globalConnections[fmt.Sprintf("%s", s.Connection)]; exists {
|
|
||||||
return db
|
|
||||||
}
|
|
||||||
panic("no db found")
|
|
||||||
}
|
|
@ -1,77 +0,0 @@
|
|||||||
//
|
|
||||||
// schema_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func setup(t *testing.T) {
|
|
||||||
db, err := sql.Open("sqlite3", ":memory:")
|
|
||||||
err = SetupConnections(ConnectionConfig{
|
|
||||||
Name: "default",
|
|
||||||
DB: db,
|
|
||||||
Dialect: Dialects.SQLite3,
|
|
||||||
})
|
|
||||||
// orm.Schematic()
|
|
||||||
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
|
||||||
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
|
||||||
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
|
||||||
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`)
|
|
||||||
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
|
||||||
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Object struct {
|
|
||||||
ID int64
|
|
||||||
Name string
|
|
||||||
Timestamps
|
|
||||||
SoftDelete
|
|
||||||
}
|
|
||||||
|
|
||||||
func (o Object) ConfigureEntity(e *EntityConfigurator) {
|
|
||||||
e.Table("objects").Connection("default")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenericFieldsOf(t *testing.T) {
|
|
||||||
t.Run("fields of with id and timestamps embedded", func(t *testing.T) {
|
|
||||||
fs := genericFieldsOf(&Object{})
|
|
||||||
assert.Len(t, fs, 5)
|
|
||||||
assert.Equal(t, "id", fs[0].Name)
|
|
||||||
assert.True(t, fs[0].IsPK)
|
|
||||||
assert.Equal(t, "name", fs[1].Name)
|
|
||||||
assert.Equal(t, "created_at", fs[2].Name)
|
|
||||||
assert.Equal(t, "updated_at", fs[3].Name)
|
|
||||||
assert.Equal(t, "deleted_at", fs[4].Name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestGenericValuesOf(t *testing.T) {
|
|
||||||
t.Run("values of", func(t *testing.T) {
|
|
||||||
|
|
||||||
setup(t)
|
|
||||||
vs := genericValuesOf(Object{}, true)
|
|
||||||
assert.Len(t, vs, 5)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestEntityConfigurator(t *testing.T) {
|
|
||||||
t.Run("test has many with user provided values", func(t *testing.T) {
|
|
||||||
setup(t)
|
|
||||||
var ec EntityConfigurator
|
|
||||||
ec.Table("users").Connection("default").HasMany(Object{}, HasManyConfig{
|
|
||||||
"objects", "user_id",
|
|
||||||
})
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
@ -1,21 +0,0 @@
|
|||||||
//
|
|
||||||
// timestamps.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package orm
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Timestamps struct {
|
|
||||||
CreatedAt sql.NullTime
|
|
||||||
UpdatedAt sql.NullTime
|
|
||||||
}
|
|
||||||
|
|
||||||
type SoftDelete struct {
|
|
||||||
DeletedAt sql.NullTime
|
|
||||||
}
|
|
@ -1,78 +0,0 @@
|
|||||||
//
|
|
||||||
// column.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import "reflect"
|
|
||||||
|
|
||||||
// ColumnMap represents a mapping between a Go struct field and a single
|
|
||||||
// column in a table.
|
|
||||||
// Unique and MaxSize only inform the
|
|
||||||
// CreateTables() function and are not used by Insert/Update/Delete/Get.
|
|
||||||
type ColumnMap struct {
|
|
||||||
// Column name in db table
|
|
||||||
ColumnName string
|
|
||||||
|
|
||||||
// If true, this column is skipped in generated SQL statements
|
|
||||||
Transient bool
|
|
||||||
|
|
||||||
// If true, " unique" is added to create table statements.
|
|
||||||
// Not used elsewhere
|
|
||||||
Unique bool
|
|
||||||
|
|
||||||
// Query used for getting generated id after insert
|
|
||||||
GeneratedIdQuery string
|
|
||||||
|
|
||||||
// Passed to Dialect.ToSqlType() to assist in informing the
|
|
||||||
// correct column type to map to in CreateTables()
|
|
||||||
MaxSize int
|
|
||||||
|
|
||||||
DefaultValue string
|
|
||||||
|
|
||||||
fieldName string
|
|
||||||
gotype reflect.Type
|
|
||||||
isPK bool
|
|
||||||
isAutoIncr bool
|
|
||||||
isNotNull bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rename allows you to specify the column name in the table
|
|
||||||
//
|
|
||||||
// Example: table.ColMap("Updated").Rename("date_updated")
|
|
||||||
func (c *ColumnMap) Rename(colname string) *ColumnMap {
|
|
||||||
c.ColumnName = colname
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetTransient allows you to mark the column as transient. If true
|
|
||||||
// this column will be skipped when SQL statements are generated
|
|
||||||
func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
|
|
||||||
c.Transient = b
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUnique adds "unique" to the create table statements for this
|
|
||||||
// column, if b is true.
|
|
||||||
func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
|
|
||||||
c.Unique = b
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNotNull adds "not null" to the create table statements for this
|
|
||||||
// column, if nn is true.
|
|
||||||
func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap {
|
|
||||||
c.isNotNull = nn
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetMaxSize specifies the max length of values of this column. This is
|
|
||||||
// passed to the dialect.ToSqlType() function, which can use the value
|
|
||||||
// to alter the generated type for "create table" statements
|
|
||||||
func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
|
|
||||||
c.MaxSize = size
|
|
||||||
return c
|
|
||||||
}
|
|
@ -1,82 +0,0 @@
|
|||||||
//
|
|
||||||
// context_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
//go:build integration
|
|
||||||
// +build integration
|
|
||||||
|
|
||||||
package sqldb_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Drivers that don't support cancellation.
|
|
||||||
var unsupportedDrivers map[string]bool = map[string]bool{
|
|
||||||
"mymysql": true,
|
|
||||||
}
|
|
||||||
|
|
||||||
type SleepDialect interface {
|
|
||||||
// string to sleep for d duration
|
|
||||||
SleepClause(d time.Duration) string
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithNotCanceledContext(t *testing.T) {
|
|
||||||
dbmap := initDBMap(t)
|
|
||||||
defer dropAndClose(dbmap)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
withCtx := dbmap.WithContext(ctx)
|
|
||||||
|
|
||||||
_, err := withCtx.Exec("SELECT 1")
|
|
||||||
assert.Nil(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestWithCanceledContext(t *testing.T) {
|
|
||||||
dialect, driver := dialectAndDriver()
|
|
||||||
if unsupportedDrivers[driver] {
|
|
||||||
t.Skipf("Cancellation is not yet supported by all drivers. Not known to be supported in %s.", driver)
|
|
||||||
}
|
|
||||||
|
|
||||||
sleepDialect, ok := dialect.(SleepDialect)
|
|
||||||
if !ok {
|
|
||||||
t.Skipf("Sleep is not supported in all dialects. Not known to be supported in %s.", driver)
|
|
||||||
}
|
|
||||||
|
|
||||||
dbmap := initDBMap(t)
|
|
||||||
defer dropAndClose(dbmap)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
withCtx := dbmap.WithContext(ctx)
|
|
||||||
|
|
||||||
startTime := time.Now()
|
|
||||||
|
|
||||||
_, err := withCtx.Exec("SELECT " + sleepDialect.SleepClause(1*time.Second))
|
|
||||||
|
|
||||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
|
||||||
t.Errorf("too long execution time: %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch driver {
|
|
||||||
case "postgres":
|
|
||||||
// pq doesn't return standard deadline exceeded error
|
|
||||||
if err.Error() != "pq: canceling statement due to user request" {
|
|
||||||
t.Errorf("expected context.DeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
if err != context.DeadlineExceeded {
|
|
||||||
t.Errorf("expected context.DeadlineExceeded, got %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
1049
gdb/sqldb/db.go
1049
gdb/sqldb/db.go
File diff suppressed because it is too large
Load Diff
@ -1,187 +0,0 @@
|
|||||||
//
|
|
||||||
// db_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
//go:build integration
|
|
||||||
// +build integration
|
|
||||||
|
|
||||||
package sqldb_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
type customType1 []string
|
|
||||||
|
|
||||||
func (c customType1) ToStringSlice() []string {
|
|
||||||
return []string(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
type customType2 []int64
|
|
||||||
|
|
||||||
func (c customType2) ToInt64Slice() []int64 {
|
|
||||||
return []int64(c)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDbMap_Select_expandSliceArgs(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
description string
|
|
||||||
query string
|
|
||||||
args []interface{}
|
|
||||||
wantLen int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
description: "it should handle slice placeholders correctly",
|
|
||||||
query: `
|
|
||||||
SELECT 1 FROM crazy_table
|
|
||||||
WHERE field1 = :Field1
|
|
||||||
AND field2 IN (:FieldStringList)
|
|
||||||
AND field3 IN (:FieldUIntList)
|
|
||||||
AND field4 IN (:FieldUInt8List)
|
|
||||||
AND field5 IN (:FieldUInt16List)
|
|
||||||
AND field6 IN (:FieldUInt32List)
|
|
||||||
AND field7 IN (:FieldUInt64List)
|
|
||||||
AND field8 IN (:FieldIntList)
|
|
||||||
AND field9 IN (:FieldInt8List)
|
|
||||||
AND field10 IN (:FieldInt16List)
|
|
||||||
AND field11 IN (:FieldInt32List)
|
|
||||||
AND field12 IN (:FieldInt64List)
|
|
||||||
AND field13 IN (:FieldFloat32List)
|
|
||||||
AND field14 IN (:FieldFloat64List)
|
|
||||||
`,
|
|
||||||
args: []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"Field1": 123,
|
|
||||||
"FieldStringList": []string{"h", "e", "y"},
|
|
||||||
"FieldUIntList": []uint{1, 2, 3, 4},
|
|
||||||
"FieldUInt8List": []uint8{1, 2, 3, 4},
|
|
||||||
"FieldUInt16List": []uint16{1, 2, 3, 4},
|
|
||||||
"FieldUInt32List": []uint32{1, 2, 3, 4},
|
|
||||||
"FieldUInt64List": []uint64{1, 2, 3, 4},
|
|
||||||
"FieldIntList": []int{1, 2, 3, 4},
|
|
||||||
"FieldInt8List": []int8{1, 2, 3, 4},
|
|
||||||
"FieldInt16List": []int16{1, 2, 3, 4},
|
|
||||||
"FieldInt32List": []int32{1, 2, 3, 4},
|
|
||||||
"FieldInt64List": []int64{1, 2, 3, 4},
|
|
||||||
"FieldFloat32List": []float32{1, 2, 3, 4},
|
|
||||||
"FieldFloat64List": []float64{1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "it should handle slice placeholders correctly with custom types",
|
|
||||||
query: `
|
|
||||||
SELECT 1 FROM crazy_table
|
|
||||||
WHERE field2 IN (:FieldStringList)
|
|
||||||
AND field12 IN (:FieldIntList)
|
|
||||||
`,
|
|
||||||
args: []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"FieldStringList": customType1{"h", "e", "y"},
|
|
||||||
"FieldIntList": customType2{1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 3,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type dataFormat struct {
|
|
||||||
Field1 int `db:"field1"`
|
|
||||||
Field2 string `db:"field2"`
|
|
||||||
Field3 uint `db:"field3"`
|
|
||||||
Field4 uint8 `db:"field4"`
|
|
||||||
Field5 uint16 `db:"field5"`
|
|
||||||
Field6 uint32 `db:"field6"`
|
|
||||||
Field7 uint64 `db:"field7"`
|
|
||||||
Field8 int `db:"field8"`
|
|
||||||
Field9 int8 `db:"field9"`
|
|
||||||
Field10 int16 `db:"field10"`
|
|
||||||
Field11 int32 `db:"field11"`
|
|
||||||
Field12 int64 `db:"field12"`
|
|
||||||
Field13 float32 `db:"field13"`
|
|
||||||
Field14 float64 `db:"field14"`
|
|
||||||
}
|
|
||||||
|
|
||||||
dbmap := newDBMap(t)
|
|
||||||
dbmap.ExpandSliceArgs = true
|
|
||||||
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
|
|
||||||
|
|
||||||
err := dbmap.CreateTables()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer dropAndClose(dbmap)
|
|
||||||
|
|
||||||
err = dbmap.Insert(
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 123,
|
|
||||||
Field2: "h",
|
|
||||||
Field3: 1,
|
|
||||||
Field4: 1,
|
|
||||||
Field5: 1,
|
|
||||||
Field6: 1,
|
|
||||||
Field7: 1,
|
|
||||||
Field8: 1,
|
|
||||||
Field9: 1,
|
|
||||||
Field10: 1,
|
|
||||||
Field11: 1,
|
|
||||||
Field12: 1,
|
|
||||||
Field13: 1,
|
|
||||||
Field14: 1,
|
|
||||||
},
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 124,
|
|
||||||
Field2: "e",
|
|
||||||
Field3: 2,
|
|
||||||
Field4: 2,
|
|
||||||
Field5: 2,
|
|
||||||
Field6: 2,
|
|
||||||
Field7: 2,
|
|
||||||
Field8: 2,
|
|
||||||
Field9: 2,
|
|
||||||
Field10: 2,
|
|
||||||
Field11: 2,
|
|
||||||
Field12: 2,
|
|
||||||
Field13: 2,
|
|
||||||
Field14: 2,
|
|
||||||
},
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 125,
|
|
||||||
Field2: "y",
|
|
||||||
Field3: 3,
|
|
||||||
Field4: 3,
|
|
||||||
Field5: 3,
|
|
||||||
Field6: 3,
|
|
||||||
Field7: 3,
|
|
||||||
Field8: 3,
|
|
||||||
Field9: 3,
|
|
||||||
Field10: 3,
|
|
||||||
Field11: 3,
|
|
||||||
Field12: 3,
|
|
||||||
Field13: 3,
|
|
||||||
Field14: 3,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.description, func(t *testing.T) {
|
|
||||||
var dummy []int
|
|
||||||
_, err := dbmap.Select(&dummy, tt.query, tt.args...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dummy) != tt.wantLen {
|
|
||||||
t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,108 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// The Dialect interface encapsulates behaviors that differ across
|
|
||||||
// SQL databases. At present the Dialect is only used by CreateTables()
|
|
||||||
// but this could change in the future
|
|
||||||
type Dialect interface {
|
|
||||||
// adds a suffix to any query, usually ";"
|
|
||||||
QuerySuffix() string
|
|
||||||
|
|
||||||
// ToSqlType returns the SQL column type to use when creating a
|
|
||||||
// table of the given Go Type. maxsize can be used to switch based on
|
|
||||||
// size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB,
|
|
||||||
// or LONGBLOB depending on the maxsize
|
|
||||||
ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string
|
|
||||||
|
|
||||||
// string to append to primary key column definitions
|
|
||||||
AutoIncrStr() string
|
|
||||||
|
|
||||||
// string to bind autoincrement columns to. Empty string will
|
|
||||||
// remove reference to those columns in the INSERT statement.
|
|
||||||
AutoIncrBindValue() string
|
|
||||||
|
|
||||||
AutoIncrInsertSuffix(col *ColumnMap) string
|
|
||||||
|
|
||||||
// string to append to "create table" statement for vendor specific
|
|
||||||
// table attributes
|
|
||||||
CreateTableSuffix() string
|
|
||||||
|
|
||||||
// string to append to "create index" statement
|
|
||||||
CreateIndexSuffix() string
|
|
||||||
|
|
||||||
// string to append to "drop index" statement
|
|
||||||
DropIndexSuffix() string
|
|
||||||
|
|
||||||
// string to truncate tables
|
|
||||||
TruncateClause() string
|
|
||||||
|
|
||||||
// bind variable string to use when forming SQL statements
|
|
||||||
// in many dbs it is "?", but Postgres appears to use $1
|
|
||||||
//
|
|
||||||
// i is a zero based index of the bind variable in this statement
|
|
||||||
//
|
|
||||||
BindVar(i int) string
|
|
||||||
|
|
||||||
// Handles quoting of a field name to ensure that it doesn't raise any
|
|
||||||
// SQL parsing exceptions by using a reserved word as a field name.
|
|
||||||
QuoteField(field string) string
|
|
||||||
|
|
||||||
// Handles building up of a schema.database string that is compatible with
|
|
||||||
// the given dialect
|
|
||||||
//
|
|
||||||
// schema - The schema that <table> lives in
|
|
||||||
// table - The table name
|
|
||||||
QuotedTableForQuery(schema string, table string) string
|
|
||||||
|
|
||||||
// Existence clause for table creation / deletion
|
|
||||||
IfSchemaNotExists(command, schema string) string
|
|
||||||
IfTableExists(command, schema, table string) string
|
|
||||||
IfTableNotExists(command, schema, table string) string
|
|
||||||
}
|
|
||||||
|
|
||||||
// IntegerAutoIncrInserter is implemented by dialects that can perform
|
|
||||||
// inserts with automatically incremented integer primary keys. If
|
|
||||||
// the dialect can handle automatic assignment of more than just
|
|
||||||
// integers, see TargetedAutoIncrInserter.
|
|
||||||
type IntegerAutoIncrInserter interface {
|
|
||||||
InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TargetedAutoIncrInserter is implemented by dialects that can
|
|
||||||
// perform automatic assignment of any primary key type (i.e. strings
|
|
||||||
// for uuids, integers for serials, etc).
|
|
||||||
type TargetedAutoIncrInserter interface {
|
|
||||||
// InsertAutoIncrToTarget runs an insert operation and assigns the
|
|
||||||
// automatically generated primary key directly to the passed in
|
|
||||||
// target. The target should be a pointer to the primary key
|
|
||||||
// field of the value being inserted.
|
|
||||||
InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// TargetQueryInserter is implemented by dialects that can perform
|
|
||||||
// assignment of integer primary key type by executing a query
|
|
||||||
// like "select sequence.currval from dual".
|
|
||||||
type TargetQueryInserter interface {
|
|
||||||
// TargetQueryInserter runs an insert operation and assigns the
|
|
||||||
// automatically generated primary key retrived by the query
|
|
||||||
// extracted from the GeneratedIdQuery field of the id column.
|
|
||||||
InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
|
|
||||||
res, err := exec.Exec(insertSql, params...)
|
|
||||||
if err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return res.LastInsertId()
|
|
||||||
}
|
|
@ -1,172 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect_mysql.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Implementation of Dialect for MySQL databases.
|
|
||||||
type MySQLDialect struct {
|
|
||||||
|
|
||||||
// Engine is the storage engine to use "InnoDB" vs "MyISAM" for example
|
|
||||||
Engine string
|
|
||||||
|
|
||||||
// Encoding is the character encoding to use for created tables
|
|
||||||
Encoding string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) QuerySuffix() string { return ";" }
|
|
||||||
|
|
||||||
func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
|
||||||
switch val.Kind() {
|
|
||||||
case reflect.Ptr:
|
|
||||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
|
||||||
case reflect.Bool:
|
|
||||||
return "boolean"
|
|
||||||
case reflect.Int8:
|
|
||||||
return "tinyint"
|
|
||||||
case reflect.Uint8:
|
|
||||||
return "tinyint unsigned"
|
|
||||||
case reflect.Int16:
|
|
||||||
return "smallint"
|
|
||||||
case reflect.Uint16:
|
|
||||||
return "smallint unsigned"
|
|
||||||
case reflect.Int, reflect.Int32:
|
|
||||||
return "int"
|
|
||||||
case reflect.Uint, reflect.Uint32:
|
|
||||||
return "int unsigned"
|
|
||||||
case reflect.Int64:
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Uint64:
|
|
||||||
return "bigint unsigned"
|
|
||||||
case reflect.Float64, reflect.Float32:
|
|
||||||
return "double"
|
|
||||||
case reflect.Slice:
|
|
||||||
if val.Elem().Kind() == reflect.Uint8 {
|
|
||||||
return "mediumblob"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch val.Name() {
|
|
||||||
case "NullInt64":
|
|
||||||
return "bigint"
|
|
||||||
case "NullFloat64":
|
|
||||||
return "double"
|
|
||||||
case "NullBool":
|
|
||||||
return "tinyint"
|
|
||||||
case "Time":
|
|
||||||
return "datetime"
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxsize < 1 {
|
|
||||||
maxsize = 255
|
|
||||||
}
|
|
||||||
|
|
||||||
/* == About varchar(N) ==
|
|
||||||
* N is number of characters.
|
|
||||||
* A varchar column can store up to 65535 bytes.
|
|
||||||
* Remember that 1 character is 3 bytes in utf-8 charset.
|
|
||||||
* Also remember that each row can store up to 65535 bytes,
|
|
||||||
* and you have some overheads, so it's not possible for a
|
|
||||||
* varchar column to have 65535/3 characters really.
|
|
||||||
* So it would be better to use 'text' type in stead of
|
|
||||||
* large varchar type.
|
|
||||||
*/
|
|
||||||
if maxsize < 256 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
|
||||||
} else {
|
|
||||||
return "text"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns auto_increment
|
|
||||||
func (d MySQLDialect) AutoIncrStr() string {
|
|
||||||
return "auto_increment"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) AutoIncrBindValue() string {
|
|
||||||
return "null"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns engine=%s charset=%s based on values stored on struct
|
|
||||||
func (d MySQLDialect) CreateTableSuffix() string {
|
|
||||||
if d.Engine == "" || d.Encoding == "" {
|
|
||||||
msg := "sqldb - undefined"
|
|
||||||
|
|
||||||
if d.Engine == "" {
|
|
||||||
msg += " MySQLDialect.Engine"
|
|
||||||
}
|
|
||||||
if d.Engine == "" && d.Encoding == "" {
|
|
||||||
msg += ","
|
|
||||||
}
|
|
||||||
if d.Encoding == "" {
|
|
||||||
msg += " MySQLDialect.Encoding"
|
|
||||||
}
|
|
||||||
msg += ". Check that your MySQLDialect was correctly initialized when declared."
|
|
||||||
panic(msg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) CreateIndexSuffix() string {
|
|
||||||
return "using"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) DropIndexSuffix() string {
|
|
||||||
return "on"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) TruncateClause() string {
|
|
||||||
return "truncate"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) SleepClause(s time.Duration) string {
|
|
||||||
return fmt.Sprintf("sleep(%f)", s.Seconds())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns "?"
|
|
||||||
func (d MySQLDialect) BindVar(i int) string {
|
|
||||||
return "?"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
|
|
||||||
return standardInsertAutoIncr(exec, insertSql, params...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) QuoteField(f string) string {
|
|
||||||
return "`" + f + "`"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string {
|
|
||||||
if strings.TrimSpace(schema) == "" {
|
|
||||||
return d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema + "." + d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) IfSchemaNotExists(command, schema string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) IfTableExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d MySQLDialect) IfTableNotExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
@ -1,195 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect_mysql_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
//go:build !integration
|
|
||||||
// +build !integration
|
|
||||||
|
|
||||||
package sqldb_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
|
||||||
"github.com/poy/onpar"
|
|
||||||
"github.com/poy/onpar/expect"
|
|
||||||
"github.com/poy/onpar/matchers"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestMySQLDialect(t *testing.T) {
|
|
||||||
// o := onpar.New(t)
|
|
||||||
// defer o.Run()
|
|
||||||
|
|
||||||
type testContext struct {
|
|
||||||
t *testing.T
|
|
||||||
dialect sqldb.MySQLDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext {
|
|
||||||
return testContext{
|
|
||||||
t: t,
|
|
||||||
dialect: sqldb.MySQLDialect{Engine: "foo", Encoding: "bar"},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
defer o.Run()
|
|
||||||
|
|
||||||
o.Group("ToSqlType", func() {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
value interface{}
|
|
||||||
maxSize int
|
|
||||||
autoIncr bool
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"bool", true, 0, false, "boolean"},
|
|
||||||
{"int8", int8(1), 0, false, "tinyint"},
|
|
||||||
{"uint8", uint8(1), 0, false, "tinyint unsigned"},
|
|
||||||
{"int16", int16(1), 0, false, "smallint"},
|
|
||||||
{"uint16", uint16(1), 0, false, "smallint unsigned"},
|
|
||||||
{"int32", int32(1), 0, false, "int"},
|
|
||||||
{"int (treated as int32)", int(1), 0, false, "int"},
|
|
||||||
{"uint32", uint32(1), 0, false, "int unsigned"},
|
|
||||||
{"uint (treated as uint32)", uint(1), 0, false, "int unsigned"},
|
|
||||||
{"int64", int64(1), 0, false, "bigint"},
|
|
||||||
{"uint64", uint64(1), 0, false, "bigint unsigned"},
|
|
||||||
{"float32", float32(1), 0, false, "double"},
|
|
||||||
{"float64", float64(1), 0, false, "double"},
|
|
||||||
{"[]uint8", []uint8{1}, 0, false, "mediumblob"},
|
|
||||||
{"NullInt64", sql.NullInt64{}, 0, false, "bigint"},
|
|
||||||
{"NullFloat64", sql.NullFloat64{}, 0, false, "double"},
|
|
||||||
{"NullBool", sql.NullBool{}, 0, false, "tinyint"},
|
|
||||||
{"Time", time.Time{}, 0, false, "datetime"},
|
|
||||||
{"default-size string", "", 0, false, "varchar(255)"},
|
|
||||||
{"sized string", "", 50, false, "varchar(50)"},
|
|
||||||
{"large string", "", 1024, false, "text"},
|
|
||||||
}
|
|
||||||
for _, t := range tests {
|
|
||||||
o.Spec(t.name, func(tt testContext) {
|
|
||||||
typ := reflect.TypeOf(t.value)
|
|
||||||
sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr)
|
|
||||||
expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("AutoIncrStr", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("auto_increment"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("AutoIncrBindValue", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("null"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("AutoIncrInsertSuffix", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(nil)).To(matchers.Equal(""))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Group("CreateTableSuffix", func() {
|
|
||||||
o.Group("with an empty engine", func() {
|
|
||||||
o1 := onpar.BeforeEach(o, func(tt testContext) testContext {
|
|
||||||
tt.dialect.Encoding = ""
|
|
||||||
return tt
|
|
||||||
})
|
|
||||||
o1.Spec("panics", func(tt testContext) {
|
|
||||||
expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Group("with an empty encoding", func() {
|
|
||||||
o2 := onpar.BeforeEach(o, func(tt testContext) testContext {
|
|
||||||
tt.dialect.Encoding = ""
|
|
||||||
return tt
|
|
||||||
})
|
|
||||||
o2.Spec("panics", func(tt testContext) {
|
|
||||||
expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("with an engine and an encoding", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(" engine=foo charset=bar"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("CreateIndexSuffix", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("DropIndexSuffix", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("on"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("TruncateClause", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("SleepClause", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("sleep(1.000000)"))
|
|
||||||
expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("sleep(0.100000)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("BindVar", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("?"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("QuoteField", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuoteField("foo")).To(matchers.Equal("`foo`"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Group("QuotedTableForQuery", func() {
|
|
||||||
o.Spec("using the default schema", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal("`foo`"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("with a supplied schema", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`"))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("IfSchemaNotExists", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("IfTableExists", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("IfTableNotExists", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists"))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
type panicMatcher struct {
|
|
||||||
}
|
|
||||||
|
|
||||||
func Panic() panicMatcher {
|
|
||||||
return panicMatcher{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m panicMatcher) Match(actual interface{}) (resultValue interface{}, err error) {
|
|
||||||
switch f := actual.(type) {
|
|
||||||
case func():
|
|
||||||
panicked := false
|
|
||||||
func() {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
panicked = true
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
f()
|
|
||||||
}()
|
|
||||||
if panicked {
|
|
||||||
return f, nil
|
|
||||||
}
|
|
||||||
return f, errors.New("function did not panic")
|
|
||||||
default:
|
|
||||||
return f, fmt.Errorf("%T is not func()", f)
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,142 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect_oracle.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Implementation of Dialect for Oracle databases.
|
|
||||||
type OracleDialect struct{}
|
|
||||||
|
|
||||||
func (d OracleDialect) QuerySuffix() string { return "" }
|
|
||||||
|
|
||||||
func (d OracleDialect) CreateIndexSuffix() string { return "" }
|
|
||||||
|
|
||||||
func (d OracleDialect) DropIndexSuffix() string { return "" }
|
|
||||||
|
|
||||||
func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
|
||||||
switch val.Kind() {
|
|
||||||
case reflect.Ptr:
|
|
||||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
|
||||||
case reflect.Bool:
|
|
||||||
return "boolean"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
|
||||||
if isAutoIncr {
|
|
||||||
return "serial"
|
|
||||||
}
|
|
||||||
return "integer"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if isAutoIncr {
|
|
||||||
return "bigserial"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float64:
|
|
||||||
return "double precision"
|
|
||||||
case reflect.Float32:
|
|
||||||
return "real"
|
|
||||||
case reflect.Slice:
|
|
||||||
if val.Elem().Kind() == reflect.Uint8 {
|
|
||||||
return "bytea"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch val.Name() {
|
|
||||||
case "NullInt64":
|
|
||||||
return "bigint"
|
|
||||||
case "NullFloat64":
|
|
||||||
return "double precision"
|
|
||||||
case "NullBool":
|
|
||||||
return "boolean"
|
|
||||||
case "NullTime", "Time":
|
|
||||||
return "timestamp with time zone"
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxsize > 0 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
|
||||||
} else {
|
|
||||||
return "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns empty string
|
|
||||||
func (d OracleDialect) AutoIncrStr() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) AutoIncrBindValue() string {
|
|
||||||
return "NULL"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns suffix
|
|
||||||
func (d OracleDialect) CreateTableSuffix() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) TruncateClause() string {
|
|
||||||
return "truncate"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns "$(i+1)"
|
|
||||||
func (d OracleDialect) BindVar(i int) string {
|
|
||||||
return fmt.Sprintf(":%d", i+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// After executing the insert uses the ColMap IdQuery to get the generated id
|
|
||||||
func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error {
|
|
||||||
_, err := exec.Exec(insertSql, params...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
id, err := exec.SelectInt(idSql)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
switch target.(type) {
|
|
||||||
case *int64:
|
|
||||||
*(target.(*int64)) = id
|
|
||||||
case *int32:
|
|
||||||
*(target.(*int32)) = int32(id)
|
|
||||||
case int:
|
|
||||||
*(target.(*int)) = int(id)
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("Id field can be int, int32 or int64")
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) QuoteField(f string) string {
|
|
||||||
return `"` + strings.ToUpper(f) + `"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) QuotedTableForQuery(schema string, table string) string {
|
|
||||||
if strings.TrimSpace(schema) == "" {
|
|
||||||
return d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema + "." + d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) IfSchemaNotExists(command, schema string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) IfTableExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d OracleDialect) IfTableNotExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
@ -1,152 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect_postgres.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
type PostgresDialect struct {
|
|
||||||
suffix string
|
|
||||||
LowercaseFields bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) QuerySuffix() string { return ";" }
|
|
||||||
|
|
||||||
func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
|
||||||
switch val.Kind() {
|
|
||||||
case reflect.Ptr:
|
|
||||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
|
||||||
case reflect.Bool:
|
|
||||||
return "boolean"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
|
||||||
if isAutoIncr {
|
|
||||||
return "serial"
|
|
||||||
}
|
|
||||||
return "integer"
|
|
||||||
case reflect.Int64, reflect.Uint64:
|
|
||||||
if isAutoIncr {
|
|
||||||
return "bigserial"
|
|
||||||
}
|
|
||||||
return "bigint"
|
|
||||||
case reflect.Float64:
|
|
||||||
return "double precision"
|
|
||||||
case reflect.Float32:
|
|
||||||
return "real"
|
|
||||||
case reflect.Slice:
|
|
||||||
if val.Elem().Kind() == reflect.Uint8 {
|
|
||||||
return "bytea"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch val.Name() {
|
|
||||||
case "NullInt64":
|
|
||||||
return "bigint"
|
|
||||||
case "NullFloat64":
|
|
||||||
return "double precision"
|
|
||||||
case "NullBool":
|
|
||||||
return "boolean"
|
|
||||||
case "Time", "NullTime":
|
|
||||||
return "timestamp with time zone"
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxsize > 0 {
|
|
||||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
|
||||||
} else {
|
|
||||||
return "text"
|
|
||||||
}
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns empty string
|
|
||||||
func (d PostgresDialect) AutoIncrStr() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) AutoIncrBindValue() string {
|
|
||||||
return "default"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
|
||||||
return " returning " + d.QuoteField(col.ColumnName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns suffix
|
|
||||||
func (d PostgresDialect) CreateTableSuffix() string {
|
|
||||||
return d.suffix
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) CreateIndexSuffix() string {
|
|
||||||
return "using"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) DropIndexSuffix() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) TruncateClause() string {
|
|
||||||
return "truncate"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) SleepClause(s time.Duration) string {
|
|
||||||
return fmt.Sprintf("pg_sleep(%f)", s.Seconds())
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns "$(i+1)"
|
|
||||||
func (d PostgresDialect) BindVar(i int) string {
|
|
||||||
return fmt.Sprintf("$%d", i+1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error {
|
|
||||||
rows, err := exec.Query(insertSql, params...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if !rows.Next() {
|
|
||||||
return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err())
|
|
||||||
}
|
|
||||||
if err := rows.Scan(target); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if rows.Next() {
|
|
||||||
return fmt.Errorf("more than two serial value returned for insert: %s", insertSql)
|
|
||||||
}
|
|
||||||
return rows.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) QuoteField(f string) string {
|
|
||||||
if d.LowercaseFields {
|
|
||||||
return `"` + strings.ToLower(f) + `"`
|
|
||||||
}
|
|
||||||
return `"` + f + `"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string {
|
|
||||||
if strings.TrimSpace(schema) == "" {
|
|
||||||
return d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
return schema + "." + d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) IfSchemaNotExists(command, schema string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) IfTableExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d PostgresDialect) IfTableNotExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
@ -1,161 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect_postgres_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
//go:build !integration
|
|
||||||
// +build !integration
|
|
||||||
|
|
||||||
package sqldb_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"reflect"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
|
||||||
"github.com/poy/onpar"
|
|
||||||
"github.com/poy/onpar/expect"
|
|
||||||
"github.com/poy/onpar/matchers"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPostgresDialect(t *testing.T) {
|
|
||||||
|
|
||||||
type testContext struct {
|
|
||||||
t *testing.T
|
|
||||||
dialect sqldb.PostgresDialect
|
|
||||||
}
|
|
||||||
|
|
||||||
o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext {
|
|
||||||
return testContext{
|
|
||||||
t: t,
|
|
||||||
dialect: sqldb.PostgresDialect{
|
|
||||||
LowercaseFields: false,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
})
|
|
||||||
defer o.Run()
|
|
||||||
|
|
||||||
o.Group("ToSqlType", func() {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
value interface{}
|
|
||||||
maxSize int
|
|
||||||
autoIncr bool
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"bool", true, 0, false, "boolean"},
|
|
||||||
{"int8", int8(1), 0, false, "integer"},
|
|
||||||
{"uint8", uint8(1), 0, false, "integer"},
|
|
||||||
{"int16", int16(1), 0, false, "integer"},
|
|
||||||
{"uint16", uint16(1), 0, false, "integer"},
|
|
||||||
{"int32", int32(1), 0, false, "integer"},
|
|
||||||
{"int (treated as int32)", int(1), 0, false, "integer"},
|
|
||||||
{"uint32", uint32(1), 0, false, "integer"},
|
|
||||||
{"uint (treated as uint32)", uint(1), 0, false, "integer"},
|
|
||||||
{"int64", int64(1), 0, false, "bigint"},
|
|
||||||
{"uint64", uint64(1), 0, false, "bigint"},
|
|
||||||
{"float32", float32(1), 0, false, "real"},
|
|
||||||
{"float64", float64(1), 0, false, "double precision"},
|
|
||||||
{"[]uint8", []uint8{1}, 0, false, "bytea"},
|
|
||||||
{"NullInt64", sql.NullInt64{}, 0, false, "bigint"},
|
|
||||||
{"NullFloat64", sql.NullFloat64{}, 0, false, "double precision"},
|
|
||||||
{"NullBool", sql.NullBool{}, 0, false, "boolean"},
|
|
||||||
{"Time", time.Time{}, 0, false, "timestamp with time zone"},
|
|
||||||
{"default-size string", "", 0, false, "text"},
|
|
||||||
{"sized string", "", 50, false, "varchar(50)"},
|
|
||||||
{"large string", "", 1024, false, "varchar(1024)"},
|
|
||||||
}
|
|
||||||
for _, t := range tests {
|
|
||||||
o.Spec(t.name, func(tt testContext) {
|
|
||||||
typ := reflect.TypeOf(t.value)
|
|
||||||
sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr)
|
|
||||||
expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("AutoIncrStr", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal(""))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("AutoIncrBindValue", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("default"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("AutoIncrInsertSuffix", func(tt testContext) {
|
|
||||||
cm := sqldb.ColumnMap{
|
|
||||||
ColumnName: "foo",
|
|
||||||
}
|
|
||||||
expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(&cm)).To(matchers.Equal(` returning "foo"`))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("CreateTableSuffix", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(""))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("CreateIndexSuffix", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("DropIndexSuffix", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal(""))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("TruncateClause", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("SleepClause", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("pg_sleep(1.000000)"))
|
|
||||||
expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("pg_sleep(0.100000)"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("BindVar", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("$1"))
|
|
||||||
expect.Expect(t, tt.dialect.BindVar(4)).To(matchers.Equal("$5"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Group("QuoteField", func() {
|
|
||||||
o.Spec("By default, case is preserved", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`))
|
|
||||||
expect.Expect(t, tt.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Group("With LowercaseFields set to true", func() {
|
|
||||||
o1 := onpar.BeforeEach(o, func(tt testContext) testContext {
|
|
||||||
tt.dialect.LowercaseFields = true
|
|
||||||
return tt
|
|
||||||
})
|
|
||||||
|
|
||||||
o1.Spec("fields are lowercased", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Group("QuotedTableForQuery", func() {
|
|
||||||
o.Spec("using the default schema", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal(`"foo"`))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("with a supplied schema", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`))
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("IfSchemaNotExists", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("IfTableExists", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists"))
|
|
||||||
})
|
|
||||||
|
|
||||||
o.Spec("IfTableNotExists", func(tt testContext) {
|
|
||||||
expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists"))
|
|
||||||
})
|
|
||||||
}
|
|
@ -1,115 +0,0 @@
|
|||||||
//
|
|
||||||
// dialect_sqlite.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
type SqliteDialect struct {
|
|
||||||
suffix string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) QuerySuffix() string { return ";" }
|
|
||||||
|
|
||||||
func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
|
||||||
switch val.Kind() {
|
|
||||||
case reflect.Ptr:
|
|
||||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
|
||||||
case reflect.Bool:
|
|
||||||
return "integer"
|
|
||||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
|
||||||
return "integer"
|
|
||||||
case reflect.Float64, reflect.Float32:
|
|
||||||
return "real"
|
|
||||||
case reflect.Slice:
|
|
||||||
if val.Elem().Kind() == reflect.Uint8 {
|
|
||||||
return "blob"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
switch val.Name() {
|
|
||||||
case "NullInt64":
|
|
||||||
return "integer"
|
|
||||||
case "NullFloat64":
|
|
||||||
return "real"
|
|
||||||
case "NullBool":
|
|
||||||
return "integer"
|
|
||||||
case "Time":
|
|
||||||
return "datetime"
|
|
||||||
}
|
|
||||||
|
|
||||||
if maxsize < 1 {
|
|
||||||
maxsize = 255
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns autoincrement
|
|
||||||
func (d SqliteDialect) AutoIncrStr() string {
|
|
||||||
return "autoincrement"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) AutoIncrBindValue() string {
|
|
||||||
return "null"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns suffix
|
|
||||||
func (d SqliteDialect) CreateTableSuffix() string {
|
|
||||||
return d.suffix
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) CreateIndexSuffix() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) DropIndexSuffix() string {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// With sqlite, there technically isn't a TRUNCATE statement,
|
|
||||||
// but a DELETE FROM uses a truncate optimization:
|
|
||||||
// http://www.sqlite.org/lang_delete.html
|
|
||||||
func (d SqliteDialect) TruncateClause() string {
|
|
||||||
return "delete from"
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns "?"
|
|
||||||
func (d SqliteDialect) BindVar(i int) string {
|
|
||||||
return "?"
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
|
|
||||||
return standardInsertAutoIncr(exec, insertSql, params...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) QuoteField(f string) string {
|
|
||||||
return `"` + f + `"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// sqlite does not have schemas like PostgreSQL does, so just escape it like normal
|
|
||||||
func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string {
|
|
||||||
return d.QuoteField(table)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) IfSchemaNotExists(command, schema string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) IfTableExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if exists", command)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d SqliteDialect) IfTableNotExists(command, schema, table string) string {
|
|
||||||
return fmt.Sprintf("%s if not exists", command)
|
|
||||||
}
|
|
@ -1,13 +0,0 @@
|
|||||||
//
|
|
||||||
// doc.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
// Package sqldb provides a simple way to marshal Go structs to and from
|
|
||||||
// SQL databases. It uses the database/sql package, and should work with any
|
|
||||||
// compliant database/sql driver.
|
|
||||||
//
|
|
||||||
// Source code and project home:
|
|
||||||
package sqldb
|
|
@ -1,34 +0,0 @@
|
|||||||
//
|
|
||||||
// errors.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A non-fatal error, when a select query returns columns that do not exist
|
|
||||||
// as fields in the struct it is being mapped to
|
|
||||||
// TODO: discuss wether this needs an error. encoding/json silently ignores missing fields
|
|
||||||
type NoFieldInTypeError struct {
|
|
||||||
TypeName string
|
|
||||||
MissingColNames []string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (err *NoFieldInTypeError) Error() string {
|
|
||||||
return fmt.Sprintf("sqldb: no fields %+v in type %s", err.MissingColNames, err.TypeName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// returns true if the error is non-fatal (ie, we shouldn't immediately return)
|
|
||||||
func NonFatalError(err error) bool {
|
|
||||||
switch err.(type) {
|
|
||||||
case *NoFieldInTypeError:
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,45 +0,0 @@
|
|||||||
//
|
|
||||||
// hooks.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
//++ TODO v2-phase3: HasPostGet => PostGetter, HasPostDelete => PostDeleter, etc.
|
|
||||||
|
|
||||||
// HasPostGet provides PostGet() which will be executed after the GET statement.
|
|
||||||
type HasPostGet interface {
|
|
||||||
PostGet(SqlExecutor) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasPostDelete provides PostDelete() which will be executed after the DELETE statement
|
|
||||||
type HasPostDelete interface {
|
|
||||||
PostDelete(SqlExecutor) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasPostUpdate provides PostUpdate() which will be executed after the UPDATE statement
|
|
||||||
type HasPostUpdate interface {
|
|
||||||
PostUpdate(SqlExecutor) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasPostInsert provides PostInsert() which will be executed after the INSERT statement
|
|
||||||
type HasPostInsert interface {
|
|
||||||
PostInsert(SqlExecutor) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasPreDelete provides PreDelete() which will be executed before the DELETE statement.
|
|
||||||
type HasPreDelete interface {
|
|
||||||
PreDelete(SqlExecutor) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasPreUpdate provides PreUpdate() which will be executed before UPDATE statement.
|
|
||||||
type HasPreUpdate interface {
|
|
||||||
PreUpdate(SqlExecutor) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasPreInsert provides PreInsert() which will be executed before INSERT statement.
|
|
||||||
type HasPreInsert interface {
|
|
||||||
PreInsert(SqlExecutor) error
|
|
||||||
}
|
|
@ -1,51 +0,0 @@
|
|||||||
//
|
|
||||||
// index.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
// IndexMap represents a mapping between a Go struct field and a single
|
|
||||||
// index in a table.
|
|
||||||
// Unique and MaxSize only inform the
|
|
||||||
// CreateTables() function and are not used by Insert/Update/Delete/Get.
|
|
||||||
type IndexMap struct {
|
|
||||||
// Index name in db table
|
|
||||||
IndexName string
|
|
||||||
|
|
||||||
// If true, " unique" is added to create index statements.
|
|
||||||
// Not used elsewhere
|
|
||||||
Unique bool
|
|
||||||
|
|
||||||
// Index type supported by Dialect
|
|
||||||
// Postgres: B-tree, Hash, GiST and GIN.
|
|
||||||
// Mysql: Btree, Hash.
|
|
||||||
// Sqlite: nil.
|
|
||||||
IndexType string
|
|
||||||
|
|
||||||
// Columns name for single and multiple indexes
|
|
||||||
columns []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rename allows you to specify the index name in the table
|
|
||||||
//
|
|
||||||
// Example: table.IndMap("customer_test_idx").Rename("customer_idx")
|
|
||||||
func (idx *IndexMap) Rename(indname string) *IndexMap {
|
|
||||||
idx.IndexName = indname
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUnique adds "unique" to the create index statements for this
|
|
||||||
// index, if b is true.
|
|
||||||
func (idx *IndexMap) SetUnique(b bool) *IndexMap {
|
|
||||||
idx.Unique = b
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetIndexType specifies the index type supported by chousen SQL Dialect
|
|
||||||
func (idx *IndexMap) SetIndexType(indtype string) *IndexMap {
|
|
||||||
idx.IndexType = indtype
|
|
||||||
return idx
|
|
||||||
}
|
|
@ -1,59 +0,0 @@
|
|||||||
//
|
|
||||||
// lockerror.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OptimisticLockError is returned by Update() or Delete() if the
|
|
||||||
// struct being modified has a Version field and the value is not equal to
|
|
||||||
// the current value in the database
|
|
||||||
type OptimisticLockError struct {
|
|
||||||
// Table name where the lock error occurred
|
|
||||||
TableName string
|
|
||||||
|
|
||||||
// Primary key values of the row being updated/deleted
|
|
||||||
Keys []interface{}
|
|
||||||
|
|
||||||
// true if a row was found with those keys, indicating the
|
|
||||||
// LocalVersion is stale. false if no value was found with those
|
|
||||||
// keys, suggesting the row has been deleted since loaded, or
|
|
||||||
// was never inserted to begin with
|
|
||||||
RowExists bool
|
|
||||||
|
|
||||||
// Version value on the struct passed to Update/Delete. This value is
|
|
||||||
// out of sync with the database.
|
|
||||||
LocalVersion int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// Error returns a description of the cause of the lock error
|
|
||||||
func (e OptimisticLockError) Error() string {
|
|
||||||
if e.RowExists {
|
|
||||||
return fmt.Sprintf("sqldb: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Sprintf("sqldb: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys)
|
|
||||||
}
|
|
||||||
|
|
||||||
func lockError(m *DbMap, exec SqlExecutor, tableName string,
|
|
||||||
existingVer int64, elem reflect.Value,
|
|
||||||
keys ...interface{}) (int64, error) {
|
|
||||||
|
|
||||||
existing, err := get(m, exec, elem.Interface(), keys...)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
ole := OptimisticLockError{tableName, keys, true, existingVer}
|
|
||||||
if existing == nil {
|
|
||||||
ole.RowExists = false
|
|
||||||
}
|
|
||||||
return -1, ole
|
|
||||||
}
|
|
@ -1,45 +0,0 @@
|
|||||||
//
|
|
||||||
// logging.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import "fmt"
|
|
||||||
|
|
||||||
// SqldbLogger is a deprecated alias of Logger.
|
|
||||||
type SqldbLogger = Logger
|
|
||||||
|
|
||||||
// Logger is the type that sqldb uses to log SQL statements.
|
|
||||||
// See DbMap.TraceOn.
|
|
||||||
type Logger interface {
|
|
||||||
Printf(format string, v ...interface{})
|
|
||||||
}
|
|
||||||
|
|
||||||
// TraceOn turns on SQL statement logging for this DbMap. After this is
|
|
||||||
// called, all SQL statements will be sent to the logger. If prefix is
|
|
||||||
// a non-empty string, it will be written to the front of all logged
|
|
||||||
// strings, which can aid in filtering log lines.
|
|
||||||
//
|
|
||||||
// Use TraceOn if you want to spy on the SQL statements that sqldb
|
|
||||||
// generates.
|
|
||||||
//
|
|
||||||
// Note that the base log.Logger type satisfies Logger, but adapters can
|
|
||||||
// easily be written for other logging packages (e.g., the golang-sanctioned
|
|
||||||
// glog framework).
|
|
||||||
func (m *DbMap) TraceOn(prefix string, logger Logger) {
|
|
||||||
m.logger = logger
|
|
||||||
if prefix == "" {
|
|
||||||
m.logPrefix = prefix
|
|
||||||
} else {
|
|
||||||
m.logPrefix = fmt.Sprintf("%s ", prefix)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// TraceOff turns off tracing. It is idempotent.
|
|
||||||
func (m *DbMap) TraceOff() {
|
|
||||||
m.logger = nil
|
|
||||||
m.logPrefix = ""
|
|
||||||
}
|
|
@ -1,68 +0,0 @@
|
|||||||
//
|
|
||||||
// nulltypes.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql/driver"
|
|
||||||
"log"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// A nullable Time value
|
|
||||||
type NullTime struct {
|
|
||||||
Time time.Time
|
|
||||||
Valid bool // Valid is true if Time is not NULL
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
|
||||||
func (nt *NullTime) Scan(value interface{}) error {
|
|
||||||
log.Printf("Time scan value is: %#v", value)
|
|
||||||
switch t := value.(type) {
|
|
||||||
case time.Time:
|
|
||||||
nt.Time, nt.Valid = t, true
|
|
||||||
case []byte:
|
|
||||||
v := strToTime(string(t))
|
|
||||||
if v != nil {
|
|
||||||
nt.Valid = true
|
|
||||||
nt.Time = *v
|
|
||||||
}
|
|
||||||
case string:
|
|
||||||
v := strToTime(t)
|
|
||||||
if v != nil {
|
|
||||||
nt.Valid = true
|
|
||||||
nt.Time = *v
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func strToTime(v string) *time.Time {
|
|
||||||
for _, dtfmt := range []string{
|
|
||||||
"2006-01-02 15:04:05.999999999",
|
|
||||||
"2006-01-02T15:04:05.999999999",
|
|
||||||
"2006-01-02 15:04:05",
|
|
||||||
"2006-01-02T15:04:05",
|
|
||||||
"2006-01-02 15:04",
|
|
||||||
"2006-01-02T15:04",
|
|
||||||
"2006-01-02",
|
|
||||||
"2006-01-02 15:04:05-07:00",
|
|
||||||
} {
|
|
||||||
if t, err := time.Parse(dtfmt, v); err == nil {
|
|
||||||
return &t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
|
||||||
func (nt NullTime) Value() (driver.Value, error) {
|
|
||||||
if !nt.Valid {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return nt.Time, nil
|
|
||||||
}
|
|
@ -1,40 +0,0 @@
|
|||||||
//
|
|
||||||
// query_builder.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import "reflect"
|
|
||||||
|
|
||||||
// 该功能用于改善记录查询,避免直接写表名
|
|
||||||
// TODO 实现 query builder
|
|
||||||
|
|
||||||
type join_item struct {
|
|
||||||
way string
|
|
||||||
table string
|
|
||||||
on string
|
|
||||||
}
|
|
||||||
|
|
||||||
type query_builder struct {
|
|
||||||
table string
|
|
||||||
fields string
|
|
||||||
conds []string
|
|
||||||
orderBy string
|
|
||||||
offset int
|
|
||||||
limit int
|
|
||||||
joins []join_item
|
|
||||||
}
|
|
||||||
|
|
||||||
func FromEntity(ent any) *query_builder {
|
|
||||||
tabM, err := dm.TableFor(reflect.TypeOf(ent), false)
|
|
||||||
if err != nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return &query_builder{
|
|
||||||
table: tabM.TableName,
|
|
||||||
}
|
|
||||||
}
|
|
@ -1,361 +0,0 @@
|
|||||||
//
|
|
||||||
// select.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
)
|
|
||||||
|
|
||||||
// SelectInt executes the given query, which should be a SELECT statement for a single
|
|
||||||
// integer column, and returns the value of the first row returned. If no rows are
|
|
||||||
// found, zero is returned.
|
|
||||||
func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) {
|
|
||||||
var h int64
|
|
||||||
err := selectVal(e, &h, query, args...)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectNullInt executes the given query, which should be a SELECT statement for a single
|
|
||||||
// integer column, and returns the value of the first row returned. If no rows are
|
|
||||||
// found, the empty sql.NullInt64 value is returned.
|
|
||||||
func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) {
|
|
||||||
var h sql.NullInt64
|
|
||||||
err := selectVal(e, &h, query, args...)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return h, err
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectFloat executes the given query, which should be a SELECT statement for a single
|
|
||||||
// float column, and returns the value of the first row returned. If no rows are
|
|
||||||
// found, zero is returned.
|
|
||||||
func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) {
|
|
||||||
var h float64
|
|
||||||
err := selectVal(e, &h, query, args...)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectNullFloat executes the given query, which should be a SELECT statement for a single
|
|
||||||
// float column, and returns the value of the first row returned. If no rows are
|
|
||||||
// found, the empty sql.NullInt64 value is returned.
|
|
||||||
func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) {
|
|
||||||
var h sql.NullFloat64
|
|
||||||
err := selectVal(e, &h, query, args...)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return h, err
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectStr executes the given query, which should be a SELECT statement for a single
|
|
||||||
// char/varchar column, and returns the value of the first row returned. If no rows are
|
|
||||||
// found, an empty string is returned.
|
|
||||||
func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) {
|
|
||||||
var h string
|
|
||||||
err := selectVal(e, &h, query, args...)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectNullStr executes the given query, which should be a SELECT
|
|
||||||
// statement for a single char/varchar column, and returns the value
|
|
||||||
// of the first row returned. If no rows are found, the empty
|
|
||||||
// sql.NullString is returned.
|
|
||||||
func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) {
|
|
||||||
var h sql.NullString
|
|
||||||
err := selectVal(e, &h, query, args...)
|
|
||||||
if err != nil && err != sql.ErrNoRows {
|
|
||||||
return h, err
|
|
||||||
}
|
|
||||||
return h, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectOne executes the given query (which should be a SELECT statement)
|
|
||||||
// and binds the result to holder, which must be a pointer.
|
|
||||||
//
|
|
||||||
// # If no row is found, an error (sql.ErrNoRows specifically) will be returned
|
|
||||||
//
|
|
||||||
// If more than one row is found, an error will be returned.
|
|
||||||
func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
|
|
||||||
t := reflect.TypeOf(holder)
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("sqldb: SelectOne holder must be a pointer, but got: %t", holder)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Handle pointer to pointer
|
|
||||||
isptr := false
|
|
||||||
if t.Kind() == reflect.Ptr {
|
|
||||||
isptr = true
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.Kind() == reflect.Struct {
|
|
||||||
var nonFatalErr error
|
|
||||||
|
|
||||||
list, err := hookedselect(m, e, holder, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
if !NonFatalError(err) { // FIXME: double negative, rename NonFatalError to FatalError
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
nonFatalErr = err
|
|
||||||
}
|
|
||||||
|
|
||||||
dest := reflect.ValueOf(holder)
|
|
||||||
if isptr {
|
|
||||||
dest = dest.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if list != nil && len(list) > 0 { // FIXME: invert if/else
|
|
||||||
// check for multiple rows
|
|
||||||
if len(list) > 1 {
|
|
||||||
return fmt.Errorf("sqldb: multiple rows returned for: %s - %v", query, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Initialize if nil
|
|
||||||
if dest.IsNil() {
|
|
||||||
dest.Set(reflect.New(t))
|
|
||||||
}
|
|
||||||
|
|
||||||
// only one row found
|
|
||||||
src := reflect.ValueOf(list[0])
|
|
||||||
dest.Elem().Set(src.Elem())
|
|
||||||
} else {
|
|
||||||
// No rows found, return a proper error.
|
|
||||||
return sql.ErrNoRows
|
|
||||||
}
|
|
||||||
|
|
||||||
return nonFatalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
return selectVal(e, holder, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
|
|
||||||
if len(args) == 1 {
|
|
||||||
switch m := e.(type) {
|
|
||||||
case *DbMap:
|
|
||||||
query, args = maybeExpandNamedQuery(m, query, args)
|
|
||||||
case *Transaction:
|
|
||||||
query, args = maybeExpandNamedQuery(m.dbmap, query, args)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rows, err := e.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
if !rows.Next() {
|
|
||||||
if err := rows.Err(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return sql.ErrNoRows
|
|
||||||
}
|
|
||||||
|
|
||||||
return rows.Scan(holder)
|
|
||||||
}
|
|
||||||
|
|
||||||
func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
|
|
||||||
args ...interface{}) ([]interface{}, error) {
|
|
||||||
|
|
||||||
var nonFatalErr error
|
|
||||||
|
|
||||||
list, err := rawselect(m, exec, i, query, args...)
|
|
||||||
if err != nil {
|
|
||||||
if !NonFatalError(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
nonFatalErr = err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Determine where the results are: written to i, or returned in list
|
|
||||||
if t, _ := toSliceType(i); t == nil {
|
|
||||||
for _, v := range list {
|
|
||||||
if v, ok := v.(HasPostGet); ok {
|
|
||||||
err := v.PostGet(exec)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
resultsValue := reflect.Indirect(reflect.ValueOf(i))
|
|
||||||
for i := 0; i < resultsValue.Len(); i++ {
|
|
||||||
if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok {
|
|
||||||
err := v.PostGet(exec)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return list, nonFatalErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
|
|
||||||
args ...interface{}) ([]interface{}, error) {
|
|
||||||
var (
|
|
||||||
appendToSlice = false // Write results to i directly?
|
|
||||||
intoStruct = true // Selecting into a struct?
|
|
||||||
pointerElements = true // Are the slice elements pointers (vs values)?
|
|
||||||
)
|
|
||||||
|
|
||||||
var nonFatalErr error
|
|
||||||
|
|
||||||
tableName := ""
|
|
||||||
var dynObj DynamicTable
|
|
||||||
isDynamic := false
|
|
||||||
if dynObj, isDynamic = i.(DynamicTable); isDynamic {
|
|
||||||
tableName = dynObj.TableName()
|
|
||||||
}
|
|
||||||
|
|
||||||
// get type for i, verifying it's a supported destination
|
|
||||||
t, err := toType(i)
|
|
||||||
if err != nil {
|
|
||||||
var err2 error
|
|
||||||
if t, err2 = toSliceType(i); t == nil {
|
|
||||||
if err2 != nil {
|
|
||||||
return nil, err2
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
pointerElements = t.Kind() == reflect.Ptr
|
|
||||||
if pointerElements {
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
appendToSlice = true
|
|
||||||
intoStruct = t.Kind() == reflect.Struct
|
|
||||||
}
|
|
||||||
|
|
||||||
// If the caller supplied a single struct/map argument, assume a "named
|
|
||||||
// parameter" query. Extract the named arguments from the struct/map, create
|
|
||||||
// the flat arg slice, and rewrite the query to use the dialect's placeholder.
|
|
||||||
if len(args) == 1 {
|
|
||||||
query, args = maybeExpandNamedQuery(m, query, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Run the query
|
|
||||||
rows, err := exec.Query(query, args...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
defer rows.Close()
|
|
||||||
|
|
||||||
// Fetch the column names as returned from db
|
|
||||||
cols, err := rows.Columns()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !intoStruct && len(cols) > 1 {
|
|
||||||
return nil, fmt.Errorf("sqldb: select into non-struct slice requires 1 column, got %d", len(cols))
|
|
||||||
}
|
|
||||||
|
|
||||||
var colToFieldIndex [][]int
|
|
||||||
if intoStruct {
|
|
||||||
colToFieldIndex, err = columnToFieldIndex(m, t, tableName, cols)
|
|
||||||
if err != nil {
|
|
||||||
if !NonFatalError(err) {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
nonFatalErr = err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
conv := m.TypeConverter
|
|
||||||
|
|
||||||
// Add results to one of these two slices.
|
|
||||||
var (
|
|
||||||
list = make([]interface{}, 0)
|
|
||||||
sliceValue = reflect.Indirect(reflect.ValueOf(i))
|
|
||||||
)
|
|
||||||
|
|
||||||
for {
|
|
||||||
if !rows.Next() {
|
|
||||||
// if error occured return rawselect
|
|
||||||
if rows.Err() != nil {
|
|
||||||
return nil, rows.Err()
|
|
||||||
}
|
|
||||||
// time to exit from outer "for" loop
|
|
||||||
break
|
|
||||||
}
|
|
||||||
v := reflect.New(t)
|
|
||||||
|
|
||||||
if isDynamic {
|
|
||||||
v.Interface().(DynamicTable).SetTableName(tableName)
|
|
||||||
}
|
|
||||||
|
|
||||||
dest := make([]interface{}, len(cols))
|
|
||||||
|
|
||||||
custScan := make([]CustomScanner, 0)
|
|
||||||
|
|
||||||
for x := range cols {
|
|
||||||
f := v.Elem()
|
|
||||||
if intoStruct {
|
|
||||||
index := colToFieldIndex[x]
|
|
||||||
if index == nil {
|
|
||||||
// this field is not present in the struct, so create a dummy
|
|
||||||
// value for rows.Scan to scan into
|
|
||||||
var dummy dummyField
|
|
||||||
dest[x] = &dummy
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
f = f.FieldByIndex(index)
|
|
||||||
}
|
|
||||||
target := f.Addr().Interface()
|
|
||||||
if conv != nil {
|
|
||||||
scanner, ok := conv.FromDb(target)
|
|
||||||
if ok {
|
|
||||||
target = scanner.Holder
|
|
||||||
custScan = append(custScan, scanner)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dest[x] = target
|
|
||||||
}
|
|
||||||
|
|
||||||
err = rows.Scan(dest...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range custScan {
|
|
||||||
err = c.Bind()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if appendToSlice {
|
|
||||||
if !pointerElements {
|
|
||||||
v = v.Elem()
|
|
||||||
}
|
|
||||||
sliceValue.Set(reflect.Append(sliceValue, v))
|
|
||||||
} else {
|
|
||||||
list = append(list, v.Interface())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if appendToSlice && sliceValue.IsNil() {
|
|
||||||
sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0))
|
|
||||||
}
|
|
||||||
|
|
||||||
return list, nonFatalErr
|
|
||||||
}
|
|
@ -1,675 +0,0 @@
|
|||||||
//
|
|
||||||
// sqldb.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"database/sql/driver"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"regexp"
|
|
||||||
"strings"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// OracleString (empty string is null)
|
|
||||||
// TODO: move to dialect/oracle?, rename to String?
|
|
||||||
type OracleString struct {
|
|
||||||
sql.NullString
|
|
||||||
}
|
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
|
||||||
func (os *OracleString) Scan(value interface{}) error {
|
|
||||||
if value == nil {
|
|
||||||
os.String, os.Valid = "", false
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
os.Valid = true
|
|
||||||
return os.NullString.Scan(value)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Value implements the driver Valuer interface.
|
|
||||||
func (os OracleString) Value() (driver.Value, error) {
|
|
||||||
if !os.Valid || os.String == "" {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return os.String, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SqlTyper is a type that returns its database type. Most of the
|
|
||||||
// time, the type can just use "database/sql/driver".Valuer; but when
|
|
||||||
// it returns nil for its empty value, it needs to implement SqlTyper
|
|
||||||
// to have its column type detected properly during table creation.
|
|
||||||
type SqlTyper interface {
|
|
||||||
SqlType() driver.Value
|
|
||||||
}
|
|
||||||
|
|
||||||
// legacySqlTyper prevents breaking clients who depended on the previous
|
|
||||||
// SqlTyper interface
|
|
||||||
type legacySqlTyper interface {
|
|
||||||
SqlType() driver.Valuer
|
|
||||||
}
|
|
||||||
|
|
||||||
// for fields that exists in DB table, but not exists in struct
|
|
||||||
type dummyField struct{}
|
|
||||||
|
|
||||||
// Scan implements the Scanner interface.
|
|
||||||
func (nt *dummyField) Scan(value interface{}) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var zeroVal reflect.Value
|
|
||||||
var versFieldConst = "[sqldb_ver_field]"
|
|
||||||
|
|
||||||
// The TypeConverter interface provides a way to map a value of one
|
|
||||||
// type to another type when persisting to, or loading from, a database.
|
|
||||||
//
|
|
||||||
// Example use cases: Implement type converter to convert bool types to "y"/"n" strings,
|
|
||||||
// or serialize a struct member as a JSON blob.
|
|
||||||
type TypeConverter interface {
|
|
||||||
// ToDb converts val to another type. Called before INSERT/UPDATE operations
|
|
||||||
ToDb(val interface{}) (interface{}, error)
|
|
||||||
|
|
||||||
// FromDb returns a CustomScanner appropriate for this type. This will be used
|
|
||||||
// to hold values returned from SELECT queries.
|
|
||||||
//
|
|
||||||
// In particular the CustomScanner returned should implement a Binder
|
|
||||||
// function appropriate for the Go type you wish to convert the db value to
|
|
||||||
//
|
|
||||||
// If bool==false, then no custom scanner will be used for this field.
|
|
||||||
FromDb(target interface{}) (CustomScanner, bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SqlExecutor exposes sqldb operations that can be run from Pre/Post
|
|
||||||
// hooks. This hides whether the current operation that triggered the
|
|
||||||
// hook is in a transaction.
|
|
||||||
//
|
|
||||||
// See the DbMap function docs for each of the functions below for more
|
|
||||||
// information.
|
|
||||||
type SqlExecutor interface {
|
|
||||||
WithContext(ctx context.Context) SqlExecutor
|
|
||||||
Get(i interface{}, keys ...interface{}) (interface{}, error)
|
|
||||||
Insert(list ...interface{}) error
|
|
||||||
Update(list ...interface{}) (int64, error)
|
|
||||||
Delete(list ...interface{}) (int64, error)
|
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
||||||
Select(i interface{}, query string, args ...interface{}) ([]interface{}, error)
|
|
||||||
SelectInt(query string, args ...interface{}) (int64, error)
|
|
||||||
SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
|
|
||||||
SelectFloat(query string, args ...interface{}) (float64, error)
|
|
||||||
SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error)
|
|
||||||
SelectStr(query string, args ...interface{}) (string, error)
|
|
||||||
SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
|
|
||||||
SelectOne(holder interface{}, query string, args ...interface{}) error
|
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
|
||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
|
||||||
}
|
|
||||||
|
|
||||||
// DynamicTable allows the users of sqldb to dynamically
|
|
||||||
// use different database table names during runtime
|
|
||||||
// while sharing the same golang struct for in-memory data
|
|
||||||
type DynamicTable interface {
|
|
||||||
TableName() string
|
|
||||||
SetTableName(string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Compile-time check that DbMap and Transaction implement the SqlExecutor
|
|
||||||
// interface.
|
|
||||||
var _, _ SqlExecutor = &DbMap{}, &Transaction{}
|
|
||||||
|
|
||||||
func argValue(a interface{}) interface{} {
|
|
||||||
v, ok := a.(driver.Valuer)
|
|
||||||
if !ok {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
vV := reflect.ValueOf(v)
|
|
||||||
if vV.Kind() == reflect.Ptr && vV.IsNil() {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
ret, err := v.Value()
|
|
||||||
if err != nil {
|
|
||||||
return a
|
|
||||||
}
|
|
||||||
return ret
|
|
||||||
}
|
|
||||||
|
|
||||||
func argsString(args ...interface{}) string {
|
|
||||||
var margs string
|
|
||||||
for i, a := range args {
|
|
||||||
v := argValue(a)
|
|
||||||
switch v.(type) {
|
|
||||||
case string:
|
|
||||||
v = fmt.Sprintf("%q", v)
|
|
||||||
default:
|
|
||||||
v = fmt.Sprintf("%v", v)
|
|
||||||
}
|
|
||||||
margs += fmt.Sprintf("%d:%s", i+1, v)
|
|
||||||
if i+1 < len(args) {
|
|
||||||
margs += " "
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return margs
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calls the Exec function on the executor, but attempts to expand any eligible named
|
|
||||||
// query arguments first.
|
|
||||||
func maybeExpandNamedQueryAndExec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
dbMap := extractDbMap(e)
|
|
||||||
|
|
||||||
if len(args) == 1 {
|
|
||||||
query, args = maybeExpandNamedQuery(dbMap, query, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
return exec(e, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractDbMap(e SqlExecutor) *DbMap {
|
|
||||||
switch m := e.(type) {
|
|
||||||
case *DbMap:
|
|
||||||
return m
|
|
||||||
case *Transaction:
|
|
||||||
return m.dbmap
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// executor exposes the sql.DB and sql.Tx functions so that it can be used
|
|
||||||
// on internal functions that need to be agnostic to the underlying object.
|
|
||||||
type executor interface {
|
|
||||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
|
||||||
Prepare(query string) (*sql.Stmt, error)
|
|
||||||
QueryRow(query string, args ...interface{}) *sql.Row
|
|
||||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
|
||||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
|
||||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
|
||||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
|
||||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
|
||||||
}
|
|
||||||
|
|
||||||
func extractExecutorAndContext(e SqlExecutor) (executor, context.Context) {
|
|
||||||
switch m := e.(type) {
|
|
||||||
case *DbMap:
|
|
||||||
return m.Db, m.ctx
|
|
||||||
case *Transaction:
|
|
||||||
return m.tx, m.ctx
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used
|
|
||||||
// as input to a named query. If so, it rewrites the query to use
|
|
||||||
// dialect-dependent bindvars and instantiates the corresponding slice of
|
|
||||||
// parameters by extracting data from the map / struct.
|
|
||||||
// If not, returns the input values unchanged.
|
|
||||||
func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) {
|
|
||||||
var (
|
|
||||||
arg = args[0]
|
|
||||||
argval = reflect.ValueOf(arg)
|
|
||||||
)
|
|
||||||
if argval.Kind() == reflect.Ptr {
|
|
||||||
argval = argval.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String {
|
|
||||||
return expandNamedQuery(m, query, func(key string) reflect.Value {
|
|
||||||
return argval.MapIndex(reflect.ValueOf(key))
|
|
||||||
})
|
|
||||||
}
|
|
||||||
if argval.Kind() != reflect.Struct {
|
|
||||||
return query, args
|
|
||||||
}
|
|
||||||
if _, ok := arg.(time.Time); ok {
|
|
||||||
// time.Time is driver.Value
|
|
||||||
return query, args
|
|
||||||
}
|
|
||||||
if _, ok := arg.(driver.Valuer); ok {
|
|
||||||
// driver.Valuer will be converted to driver.Value.
|
|
||||||
return query, args
|
|
||||||
}
|
|
||||||
|
|
||||||
return expandNamedQuery(m, query, argval.FieldByName)
|
|
||||||
}
|
|
||||||
|
|
||||||
var keyRegexp = regexp.MustCompile(`:[[:word:]]+`)
|
|
||||||
|
|
||||||
// expandNamedQuery accepts a query with placeholders of the form ":key", and a
|
|
||||||
// single arg of Kind Struct or Map[string]. It returns the query with the
|
|
||||||
// dialect's placeholders, and a slice of args ready for positional insertion
|
|
||||||
// into the query.
|
|
||||||
func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) {
|
|
||||||
var (
|
|
||||||
n int
|
|
||||||
args []interface{}
|
|
||||||
)
|
|
||||||
return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
|
|
||||||
val := keyGetter(key[1:])
|
|
||||||
if !val.IsValid() {
|
|
||||||
return key
|
|
||||||
}
|
|
||||||
args = append(args, val.Interface())
|
|
||||||
newVar := m.Dialect.BindVar(n)
|
|
||||||
n++
|
|
||||||
return newVar
|
|
||||||
}), args
|
|
||||||
}
|
|
||||||
|
|
||||||
func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) {
|
|
||||||
colToFieldIndex := make([][]int, len(cols))
|
|
||||||
|
|
||||||
// check if type t is a mapped table - if so we'll
|
|
||||||
// check the table for column aliasing below
|
|
||||||
tableMapped := false
|
|
||||||
table := tableOrNil(m, t, name)
|
|
||||||
if table != nil {
|
|
||||||
tableMapped = true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Loop over column names and find field in i to bind to
|
|
||||||
// based on column name. all returned columns must match
|
|
||||||
// a field in the i struct
|
|
||||||
missingColNames := []string{}
|
|
||||||
for x := range cols {
|
|
||||||
colName := strings.ToLower(cols[x])
|
|
||||||
field, found := t.FieldByNameFunc(func(fieldName string) bool {
|
|
||||||
field, _ := t.FieldByName(fieldName)
|
|
||||||
cArguments := strings.Split(field.Tag.Get("db"), ",")
|
|
||||||
fieldName = cArguments[0]
|
|
||||||
|
|
||||||
if fieldName == "-" {
|
|
||||||
return false
|
|
||||||
} else if fieldName == "" {
|
|
||||||
fieldName = field.Name
|
|
||||||
}
|
|
||||||
if tableMapped {
|
|
||||||
colMap := colMapOrNil(table, fieldName)
|
|
||||||
if colMap != nil {
|
|
||||||
fieldName = colMap.ColumnName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return colName == strings.ToLower(fieldName)
|
|
||||||
})
|
|
||||||
if found {
|
|
||||||
colToFieldIndex[x] = field.Index
|
|
||||||
}
|
|
||||||
if colToFieldIndex[x] == nil {
|
|
||||||
missingColNames = append(missingColNames, colName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(missingColNames) > 0 {
|
|
||||||
return colToFieldIndex, &NoFieldInTypeError{
|
|
||||||
TypeName: t.Name(),
|
|
||||||
MissingColNames: missingColNames,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return colToFieldIndex, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func fieldByName(val reflect.Value, fieldName string) *reflect.Value {
|
|
||||||
// try to find field by exact match
|
|
||||||
f := val.FieldByName(fieldName)
|
|
||||||
|
|
||||||
if f != zeroVal {
|
|
||||||
return &f
|
|
||||||
}
|
|
||||||
|
|
||||||
// try to find by case insensitive match - only the Postgres driver
|
|
||||||
// seems to require this - in the case where columns are aliased in the sql
|
|
||||||
fieldNameL := strings.ToLower(fieldName)
|
|
||||||
fieldCount := val.NumField()
|
|
||||||
t := val.Type()
|
|
||||||
for i := 0; i < fieldCount; i++ {
|
|
||||||
sf := t.Field(i)
|
|
||||||
if strings.ToLower(sf.Name) == fieldNameL {
|
|
||||||
f := val.Field(i)
|
|
||||||
return &f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// toSliceType returns the element type of the given object, if the object is a
|
|
||||||
// "*[]*Element" or "*[]Element". If not, returns nil.
|
|
||||||
// err is returned if the user was trying to pass a pointer-to-slice but failed.
|
|
||||||
func toSliceType(i interface{}) (reflect.Type, error) {
|
|
||||||
t := reflect.TypeOf(i)
|
|
||||||
if t.Kind() != reflect.Ptr {
|
|
||||||
// If it's a slice, return a more helpful error message
|
|
||||||
if t.Kind() == reflect.Slice {
|
|
||||||
return nil, fmt.Errorf("sqldb: cannot SELECT into a non-pointer slice: %v", t)
|
|
||||||
}
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
if t = t.Elem(); t.Kind() != reflect.Slice {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
return t.Elem(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func toType(i interface{}) (reflect.Type, error) {
|
|
||||||
t := reflect.TypeOf(i)
|
|
||||||
|
|
||||||
// If a Pointer to a type, follow
|
|
||||||
for t.Kind() == reflect.Ptr {
|
|
||||||
t = t.Elem()
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.Kind() != reflect.Struct {
|
|
||||||
return nil, fmt.Errorf("sqldb: cannot SELECT into this type: %v", reflect.TypeOf(i))
|
|
||||||
}
|
|
||||||
return t, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type foundTable struct {
|
|
||||||
table *TableMap
|
|
||||||
dynName *string
|
|
||||||
}
|
|
||||||
|
|
||||||
func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) {
|
|
||||||
if dyn, isDynamic := i.(DynamicTable); isDynamic {
|
|
||||||
tableName := dyn.TableName()
|
|
||||||
table, err := m.DynamicTableFor(tableName, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &foundTable{
|
|
||||||
table: table,
|
|
||||||
dynName: &tableName,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
table, err := m.TableFor(t, true)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &foundTable{table: table}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func get(m *DbMap, exec SqlExecutor, i interface{},
|
|
||||||
keys ...interface{}) (interface{}, error) {
|
|
||||||
|
|
||||||
t, err := toType(i)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
foundTable, err := tableFor(m, t, i)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
table := foundTable.table
|
|
||||||
|
|
||||||
plan := table.bindGet()
|
|
||||||
|
|
||||||
v := reflect.New(t)
|
|
||||||
if foundTable.dynName != nil {
|
|
||||||
retDyn := v.Interface().(DynamicTable)
|
|
||||||
retDyn.SetTableName(*foundTable.dynName)
|
|
||||||
}
|
|
||||||
|
|
||||||
dest := make([]interface{}, len(plan.argFields))
|
|
||||||
|
|
||||||
conv := m.TypeConverter
|
|
||||||
custScan := make([]CustomScanner, 0)
|
|
||||||
|
|
||||||
for x, fieldName := range plan.argFields {
|
|
||||||
f := v.Elem().FieldByName(fieldName)
|
|
||||||
target := f.Addr().Interface()
|
|
||||||
if conv != nil {
|
|
||||||
scanner, ok := conv.FromDb(target)
|
|
||||||
if ok {
|
|
||||||
target = scanner.Holder
|
|
||||||
custScan = append(custScan, scanner)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
dest[x] = target
|
|
||||||
}
|
|
||||||
|
|
||||||
row := exec.QueryRow(plan.query, keys...)
|
|
||||||
err = row.Scan(dest...)
|
|
||||||
if err != nil {
|
|
||||||
if err == sql.ErrNoRows {
|
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, c := range custScan {
|
|
||||||
err = c.Bind()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := v.Interface().(HasPostGet); ok {
|
|
||||||
err := v.PostGet(exec)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return v.Interface(), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
|
|
||||||
count := int64(0)
|
|
||||||
for _, ptr := range list {
|
|
||||||
table, elem, err := m.tableForPointer(ptr, true)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
eval := elem.Addr().Interface()
|
|
||||||
if v, ok := eval.(HasPreDelete); ok {
|
|
||||||
err = v.PreDelete(exec)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bi, err := table.bindDelete(elem)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := exec.Exec(bi.query, bi.args...)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
rows, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows == 0 && bi.existingVersion > 0 {
|
|
||||||
return lockError(m, exec, table.TableName,
|
|
||||||
bi.existingVersion, elem, bi.keys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
count += rows
|
|
||||||
|
|
||||||
if v, ok := eval.(HasPostDelete); ok {
|
|
||||||
err := v.PostDelete(exec)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) {
|
|
||||||
count := int64(0)
|
|
||||||
for _, ptr := range list {
|
|
||||||
table, elem, err := m.tableForPointer(ptr, true)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
eval := elem.Addr().Interface()
|
|
||||||
if v, ok := eval.(HasPreUpdate); ok {
|
|
||||||
err = v.PreUpdate(exec)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bi, err := table.bindUpdate(elem, colFilter)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := exec.Exec(bi.query, bi.args...)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
rows, err := res.RowsAffected()
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
|
|
||||||
if rows == 0 && bi.existingVersion > 0 {
|
|
||||||
return lockError(m, exec, table.TableName,
|
|
||||||
bi.existingVersion, elem, bi.keys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if bi.versField != "" {
|
|
||||||
elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
|
|
||||||
}
|
|
||||||
|
|
||||||
count += rows
|
|
||||||
|
|
||||||
if v, ok := eval.(HasPostUpdate); ok {
|
|
||||||
err = v.PostUpdate(exec)
|
|
||||||
if err != nil {
|
|
||||||
return -1, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
|
|
||||||
for _, ptr := range list {
|
|
||||||
table, elem, err := m.tableForPointer(ptr, false)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
eval := elem.Addr().Interface()
|
|
||||||
if v, ok := eval.(HasPreInsert); ok {
|
|
||||||
err := v.PreInsert(exec)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
bi, err := table.bindInsert(elem)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if bi.autoIncrIdx > -1 {
|
|
||||||
f := elem.FieldByName(bi.autoIncrFieldName)
|
|
||||||
switch inserter := m.Dialect.(type) {
|
|
||||||
case IntegerAutoIncrInserter:
|
|
||||||
id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
k := f.Kind()
|
|
||||||
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
|
|
||||||
f.SetInt(id)
|
|
||||||
} else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
|
|
||||||
f.SetUint(uint64(id))
|
|
||||||
} else {
|
|
||||||
return fmt.Errorf("sqldb: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
|
|
||||||
}
|
|
||||||
case TargetedAutoIncrInserter:
|
|
||||||
err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
case TargetQueryInserter:
|
|
||||||
var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery
|
|
||||||
if idQuery == "" {
|
|
||||||
return fmt.Errorf("sqldb: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName)
|
|
||||||
}
|
|
||||||
err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("sqldb: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
_, err := exec.Exec(bi.query, bi.args...)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := eval.(HasPostInsert); ok {
|
|
||||||
err := v.PostInsert(exec)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
executor, ctx := extractExecutorAndContext(e)
|
|
||||||
|
|
||||||
if ctx != nil {
|
|
||||||
return executor.ExecContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return executor.Exec(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepare(e SqlExecutor, query string) (*sql.Stmt, error) {
|
|
||||||
executor, ctx := extractExecutorAndContext(e)
|
|
||||||
|
|
||||||
if ctx != nil {
|
|
||||||
return executor.PrepareContext(ctx, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
return executor.Prepare(query)
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row {
|
|
||||||
executor, ctx := extractExecutorAndContext(e)
|
|
||||||
|
|
||||||
if ctx != nil {
|
|
||||||
return executor.QueryRowContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return executor.QueryRow(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func query(e SqlExecutor, query string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
executor, ctx := extractExecutorAndContext(e)
|
|
||||||
|
|
||||||
if ctx != nil {
|
|
||||||
return executor.QueryContext(ctx, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return executor.Query(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func begin(m *DbMap) (*sql.Tx, error) {
|
|
||||||
if m.ctx != nil {
|
|
||||||
return m.Db.BeginTx(m.ctx, nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
return m.Db.Begin()
|
|
||||||
}
|
|
File diff suppressed because it is too large
Load Diff
@ -1,258 +0,0 @@
|
|||||||
//
|
|
||||||
// table.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"strings"
|
|
||||||
)
|
|
||||||
|
|
||||||
// TableMap represents a mapping between a Go struct and a database table
|
|
||||||
// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
|
|
||||||
type TableMap struct {
|
|
||||||
// Name of database table.
|
|
||||||
TableName string
|
|
||||||
SchemaName string
|
|
||||||
gotype reflect.Type
|
|
||||||
Columns []*ColumnMap
|
|
||||||
keys []*ColumnMap
|
|
||||||
indexes []*IndexMap
|
|
||||||
uniqueTogether [][]string
|
|
||||||
version *ColumnMap
|
|
||||||
insertPlan bindPlan
|
|
||||||
updatePlan bindPlan
|
|
||||||
deletePlan bindPlan
|
|
||||||
getPlan bindPlan
|
|
||||||
dbmap *DbMap
|
|
||||||
}
|
|
||||||
|
|
||||||
// ResetSql removes cached insert/update/select/delete SQL strings
|
|
||||||
// associated with this TableMap. Call this if you've modified
|
|
||||||
// any column names or the table name itself.
|
|
||||||
func (t *TableMap) ResetSql() {
|
|
||||||
t.insertPlan = bindPlan{}
|
|
||||||
t.updatePlan = bindPlan{}
|
|
||||||
t.deletePlan = bindPlan{}
|
|
||||||
t.getPlan = bindPlan{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetKeys lets you specify the fields on a struct that map to primary
|
|
||||||
// key columns on the table. If isAutoIncr is set, result.LastInsertId()
|
|
||||||
// will be used after INSERT to bind the generated id to the Go struct.
|
|
||||||
//
|
|
||||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
|
||||||
//
|
|
||||||
// Panics if isAutoIncr is true, and fieldNames length != 1
|
|
||||||
func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
|
|
||||||
if isAutoIncr && len(fieldNames) != 1 {
|
|
||||||
panic(fmt.Sprintf(
|
|
||||||
"sqldb: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)",
|
|
||||||
len(fieldNames)))
|
|
||||||
}
|
|
||||||
t.keys = make([]*ColumnMap, 0)
|
|
||||||
for _, name := range fieldNames {
|
|
||||||
colmap := t.ColMap(name)
|
|
||||||
colmap.isPK = true
|
|
||||||
colmap.isAutoIncr = isAutoIncr
|
|
||||||
t.keys = append(t.keys, colmap)
|
|
||||||
}
|
|
||||||
t.ResetSql()
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetUniqueTogether lets you specify uniqueness constraints across multiple
|
|
||||||
// columns on the table. Each call adds an additional constraint for the
|
|
||||||
// specified columns.
|
|
||||||
//
|
|
||||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
|
||||||
//
|
|
||||||
// Panics if fieldNames length < 2.
|
|
||||||
func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap {
|
|
||||||
if len(fieldNames) < 2 {
|
|
||||||
panic(fmt.Sprintf(
|
|
||||||
"sqldb: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint."))
|
|
||||||
}
|
|
||||||
|
|
||||||
columns := make([]string, 0, len(fieldNames))
|
|
||||||
for _, name := range fieldNames {
|
|
||||||
columns = append(columns, name)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, existingColumns := range t.uniqueTogether {
|
|
||||||
if equal(existingColumns, columns) {
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
}
|
|
||||||
t.uniqueTogether = append(t.uniqueTogether, columns)
|
|
||||||
t.ResetSql()
|
|
||||||
|
|
||||||
return t
|
|
||||||
}
|
|
||||||
|
|
||||||
// ColMap returns the ColumnMap pointer matching the given struct field
|
|
||||||
// name. It panics if the struct does not contain a field matching this
|
|
||||||
// name.
|
|
||||||
func (t *TableMap) ColMap(field string) *ColumnMap {
|
|
||||||
col := colMapOrNil(t, field)
|
|
||||||
if col == nil {
|
|
||||||
e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
|
|
||||||
t.TableName, t.gotype.Name(), field)
|
|
||||||
|
|
||||||
panic(e)
|
|
||||||
}
|
|
||||||
return col
|
|
||||||
}
|
|
||||||
|
|
||||||
func colMapOrNil(t *TableMap, field string) *ColumnMap {
|
|
||||||
for _, col := range t.Columns {
|
|
||||||
if col.fieldName == field || col.ColumnName == field {
|
|
||||||
return col
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IdxMap returns the IndexMap pointer matching the given index name.
|
|
||||||
func (t *TableMap) IdxMap(field string) *IndexMap {
|
|
||||||
for _, idx := range t.indexes {
|
|
||||||
if idx.IndexName == field {
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddIndex registers the index with sqldb for specified table with given parameters.
|
|
||||||
// This operation is idempotent. If index is already mapped, the
|
|
||||||
// existing *IndexMap is returned
|
|
||||||
// Function will panic if one of the given for index columns does not exists
|
|
||||||
//
|
|
||||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
|
||||||
func (t *TableMap) AddIndex(name string, idxtype string, columns []string) *IndexMap {
|
|
||||||
// check if we have a index with this name already
|
|
||||||
for _, idx := range t.indexes {
|
|
||||||
if idx.IndexName == name {
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
}
|
|
||||||
for _, icol := range columns {
|
|
||||||
if res := t.ColMap(icol); res == nil {
|
|
||||||
e := fmt.Sprintf("No ColumnName in table %s to create index on", t.TableName)
|
|
||||||
panic(e)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
idx := &IndexMap{IndexName: name, Unique: false, IndexType: idxtype, columns: columns}
|
|
||||||
t.indexes = append(t.indexes, idx)
|
|
||||||
t.ResetSql()
|
|
||||||
return idx
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetVersionCol sets the column to use as the Version field. By default
|
|
||||||
// the "Version" field is used. Returns the column found, or panics
|
|
||||||
// if the struct does not contain a field matching this name.
|
|
||||||
//
|
|
||||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
|
||||||
func (t *TableMap) SetVersionCol(field string) *ColumnMap {
|
|
||||||
c := t.ColMap(field)
|
|
||||||
t.version = c
|
|
||||||
t.ResetSql()
|
|
||||||
return c
|
|
||||||
}
|
|
||||||
|
|
||||||
// SqlForCreateTable gets a sequence of SQL commands that will create
|
|
||||||
// the specified table and any associated schema
|
|
||||||
func (t *TableMap) SqlForCreate(ifNotExists bool) string {
|
|
||||||
s := bytes.Buffer{}
|
|
||||||
dialect := t.dbmap.Dialect
|
|
||||||
|
|
||||||
if strings.TrimSpace(t.SchemaName) != "" {
|
|
||||||
schemaCreate := "create schema"
|
|
||||||
if ifNotExists {
|
|
||||||
s.WriteString(dialect.IfSchemaNotExists(schemaCreate, t.SchemaName))
|
|
||||||
} else {
|
|
||||||
s.WriteString(schemaCreate)
|
|
||||||
}
|
|
||||||
s.WriteString(fmt.Sprintf(" %s;", t.SchemaName))
|
|
||||||
}
|
|
||||||
|
|
||||||
tableCreate := "create table"
|
|
||||||
if ifNotExists {
|
|
||||||
s.WriteString(dialect.IfTableNotExists(tableCreate, t.SchemaName, t.TableName))
|
|
||||||
} else {
|
|
||||||
s.WriteString(tableCreate)
|
|
||||||
}
|
|
||||||
s.WriteString(fmt.Sprintf(" %s (", dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
|
||||||
|
|
||||||
x := 0
|
|
||||||
for _, col := range t.Columns {
|
|
||||||
if !col.Transient {
|
|
||||||
if x > 0 {
|
|
||||||
s.WriteString(", ")
|
|
||||||
}
|
|
||||||
stype := dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr)
|
|
||||||
s.WriteString(fmt.Sprintf("%s %s", dialect.QuoteField(col.ColumnName), stype))
|
|
||||||
|
|
||||||
if col.isPK || col.isNotNull {
|
|
||||||
s.WriteString(" not null")
|
|
||||||
}
|
|
||||||
if col.isPK && len(t.keys) == 1 {
|
|
||||||
s.WriteString(" primary key")
|
|
||||||
}
|
|
||||||
if col.Unique {
|
|
||||||
s.WriteString(" unique")
|
|
||||||
}
|
|
||||||
if col.isAutoIncr {
|
|
||||||
s.WriteString(fmt.Sprintf(" %s", dialect.AutoIncrStr()))
|
|
||||||
}
|
|
||||||
|
|
||||||
x++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(t.keys) > 1 {
|
|
||||||
s.WriteString(", primary key (")
|
|
||||||
for x := range t.keys {
|
|
||||||
if x > 0 {
|
|
||||||
s.WriteString(", ")
|
|
||||||
}
|
|
||||||
s.WriteString(dialect.QuoteField(t.keys[x].ColumnName))
|
|
||||||
}
|
|
||||||
s.WriteString(")")
|
|
||||||
}
|
|
||||||
if len(t.uniqueTogether) > 0 {
|
|
||||||
for _, columns := range t.uniqueTogether {
|
|
||||||
s.WriteString(", unique (")
|
|
||||||
for i, column := range columns {
|
|
||||||
if i > 0 {
|
|
||||||
s.WriteString(", ")
|
|
||||||
}
|
|
||||||
s.WriteString(dialect.QuoteField(column))
|
|
||||||
}
|
|
||||||
s.WriteString(")")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.WriteString(") ")
|
|
||||||
s.WriteString(dialect.CreateTableSuffix())
|
|
||||||
s.WriteString(dialect.QuerySuffix())
|
|
||||||
return s.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func equal(a, b []string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for i := range a {
|
|
||||||
if a[i] != b[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
@ -1,308 +0,0 @@
|
|||||||
//
|
|
||||||
// table_bindings.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"fmt"
|
|
||||||
"reflect"
|
|
||||||
"sync"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CustomScanner binds a database column value to a Go type
|
|
||||||
type CustomScanner struct {
|
|
||||||
// After a row is scanned, Holder will contain the value from the database column.
|
|
||||||
// Initialize the CustomScanner with the concrete Go type you wish the database
|
|
||||||
// driver to scan the raw column into.
|
|
||||||
Holder interface{}
|
|
||||||
// Target typically holds a pointer to the target struct field to bind the Holder
|
|
||||||
// value to.
|
|
||||||
Target interface{}
|
|
||||||
// Binder is a custom function that converts the holder value to the target type
|
|
||||||
// and sets target accordingly. This function should return error if a problem
|
|
||||||
// occurs converting the holder to the target.
|
|
||||||
Binder func(holder interface{}, target interface{}) error
|
|
||||||
}
|
|
||||||
|
|
||||||
// Used to filter columns when selectively updating
|
|
||||||
type ColumnFilter func(*ColumnMap) bool
|
|
||||||
|
|
||||||
func acceptAllFilter(col *ColumnMap) bool {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Bind is called automatically by sqldb after Scan()
|
|
||||||
func (me CustomScanner) Bind() error {
|
|
||||||
return me.Binder(me.Holder, me.Target)
|
|
||||||
}
|
|
||||||
|
|
||||||
type bindPlan struct {
|
|
||||||
query string
|
|
||||||
argFields []string
|
|
||||||
keyFields []string
|
|
||||||
versField string
|
|
||||||
autoIncrIdx int
|
|
||||||
autoIncrFieldName string
|
|
||||||
once sync.Once
|
|
||||||
}
|
|
||||||
|
|
||||||
func (plan *bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) {
|
|
||||||
bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField}
|
|
||||||
if plan.versField != "" {
|
|
||||||
bi.existingVersion = elem.FieldByName(plan.versField).Int()
|
|
||||||
}
|
|
||||||
|
|
||||||
var err error
|
|
||||||
|
|
||||||
for i := 0; i < len(plan.argFields); i++ {
|
|
||||||
k := plan.argFields[i]
|
|
||||||
if k == versFieldConst {
|
|
||||||
newVer := bi.existingVersion + 1
|
|
||||||
bi.args = append(bi.args, newVer)
|
|
||||||
if bi.existingVersion == 0 {
|
|
||||||
elem.FieldByName(plan.versField).SetInt(int64(newVer))
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
val := elem.FieldByName(k).Interface()
|
|
||||||
if conv != nil {
|
|
||||||
val, err = conv.ToDb(val)
|
|
||||||
if err != nil {
|
|
||||||
return bindInstance{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bi.args = append(bi.args, val)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < len(plan.keyFields); i++ {
|
|
||||||
k := plan.keyFields[i]
|
|
||||||
val := elem.FieldByName(k).Interface()
|
|
||||||
if conv != nil {
|
|
||||||
val, err = conv.ToDb(val)
|
|
||||||
if err != nil {
|
|
||||||
return bindInstance{}, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
bi.keys = append(bi.keys, val)
|
|
||||||
}
|
|
||||||
|
|
||||||
return bi, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
type bindInstance struct {
|
|
||||||
query string
|
|
||||||
args []interface{}
|
|
||||||
keys []interface{}
|
|
||||||
existingVersion int64
|
|
||||||
versField string
|
|
||||||
autoIncrIdx int
|
|
||||||
autoIncrFieldName string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) {
|
|
||||||
plan := &t.insertPlan
|
|
||||||
plan.once.Do(func() {
|
|
||||||
plan.autoIncrIdx = -1
|
|
||||||
|
|
||||||
s := bytes.Buffer{}
|
|
||||||
s2 := bytes.Buffer{}
|
|
||||||
s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
|
||||||
|
|
||||||
x := 0
|
|
||||||
first := true
|
|
||||||
for y := range t.Columns {
|
|
||||||
col := t.Columns[y]
|
|
||||||
if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") {
|
|
||||||
if !col.Transient {
|
|
||||||
if !first {
|
|
||||||
s.WriteString(",")
|
|
||||||
s2.WriteString(",")
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
|
||||||
|
|
||||||
if col.isAutoIncr {
|
|
||||||
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
|
|
||||||
plan.autoIncrIdx = y
|
|
||||||
plan.autoIncrFieldName = col.fieldName
|
|
||||||
} else {
|
|
||||||
if col.DefaultValue == "" {
|
|
||||||
s2.WriteString(t.dbmap.Dialect.BindVar(x))
|
|
||||||
if col == t.version {
|
|
||||||
plan.versField = col.fieldName
|
|
||||||
plan.argFields = append(plan.argFields, versFieldConst)
|
|
||||||
} else {
|
|
||||||
plan.argFields = append(plan.argFields, col.fieldName)
|
|
||||||
}
|
|
||||||
x++
|
|
||||||
} else {
|
|
||||||
s2.WriteString(col.DefaultValue)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
first = false
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
plan.autoIncrIdx = y
|
|
||||||
plan.autoIncrFieldName = col.fieldName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.WriteString(") values (")
|
|
||||||
s.WriteString(s2.String())
|
|
||||||
s.WriteString(")")
|
|
||||||
if plan.autoIncrIdx > -1 {
|
|
||||||
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx]))
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
|
||||||
|
|
||||||
plan.query = s.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) {
|
|
||||||
if colFilter == nil {
|
|
||||||
colFilter = acceptAllFilter
|
|
||||||
}
|
|
||||||
|
|
||||||
plan := &t.updatePlan
|
|
||||||
plan.once.Do(func() {
|
|
||||||
s := bytes.Buffer{}
|
|
||||||
s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
|
||||||
x := 0
|
|
||||||
|
|
||||||
for y := range t.Columns {
|
|
||||||
col := t.Columns[y]
|
|
||||||
if !col.isAutoIncr && !col.Transient && colFilter(col) {
|
|
||||||
if x > 0 {
|
|
||||||
s.WriteString(", ")
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
|
||||||
s.WriteString("=")
|
|
||||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
|
||||||
|
|
||||||
if col == t.version {
|
|
||||||
plan.versField = col.fieldName
|
|
||||||
plan.argFields = append(plan.argFields, versFieldConst)
|
|
||||||
} else {
|
|
||||||
plan.argFields = append(plan.argFields, col.fieldName)
|
|
||||||
}
|
|
||||||
x++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.WriteString(" where ")
|
|
||||||
for y := range t.keys {
|
|
||||||
col := t.keys[y]
|
|
||||||
if y > 0 {
|
|
||||||
s.WriteString(" and ")
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
|
||||||
s.WriteString("=")
|
|
||||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
|
||||||
|
|
||||||
plan.argFields = append(plan.argFields, col.fieldName)
|
|
||||||
plan.keyFields = append(plan.keyFields, col.fieldName)
|
|
||||||
x++
|
|
||||||
}
|
|
||||||
if plan.versField != "" {
|
|
||||||
s.WriteString(" and ")
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
|
|
||||||
s.WriteString("=")
|
|
||||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
|
||||||
plan.argFields = append(plan.argFields, plan.versField)
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
|
||||||
|
|
||||||
plan.query = s.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) {
|
|
||||||
plan := &t.deletePlan
|
|
||||||
plan.once.Do(func() {
|
|
||||||
s := bytes.Buffer{}
|
|
||||||
s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
|
||||||
|
|
||||||
for y := range t.Columns {
|
|
||||||
col := t.Columns[y]
|
|
||||||
if !col.Transient {
|
|
||||||
if col == t.version {
|
|
||||||
plan.versField = col.fieldName
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.WriteString(" where ")
|
|
||||||
for x := range t.keys {
|
|
||||||
k := t.keys[x]
|
|
||||||
if x > 0 {
|
|
||||||
s.WriteString(" and ")
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName))
|
|
||||||
s.WriteString("=")
|
|
||||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
|
||||||
|
|
||||||
plan.keyFields = append(plan.keyFields, k.fieldName)
|
|
||||||
plan.argFields = append(plan.argFields, k.fieldName)
|
|
||||||
}
|
|
||||||
if plan.versField != "" {
|
|
||||||
s.WriteString(" and ")
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
|
|
||||||
s.WriteString("=")
|
|
||||||
s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields)))
|
|
||||||
|
|
||||||
plan.argFields = append(plan.argFields, plan.versField)
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
|
||||||
|
|
||||||
plan.query = s.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TableMap) bindGet() *bindPlan {
|
|
||||||
plan := &t.getPlan
|
|
||||||
plan.once.Do(func() {
|
|
||||||
s := bytes.Buffer{}
|
|
||||||
s.WriteString("select ")
|
|
||||||
|
|
||||||
x := 0
|
|
||||||
for _, col := range t.Columns {
|
|
||||||
if !col.Transient {
|
|
||||||
if x > 0 {
|
|
||||||
s.WriteString(",")
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
|
||||||
plan.argFields = append(plan.argFields, col.fieldName)
|
|
||||||
x++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
s.WriteString(" from ")
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))
|
|
||||||
s.WriteString(" where ")
|
|
||||||
for x := range t.keys {
|
|
||||||
col := t.keys[x]
|
|
||||||
if x > 0 {
|
|
||||||
s.WriteString(" and ")
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
|
||||||
s.WriteString("=")
|
|
||||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
|
||||||
|
|
||||||
plan.keyFields = append(plan.keyFields, col.fieldName)
|
|
||||||
}
|
|
||||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
|
||||||
|
|
||||||
plan.query = s.String()
|
|
||||||
})
|
|
||||||
|
|
||||||
return plan
|
|
||||||
}
|
|
@ -1,24 +0,0 @@
|
|||||||
#!/bin/bash -ex
|
|
||||||
|
|
||||||
# on macs, you may need to:
|
|
||||||
# export GOBUILDFLAG=-ldflags -linkmode=external
|
|
||||||
|
|
||||||
echo "Running unit tests"
|
|
||||||
go test -race
|
|
||||||
|
|
||||||
echo "Testing against postgres"
|
|
||||||
export SQLDB_TEST_DSN="host=127.0.0.1 user=testuser password=123 dbname=testdb sslmode=disable"
|
|
||||||
export SQLDB_TEST_DIALECT=postgres
|
|
||||||
go test -tags integration $GOBUILDFLAG $@ .
|
|
||||||
|
|
||||||
echo "Testing against sqlite"
|
|
||||||
export SQLDB_TEST_DSN=/tmp/testdb.bin
|
|
||||||
export SQLDB_TEST_DIALECT=sqlite
|
|
||||||
go test -tags integration $GOBUILDFLAG $@ .
|
|
||||||
rm -f /tmp/testdb.bin
|
|
||||||
|
|
||||||
echo "Testing against mysql"
|
|
||||||
# export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local"
|
|
||||||
export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb"
|
|
||||||
export SQLDB_TEST_DIALECT=mysql
|
|
||||||
go test -tags integration $GOBUILDFLAG $@ .
|
|
@ -1,242 +0,0 @@
|
|||||||
//
|
|
||||||
// transaction.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package sqldb
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"database/sql"
|
|
||||||
"time"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Transaction represents a database transaction.
|
|
||||||
// Insert/Update/Delete/Get/Exec operations will be run in the context
|
|
||||||
// of that transaction. Transactions should be terminated with
|
|
||||||
// a call to Commit() or Rollback()
|
|
||||||
type Transaction struct {
|
|
||||||
ctx context.Context
|
|
||||||
dbmap *DbMap
|
|
||||||
tx *sql.Tx
|
|
||||||
closed bool
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Transaction) WithContext(ctx context.Context) SqlExecutor {
|
|
||||||
copy := &Transaction{}
|
|
||||||
*copy = *t
|
|
||||||
copy.ctx = ctx
|
|
||||||
return copy
|
|
||||||
}
|
|
||||||
|
|
||||||
// Insert has the same behavior as DbMap.Insert(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Insert(list ...interface{}) error {
|
|
||||||
return insert(t.dbmap, t, list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update had the same behavior as DbMap.Update(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Update(list ...interface{}) (int64, error) {
|
|
||||||
return update(t.dbmap, t, nil, list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateColumns had the same behavior as DbMap.UpdateColumns(), but runs in a transaction.
|
|
||||||
func (t *Transaction) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) {
|
|
||||||
return update(t.dbmap, t, filter, list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Delete has the same behavior as DbMap.Delete(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Delete(list ...interface{}) (int64, error) {
|
|
||||||
return delete(t.dbmap, t, list...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get has the same behavior as DbMap.Get(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) {
|
|
||||||
return get(t.dbmap, t, i, keys...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Select has the same behavior as DbMap.Select(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return hookedselect(t.dbmap, t, i, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Exec has the same behavior as DbMap.Exec(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, query, args...)
|
|
||||||
}
|
|
||||||
return maybeExpandNamedQueryAndExec(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectInt is a convenience wrapper around the sqldb.SelectInt function.
|
|
||||||
func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectInt(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function.
|
|
||||||
func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectNullInt(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function.
|
|
||||||
func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectFloat(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function.
|
|
||||||
func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectNullFloat(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectStr is a convenience wrapper around the sqldb.SelectStr function.
|
|
||||||
func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectStr(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function.
|
|
||||||
func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectNullStr(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SelectOne is a convenience wrapper around the sqldb.SelectOne function.
|
|
||||||
func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
return SelectOne(t.dbmap, t, holder, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Commit commits the underlying database transaction.
|
|
||||||
func (t *Transaction) Commit() error {
|
|
||||||
if !t.closed {
|
|
||||||
t.closed = true
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, "commit;")
|
|
||||||
}
|
|
||||||
return t.tx.Commit()
|
|
||||||
}
|
|
||||||
|
|
||||||
return sql.ErrTxDone
|
|
||||||
}
|
|
||||||
|
|
||||||
// Rollback rolls back the underlying database transaction.
|
|
||||||
func (t *Transaction) Rollback() error {
|
|
||||||
if !t.closed {
|
|
||||||
t.closed = true
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, "rollback;")
|
|
||||||
}
|
|
||||||
return t.tx.Rollback()
|
|
||||||
}
|
|
||||||
|
|
||||||
return sql.ErrTxDone
|
|
||||||
}
|
|
||||||
|
|
||||||
// Savepoint creates a savepoint with the given name. The name is interpolated
|
|
||||||
// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is
|
|
||||||
// derived from user input.
|
|
||||||
func (t *Transaction) Savepoint(name string) error {
|
|
||||||
query := "savepoint " + t.dbmap.Dialect.QuoteField(name)
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, query, nil)
|
|
||||||
}
|
|
||||||
_, err := exec(t, query)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// RollbackToSavepoint rolls back to the savepoint with the given name. The
|
|
||||||
// name is interpolated directly into the SQL SAVEPOINT statement, so you must
|
|
||||||
// sanitize it if it is derived from user input.
|
|
||||||
func (t *Transaction) RollbackToSavepoint(savepoint string) error {
|
|
||||||
query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, query, nil)
|
|
||||||
}
|
|
||||||
_, err := exec(t, query)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// ReleaseSavepint releases the savepoint with the given name. The name is
|
|
||||||
// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize
|
|
||||||
// it if it is derived from user input.
|
|
||||||
func (t *Transaction) ReleaseSavepoint(savepoint string) error {
|
|
||||||
query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, query, nil)
|
|
||||||
}
|
|
||||||
_, err := exec(t, query)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction.
|
|
||||||
func (t *Transaction) Prepare(query string) (*sql.Stmt, error) {
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, query, nil)
|
|
||||||
}
|
|
||||||
return prepare(t, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, query, args...)
|
|
||||||
}
|
|
||||||
return queryRow(t, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *Transaction) Query(q string, args ...interface{}) (*sql.Rows, error) {
|
|
||||||
if t.dbmap.ExpandSliceArgs {
|
|
||||||
expandSliceArgs(&q, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
if t.dbmap.logger != nil {
|
|
||||||
now := time.Now()
|
|
||||||
defer t.dbmap.trace(now, q, args...)
|
|
||||||
}
|
|
||||||
return query(t, q, args...)
|
|
||||||
}
|
|
@ -1,340 +0,0 @@
|
|||||||
//
|
|
||||||
// transaction_test.go
|
|
||||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
//go:build integration
|
|
||||||
// +build integration
|
|
||||||
|
|
||||||
package sqldb_test
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestTransaction_Select_expandSliceArgs(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
description string
|
|
||||||
query string
|
|
||||||
args []interface{}
|
|
||||||
wantLen int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
description: "it should handle slice placeholders correctly",
|
|
||||||
query: `
|
|
||||||
SELECT 1 FROM crazy_table
|
|
||||||
WHERE field1 = :Field1
|
|
||||||
AND field2 IN (:FieldStringList)
|
|
||||||
AND field3 IN (:FieldUIntList)
|
|
||||||
AND field4 IN (:FieldUInt8List)
|
|
||||||
AND field5 IN (:FieldUInt16List)
|
|
||||||
AND field6 IN (:FieldUInt32List)
|
|
||||||
AND field7 IN (:FieldUInt64List)
|
|
||||||
AND field8 IN (:FieldIntList)
|
|
||||||
AND field9 IN (:FieldInt8List)
|
|
||||||
AND field10 IN (:FieldInt16List)
|
|
||||||
AND field11 IN (:FieldInt32List)
|
|
||||||
AND field12 IN (:FieldInt64List)
|
|
||||||
AND field13 IN (:FieldFloat32List)
|
|
||||||
AND field14 IN (:FieldFloat64List)
|
|
||||||
`,
|
|
||||||
args: []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"Field1": 123,
|
|
||||||
"FieldStringList": []string{"h", "e", "y"},
|
|
||||||
"FieldUIntList": []uint{1, 2, 3, 4},
|
|
||||||
"FieldUInt8List": []uint8{1, 2, 3, 4},
|
|
||||||
"FieldUInt16List": []uint16{1, 2, 3, 4},
|
|
||||||
"FieldUInt32List": []uint32{1, 2, 3, 4},
|
|
||||||
"FieldUInt64List": []uint64{1, 2, 3, 4},
|
|
||||||
"FieldIntList": []int{1, 2, 3, 4},
|
|
||||||
"FieldInt8List": []int8{1, 2, 3, 4},
|
|
||||||
"FieldInt16List": []int16{1, 2, 3, 4},
|
|
||||||
"FieldInt32List": []int32{1, 2, 3, 4},
|
|
||||||
"FieldInt64List": []int64{1, 2, 3, 4},
|
|
||||||
"FieldFloat32List": []float32{1, 2, 3, 4},
|
|
||||||
"FieldFloat64List": []float64{1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "it should handle slice placeholders correctly with custom types",
|
|
||||||
query: `
|
|
||||||
SELECT 1 FROM crazy_table
|
|
||||||
WHERE field2 IN (:FieldStringList)
|
|
||||||
AND field12 IN (:FieldIntList)
|
|
||||||
`,
|
|
||||||
args: []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"FieldStringList": customType1{"h", "e", "y"},
|
|
||||||
"FieldIntList": customType2{1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 3,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type dataFormat struct {
|
|
||||||
Field1 int `db:"field1"`
|
|
||||||
Field2 string `db:"field2"`
|
|
||||||
Field3 uint `db:"field3"`
|
|
||||||
Field4 uint8 `db:"field4"`
|
|
||||||
Field5 uint16 `db:"field5"`
|
|
||||||
Field6 uint32 `db:"field6"`
|
|
||||||
Field7 uint64 `db:"field7"`
|
|
||||||
Field8 int `db:"field8"`
|
|
||||||
Field9 int8 `db:"field9"`
|
|
||||||
Field10 int16 `db:"field10"`
|
|
||||||
Field11 int32 `db:"field11"`
|
|
||||||
Field12 int64 `db:"field12"`
|
|
||||||
Field13 float32 `db:"field13"`
|
|
||||||
Field14 float64 `db:"field14"`
|
|
||||||
}
|
|
||||||
|
|
||||||
dbmap := newDBMap(t)
|
|
||||||
dbmap.ExpandSliceArgs = true
|
|
||||||
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
|
|
||||||
|
|
||||||
err := dbmap.CreateTables()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer dropAndClose(dbmap)
|
|
||||||
|
|
||||||
err = dbmap.Insert(
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 123,
|
|
||||||
Field2: "h",
|
|
||||||
Field3: 1,
|
|
||||||
Field4: 1,
|
|
||||||
Field5: 1,
|
|
||||||
Field6: 1,
|
|
||||||
Field7: 1,
|
|
||||||
Field8: 1,
|
|
||||||
Field9: 1,
|
|
||||||
Field10: 1,
|
|
||||||
Field11: 1,
|
|
||||||
Field12: 1,
|
|
||||||
Field13: 1,
|
|
||||||
Field14: 1,
|
|
||||||
},
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 124,
|
|
||||||
Field2: "e",
|
|
||||||
Field3: 2,
|
|
||||||
Field4: 2,
|
|
||||||
Field5: 2,
|
|
||||||
Field6: 2,
|
|
||||||
Field7: 2,
|
|
||||||
Field8: 2,
|
|
||||||
Field9: 2,
|
|
||||||
Field10: 2,
|
|
||||||
Field11: 2,
|
|
||||||
Field12: 2,
|
|
||||||
Field13: 2,
|
|
||||||
Field14: 2,
|
|
||||||
},
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 125,
|
|
||||||
Field2: "y",
|
|
||||||
Field3: 3,
|
|
||||||
Field4: 3,
|
|
||||||
Field5: 3,
|
|
||||||
Field6: 3,
|
|
||||||
Field7: 3,
|
|
||||||
Field8: 3,
|
|
||||||
Field9: 3,
|
|
||||||
Field10: 3,
|
|
||||||
Field11: 3,
|
|
||||||
Field12: 3,
|
|
||||||
Field13: 3,
|
|
||||||
Field14: 3,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.description, func(t *testing.T) {
|
|
||||||
tx, err := dbmap.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
var dummy []int
|
|
||||||
_, err = tx.Select(&dummy, tt.query, tt.args...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(dummy) != tt.wantLen {
|
|
||||||
t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestTransaction_Exec_expandSliceArgs(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
description string
|
|
||||||
query string
|
|
||||||
args []interface{}
|
|
||||||
wantLen int
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
description: "it should handle slice placeholders correctly",
|
|
||||||
query: `
|
|
||||||
DELETE FROM crazy_table
|
|
||||||
WHERE field1 = :Field1
|
|
||||||
AND field2 IN (:FieldStringList)
|
|
||||||
AND field3 IN (:FieldUIntList)
|
|
||||||
AND field4 IN (:FieldUInt8List)
|
|
||||||
AND field5 IN (:FieldUInt16List)
|
|
||||||
AND field6 IN (:FieldUInt32List)
|
|
||||||
AND field7 IN (:FieldUInt64List)
|
|
||||||
AND field8 IN (:FieldIntList)
|
|
||||||
AND field9 IN (:FieldInt8List)
|
|
||||||
AND field10 IN (:FieldInt16List)
|
|
||||||
AND field11 IN (:FieldInt32List)
|
|
||||||
AND field12 IN (:FieldInt64List)
|
|
||||||
AND field13 IN (:FieldFloat32List)
|
|
||||||
AND field14 IN (:FieldFloat64List)
|
|
||||||
`,
|
|
||||||
args: []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"Field1": 123,
|
|
||||||
"FieldStringList": []string{"h", "e", "y"},
|
|
||||||
"FieldUIntList": []uint{1, 2, 3, 4},
|
|
||||||
"FieldUInt8List": []uint8{1, 2, 3, 4},
|
|
||||||
"FieldUInt16List": []uint16{1, 2, 3, 4},
|
|
||||||
"FieldUInt32List": []uint32{1, 2, 3, 4},
|
|
||||||
"FieldUInt64List": []uint64{1, 2, 3, 4},
|
|
||||||
"FieldIntList": []int{1, 2, 3, 4},
|
|
||||||
"FieldInt8List": []int8{1, 2, 3, 4},
|
|
||||||
"FieldInt16List": []int16{1, 2, 3, 4},
|
|
||||||
"FieldInt32List": []int32{1, 2, 3, 4},
|
|
||||||
"FieldInt64List": []int64{1, 2, 3, 4},
|
|
||||||
"FieldFloat32List": []float32{1, 2, 3, 4},
|
|
||||||
"FieldFloat64List": []float64{1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 1,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
description: "it should handle slice placeholders correctly with custom types",
|
|
||||||
query: `
|
|
||||||
DELETE FROM crazy_table
|
|
||||||
WHERE field2 IN (:FieldStringList)
|
|
||||||
AND field12 IN (:FieldIntList)
|
|
||||||
`,
|
|
||||||
args: []interface{}{
|
|
||||||
map[string]interface{}{
|
|
||||||
"FieldStringList": customType1{"h", "e", "y"},
|
|
||||||
"FieldIntList": customType2{1, 2, 3, 4},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
wantLen: 3,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
type dataFormat struct {
|
|
||||||
Field1 int `db:"field1"`
|
|
||||||
Field2 string `db:"field2"`
|
|
||||||
Field3 uint `db:"field3"`
|
|
||||||
Field4 uint8 `db:"field4"`
|
|
||||||
Field5 uint16 `db:"field5"`
|
|
||||||
Field6 uint32 `db:"field6"`
|
|
||||||
Field7 uint64 `db:"field7"`
|
|
||||||
Field8 int `db:"field8"`
|
|
||||||
Field9 int8 `db:"field9"`
|
|
||||||
Field10 int16 `db:"field10"`
|
|
||||||
Field11 int32 `db:"field11"`
|
|
||||||
Field12 int64 `db:"field12"`
|
|
||||||
Field13 float32 `db:"field13"`
|
|
||||||
Field14 float64 `db:"field14"`
|
|
||||||
}
|
|
||||||
|
|
||||||
dbmap := newDBMap(t)
|
|
||||||
dbmap.ExpandSliceArgs = true
|
|
||||||
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
|
|
||||||
|
|
||||||
err := dbmap.CreateTables()
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
defer dropAndClose(dbmap)
|
|
||||||
|
|
||||||
err = dbmap.Insert(
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 123,
|
|
||||||
Field2: "h",
|
|
||||||
Field3: 1,
|
|
||||||
Field4: 1,
|
|
||||||
Field5: 1,
|
|
||||||
Field6: 1,
|
|
||||||
Field7: 1,
|
|
||||||
Field8: 1,
|
|
||||||
Field9: 1,
|
|
||||||
Field10: 1,
|
|
||||||
Field11: 1,
|
|
||||||
Field12: 1,
|
|
||||||
Field13: 1,
|
|
||||||
Field14: 1,
|
|
||||||
},
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 124,
|
|
||||||
Field2: "e",
|
|
||||||
Field3: 2,
|
|
||||||
Field4: 2,
|
|
||||||
Field5: 2,
|
|
||||||
Field6: 2,
|
|
||||||
Field7: 2,
|
|
||||||
Field8: 2,
|
|
||||||
Field9: 2,
|
|
||||||
Field10: 2,
|
|
||||||
Field11: 2,
|
|
||||||
Field12: 2,
|
|
||||||
Field13: 2,
|
|
||||||
Field14: 2,
|
|
||||||
},
|
|
||||||
&dataFormat{
|
|
||||||
Field1: 125,
|
|
||||||
Field2: "y",
|
|
||||||
Field3: 3,
|
|
||||||
Field4: 3,
|
|
||||||
Field5: 3,
|
|
||||||
Field6: 3,
|
|
||||||
Field7: 3,
|
|
||||||
Field8: 3,
|
|
||||||
Field9: 3,
|
|
||||||
Field10: 3,
|
|
||||||
Field11: 3,
|
|
||||||
Field12: 3,
|
|
||||||
Field13: 3,
|
|
||||||
Field14: 3,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.description, func(t *testing.T) {
|
|
||||||
tx, err := dbmap.Begin()
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
defer tx.Rollback()
|
|
||||||
|
|
||||||
_, err = tx.Exec(tt.query, tt.args...)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal(err)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user