feat: 增加 orm 库
This commit is contained in:
parent
30c596f8ef
commit
0cc0b5f310
3
gdb/orm/.gitignore
vendored
Normal file
3
gdb/orm/.gitignore
vendored
Normal file
@ -0,0 +1,3 @@
|
|||||||
|
**/.idea/*
|
||||||
|
cover.out
|
||||||
|
**db
|
157
gdb/orm/binder.go
Normal file
157
gdb/orm/binder.go
Normal file
@ -0,0 +1,157 @@
|
|||||||
|
//
|
||||||
|
// binder.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"unsafe"
|
||||||
|
)
|
||||||
|
|
||||||
|
// makeNewPointersOf creates a map of [field name] -> pointer to fill it
|
||||||
|
// recursively. it will go down until reaches a driver.Valuer implementation, it will stop there.
|
||||||
|
func (b *binder) makeNewPointersOf(v reflect.Value) interface{} {
|
||||||
|
m := map[string]interface{}{}
|
||||||
|
actualV := v
|
||||||
|
for actualV.Type().Kind() == reflect.Ptr {
|
||||||
|
actualV = actualV.Elem()
|
||||||
|
}
|
||||||
|
if actualV.Type().Kind() == reflect.Struct {
|
||||||
|
for i := 0; i < actualV.NumField(); i++ {
|
||||||
|
f := actualV.Field(i)
|
||||||
|
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||||||
|
f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr()))
|
||||||
|
fm := b.makeNewPointersOf(f).(map[string]interface{})
|
||||||
|
for k, p := range fm {
|
||||||
|
m[k] = p
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
var fm *field
|
||||||
|
fm = b.s.getField(actualV.Type().Field(i))
|
||||||
|
if fm == nil {
|
||||||
|
fm = fieldMetadata(actualV.Type().Field(i), b.s.columnConstraints)[0]
|
||||||
|
}
|
||||||
|
m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return v.Addr().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
// ptrsFor first allocates for all struct fields recursively until reaches a driver.Value impl
|
||||||
|
// then it will put them in a map with their correct field name as key, then loops over cts
|
||||||
|
// and for each one gets appropriate one from the map and adds it to pointer list.
|
||||||
|
func (b *binder) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} {
|
||||||
|
ptrs := b.makeNewPointersOf(v)
|
||||||
|
var scanInto []interface{}
|
||||||
|
if reflect.TypeOf(ptrs).Kind() == reflect.Map {
|
||||||
|
nameToPtr := ptrs.(map[string]interface{})
|
||||||
|
for _, ct := range cts {
|
||||||
|
if nameToPtr[ct.Name()] != nil {
|
||||||
|
scanInto = append(scanInto, nameToPtr[ct.Name()])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
scanInto = append(scanInto, ptrs)
|
||||||
|
}
|
||||||
|
|
||||||
|
return scanInto
|
||||||
|
}
|
||||||
|
|
||||||
|
type binder struct {
|
||||||
|
s *schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBinder(s *schema) *binder {
|
||||||
|
return &binder{s: s}
|
||||||
|
}
|
||||||
|
|
||||||
|
// bind binds given rows to the given object at obj. obj should be a pointer
|
||||||
|
func (b *binder) bind(rows *sql.Rows, obj interface{}) error {
|
||||||
|
cts, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
t := reflect.TypeOf(obj)
|
||||||
|
v := reflect.ValueOf(obj)
|
||||||
|
if t.Kind() != reflect.Ptr {
|
||||||
|
return fmt.Errorf("obj should be a ptr")
|
||||||
|
}
|
||||||
|
// since passed input is always a pointer one deref is necessary
|
||||||
|
t = t.Elem()
|
||||||
|
v = v.Elem()
|
||||||
|
if t.Kind() == reflect.Slice {
|
||||||
|
// getting slice elemnt type -> slice[t]
|
||||||
|
t = t.Elem()
|
||||||
|
for rows.Next() {
|
||||||
|
var rowValue reflect.Value
|
||||||
|
// Since reflect.SetupConnections returns a pointer to the type, we need to unwrap it to get actual
|
||||||
|
rowValue = reflect.New(t).Elem()
|
||||||
|
// till we reach a not pointer type continue newing the underlying type.
|
||||||
|
for rowValue.IsZero() && rowValue.Type().Kind() == reflect.Ptr {
|
||||||
|
rowValue = reflect.New(rowValue.Type().Elem()).Elem()
|
||||||
|
}
|
||||||
|
newCts := make([]*sql.ColumnType, len(cts))
|
||||||
|
copy(newCts, cts)
|
||||||
|
ptrs := b.ptrsFor(rowValue, newCts)
|
||||||
|
err = rows.Scan(ptrs...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
for rowValue.Type() != t {
|
||||||
|
tmp := reflect.New(rowValue.Type())
|
||||||
|
tmp.Elem().Set(rowValue)
|
||||||
|
rowValue = tmp
|
||||||
|
}
|
||||||
|
v = reflect.Append(v, rowValue)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
for rows.Next() {
|
||||||
|
ptrs := b.ptrsFor(v, cts)
|
||||||
|
err = rows.Scan(ptrs...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// v is either struct or slice
|
||||||
|
reflect.ValueOf(obj).Elem().Set(v)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func bindToMap(rows *sql.Rows) ([]map[string]interface{}, error) {
|
||||||
|
cts, err := rows.ColumnTypes()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var ms []map[string]interface{}
|
||||||
|
for rows.Next() {
|
||||||
|
var ptrs []interface{}
|
||||||
|
for _, ct := range cts {
|
||||||
|
ptrs = append(ptrs, reflect.New(ct.ScanType()).Interface())
|
||||||
|
}
|
||||||
|
|
||||||
|
err = rows.Scan(ptrs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
m := map[string]interface{}{}
|
||||||
|
for i, ptr := range ptrs {
|
||||||
|
m[cts[i].Name()] = reflect.ValueOf(ptr).Elem().Interface()
|
||||||
|
}
|
||||||
|
|
||||||
|
ms = append(ms, m)
|
||||||
|
}
|
||||||
|
return ms, nil
|
||||||
|
}
|
91
gdb/orm/binder_test.go
Normal file
91
gdb/orm/binder_test.go
Normal file
@ -0,0 +1,91 @@
|
|||||||
|
//
|
||||||
|
// binder_test.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/DATA-DOG/go-sqlmock"
|
||||||
|
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
ID int64
|
||||||
|
Name string
|
||||||
|
Timestamps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u User) ConfigureEntity(e *EntityConfigurator) {
|
||||||
|
e.Table("users")
|
||||||
|
}
|
||||||
|
|
||||||
|
type Address struct {
|
||||||
|
ID int
|
||||||
|
Path string
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBind(t *testing.T) {
|
||||||
|
t.Run("single result", func(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
mock.
|
||||||
|
ExpectQuery("SELECT .* FROM users").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}).
|
||||||
|
AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{}))
|
||||||
|
rows, err := db.Query(`SELECT * FROM users`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
u := &User{}
|
||||||
|
md := schemaOfHeavyReflectionStuff(u)
|
||||||
|
err = newBinder(md).bind(rows, u)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "amirreza", u.Name)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multi result", func(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
mock.
|
||||||
|
ExpectQuery("SELECT .* FROM users").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "name"}).AddRow(1, "amirreza").AddRow(2, "milad"))
|
||||||
|
|
||||||
|
rows, err := db.Query(`SELECT * FROM users`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
md := schemaOfHeavyReflectionStuff(&User{})
|
||||||
|
var users []*User
|
||||||
|
err = newBinder(md).bind(rows, &users)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "amirreza", users[0].Name)
|
||||||
|
assert.Equal(t, "milad", users[1].Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBindMap(t *testing.T) {
|
||||||
|
db, mock, err := sqlmock.New()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
mock.
|
||||||
|
ExpectQuery("SELECT .* FROM users").
|
||||||
|
WillReturnRows(sqlmock.NewRows([]string{"id", "name", "created_at", "updated_at", "deleted_at"}).
|
||||||
|
AddRow(1, "amirreza", sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{Time: time.Now(), Valid: true}, sql.NullTime{}))
|
||||||
|
rows, err := db.Query(`SELECT * FROM users`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
ms, err := bindToMap(rows)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, ms)
|
||||||
|
|
||||||
|
assert.Len(t, ms, 1)
|
||||||
|
}
|
192
gdb/orm/configurators.go
Normal file
192
gdb/orm/configurators.go
Normal file
@ -0,0 +1,192 @@
|
|||||||
|
//
|
||||||
|
// configurators.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
|
||||||
|
"git.hexq.cn/tiglog/golib/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type EntityConfigurator struct {
|
||||||
|
connection string
|
||||||
|
table string
|
||||||
|
this Entity
|
||||||
|
relations map[string]interface{}
|
||||||
|
resolveRelations []func()
|
||||||
|
columnConstraints []*FieldConfigurator
|
||||||
|
}
|
||||||
|
|
||||||
|
func newEntityConfigurator() *EntityConfigurator {
|
||||||
|
return &EntityConfigurator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) Table(name string) *EntityConfigurator {
|
||||||
|
ec.table = name
|
||||||
|
return ec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) Connection(name string) *EntityConfigurator {
|
||||||
|
ec.connection = name
|
||||||
|
return ec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) HasMany(property Entity, config HasManyConfig) *EntityConfigurator {
|
||||||
|
if ec.relations == nil {
|
||||||
|
ec.relations = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
ec.resolveRelations = append(ec.resolveRelations, func() {
|
||||||
|
if config.PropertyForeignKey != "" && config.PropertyTable != "" {
|
||||||
|
ec.relations[config.PropertyTable] = config
|
||||||
|
return
|
||||||
|
}
|
||||||
|
configurator := newEntityConfigurator()
|
||||||
|
property.ConfigureEntity(configurator)
|
||||||
|
|
||||||
|
if config.PropertyTable == "" {
|
||||||
|
config.PropertyTable = configurator.table
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.PropertyForeignKey == "" {
|
||||||
|
config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
ec.relations[configurator.table] = config
|
||||||
|
|
||||||
|
return
|
||||||
|
})
|
||||||
|
return ec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) HasOne(property Entity, config HasOneConfig) *EntityConfigurator {
|
||||||
|
if ec.relations == nil {
|
||||||
|
ec.relations = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
ec.resolveRelations = append(ec.resolveRelations, func() {
|
||||||
|
if config.PropertyForeignKey != "" && config.PropertyTable != "" {
|
||||||
|
ec.relations[config.PropertyTable] = config
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
configurator := newEntityConfigurator()
|
||||||
|
property.ConfigureEntity(configurator)
|
||||||
|
|
||||||
|
if config.PropertyTable == "" {
|
||||||
|
config.PropertyTable = configurator.table
|
||||||
|
}
|
||||||
|
if config.PropertyForeignKey == "" {
|
||||||
|
config.PropertyForeignKey = helper.NewPluralizeClient().Singular(ec.table) + "_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
ec.relations[configurator.table] = config
|
||||||
|
return
|
||||||
|
})
|
||||||
|
return ec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) BelongsTo(owner Entity, config BelongsToConfig) *EntityConfigurator {
|
||||||
|
if ec.relations == nil {
|
||||||
|
ec.relations = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
ec.resolveRelations = append(ec.resolveRelations, func() {
|
||||||
|
if config.ForeignColumnName != "" && config.LocalForeignKey != "" && config.OwnerTable != "" {
|
||||||
|
ec.relations[config.OwnerTable] = config
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ownerConfigurator := newEntityConfigurator()
|
||||||
|
owner.ConfigureEntity(ownerConfigurator)
|
||||||
|
if config.OwnerTable == "" {
|
||||||
|
config.OwnerTable = ownerConfigurator.table
|
||||||
|
}
|
||||||
|
if config.LocalForeignKey == "" {
|
||||||
|
config.LocalForeignKey = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id"
|
||||||
|
}
|
||||||
|
if config.ForeignColumnName == "" {
|
||||||
|
config.ForeignColumnName = "id"
|
||||||
|
}
|
||||||
|
ec.relations[ownerConfigurator.table] = config
|
||||||
|
})
|
||||||
|
return ec
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) BelongsToMany(owner Entity, config BelongsToManyConfig) *EntityConfigurator {
|
||||||
|
if ec.relations == nil {
|
||||||
|
ec.relations = map[string]interface{}{}
|
||||||
|
}
|
||||||
|
ec.resolveRelations = append(ec.resolveRelations, func() {
|
||||||
|
ownerConfigurator := newEntityConfigurator()
|
||||||
|
owner.ConfigureEntity(ownerConfigurator)
|
||||||
|
|
||||||
|
if config.OwnerLookupColumn == "" {
|
||||||
|
var pkName string
|
||||||
|
for _, field := range genericFieldsOf(owner) {
|
||||||
|
if field.IsPK {
|
||||||
|
pkName = field.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
config.OwnerLookupColumn = pkName
|
||||||
|
|
||||||
|
}
|
||||||
|
if config.OwnerTable == "" {
|
||||||
|
config.OwnerTable = ownerConfigurator.table
|
||||||
|
}
|
||||||
|
if config.IntermediateTable == "" {
|
||||||
|
panic("cannot infer intermediate table yet")
|
||||||
|
}
|
||||||
|
if config.IntermediatePropertyID == "" {
|
||||||
|
config.IntermediatePropertyID = helper.NewPluralizeClient().Singular(ownerConfigurator.table) + "_id"
|
||||||
|
}
|
||||||
|
if config.IntermediateOwnerID == "" {
|
||||||
|
config.IntermediateOwnerID = helper.NewPluralizeClient().Singular(ec.table) + "_id"
|
||||||
|
}
|
||||||
|
|
||||||
|
ec.relations[ownerConfigurator.table] = config
|
||||||
|
})
|
||||||
|
return ec
|
||||||
|
}
|
||||||
|
|
||||||
|
type FieldConfigurator struct {
|
||||||
|
fieldName string
|
||||||
|
nullable sql.NullBool
|
||||||
|
primaryKey bool
|
||||||
|
column string
|
||||||
|
isCreatedAt bool
|
||||||
|
isUpdatedAt bool
|
||||||
|
isDeletedAt bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ec *EntityConfigurator) Field(name string) *FieldConfigurator {
|
||||||
|
cc := &FieldConfigurator{fieldName: name}
|
||||||
|
ec.columnConstraints = append(ec.columnConstraints, cc)
|
||||||
|
return cc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *FieldConfigurator) IsPrimaryKey() *FieldConfigurator {
|
||||||
|
fc.primaryKey = true
|
||||||
|
return fc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *FieldConfigurator) IsCreatedAt() *FieldConfigurator {
|
||||||
|
fc.isCreatedAt = true
|
||||||
|
return fc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *FieldConfigurator) IsUpdatedAt() *FieldConfigurator {
|
||||||
|
fc.isUpdatedAt = true
|
||||||
|
return fc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *FieldConfigurator) IsDeletedAt() *FieldConfigurator {
|
||||||
|
fc.isDeletedAt = true
|
||||||
|
return fc
|
||||||
|
}
|
||||||
|
|
||||||
|
func (fc *FieldConfigurator) ColumnName(name string) *FieldConfigurator {
|
||||||
|
fc.column = name
|
||||||
|
return fc
|
||||||
|
}
|
160
gdb/orm/connection.go
Normal file
160
gdb/orm/connection.go
Normal file
@ -0,0 +1,160 @@
|
|||||||
|
//
|
||||||
|
// connection.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"git.hexq.cn/tiglog/golib/helper/table"
|
||||||
|
)
|
||||||
|
|
||||||
|
type connection struct {
|
||||||
|
Name string
|
||||||
|
Dialect *Dialect
|
||||||
|
DB *sql.DB
|
||||||
|
Schemas map[string]*schema
|
||||||
|
DBSchema map[string][]columnSpec
|
||||||
|
DatabaseValidations bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) inferedTables() []string {
|
||||||
|
var tables []string
|
||||||
|
for t, s := range c.Schemas {
|
||||||
|
tables = append(tables, t)
|
||||||
|
for _, relC := range s.relations {
|
||||||
|
if belongsToManyConfig, is := relC.(BelongsToManyConfig); is {
|
||||||
|
tables = append(tables, belongsToManyConfig.IntermediateTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return tables
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) validateAllTablesArePresent() error {
|
||||||
|
for _, inferedTable := range c.inferedTables() {
|
||||||
|
if _, exists := c.DBSchema[inferedTable]; !exists {
|
||||||
|
return fmt.Errorf("orm infered %s but it's not found in your database, your database is out of sync", inferedTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) validateTablesSchemas() error {
|
||||||
|
// check for entity tables: there should not be any struct field that does not have a coresponding column
|
||||||
|
for table, sc := range c.Schemas {
|
||||||
|
if columns, exists := c.DBSchema[table]; exists {
|
||||||
|
for _, f := range sc.fields {
|
||||||
|
found := false
|
||||||
|
for _, c := range columns {
|
||||||
|
if c.Name == f.Name {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("column %s not found while it was inferred", f.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return fmt.Errorf("tables are out of sync, %s was inferred but not present in database", table)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check for relation tables: for HasMany,HasOne relations check if OWNER pk column is in PROPERTY,
|
||||||
|
// for BelongsToMany check intermediate table has 2 pk for two entities
|
||||||
|
|
||||||
|
for table, sc := range c.Schemas {
|
||||||
|
for _, rel := range sc.relations {
|
||||||
|
switch rel.(type) {
|
||||||
|
case BelongsToConfig:
|
||||||
|
columns := c.DBSchema[table]
|
||||||
|
var found bool
|
||||||
|
for _, col := range columns {
|
||||||
|
if col.Name == rel.(BelongsToConfig).LocalForeignKey {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return fmt.Errorf("cannot find local foreign key %s for relation", rel.(BelongsToConfig).LocalForeignKey)
|
||||||
|
}
|
||||||
|
case BelongsToManyConfig:
|
||||||
|
columns := c.DBSchema[rel.(BelongsToManyConfig).IntermediateTable]
|
||||||
|
var foundOwner bool
|
||||||
|
var foundProperty bool
|
||||||
|
|
||||||
|
for _, col := range columns {
|
||||||
|
if col.Name == rel.(BelongsToManyConfig).IntermediateOwnerID {
|
||||||
|
foundOwner = true
|
||||||
|
}
|
||||||
|
if col.Name == rel.(BelongsToManyConfig).IntermediatePropertyID {
|
||||||
|
foundProperty = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !foundOwner || !foundProperty {
|
||||||
|
return fmt.Errorf("table schema for %s is not correct one of foreign keys is not present", rel.(BelongsToManyConfig).IntermediateTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) Schematic() {
|
||||||
|
fmt.Printf("SQL Dialect: %s\n", c.Dialect.DriverName)
|
||||||
|
for t, schema := range c.Schemas {
|
||||||
|
fmt.Printf("t: %s\n", t)
|
||||||
|
w := table.NewWriter()
|
||||||
|
w.AppendHeader(table.Row{"SQL Name", "Type", "Is Primary Key", "Is Virtual"})
|
||||||
|
for _, field := range schema.fields {
|
||||||
|
w.AppendRow(table.Row{field.Name, field.Type, field.IsPK, field.Virtual})
|
||||||
|
}
|
||||||
|
fmt.Println(w.Render())
|
||||||
|
for _, rel := range schema.relations {
|
||||||
|
switch rel.(type) {
|
||||||
|
case HasOneConfig:
|
||||||
|
fmt.Printf("%s 1-1 %s => %+v\n", t, rel.(HasOneConfig).PropertyTable, rel)
|
||||||
|
case HasManyConfig:
|
||||||
|
fmt.Printf("%s 1-N %s => %+v\n", t, rel.(HasManyConfig).PropertyTable, rel)
|
||||||
|
|
||||||
|
case BelongsToConfig:
|
||||||
|
fmt.Printf("%s N-1 %s => %+v\n", t, rel.(BelongsToConfig).OwnerTable, rel)
|
||||||
|
|
||||||
|
case BelongsToManyConfig:
|
||||||
|
fmt.Printf("%s N-N %s => %+v\n", t, rel.(BelongsToManyConfig).IntermediateTable, rel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Println("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) getSchema(t string) *schema {
|
||||||
|
return c.Schemas[t]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) setSchema(e Entity, s *schema) {
|
||||||
|
var configurator EntityConfigurator
|
||||||
|
e.ConfigureEntity(&configurator)
|
||||||
|
c.Schemas[configurator.table] = s
|
||||||
|
}
|
||||||
|
|
||||||
|
func GetConnection(name string) *connection {
|
||||||
|
return globalConnections[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) exec(q string, args ...any) (sql.Result, error) {
|
||||||
|
return c.DB.Exec(q, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) query(q string, args ...any) (*sql.Rows, error) {
|
||||||
|
return c.DB.Query(q, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *connection) queryRow(q string, args ...any) *sql.Row {
|
||||||
|
return c.DB.QueryRow(q, args...)
|
||||||
|
}
|
109
gdb/orm/dialect.go
Normal file
109
gdb/orm/dialect.go
Normal file
@ -0,0 +1,109 @@
|
|||||||
|
//
|
||||||
|
// dialect.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dialect struct {
|
||||||
|
DriverName string
|
||||||
|
PlaceholderChar string
|
||||||
|
IncludeIndexInPlaceholder bool
|
||||||
|
AddTableNameInSelectColumns bool
|
||||||
|
PlaceHolderGenerator func(n int) []string
|
||||||
|
QueryListTables string
|
||||||
|
QueryTableSchema string
|
||||||
|
}
|
||||||
|
|
||||||
|
func getListOfTables(query string) func(db *sql.DB) ([]string, error) {
|
||||||
|
return func(db *sql.DB) ([]string, error) {
|
||||||
|
rows, err := db.Query(query)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var tables []string
|
||||||
|
for rows.Next() {
|
||||||
|
var table string
|
||||||
|
err = rows.Scan(&table)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
tables = append(tables, table)
|
||||||
|
}
|
||||||
|
return tables, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type columnSpec struct {
|
||||||
|
//0|id|INTEGER|0||1
|
||||||
|
Name string
|
||||||
|
Type string
|
||||||
|
Nullable bool
|
||||||
|
DefaultValue sql.NullString
|
||||||
|
IsPrimaryKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func getTableSchema(query string) func(db *sql.DB, query string) ([]columnSpec, error) {
|
||||||
|
return func(db *sql.DB, table string) ([]columnSpec, error) {
|
||||||
|
rows, err := db.Query(fmt.Sprintf(query, table))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var output []columnSpec
|
||||||
|
for rows.Next() {
|
||||||
|
var cs columnSpec
|
||||||
|
var nullable string
|
||||||
|
var pk int
|
||||||
|
err = rows.Scan(&cs.Name, &cs.Type, &nullable, &cs.DefaultValue, &pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
cs.Nullable = nullable == "notnull"
|
||||||
|
cs.IsPrimaryKey = pk == 1
|
||||||
|
output = append(output, cs)
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var Dialects = &struct {
|
||||||
|
MySQL *Dialect
|
||||||
|
PostgreSQL *Dialect
|
||||||
|
SQLite3 *Dialect
|
||||||
|
}{
|
||||||
|
MySQL: &Dialect{
|
||||||
|
DriverName: "mysql",
|
||||||
|
PlaceholderChar: "?",
|
||||||
|
IncludeIndexInPlaceholder: false,
|
||||||
|
AddTableNameInSelectColumns: true,
|
||||||
|
PlaceHolderGenerator: questionMarks,
|
||||||
|
QueryListTables: "SHOW TABLES",
|
||||||
|
QueryTableSchema: "DESCRIBE %s",
|
||||||
|
},
|
||||||
|
PostgreSQL: &Dialect{
|
||||||
|
DriverName: "postgres",
|
||||||
|
PlaceholderChar: "$",
|
||||||
|
IncludeIndexInPlaceholder: true,
|
||||||
|
AddTableNameInSelectColumns: true,
|
||||||
|
PlaceHolderGenerator: postgresPlaceholder,
|
||||||
|
QueryListTables: `\dt`,
|
||||||
|
QueryTableSchema: `\d %s`,
|
||||||
|
},
|
||||||
|
SQLite3: &Dialect{
|
||||||
|
DriverName: "sqlite3",
|
||||||
|
PlaceholderChar: "?",
|
||||||
|
IncludeIndexInPlaceholder: false,
|
||||||
|
AddTableNameInSelectColumns: false,
|
||||||
|
PlaceHolderGenerator: questionMarks,
|
||||||
|
QueryListTables: "SELECT name FROM sqlite_schema WHERE type='table'",
|
||||||
|
QueryTableSchema: `SELECT name,type,"notnull","dflt_value","pk" FROM PRAGMA_TABLE_INFO('%s')`,
|
||||||
|
},
|
||||||
|
}
|
75
gdb/orm/field.go
Normal file
75
gdb/orm/field.go
Normal file
@ -0,0 +1,75 @@
|
|||||||
|
//
|
||||||
|
// field.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql/driver"
|
||||||
|
"reflect"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"git.hexq.cn/tiglog/golib/helper"
|
||||||
|
)
|
||||||
|
|
||||||
|
type field struct {
|
||||||
|
Name string
|
||||||
|
IsPK bool
|
||||||
|
Virtual bool
|
||||||
|
IsCreatedAt bool
|
||||||
|
IsUpdatedAt bool
|
||||||
|
IsDeletedAt bool
|
||||||
|
Nullable bool
|
||||||
|
Default any
|
||||||
|
Type reflect.Type
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFieldConfiguratorFor(fieldConfigurators []*FieldConfigurator, name string) *FieldConfigurator {
|
||||||
|
for _, fc := range fieldConfigurators {
|
||||||
|
if fc.fieldName == name {
|
||||||
|
return fc
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return &FieldConfigurator{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fieldMetadata(ft reflect.StructField, fieldConfigurators []*FieldConfigurator) []*field {
|
||||||
|
var fms []*field
|
||||||
|
fc := getFieldConfiguratorFor(fieldConfigurators, ft.Name)
|
||||||
|
baseFm := &field{}
|
||||||
|
baseFm.Type = ft.Type
|
||||||
|
fms = append(fms, baseFm)
|
||||||
|
if fc.column != "" {
|
||||||
|
baseFm.Name = fc.column
|
||||||
|
} else {
|
||||||
|
baseFm.Name = helper.SnakeString(ft.Name)
|
||||||
|
}
|
||||||
|
if strings.ToLower(ft.Name) == "id" || fc.primaryKey {
|
||||||
|
baseFm.IsPK = true
|
||||||
|
}
|
||||||
|
if strings.ToLower(ft.Name) == "createdat" || fc.isCreatedAt {
|
||||||
|
baseFm.IsCreatedAt = true
|
||||||
|
}
|
||||||
|
if strings.ToLower(ft.Name) == "updatedat" || fc.isUpdatedAt {
|
||||||
|
baseFm.IsUpdatedAt = true
|
||||||
|
}
|
||||||
|
if strings.ToLower(ft.Name) == "deletedat" || fc.isDeletedAt {
|
||||||
|
baseFm.IsDeletedAt = true
|
||||||
|
}
|
||||||
|
if ft.Type.Kind() == reflect.Struct || ft.Type.Kind() == reflect.Ptr {
|
||||||
|
t := ft.Type
|
||||||
|
if ft.Type.Kind() == reflect.Ptr {
|
||||||
|
t = ft.Type.Elem()
|
||||||
|
}
|
||||||
|
if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
fms = append(fms, fieldMetadata(t.Field(i), fieldConfigurators)...)
|
||||||
|
}
|
||||||
|
fms = fms[1:]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return fms
|
||||||
|
}
|
654
gdb/orm/orm.go
Normal file
654
gdb/orm/orm.go
Normal file
@ -0,0 +1,654 @@
|
|||||||
|
//
|
||||||
|
// orm.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
// Drivers
|
||||||
|
_ "github.com/go-sql-driver/mysql"
|
||||||
|
_ "github.com/lib/pq"
|
||||||
|
_ "github.com/mattn/go-sqlite3"
|
||||||
|
)
|
||||||
|
|
||||||
|
var globalConnections = map[string]*connection{}
|
||||||
|
|
||||||
|
// Schematic prints all information ORM inferred from your entities in startup, remember to pass
|
||||||
|
// your entities in Entities when you call SetupConnections if you want their data inferred
|
||||||
|
// otherwise Schematic does not print correct data since GoLobby ORM also
|
||||||
|
// incrementally cache your entities metadata and schema.
|
||||||
|
func Schematic() {
|
||||||
|
for conn, connObj := range globalConnections {
|
||||||
|
fmt.Printf("----------------%s---------------\n", conn)
|
||||||
|
connObj.Schematic()
|
||||||
|
fmt.Println("-----------------------------------")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ConnectionConfig struct {
|
||||||
|
// Name of your database connection, it's up to you to name them anything
|
||||||
|
// just remember that having a connection name is mandatory if
|
||||||
|
// you have multiple connections
|
||||||
|
Name string
|
||||||
|
// If you already have an active database connection configured pass it in this value and
|
||||||
|
// do not pass Driver and DSN fields.
|
||||||
|
DB *sql.DB
|
||||||
|
// Which dialect of sql to generate queries for, you don't need it most of the times when you are using
|
||||||
|
// traditional databases such as mysql, sqlite3, postgres.
|
||||||
|
Dialect *Dialect
|
||||||
|
// List of entities that you want to use for this connection, remember that you can ignore this field
|
||||||
|
// and GoLobby ORM will build our metadata cache incrementally but you will lose schematic
|
||||||
|
// information that we can provide you and also potentialy validations that we
|
||||||
|
// can do with the database
|
||||||
|
Entities []Entity
|
||||||
|
// Database validations, check if all tables exists and also table schemas contains all necessary columns.
|
||||||
|
// Check if all infered tables exist in your database
|
||||||
|
DatabaseValidations bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetupConnections declares a new connections for ORM.
|
||||||
|
func SetupConnections(configs ...ConnectionConfig) error {
|
||||||
|
|
||||||
|
for _, c := range configs {
|
||||||
|
if err := setupConnection(c); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, conn := range globalConnections {
|
||||||
|
if !conn.DatabaseValidations {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
tables, err := getListOfTables(conn.Dialect.QueryListTables)(conn.DB)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range tables {
|
||||||
|
if conn.DatabaseValidations {
|
||||||
|
spec, err := getTableSchema(conn.Dialect.QueryTableSchema)(conn.DB, table)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
conn.DBSchema[table] = spec
|
||||||
|
} else {
|
||||||
|
conn.DBSchema[table] = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// check tables existence
|
||||||
|
if conn.DatabaseValidations {
|
||||||
|
err := conn.validateAllTablesArePresent()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.DatabaseValidations {
|
||||||
|
err = conn.validateTablesSchemas()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupConnection(config ConnectionConfig) error {
|
||||||
|
schemas := map[string]*schema{}
|
||||||
|
if config.Name == "" {
|
||||||
|
config.Name = "default"
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, entity := range config.Entities {
|
||||||
|
s := schemaOfHeavyReflectionStuff(entity)
|
||||||
|
var configurator EntityConfigurator
|
||||||
|
entity.ConfigureEntity(&configurator)
|
||||||
|
schemas[configurator.table] = s
|
||||||
|
}
|
||||||
|
|
||||||
|
s := &connection{
|
||||||
|
Name: config.Name,
|
||||||
|
DB: config.DB,
|
||||||
|
Dialect: config.Dialect,
|
||||||
|
Schemas: schemas,
|
||||||
|
DBSchema: make(map[string][]columnSpec),
|
||||||
|
DatabaseValidations: config.DatabaseValidations,
|
||||||
|
}
|
||||||
|
|
||||||
|
globalConnections[fmt.Sprintf("%s", config.Name)] = s
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Entity defines the interface that each of your structs that
|
||||||
|
// you want to use as database entities should have,
|
||||||
|
// it's a simple one and its ConfigureEntity.
|
||||||
|
type Entity interface {
|
||||||
|
// ConfigureEntity should be defined for all of your database entities
|
||||||
|
// and it can define Table, DB and also relations of your Entity.
|
||||||
|
ConfigureEntity(e *EntityConfigurator)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InsertAll given entities into database based on their ConfigureEntity
|
||||||
|
// we can find table and also DB name.
|
||||||
|
func InsertAll(objs ...Entity) error {
|
||||||
|
if len(objs) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
s := getSchemaFor(objs[0])
|
||||||
|
cols := s.Columns(false)
|
||||||
|
var values [][]interface{}
|
||||||
|
for _, obj := range objs {
|
||||||
|
createdAtF := s.createdAt()
|
||||||
|
if createdAtF != nil {
|
||||||
|
genericSet(obj, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
||||||
|
}
|
||||||
|
updatedAtF := s.updatedAt()
|
||||||
|
if updatedAtF != nil {
|
||||||
|
genericSet(obj, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
||||||
|
}
|
||||||
|
values = append(values, genericValuesOf(obj, false))
|
||||||
|
}
|
||||||
|
|
||||||
|
is := insertStmt{
|
||||||
|
PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator,
|
||||||
|
Table: s.getTable(),
|
||||||
|
Columns: cols,
|
||||||
|
Values: values,
|
||||||
|
}
|
||||||
|
|
||||||
|
q, args := is.ToSql()
|
||||||
|
|
||||||
|
_, err := s.getConnection().exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Insert given entity into database based on their ConfigureEntity
|
||||||
|
// we can find table and also DB name.
|
||||||
|
func Insert(o Entity) error {
|
||||||
|
s := getSchemaFor(o)
|
||||||
|
cols := s.Columns(false)
|
||||||
|
var values [][]interface{}
|
||||||
|
createdAtF := s.createdAt()
|
||||||
|
if createdAtF != nil {
|
||||||
|
genericSet(o, createdAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
||||||
|
}
|
||||||
|
updatedAtF := s.updatedAt()
|
||||||
|
if updatedAtF != nil {
|
||||||
|
genericSet(o, updatedAtF.Name, sql.NullTime{Time: time.Now(), Valid: true})
|
||||||
|
}
|
||||||
|
values = append(values, genericValuesOf(o, false))
|
||||||
|
|
||||||
|
is := insertStmt{
|
||||||
|
PlaceHolderGenerator: s.getDialect().PlaceHolderGenerator,
|
||||||
|
Table: s.getTable(),
|
||||||
|
Columns: cols,
|
||||||
|
Values: values,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.getDialect().DriverName == "postgres" {
|
||||||
|
is.Returning = s.pkName()
|
||||||
|
}
|
||||||
|
q, args := is.ToSql()
|
||||||
|
|
||||||
|
res, err := s.getConnection().exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.pkName() != "" {
|
||||||
|
// intermediate tables usually have no single pk column.
|
||||||
|
s.setPK(o, id)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isZero(val interface{}) bool {
|
||||||
|
switch val.(type) {
|
||||||
|
case int64:
|
||||||
|
return val.(int64) == 0
|
||||||
|
case int:
|
||||||
|
return val.(int) == 0
|
||||||
|
case string:
|
||||||
|
return val.(string) == ""
|
||||||
|
default:
|
||||||
|
return reflect.ValueOf(val).Elem().IsZero()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Save saves given entity, if primary key is set
|
||||||
|
// we will make an update query and if
|
||||||
|
// primary key is zero value we will
|
||||||
|
// insert it.
|
||||||
|
func Save(obj Entity) error {
|
||||||
|
if isZero(getSchemaFor(obj).getPK(obj)) {
|
||||||
|
return Insert(obj)
|
||||||
|
} else {
|
||||||
|
return Update(obj)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find finds the Entity you want based on generic type and primary key you passed.
|
||||||
|
func Find[T Entity](id interface{}) (T, error) {
|
||||||
|
var q string
|
||||||
|
out := new(T)
|
||||||
|
md := getSchemaFor(*out)
|
||||||
|
q, args, err := NewQueryBuilder[T](md).
|
||||||
|
SetDialect(md.getDialect()).
|
||||||
|
Table(md.Table).
|
||||||
|
Select(md.Columns(true)...).
|
||||||
|
Where(md.pkName(), id).
|
||||||
|
ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return *out, err
|
||||||
|
}
|
||||||
|
err = bind[T](out, q, args)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return *out, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return *out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toKeyValues(obj Entity, withPK bool) []any {
|
||||||
|
var tuples []any
|
||||||
|
vs := genericValuesOf(obj, withPK)
|
||||||
|
cols := getSchemaFor(obj).Columns(withPK)
|
||||||
|
for i, col := range cols {
|
||||||
|
tuples = append(tuples, col, vs[i])
|
||||||
|
}
|
||||||
|
return tuples
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update given Entity in database.
|
||||||
|
func Update(obj Entity) error {
|
||||||
|
s := getSchemaFor(obj)
|
||||||
|
q, args, err := NewQueryBuilder[Entity](s).
|
||||||
|
SetDialect(s.getDialect()).
|
||||||
|
Set(toKeyValues(obj, false)...).
|
||||||
|
Where(s.pkName(), genericGetPKValue(obj)).Table(s.Table).ToSql()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.getConnection().exec(q, args...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete given Entity from database
|
||||||
|
func Delete(obj Entity) error {
|
||||||
|
s := getSchemaFor(obj)
|
||||||
|
genericSet(obj, "deleted_at", sql.NullTime{Time: time.Now(), Valid: true})
|
||||||
|
query, args, err := NewQueryBuilder[Entity](s).SetDialect(s.getDialect()).Table(s.Table).Where(s.pkName(), genericGetPKValue(obj)).SetDelete().ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = s.getConnection().exec(query, args...)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
func bind[T Entity](output interface{}, q string, args []interface{}) error {
|
||||||
|
outputMD := getSchemaFor(*new(T))
|
||||||
|
rows, err := outputMD.getConnection().query(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return newBinder(outputMD).bind(rows, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasManyConfig contains all information we need for querying HasMany relationships.
|
||||||
|
// We can infer both fields if you have them in standard way but you
|
||||||
|
// can specify them if you want custom ones.
|
||||||
|
type HasManyConfig struct {
|
||||||
|
// PropertyTable is table of the property of HasMany relationship,
|
||||||
|
// consider `Comment` in Post and Comment relationship,
|
||||||
|
// each Post HasMany Comment, so PropertyTable is
|
||||||
|
// `comments`.
|
||||||
|
PropertyTable string
|
||||||
|
// PropertyForeignKey is the foreign key field name in the property table,
|
||||||
|
// for example in Post HasMany Comment, if comment has `post_id` field,
|
||||||
|
// it's the PropertyForeignKey field.
|
||||||
|
PropertyForeignKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasMany configures a QueryBuilder for a HasMany relationship
|
||||||
|
// this relationship will be defined for owner argument
|
||||||
|
// that has many of PROPERTY generic type for example
|
||||||
|
// HasMany[Comment](&Post{})
|
||||||
|
// is for Post HasMany Comment relationship.
|
||||||
|
func HasMany[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] {
|
||||||
|
outSchema := getSchemaFor(*new(PROPERTY))
|
||||||
|
|
||||||
|
q := NewQueryBuilder[PROPERTY](outSchema)
|
||||||
|
// getting config from our cache
|
||||||
|
c, ok := getSchemaFor(owner).relations[outSchema.Table].(HasManyConfig)
|
||||||
|
if !ok {
|
||||||
|
q.err = fmt.Errorf("wrong config passed for HasMany")
|
||||||
|
}
|
||||||
|
|
||||||
|
s := getSchemaFor(owner)
|
||||||
|
return q.
|
||||||
|
SetDialect(s.getDialect()).
|
||||||
|
Table(c.PropertyTable).
|
||||||
|
Select(outSchema.Columns(true)...).
|
||||||
|
Where(c.PropertyForeignKey, genericGetPKValue(owner))
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasOneConfig contains all information we need for a HasOne relationship,
|
||||||
|
// it's similar to HasManyConfig.
|
||||||
|
type HasOneConfig struct {
|
||||||
|
// PropertyTable is table of the property of HasOne relationship,
|
||||||
|
// consider `HeaderPicture` in Post and HeaderPicture relationship,
|
||||||
|
// each Post HasOne HeaderPicture, so PropertyTable is
|
||||||
|
// `header_pictures`.
|
||||||
|
PropertyTable string
|
||||||
|
// PropertyForeignKey is the foreign key field name in the property table,
|
||||||
|
// forexample in Post HasOne HeaderPicture, if header_picture has `post_id` field,
|
||||||
|
// it's the PropertyForeignKey field.
|
||||||
|
PropertyForeignKey string
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasOne configures a QueryBuilder for a HasOne relationship
|
||||||
|
// this relationship will be defined for owner argument
|
||||||
|
// that has one of PROPERTY generic type for example
|
||||||
|
// HasOne[HeaderPicture](&Post{})
|
||||||
|
// is for Post HasOne HeaderPicture relationship.
|
||||||
|
func HasOne[PROPERTY Entity](owner Entity) *QueryBuilder[PROPERTY] {
|
||||||
|
property := getSchemaFor(*new(PROPERTY))
|
||||||
|
q := NewQueryBuilder[PROPERTY](property)
|
||||||
|
c, ok := getSchemaFor(owner).relations[property.Table].(HasOneConfig)
|
||||||
|
if !ok {
|
||||||
|
q.err = fmt.Errorf("wrong config passed for HasOne")
|
||||||
|
}
|
||||||
|
|
||||||
|
// settings default config Values
|
||||||
|
return q.
|
||||||
|
SetDialect(property.getDialect()).
|
||||||
|
Table(c.PropertyTable).
|
||||||
|
Select(property.Columns(true)...).
|
||||||
|
Where(c.PropertyForeignKey, genericGetPKValue(owner))
|
||||||
|
}
|
||||||
|
|
||||||
|
// BelongsToConfig contains all information we need for a BelongsTo relationship
|
||||||
|
// BelongsTo is a relationship between a Comment and it's Post,
|
||||||
|
// A Comment BelongsTo Post.
|
||||||
|
type BelongsToConfig struct {
|
||||||
|
// OwnerTable is the table that contains owner of a BelongsTo
|
||||||
|
// relationship.
|
||||||
|
OwnerTable string
|
||||||
|
// LocalForeignKey is name of the field that links property
|
||||||
|
// to its owner in BelongsTo relation. for example when
|
||||||
|
// a Comment BelongsTo Post, LocalForeignKey is
|
||||||
|
// post_id of Comment.
|
||||||
|
LocalForeignKey string
|
||||||
|
// ForeignColumnName is name of the field that LocalForeignKey
|
||||||
|
// field value will point to it, for example when
|
||||||
|
// a Comment BelongsTo Post, ForeignColumnName is
|
||||||
|
// id of Post.
|
||||||
|
ForeignColumnName string
|
||||||
|
}
|
||||||
|
|
||||||
|
// BelongsTo configures a QueryBuilder for a BelongsTo relationship between
|
||||||
|
// OWNER type parameter and property argument, so
|
||||||
|
// property BelongsTo OWNER.
|
||||||
|
func BelongsTo[OWNER Entity](property Entity) *QueryBuilder[OWNER] {
|
||||||
|
owner := getSchemaFor(*new(OWNER))
|
||||||
|
q := NewQueryBuilder[OWNER](owner)
|
||||||
|
c, ok := getSchemaFor(property).relations[owner.Table].(BelongsToConfig)
|
||||||
|
if !ok {
|
||||||
|
q.err = fmt.Errorf("wrong config passed for BelongsTo")
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerIDidx := 0
|
||||||
|
for idx, field := range owner.fields {
|
||||||
|
if field.Name == c.LocalForeignKey {
|
||||||
|
ownerIDidx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerID := genericValuesOf(property, true)[ownerIDidx]
|
||||||
|
|
||||||
|
return q.
|
||||||
|
SetDialect(owner.getDialect()).
|
||||||
|
Table(c.OwnerTable).Select(owner.Columns(true)...).
|
||||||
|
Where(c.ForeignColumnName, ownerID)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// BelongsToManyConfig contains information that we
|
||||||
|
// need for creating many to many queries.
|
||||||
|
type BelongsToManyConfig struct {
|
||||||
|
// IntermediateTable is the name of the middle table
|
||||||
|
// in a BelongsToMany (Many to Many) relationship.
|
||||||
|
// for example when we have Post BelongsToMany
|
||||||
|
// Category, this table will be post_categories
|
||||||
|
// table, remember that this field cannot be
|
||||||
|
// inferred.
|
||||||
|
IntermediateTable string
|
||||||
|
// IntermediatePropertyID is the name of the field name
|
||||||
|
// of property foreign key in intermediate table,
|
||||||
|
// for example when we have Post BelongsToMany
|
||||||
|
// Category, in post_categories table, it would
|
||||||
|
// be post_id.
|
||||||
|
IntermediatePropertyID string
|
||||||
|
// IntermediateOwnerID is the name of the field name
|
||||||
|
// of property foreign key in intermediate table,
|
||||||
|
// for example when we have Post BelongsToMany
|
||||||
|
// Category, in post_categories table, it would
|
||||||
|
// be category_id.
|
||||||
|
IntermediateOwnerID string
|
||||||
|
// Table name of the owner in BelongsToMany relation,
|
||||||
|
// for example in Post BelongsToMany Category
|
||||||
|
// Owner table is name of Category table
|
||||||
|
// for example `categories`.
|
||||||
|
OwnerTable string
|
||||||
|
// OwnerLookupColumn is name of the field in the owner
|
||||||
|
// table that is used in query, for example in Post BelongsToMany Category
|
||||||
|
// Owner lookup field would be Category primary key which is id.
|
||||||
|
OwnerLookupColumn string
|
||||||
|
}
|
||||||
|
|
||||||
|
// BelongsToMany configures a QueryBuilder for a BelongsToMany relationship
|
||||||
|
func BelongsToMany[OWNER Entity](property Entity) *QueryBuilder[OWNER] {
|
||||||
|
out := *new(OWNER)
|
||||||
|
outSchema := getSchemaFor(out)
|
||||||
|
q := NewQueryBuilder[OWNER](outSchema)
|
||||||
|
c, ok := getSchemaFor(property).relations[outSchema.Table].(BelongsToManyConfig)
|
||||||
|
if !ok {
|
||||||
|
q.err = fmt.Errorf("wrong config passed for HasMany")
|
||||||
|
}
|
||||||
|
return q.
|
||||||
|
Select(outSchema.Columns(true)...).
|
||||||
|
Table(outSchema.Table).
|
||||||
|
WhereIn(c.OwnerLookupColumn, Raw(fmt.Sprintf(`SELECT %s FROM %s WHERE %s = ?`,
|
||||||
|
c.IntermediatePropertyID,
|
||||||
|
c.IntermediateTable, c.IntermediateOwnerID), genericGetPKValue(property)))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add adds `items` to `to` using relations defined between items and to in ConfigureEntity method of `to`.
|
||||||
|
func Add(to Entity, items ...Entity) error {
|
||||||
|
if len(items) == 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
rels := getSchemaFor(to).relations
|
||||||
|
tname := getSchemaFor(items[0]).Table
|
||||||
|
c, ok := rels[tname]
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("no config found for given to and item...")
|
||||||
|
}
|
||||||
|
switch c.(type) {
|
||||||
|
case HasManyConfig:
|
||||||
|
return addProperty(to, items...)
|
||||||
|
case HasOneConfig:
|
||||||
|
return addProperty(to, items[0])
|
||||||
|
case BelongsToManyConfig:
|
||||||
|
return addM2M(to, items...)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("cannot add for relation: %T", rels[getSchemaFor(items[0]).Table])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func addM2M(to Entity, items ...Entity) error {
|
||||||
|
//TODO: Optimize this
|
||||||
|
rels := getSchemaFor(to).relations
|
||||||
|
tname := getSchemaFor(items[0]).Table
|
||||||
|
c := rels[tname].(BelongsToManyConfig)
|
||||||
|
var values [][]interface{}
|
||||||
|
ownerPk := genericGetPKValue(to)
|
||||||
|
for _, item := range items {
|
||||||
|
pk := genericGetPKValue(item)
|
||||||
|
if isZero(pk) {
|
||||||
|
err := Insert(item)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
pk = genericGetPKValue(item)
|
||||||
|
}
|
||||||
|
values = append(values, []interface{}{ownerPk, pk})
|
||||||
|
}
|
||||||
|
i := insertStmt{
|
||||||
|
PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator,
|
||||||
|
Table: c.IntermediateTable,
|
||||||
|
Columns: []string{c.IntermediateOwnerID, c.IntermediatePropertyID},
|
||||||
|
Values: values,
|
||||||
|
}
|
||||||
|
|
||||||
|
q, args := i.ToSql()
|
||||||
|
|
||||||
|
_, err := getConnectionFor(items[0]).DB.Exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
// addHasMany(Post, comments)
|
||||||
|
func addProperty(to Entity, items ...Entity) error {
|
||||||
|
var lastTable string
|
||||||
|
for _, obj := range items {
|
||||||
|
s := getSchemaFor(obj)
|
||||||
|
if lastTable == "" {
|
||||||
|
lastTable = s.Table
|
||||||
|
} else {
|
||||||
|
if lastTable != s.Table {
|
||||||
|
return fmt.Errorf("cannot batch insert for two different tables: %s and %s", s.Table, lastTable)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
i := insertStmt{
|
||||||
|
PlaceHolderGenerator: getSchemaFor(to).getDialect().PlaceHolderGenerator,
|
||||||
|
Table: getSchemaFor(items[0]).getTable(),
|
||||||
|
}
|
||||||
|
ownerPKIdx := -1
|
||||||
|
ownerPKName := getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey
|
||||||
|
for idx, col := range getSchemaFor(items[0]).Columns(false) {
|
||||||
|
if col == ownerPKName {
|
||||||
|
ownerPKIdx = idx
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ownerPK := genericGetPKValue(to)
|
||||||
|
if ownerPKIdx != -1 {
|
||||||
|
cols := getSchemaFor(items[0]).Columns(false)
|
||||||
|
i.Columns = append(i.Columns, cols...)
|
||||||
|
// Owner PK is present in the items struct
|
||||||
|
for _, item := range items {
|
||||||
|
vals := genericValuesOf(item, false)
|
||||||
|
if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey {
|
||||||
|
return fmt.Errorf("owner pk idx is not correct")
|
||||||
|
}
|
||||||
|
vals[ownerPKIdx] = ownerPK
|
||||||
|
i.Values = append(i.Values, vals)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
ownerPKIdx = 0
|
||||||
|
cols := getSchemaFor(items[0]).Columns(false)
|
||||||
|
cols = append(cols[:ownerPKIdx+1], cols[ownerPKIdx:]...)
|
||||||
|
cols[ownerPKIdx] = getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey
|
||||||
|
i.Columns = append(i.Columns, cols...)
|
||||||
|
for _, item := range items {
|
||||||
|
vals := genericValuesOf(item, false)
|
||||||
|
if cols[ownerPKIdx] != getSchemaFor(items[0]).relations[getSchemaFor(to).Table].(BelongsToConfig).LocalForeignKey {
|
||||||
|
return fmt.Errorf("owner pk idx is not correct")
|
||||||
|
}
|
||||||
|
vals = append(vals[:ownerPKIdx+1], vals[ownerPKIdx:]...)
|
||||||
|
vals[ownerPKIdx] = ownerPK
|
||||||
|
i.Values = append(i.Values, vals)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
q, args := i.ToSql()
|
||||||
|
|
||||||
|
_, err := getConnectionFor(items[0]).DB.Exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// Query creates a new QueryBuilder for given type parameter, sets dialect and table as well.
|
||||||
|
func Query[E Entity]() *QueryBuilder[E] {
|
||||||
|
s := getSchemaFor(*new(E))
|
||||||
|
q := NewQueryBuilder[E](s)
|
||||||
|
q.SetDialect(s.getDialect()).Table(s.Table)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecRaw executes given query string and arguments on given type parameter database connection.
|
||||||
|
func ExecRaw[E Entity](q string, args ...interface{}) (int64, int64, error) {
|
||||||
|
e := new(E)
|
||||||
|
|
||||||
|
res, err := getSchemaFor(*e).getSQLDB().Exec(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
id, err := res.LastInsertId()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
affected, err := res.RowsAffected()
|
||||||
|
if err != nil {
|
||||||
|
return 0, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return id, affected, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// QueryRaw queries given query string and arguments on given type parameter database connection.
|
||||||
|
func QueryRaw[OUTPUT Entity](q string, args ...interface{}) ([]OUTPUT, error) {
|
||||||
|
o := new(OUTPUT)
|
||||||
|
rows, err := getSchemaFor(*o).getSQLDB().Query(q, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var output []OUTPUT
|
||||||
|
err = newBinder(getSchemaFor(*o)).bind(rows, &output)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
588
gdb/orm/orm_test.go
Normal file
588
gdb/orm/orm_test.go
Normal file
@ -0,0 +1,588 @@
|
|||||||
|
//
|
||||||
|
// orm_test.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"git.hexq.cn/tiglog/golib/gdb/orm"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type AuthorEmail struct {
|
||||||
|
ID int64
|
||||||
|
Email string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a AuthorEmail) ConfigureEntity(e *orm.EntityConfigurator) {
|
||||||
|
e.
|
||||||
|
Table("emails").
|
||||||
|
Connection("default").
|
||||||
|
BelongsTo(&Post{}, orm.BelongsToConfig{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type HeaderPicture struct {
|
||||||
|
ID int64
|
||||||
|
PostID int64
|
||||||
|
Link string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h HeaderPicture) ConfigureEntity(e *orm.EntityConfigurator) {
|
||||||
|
e.Table("header_pictures").BelongsTo(&Post{}, orm.BelongsToConfig{})
|
||||||
|
}
|
||||||
|
|
||||||
|
type Post struct {
|
||||||
|
ID int64
|
||||||
|
BodyText string
|
||||||
|
CreatedAt sql.NullTime
|
||||||
|
UpdatedAt sql.NullTime
|
||||||
|
DeletedAt sql.NullTime
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p Post) ConfigureEntity(e *orm.EntityConfigurator) {
|
||||||
|
e.Field("BodyText").ColumnName("body")
|
||||||
|
e.Field("ID").ColumnName("id")
|
||||||
|
e.
|
||||||
|
Table("posts").
|
||||||
|
HasMany(Comment{}, orm.HasManyConfig{}).
|
||||||
|
HasOne(HeaderPicture{}, orm.HasOneConfig{}).
|
||||||
|
HasOne(AuthorEmail{}, orm.HasOneConfig{}).
|
||||||
|
BelongsToMany(Category{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Post) Categories() ([]Category, error) {
|
||||||
|
return orm.BelongsToMany[Category](p).All()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *Post) Comments() *orm.QueryBuilder[Comment] {
|
||||||
|
return orm.HasMany[Comment](p)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Comment struct {
|
||||||
|
ID int64
|
||||||
|
PostID int64
|
||||||
|
Body string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Comment) ConfigureEntity(e *orm.EntityConfigurator) {
|
||||||
|
e.Table("comments").BelongsTo(&Post{}, orm.BelongsToConfig{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Comment) Post() (Post, error) {
|
||||||
|
return orm.BelongsTo[Post](c).Get()
|
||||||
|
}
|
||||||
|
|
||||||
|
type Category struct {
|
||||||
|
ID int64
|
||||||
|
Title string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Category) ConfigureEntity(e *orm.EntityConfigurator) {
|
||||||
|
e.Table("categories").BelongsToMany(Post{}, orm.BelongsToManyConfig{IntermediateTable: "post_categories"})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c Category) Posts() ([]Post, error) {
|
||||||
|
return orm.BelongsToMany[Post](c).All()
|
||||||
|
}
|
||||||
|
|
||||||
|
// enough models let's test
|
||||||
|
// Entities is mandatory
|
||||||
|
// Errors should be carried
|
||||||
|
|
||||||
|
func setup() error {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return orm.SetupConnections(orm.ConnectionConfig{
|
||||||
|
Name: "default",
|
||||||
|
DB: db,
|
||||||
|
Dialect: orm.Dialects.SQLite3,
|
||||||
|
Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}},
|
||||||
|
DatabaseValidations: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFind(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
err = orm.InsertAll(&Post{
|
||||||
|
BodyText: "my body for insert",
|
||||||
|
})
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post, err := orm.Find[Post](1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "my body for insert", post.BodyText)
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInsert(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "my body for insert",
|
||||||
|
}
|
||||||
|
err = orm.Insert(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
var p Post
|
||||||
|
assert.NoError(t,
|
||||||
|
orm.GetConnection("default").DB.QueryRow(`SELECT id, body FROM posts where id = ?`, 1).Scan(&p.ID, &p.BodyText))
|
||||||
|
|
||||||
|
assert.Equal(t, "my body for insert", p.BodyText)
|
||||||
|
}
|
||||||
|
func TestInsertAll(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post1 := &Post{
|
||||||
|
BodyText: "Body1",
|
||||||
|
}
|
||||||
|
post2 := &Post{
|
||||||
|
BodyText: "Body2",
|
||||||
|
}
|
||||||
|
|
||||||
|
post3 := &Post{
|
||||||
|
BodyText: "Body3",
|
||||||
|
}
|
||||||
|
|
||||||
|
err = orm.InsertAll(post1, post2, post3)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var counter int
|
||||||
|
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts`).Scan(&counter))
|
||||||
|
assert.Equal(t, 3, counter)
|
||||||
|
|
||||||
|
}
|
||||||
|
func TestUpdateORM(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "my body for insert",
|
||||||
|
}
|
||||||
|
err = orm.Insert(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
post.BodyText += " update text"
|
||||||
|
assert.NoError(t, orm.Update(post))
|
||||||
|
|
||||||
|
var body string
|
||||||
|
assert.NoError(t,
|
||||||
|
orm.GetConnection("default").DB.QueryRow(`SELECT body FROM posts where id = ?`, post.ID).Scan(&body))
|
||||||
|
|
||||||
|
assert.Equal(t, "my body for insert update text", body)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDeleteORM(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "my body for insert",
|
||||||
|
}
|
||||||
|
err = orm.Insert(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Delete(post))
|
||||||
|
|
||||||
|
var count int
|
||||||
|
assert.NoError(t,
|
||||||
|
orm.GetConnection("default").DB.QueryRow(`SELECT count(id) FROM posts where id = ?`, post.ID).Scan(&count))
|
||||||
|
|
||||||
|
assert.Equal(t, 0, count)
|
||||||
|
}
|
||||||
|
func TestAdd_HasMany(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "my body for insert",
|
||||||
|
}
|
||||||
|
err = orm.Insert(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
err = orm.Add(post, []orm.Entity{
|
||||||
|
Comment{
|
||||||
|
Body: "comment 1",
|
||||||
|
},
|
||||||
|
Comment{
|
||||||
|
Body: "comment 2",
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
// orm.Query(qm.WhereBetween())
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var count int
|
||||||
|
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM comments`).Scan(&count))
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
comment, err := orm.Find[Comment](1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, int64(1), comment.PostID)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAdd_ManyToMany(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "my body for insert",
|
||||||
|
}
|
||||||
|
err = orm.Insert(post)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
err = orm.Add(post, []orm.Entity{
|
||||||
|
&Category{
|
||||||
|
Title: "cat 1",
|
||||||
|
},
|
||||||
|
&Category{
|
||||||
|
Title: "cat 2",
|
||||||
|
},
|
||||||
|
}...)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
var count int
|
||||||
|
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(post_id) FROM post_categories`).Scan(&count))
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
assert.NoError(t, orm.GetConnection("default").DB.QueryRow(`SELECT COUNT(id) FROM categories`).Scan(&count))
|
||||||
|
assert.Equal(t, 2, count)
|
||||||
|
|
||||||
|
categories, err := post.Categories()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, 2, len(categories))
|
||||||
|
assert.Equal(t, int64(1), categories[0].ID)
|
||||||
|
assert.Equal(t, int64(2), categories[1].ID)
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSave(t *testing.T) {
|
||||||
|
t.Run("save should insert", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "1",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("save should update", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "1",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
post.BodyText += "2"
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
|
||||||
|
myPost, err := orm.Find[Post](1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.EqualValues(t, post.BodyText, myPost.BodyText)
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasMany(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "first post",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(&Comment{
|
||||||
|
PostID: post.ID,
|
||||||
|
Body: "comment 1",
|
||||||
|
}))
|
||||||
|
assert.NoError(t, orm.Save(&Comment{
|
||||||
|
PostID: post.ID,
|
||||||
|
Body: "comment 2",
|
||||||
|
}))
|
||||||
|
|
||||||
|
comments, err := orm.HasMany[Comment](post).All()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, comments, 2)
|
||||||
|
|
||||||
|
assert.Equal(t, post.ID, comments[0].PostID)
|
||||||
|
assert.Equal(t, post.ID, comments[1].PostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsTo(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "first post",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
comment := &Comment{
|
||||||
|
PostID: post.ID,
|
||||||
|
Body: "comment 1",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(comment))
|
||||||
|
|
||||||
|
post2, err := orm.BelongsTo[Post](comment).Get()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, post.BodyText, post2.BodyText)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHasOne(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "first post",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
headerPicture := &HeaderPicture{
|
||||||
|
PostID: post.ID,
|
||||||
|
Link: "google",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(headerPicture))
|
||||||
|
|
||||||
|
c1, err := orm.HasOne[HeaderPicture](post).Get()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, headerPicture.PostID, c1.PostID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBelongsToMany(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "first Post",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.Equal(t, int64(1), post.ID)
|
||||||
|
|
||||||
|
category := &Category{
|
||||||
|
Title: "first category",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(category))
|
||||||
|
assert.Equal(t, int64(1), category.ID)
|
||||||
|
|
||||||
|
_, _, err = orm.ExecRaw[Category](`INSERT INTO post_categories (post_id, category_id) VALUES (?,?)`, post.ID, category.ID)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
categories, err := orm.BelongsToMany[Category](post).All()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Len(t, categories, 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSchematic(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
orm.Schematic()
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAddProperty(t *testing.T) {
|
||||||
|
t.Run("having pk value", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "first post",
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.EqualValues(t, 1, post.ID)
|
||||||
|
|
||||||
|
err = orm.Add(post, &Comment{PostID: post.ID, Body: "firstComment"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var comment Comment
|
||||||
|
assert.NoError(t, orm.GetConnection("default").
|
||||||
|
DB.
|
||||||
|
QueryRow(`SELECT id, post_id, body FROM comments WHERE post_id=?`, post.ID).
|
||||||
|
Scan(&comment.ID, &comment.PostID, &comment.Body))
|
||||||
|
|
||||||
|
assert.EqualValues(t, post.ID, comment.PostID)
|
||||||
|
})
|
||||||
|
t.Run("not having PK value", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
post := &Post{
|
||||||
|
BodyText: "first post",
|
||||||
|
}
|
||||||
|
assert.NoError(t, orm.Save(post))
|
||||||
|
assert.EqualValues(t, 1, post.ID)
|
||||||
|
|
||||||
|
err = orm.Add(post, &AuthorEmail{Email: "myemail"})
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
emails, err := orm.QueryRaw[AuthorEmail](`SELECT id, email FROM emails WHERE post_id=?`, post.ID)
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, []AuthorEmail{{ID: 1, Email: "myemail"}}, emails)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestQuery(t *testing.T) {
|
||||||
|
t.Run("querying single row", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
||||||
|
// post, err := orm.Query[Post]().Where("id", 1).First()
|
||||||
|
post, err := orm.Query[Post]().WherePK(1).First().Get()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, "body 1", post.BodyText)
|
||||||
|
assert.EqualValues(t, 1, post.ID)
|
||||||
|
|
||||||
|
})
|
||||||
|
t.Run("querying multiple rows", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 2"}))
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 3"}))
|
||||||
|
posts, err := orm.Query[Post]().All()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Len(t, posts, 3)
|
||||||
|
assert.Equal(t, "body 1", posts[0].BodyText)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("updating a row using query interface", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
||||||
|
|
||||||
|
affected, err := orm.Query[Post]().Where("id", 1).Set("body", "body jadid").Update()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, affected)
|
||||||
|
|
||||||
|
post, err := orm.Find[Post](1)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "body jadid", post.BodyText)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("deleting a row using query interface", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
||||||
|
|
||||||
|
affected, err := orm.Query[Post]().WherePK(1).Delete()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 1, affected)
|
||||||
|
count, err := orm.Query[Post]().WherePK(1).Count().Get()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 0, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("count", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
count, err := orm.Query[Post]().WherePK(1).Count().Get()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, 0, count)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("latest", func(t *testing.T) {
|
||||||
|
err := setup()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 1"}))
|
||||||
|
assert.NoError(t, orm.Save(&Post{BodyText: "body 2"}))
|
||||||
|
|
||||||
|
post, err := orm.Query[Post]().Latest().Get()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
assert.EqualValues(t, "body 2", post.BodyText)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetup(t *testing.T) {
|
||||||
|
t.Run("tables are out of sync", func(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
// _, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
||||||
|
// _, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
||||||
|
|
||||||
|
err = orm.SetupConnections(orm.ConnectionConfig{
|
||||||
|
Name: "default",
|
||||||
|
DB: db,
|
||||||
|
Dialect: orm.Dialects.SQLite3,
|
||||||
|
Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}},
|
||||||
|
DatabaseValidations: true,
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
})
|
||||||
|
t.Run("schemas are wrong", func(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, body text)`) // missing post_id
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
||||||
|
_, err = db.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
||||||
|
|
||||||
|
err = orm.SetupConnections(orm.ConnectionConfig{
|
||||||
|
Name: "default",
|
||||||
|
DB: db,
|
||||||
|
Dialect: orm.Dialects.SQLite3,
|
||||||
|
Entities: []orm.Entity{&Post{}, &Comment{}, &Category{}, &HeaderPicture{}},
|
||||||
|
DatabaseValidations: true,
|
||||||
|
})
|
||||||
|
assert.Error(t, err)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
811
gdb/orm/query.go
Normal file
811
gdb/orm/query.go
Normal file
@ -0,0 +1,811 @@
|
|||||||
|
//
|
||||||
|
// query.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
queryTypeSELECT = iota + 1
|
||||||
|
queryTypeUPDATE
|
||||||
|
queryTypeDelete
|
||||||
|
)
|
||||||
|
|
||||||
|
// QueryBuilder is our query builder, almost all methods and functions in GoLobby ORM
|
||||||
|
// create or configure instance of QueryBuilder.
|
||||||
|
type QueryBuilder[OUTPUT any] struct {
|
||||||
|
typ int
|
||||||
|
schema *schema
|
||||||
|
// general parts
|
||||||
|
where *whereClause
|
||||||
|
table string
|
||||||
|
placeholderGenerator func(n int) []string
|
||||||
|
|
||||||
|
// select parts
|
||||||
|
orderBy *orderByClause
|
||||||
|
groupBy *GroupBy
|
||||||
|
selected *selected
|
||||||
|
subQuery *struct {
|
||||||
|
q string
|
||||||
|
args []interface{}
|
||||||
|
placeholderGenerator func(n int) []string
|
||||||
|
}
|
||||||
|
joins []*Join
|
||||||
|
limit *Limit
|
||||||
|
offset *Offset
|
||||||
|
|
||||||
|
// update parts
|
||||||
|
sets [][2]interface{}
|
||||||
|
|
||||||
|
// execution parts
|
||||||
|
db *sql.DB
|
||||||
|
err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Finisher APIs
|
||||||
|
|
||||||
|
// execute is a finisher executes QueryBuilder query, remember to use this when you have an Update
|
||||||
|
// or Delete Query.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) execute() (sql.Result, error) {
|
||||||
|
if q.err != nil {
|
||||||
|
return nil, q.err
|
||||||
|
}
|
||||||
|
if q.typ == queryTypeSELECT {
|
||||||
|
return nil, fmt.Errorf("query type is SELECT")
|
||||||
|
}
|
||||||
|
query, args, err := q.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return q.schema.getConnection().exec(query, args...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get limit results to 1, runs query generated by query builder, scans result into OUTPUT.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Get() (OUTPUT, error) {
|
||||||
|
if q.err != nil {
|
||||||
|
return *new(OUTPUT), q.err
|
||||||
|
}
|
||||||
|
queryString, args, err := q.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return *new(OUTPUT), err
|
||||||
|
}
|
||||||
|
rows, err := q.schema.getConnection().query(queryString, args...)
|
||||||
|
if err != nil {
|
||||||
|
return *new(OUTPUT), err
|
||||||
|
}
|
||||||
|
var output OUTPUT
|
||||||
|
err = newBinder(q.schema).bind(rows, &output)
|
||||||
|
if err != nil {
|
||||||
|
return *new(OUTPUT), err
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// All is a finisher, create the Select query based on QueryBuilder and scan results into
|
||||||
|
// slice of type parameter E.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) All() ([]OUTPUT, error) {
|
||||||
|
if q.err != nil {
|
||||||
|
return nil, q.err
|
||||||
|
}
|
||||||
|
q.SetSelect()
|
||||||
|
queryString, args, err := q.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
rows, err := q.schema.getConnection().query(queryString, args...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var output []OUTPUT
|
||||||
|
err = newBinder(q.schema).bind(rows, &output)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return output, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Delete is a finisher, creates a delete query from query builder and executes it.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Delete() (rowsAffected int64, err error) {
|
||||||
|
if q.err != nil {
|
||||||
|
return 0, q.err
|
||||||
|
}
|
||||||
|
q.SetDelete()
|
||||||
|
res, err := q.execute()
|
||||||
|
if err != nil {
|
||||||
|
return 0, q.err
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update is a finisher, creates an Update query from QueryBuilder and executes in into database, returns
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Update() (rowsAffected int64, err error) {
|
||||||
|
if q.err != nil {
|
||||||
|
return 0, q.err
|
||||||
|
}
|
||||||
|
q.SetUpdate()
|
||||||
|
res, err := q.execute()
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
return res.RowsAffected()
|
||||||
|
}
|
||||||
|
|
||||||
|
func copyQueryBuilder[T1 any, T2 any](q *QueryBuilder[T1], q2 *QueryBuilder[T2]) {
|
||||||
|
q2.db = q.db
|
||||||
|
q2.err = q.err
|
||||||
|
q2.groupBy = q.groupBy
|
||||||
|
q2.joins = q.joins
|
||||||
|
q2.limit = q.limit
|
||||||
|
q2.offset = q.offset
|
||||||
|
q2.orderBy = q.orderBy
|
||||||
|
q2.placeholderGenerator = q.placeholderGenerator
|
||||||
|
q2.schema = q.schema
|
||||||
|
q2.selected = q.selected
|
||||||
|
q2.sets = q.sets
|
||||||
|
|
||||||
|
q2.subQuery = q.subQuery
|
||||||
|
q2.table = q.table
|
||||||
|
q2.typ = q.typ
|
||||||
|
q2.where = q.where
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count creates and execute a select query from QueryBuilder and set it's field list of selection
|
||||||
|
// to COUNT(id).
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Count() *QueryBuilder[int] {
|
||||||
|
q.selected = &selected{Columns: []string{"COUNT(id)"}}
|
||||||
|
q.SetSelect()
|
||||||
|
qCount := NewQueryBuilder[int](q.schema)
|
||||||
|
|
||||||
|
copyQueryBuilder(q, qCount)
|
||||||
|
|
||||||
|
return qCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// First returns first record of database using OrderBy primary key
|
||||||
|
// ascending order.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) First() *QueryBuilder[OUTPUT] {
|
||||||
|
q.OrderBy(q.schema.pkName(), ASC).Limit(1)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Latest is like Get but it also do a OrderBy(primary key, DESC)
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Latest() *QueryBuilder[OUTPUT] {
|
||||||
|
q.OrderBy(q.schema.pkName(), DESC).Limit(1)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// WherePK adds a where clause to QueryBuilder and also gets primary key name
|
||||||
|
// from type parameter schema.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) WherePK(value interface{}) *QueryBuilder[OUTPUT] {
|
||||||
|
return q.Where(q.schema.pkName(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *QueryBuilder[OUTPUT]) toSqlDelete() (string, []interface{}, error) {
|
||||||
|
base := fmt.Sprintf("DELETE FROM %s", d.table)
|
||||||
|
var args []interface{}
|
||||||
|
if d.where != nil {
|
||||||
|
d.where.PlaceHolderGenerator = d.placeholderGenerator
|
||||||
|
where, whereArgs, err := d.where.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
base += " WHERE " + where
|
||||||
|
args = append(args, whereArgs...)
|
||||||
|
}
|
||||||
|
return base, args, nil
|
||||||
|
}
|
||||||
|
func pop(phs *[]string) string {
|
||||||
|
top := (*phs)[len(*phs)-1]
|
||||||
|
*phs = (*phs)[:len(*phs)-1]
|
||||||
|
return top
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *QueryBuilder[OUTPUT]) kvString() string {
|
||||||
|
phs := u.placeholderGenerator(len(u.sets))
|
||||||
|
var sets []string
|
||||||
|
for _, pair := range u.sets {
|
||||||
|
sets = append(sets, fmt.Sprintf("%s=%s", pair[0], pop(&phs)))
|
||||||
|
}
|
||||||
|
return strings.Join(sets, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *QueryBuilder[OUTPUT]) args() []interface{} {
|
||||||
|
var values []interface{}
|
||||||
|
for _, pair := range u.sets {
|
||||||
|
values = append(values, pair[1])
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *QueryBuilder[OUTPUT]) toSqlUpdate() (string, []interface{}, error) {
|
||||||
|
if u.table == "" {
|
||||||
|
return "", nil, fmt.Errorf("table cannot be empty")
|
||||||
|
}
|
||||||
|
base := fmt.Sprintf("UPDATE %s SET %s", u.table, u.kvString())
|
||||||
|
args := u.args()
|
||||||
|
if u.where != nil {
|
||||||
|
u.where.PlaceHolderGenerator = u.placeholderGenerator
|
||||||
|
where, whereArgs, err := u.where.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
args = append(args, whereArgs...)
|
||||||
|
base += " WHERE " + where
|
||||||
|
}
|
||||||
|
return base, args, nil
|
||||||
|
}
|
||||||
|
func (s *QueryBuilder[OUTPUT]) toSqlSelect() (string, []interface{}, error) {
|
||||||
|
if s.err != nil {
|
||||||
|
return "", nil, s.err
|
||||||
|
}
|
||||||
|
base := "SELECT"
|
||||||
|
var args []interface{}
|
||||||
|
// select
|
||||||
|
if s.selected == nil {
|
||||||
|
s.selected = &selected{
|
||||||
|
Columns: []string{"*"},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
base += " " + s.selected.String()
|
||||||
|
// from
|
||||||
|
if s.table == "" && s.subQuery == nil {
|
||||||
|
return "", nil, fmt.Errorf("Table name cannot be empty")
|
||||||
|
} else if s.table != "" && s.subQuery != nil {
|
||||||
|
return "", nil, fmt.Errorf("cannot have both Table and subquery")
|
||||||
|
}
|
||||||
|
if s.table != "" {
|
||||||
|
base += " " + "FROM " + s.table
|
||||||
|
}
|
||||||
|
if s.subQuery != nil {
|
||||||
|
s.subQuery.placeholderGenerator = s.placeholderGenerator
|
||||||
|
base += " " + "FROM (" + s.subQuery.q + " )"
|
||||||
|
args = append(args, s.subQuery.args...)
|
||||||
|
}
|
||||||
|
// Joins
|
||||||
|
if s.joins != nil {
|
||||||
|
for _, join := range s.joins {
|
||||||
|
base += " " + join.String()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// whereClause
|
||||||
|
if s.where != nil {
|
||||||
|
s.where.PlaceHolderGenerator = s.placeholderGenerator
|
||||||
|
where, whereArgs, err := s.where.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
base += " WHERE " + where
|
||||||
|
args = append(args, whereArgs...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// orderByClause
|
||||||
|
if s.orderBy != nil {
|
||||||
|
base += " " + s.orderBy.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupBy
|
||||||
|
if s.groupBy != nil {
|
||||||
|
base += " " + s.groupBy.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit
|
||||||
|
if s.limit != nil {
|
||||||
|
base += " " + s.limit.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset
|
||||||
|
if s.offset != nil {
|
||||||
|
base += " " + s.offset.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
return base, args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToSql creates sql query from QueryBuilder based on internal fields it would decide what kind
|
||||||
|
// of query to build.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) ToSql() (string, []interface{}, error) {
|
||||||
|
if q.err != nil {
|
||||||
|
return "", nil, q.err
|
||||||
|
}
|
||||||
|
if q.typ == queryTypeSELECT {
|
||||||
|
return q.toSqlSelect()
|
||||||
|
} else if q.typ == queryTypeDelete {
|
||||||
|
return q.toSqlDelete()
|
||||||
|
} else if q.typ == queryTypeUPDATE {
|
||||||
|
return q.toSqlUpdate()
|
||||||
|
} else {
|
||||||
|
return "", nil, fmt.Errorf("no sql type matched")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type orderByOrder string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ASC = "ASC"
|
||||||
|
DESC = "DESC"
|
||||||
|
)
|
||||||
|
|
||||||
|
type orderByClause struct {
|
||||||
|
Columns [][2]string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o orderByClause) String() string {
|
||||||
|
var tuples []string
|
||||||
|
for _, pair := range o.Columns {
|
||||||
|
tuples = append(tuples, fmt.Sprintf("%s %s", pair[0], pair[1]))
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("ORDER BY %s", strings.Join(tuples, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
type GroupBy struct {
|
||||||
|
Columns []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g GroupBy) String() string {
|
||||||
|
return fmt.Sprintf("GROUP BY %s", strings.Join(g.Columns, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
type joinType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
JoinTypeInner = "INNER"
|
||||||
|
JoinTypeLeft = "LEFT"
|
||||||
|
JoinTypeRight = "RIGHT"
|
||||||
|
JoinTypeFull = "FULL OUTER"
|
||||||
|
JoinTypeSelf = "SELF"
|
||||||
|
)
|
||||||
|
|
||||||
|
type JoinOn struct {
|
||||||
|
Lhs string
|
||||||
|
Rhs string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j JoinOn) String() string {
|
||||||
|
return fmt.Sprintf("%s = %s", j.Lhs, j.Rhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Join struct {
|
||||||
|
Type joinType
|
||||||
|
Table string
|
||||||
|
On JoinOn
|
||||||
|
}
|
||||||
|
|
||||||
|
func (j Join) String() string {
|
||||||
|
return fmt.Sprintf("%s JOIN %s ON %s", j.Type, j.Table, j.On.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
type Limit struct {
|
||||||
|
N int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l Limit) String() string {
|
||||||
|
return fmt.Sprintf("LIMIT %d", l.N)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Offset struct {
|
||||||
|
N int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Offset) String() string {
|
||||||
|
return fmt.Sprintf("OFFSET %d", o.N)
|
||||||
|
}
|
||||||
|
|
||||||
|
type selected struct {
|
||||||
|
Columns []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s selected) String() string {
|
||||||
|
return fmt.Sprintf("%s", strings.Join(s.Columns, ","))
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrderBy adds an OrderBy section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) OrderBy(column string, how string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
if q.orderBy == nil {
|
||||||
|
q.orderBy = &orderByClause{}
|
||||||
|
}
|
||||||
|
q.orderBy.Columns = append(q.orderBy.Columns, [2]string{column, how})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// LeftJoin adds a left join section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) LeftJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
q.joins = append(q.joins, &Join{
|
||||||
|
Type: JoinTypeLeft,
|
||||||
|
Table: table,
|
||||||
|
On: JoinOn{
|
||||||
|
Lhs: onLhs,
|
||||||
|
Rhs: onRhs,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// RightJoin adds a right join section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) RightJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
q.joins = append(q.joins, &Join{
|
||||||
|
Type: JoinTypeRight,
|
||||||
|
Table: table,
|
||||||
|
On: JoinOn{
|
||||||
|
Lhs: onLhs,
|
||||||
|
Rhs: onRhs,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// InnerJoin adds a inner join section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) InnerJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
q.joins = append(q.joins, &Join{
|
||||||
|
Type: JoinTypeInner,
|
||||||
|
Table: table,
|
||||||
|
On: JoinOn{
|
||||||
|
Lhs: onLhs,
|
||||||
|
Rhs: onRhs,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Join adds a inner join section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Join(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
||||||
|
return q.InnerJoin(table, onLhs, onRhs)
|
||||||
|
}
|
||||||
|
|
||||||
|
// FullOuterJoin adds a full outer join section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) FullOuterJoin(table string, onLhs string, onRhs string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
q.joins = append(q.joins, &Join{
|
||||||
|
Type: JoinTypeFull,
|
||||||
|
Table: table,
|
||||||
|
On: JoinOn{
|
||||||
|
Lhs: onLhs,
|
||||||
|
Rhs: onRhs,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Where Adds a where clause to query, if already have where clause append to it
|
||||||
|
// as AndWhere.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Where(parts ...interface{}) *QueryBuilder[OUTPUT] {
|
||||||
|
if q.where != nil {
|
||||||
|
return q.addWhere("AND", parts...)
|
||||||
|
}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
if r, isRaw := parts[0].(*raw); isRaw {
|
||||||
|
q.where = &whereClause{raw: r.sql, args: r.args, PlaceHolderGenerator: q.placeholderGenerator}
|
||||||
|
return q
|
||||||
|
} else {
|
||||||
|
q.err = fmt.Errorf("when you have one argument passed to where, it should be *raw")
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
} else if len(parts) == 2 {
|
||||||
|
if strings.Index(parts[0].(string), " ") == -1 {
|
||||||
|
// Equal mode
|
||||||
|
q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]}, PlaceHolderGenerator: q.placeholderGenerator}
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
} else if len(parts) == 3 {
|
||||||
|
// operator mode
|
||||||
|
q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]}, PlaceHolderGenerator: q.placeholderGenerator}
|
||||||
|
return q
|
||||||
|
} else if len(parts) > 3 && parts[1].(string) == "IN" {
|
||||||
|
q.where = &whereClause{cond: cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2:]}, PlaceHolderGenerator: q.placeholderGenerator}
|
||||||
|
return q
|
||||||
|
} else {
|
||||||
|
q.err = fmt.Errorf("wrong number of arguments passed to Where")
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type binaryOp string
|
||||||
|
|
||||||
|
const (
|
||||||
|
Eq = "="
|
||||||
|
GT = ">"
|
||||||
|
LT = "<"
|
||||||
|
GE = ">="
|
||||||
|
LE = "<="
|
||||||
|
NE = "!="
|
||||||
|
Between = "BETWEEN"
|
||||||
|
Like = "LIKE"
|
||||||
|
In = "IN"
|
||||||
|
)
|
||||||
|
|
||||||
|
type cond struct {
|
||||||
|
PlaceHolderGenerator func(n int) []string
|
||||||
|
|
||||||
|
Lhs string
|
||||||
|
Op binaryOp
|
||||||
|
Rhs interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b cond) ToSql() (string, []interface{}, error) {
|
||||||
|
var phs []string
|
||||||
|
if b.Op == In {
|
||||||
|
rhs, isInterfaceSlice := b.Rhs.([]interface{})
|
||||||
|
if isInterfaceSlice {
|
||||||
|
phs = b.PlaceHolderGenerator(len(rhs))
|
||||||
|
return fmt.Sprintf("%s IN (%s)", b.Lhs, strings.Join(phs, ",")), rhs, nil
|
||||||
|
} else if rawThing, isRaw := b.Rhs.(*raw); isRaw {
|
||||||
|
return fmt.Sprintf("%s IN (%s)", b.Lhs, rawThing.sql), rawThing.args, nil
|
||||||
|
} else {
|
||||||
|
return "", nil, fmt.Errorf("Right hand side of Cond when operator is IN should be either a interface{} slice or *raw")
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
phs = b.PlaceHolderGenerator(1)
|
||||||
|
return fmt.Sprintf("%s %s %s", b.Lhs, b.Op, pop(&phs)), []interface{}{b.Rhs}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
nextType_AND = "AND"
|
||||||
|
nextType_OR = "OR"
|
||||||
|
)
|
||||||
|
|
||||||
|
type whereClause struct {
|
||||||
|
PlaceHolderGenerator func(n int) []string
|
||||||
|
nextTyp string
|
||||||
|
next *whereClause
|
||||||
|
cond
|
||||||
|
raw string
|
||||||
|
args []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w whereClause) ToSql() (string, []interface{}, error) {
|
||||||
|
var base string
|
||||||
|
var args []interface{}
|
||||||
|
var err error
|
||||||
|
if w.raw != "" {
|
||||||
|
base = w.raw
|
||||||
|
args = w.args
|
||||||
|
} else {
|
||||||
|
w.cond.PlaceHolderGenerator = w.PlaceHolderGenerator
|
||||||
|
base, args, err = w.cond.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if w.next == nil {
|
||||||
|
return base, args, nil
|
||||||
|
}
|
||||||
|
if w.next != nil {
|
||||||
|
next, nextArgs, err := w.next.ToSql()
|
||||||
|
if err != nil {
|
||||||
|
return "", nil, err
|
||||||
|
}
|
||||||
|
base += " " + w.nextTyp + " " + next
|
||||||
|
args = append(args, nextArgs...)
|
||||||
|
return base, args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return base, args, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
//func (q *QueryBuilder[OUTPUT]) WhereKeyValue(m map) {}
|
||||||
|
|
||||||
|
// WhereIn adds a where clause to QueryBuilder using In operator.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) WhereIn(column string, values ...interface{}) *QueryBuilder[OUTPUT] {
|
||||||
|
return q.Where(append([]interface{}{column, In}, values...)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AndWhere appends a where clause to query builder as And where clause.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) AndWhere(parts ...interface{}) *QueryBuilder[OUTPUT] {
|
||||||
|
return q.addWhere(nextType_AND, parts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// OrWhere appends a where clause to query builder as Or where clause.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) OrWhere(parts ...interface{}) *QueryBuilder[OUTPUT] {
|
||||||
|
return q.addWhere(nextType_OR, parts...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryBuilder[OUTPUT]) addWhere(typ string, parts ...interface{}) *QueryBuilder[OUTPUT] {
|
||||||
|
w := q.where
|
||||||
|
for {
|
||||||
|
if w == nil {
|
||||||
|
break
|
||||||
|
} else if w.next == nil {
|
||||||
|
w.next = &whereClause{PlaceHolderGenerator: q.placeholderGenerator}
|
||||||
|
w.nextTyp = typ
|
||||||
|
w = w.next
|
||||||
|
break
|
||||||
|
} else {
|
||||||
|
w = w.next
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if w == nil {
|
||||||
|
w = &whereClause{PlaceHolderGenerator: q.placeholderGenerator}
|
||||||
|
}
|
||||||
|
if len(parts) == 1 {
|
||||||
|
w.raw = parts[0].(*raw).sql
|
||||||
|
w.args = parts[0].(*raw).args
|
||||||
|
return q
|
||||||
|
} else if len(parts) == 2 {
|
||||||
|
// Equal mode
|
||||||
|
w.cond = cond{Lhs: parts[0].(string), Op: Eq, Rhs: parts[1]}
|
||||||
|
return q
|
||||||
|
} else if len(parts) == 3 {
|
||||||
|
// operator mode
|
||||||
|
w.cond = cond{Lhs: parts[0].(string), Op: binaryOp(parts[1].(string)), Rhs: parts[2]}
|
||||||
|
return q
|
||||||
|
} else {
|
||||||
|
panic("wrong number of arguments passed to Where")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Offset adds offset section to query builder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Offset(n int) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
q.offset = &Offset{N: n}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Limit adds limit section to query builder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Limit(n int) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
q.limit = &Limit{N: n}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Table sets table of QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Table(t string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.table = t
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetSelect sets query type of QueryBuilder to Select.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) SetSelect() *QueryBuilder[OUTPUT] {
|
||||||
|
q.typ = queryTypeSELECT
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// GroupBy adds a group by section to QueryBuilder.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) GroupBy(columns ...string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
if q.groupBy == nil {
|
||||||
|
q.groupBy = &GroupBy{}
|
||||||
|
}
|
||||||
|
q.groupBy.Columns = append(q.groupBy.Columns, columns...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// Select adds columns to QueryBuilder select field list.
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Select(columns ...string) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
if q.selected == nil {
|
||||||
|
q.selected = &selected{}
|
||||||
|
}
|
||||||
|
q.selected.Columns = append(q.selected.Columns, columns...)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
// FromQuery sets subquery of QueryBuilder to be given subquery so
|
||||||
|
// when doing select instead of from table we do from(subquery).
|
||||||
|
func (q *QueryBuilder[OUTPUT]) FromQuery(subQuery *QueryBuilder[OUTPUT]) *QueryBuilder[OUTPUT] {
|
||||||
|
q.SetSelect()
|
||||||
|
subQuery.SetSelect()
|
||||||
|
subQuery.placeholderGenerator = q.placeholderGenerator
|
||||||
|
subQueryString, args, err := subQuery.ToSql()
|
||||||
|
q.err = err
|
||||||
|
q.subQuery = &struct {
|
||||||
|
q string
|
||||||
|
args []interface{}
|
||||||
|
placeholderGenerator func(n int) []string
|
||||||
|
}{
|
||||||
|
subQueryString, args, q.placeholderGenerator,
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryBuilder[OUTPUT]) SetUpdate() *QueryBuilder[OUTPUT] {
|
||||||
|
q.typ = queryTypeUPDATE
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryBuilder[OUTPUT]) Set(keyValues ...any) *QueryBuilder[OUTPUT] {
|
||||||
|
if len(keyValues)%2 != 0 {
|
||||||
|
q.err = fmt.Errorf("when using Set, passed argument count should be even: %w", q.err)
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
q.SetUpdate()
|
||||||
|
for i := 0; i < len(keyValues); i++ {
|
||||||
|
if i != 0 && i%2 == 1 {
|
||||||
|
q.sets = append(q.sets, [2]any{keyValues[i-1], keyValues[i]})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func (q *QueryBuilder[OUTPUT]) SetDialect(dialect *Dialect) *QueryBuilder[OUTPUT] {
|
||||||
|
q.placeholderGenerator = dialect.PlaceHolderGenerator
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
func (q *QueryBuilder[OUTPUT]) SetDelete() *QueryBuilder[OUTPUT] {
|
||||||
|
q.typ = queryTypeDelete
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
type raw struct {
|
||||||
|
sql string
|
||||||
|
args []interface{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Raw creates a Raw sql query chunk that you can add to several components of QueryBuilder like
|
||||||
|
// Wheres.
|
||||||
|
func Raw(sql string, args ...interface{}) *raw {
|
||||||
|
return &raw{sql: sql, args: args}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewQueryBuilder[OUTPUT any](s *schema) *QueryBuilder[OUTPUT] {
|
||||||
|
return &QueryBuilder[OUTPUT]{schema: s}
|
||||||
|
}
|
||||||
|
|
||||||
|
type insertStmt struct {
|
||||||
|
PlaceHolderGenerator func(n int) []string
|
||||||
|
Table string
|
||||||
|
Columns []string
|
||||||
|
Values [][]interface{}
|
||||||
|
Returning string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i insertStmt) flatValues() []interface{} {
|
||||||
|
var values []interface{}
|
||||||
|
for _, row := range i.Values {
|
||||||
|
values = append(values, row...)
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i insertStmt) getValuesStr() string {
|
||||||
|
phs := i.PlaceHolderGenerator(len(i.Values) * len(i.Values[0]))
|
||||||
|
|
||||||
|
var output []string
|
||||||
|
for _, valueRow := range i.Values {
|
||||||
|
output = append(output, fmt.Sprintf("(%s)", strings.Join(phs[:len(valueRow)], ",")))
|
||||||
|
phs = phs[len(valueRow):]
|
||||||
|
}
|
||||||
|
return strings.Join(output, ",")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i insertStmt) ToSql() (string, []interface{}) {
|
||||||
|
base := fmt.Sprintf("INSERT INTO %s (%s) VALUES %s",
|
||||||
|
i.Table,
|
||||||
|
strings.Join(i.Columns, ","),
|
||||||
|
i.getValuesStr(),
|
||||||
|
)
|
||||||
|
if i.Returning != "" {
|
||||||
|
base += "RETURNING " + i.Returning
|
||||||
|
}
|
||||||
|
return base, i.flatValues()
|
||||||
|
}
|
||||||
|
|
||||||
|
func postgresPlaceholder(n int) []string {
|
||||||
|
output := []string{}
|
||||||
|
for i := 1; i < n+1; i++ {
|
||||||
|
output = append(output, fmt.Sprintf("$%d", i))
|
||||||
|
}
|
||||||
|
return output
|
||||||
|
}
|
||||||
|
|
||||||
|
func questionMarks(n int) []string {
|
||||||
|
output := []string{}
|
||||||
|
for i := 0; i < n; i++ {
|
||||||
|
output = append(output, "?")
|
||||||
|
}
|
||||||
|
|
||||||
|
return output
|
||||||
|
}
|
243
gdb/orm/query_test.go
Normal file
243
gdb/orm/query_test.go
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
//
|
||||||
|
// query_test.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Dummy struct{}
|
||||||
|
|
||||||
|
func (d Dummy) ConfigureEntity(e *EntityConfigurator) {
|
||||||
|
// TODO implement me
|
||||||
|
panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSelect(t *testing.T) {
|
||||||
|
t.Run("only select * from Table", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil)
|
||||||
|
s.Table("users").SetSelect()
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, "SELECT * FROM users", str)
|
||||||
|
})
|
||||||
|
t.Run("select with whereClause", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil)
|
||||||
|
|
||||||
|
s.Table("users").SetDialect(Dialects.MySQL).
|
||||||
|
Where("age", 10).
|
||||||
|
AndWhere("age", "<", 10).
|
||||||
|
Where("name", "Amirreza").
|
||||||
|
OrWhere("age", GT, 11).
|
||||||
|
SetSelect()
|
||||||
|
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, []interface{}{10, 10, "Amirreza", 11}, args)
|
||||||
|
assert.Equal(t, "SELECT * FROM users WHERE age = ? AND age < ? AND name = ? OR age > ?", str)
|
||||||
|
})
|
||||||
|
t.Run("select with order by", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").OrderBy("created_at", ASC).OrderBy("updated_at", DESC)
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, "SELECT * FROM users ORDER BY created_at ASC,updated_at DESC", str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with group by", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").GroupBy("created_at", "updated_at")
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, "SELECT * FROM users GROUP BY created_at,updated_at", str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Select with limit", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").Limit(10)
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, "SELECT * FROM users LIMIT 10", str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Select with offset", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").Offset(10)
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, "SELECT * FROM users OFFSET 10", str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with join", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").RightJoin("addresses", "users.id", "addresses.user_id")
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id`, str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with multiple joins", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").
|
||||||
|
Select("id", "name").
|
||||||
|
RightJoin("addresses", "users.id", "addresses.user_id").
|
||||||
|
LeftJoin("user_credits", "users.id", "user_credits.user_id")
|
||||||
|
sql, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, `SELECT id,name FROM users RIGHT JOIN addresses ON users.id = addresses.user_id LEFT JOIN user_credits ON users.id = user_credits.user_id`, sql)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with subquery", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).SetDialect(Dialects.MySQL)
|
||||||
|
s.FromQuery(NewQueryBuilder[Dummy](nil).Table("users").Where("age", "<", 10))
|
||||||
|
sql, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, []interface{}{10}, args)
|
||||||
|
assert.Equal(t, `SELECT * FROM (SELECT * FROM users WHERE age < ? )`, sql)
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with inner join", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").InnerJoin("addresses", "users.id", "addresses.user_id")
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with join", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").Join("addresses", "users.id", "addresses.user_id")
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, `SELECT id,name FROM users INNER JOIN addresses ON users.id = addresses.user_id`, str)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("select with full outer join", func(t *testing.T) {
|
||||||
|
s := NewQueryBuilder[Dummy](nil).Table("users").Select("id", "name").FullOuterJoin("addresses", "users.id", "addresses.user_id")
|
||||||
|
str, args, err := s.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Equal(t, `SELECT id,name FROM users FULL OUTER JOIN addresses ON users.id = addresses.user_id`, str)
|
||||||
|
})
|
||||||
|
t.Run("raw where", func(t *testing.T) {
|
||||||
|
sql, args, err :=
|
||||||
|
NewQueryBuilder[Dummy](nil).
|
||||||
|
SetDialect(Dialects.MySQL).
|
||||||
|
Table("users").
|
||||||
|
Where(Raw("id = ?", 1)).
|
||||||
|
AndWhere(Raw("age < ?", 10)).
|
||||||
|
SetSelect().
|
||||||
|
ToSql()
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, []interface{}{1, 10}, args)
|
||||||
|
assert.Equal(t, `SELECT * FROM users WHERE id = ? AND age < ?`, sql)
|
||||||
|
})
|
||||||
|
t.Run("no sql type matched", func(t *testing.T) {
|
||||||
|
sql, args, err := NewQueryBuilder[Dummy](nil).ToSql()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
assert.Empty(t, sql)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("raw where in", func(t *testing.T) {
|
||||||
|
sql, args, err :=
|
||||||
|
NewQueryBuilder[Dummy](nil).
|
||||||
|
SetDialect(Dialects.MySQL).
|
||||||
|
Table("users").
|
||||||
|
WhereIn("id", Raw("SELECT user_id FROM user_books WHERE book_id = ?", 10)).
|
||||||
|
SetSelect().
|
||||||
|
ToSql()
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, []interface{}{10}, args)
|
||||||
|
assert.Equal(t, `SELECT * FROM users WHERE id IN (SELECT user_id FROM user_books WHERE book_id = ?)`, sql)
|
||||||
|
})
|
||||||
|
t.Run("where in", func(t *testing.T) {
|
||||||
|
sql, args, err :=
|
||||||
|
NewQueryBuilder[Dummy](nil).
|
||||||
|
SetDialect(Dialects.MySQL).
|
||||||
|
Table("users").
|
||||||
|
WhereIn("id", 1, 2, 3, 4, 5, 6).
|
||||||
|
SetSelect().
|
||||||
|
ToSql()
|
||||||
|
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.EqualValues(t, []interface{}{1, 2, 3, 4, 5, 6}, args)
|
||||||
|
assert.Equal(t, `SELECT * FROM users WHERE id IN (?,?,?,?,?,?)`, sql)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func TestUpdate(t *testing.T) {
|
||||||
|
t.Run("update no whereClause", func(t *testing.T) {
|
||||||
|
u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").SetDialect(Dialects.MySQL)
|
||||||
|
sql, args, err := u.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `UPDATE users SET name=?`, sql)
|
||||||
|
assert.Equal(t, []interface{}{"amirreza"}, args)
|
||||||
|
})
|
||||||
|
t.Run("update with whereClause", func(t *testing.T) {
|
||||||
|
u := NewQueryBuilder[Dummy](nil).Table("users").Set("name", "amirreza").Where("age", "<", 18).SetDialect(Dialects.MySQL)
|
||||||
|
sql, args, err := u.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `UPDATE users SET name=? WHERE age < ?`, sql)
|
||||||
|
assert.Equal(t, []interface{}{"amirreza", 18}, args)
|
||||||
|
|
||||||
|
})
|
||||||
|
}
|
||||||
|
func TestDelete(t *testing.T) {
|
||||||
|
t.Run("delete without whereClause", func(t *testing.T) {
|
||||||
|
d := NewQueryBuilder[Dummy](nil).Table("users").SetDelete()
|
||||||
|
sql, args, err := d.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `DELETE FROM users`, sql)
|
||||||
|
assert.Empty(t, args)
|
||||||
|
})
|
||||||
|
t.Run("delete with whereClause", func(t *testing.T) {
|
||||||
|
d := NewQueryBuilder[Dummy](nil).Table("users").SetDialect(Dialects.MySQL).Where("created_at", ">", "2012-01-10").SetDelete()
|
||||||
|
sql, args, err := d.ToSql()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, `DELETE FROM users WHERE created_at > ?`, sql)
|
||||||
|
assert.EqualValues(t, []interface{}{"2012-01-10"}, args)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInsert(t *testing.T) {
|
||||||
|
t.Run("insert into multiple rows", func(t *testing.T) {
|
||||||
|
i := insertStmt{}
|
||||||
|
i.Table = "users"
|
||||||
|
i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator
|
||||||
|
i.Columns = []string{"name", "age"}
|
||||||
|
i.Values = append(i.Values, []interface{}{"amirreza", 11}, []interface{}{"parsa", 10})
|
||||||
|
s, args := i.ToSql()
|
||||||
|
assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?),(?,?)`, s)
|
||||||
|
assert.EqualValues(t, []interface{}{"amirreza", 11, "parsa", 10}, args)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("insert into single row", func(t *testing.T) {
|
||||||
|
i := insertStmt{}
|
||||||
|
i.Table = "users"
|
||||||
|
i.PlaceHolderGenerator = Dialects.MySQL.PlaceHolderGenerator
|
||||||
|
i.Columns = []string{"name", "age"}
|
||||||
|
i.Values = append(i.Values, []interface{}{"amirreza", 11})
|
||||||
|
s, args := i.ToSql()
|
||||||
|
assert.Equal(t, `INSERT INTO users (name,age) VALUES (?,?)`, s)
|
||||||
|
assert.Equal(t, []interface{}{"amirreza", 11}, args)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPostgresPlaceholder(t *testing.T) {
|
||||||
|
t.Run("for 5 it should have 5", func(t *testing.T) {
|
||||||
|
phs := postgresPlaceholder(5)
|
||||||
|
assert.EqualValues(t, []string{"$1", "$2", "$3", "$4", "$5"}, phs)
|
||||||
|
})
|
||||||
|
}
|
306
gdb/orm/schema.go
Normal file
306
gdb/orm/schema.go
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
//
|
||||||
|
// schema.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"database/sql/driver"
|
||||||
|
"fmt"
|
||||||
|
"reflect"
|
||||||
|
)
|
||||||
|
|
||||||
|
func getConnectionFor(e Entity) *connection {
|
||||||
|
configurator := newEntityConfigurator()
|
||||||
|
e.ConfigureEntity(configurator)
|
||||||
|
|
||||||
|
if len(globalConnections) > 1 && (configurator.connection == "" || configurator.table == "") {
|
||||||
|
panic("need table and DB name when having more than 1 DB registered")
|
||||||
|
}
|
||||||
|
if len(globalConnections) == 1 {
|
||||||
|
for _, db := range globalConnections {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if db, exists := globalConnections[fmt.Sprintf("%s", configurator.connection)]; exists {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
panic("no db found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSchemaFor(e Entity) *schema {
|
||||||
|
configurator := newEntityConfigurator()
|
||||||
|
c := getConnectionFor(e)
|
||||||
|
e.ConfigureEntity(configurator)
|
||||||
|
s := c.getSchema(configurator.table)
|
||||||
|
if s == nil {
|
||||||
|
s = schemaOfHeavyReflectionStuff(e)
|
||||||
|
c.setSchema(e, s)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
type schema struct {
|
||||||
|
Connection string
|
||||||
|
Table string
|
||||||
|
fields []*field
|
||||||
|
relations map[string]interface{}
|
||||||
|
setPK func(o Entity, value interface{})
|
||||||
|
getPK func(o Entity) interface{}
|
||||||
|
columnConstraints []*FieldConfigurator
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) getField(sf reflect.StructField) *field {
|
||||||
|
for _, f := range s.fields {
|
||||||
|
if sf.Name == f.Name {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) getDialect() *Dialect {
|
||||||
|
return GetConnection(s.Connection).Dialect
|
||||||
|
}
|
||||||
|
func (s *schema) Columns(withPK bool) []string {
|
||||||
|
var cols []string
|
||||||
|
for _, field := range s.fields {
|
||||||
|
if field.Virtual {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !withPK && field.IsPK {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if s.getDialect().AddTableNameInSelectColumns {
|
||||||
|
cols = append(cols, s.Table+"."+field.Name)
|
||||||
|
} else {
|
||||||
|
cols = append(cols, field.Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return cols
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) pkName() string {
|
||||||
|
for _, field := range s.fields {
|
||||||
|
if field.IsPK {
|
||||||
|
return field.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func genericFieldsOf(obj Entity) []*field {
|
||||||
|
t := reflect.TypeOf(obj)
|
||||||
|
for t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
|
||||||
|
}
|
||||||
|
if t.Kind() == reflect.Slice {
|
||||||
|
t = t.Elem()
|
||||||
|
for t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var ec EntityConfigurator
|
||||||
|
obj.ConfigureEntity(&ec)
|
||||||
|
|
||||||
|
var fms []*field
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
ft := t.Field(i)
|
||||||
|
fm := fieldMetadata(ft, ec.columnConstraints)
|
||||||
|
fms = append(fms, fm...)
|
||||||
|
}
|
||||||
|
return fms
|
||||||
|
}
|
||||||
|
|
||||||
|
func valuesOfField(vf reflect.Value) []interface{} {
|
||||||
|
var values []interface{}
|
||||||
|
if vf.Type().Kind() == reflect.Struct || vf.Type().Kind() == reflect.Ptr {
|
||||||
|
t := vf.Type()
|
||||||
|
if vf.Type().Kind() == reflect.Ptr {
|
||||||
|
t = vf.Type().Elem()
|
||||||
|
}
|
||||||
|
if !t.Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||||||
|
// go into
|
||||||
|
// it does not implement driver.Valuer interface
|
||||||
|
for i := 0; i < vf.NumField(); i++ {
|
||||||
|
vif := vf.Field(i)
|
||||||
|
values = append(values, valuesOfField(vif)...)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
values = append(values, vf.Interface())
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
values = append(values, vf.Interface())
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
func genericValuesOf(o Entity, withPK bool) []interface{} {
|
||||||
|
t := reflect.TypeOf(o)
|
||||||
|
v := reflect.ValueOf(o)
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
t = t.Elem()
|
||||||
|
v = v.Elem()
|
||||||
|
}
|
||||||
|
fields := getSchemaFor(o).fields
|
||||||
|
pkIdx := -1
|
||||||
|
for i, field := range fields {
|
||||||
|
if field.IsPK {
|
||||||
|
pkIdx = i
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
var values []interface{}
|
||||||
|
|
||||||
|
for i := 0; i < t.NumField(); i++ {
|
||||||
|
if !withPK && i == pkIdx {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if fields[i].Virtual {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
vf := v.Field(i)
|
||||||
|
values = append(values, valuesOfField(vf)...)
|
||||||
|
}
|
||||||
|
return values
|
||||||
|
}
|
||||||
|
|
||||||
|
func genericSetPkValue(obj Entity, value interface{}) {
|
||||||
|
genericSet(obj, getSchemaFor(obj).pkName(), value)
|
||||||
|
}
|
||||||
|
|
||||||
|
func genericGetPKValue(obj Entity) interface{} {
|
||||||
|
t := reflect.TypeOf(obj)
|
||||||
|
val := reflect.ValueOf(obj)
|
||||||
|
if t.Kind() == reflect.Ptr {
|
||||||
|
val = val.Elem()
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := getSchemaFor(obj).fields
|
||||||
|
for i, field := range fields {
|
||||||
|
if field.IsPK {
|
||||||
|
return val.Field(i).Interface()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) createdAt() *field {
|
||||||
|
for _, f := range s.fields {
|
||||||
|
if f.IsCreatedAt {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func (s *schema) updatedAt() *field {
|
||||||
|
for _, f := range s.fields {
|
||||||
|
if f.IsUpdatedAt {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) deletedAt() *field {
|
||||||
|
for _, f := range s.fields {
|
||||||
|
if f.IsDeletedAt {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
func pointersOf(v reflect.Value) map[string]interface{} {
|
||||||
|
m := map[string]interface{}{}
|
||||||
|
actualV := v
|
||||||
|
for actualV.Type().Kind() == reflect.Ptr {
|
||||||
|
actualV = actualV.Elem()
|
||||||
|
}
|
||||||
|
for i := 0; i < actualV.NumField(); i++ {
|
||||||
|
f := actualV.Field(i)
|
||||||
|
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||||||
|
fm := pointersOf(f)
|
||||||
|
for k, p := range fm {
|
||||||
|
m[k] = p
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fm := fieldMetadata(actualV.Type().Field(i), nil)[0]
|
||||||
|
m[fm.Name] = actualV.Field(i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
func genericSet(obj Entity, name string, value interface{}) {
|
||||||
|
n2p := pointersOf(reflect.ValueOf(obj))
|
||||||
|
var val interface{}
|
||||||
|
for k, v := range n2p {
|
||||||
|
if k == name {
|
||||||
|
val = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
val.(reflect.Value).Set(reflect.ValueOf(value))
|
||||||
|
}
|
||||||
|
func schemaOfHeavyReflectionStuff(v Entity) *schema {
|
||||||
|
userEntityConfigurator := newEntityConfigurator()
|
||||||
|
v.ConfigureEntity(userEntityConfigurator)
|
||||||
|
for _, relation := range userEntityConfigurator.resolveRelations {
|
||||||
|
relation()
|
||||||
|
}
|
||||||
|
schema := &schema{}
|
||||||
|
if userEntityConfigurator.connection != "" {
|
||||||
|
schema.Connection = userEntityConfigurator.connection
|
||||||
|
}
|
||||||
|
if userEntityConfigurator.table != "" {
|
||||||
|
schema.Table = userEntityConfigurator.table
|
||||||
|
} else {
|
||||||
|
panic("you need to have table name for getting schema.")
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.columnConstraints = userEntityConfigurator.columnConstraints
|
||||||
|
if schema.Connection == "" {
|
||||||
|
schema.Connection = "default"
|
||||||
|
}
|
||||||
|
if schema.fields == nil {
|
||||||
|
schema.fields = genericFieldsOf(v)
|
||||||
|
}
|
||||||
|
if schema.getPK == nil {
|
||||||
|
schema.getPK = genericGetPKValue
|
||||||
|
}
|
||||||
|
|
||||||
|
if schema.setPK == nil {
|
||||||
|
schema.setPK = genericSetPkValue
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.relations = userEntityConfigurator.relations
|
||||||
|
|
||||||
|
return schema
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) getTable() string {
|
||||||
|
return s.Table
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) getSQLDB() *sql.DB {
|
||||||
|
return s.getConnection().DB
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *schema) getConnection() *connection {
|
||||||
|
if len(globalConnections) > 1 && (s.Connection == "" || s.Table == "") {
|
||||||
|
panic("need table and DB name when having more than 1 DB registered")
|
||||||
|
}
|
||||||
|
if len(globalConnections) == 1 {
|
||||||
|
for _, db := range globalConnections {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if db, exists := globalConnections[fmt.Sprintf("%s", s.Connection)]; exists {
|
||||||
|
return db
|
||||||
|
}
|
||||||
|
panic("no db found")
|
||||||
|
}
|
76
gdb/orm/schema_test.go
Normal file
76
gdb/orm/schema_test.go
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
//
|
||||||
|
// schema_test.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func setup(t *testing.T) {
|
||||||
|
db, err := sql.Open("sqlite3", ":memory:")
|
||||||
|
err = SetupConnections(ConnectionConfig{
|
||||||
|
Name: "default",
|
||||||
|
DB: db,
|
||||||
|
Dialect: Dialects.SQLite3,
|
||||||
|
})
|
||||||
|
// orm.Schematic()
|
||||||
|
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS posts (id INTEGER PRIMARY KEY, body text, created_at TIMESTAMP, updated_at TIMESTAMP, deleted_at TIMESTAMP)`)
|
||||||
|
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS emails (id INTEGER PRIMARY KEY, post_id INTEGER, email text)`)
|
||||||
|
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS header_pictures (id INTEGER PRIMARY KEY, post_id INTEGER, link text)`)
|
||||||
|
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS comments (id INTEGER PRIMARY KEY, post_id INTEGER, body text)`)
|
||||||
|
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS categories (id INTEGER PRIMARY KEY, title text)`)
|
||||||
|
_, err = GetConnection("default").DB.Exec(`CREATE TABLE IF NOT EXISTS post_categories (post_id INTEGER, category_id INTEGER, PRIMARY KEY(post_id, category_id))`)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
type Object struct {
|
||||||
|
ID int64
|
||||||
|
Name string
|
||||||
|
Timestamps
|
||||||
|
}
|
||||||
|
|
||||||
|
func (o Object) ConfigureEntity(e *EntityConfigurator) {
|
||||||
|
e.Table("objects").Connection("default")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericFieldsOf(t *testing.T) {
|
||||||
|
t.Run("fields of with id and timestamps embedded", func(t *testing.T) {
|
||||||
|
fs := genericFieldsOf(&Object{})
|
||||||
|
assert.Len(t, fs, 5)
|
||||||
|
assert.Equal(t, "id", fs[0].Name)
|
||||||
|
assert.True(t, fs[0].IsPK)
|
||||||
|
assert.Equal(t, "name", fs[1].Name)
|
||||||
|
assert.Equal(t, "created_at", fs[2].Name)
|
||||||
|
assert.Equal(t, "updated_at", fs[3].Name)
|
||||||
|
assert.Equal(t, "deleted_at", fs[4].Name)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestGenericValuesOf(t *testing.T) {
|
||||||
|
t.Run("values of", func(t *testing.T) {
|
||||||
|
|
||||||
|
setup(t)
|
||||||
|
vs := genericValuesOf(Object{}, true)
|
||||||
|
assert.Len(t, vs, 5)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEntityConfigurator(t *testing.T) {
|
||||||
|
t.Run("test has many with user provided values", func(t *testing.T) {
|
||||||
|
setup(t)
|
||||||
|
var ec EntityConfigurator
|
||||||
|
ec.Table("users").Connection("default").HasMany(Object{}, HasManyConfig{
|
||||||
|
"objects", "user_id",
|
||||||
|
})
|
||||||
|
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
18
gdb/orm/timestamps.go
Normal file
18
gdb/orm/timestamps.go
Normal file
@ -0,0 +1,18 @@
|
|||||||
|
//
|
||||||
|
// timestamps.go
|
||||||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||||
|
//
|
||||||
|
// Distributed under terms of the MIT license.
|
||||||
|
//
|
||||||
|
|
||||||
|
package orm
|
||||||
|
|
||||||
|
import (
|
||||||
|
"database/sql"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Timestamps struct {
|
||||||
|
CreatedAt sql.NullTime
|
||||||
|
UpdatedAt sql.NullTime
|
||||||
|
DeletedAt sql.NullTime
|
||||||
|
}
|
Loading…
Reference in New Issue
Block a user