Compare commits

..

3 Commits

Author SHA1 Message Date
02b5c78a45 refactor: 基于 gorp 重构 2023-08-17 17:16:00 +08:00
99e2411cee chore: web => http 2023-08-17 17:15:25 +08:00
2822110dec chore: 移除 sqlx 2023-08-17 17:13:45 +08:00
38 changed files with 7815 additions and 2297 deletions

View File

@ -105,7 +105,7 @@ tmp_dir = "var/tmp"
[build] [build]
# 只需要写你平常编译使用的shell命令你也可以使用 make # 只需要写你平常编译使用的shell命令你也可以使用 make
cmd = "go build -o ./var/tmp/main entry/web/main.go" cmd = "go build -o ./var/tmp/main entry/http/main.go"
# cmd 命令得到的二进制文件名 # cmd 命令得到的二进制文件名
bin = "var/tmp/main" bin = "var/tmp/main"
# 自定义的二进制可以添加额外的编译标识例如添加 GIN_MODE=release # 自定义的二进制可以添加额外的编译标识例如添加 GIN_MODE=release

View File

@ -1,749 +0,0 @@
//
// adapter_sqlx.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package gcasbin
import (
"bytes"
"context"
"errors"
"fmt"
"strconv"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/jmoiron/sqlx"
)
// defaultTableName if tableName == "", the Adapter will use this default table name.
const defaultTableName = "casbin_rule"
// maxParamLength .
const maxParamLength = 7
// general sql
const (
sqlCreateTable = `
CREATE TABLE %[1]s(
p_type VARCHAR(32),
v0 VARCHAR(255),
v1 VARCHAR(255),
v2 VARCHAR(255),
v3 VARCHAR(255),
v4 VARCHAR(255),
v5 VARCHAR(255)
);
CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlTruncateTable = "TRUNCATE TABLE %s"
sqlIsTableExist = "SELECT 1 FROM %s"
sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)"
sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
sqlDeleteAll = "DELETE FROM %s"
sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?"
sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s"
sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE "
)
// for Sqlite3
const (
sqlCreateTableSqlite3 = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL,
CHECK (TYPEOF("p_type") = "text" AND
LENGTH("p_type") <= 32),
CHECK (TYPEOF("v0") = "text" AND
LENGTH("v0") <= 255),
CHECK (TYPEOF("v1") = "text" AND
LENGTH("v1") <= 255),
CHECK (TYPEOF("v2") = "text" AND
LENGTH("v2") <= 255),
CHECK (TYPEOF("v3") = "text" AND
LENGTH("v3") <= 255),
CHECK (TYPEOF("v4") = "text" AND
LENGTH("v4") <= 255),
CHECK (TYPEOF("v5") = "text" AND
LENGTH("v5") <= 255)
);
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3
)
// for Mysql
const (
sqlCreateTableMysql = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL,
INDEX idx_%[1]s (p_type,v0,v1)
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;`
)
// for Postgres
const (
sqlCreateTablePostgres = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)"
sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14"
sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7"
)
// for Sqlserver
const (
sqlCreateTableSqlserver = `
CREATE TABLE %[1]s(
p_type NVARCHAR(32) DEFAULT '' NOT NULL,
v0 NVARCHAR(255) DEFAULT '' NOT NULL,
v1 NVARCHAR(255) DEFAULT '' NOT NULL,
v2 NVARCHAR(255) DEFAULT '' NOT NULL,
v3 NVARCHAR(255) DEFAULT '' NOT NULL,
v4 NVARCHAR(255) DEFAULT '' NOT NULL,
v5 NVARCHAR(255) DEFAULT '' NOT NULL
);
CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);`
sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)"
sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14"
sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7"
)
// CasbinRule defines the casbin rule model.
// It used for save or load policy lines from sqlx connected database.
type SqlCasbinRule struct {
PType string `db:"p_type"`
V0 string `db:"v0"`
V1 string `db:"v1"`
V2 string `db:"v2"`
V3 string `db:"v3"`
V4 string `db:"v4"`
V5 string `db:"v5"`
}
// Adapter define the sqlx adapter for Casbin.
// It can load policy lines or save policy lines from sqlx connected database.
type SqlAdapter struct {
db *sqlx.DB
ctx context.Context
tableName string
isFiltered bool
SqlCreateTable string
SqlTruncateTable string
SqlIsTableExist string
SqlInsertRow string
SqlUpdateRow string
SqlDeleteAll string
SqlDeleteRow string
SqlDeleteByArgs string
SqlSelectAll string
SqlSelectWhere string
}
// Filter defines the filtering rules for a FilteredAdapter's policy.
// Empty values are ignored, but all others must match the filter.
type SqlFilter struct {
PType []string
V0 []string
V1 []string
V2 []string
V3 []string
V4 []string
V5 []string
}
// NewAdapter the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) {
return NewSqlAdapterContext(context.Background(), db, tableName)
}
// NewAdapterContext the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) {
if db == nil {
return nil, errors.New("db is nil")
}
// check db connecting
err := db.PingContext(ctx)
if err != nil {
return nil, err
}
switch db.DriverName() {
case "oci8", "ora", "goracle":
return nil, errors.New("sqlxadapter: please checkout 'oracle' branch")
}
if tableName == "" {
tableName = defaultTableName
}
adapter := SqlAdapter{
db: db,
ctx: ctx,
tableName: tableName,
}
// generate different databases sql
adapter.genSQL()
if !adapter.IsTableExist() {
if err = adapter.CreateTable(); err != nil {
return nil, err
}
}
return &adapter, nil
}
// genSQL generate sql based on db driver name.
func (p *SqlAdapter) genSQL() {
p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName)
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName)
p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName)
p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName)
p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName)
p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName)
p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName)
switch p.db.DriverName() {
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName)
case "mysql":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName)
case "sqlite3":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName)
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName)
case "sqlserver":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName)
}
}
// createTable create a not exists table.
func (p *SqlAdapter) CreateTable() error {
_, err := p.db.ExecContext(p.ctx, p.SqlCreateTable)
return err
}
// truncateTable clear the table.
func (p *SqlAdapter) TruncateTable() error {
_, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable)
return err
}
// deleteAll clear the table.
func (p *SqlAdapter) DeleteAll() error {
_, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll)
return err
}
// isTableExist check the table exists.
func (p *SqlAdapter) IsTableExist() bool {
_, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist)
return err == nil
}
// deleteRows delete eligible data.
func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error {
query = p.db.Rebind(query)
_, err := p.db.ExecContext(p.ctx, query, args...)
return err
}
// truncateAndInsertRows clear table and insert new rows.
func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error {
if err := p.TruncateTable(); err != nil {
return err
}
return p.execTxSqlRows(p.SqlInsertRow, rules)
}
// deleteAllAndInsertRows clear table and insert new rows.
func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error {
if err := p.DeleteAll(); err != nil {
return err
}
return p.execTxSqlRows(p.SqlInsertRow, rules)
}
// execTxSqlRows exec sql rows.
func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) {
tx, err := p.db.BeginTx(p.ctx, nil)
if err != nil {
return
}
var action string
stmt, err := tx.PrepareContext(p.ctx, query)
if err != nil {
action = "prepare context"
goto ROLLBACK
}
for _, rule := range rules {
if _, err = stmt.ExecContext(p.ctx, rule...); err != nil {
action = "stmt exec"
goto ROLLBACK
}
}
if err = stmt.Close(); err != nil {
action = "stmt close"
goto ROLLBACK
}
if err = tx.Commit(); err != nil {
action = "commit"
goto ROLLBACK
}
return
ROLLBACK:
if err1 := tx.Rollback(); err1 != nil {
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
}
return
}
// selectRows select eligible data by args from the table.
func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) {
// make a slice with capacity
lines := make([]*SqlCasbinRule, 0, 64)
if len(args) == 0 {
return lines, p.db.SelectContext(p.ctx, &lines, query)
}
query = p.db.Rebind(query)
return lines, p.db.SelectContext(p.ctx, &lines, query, args...)
}
// selectWhereIn select eligible data by filter from the table.
func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlSelectWhere)
args := make([]interface{}, 0, 4)
hasInCond := false
for _, col := range [maxParamLength]struct {
name string
arg []string
}{
{"p_type", filter.PType},
{"v0", filter.V0},
{"v1", filter.V1},
{"v2", filter.V2},
{"v3", filter.V3},
{"v4", filter.V4},
{"v5", filter.V5},
} {
l := len(col.arg)
if l == 0 {
continue
}
switch sqlBuf.Bytes()[sqlBuf.Len()-1] {
case '?', ')':
sqlBuf.WriteString(" AND ")
}
sqlBuf.WriteString(col.name)
if l == 1 {
sqlBuf.WriteString("=?")
args = append(args, col.arg[0])
} else {
sqlBuf.WriteString(" IN (?)")
args = append(args, col.arg)
hasInCond = true
}
}
var query string
if hasInCond {
if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil {
return
}
} else {
query = sqlBuf.String()
}
return p.SelectRows(query, args...)
}
// LoadPolicy load all policy rules from the storage.
func (p *SqlAdapter) LoadPolicy(model model.Model) error {
lines, err := p.SelectRows(p.SqlSelectAll)
if err != nil {
return err
}
for _, line := range lines {
p.loadPolicyLine(line, model)
}
return nil
}
// SavePolicy save policy rules to the storage.
func (p *SqlAdapter) SavePolicy(model model.Model) error {
args := make([][]interface{}, 0, 64)
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
}
return p.DeleteAllAndInsertRows(args)
}
// AddPolicy add one policy rule to the storage.
func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error {
args := p.GenArgs(ptype, rule)
_, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...)
return err
}
// AddPolicies add multiple policy rules to the storage.
func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
args := make([][]interface{}, 0, 8)
for _, rule := range rules {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
return p.execTxSqlRows(p.SqlInsertRow, args)
}
// RemovePolicy remove policy rules from the storage.
func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlDeleteByArgs)
args := make([]interface{}, 0, 4)
args = append(args, ptype)
for idx, arg := range rule {
if arg != "" {
sqlBuf.WriteString(" AND v")
sqlBuf.WriteString(strconv.Itoa(idx))
sqlBuf.WriteString("=?")
args = append(args, arg)
}
}
return p.DeleteRows(sqlBuf.String(), args...)
}
// RemoveFilteredPolicy remove policy rules that match the filter from the storage.
func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlDeleteByArgs)
args := make([]interface{}, 0, 4)
args = append(args, ptype)
var value string
l := fieldIndex + len(fieldValues)
for idx := 0; idx < 6; idx++ {
if fieldIndex <= idx && idx < l {
value = fieldValues[idx-fieldIndex]
if value != "" {
sqlBuf.WriteString(" AND v")
sqlBuf.WriteString(strconv.Itoa(idx))
sqlBuf.WriteString("=?")
args = append(args, value)
}
}
}
return p.DeleteRows(sqlBuf.String(), args...)
}
// RemovePolicies remove policy rules.
func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
args := make([][]interface{}, 0, 8)
for _, rule := range rules {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
return p.execTxSqlRows(p.SqlDeleteRow, args)
}
// LoadFilteredPolicy load policy rules that match the filter.
// filterPtr must be a pointer.
func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error {
if filterPtr == nil {
return p.LoadPolicy(model)
}
filter, ok := filterPtr.(*SqlFilter)
if !ok {
return errors.New("invalid filter type")
}
lines, err := p.SelectWhereIn(filter)
if err != nil {
return err
}
for _, line := range lines {
p.loadPolicyLine(line, model)
}
p.isFiltered = true
return nil
}
// IsFiltered returns true if the loaded policy rules has been filtered.
func (p *SqlAdapter) IsFiltered() bool {
return p.isFiltered
}
// UpdatePolicy update a policy rule from storage.
// This is part of the Auto-Save feature.
func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error {
oldArg := p.GenArgs(ptype, oldRule)
newArg := p.GenArgs(ptype, newPolicy)
_, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...)
return err
}
// UpdatePolicies updates policy rules to storage.
func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) {
if len(oldRules) != len(newRules) {
return errors.New("old rules size not equal to new rules size")
}
args := make([][]interface{}, 0, 16)
for idx := range oldRules {
oldArg := p.GenArgs(ptype, oldRules[idx])
newArg := p.GenArgs(ptype, newRules[idx])
args = append(args, append(newArg, oldArg...))
}
return p.execTxSqlRows(p.SqlUpdateRow, args)
}
// UpdateFilteredPolicies deletes old rules and adds new rules.
func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) {
var value string
var whereBuf bytes.Buffer
whereBuf.Grow(32)
l := fieldIndex + len(fieldValues)
whereArgs := make([]interface{}, 0, 4)
whereArgs = append(whereArgs, ptype)
for idx := 0; idx < 6; idx++ {
if fieldIndex <= idx && idx < l {
value = fieldValues[idx-fieldIndex]
if value != "" {
whereBuf.WriteString(" AND v")
whereBuf.WriteString(strconv.Itoa(idx))
whereBuf.WriteString("=?")
whereArgs = append(whereArgs, value)
}
}
}
var selectBuf bytes.Buffer
selectBuf.Grow(64)
selectBuf.WriteString(p.SqlSelectWhere)
selectBuf.WriteString("p_type=?")
selectBuf.Write(whereBuf.Bytes())
var oldRows []*SqlCasbinRule
value = p.db.Rebind(selectBuf.String())
oldRows, err = p.SelectRows(value, whereArgs...)
if err != nil {
return
}
var deleteBuf bytes.Buffer
deleteBuf.Grow(64)
deleteBuf.WriteString(p.SqlDeleteByArgs)
deleteBuf.Write(whereBuf.Bytes())
var tx *sqlx.Tx
tx, err = p.db.BeginTxx(p.ctx, nil)
if err != nil {
return
}
var (
stmt *sqlx.Stmt
action string
)
value = p.db.Rebind(deleteBuf.String())
if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil {
action = "delete old policies"
goto ROLLBACK
}
stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow)
if err != nil {
action = "preparex context"
goto ROLLBACK
}
for _, policy := range newPolicies {
arg := p.GenArgs(ptype, policy)
if _, err = stmt.ExecContext(p.ctx, arg...); err != nil {
action = "stmt exec context"
goto ROLLBACK
}
}
if err = stmt.Close(); err != nil {
action = "stmt close"
goto ROLLBACK
}
if err = tx.Commit(); err != nil {
action = "commit"
goto ROLLBACK
}
oldPolicies = make([][]string, 0, len(oldRows))
for _, rule := range oldRows {
oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
}
return
ROLLBACK:
if err1 := tx.Rollback(); err1 != nil {
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
}
return
}
// loadPolicyLine load a policy line to model.
func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) {
if line == nil {
return
}
var lineBuf bytes.Buffer
lineBuf.Grow(64)
lineBuf.WriteString(line.PType)
args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
for _, arg := range args {
if arg != "" {
lineBuf.WriteByte(',')
lineBuf.WriteString(arg)
}
}
persist.LoadPolicyLine(lineBuf.String(), model)
}
// genArgs generate args from ptype and rule.
func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} {
l := len(rule)
args := make([]interface{}, maxParamLength)
args[0] = ptype
for idx := 0; idx < l; idx++ {
args[idx+1] = rule[idx]
}
for idx := l + 1; idx < maxParamLength; idx++ {
args[idx] = ""
}
return args
}

View File

@ -1,466 +0,0 @@
//
// adapter_sqlx_test.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package gcasbin_test
import (
"os"
"strings"
"testing"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"git.hexq.cn/tiglog/golib/gcasbin"
)
const (
rbacModelFile = "testdata/rbac_model.conf"
rbacPolicyFile = "testdata/rbac_policy.csv"
)
var (
dataSourceNames = map[string]string{
// "sqlite3": ":memory:",
// "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test",
"postgres": os.Getenv("DB_DSN"),
// "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30",
}
lines = []gcasbin.SqlCasbinRule{
{PType: "p", V0: "alice", V1: "data1", V2: "read"},
{PType: "p", V0: "bob", V1: "data2", V2: "read"},
{PType: "p", V0: "bob", V1: "data2", V2: "write"},
{PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"},
{PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"},
{PType: "p", V0: "data1_admin", V1: "data2", V2: "write"},
{PType: "g", V0: "alice", V1: "data2_admin"},
{PType: "g", V0: "bob", V1: "data2_admin", V2: "test"},
{PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"},
}
filter = gcasbin.SqlFilter{
PType: []string{"p"},
V0: []string{"bob", "data2_admin"},
V1: []string{"data1", "data2"},
V2: []string{"read", "write"},
V3: []string{"test1"},
V4: []string{"test2"},
V5: []string{"test3"},
}
)
func TestSqlAdapters(t *testing.T) {
for key, value := range dataSourceNames {
t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value)
db, err := sqlx.Connect(key, value)
if err != nil {
t.Fatalf("sqlx.Connect failed, err: %v", err)
}
t.Log("---------- testTableName start")
testTableName(t, db)
t.Log("---------- testTableName finished")
t.Log("---------- testSQL start")
testSQL(t, db, "sqlxadapter_sql")
t.Log("---------- testSQL finished")
t.Log("---------- testSaveLoad start")
testSaveLoad(t, db, "sqlxadapter_save_load")
t.Log("---------- testSaveLoad finished")
t.Log("---------- testAutoSave start")
testAutoSave(t, db, "sqlxadapter_auto_save")
t.Log("---------- testAutoSave finished")
t.Log("---------- testFilteredSqlPolicy start")
testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy")
t.Log("---------- testFilteredSqlPolicy finished")
// t.Log("---------- testUpdateSqlPolicy start")
// testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateSqlPolicy finished")
// t.Log("---------- testUpdateSqlPolicies start")
// testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateSqlPolicies finished")
// t.Log("---------- testUpdateFilteredSqlPolicies start")
// testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateFilteredSqlPolicies finished")
}
}
func testTableName(t *testing.T, db *sqlx.DB) {
_, err := gcasbin.NewSqlAdapter(db, "")
if err != nil {
t.Fatalf("NewAdapter failed, err: %v", err)
}
}
func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool {
if line1.PType != line2.PType ||
line1.V0 != line2.V0 ||
line1.V1 != line2.V1 ||
line1.V2 != line2.V2 ||
line1.V3 != line2.V3 ||
line1.V4 != line2.V4 ||
line1.V5 != line2.V5 {
return false
}
return true
}
var a *gcasbin.SqlAdapter
a, err = gcasbin.NewSqlAdapter(db, tableName)
logErr("NewSqlAdapter")
// createTable test has passed when adapter create
// err = a.CreateTable()
// logErr("createTable")
if b := a.IsTableExist(); b == false {
t.Fatal("isTableExist test failed")
}
rules := make([][]interface{}, len(lines))
for idx, rule := range lines {
args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
rules[idx] = args
}
err = a.TruncateAndInsertRows(rules)
logErr("truncateAndInsertRows")
err = a.DeleteAllAndInsertRows(rules)
logErr("truncateAndInsertRows")
err = a.DeleteRows(a.SqlDeleteByArgs, "g")
logErr("deleteRows sqlDeleteByArgs g")
err = a.DeleteRows(a.SqlDeleteAll)
logErr("deleteRows sqlDeleteAll")
_ = a.TruncateAndInsertRows(rules)
records, err := a.SelectRows(a.SqlSelectAll)
logErr("selectRows sqlSelectAll")
for idx, record := range records {
line := lines[idx]
if !equalValue(*record, line) {
t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line)
}
}
records, err = a.SelectWhereIn(&filter)
logErr("selectWhereIn")
i := 3
for _, record := range records {
line := lines[i]
if !equalValue(*record, line) {
t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line)
}
i++
}
err = a.TruncateTable()
logErr("truncateTable")
}
func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Because the DB is empty at first,
// so we need to load the policy from the file adapter (.CSV) first.
e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile)
a, err := gcasbin.NewSqlAdapter(db, tableName)
if err != nil {
t.Fatal("NewAdapter test failed, err: ", err)
}
// This is a trick to save the current policy to the DB.
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
// The current policy means the policy in the Casbin enforcer (aka in memory).
err = a.SavePolicy(e.GetModel())
if err != nil {
t.Fatal("SavePolicy test failed, err: ", err)
}
// Clear the current policy.
e.ClearPolicy()
testGetSqlPolicy(t, e, [][]string{})
// Load the policy from DB.
err = a.LoadPolicy(e.GetModel())
if err != nil {
t.Fatal("LoadPolicy test failed, err: ", err)
}
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
// AutoSave is enabled by default.
// Now we disable it.
e.EnableAutoSave(false)
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
// Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer,
// it doesn't affect the policy in the storage.
_, err = e.AddPolicy("alice", "data1", "write")
logErr("AddPolicy1")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy1")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
_, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}})
logErr("AddPolicies1")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy2")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Now we enable the AutoSave.
e.EnableAutoSave(true)
// Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer,
// but also affects the policy in the storage.
_, err = e.AddPolicy("alice", "data1", "write")
logErr("AddPolicy2")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy3")
// The policy has a new rule: {"alice", "data1", "write"}.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
_, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
logErr("AddPolicies2")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy4")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"},
{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
_, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
logErr("RemovePolicies")
err = e.LoadPolicy()
logErr("LoadPolicy5")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
// Remove the added rule.
_, err = e.RemovePolicy("alice", "data1", "write")
logErr("RemovePolicy")
err = e.LoadPolicy()
logErr("LoadPolicy6")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Remove "data2_admin" related policy rules via a filter.
// Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted.
_, err = e.RemoveFilteredPolicy(0, "data2_admin")
logErr("RemoveFilteredPolicy")
err = e.LoadPolicy()
logErr("LoadPolicy7")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
}
func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
// Now set the adapter
e.SetAdapter(a)
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
// Load only alice's policies
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}})
logErr("LoadFilteredPolicy alice")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}})
// Load only bob's policies
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}})
logErr("LoadFilteredPolicy bob")
testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}})
// Load policies for data2_admin
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}})
logErr("LoadFilteredPolicy data2_admin")
testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Load policies for alice and bob
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}})
logErr("LoadFilteredPolicy alice bob")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
_, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3")
logErr("AddPolicy")
err = e.LoadFilteredPolicy(&filter)
logErr("LoadFilteredPolicy filter")
testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}})
}
func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"})
e.LoadPolicy()
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
e.LoadPolicy()
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
e.LoadPolicy()
testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}
func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
t.Helper()
myRes := e.GetPolicy()
t.Log("Policy: ", myRes)
m := make(map[string]struct{}, len(myRes))
for _, record := range myRes {
key := strings.Join(record, ",")
m[key] = struct{}{}
}
for _, record := range res {
key := strings.Join(record, ",")
if _, ok := m[key]; !ok {
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
break
}
}
}
func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
myRes := e.GetPolicy()
// log.Print("Policy: \n", myRes)
if !arraySqlEqualsWithoutOrder(myRes, res) {
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
}
}
func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool {
if len(a) != len(b) {
return false
}
mapA := make(map[int]string)
mapB := make(map[int]string)
order := make(map[int]struct{})
l := len(a)
for i := 0; i < l; i++ {
mapA[i] = util.ArrayToString(a[i])
mapB[i] = util.ArrayToString(b[i])
}
for i := 0; i < l; i++ {
for j := 0; j < l; j++ {
if _, ok := order[j]; ok {
if j == l-1 {
return false
} else {
continue
}
}
if mapA[i] == mapB[j] {
order[j] = struct{}{}
break
} else if j == l-1 {
return false
}
}
}
return true
}

View File

@ -1,180 +0,0 @@
//
// base_test.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb_test
import (
"database/sql"
"fmt"
"os"
"strings"
"testing"
_ "github.com/lib/pq"
"git.hexq.cn/tiglog/golib/gdb/sqldb"
// _ "github.com/go-sql-driver/mysql"
)
type Schema struct {
create string
drop string
}
var defaultSchema = Schema{
create: `
CREATE TABLE person (
id serial,
first_name text,
last_name text,
email text,
added_at int default 0,
PRIMARY KEY (id)
);
CREATE TABLE place (
country text,
city text NULL,
telcode integer
);
CREATE TABLE capplace (
country text,
city text NULL,
telcode integer
);
CREATE TABLE nullperson (
first_name text NULL,
last_name text NULL,
email text NULL
);
CREATE TABLE employees (
name text,
id integer,
boss_id integer
);
`,
drop: `
drop table person;
drop table place;
drop table capplace;
drop table nullperson;
drop table employees;
`,
}
type Person struct {
Id int64 `db:"id"`
FirstName string `db:"first_name"`
LastName string `db:"last_name"`
Email string `db:"email"`
AddedAt int64 `db:"added_at"`
}
type Person2 struct {
FirstName sql.NullString `db:"first_name"`
LastName sql.NullString `db:"last_name"`
Email sql.NullString
}
type Place struct {
Country string
City sql.NullString
TelCode int
}
type PlacePtr struct {
Country string
City *string
TelCode int
}
type PersonPlace struct {
Person
Place
}
type PersonPlacePtr struct {
*Person
*Place
}
type EmbedConflict struct {
FirstName string `db:"first_name"`
Person
}
type SliceMember struct {
Country string
City sql.NullString
TelCode int
People []Person `db:"-"`
Addresses []Place `db:"-"`
}
func loadDefaultFixture(db *sqldb.Engine, t *testing.T) {
tx := db.MustBegin()
s1 := "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"
tx.MustExec(db.Rebind(s1), "Jason", "Moiron", "jmoiron@jmoiron.net")
s1 = "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"
tx.MustExec(db.Rebind(s1), "John", "Doe", "johndoeDNE@gmail.net")
s1 = "INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"
tx.MustExec(db.Rebind(s1), "United States", "New York", "1")
s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)"
tx.MustExec(db.Rebind(s1), "Hong Kong", "852")
s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)"
tx.MustExec(db.Rebind(s1), "Singapore", "65")
s1 = "INSERT INTO capplace (country, telcode) VALUES (?, ?)"
tx.MustExec(db.Rebind(s1), "Sarf Efrica", "27")
s1 = "INSERT INTO employees (name, id) VALUES (?, ?)"
tx.MustExec(db.Rebind(s1), "Peter", "4444")
s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"
tx.MustExec(db.Rebind(s1), "Joe", "1", "4444")
s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"
tx.MustExec(db.Rebind(s1), "Martin", "2", "4444")
tx.Commit()
}
func MultiExec(e *sqldb.Engine, query string) {
stmts := strings.Split(query, ";\n")
if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 {
stmts = stmts[:len(stmts)-1]
}
for _, s := range stmts {
_, err := e.Exec(s)
if err != nil {
fmt.Println(err, s)
}
}
}
func RunDbTest(t *testing.T, test func(db *sqldb.Engine, t *testing.T)) {
// 先初始化数据库
url := os.Getenv("DB_URL")
var db = sqldb.New(url)
// 再注册清空数据库
defer func() {
MultiExec(db, defaultSchema.drop)
}()
// 再加入一些数据
MultiExec(db, defaultSchema.create)
loadDefaultFixture(db, t)
// 最后测试
test(db, t)
}

78
gdb/sqldb/column.go Normal file
View File

@ -0,0 +1,78 @@
//
// column.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import "reflect"
// ColumnMap represents a mapping between a Go struct field and a single
// column in a table.
// Unique and MaxSize only inform the
// CreateTables() function and are not used by Insert/Update/Delete/Get.
type ColumnMap struct {
// Column name in db table
ColumnName string
// If true, this column is skipped in generated SQL statements
Transient bool
// If true, " unique" is added to create table statements.
// Not used elsewhere
Unique bool
// Query used for getting generated id after insert
GeneratedIdQuery string
// Passed to Dialect.ToSqlType() to assist in informing the
// correct column type to map to in CreateTables()
MaxSize int
DefaultValue string
fieldName string
gotype reflect.Type
isPK bool
isAutoIncr bool
isNotNull bool
}
// Rename allows you to specify the column name in the table
//
// Example: table.ColMap("Updated").Rename("date_updated")
func (c *ColumnMap) Rename(colname string) *ColumnMap {
c.ColumnName = colname
return c
}
// SetTransient allows you to mark the column as transient. If true
// this column will be skipped when SQL statements are generated
func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
c.Transient = b
return c
}
// SetUnique adds "unique" to the create table statements for this
// column, if b is true.
func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
c.Unique = b
return c
}
// SetNotNull adds "not null" to the create table statements for this
// column, if nn is true.
func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap {
c.isNotNull = nn
return c
}
// SetMaxSize specifies the max length of values of this column. This is
// passed to the dialect.ToSqlType() function, which can use the value
// to alter the generated type for "create table" statements
func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
c.MaxSize = size
return c
}

82
gdb/sqldb/context_test.go Normal file
View File

@ -0,0 +1,82 @@
//
// context_test.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
//go:build integration
// +build integration
package sqldb_test
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
)
// Drivers that don't support cancellation.
var unsupportedDrivers map[string]bool = map[string]bool{
"mymysql": true,
}
type SleepDialect interface {
// string to sleep for d duration
SleepClause(d time.Duration) string
}
func TestWithNotCanceledContext(t *testing.T) {
dbmap := initDBMap(t)
defer dropAndClose(dbmap)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
withCtx := dbmap.WithContext(ctx)
_, err := withCtx.Exec("SELECT 1")
assert.Nil(t, err)
}
func TestWithCanceledContext(t *testing.T) {
dialect, driver := dialectAndDriver()
if unsupportedDrivers[driver] {
t.Skipf("Cancellation is not yet supported by all drivers. Not known to be supported in %s.", driver)
}
sleepDialect, ok := dialect.(SleepDialect)
if !ok {
t.Skipf("Sleep is not supported in all dialects. Not known to be supported in %s.", driver)
}
dbmap := initDBMap(t)
defer dropAndClose(dbmap)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
withCtx := dbmap.WithContext(ctx)
startTime := time.Now()
_, err := withCtx.Exec("SELECT " + sleepDialect.SleepClause(1*time.Second))
if d := time.Since(startTime); d > 500*time.Millisecond {
t.Errorf("too long execution time: %s", d)
}
switch driver {
case "postgres":
// pq doesn't return standard deadline exceeded error
if err.Error() != "pq: canceling statement due to user request" {
t.Errorf("expected context.DeadlineExceeded, got %v", err)
}
default:
if err != context.DeadlineExceeded {
t.Errorf("expected context.DeadlineExceeded, got %v", err)
}
}
}

File diff suppressed because it is too large Load Diff

View File

@ -1,220 +0,0 @@
//
// db_func.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
)
func (e *Engine) Begin() (*sqlx.Tx, error) {
return e.Beginx()
}
// 插入一条记录
func (e *Engine) NamedInsertRecord(opt *QueryOption, arg interface{}) (int64, error) { // {{{
if len(opt.fields) == 0 {
return 0, errors.New("empty fields")
}
var tmp = make([]string, 0)
for _, field := range opt.fields {
tmp = append(tmp, fmt.Sprintf(":%s", field))
}
fields_str := strings.Join(opt.fields, ",")
fields_pl := strings.Join(tmp, ",")
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", opt.table, fields_str, fields_pl)
if e.DriverName() == "postgres" {
sql += " returning id"
}
// sql = e.Rebind(sql)
stmt, err := e.PrepareNamed(sql)
if err != nil {
return 0, err
}
var id int64
err = stmt.Get(&id, arg)
if err != nil {
return 0, err
}
return id, err
} // }}}
// 插入一条记录
func (e *Engine) InsertRecord(opt *QueryOption) (int64, error) { // {{{
if len(opt.fields) == 0 {
return 0, errors.New("empty fields")
}
fields_str := strings.Join(opt.fields, ",")
fields_pl := strings.TrimRight(strings.Repeat("?,", len(opt.fields)), ",")
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", opt.table, fields_str, fields_pl)
if e.DriverName() == "postgres" {
sql += " returning id"
}
sql = e.Rebind(sql)
result, err := e.Exec(sql, opt.args...)
if err != nil {
return 0, err
}
return result.LastInsertId()
} // }}}
// 查询一条记录
// dest 目标对象
// table 查询表
// query 查询条件
// args bindvars
func (e *Engine) GetRecord(dest interface{}, opt *QueryOption) error { // {{{
if opt.query == "" {
return errors.New("empty query")
}
opt.query = "WHERE " + opt.query
sql := fmt.Sprintf("SELECT * FROM %s %s limit 1", opt.table, opt.query)
sql = e.Rebind(sql)
err := e.Get(dest, sql, opt.args...)
if err != nil {
return err
}
return nil
} // }}}
// 查询多条记录
// dest 目标变量
// opt 查询对象
// args bindvars
func (e *Engine) GetRecords(dest interface{}, opt *QueryOption) error { // {{{
var tmp = []string{}
if opt.query != "" {
tmp = append(tmp, "where", opt.query)
}
if opt.sort != "" {
tmp = append(tmp, "order by", opt.sort)
}
if opt.offset > 0 {
tmp = append(tmp, "offset", strconv.Itoa(opt.offset))
}
if opt.limit > 0 {
tmp = append(tmp, "limit", strconv.Itoa(opt.limit))
}
sql := fmt.Sprintf("select * from %s %s", opt.table, strings.Join(tmp, " "))
sql = e.Rebind(sql)
return e.Select(dest, sql, opt.args...)
} // }}}
// 更新一条记录
// table 待处理的表
// set 需要设置的语句, eg: age=:age
// query 查询语句,不能为空,确保误更新所有记录
// arg 值
func (e *Engine) NamedUpdateRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{
if opt.set == "" || opt.query == "" {
return 0, errors.New("empty set or query")
}
sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query)
result, err := e.NamedExec(sql, arg)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
} // }}}
func (e *Engine) UpdateRecords(opt *QueryOption) (int64, error) { // {{{
if opt.set == "" || opt.query == "" {
return 0, errors.New("empty set or query")
}
sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query)
sql = e.Rebind(sql)
result, err := e.Exec(sql, opt.args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
} // }}}
// 删除若干条记录
// opt 的 query 不能为空
// arg bindvars
func (e *Engine) NamedDeleteRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{
if opt.query == "" {
return 0, errors.New("emtpy query")
}
sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query)
result, err := e.NamedExec(sql, arg)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
} // }}}
func (e *Engine) DeleteRecords(opt *QueryOption) (int64, error) {
if opt.query == "" {
return 0, errors.New("emtpy query")
}
sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query)
sql = e.Rebind(sql)
result, err := e.Exec(sql, opt.args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
func (e *Engine) CountRecords(opt *QueryOption) (int, error) {
sql := fmt.Sprintf("select count(*) from %s where %s", opt.table, opt.query)
sql = e.Rebind(sql)
var num int
err := e.Get(&num, sql, opt.args...)
if err != nil {
return 0, err
}
return num, nil
}
// var levels = []int{4, 6, 7}
// query, args, err := sqlx.In("SELECT * FROM users WHERE level IN (?);", levels)
// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend
// query = db.Rebind(query)
// rows, err := db.Query(query, args...)
func (e *Engine) In(query string, args ...interface{}) (string, []interface{}, error) {
return sqlx.In(query, args...)
}
func IsNoRows(err error) bool {
return err == ErrNoRows
}
// 把 fields 转换为 field1=:field1, field2=:field2, ..., fieldN=:fieldN
func GetSetString(fields []string) string {
items := []string{}
for _, field := range fields {
if field == "id" {
continue
}
items = append(items, fmt.Sprintf("%s=:%s", field, field))
}
return strings.Join(items, ",")
}

View File

@ -1,75 +0,0 @@
//
// db_func_opt.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
type QueryOption struct {
table string
query string
set string
fields []string
sort string
offset int
limit int
args []any
joins []string
}
func NewQueryOption(table string) *QueryOption {
return &QueryOption{
table: table,
fields: []string{"*"},
offset: 0,
limit: 0,
args: make([]any, 0),
joins: make([]string, 0),
}
}
func (opt *QueryOption) Query(query string) *QueryOption {
opt.query = query
return opt
}
func (opt *QueryOption) Fields(args []string) *QueryOption {
opt.fields = args
return opt
}
func (opt *QueryOption) Select(cols ...string) *QueryOption {
opt.fields = cols
return opt
}
func (opt *QueryOption) Offset(offset int) *QueryOption {
opt.offset = offset
return opt
}
func (opt *QueryOption) Limit(limit int) *QueryOption {
opt.limit = limit
return opt
}
func (opt *QueryOption) Sort(sort string) *QueryOption {
opt.sort = sort
return opt
}
func (opt *QueryOption) Set(set string) *QueryOption {
opt.set = set
return opt
}
func (opt *QueryOption) Args(args ...any) *QueryOption {
opt.args = args
return opt
}
func (opt *QueryOption) Join(table string, cond string) *QueryOption {
opt.joins = append(opt.joins, "join "+table+" on "+cond)
return opt
}
func (opt *QueryOption) LeftJoin(table string, cond string) *QueryOption {
opt.joins = append(opt.joins, "left join "+table+" on "+cond)
return opt
}
func (opt *QueryOption) RightJoin(table string, cond string) *QueryOption {
opt.joins = append(opt.joins, "right join "+table+" on "+cond)
return opt
}

View File

@ -1,114 +0,0 @@
//
// db_func_test.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb_test
import (
"testing"
"time"
"git.hexq.cn/tiglog/golib/gdb/sqldb"
"git.hexq.cn/tiglog/golib/gtest"
)
// 经过测试,发现数据库里面使用 time 类型容易出现 timezone 不一致的情况
// 在存入数据库时,可能会导致时区丢失
// 因此,为了更好的兼容性,使用 int 时间戳会更合适
func dbFuncTest(db *sqldb.Engine, t *testing.T) {
var err error
fields := []string{"first_name", "last_name", "email"}
p := &Person{
FirstName: "三",
LastName: "张",
Email: "zs@foo.com",
}
// InsertRecord 的用法
opt := sqldb.NewQueryOption("person").Fields(fields)
rows, err := db.NamedInsertRecord(opt, p)
gtest.Nil(t, err)
gtest.True(t, rows > 0)
// fmt.Println(rows)
// GetRecord 的用法
var p3 Person
opt = sqldb.NewQueryOption("person").Query("email=?").Args("zs@foo.com")
err = db.GetRecord(&p3, opt)
// fmt.Println(p3)
gtest.Equal(t, "张", p3.LastName)
gtest.Equal(t, "三", p3.FirstName)
gtest.Equal(t, int64(0), p3.AddedAt)
gtest.Nil(t, err)
p2 := &Person{
FirstName: "四",
LastName: "李",
Email: "ls@foo.com",
AddedAt: time.Now().Unix(),
}
fields2 := append(fields, "added_at")
opt = sqldb.NewQueryOption("person").Fields(fields2)
_, err = db.NamedInsertRecord(opt, p2)
gtest.Nil(t, err)
var p4 Person
opt = sqldb.NewQueryOption("person")
err = db.GetRecord(&p4, opt)
gtest.NotNil(t, err)
gtest.Equal(t, "", p4.FirstName)
opt = sqldb.NewQueryOption("person").Query("first_name=?").Args("四")
err = db.GetRecord(&p4, opt)
gtest.Nil(t, err)
gtest.Equal(t, time.Now().Unix(), p4.AddedAt)
gtest.Equal(t, "ls@foo.com", p4.Email)
// GetRecords
var ps []Person
opt = sqldb.NewQueryOption("person").Query("id > ?").Args(0)
err = db.GetRecords(&ps, opt)
gtest.Nil(t, err)
gtest.Greater(t, int64(1), ps)
var ps2 []Person
opt = sqldb.NewQueryOption("person").Query("id=?").Args(1)
err = db.GetRecords(&ps2, opt)
gtest.Equal(t, 1, len(ps2))
if len(ps2) > 1 {
gtest.Equal(t, int64(1), ps2[0].Id)
}
// DeleteRecords
opt = sqldb.NewQueryOption("person").Query("id=?").Args(2)
n, err := db.DeleteRecords(opt)
gtest.Nil(t, err)
gtest.Greater(t, int64(0), n)
// UpdateRecords
opt = sqldb.NewQueryOption("person").Set("first_name=?").Query("email=?").Args("哈哈", "zs@foo.com")
n, err = db.UpdateRecords(opt)
gtest.Nil(t, err)
gtest.Greater(t, int64(0), n)
// NamedUpdateRecords
var p5 = ps[0]
p5.FirstName = "中华人民共和国"
opt = sqldb.NewQueryOption("person").Set("first_name=:first_name").Query("email=:email")
n, err = db.NamedUpdateRecords(opt, p5)
gtest.Nil(t, err)
gtest.Greater(t, int64(0), n)
var p6 Person
opt = sqldb.NewQueryOption("person").Query("first_name=?").Args(p5.FirstName)
err = db.GetRecord(&p6, opt)
gtest.Nil(t, err)
gtest.Greater(t, int64(0), p6.Id)
gtest.Equal(t, p6.FirstName, p5.FirstName)
}
func TestFunc(t *testing.T) {
RunDbTest(t, dbFuncTest)
}

View File

@ -1,20 +0,0 @@
//
// db_model.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
// TODO 暂时不好实现,以后再说
type Model struct {
db *Engine
}
func NewModel() *Model {
return &Model{
db: Db,
}
}

View File

@ -1,322 +0,0 @@
//
// db_query.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
)
type Query struct {
db *Engine
table string
fields []string
wheres []string // 不能太复杂
joins []string
orderBy string
groupBy string
offset int
limit int
}
func NewQueryBuild(table string, db *Engine) *Query {
return &Query{
db: db,
table: table,
fields: []string{},
wheres: []string{},
joins: []string{},
offset: 0,
limit: 0,
}
}
func (q *Query) Table(table string) *Query {
q.table = table
return q
}
// 设置 select fields
func (q *Query) Select(fields ...string) *Query {
q.fields = fields
return q
}
// 增加一个 select field
func (q *Query) AddFields(fields ...string) *Query {
q.fields = append(q.fields, fields...)
return q
}
func (q *Query) Where(query string) *Query {
q.wheres = []string{query}
return q
}
func (q *Query) AndWhere(query string) *Query {
q.wheres = append(q.wheres, "and "+query)
return q
}
func (q *Query) OrWhere(query string) *Query {
q.wheres = append(q.wheres, "or "+query)
return q
}
func (q *Query) Join(table string, on string) *Query {
var join = "join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) LeftJoin(table string, on string) *Query {
var join = "left join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) RightJoin(table string, on string) *Query {
var join = "right join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) InnerJoin(table string, on string) *Query {
var join = "inner join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) OrderBy(order string) *Query {
q.orderBy = order
return q
}
func (q *Query) GroupBy(group string) *Query {
q.groupBy = group
return q
}
func (q *Query) Offset(offset int) *Query {
q.offset = offset
return q
}
func (q *Query) Limit(limit int) *Query {
q.limit = limit
return q
}
// returningId postgres 数据库返回 LastInsertId 处理
// TODO returningId 暂时不处理
func (q *Query) getInsertSql(named, returningId bool) string {
fields_str := strings.Join(q.fields, ",")
var pl string
if named {
var tmp []string
for _, field := range q.fields {
tmp = append(tmp, ":"+field)
}
pl = strings.Join(tmp, ",")
} else {
pl = strings.Repeat("?,", len(q.fields))
pl = strings.TrimRight(pl, ",")
}
sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl)
sql = q.db.Rebind(sql)
// fmt.Println(sql)
return sql
}
// return RowsAffected, error
func (q *Query) Insert(args ...interface{}) (int64, error) {
if len(q.fields) == 0 {
return 0, errors.New("empty fields")
}
sql := q.getInsertSql(false, false)
result, err := q.db.Exec(sql, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// return RowsAffected, error
func (q *Query) NamedInsert(arg interface{}) (int64, error) {
if len(q.fields) == 0 {
return 0, errors.New("empty fields")
}
sql := q.getInsertSql(true, false)
result, err := q.db.NamedExec(sql, arg)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (q *Query) getQuerySql() string {
var (
fields_str string = "*"
join_str string
where_str string
offlim string
)
if len(q.fields) > 0 {
fields_str = strings.Join(q.fields, ",")
}
if len(q.joins) > 0 {
join_str = strings.Join(q.joins, " ")
}
if len(q.wheres) > 0 {
where_str = "where " + strings.Join(q.wheres, " ")
}
if q.offset > 0 {
offlim = " offset " + strconv.Itoa(q.offset)
}
if q.limit > 0 {
offlim = " limit " + strconv.Itoa(q.limit)
}
// select fields from table t join where groupby orderby offset limit
sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim)
return sql
}
func (q *Query) One(dest interface{}, args ...interface{}) error {
q.Limit(1)
sql := q.getQuerySql()
sql = q.db.Rebind(sql)
return q.db.Get(dest, sql, args...)
}
func (q *Query) NamedOne(dest interface{}, arg interface{}) error {
q.Limit(1)
sql := q.getQuerySql()
rows, err := q.db.NamedQuery(sql, arg)
if err != nil {
return err
}
if rows.Next() {
return rows.Scan(dest)
}
return errors.New("nr") // no record
}
func (q *Query) All(dest interface{}, args ...interface{}) error {
sql := q.getQuerySql()
sql = q.db.Rebind(sql)
return q.db.Select(dest, sql, args...)
}
// 为了省内存,直接返回迭代器
func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) {
sql := q.getQuerySql()
return q.db.NamedQuery(sql, arg)
}
// set age=? / age=:age
func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) {
var where_str string
if len(q.wheres) > 0 {
where_str = strings.Join(q.wheres, " ")
}
if set == "" || where_str == "" {
return 0, errors.New("empty set or where")
}
// update table t where
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.NamedExec(sql, arg)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数
func (q *Query) Update(set string, args ...interface{}) (int64, error) {
var where_str string
if len(q.wheres) > 0 {
where_str = strings.Join(q.wheres, " ")
}
if set == "" || where_str == "" {
return 0, errors.New("empty set or where")
}
// update table t where
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.Exec(sql, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// 普通的删除
func (q *Query) Delete(args ...interface{}) (int64, error) {
var where_str string
if len(q.wheres) == 0 {
return 0, errors.New("missing where clause")
}
where_str = strings.Join(q.wheres, " ")
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.Exec(sql, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (q *Query) NamedDelete(arg interface{}) (int64, error) {
if len(q.wheres) == 0 {
return 0, errors.New("missing where clause")
}
var where_str string
where_str = strings.Join(q.wheres, " ")
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.NamedExec(sql, arg)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (q *Query) Count(args ...interface{}) (int64, error) {
var where_str string
if len(q.wheres) > 0 {
where_str = " where " + strings.Join(q.wheres, " ")
}
sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str)
sql = q.db.Rebind(sql)
var num int64
err := q.db.Get(&num, sql, args...)
return num, err
}

View File

@ -1,109 +0,0 @@
//
// db_query_test.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb_test
import (
"testing"
"time"
"git.hexq.cn/tiglog/golib/gtest"
"git.hexq.cn/tiglog/golib/gdb/sqldb"
)
func dbQueryTest(db *sqldb.Engine, t *testing.T) {
query := sqldb.NewQueryBuild("person", db)
// query one
var p1 Person
query.Where("id=?")
err := query.One(&p1, 1)
gtest.Nil(t, err)
gtest.Equal(t, int64(1), p1.Id)
// query all
var ps1 []Person
query = sqldb.NewQueryBuild("person", db)
query.Where("id > ?")
err = query.All(&ps1, 1)
gtest.Nil(t, err)
gtest.True(t, len(ps1) > 0)
// fmt.Println(ps1)
if len(ps1) > 0 {
var val int64 = 2
gtest.Equal(t, val, ps1[0].Id)
}
// insert
query = sqldb.NewQueryBuild("person", db)
query.AddFields("first_name", "last_name", "email")
id, err := query.Insert("三", "张", "zs@bar.com")
gtest.Nil(t, err)
gtest.Greater(t, int64(0), id)
// fmt.Println(id)
// named insert
query = sqldb.NewQueryBuild("person", db)
query.AddFields("first_name", "last_name", "email")
row, err := query.NamedInsert(&Person{
FirstName: "四",
LastName: "李",
Email: "ls@bar.com",
AddedAt: time.Now().Unix(),
})
gtest.Nil(t, err)
gtest.Equal(t, int64(1), row)
// update
query = sqldb.NewQueryBuild("person", db)
query.Where("email=?")
n, err := query.Update("first_name=?", "哈哈", "ls@bar.com")
gtest.Nil(t, err)
gtest.Equal(t, int64(1), n)
// named update map
query = sqldb.NewQueryBuild("person", db)
query.Where("email=:email")
n, err = query.NamedUpdate("first_name=:first_name", map[string]interface{}{
"email": "ls@bar.com",
"first_name": "中华人民共和国",
})
gtest.Nil(t, err)
gtest.Equal(t, int64(1), n)
// named update struct
query = sqldb.NewQueryBuild("person", db)
query.Where("email=:email")
var p = &Person{
Email: "ls@bar.com",
LastName: "中华人民共和国,救民于水火",
}
n, err = query.NamedUpdate("last_name=:last_name", p)
gtest.Nil(t, err)
gtest.Equal(t, int64(1), n)
// count
query = sqldb.NewQueryBuild("person", db)
n, err = query.Count()
gtest.Nil(t, err)
// fmt.Println(n)
gtest.Greater(t, int64(0), n)
// delete
query = sqldb.NewQueryBuild("person", db)
n, err = query.Delete()
gtest.NotNil(t, err)
gtest.Equal(t, int64(0), n)
n, err = query.Where("id=?").Delete(2)
gtest.Nil(t, err)
gtest.Equal(t, int64(1), n)
}
func TestQuery(t *testing.T) {
RunDbTest(t, dbQueryTest)
}

187
gdb/sqldb/db_test.go Normal file
View File

@ -0,0 +1,187 @@
//
// db_test.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
//go:build integration
// +build integration
package sqldb_test
import (
"testing"
)
type customType1 []string
func (c customType1) ToStringSlice() []string {
return []string(c)
}
type customType2 []int64
func (c customType2) ToInt64Slice() []int64 {
return []int64(c)
}
func TestDbMap_Select_expandSliceArgs(t *testing.T) {
tests := []struct {
description string
query string
args []interface{}
wantLen int
}{
{
description: "it should handle slice placeholders correctly",
query: `
SELECT 1 FROM crazy_table
WHERE field1 = :Field1
AND field2 IN (:FieldStringList)
AND field3 IN (:FieldUIntList)
AND field4 IN (:FieldUInt8List)
AND field5 IN (:FieldUInt16List)
AND field6 IN (:FieldUInt32List)
AND field7 IN (:FieldUInt64List)
AND field8 IN (:FieldIntList)
AND field9 IN (:FieldInt8List)
AND field10 IN (:FieldInt16List)
AND field11 IN (:FieldInt32List)
AND field12 IN (:FieldInt64List)
AND field13 IN (:FieldFloat32List)
AND field14 IN (:FieldFloat64List)
`,
args: []interface{}{
map[string]interface{}{
"Field1": 123,
"FieldStringList": []string{"h", "e", "y"},
"FieldUIntList": []uint{1, 2, 3, 4},
"FieldUInt8List": []uint8{1, 2, 3, 4},
"FieldUInt16List": []uint16{1, 2, 3, 4},
"FieldUInt32List": []uint32{1, 2, 3, 4},
"FieldUInt64List": []uint64{1, 2, 3, 4},
"FieldIntList": []int{1, 2, 3, 4},
"FieldInt8List": []int8{1, 2, 3, 4},
"FieldInt16List": []int16{1, 2, 3, 4},
"FieldInt32List": []int32{1, 2, 3, 4},
"FieldInt64List": []int64{1, 2, 3, 4},
"FieldFloat32List": []float32{1, 2, 3, 4},
"FieldFloat64List": []float64{1, 2, 3, 4},
},
},
wantLen: 1,
},
{
description: "it should handle slice placeholders correctly with custom types",
query: `
SELECT 1 FROM crazy_table
WHERE field2 IN (:FieldStringList)
AND field12 IN (:FieldIntList)
`,
args: []interface{}{
map[string]interface{}{
"FieldStringList": customType1{"h", "e", "y"},
"FieldIntList": customType2{1, 2, 3, 4},
},
},
wantLen: 3,
},
}
type dataFormat struct {
Field1 int `db:"field1"`
Field2 string `db:"field2"`
Field3 uint `db:"field3"`
Field4 uint8 `db:"field4"`
Field5 uint16 `db:"field5"`
Field6 uint32 `db:"field6"`
Field7 uint64 `db:"field7"`
Field8 int `db:"field8"`
Field9 int8 `db:"field9"`
Field10 int16 `db:"field10"`
Field11 int32 `db:"field11"`
Field12 int64 `db:"field12"`
Field13 float32 `db:"field13"`
Field14 float64 `db:"field14"`
}
dbmap := newDBMap(t)
dbmap.ExpandSliceArgs = true
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
err = dbmap.Insert(
&dataFormat{
Field1: 123,
Field2: "h",
Field3: 1,
Field4: 1,
Field5: 1,
Field6: 1,
Field7: 1,
Field8: 1,
Field9: 1,
Field10: 1,
Field11: 1,
Field12: 1,
Field13: 1,
Field14: 1,
},
&dataFormat{
Field1: 124,
Field2: "e",
Field3: 2,
Field4: 2,
Field5: 2,
Field6: 2,
Field7: 2,
Field8: 2,
Field9: 2,
Field10: 2,
Field11: 2,
Field12: 2,
Field13: 2,
Field14: 2,
},
&dataFormat{
Field1: 125,
Field2: "y",
Field3: 3,
Field4: 3,
Field5: 3,
Field6: 3,
Field7: 3,
Field8: 3,
Field9: 3,
Field10: 3,
Field11: 3,
Field12: 3,
Field13: 3,
Field14: 3,
},
)
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
var dummy []int
_, err := dbmap.Select(&dummy, tt.query, tt.args...)
if err != nil {
t.Fatal(err)
}
if len(dummy) != tt.wantLen {
t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen)
}
})
}
}

108
gdb/sqldb/dialect.go Normal file
View File

@ -0,0 +1,108 @@
//
// dialect.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"reflect"
)
// The Dialect interface encapsulates behaviors that differ across
// SQL databases. At present the Dialect is only used by CreateTables()
// but this could change in the future
type Dialect interface {
// adds a suffix to any query, usually ";"
QuerySuffix() string
// ToSqlType returns the SQL column type to use when creating a
// table of the given Go Type. maxsize can be used to switch based on
// size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB,
// or LONGBLOB depending on the maxsize
ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string
// string to append to primary key column definitions
AutoIncrStr() string
// string to bind autoincrement columns to. Empty string will
// remove reference to those columns in the INSERT statement.
AutoIncrBindValue() string
AutoIncrInsertSuffix(col *ColumnMap) string
// string to append to "create table" statement for vendor specific
// table attributes
CreateTableSuffix() string
// string to append to "create index" statement
CreateIndexSuffix() string
// string to append to "drop index" statement
DropIndexSuffix() string
// string to truncate tables
TruncateClause() string
// bind variable string to use when forming SQL statements
// in many dbs it is "?", but Postgres appears to use $1
//
// i is a zero based index of the bind variable in this statement
//
BindVar(i int) string
// Handles quoting of a field name to ensure that it doesn't raise any
// SQL parsing exceptions by using a reserved word as a field name.
QuoteField(field string) string
// Handles building up of a schema.database string that is compatible with
// the given dialect
//
// schema - The schema that <table> lives in
// table - The table name
QuotedTableForQuery(schema string, table string) string
// Existence clause for table creation / deletion
IfSchemaNotExists(command, schema string) string
IfTableExists(command, schema, table string) string
IfTableNotExists(command, schema, table string) string
}
// IntegerAutoIncrInserter is implemented by dialects that can perform
// inserts with automatically incremented integer primary keys. If
// the dialect can handle automatic assignment of more than just
// integers, see TargetedAutoIncrInserter.
type IntegerAutoIncrInserter interface {
InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)
}
// TargetedAutoIncrInserter is implemented by dialects that can
// perform automatic assignment of any primary key type (i.e. strings
// for uuids, integers for serials, etc).
type TargetedAutoIncrInserter interface {
// InsertAutoIncrToTarget runs an insert operation and assigns the
// automatically generated primary key directly to the passed in
// target. The target should be a pointer to the primary key
// field of the value being inserted.
InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error
}
// TargetQueryInserter is implemented by dialects that can perform
// assignment of integer primary key type by executing a query
// like "select sequence.currval from dual".
type TargetQueryInserter interface {
// TargetQueryInserter runs an insert operation and assigns the
// automatically generated primary key retrived by the query
// extracted from the GeneratedIdQuery field of the id column.
InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error
}
func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
res, err := exec.Exec(insertSql, params...)
if err != nil {
return 0, err
}
return res.LastInsertId()
}

172
gdb/sqldb/dialect_mysql.go Normal file
View File

@ -0,0 +1,172 @@
//
// dialect_mysql.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"fmt"
"reflect"
"strings"
"time"
)
// Implementation of Dialect for MySQL databases.
type MySQLDialect struct {
// Engine is the storage engine to use "InnoDB" vs "MyISAM" for example
Engine string
// Encoding is the character encoding to use for created tables
Encoding string
}
func (d MySQLDialect) QuerySuffix() string { return ";" }
func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int8:
return "tinyint"
case reflect.Uint8:
return "tinyint unsigned"
case reflect.Int16:
return "smallint"
case reflect.Uint16:
return "smallint unsigned"
case reflect.Int, reflect.Int32:
return "int"
case reflect.Uint, reflect.Uint32:
return "int unsigned"
case reflect.Int64:
return "bigint"
case reflect.Uint64:
return "bigint unsigned"
case reflect.Float64, reflect.Float32:
return "double"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "mediumblob"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double"
case "NullBool":
return "tinyint"
case "Time":
return "datetime"
}
if maxsize < 1 {
maxsize = 255
}
/* == About varchar(N) ==
* N is number of characters.
* A varchar column can store up to 65535 bytes.
* Remember that 1 character is 3 bytes in utf-8 charset.
* Also remember that each row can store up to 65535 bytes,
* and you have some overheads, so it's not possible for a
* varchar column to have 65535/3 characters really.
* So it would be better to use 'text' type in stead of
* large varchar type.
*/
if maxsize < 256 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}
}
// Returns auto_increment
func (d MySQLDialect) AutoIncrStr() string {
return "auto_increment"
}
func (d MySQLDialect) AutoIncrBindValue() string {
return "null"
}
func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// Returns engine=%s charset=%s based on values stored on struct
func (d MySQLDialect) CreateTableSuffix() string {
if d.Engine == "" || d.Encoding == "" {
msg := "sqldb - undefined"
if d.Engine == "" {
msg += " MySQLDialect.Engine"
}
if d.Engine == "" && d.Encoding == "" {
msg += ","
}
if d.Encoding == "" {
msg += " MySQLDialect.Encoding"
}
msg += ". Check that your MySQLDialect was correctly initialized when declared."
panic(msg)
}
return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding)
}
func (d MySQLDialect) CreateIndexSuffix() string {
return "using"
}
func (d MySQLDialect) DropIndexSuffix() string {
return "on"
}
func (d MySQLDialect) TruncateClause() string {
return "truncate"
}
func (d MySQLDialect) SleepClause(s time.Duration) string {
return fmt.Sprintf("sleep(%f)", s.Seconds())
}
// Returns "?"
func (d MySQLDialect) BindVar(i int) string {
return "?"
}
func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}
func (d MySQLDialect) QuoteField(f string) string {
return "`" + f + "`"
}
func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}
return schema + "." + d.QuoteField(table)
}
func (d MySQLDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d MySQLDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d MySQLDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}

View File

@ -0,0 +1,195 @@
//
// dialect_mysql_test.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
//go:build !integration
// +build !integration
package sqldb_test
import (
"database/sql"
"errors"
"fmt"
"reflect"
"testing"
"time"
"git.hexq.cn/tiglog/golib/gdb/sqldb"
"github.com/poy/onpar"
"github.com/poy/onpar/expect"
"github.com/poy/onpar/matchers"
)
func TestMySQLDialect(t *testing.T) {
// o := onpar.New(t)
// defer o.Run()
type testContext struct {
t *testing.T
dialect sqldb.MySQLDialect
}
o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext {
return testContext{
t: t,
dialect: sqldb.MySQLDialect{Engine: "foo", Encoding: "bar"},
}
})
defer o.Run()
o.Group("ToSqlType", func() {
tests := []struct {
name string
value interface{}
maxSize int
autoIncr bool
expected string
}{
{"bool", true, 0, false, "boolean"},
{"int8", int8(1), 0, false, "tinyint"},
{"uint8", uint8(1), 0, false, "tinyint unsigned"},
{"int16", int16(1), 0, false, "smallint"},
{"uint16", uint16(1), 0, false, "smallint unsigned"},
{"int32", int32(1), 0, false, "int"},
{"int (treated as int32)", int(1), 0, false, "int"},
{"uint32", uint32(1), 0, false, "int unsigned"},
{"uint (treated as uint32)", uint(1), 0, false, "int unsigned"},
{"int64", int64(1), 0, false, "bigint"},
{"uint64", uint64(1), 0, false, "bigint unsigned"},
{"float32", float32(1), 0, false, "double"},
{"float64", float64(1), 0, false, "double"},
{"[]uint8", []uint8{1}, 0, false, "mediumblob"},
{"NullInt64", sql.NullInt64{}, 0, false, "bigint"},
{"NullFloat64", sql.NullFloat64{}, 0, false, "double"},
{"NullBool", sql.NullBool{}, 0, false, "tinyint"},
{"Time", time.Time{}, 0, false, "datetime"},
{"default-size string", "", 0, false, "varchar(255)"},
{"sized string", "", 50, false, "varchar(50)"},
{"large string", "", 1024, false, "text"},
}
for _, t := range tests {
o.Spec(t.name, func(tt testContext) {
typ := reflect.TypeOf(t.value)
sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr)
expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected))
})
}
})
o.Spec("AutoIncrStr", func(tt testContext) {
expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("auto_increment"))
})
o.Spec("AutoIncrBindValue", func(tt testContext) {
expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("null"))
})
o.Spec("AutoIncrInsertSuffix", func(tt testContext) {
expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(nil)).To(matchers.Equal(""))
})
o.Group("CreateTableSuffix", func() {
o.Group("with an empty engine", func() {
o1 := onpar.BeforeEach(o, func(tt testContext) testContext {
tt.dialect.Encoding = ""
return tt
})
o1.Spec("panics", func(tt testContext) {
expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic())
})
})
o.Group("with an empty encoding", func() {
o2 := onpar.BeforeEach(o, func(tt testContext) testContext {
tt.dialect.Encoding = ""
return tt
})
o2.Spec("panics", func(tt testContext) {
expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic())
})
})
o.Spec("with an engine and an encoding", func(tt testContext) {
expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(" engine=foo charset=bar"))
})
})
o.Spec("CreateIndexSuffix", func(tt testContext) {
expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using"))
})
o.Spec("DropIndexSuffix", func(tt testContext) {
expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("on"))
})
o.Spec("TruncateClause", func(tt testContext) {
expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate"))
})
o.Spec("SleepClause", func(tt testContext) {
expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("sleep(1.000000)"))
expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("sleep(0.100000)"))
})
o.Spec("BindVar", func(tt testContext) {
expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("?"))
})
o.Spec("QuoteField", func(tt testContext) {
expect.Expect(t, tt.dialect.QuoteField("foo")).To(matchers.Equal("`foo`"))
})
o.Group("QuotedTableForQuery", func() {
o.Spec("using the default schema", func(tt testContext) {
expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal("`foo`"))
})
o.Spec("with a supplied schema", func(tt testContext) {
expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`"))
})
})
o.Spec("IfSchemaNotExists", func(tt testContext) {
expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists"))
})
o.Spec("IfTableExists", func(tt testContext) {
expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists"))
})
o.Spec("IfTableNotExists", func(tt testContext) {
expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists"))
})
}
type panicMatcher struct {
}
func Panic() panicMatcher {
return panicMatcher{}
}
func (m panicMatcher) Match(actual interface{}) (resultValue interface{}, err error) {
switch f := actual.(type) {
case func():
panicked := false
func() {
defer func() {
if r := recover(); r != nil {
panicked = true
}
}()
f()
}()
if panicked {
return f, nil
}
return f, errors.New("function did not panic")
default:
return f, fmt.Errorf("%T is not func()", f)
}
}

142
gdb/sqldb/dialect_oracle.go Normal file
View File

@ -0,0 +1,142 @@
//
// dialect_oracle.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"fmt"
"reflect"
"strings"
)
// Implementation of Dialect for Oracle databases.
type OracleDialect struct{}
func (d OracleDialect) QuerySuffix() string { return "" }
func (d OracleDialect) CreateIndexSuffix() string { return "" }
func (d OracleDialect) DropIndexSuffix() string { return "" }
func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64:
return "double precision"
case reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double precision"
case "NullBool":
return "boolean"
case "NullTime", "Time":
return "timestamp with time zone"
}
if maxsize > 0 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}
}
// Returns empty string
func (d OracleDialect) AutoIncrStr() string {
return ""
}
func (d OracleDialect) AutoIncrBindValue() string {
return "NULL"
}
func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// Returns suffix
func (d OracleDialect) CreateTableSuffix() string {
return ""
}
func (d OracleDialect) TruncateClause() string {
return "truncate"
}
// Returns "$(i+1)"
func (d OracleDialect) BindVar(i int) string {
return fmt.Sprintf(":%d", i+1)
}
// After executing the insert uses the ColMap IdQuery to get the generated id
func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error {
_, err := exec.Exec(insertSql, params...)
if err != nil {
return err
}
id, err := exec.SelectInt(idSql)
if err != nil {
return err
}
switch target.(type) {
case *int64:
*(target.(*int64)) = id
case *int32:
*(target.(*int32)) = int32(id)
case int:
*(target.(*int)) = int(id)
default:
return fmt.Errorf("Id field can be int, int32 or int64")
}
return nil
}
func (d OracleDialect) QuoteField(f string) string {
return `"` + strings.ToUpper(f) + `"`
}
func (d OracleDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}
return schema + "." + d.QuoteField(table)
}
func (d OracleDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d OracleDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d OracleDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}

View File

@ -0,0 +1,152 @@
//
// dialect_postgres.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"fmt"
"reflect"
"strings"
"time"
)
type PostgresDialect struct {
suffix string
LowercaseFields bool
}
func (d PostgresDialect) QuerySuffix() string { return ";" }
func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "boolean"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
if isAutoIncr {
return "serial"
}
return "integer"
case reflect.Int64, reflect.Uint64:
if isAutoIncr {
return "bigserial"
}
return "bigint"
case reflect.Float64:
return "double precision"
case reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "bytea"
}
}
switch val.Name() {
case "NullInt64":
return "bigint"
case "NullFloat64":
return "double precision"
case "NullBool":
return "boolean"
case "Time", "NullTime":
return "timestamp with time zone"
}
if maxsize > 0 {
return fmt.Sprintf("varchar(%d)", maxsize)
} else {
return "text"
}
}
// Returns empty string
func (d PostgresDialect) AutoIncrStr() string {
return ""
}
func (d PostgresDialect) AutoIncrBindValue() string {
return "default"
}
func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return " returning " + d.QuoteField(col.ColumnName)
}
// Returns suffix
func (d PostgresDialect) CreateTableSuffix() string {
return d.suffix
}
func (d PostgresDialect) CreateIndexSuffix() string {
return "using"
}
func (d PostgresDialect) DropIndexSuffix() string {
return ""
}
func (d PostgresDialect) TruncateClause() string {
return "truncate"
}
func (d PostgresDialect) SleepClause(s time.Duration) string {
return fmt.Sprintf("pg_sleep(%f)", s.Seconds())
}
// Returns "$(i+1)"
func (d PostgresDialect) BindVar(i int) string {
return fmt.Sprintf("$%d", i+1)
}
func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error {
rows, err := exec.Query(insertSql, params...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err())
}
if err := rows.Scan(target); err != nil {
return err
}
if rows.Next() {
return fmt.Errorf("more than two serial value returned for insert: %s", insertSql)
}
return rows.Err()
}
func (d PostgresDialect) QuoteField(f string) string {
if d.LowercaseFields {
return `"` + strings.ToLower(f) + `"`
}
return `"` + f + `"`
}
func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string {
if strings.TrimSpace(schema) == "" {
return d.QuoteField(table)
}
return schema + "." + d.QuoteField(table)
}
func (d PostgresDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d PostgresDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d PostgresDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}

View File

@ -0,0 +1,161 @@
//
// dialect_postgres_test.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
//go:build !integration
// +build !integration
package sqldb_test
import (
"database/sql"
"reflect"
"testing"
"time"
"git.hexq.cn/tiglog/golib/gdb/sqldb"
"github.com/poy/onpar"
"github.com/poy/onpar/expect"
"github.com/poy/onpar/matchers"
)
func TestPostgresDialect(t *testing.T) {
type testContext struct {
t *testing.T
dialect sqldb.PostgresDialect
}
o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext {
return testContext{
t: t,
dialect: sqldb.PostgresDialect{
LowercaseFields: false,
},
}
})
defer o.Run()
o.Group("ToSqlType", func() {
tests := []struct {
name string
value interface{}
maxSize int
autoIncr bool
expected string
}{
{"bool", true, 0, false, "boolean"},
{"int8", int8(1), 0, false, "integer"},
{"uint8", uint8(1), 0, false, "integer"},
{"int16", int16(1), 0, false, "integer"},
{"uint16", uint16(1), 0, false, "integer"},
{"int32", int32(1), 0, false, "integer"},
{"int (treated as int32)", int(1), 0, false, "integer"},
{"uint32", uint32(1), 0, false, "integer"},
{"uint (treated as uint32)", uint(1), 0, false, "integer"},
{"int64", int64(1), 0, false, "bigint"},
{"uint64", uint64(1), 0, false, "bigint"},
{"float32", float32(1), 0, false, "real"},
{"float64", float64(1), 0, false, "double precision"},
{"[]uint8", []uint8{1}, 0, false, "bytea"},
{"NullInt64", sql.NullInt64{}, 0, false, "bigint"},
{"NullFloat64", sql.NullFloat64{}, 0, false, "double precision"},
{"NullBool", sql.NullBool{}, 0, false, "boolean"},
{"Time", time.Time{}, 0, false, "timestamp with time zone"},
{"default-size string", "", 0, false, "text"},
{"sized string", "", 50, false, "varchar(50)"},
{"large string", "", 1024, false, "varchar(1024)"},
}
for _, t := range tests {
o.Spec(t.name, func(tt testContext) {
typ := reflect.TypeOf(t.value)
sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr)
expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected))
})
}
})
o.Spec("AutoIncrStr", func(tt testContext) {
expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal(""))
})
o.Spec("AutoIncrBindValue", func(tt testContext) {
expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("default"))
})
o.Spec("AutoIncrInsertSuffix", func(tt testContext) {
cm := sqldb.ColumnMap{
ColumnName: "foo",
}
expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(&cm)).To(matchers.Equal(` returning "foo"`))
})
o.Spec("CreateTableSuffix", func(tt testContext) {
expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(""))
})
o.Spec("CreateIndexSuffix", func(tt testContext) {
expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using"))
})
o.Spec("DropIndexSuffix", func(tt testContext) {
expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal(""))
})
o.Spec("TruncateClause", func(tt testContext) {
expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate"))
})
o.Spec("SleepClause", func(tt testContext) {
expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("pg_sleep(1.000000)"))
expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("pg_sleep(0.100000)"))
})
o.Spec("BindVar", func(tt testContext) {
expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("$1"))
expect.Expect(t, tt.dialect.BindVar(4)).To(matchers.Equal("$5"))
})
o.Group("QuoteField", func() {
o.Spec("By default, case is preserved", func(tt testContext) {
expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`))
expect.Expect(t, tt.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`))
})
o.Group("With LowercaseFields set to true", func() {
o1 := onpar.BeforeEach(o, func(tt testContext) testContext {
tt.dialect.LowercaseFields = true
return tt
})
o1.Spec("fields are lowercased", func(tt testContext) {
expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`))
})
})
})
o.Group("QuotedTableForQuery", func() {
o.Spec("using the default schema", func(tt testContext) {
expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal(`"foo"`))
})
o.Spec("with a supplied schema", func(tt testContext) {
expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`))
})
})
o.Spec("IfSchemaNotExists", func(tt testContext) {
expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists"))
})
o.Spec("IfTableExists", func(tt testContext) {
expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists"))
})
o.Spec("IfTableNotExists", func(tt testContext) {
expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists"))
})
}

115
gdb/sqldb/dialect_sqlite.go Normal file
View File

@ -0,0 +1,115 @@
//
// dialect_sqlite.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"fmt"
"reflect"
)
type SqliteDialect struct {
suffix string
}
func (d SqliteDialect) QuerySuffix() string { return ";" }
func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
switch val.Kind() {
case reflect.Ptr:
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
case reflect.Bool:
return "integer"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
return "integer"
case reflect.Float64, reflect.Float32:
return "real"
case reflect.Slice:
if val.Elem().Kind() == reflect.Uint8 {
return "blob"
}
}
switch val.Name() {
case "NullInt64":
return "integer"
case "NullFloat64":
return "real"
case "NullBool":
return "integer"
case "Time":
return "datetime"
}
if maxsize < 1 {
maxsize = 255
}
return fmt.Sprintf("varchar(%d)", maxsize)
}
// Returns autoincrement
func (d SqliteDialect) AutoIncrStr() string {
return "autoincrement"
}
func (d SqliteDialect) AutoIncrBindValue() string {
return "null"
}
func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
return ""
}
// Returns suffix
func (d SqliteDialect) CreateTableSuffix() string {
return d.suffix
}
func (d SqliteDialect) CreateIndexSuffix() string {
return ""
}
func (d SqliteDialect) DropIndexSuffix() string {
return ""
}
// With sqlite, there technically isn't a TRUNCATE statement,
// but a DELETE FROM uses a truncate optimization:
// http://www.sqlite.org/lang_delete.html
func (d SqliteDialect) TruncateClause() string {
return "delete from"
}
// Returns "?"
func (d SqliteDialect) BindVar(i int) string {
return "?"
}
func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
return standardInsertAutoIncr(exec, insertSql, params...)
}
func (d SqliteDialect) QuoteField(f string) string {
return `"` + f + `"`
}
// sqlite does not have schemas like PostgreSQL does, so just escape it like normal
func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string {
return d.QuoteField(table)
}
func (d SqliteDialect) IfSchemaNotExists(command, schema string) string {
return fmt.Sprintf("%s if not exists", command)
}
func (d SqliteDialect) IfTableExists(command, schema, table string) string {
return fmt.Sprintf("%s if exists", command)
}
func (d SqliteDialect) IfTableNotExists(command, schema, table string) string {
return fmt.Sprintf("%s if not exists", command)
}

13
gdb/sqldb/doc.go Normal file
View File

@ -0,0 +1,13 @@
//
// doc.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
// Package sqldb provides a simple way to marshal Go structs to and from
// SQL databases. It uses the database/sql package, and should work with any
// compliant database/sql driver.
//
// Source code and project home:
package sqldb

34
gdb/sqldb/errors.go Normal file
View File

@ -0,0 +1,34 @@
//
// errors.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"fmt"
)
// A non-fatal error, when a select query returns columns that do not exist
// as fields in the struct it is being mapped to
// TODO: discuss wether this needs an error. encoding/json silently ignores missing fields
type NoFieldInTypeError struct {
TypeName string
MissingColNames []string
}
func (err *NoFieldInTypeError) Error() string {
return fmt.Sprintf("sqldb: no fields %+v in type %s", err.MissingColNames, err.TypeName)
}
// returns true if the error is non-fatal (ie, we shouldn't immediately return)
func NonFatalError(err error) bool {
switch err.(type) {
case *NoFieldInTypeError:
return true
default:
return false
}
}

45
gdb/sqldb/hooks.go Normal file
View File

@ -0,0 +1,45 @@
//
// hooks.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
//++ TODO v2-phase3: HasPostGet => PostGetter, HasPostDelete => PostDeleter, etc.
// HasPostGet provides PostGet() which will be executed after the GET statement.
type HasPostGet interface {
PostGet(SqlExecutor) error
}
// HasPostDelete provides PostDelete() which will be executed after the DELETE statement
type HasPostDelete interface {
PostDelete(SqlExecutor) error
}
// HasPostUpdate provides PostUpdate() which will be executed after the UPDATE statement
type HasPostUpdate interface {
PostUpdate(SqlExecutor) error
}
// HasPostInsert provides PostInsert() which will be executed after the INSERT statement
type HasPostInsert interface {
PostInsert(SqlExecutor) error
}
// HasPreDelete provides PreDelete() which will be executed before the DELETE statement.
type HasPreDelete interface {
PreDelete(SqlExecutor) error
}
// HasPreUpdate provides PreUpdate() which will be executed before UPDATE statement.
type HasPreUpdate interface {
PreUpdate(SqlExecutor) error
}
// HasPreInsert provides PreInsert() which will be executed before INSERT statement.
type HasPreInsert interface {
PreInsert(SqlExecutor) error
}

51
gdb/sqldb/index.go Normal file
View File

@ -0,0 +1,51 @@
//
// index.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
// IndexMap represents a mapping between a Go struct field and a single
// index in a table.
// Unique and MaxSize only inform the
// CreateTables() function and are not used by Insert/Update/Delete/Get.
type IndexMap struct {
// Index name in db table
IndexName string
// If true, " unique" is added to create index statements.
// Not used elsewhere
Unique bool
// Index type supported by Dialect
// Postgres: B-tree, Hash, GiST and GIN.
// Mysql: Btree, Hash.
// Sqlite: nil.
IndexType string
// Columns name for single and multiple indexes
columns []string
}
// Rename allows you to specify the index name in the table
//
// Example: table.IndMap("customer_test_idx").Rename("customer_idx")
func (idx *IndexMap) Rename(indname string) *IndexMap {
idx.IndexName = indname
return idx
}
// SetUnique adds "unique" to the create index statements for this
// index, if b is true.
func (idx *IndexMap) SetUnique(b bool) *IndexMap {
idx.Unique = b
return idx
}
// SetIndexType specifies the index type supported by chousen SQL Dialect
func (idx *IndexMap) SetIndexType(indtype string) *IndexMap {
idx.IndexType = indtype
return idx
}

59
gdb/sqldb/lockerror.go Normal file
View File

@ -0,0 +1,59 @@
//
// lockerror.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"fmt"
"reflect"
)
// OptimisticLockError is returned by Update() or Delete() if the
// struct being modified has a Version field and the value is not equal to
// the current value in the database
type OptimisticLockError struct {
// Table name where the lock error occurred
TableName string
// Primary key values of the row being updated/deleted
Keys []interface{}
// true if a row was found with those keys, indicating the
// LocalVersion is stale. false if no value was found with those
// keys, suggesting the row has been deleted since loaded, or
// was never inserted to begin with
RowExists bool
// Version value on the struct passed to Update/Delete. This value is
// out of sync with the database.
LocalVersion int64
}
// Error returns a description of the cause of the lock error
func (e OptimisticLockError) Error() string {
if e.RowExists {
return fmt.Sprintf("sqldb: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion)
}
return fmt.Sprintf("sqldb: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys)
}
func lockError(m *DbMap, exec SqlExecutor, tableName string,
existingVer int64, elem reflect.Value,
keys ...interface{}) (int64, error) {
existing, err := get(m, exec, elem.Interface(), keys...)
if err != nil {
return -1, err
}
ole := OptimisticLockError{tableName, keys, true, existingVer}
if existing == nil {
ole.RowExists = false
}
return -1, ole
}

45
gdb/sqldb/logging.go Normal file
View File

@ -0,0 +1,45 @@
//
// logging.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import "fmt"
// SqldbLogger is a deprecated alias of Logger.
type SqldbLogger = Logger
// Logger is the type that sqldb uses to log SQL statements.
// See DbMap.TraceOn.
type Logger interface {
Printf(format string, v ...interface{})
}
// TraceOn turns on SQL statement logging for this DbMap. After this is
// called, all SQL statements will be sent to the logger. If prefix is
// a non-empty string, it will be written to the front of all logged
// strings, which can aid in filtering log lines.
//
// Use TraceOn if you want to spy on the SQL statements that sqldb
// generates.
//
// Note that the base log.Logger type satisfies Logger, but adapters can
// easily be written for other logging packages (e.g., the golang-sanctioned
// glog framework).
func (m *DbMap) TraceOn(prefix string, logger Logger) {
m.logger = logger
if prefix == "" {
m.logPrefix = prefix
} else {
m.logPrefix = fmt.Sprintf("%s ", prefix)
}
}
// TraceOff turns off tracing. It is idempotent.
func (m *DbMap) TraceOff() {
m.logger = nil
m.logPrefix = ""
}

68
gdb/sqldb/nulltypes.go Normal file
View File

@ -0,0 +1,68 @@
//
// nulltypes.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"database/sql/driver"
"log"
"time"
)
// A nullable Time value
type NullTime struct {
Time time.Time
Valid bool // Valid is true if Time is not NULL
}
// Scan implements the Scanner interface.
func (nt *NullTime) Scan(value interface{}) error {
log.Printf("Time scan value is: %#v", value)
switch t := value.(type) {
case time.Time:
nt.Time, nt.Valid = t, true
case []byte:
v := strToTime(string(t))
if v != nil {
nt.Valid = true
nt.Time = *v
}
case string:
v := strToTime(t)
if v != nil {
nt.Valid = true
nt.Time = *v
}
}
return nil
}
func strToTime(v string) *time.Time {
for _, dtfmt := range []string{
"2006-01-02 15:04:05.999999999",
"2006-01-02T15:04:05.999999999",
"2006-01-02 15:04:05",
"2006-01-02T15:04:05",
"2006-01-02 15:04",
"2006-01-02T15:04",
"2006-01-02",
"2006-01-02 15:04:05-07:00",
} {
if t, err := time.Parse(dtfmt, v); err == nil {
return &t
}
}
return nil
}
// Value implements the driver Valuer interface.
func (nt NullTime) Value() (driver.Value, error) {
if !nt.Valid {
return nil, nil
}
return nt.Time, nil
}

361
gdb/sqldb/select.go Normal file
View File

@ -0,0 +1,361 @@
//
// select.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"database/sql"
"fmt"
"reflect"
)
// SelectInt executes the given query, which should be a SELECT statement for a single
// integer column, and returns the value of the first row returned. If no rows are
// found, zero is returned.
func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) {
var h int64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return h, nil
}
// SelectNullInt executes the given query, which should be a SELECT statement for a single
// integer column, and returns the value of the first row returned. If no rows are
// found, the empty sql.NullInt64 value is returned.
func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) {
var h sql.NullInt64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
}
// SelectFloat executes the given query, which should be a SELECT statement for a single
// float column, and returns the value of the first row returned. If no rows are
// found, zero is returned.
func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) {
var h float64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return 0, err
}
return h, nil
}
// SelectNullFloat executes the given query, which should be a SELECT statement for a single
// float column, and returns the value of the first row returned. If no rows are
// found, the empty sql.NullInt64 value is returned.
func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) {
var h sql.NullFloat64
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
}
// SelectStr executes the given query, which should be a SELECT statement for a single
// char/varchar column, and returns the value of the first row returned. If no rows are
// found, an empty string is returned.
func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) {
var h string
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return "", err
}
return h, nil
}
// SelectNullStr executes the given query, which should be a SELECT
// statement for a single char/varchar column, and returns the value
// of the first row returned. If no rows are found, the empty
// sql.NullString is returned.
func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) {
var h sql.NullString
err := selectVal(e, &h, query, args...)
if err != nil && err != sql.ErrNoRows {
return h, err
}
return h, nil
}
// SelectOne executes the given query (which should be a SELECT statement)
// and binds the result to holder, which must be a pointer.
//
// # If no row is found, an error (sql.ErrNoRows specifically) will be returned
//
// If more than one row is found, an error will be returned.
func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
t := reflect.TypeOf(holder)
if t.Kind() == reflect.Ptr {
t = t.Elem()
} else {
return fmt.Errorf("sqldb: SelectOne holder must be a pointer, but got: %t", holder)
}
// Handle pointer to pointer
isptr := false
if t.Kind() == reflect.Ptr {
isptr = true
t = t.Elem()
}
if t.Kind() == reflect.Struct {
var nonFatalErr error
list, err := hookedselect(m, e, holder, query, args...)
if err != nil {
if !NonFatalError(err) { // FIXME: double negative, rename NonFatalError to FatalError
return err
}
nonFatalErr = err
}
dest := reflect.ValueOf(holder)
if isptr {
dest = dest.Elem()
}
if list != nil && len(list) > 0 { // FIXME: invert if/else
// check for multiple rows
if len(list) > 1 {
return fmt.Errorf("sqldb: multiple rows returned for: %s - %v", query, args)
}
// Initialize if nil
if dest.IsNil() {
dest.Set(reflect.New(t))
}
// only one row found
src := reflect.ValueOf(list[0])
dest.Elem().Set(src.Elem())
} else {
// No rows found, return a proper error.
return sql.ErrNoRows
}
return nonFatalErr
}
return selectVal(e, holder, query, args...)
}
func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
if len(args) == 1 {
switch m := e.(type) {
case *DbMap:
query, args = maybeExpandNamedQuery(m, query, args)
case *Transaction:
query, args = maybeExpandNamedQuery(m.dbmap, query, args)
}
}
rows, err := e.Query(query, args...)
if err != nil {
return err
}
defer rows.Close()
if !rows.Next() {
if err := rows.Err(); err != nil {
return err
}
return sql.ErrNoRows
}
return rows.Scan(holder)
}
func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
args ...interface{}) ([]interface{}, error) {
var nonFatalErr error
list, err := rawselect(m, exec, i, query, args...)
if err != nil {
if !NonFatalError(err) {
return nil, err
}
nonFatalErr = err
}
// Determine where the results are: written to i, or returned in list
if t, _ := toSliceType(i); t == nil {
for _, v := range list {
if v, ok := v.(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
}
} else {
resultsValue := reflect.Indirect(reflect.ValueOf(i))
for i := 0; i < resultsValue.Len(); i++ {
if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
}
}
return list, nonFatalErr
}
func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
args ...interface{}) ([]interface{}, error) {
var (
appendToSlice = false // Write results to i directly?
intoStruct = true // Selecting into a struct?
pointerElements = true // Are the slice elements pointers (vs values)?
)
var nonFatalErr error
tableName := ""
var dynObj DynamicTable
isDynamic := false
if dynObj, isDynamic = i.(DynamicTable); isDynamic {
tableName = dynObj.TableName()
}
// get type for i, verifying it's a supported destination
t, err := toType(i)
if err != nil {
var err2 error
if t, err2 = toSliceType(i); t == nil {
if err2 != nil {
return nil, err2
}
return nil, err
}
pointerElements = t.Kind() == reflect.Ptr
if pointerElements {
t = t.Elem()
}
appendToSlice = true
intoStruct = t.Kind() == reflect.Struct
}
// If the caller supplied a single struct/map argument, assume a "named
// parameter" query. Extract the named arguments from the struct/map, create
// the flat arg slice, and rewrite the query to use the dialect's placeholder.
if len(args) == 1 {
query, args = maybeExpandNamedQuery(m, query, args)
}
// Run the query
rows, err := exec.Query(query, args...)
if err != nil {
return nil, err
}
defer rows.Close()
// Fetch the column names as returned from db
cols, err := rows.Columns()
if err != nil {
return nil, err
}
if !intoStruct && len(cols) > 1 {
return nil, fmt.Errorf("sqldb: select into non-struct slice requires 1 column, got %d", len(cols))
}
var colToFieldIndex [][]int
if intoStruct {
colToFieldIndex, err = columnToFieldIndex(m, t, tableName, cols)
if err != nil {
if !NonFatalError(err) {
return nil, err
}
nonFatalErr = err
}
}
conv := m.TypeConverter
// Add results to one of these two slices.
var (
list = make([]interface{}, 0)
sliceValue = reflect.Indirect(reflect.ValueOf(i))
)
for {
if !rows.Next() {
// if error occured return rawselect
if rows.Err() != nil {
return nil, rows.Err()
}
// time to exit from outer "for" loop
break
}
v := reflect.New(t)
if isDynamic {
v.Interface().(DynamicTable).SetTableName(tableName)
}
dest := make([]interface{}, len(cols))
custScan := make([]CustomScanner, 0)
for x := range cols {
f := v.Elem()
if intoStruct {
index := colToFieldIndex[x]
if index == nil {
// this field is not present in the struct, so create a dummy
// value for rows.Scan to scan into
var dummy dummyField
dest[x] = &dummy
continue
}
f = f.FieldByIndex(index)
}
target := f.Addr().Interface()
if conv != nil {
scanner, ok := conv.FromDb(target)
if ok {
target = scanner.Holder
custScan = append(custScan, scanner)
}
}
dest[x] = target
}
err = rows.Scan(dest...)
if err != nil {
return nil, err
}
for _, c := range custScan {
err = c.Bind()
if err != nil {
return nil, err
}
}
if appendToSlice {
if !pointerElements {
v = v.Elem()
}
sliceValue.Set(reflect.Append(sliceValue, v))
} else {
list = append(list, v.Interface())
}
}
if appendToSlice && sliceValue.IsNil() {
sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0))
}
return list, nonFatalErr
}

675
gdb/sqldb/sqldb.go Normal file
View File

@ -0,0 +1,675 @@
//
// sqldb.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"context"
"database/sql"
"database/sql/driver"
"fmt"
"reflect"
"regexp"
"strings"
"time"
)
// OracleString (empty string is null)
// TODO: move to dialect/oracle?, rename to String?
type OracleString struct {
sql.NullString
}
// Scan implements the Scanner interface.
func (os *OracleString) Scan(value interface{}) error {
if value == nil {
os.String, os.Valid = "", false
return nil
}
os.Valid = true
return os.NullString.Scan(value)
}
// Value implements the driver Valuer interface.
func (os OracleString) Value() (driver.Value, error) {
if !os.Valid || os.String == "" {
return nil, nil
}
return os.String, nil
}
// SqlTyper is a type that returns its database type. Most of the
// time, the type can just use "database/sql/driver".Valuer; but when
// it returns nil for its empty value, it needs to implement SqlTyper
// to have its column type detected properly during table creation.
type SqlTyper interface {
SqlType() driver.Value
}
// legacySqlTyper prevents breaking clients who depended on the previous
// SqlTyper interface
type legacySqlTyper interface {
SqlType() driver.Valuer
}
// for fields that exists in DB table, but not exists in struct
type dummyField struct{}
// Scan implements the Scanner interface.
func (nt *dummyField) Scan(value interface{}) error {
return nil
}
var zeroVal reflect.Value
var versFieldConst = "[sqldb_ver_field]"
// The TypeConverter interface provides a way to map a value of one
// type to another type when persisting to, or loading from, a database.
//
// Example use cases: Implement type converter to convert bool types to "y"/"n" strings,
// or serialize a struct member as a JSON blob.
type TypeConverter interface {
// ToDb converts val to another type. Called before INSERT/UPDATE operations
ToDb(val interface{}) (interface{}, error)
// FromDb returns a CustomScanner appropriate for this type. This will be used
// to hold values returned from SELECT queries.
//
// In particular the CustomScanner returned should implement a Binder
// function appropriate for the Go type you wish to convert the db value to
//
// If bool==false, then no custom scanner will be used for this field.
FromDb(target interface{}) (CustomScanner, bool)
}
// SqlExecutor exposes sqldb operations that can be run from Pre/Post
// hooks. This hides whether the current operation that triggered the
// hook is in a transaction.
//
// See the DbMap function docs for each of the functions below for more
// information.
type SqlExecutor interface {
WithContext(ctx context.Context) SqlExecutor
Get(i interface{}, keys ...interface{}) (interface{}, error)
Insert(list ...interface{}) error
Update(list ...interface{}) (int64, error)
Delete(list ...interface{}) (int64, error)
Exec(query string, args ...interface{}) (sql.Result, error)
Select(i interface{}, query string, args ...interface{}) ([]interface{}, error)
SelectInt(query string, args ...interface{}) (int64, error)
SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
SelectFloat(query string, args ...interface{}) (float64, error)
SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error)
SelectStr(query string, args ...interface{}) (string, error)
SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
SelectOne(holder interface{}, query string, args ...interface{}) error
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
// DynamicTable allows the users of sqldb to dynamically
// use different database table names during runtime
// while sharing the same golang struct for in-memory data
type DynamicTable interface {
TableName() string
SetTableName(string)
}
// Compile-time check that DbMap and Transaction implement the SqlExecutor
// interface.
var _, _ SqlExecutor = &DbMap{}, &Transaction{}
func argValue(a interface{}) interface{} {
v, ok := a.(driver.Valuer)
if !ok {
return a
}
vV := reflect.ValueOf(v)
if vV.Kind() == reflect.Ptr && vV.IsNil() {
return nil
}
ret, err := v.Value()
if err != nil {
return a
}
return ret
}
func argsString(args ...interface{}) string {
var margs string
for i, a := range args {
v := argValue(a)
switch v.(type) {
case string:
v = fmt.Sprintf("%q", v)
default:
v = fmt.Sprintf("%v", v)
}
margs += fmt.Sprintf("%d:%s", i+1, v)
if i+1 < len(args) {
margs += " "
}
}
return margs
}
// Calls the Exec function on the executor, but attempts to expand any eligible named
// query arguments first.
func maybeExpandNamedQueryAndExec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
dbMap := extractDbMap(e)
if len(args) == 1 {
query, args = maybeExpandNamedQuery(dbMap, query, args)
}
return exec(e, query, args...)
}
func extractDbMap(e SqlExecutor) *DbMap {
switch m := e.(type) {
case *DbMap:
return m
case *Transaction:
return m.dbmap
}
return nil
}
// executor exposes the sql.DB and sql.Tx functions so that it can be used
// on internal functions that need to be agnostic to the underlying object.
type executor interface {
Exec(query string, args ...interface{}) (sql.Result, error)
Prepare(query string) (*sql.Stmt, error)
QueryRow(query string, args ...interface{}) *sql.Row
Query(query string, args ...interface{}) (*sql.Rows, error)
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
func extractExecutorAndContext(e SqlExecutor) (executor, context.Context) {
switch m := e.(type) {
case *DbMap:
return m.Db, m.ctx
case *Transaction:
return m.tx, m.ctx
}
return nil, nil
}
// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used
// as input to a named query. If so, it rewrites the query to use
// dialect-dependent bindvars and instantiates the corresponding slice of
// parameters by extracting data from the map / struct.
// If not, returns the input values unchanged.
func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) {
var (
arg = args[0]
argval = reflect.ValueOf(arg)
)
if argval.Kind() == reflect.Ptr {
argval = argval.Elem()
}
if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String {
return expandNamedQuery(m, query, func(key string) reflect.Value {
return argval.MapIndex(reflect.ValueOf(key))
})
}
if argval.Kind() != reflect.Struct {
return query, args
}
if _, ok := arg.(time.Time); ok {
// time.Time is driver.Value
return query, args
}
if _, ok := arg.(driver.Valuer); ok {
// driver.Valuer will be converted to driver.Value.
return query, args
}
return expandNamedQuery(m, query, argval.FieldByName)
}
var keyRegexp = regexp.MustCompile(`:[[:word:]]+`)
// expandNamedQuery accepts a query with placeholders of the form ":key", and a
// single arg of Kind Struct or Map[string]. It returns the query with the
// dialect's placeholders, and a slice of args ready for positional insertion
// into the query.
func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) {
var (
n int
args []interface{}
)
return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
val := keyGetter(key[1:])
if !val.IsValid() {
return key
}
args = append(args, val.Interface())
newVar := m.Dialect.BindVar(n)
n++
return newVar
}), args
}
func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) {
colToFieldIndex := make([][]int, len(cols))
// check if type t is a mapped table - if so we'll
// check the table for column aliasing below
tableMapped := false
table := tableOrNil(m, t, name)
if table != nil {
tableMapped = true
}
// Loop over column names and find field in i to bind to
// based on column name. all returned columns must match
// a field in the i struct
missingColNames := []string{}
for x := range cols {
colName := strings.ToLower(cols[x])
field, found := t.FieldByNameFunc(func(fieldName string) bool {
field, _ := t.FieldByName(fieldName)
cArguments := strings.Split(field.Tag.Get("db"), ",")
fieldName = cArguments[0]
if fieldName == "-" {
return false
} else if fieldName == "" {
fieldName = field.Name
}
if tableMapped {
colMap := colMapOrNil(table, fieldName)
if colMap != nil {
fieldName = colMap.ColumnName
}
}
return colName == strings.ToLower(fieldName)
})
if found {
colToFieldIndex[x] = field.Index
}
if colToFieldIndex[x] == nil {
missingColNames = append(missingColNames, colName)
}
}
if len(missingColNames) > 0 {
return colToFieldIndex, &NoFieldInTypeError{
TypeName: t.Name(),
MissingColNames: missingColNames,
}
}
return colToFieldIndex, nil
}
func fieldByName(val reflect.Value, fieldName string) *reflect.Value {
// try to find field by exact match
f := val.FieldByName(fieldName)
if f != zeroVal {
return &f
}
// try to find by case insensitive match - only the Postgres driver
// seems to require this - in the case where columns are aliased in the sql
fieldNameL := strings.ToLower(fieldName)
fieldCount := val.NumField()
t := val.Type()
for i := 0; i < fieldCount; i++ {
sf := t.Field(i)
if strings.ToLower(sf.Name) == fieldNameL {
f := val.Field(i)
return &f
}
}
return nil
}
// toSliceType returns the element type of the given object, if the object is a
// "*[]*Element" or "*[]Element". If not, returns nil.
// err is returned if the user was trying to pass a pointer-to-slice but failed.
func toSliceType(i interface{}) (reflect.Type, error) {
t := reflect.TypeOf(i)
if t.Kind() != reflect.Ptr {
// If it's a slice, return a more helpful error message
if t.Kind() == reflect.Slice {
return nil, fmt.Errorf("sqldb: cannot SELECT into a non-pointer slice: %v", t)
}
return nil, nil
}
if t = t.Elem(); t.Kind() != reflect.Slice {
return nil, nil
}
return t.Elem(), nil
}
func toType(i interface{}) (reflect.Type, error) {
t := reflect.TypeOf(i)
// If a Pointer to a type, follow
for t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, fmt.Errorf("sqldb: cannot SELECT into this type: %v", reflect.TypeOf(i))
}
return t, nil
}
type foundTable struct {
table *TableMap
dynName *string
}
func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) {
if dyn, isDynamic := i.(DynamicTable); isDynamic {
tableName := dyn.TableName()
table, err := m.DynamicTableFor(tableName, true)
if err != nil {
return nil, err
}
return &foundTable{
table: table,
dynName: &tableName,
}, nil
}
table, err := m.TableFor(t, true)
if err != nil {
return nil, err
}
return &foundTable{table: table}, nil
}
func get(m *DbMap, exec SqlExecutor, i interface{},
keys ...interface{}) (interface{}, error) {
t, err := toType(i)
if err != nil {
return nil, err
}
foundTable, err := tableFor(m, t, i)
if err != nil {
return nil, err
}
table := foundTable.table
plan := table.bindGet()
v := reflect.New(t)
if foundTable.dynName != nil {
retDyn := v.Interface().(DynamicTable)
retDyn.SetTableName(*foundTable.dynName)
}
dest := make([]interface{}, len(plan.argFields))
conv := m.TypeConverter
custScan := make([]CustomScanner, 0)
for x, fieldName := range plan.argFields {
f := v.Elem().FieldByName(fieldName)
target := f.Addr().Interface()
if conv != nil {
scanner, ok := conv.FromDb(target)
if ok {
target = scanner.Holder
custScan = append(custScan, scanner)
}
}
dest[x] = target
}
row := exec.QueryRow(plan.query, keys...)
err = row.Scan(dest...)
if err != nil {
if err == sql.ErrNoRows {
err = nil
}
return nil, err
}
for _, c := range custScan {
err = c.Bind()
if err != nil {
return nil, err
}
}
if v, ok := v.Interface().(HasPostGet); ok {
err := v.PostGet(exec)
if err != nil {
return nil, err
}
}
return v.Interface(), nil
}
func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreDelete); ok {
err = v.PreDelete(exec)
if err != nil {
return -1, err
}
}
bi, err := table.bindDelete(elem)
if err != nil {
return -1, err
}
res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, exec, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
count += rows
if v, ok := eval.(HasPostDelete); ok {
err := v.PostDelete(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) {
count := int64(0)
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, true)
if err != nil {
return -1, err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreUpdate); ok {
err = v.PreUpdate(exec)
if err != nil {
return -1, err
}
}
bi, err := table.bindUpdate(elem, colFilter)
if err != nil {
return -1, err
}
res, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return -1, err
}
rows, err := res.RowsAffected()
if err != nil {
return -1, err
}
if rows == 0 && bi.existingVersion > 0 {
return lockError(m, exec, table.TableName,
bi.existingVersion, elem, bi.keys...)
}
if bi.versField != "" {
elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
}
count += rows
if v, ok := eval.(HasPostUpdate); ok {
err = v.PostUpdate(exec)
if err != nil {
return -1, err
}
}
}
return count, nil
}
func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
for _, ptr := range list {
table, elem, err := m.tableForPointer(ptr, false)
if err != nil {
return err
}
eval := elem.Addr().Interface()
if v, ok := eval.(HasPreInsert); ok {
err := v.PreInsert(exec)
if err != nil {
return err
}
}
bi, err := table.bindInsert(elem)
if err != nil {
return err
}
if bi.autoIncrIdx > -1 {
f := elem.FieldByName(bi.autoIncrFieldName)
switch inserter := m.Dialect.(type) {
case IntegerAutoIncrInserter:
id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
if err != nil {
return err
}
k := f.Kind()
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
f.SetInt(id)
} else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
f.SetUint(uint64(id))
} else {
return fmt.Errorf("sqldb: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
}
case TargetedAutoIncrInserter:
err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
if err != nil {
return err
}
case TargetQueryInserter:
var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery
if idQuery == "" {
return fmt.Errorf("sqldb: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName)
}
err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...)
if err != nil {
return err
}
default:
return fmt.Errorf("sqldb: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
}
} else {
_, err := exec.Exec(bi.query, bi.args...)
if err != nil {
return err
}
}
if v, ok := eval.(HasPostInsert); ok {
err := v.PostInsert(exec)
if err != nil {
return err
}
}
}
return nil
}
func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
executor, ctx := extractExecutorAndContext(e)
if ctx != nil {
return executor.ExecContext(ctx, query, args...)
}
return executor.Exec(query, args...)
}
func prepare(e SqlExecutor, query string) (*sql.Stmt, error) {
executor, ctx := extractExecutorAndContext(e)
if ctx != nil {
return executor.PrepareContext(ctx, query)
}
return executor.Prepare(query)
}
func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row {
executor, ctx := extractExecutorAndContext(e)
if ctx != nil {
return executor.QueryRowContext(ctx, query, args...)
}
return executor.QueryRow(query, args...)
}
func query(e SqlExecutor, query string, args ...interface{}) (*sql.Rows, error) {
executor, ctx := extractExecutorAndContext(e)
if ctx != nil {
return executor.QueryContext(ctx, query, args...)
}
return executor.Query(query, args...)
}
func begin(m *DbMap) (*sql.Tx, error) {
if m.ctx != nil {
return m.Db.BeginTx(m.ctx, nil)
}
return m.Db.Begin()
}

2875
gdb/sqldb/sqldb_test.go Normal file

File diff suppressed because it is too large Load Diff

258
gdb/sqldb/table.go Normal file
View File

@ -0,0 +1,258 @@
//
// table.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"bytes"
"fmt"
"reflect"
"strings"
)
// TableMap represents a mapping between a Go struct and a database table
// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
type TableMap struct {
// Name of database table.
TableName string
SchemaName string
gotype reflect.Type
Columns []*ColumnMap
keys []*ColumnMap
indexes []*IndexMap
uniqueTogether [][]string
version *ColumnMap
insertPlan bindPlan
updatePlan bindPlan
deletePlan bindPlan
getPlan bindPlan
dbmap *DbMap
}
// ResetSql removes cached insert/update/select/delete SQL strings
// associated with this TableMap. Call this if you've modified
// any column names or the table name itself.
func (t *TableMap) ResetSql() {
t.insertPlan = bindPlan{}
t.updatePlan = bindPlan{}
t.deletePlan = bindPlan{}
t.getPlan = bindPlan{}
}
// SetKeys lets you specify the fields on a struct that map to primary
// key columns on the table. If isAutoIncr is set, result.LastInsertId()
// will be used after INSERT to bind the generated id to the Go struct.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
//
// Panics if isAutoIncr is true, and fieldNames length != 1
func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
if isAutoIncr && len(fieldNames) != 1 {
panic(fmt.Sprintf(
"sqldb: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)",
len(fieldNames)))
}
t.keys = make([]*ColumnMap, 0)
for _, name := range fieldNames {
colmap := t.ColMap(name)
colmap.isPK = true
colmap.isAutoIncr = isAutoIncr
t.keys = append(t.keys, colmap)
}
t.ResetSql()
return t
}
// SetUniqueTogether lets you specify uniqueness constraints across multiple
// columns on the table. Each call adds an additional constraint for the
// specified columns.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
//
// Panics if fieldNames length < 2.
func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap {
if len(fieldNames) < 2 {
panic(fmt.Sprintf(
"sqldb: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint."))
}
columns := make([]string, 0, len(fieldNames))
for _, name := range fieldNames {
columns = append(columns, name)
}
for _, existingColumns := range t.uniqueTogether {
if equal(existingColumns, columns) {
return t
}
}
t.uniqueTogether = append(t.uniqueTogether, columns)
t.ResetSql()
return t
}
// ColMap returns the ColumnMap pointer matching the given struct field
// name. It panics if the struct does not contain a field matching this
// name.
func (t *TableMap) ColMap(field string) *ColumnMap {
col := colMapOrNil(t, field)
if col == nil {
e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
t.TableName, t.gotype.Name(), field)
panic(e)
}
return col
}
func colMapOrNil(t *TableMap, field string) *ColumnMap {
for _, col := range t.Columns {
if col.fieldName == field || col.ColumnName == field {
return col
}
}
return nil
}
// IdxMap returns the IndexMap pointer matching the given index name.
func (t *TableMap) IdxMap(field string) *IndexMap {
for _, idx := range t.indexes {
if idx.IndexName == field {
return idx
}
}
return nil
}
// AddIndex registers the index with sqldb for specified table with given parameters.
// This operation is idempotent. If index is already mapped, the
// existing *IndexMap is returned
// Function will panic if one of the given for index columns does not exists
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
func (t *TableMap) AddIndex(name string, idxtype string, columns []string) *IndexMap {
// check if we have a index with this name already
for _, idx := range t.indexes {
if idx.IndexName == name {
return idx
}
}
for _, icol := range columns {
if res := t.ColMap(icol); res == nil {
e := fmt.Sprintf("No ColumnName in table %s to create index on", t.TableName)
panic(e)
}
}
idx := &IndexMap{IndexName: name, Unique: false, IndexType: idxtype, columns: columns}
t.indexes = append(t.indexes, idx)
t.ResetSql()
return idx
}
// SetVersionCol sets the column to use as the Version field. By default
// the "Version" field is used. Returns the column found, or panics
// if the struct does not contain a field matching this name.
//
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
func (t *TableMap) SetVersionCol(field string) *ColumnMap {
c := t.ColMap(field)
t.version = c
t.ResetSql()
return c
}
// SqlForCreateTable gets a sequence of SQL commands that will create
// the specified table and any associated schema
func (t *TableMap) SqlForCreate(ifNotExists bool) string {
s := bytes.Buffer{}
dialect := t.dbmap.Dialect
if strings.TrimSpace(t.SchemaName) != "" {
schemaCreate := "create schema"
if ifNotExists {
s.WriteString(dialect.IfSchemaNotExists(schemaCreate, t.SchemaName))
} else {
s.WriteString(schemaCreate)
}
s.WriteString(fmt.Sprintf(" %s;", t.SchemaName))
}
tableCreate := "create table"
if ifNotExists {
s.WriteString(dialect.IfTableNotExists(tableCreate, t.SchemaName, t.TableName))
} else {
s.WriteString(tableCreate)
}
s.WriteString(fmt.Sprintf(" %s (", dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
x := 0
for _, col := range t.Columns {
if !col.Transient {
if x > 0 {
s.WriteString(", ")
}
stype := dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr)
s.WriteString(fmt.Sprintf("%s %s", dialect.QuoteField(col.ColumnName), stype))
if col.isPK || col.isNotNull {
s.WriteString(" not null")
}
if col.isPK && len(t.keys) == 1 {
s.WriteString(" primary key")
}
if col.Unique {
s.WriteString(" unique")
}
if col.isAutoIncr {
s.WriteString(fmt.Sprintf(" %s", dialect.AutoIncrStr()))
}
x++
}
}
if len(t.keys) > 1 {
s.WriteString(", primary key (")
for x := range t.keys {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(dialect.QuoteField(t.keys[x].ColumnName))
}
s.WriteString(")")
}
if len(t.uniqueTogether) > 0 {
for _, columns := range t.uniqueTogether {
s.WriteString(", unique (")
for i, column := range columns {
if i > 0 {
s.WriteString(", ")
}
s.WriteString(dialect.QuoteField(column))
}
s.WriteString(")")
}
}
s.WriteString(") ")
s.WriteString(dialect.CreateTableSuffix())
s.WriteString(dialect.QuerySuffix())
return s.String()
}
func equal(a, b []string) bool {
if len(a) != len(b) {
return false
}
for i := range a {
if a[i] != b[i] {
return false
}
}
return true
}

308
gdb/sqldb/table_bindings.go Normal file
View File

@ -0,0 +1,308 @@
//
// table_bindings.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"bytes"
"fmt"
"reflect"
"sync"
)
// CustomScanner binds a database column value to a Go type
type CustomScanner struct {
// After a row is scanned, Holder will contain the value from the database column.
// Initialize the CustomScanner with the concrete Go type you wish the database
// driver to scan the raw column into.
Holder interface{}
// Target typically holds a pointer to the target struct field to bind the Holder
// value to.
Target interface{}
// Binder is a custom function that converts the holder value to the target type
// and sets target accordingly. This function should return error if a problem
// occurs converting the holder to the target.
Binder func(holder interface{}, target interface{}) error
}
// Used to filter columns when selectively updating
type ColumnFilter func(*ColumnMap) bool
func acceptAllFilter(col *ColumnMap) bool {
return true
}
// Bind is called automatically by sqldb after Scan()
func (me CustomScanner) Bind() error {
return me.Binder(me.Holder, me.Target)
}
type bindPlan struct {
query string
argFields []string
keyFields []string
versField string
autoIncrIdx int
autoIncrFieldName string
once sync.Once
}
func (plan *bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) {
bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField}
if plan.versField != "" {
bi.existingVersion = elem.FieldByName(plan.versField).Int()
}
var err error
for i := 0; i < len(plan.argFields); i++ {
k := plan.argFields[i]
if k == versFieldConst {
newVer := bi.existingVersion + 1
bi.args = append(bi.args, newVer)
if bi.existingVersion == 0 {
elem.FieldByName(plan.versField).SetInt(int64(newVer))
}
} else {
val := elem.FieldByName(k).Interface()
if conv != nil {
val, err = conv.ToDb(val)
if err != nil {
return bindInstance{}, err
}
}
bi.args = append(bi.args, val)
}
}
for i := 0; i < len(plan.keyFields); i++ {
k := plan.keyFields[i]
val := elem.FieldByName(k).Interface()
if conv != nil {
val, err = conv.ToDb(val)
if err != nil {
return bindInstance{}, err
}
}
bi.keys = append(bi.keys, val)
}
return bi, nil
}
type bindInstance struct {
query string
args []interface{}
keys []interface{}
existingVersion int64
versField string
autoIncrIdx int
autoIncrFieldName string
}
func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) {
plan := &t.insertPlan
plan.once.Do(func() {
plan.autoIncrIdx = -1
s := bytes.Buffer{}
s2 := bytes.Buffer{}
s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
x := 0
first := true
for y := range t.Columns {
col := t.Columns[y]
if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") {
if !col.Transient {
if !first {
s.WriteString(",")
s2.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
if col.isAutoIncr {
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
plan.autoIncrIdx = y
plan.autoIncrFieldName = col.fieldName
} else {
if col.DefaultValue == "" {
s2.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
plan.versField = col.fieldName
plan.argFields = append(plan.argFields, versFieldConst)
} else {
plan.argFields = append(plan.argFields, col.fieldName)
}
x++
} else {
s2.WriteString(col.DefaultValue)
}
}
first = false
}
} else {
plan.autoIncrIdx = y
plan.autoIncrFieldName = col.fieldName
}
}
s.WriteString(") values (")
s.WriteString(s2.String())
s.WriteString(")")
if plan.autoIncrIdx > -1 {
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx]))
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
})
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}
func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) {
if colFilter == nil {
colFilter = acceptAllFilter
}
plan := &t.updatePlan
plan.once.Do(func() {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
x := 0
for y := range t.Columns {
col := t.Columns[y]
if !col.isAutoIncr && !col.Transient && colFilter(col) {
if x > 0 {
s.WriteString(", ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
if col == t.version {
plan.versField = col.fieldName
plan.argFields = append(plan.argFields, versFieldConst)
} else {
plan.argFields = append(plan.argFields, col.fieldName)
}
x++
}
}
s.WriteString(" where ")
for y := range t.keys {
col := t.keys[y]
if y > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, col.fieldName)
plan.keyFields = append(plan.keyFields, col.fieldName)
x++
}
if plan.versField != "" {
s.WriteString(" and ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
})
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}
func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) {
plan := &t.deletePlan
plan.once.Do(func() {
s := bytes.Buffer{}
s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
for y := range t.Columns {
col := t.Columns[y]
if !col.Transient {
if col == t.version {
plan.versField = col.fieldName
}
}
}
s.WriteString(" where ")
for x := range t.keys {
k := t.keys[x]
if x > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.keyFields = append(plan.keyFields, k.fieldName)
plan.argFields = append(plan.argFields, k.fieldName)
}
if plan.versField != "" {
s.WriteString(" and ")
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields)))
plan.argFields = append(plan.argFields, plan.versField)
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
})
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
}
func (t *TableMap) bindGet() *bindPlan {
plan := &t.getPlan
plan.once.Do(func() {
s := bytes.Buffer{}
s.WriteString("select ")
x := 0
for _, col := range t.Columns {
if !col.Transient {
if x > 0 {
s.WriteString(",")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
plan.argFields = append(plan.argFields, col.fieldName)
x++
}
}
s.WriteString(" from ")
s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))
s.WriteString(" where ")
for x := range t.keys {
col := t.keys[x]
if x > 0 {
s.WriteString(" and ")
}
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
s.WriteString("=")
s.WriteString(t.dbmap.Dialect.BindVar(x))
plan.keyFields = append(plan.keyFields, col.fieldName)
}
s.WriteString(t.dbmap.Dialect.QuerySuffix())
plan.query = s.String()
})
return plan
}

24
gdb/sqldb/test_all.sh Executable file
View File

@ -0,0 +1,24 @@
#!/bin/bash -ex
# on macs, you may need to:
# export GOBUILDFLAG=-ldflags -linkmode=external
echo "Running unit tests"
go test -race
echo "Testing against postgres"
export SQLDB_TEST_DSN="host=127.0.0.1 user=testuser password=123 dbname=testdb sslmode=disable"
export SQLDB_TEST_DIALECT=postgres
go test -tags integration $GOBUILDFLAG $@ .
echo "Testing against sqlite"
export SQLDB_TEST_DSN=/tmp/testdb.bin
export SQLDB_TEST_DIALECT=sqlite
go test -tags integration $GOBUILDFLAG $@ .
rm -f /tmp/testdb.bin
echo "Testing against mysql"
# export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local"
export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb"
export SQLDB_TEST_DIALECT=mysql
go test -tags integration $GOBUILDFLAG $@ .

242
gdb/sqldb/transaction.go Normal file
View File

@ -0,0 +1,242 @@
//
// transaction.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"context"
"database/sql"
"time"
)
// Transaction represents a database transaction.
// Insert/Update/Delete/Get/Exec operations will be run in the context
// of that transaction. Transactions should be terminated with
// a call to Commit() or Rollback()
type Transaction struct {
ctx context.Context
dbmap *DbMap
tx *sql.Tx
closed bool
}
func (t *Transaction) WithContext(ctx context.Context) SqlExecutor {
copy := &Transaction{}
*copy = *t
copy.ctx = ctx
return copy
}
// Insert has the same behavior as DbMap.Insert(), but runs in a transaction.
func (t *Transaction) Insert(list ...interface{}) error {
return insert(t.dbmap, t, list...)
}
// Update had the same behavior as DbMap.Update(), but runs in a transaction.
func (t *Transaction) Update(list ...interface{}) (int64, error) {
return update(t.dbmap, t, nil, list...)
}
// UpdateColumns had the same behavior as DbMap.UpdateColumns(), but runs in a transaction.
func (t *Transaction) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) {
return update(t.dbmap, t, filter, list...)
}
// Delete has the same behavior as DbMap.Delete(), but runs in a transaction.
func (t *Transaction) Delete(list ...interface{}) (int64, error) {
return delete(t.dbmap, t, list...)
}
// Get has the same behavior as DbMap.Get(), but runs in a transaction.
func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) {
return get(t.dbmap, t, i, keys...)
}
// Select has the same behavior as DbMap.Select(), but runs in a transaction.
func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return hookedselect(t.dbmap, t, i, query, args...)
}
// Exec has the same behavior as DbMap.Exec(), but runs in a transaction.
func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, args...)
}
return maybeExpandNamedQueryAndExec(t, query, args...)
}
// SelectInt is a convenience wrapper around the sqldb.SelectInt function.
func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectInt(t, query, args...)
}
// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function.
func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectNullInt(t, query, args...)
}
// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function.
func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectFloat(t, query, args...)
}
// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function.
func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectNullFloat(t, query, args...)
}
// SelectStr is a convenience wrapper around the sqldb.SelectStr function.
func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectStr(t, query, args...)
}
// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function.
func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectNullStr(t, query, args...)
}
// SelectOne is a convenience wrapper around the sqldb.SelectOne function.
func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
return SelectOne(t.dbmap, t, holder, query, args...)
}
// Commit commits the underlying database transaction.
func (t *Transaction) Commit() error {
if !t.closed {
t.closed = true
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, "commit;")
}
return t.tx.Commit()
}
return sql.ErrTxDone
}
// Rollback rolls back the underlying database transaction.
func (t *Transaction) Rollback() error {
if !t.closed {
t.closed = true
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, "rollback;")
}
return t.tx.Rollback()
}
return sql.ErrTxDone
}
// Savepoint creates a savepoint with the given name. The name is interpolated
// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is
// derived from user input.
func (t *Transaction) Savepoint(name string) error {
query := "savepoint " + t.dbmap.Dialect.QuoteField(name)
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, nil)
}
_, err := exec(t, query)
return err
}
// RollbackToSavepoint rolls back to the savepoint with the given name. The
// name is interpolated directly into the SQL SAVEPOINT statement, so you must
// sanitize it if it is derived from user input.
func (t *Transaction) RollbackToSavepoint(savepoint string) error {
query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, nil)
}
_, err := exec(t, query)
return err
}
// ReleaseSavepint releases the savepoint with the given name. The name is
// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize
// it if it is derived from user input.
func (t *Transaction) ReleaseSavepoint(savepoint string) error {
query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, nil)
}
_, err := exec(t, query)
return err
}
// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction.
func (t *Transaction) Prepare(query string) (*sql.Stmt, error) {
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, nil)
}
return prepare(t, query)
}
func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&query, args...)
}
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, query, args...)
}
return queryRow(t, query, args...)
}
func (t *Transaction) Query(q string, args ...interface{}) (*sql.Rows, error) {
if t.dbmap.ExpandSliceArgs {
expandSliceArgs(&q, args...)
}
if t.dbmap.logger != nil {
now := time.Now()
defer t.dbmap.trace(now, q, args...)
}
return query(t, q, args...)
}

View File

@ -0,0 +1,340 @@
//
// transaction_test.go
// Copyright (C) 2023 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
//go:build integration
// +build integration
package sqldb_test
import "testing"
func TestTransaction_Select_expandSliceArgs(t *testing.T) {
tests := []struct {
description string
query string
args []interface{}
wantLen int
}{
{
description: "it should handle slice placeholders correctly",
query: `
SELECT 1 FROM crazy_table
WHERE field1 = :Field1
AND field2 IN (:FieldStringList)
AND field3 IN (:FieldUIntList)
AND field4 IN (:FieldUInt8List)
AND field5 IN (:FieldUInt16List)
AND field6 IN (:FieldUInt32List)
AND field7 IN (:FieldUInt64List)
AND field8 IN (:FieldIntList)
AND field9 IN (:FieldInt8List)
AND field10 IN (:FieldInt16List)
AND field11 IN (:FieldInt32List)
AND field12 IN (:FieldInt64List)
AND field13 IN (:FieldFloat32List)
AND field14 IN (:FieldFloat64List)
`,
args: []interface{}{
map[string]interface{}{
"Field1": 123,
"FieldStringList": []string{"h", "e", "y"},
"FieldUIntList": []uint{1, 2, 3, 4},
"FieldUInt8List": []uint8{1, 2, 3, 4},
"FieldUInt16List": []uint16{1, 2, 3, 4},
"FieldUInt32List": []uint32{1, 2, 3, 4},
"FieldUInt64List": []uint64{1, 2, 3, 4},
"FieldIntList": []int{1, 2, 3, 4},
"FieldInt8List": []int8{1, 2, 3, 4},
"FieldInt16List": []int16{1, 2, 3, 4},
"FieldInt32List": []int32{1, 2, 3, 4},
"FieldInt64List": []int64{1, 2, 3, 4},
"FieldFloat32List": []float32{1, 2, 3, 4},
"FieldFloat64List": []float64{1, 2, 3, 4},
},
},
wantLen: 1,
},
{
description: "it should handle slice placeholders correctly with custom types",
query: `
SELECT 1 FROM crazy_table
WHERE field2 IN (:FieldStringList)
AND field12 IN (:FieldIntList)
`,
args: []interface{}{
map[string]interface{}{
"FieldStringList": customType1{"h", "e", "y"},
"FieldIntList": customType2{1, 2, 3, 4},
},
},
wantLen: 3,
},
}
type dataFormat struct {
Field1 int `db:"field1"`
Field2 string `db:"field2"`
Field3 uint `db:"field3"`
Field4 uint8 `db:"field4"`
Field5 uint16 `db:"field5"`
Field6 uint32 `db:"field6"`
Field7 uint64 `db:"field7"`
Field8 int `db:"field8"`
Field9 int8 `db:"field9"`
Field10 int16 `db:"field10"`
Field11 int32 `db:"field11"`
Field12 int64 `db:"field12"`
Field13 float32 `db:"field13"`
Field14 float64 `db:"field14"`
}
dbmap := newDBMap(t)
dbmap.ExpandSliceArgs = true
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
err = dbmap.Insert(
&dataFormat{
Field1: 123,
Field2: "h",
Field3: 1,
Field4: 1,
Field5: 1,
Field6: 1,
Field7: 1,
Field8: 1,
Field9: 1,
Field10: 1,
Field11: 1,
Field12: 1,
Field13: 1,
Field14: 1,
},
&dataFormat{
Field1: 124,
Field2: "e",
Field3: 2,
Field4: 2,
Field5: 2,
Field6: 2,
Field7: 2,
Field8: 2,
Field9: 2,
Field10: 2,
Field11: 2,
Field12: 2,
Field13: 2,
Field14: 2,
},
&dataFormat{
Field1: 125,
Field2: "y",
Field3: 3,
Field4: 3,
Field5: 3,
Field6: 3,
Field7: 3,
Field8: 3,
Field9: 3,
Field10: 3,
Field11: 3,
Field12: 3,
Field13: 3,
Field14: 3,
},
)
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
tx, err := dbmap.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
var dummy []int
_, err = tx.Select(&dummy, tt.query, tt.args...)
if err != nil {
t.Fatal(err)
}
if len(dummy) != tt.wantLen {
t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen)
}
})
}
}
func TestTransaction_Exec_expandSliceArgs(t *testing.T) {
tests := []struct {
description string
query string
args []interface{}
wantLen int
}{
{
description: "it should handle slice placeholders correctly",
query: `
DELETE FROM crazy_table
WHERE field1 = :Field1
AND field2 IN (:FieldStringList)
AND field3 IN (:FieldUIntList)
AND field4 IN (:FieldUInt8List)
AND field5 IN (:FieldUInt16List)
AND field6 IN (:FieldUInt32List)
AND field7 IN (:FieldUInt64List)
AND field8 IN (:FieldIntList)
AND field9 IN (:FieldInt8List)
AND field10 IN (:FieldInt16List)
AND field11 IN (:FieldInt32List)
AND field12 IN (:FieldInt64List)
AND field13 IN (:FieldFloat32List)
AND field14 IN (:FieldFloat64List)
`,
args: []interface{}{
map[string]interface{}{
"Field1": 123,
"FieldStringList": []string{"h", "e", "y"},
"FieldUIntList": []uint{1, 2, 3, 4},
"FieldUInt8List": []uint8{1, 2, 3, 4},
"FieldUInt16List": []uint16{1, 2, 3, 4},
"FieldUInt32List": []uint32{1, 2, 3, 4},
"FieldUInt64List": []uint64{1, 2, 3, 4},
"FieldIntList": []int{1, 2, 3, 4},
"FieldInt8List": []int8{1, 2, 3, 4},
"FieldInt16List": []int16{1, 2, 3, 4},
"FieldInt32List": []int32{1, 2, 3, 4},
"FieldInt64List": []int64{1, 2, 3, 4},
"FieldFloat32List": []float32{1, 2, 3, 4},
"FieldFloat64List": []float64{1, 2, 3, 4},
},
},
wantLen: 1,
},
{
description: "it should handle slice placeholders correctly with custom types",
query: `
DELETE FROM crazy_table
WHERE field2 IN (:FieldStringList)
AND field12 IN (:FieldIntList)
`,
args: []interface{}{
map[string]interface{}{
"FieldStringList": customType1{"h", "e", "y"},
"FieldIntList": customType2{1, 2, 3, 4},
},
},
wantLen: 3,
},
}
type dataFormat struct {
Field1 int `db:"field1"`
Field2 string `db:"field2"`
Field3 uint `db:"field3"`
Field4 uint8 `db:"field4"`
Field5 uint16 `db:"field5"`
Field6 uint32 `db:"field6"`
Field7 uint64 `db:"field7"`
Field8 int `db:"field8"`
Field9 int8 `db:"field9"`
Field10 int16 `db:"field10"`
Field11 int32 `db:"field11"`
Field12 int64 `db:"field12"`
Field13 float32 `db:"field13"`
Field14 float64 `db:"field14"`
}
dbmap := newDBMap(t)
dbmap.ExpandSliceArgs = true
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
err := dbmap.CreateTables()
if err != nil {
panic(err)
}
defer dropAndClose(dbmap)
err = dbmap.Insert(
&dataFormat{
Field1: 123,
Field2: "h",
Field3: 1,
Field4: 1,
Field5: 1,
Field6: 1,
Field7: 1,
Field8: 1,
Field9: 1,
Field10: 1,
Field11: 1,
Field12: 1,
Field13: 1,
Field14: 1,
},
&dataFormat{
Field1: 124,
Field2: "e",
Field3: 2,
Field4: 2,
Field5: 2,
Field6: 2,
Field7: 2,
Field8: 2,
Field9: 2,
Field10: 2,
Field11: 2,
Field12: 2,
Field13: 2,
Field14: 2,
},
&dataFormat{
Field1: 125,
Field2: "y",
Field3: 3,
Field4: 3,
Field5: 3,
Field6: 3,
Field7: 3,
Field8: 3,
Field9: 3,
Field10: 3,
Field11: 3,
Field12: 3,
Field13: 3,
Field14: 3,
},
)
if err != nil {
t.Fatal(err)
}
for _, tt := range tests {
t.Run(tt.description, func(t *testing.T) {
tx, err := dbmap.Begin()
if err != nil {
t.Fatal(err)
}
defer tx.Rollback()
_, err = tx.Exec(tt.query, tt.args...)
if err != nil {
t.Fatal(err)
}
})
}
}

8
go.mod
View File

@ -8,17 +8,20 @@ require (
github.com/go-redis/redis/v8 v8.11.5 github.com/go-redis/redis/v8 v8.11.5
github.com/go-sql-driver/mysql v1.7.1 github.com/go-sql-driver/mysql v1.7.1
github.com/hibiken/asynq v0.24.1 github.com/hibiken/asynq v0.24.1
github.com/jmoiron/sqlx v1.3.5
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible
github.com/lib/pq v1.10.9 github.com/lib/pq v1.10.9
github.com/mattn/go-runewidth v0.0.14 github.com/mattn/go-runewidth v0.0.14
github.com/mattn/go-sqlite3 v1.14.6
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/poy/onpar v0.3.2
github.com/rs/xid v1.5.0 github.com/rs/xid v1.5.0
github.com/stretchr/testify v1.8.3
go.mongodb.org/mongo-driver v1.11.7 go.mongodb.org/mongo-driver v1.11.7
go.uber.org/zap v1.25.0 go.uber.org/zap v1.25.0
golang.org/x/crypto v0.10.0 golang.org/x/crypto v0.10.0
gopkg.in/natefinch/lumberjack.v2 v2.2.1 gopkg.in/natefinch/lumberjack.v2 v2.2.1
gopkg.in/yaml.v2 v2.4.0 gopkg.in/yaml.v2 v2.4.0
gopkg.in/yaml.v3 v3.0.1
) )
require ( require (
@ -26,6 +29,7 @@ require (
github.com/bytedance/sonic v1.9.1 // indirect github.com/bytedance/sonic v1.9.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
github.com/gabriel-vasile/mimetype v1.4.2 // indirect github.com/gabriel-vasile/mimetype v1.4.2 // indirect
github.com/gin-contrib/sse v0.1.0 // indirect github.com/gin-contrib/sse v0.1.0 // indirect
@ -47,6 +51,7 @@ require (
github.com/modern-go/reflect2 v1.0.2 // indirect github.com/modern-go/reflect2 v1.0.2 // indirect
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
github.com/pelletier/go-toml/v2 v2.0.8 // indirect github.com/pelletier/go-toml/v2 v2.0.8 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/redis/go-redis/v9 v9.0.3 // indirect github.com/redis/go-redis/v9 v9.0.3 // indirect
github.com/rivo/uniseg v0.2.0 // indirect github.com/rivo/uniseg v0.2.0 // indirect
github.com/robfig/cron/v3 v3.0.1 // indirect github.com/robfig/cron/v3 v3.0.1 // indirect
@ -65,5 +70,4 @@ require (
golang.org/x/text v0.10.0 // indirect golang.org/x/text v0.10.0 // indirect
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect
google.golang.org/protobuf v1.30.0 // indirect google.golang.org/protobuf v1.30.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
) )

7
go.sum
View File

@ -1,3 +1,4 @@
git.sr.ht/~nelsam/hel v0.4.3 h1:9W0zz8zv8CZhFsp8r9Wq6c8gFemBdtMurjZU/JKfvfM=
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw=
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A= github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
@ -36,7 +37,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
@ -57,8 +57,6 @@ github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs=
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw= github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw=
github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts= github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts=
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
@ -81,7 +79,6 @@ github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible h1:Y6sqxHMyB1D2YSzWkL
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible/go.mod h1:ZQnN8lSECaebrkQytbHj4xNgtg8CR7RYXnPok8e0EHA= github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible/go.mod h1:ZQnN8lSECaebrkQytbHj4xNgtg8CR7RYXnPok8e0EHA=
github.com/lestrrat-go/strftime v1.0.6 h1:CFGsDEt1pOpFNU+TJB0nhz9jl+K0hZSLE205AhTIGQQ= github.com/lestrrat-go/strftime v1.0.6 h1:CFGsDEt1pOpFNU+TJB0nhz9jl+K0hZSLE205AhTIGQQ=
github.com/lestrrat-go/strftime v1.0.6/go.mod h1:f7jQKgV5nnJpYgdEasS+/y7EsTb8ykN2z68n3TtcTaw= github.com/lestrrat-go/strftime v1.0.6/go.mod h1:f7jQKgV5nnJpYgdEasS+/y7EsTb8ykN2z68n3TtcTaw=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
@ -106,6 +103,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/poy/onpar v0.3.2 h1:yo8ZRqU3C4RlvkXPWUWfonQiTodAgpKQZ1g8VTNU9xU=
github.com/poy/onpar v0.3.2/go.mod h1:6XDWG8DJ1HsFX6/Btn0pHl3Jz5d1SEEGNZ5N1gtYo+I=
github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k=
github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk=
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=