Compare commits
3 Commits
36f9acb8f7
...
02b5c78a45
Author | SHA1 | Date | |
---|---|---|---|
02b5c78a45 | |||
99e2411cee | |||
2822110dec |
@ -105,7 +105,7 @@ tmp_dir = "var/tmp"
|
||||
|
||||
[build]
|
||||
# 只需要写你平常编译使用的shell命令。你也可以使用 make
|
||||
cmd = "go build -o ./var/tmp/main entry/web/main.go"
|
||||
cmd = "go build -o ./var/tmp/main entry/http/main.go"
|
||||
# 由 cmd 命令得到的二进制文件名
|
||||
bin = "var/tmp/main"
|
||||
# 自定义的二进制,可以添加额外的编译标识例如添加 GIN_MODE=release
|
||||
|
@ -1,749 +0,0 @@
|
||||
//
|
||||
// adapter_sqlx.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package gcasbin
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
|
||||
"github.com/casbin/casbin/v2/model"
|
||||
"github.com/casbin/casbin/v2/persist"
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
// defaultTableName if tableName == "", the Adapter will use this default table name.
|
||||
const defaultTableName = "casbin_rule"
|
||||
|
||||
// maxParamLength .
|
||||
const maxParamLength = 7
|
||||
|
||||
// general sql
|
||||
const (
|
||||
sqlCreateTable = `
|
||||
CREATE TABLE %[1]s(
|
||||
p_type VARCHAR(32),
|
||||
v0 VARCHAR(255),
|
||||
v1 VARCHAR(255),
|
||||
v2 VARCHAR(255),
|
||||
v3 VARCHAR(255),
|
||||
v4 VARCHAR(255),
|
||||
v5 VARCHAR(255)
|
||||
);
|
||||
CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
||||
sqlTruncateTable = "TRUNCATE TABLE %s"
|
||||
sqlIsTableExist = "SELECT 1 FROM %s"
|
||||
sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)"
|
||||
sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
|
||||
sqlDeleteAll = "DELETE FROM %s"
|
||||
sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
|
||||
sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?"
|
||||
sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s"
|
||||
sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE "
|
||||
)
|
||||
|
||||
// for Sqlite3
|
||||
const (
|
||||
sqlCreateTableSqlite3 = `
|
||||
CREATE TABLE IF NOT EXISTS %[1]s(
|
||||
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
||||
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v5 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
CHECK (TYPEOF("p_type") = "text" AND
|
||||
LENGTH("p_type") <= 32),
|
||||
CHECK (TYPEOF("v0") = "text" AND
|
||||
LENGTH("v0") <= 255),
|
||||
CHECK (TYPEOF("v1") = "text" AND
|
||||
LENGTH("v1") <= 255),
|
||||
CHECK (TYPEOF("v2") = "text" AND
|
||||
LENGTH("v2") <= 255),
|
||||
CHECK (TYPEOF("v3") = "text" AND
|
||||
LENGTH("v3") <= 255),
|
||||
CHECK (TYPEOF("v4") = "text" AND
|
||||
LENGTH("v4") <= 255),
|
||||
CHECK (TYPEOF("v5") = "text" AND
|
||||
LENGTH("v5") <= 255)
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
||||
sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3
|
||||
)
|
||||
|
||||
// for Mysql
|
||||
const (
|
||||
sqlCreateTableMysql = `
|
||||
CREATE TABLE IF NOT EXISTS %[1]s(
|
||||
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
||||
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v5 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
INDEX idx_%[1]s (p_type,v0,v1)
|
||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;`
|
||||
)
|
||||
|
||||
// for Postgres
|
||||
const (
|
||||
sqlCreateTablePostgres = `
|
||||
CREATE TABLE IF NOT EXISTS %[1]s(
|
||||
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
||||
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v5 VARCHAR(255) DEFAULT '' NOT NULL
|
||||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
||||
sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)"
|
||||
sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14"
|
||||
sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7"
|
||||
)
|
||||
|
||||
// for Sqlserver
|
||||
const (
|
||||
sqlCreateTableSqlserver = `
|
||||
CREATE TABLE %[1]s(
|
||||
p_type NVARCHAR(32) DEFAULT '' NOT NULL,
|
||||
v0 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v1 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v2 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v3 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v4 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||||
v5 NVARCHAR(255) DEFAULT '' NOT NULL
|
||||
);
|
||||
CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);`
|
||||
sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)"
|
||||
sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14"
|
||||
sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7"
|
||||
)
|
||||
|
||||
// CasbinRule defines the casbin rule model.
|
||||
// It used for save or load policy lines from sqlx connected database.
|
||||
type SqlCasbinRule struct {
|
||||
PType string `db:"p_type"`
|
||||
V0 string `db:"v0"`
|
||||
V1 string `db:"v1"`
|
||||
V2 string `db:"v2"`
|
||||
V3 string `db:"v3"`
|
||||
V4 string `db:"v4"`
|
||||
V5 string `db:"v5"`
|
||||
}
|
||||
|
||||
// Adapter define the sqlx adapter for Casbin.
|
||||
// It can load policy lines or save policy lines from sqlx connected database.
|
||||
type SqlAdapter struct {
|
||||
db *sqlx.DB
|
||||
ctx context.Context
|
||||
tableName string
|
||||
|
||||
isFiltered bool
|
||||
|
||||
SqlCreateTable string
|
||||
SqlTruncateTable string
|
||||
SqlIsTableExist string
|
||||
SqlInsertRow string
|
||||
SqlUpdateRow string
|
||||
SqlDeleteAll string
|
||||
SqlDeleteRow string
|
||||
SqlDeleteByArgs string
|
||||
SqlSelectAll string
|
||||
SqlSelectWhere string
|
||||
}
|
||||
|
||||
// Filter defines the filtering rules for a FilteredAdapter's policy.
|
||||
// Empty values are ignored, but all others must match the filter.
|
||||
type SqlFilter struct {
|
||||
PType []string
|
||||
V0 []string
|
||||
V1 []string
|
||||
V2 []string
|
||||
V3 []string
|
||||
V4 []string
|
||||
V5 []string
|
||||
}
|
||||
|
||||
// NewAdapter the constructor for Adapter.
|
||||
// db should connected to database and controlled by user.
|
||||
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
|
||||
func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) {
|
||||
return NewSqlAdapterContext(context.Background(), db, tableName)
|
||||
}
|
||||
|
||||
// NewAdapterContext the constructor for Adapter.
|
||||
// db should connected to database and controlled by user.
|
||||
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
|
||||
func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) {
|
||||
if db == nil {
|
||||
return nil, errors.New("db is nil")
|
||||
}
|
||||
|
||||
// check db connecting
|
||||
err := db.PingContext(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch db.DriverName() {
|
||||
case "oci8", "ora", "goracle":
|
||||
return nil, errors.New("sqlxadapter: please checkout 'oracle' branch")
|
||||
}
|
||||
|
||||
if tableName == "" {
|
||||
tableName = defaultTableName
|
||||
}
|
||||
|
||||
adapter := SqlAdapter{
|
||||
db: db,
|
||||
ctx: ctx,
|
||||
tableName: tableName,
|
||||
}
|
||||
|
||||
// generate different databases sql
|
||||
adapter.genSQL()
|
||||
|
||||
if !adapter.IsTableExist() {
|
||||
if err = adapter.CreateTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &adapter, nil
|
||||
}
|
||||
|
||||
// genSQL generate sql based on db driver name.
|
||||
func (p *SqlAdapter) genSQL() {
|
||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName)
|
||||
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName)
|
||||
|
||||
p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName)
|
||||
|
||||
p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName)
|
||||
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName)
|
||||
p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName)
|
||||
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName)
|
||||
p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName)
|
||||
|
||||
p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName)
|
||||
p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName)
|
||||
|
||||
switch p.db.DriverName() {
|
||||
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
|
||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName)
|
||||
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName)
|
||||
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName)
|
||||
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName)
|
||||
case "mysql":
|
||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName)
|
||||
case "sqlite3":
|
||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName)
|
||||
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName)
|
||||
case "sqlserver":
|
||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName)
|
||||
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName)
|
||||
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName)
|
||||
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName)
|
||||
}
|
||||
}
|
||||
|
||||
// createTable create a not exists table.
|
||||
func (p *SqlAdapter) CreateTable() error {
|
||||
_, err := p.db.ExecContext(p.ctx, p.SqlCreateTable)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// truncateTable clear the table.
|
||||
func (p *SqlAdapter) TruncateTable() error {
|
||||
_, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// deleteAll clear the table.
|
||||
func (p *SqlAdapter) DeleteAll() error {
|
||||
_, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// isTableExist check the table exists.
|
||||
func (p *SqlAdapter) IsTableExist() bool {
|
||||
_, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist)
|
||||
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// deleteRows delete eligible data.
|
||||
func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error {
|
||||
query = p.db.Rebind(query)
|
||||
|
||||
_, err := p.db.ExecContext(p.ctx, query, args...)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// truncateAndInsertRows clear table and insert new rows.
|
||||
func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error {
|
||||
if err := p.TruncateTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.execTxSqlRows(p.SqlInsertRow, rules)
|
||||
}
|
||||
|
||||
// deleteAllAndInsertRows clear table and insert new rows.
|
||||
func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error {
|
||||
if err := p.DeleteAll(); err != nil {
|
||||
return err
|
||||
}
|
||||
return p.execTxSqlRows(p.SqlInsertRow, rules)
|
||||
}
|
||||
|
||||
// execTxSqlRows exec sql rows.
|
||||
func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) {
|
||||
tx, err := p.db.BeginTx(p.ctx, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var action string
|
||||
|
||||
stmt, err := tx.PrepareContext(p.ctx, query)
|
||||
if err != nil {
|
||||
action = "prepare context"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
for _, rule := range rules {
|
||||
if _, err = stmt.ExecContext(p.ctx, rule...); err != nil {
|
||||
action = "stmt exec"
|
||||
goto ROLLBACK
|
||||
}
|
||||
}
|
||||
|
||||
if err = stmt.Close(); err != nil {
|
||||
action = "stmt close"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
action = "commit"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
ROLLBACK:
|
||||
|
||||
if err1 := tx.Rollback(); err1 != nil {
|
||||
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// selectRows select eligible data by args from the table.
|
||||
func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) {
|
||||
// make a slice with capacity
|
||||
lines := make([]*SqlCasbinRule, 0, 64)
|
||||
|
||||
if len(args) == 0 {
|
||||
return lines, p.db.SelectContext(p.ctx, &lines, query)
|
||||
}
|
||||
|
||||
query = p.db.Rebind(query)
|
||||
|
||||
return lines, p.db.SelectContext(p.ctx, &lines, query, args...)
|
||||
}
|
||||
|
||||
// selectWhereIn select eligible data by filter from the table.
|
||||
func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) {
|
||||
var sqlBuf bytes.Buffer
|
||||
|
||||
sqlBuf.Grow(64)
|
||||
sqlBuf.WriteString(p.SqlSelectWhere)
|
||||
|
||||
args := make([]interface{}, 0, 4)
|
||||
|
||||
hasInCond := false
|
||||
|
||||
for _, col := range [maxParamLength]struct {
|
||||
name string
|
||||
arg []string
|
||||
}{
|
||||
{"p_type", filter.PType},
|
||||
{"v0", filter.V0},
|
||||
{"v1", filter.V1},
|
||||
{"v2", filter.V2},
|
||||
{"v3", filter.V3},
|
||||
{"v4", filter.V4},
|
||||
{"v5", filter.V5},
|
||||
} {
|
||||
l := len(col.arg)
|
||||
if l == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
switch sqlBuf.Bytes()[sqlBuf.Len()-1] {
|
||||
case '?', ')':
|
||||
sqlBuf.WriteString(" AND ")
|
||||
}
|
||||
|
||||
sqlBuf.WriteString(col.name)
|
||||
|
||||
if l == 1 {
|
||||
sqlBuf.WriteString("=?")
|
||||
args = append(args, col.arg[0])
|
||||
} else {
|
||||
sqlBuf.WriteString(" IN (?)")
|
||||
args = append(args, col.arg)
|
||||
|
||||
hasInCond = true
|
||||
}
|
||||
}
|
||||
|
||||
var query string
|
||||
|
||||
if hasInCond {
|
||||
if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil {
|
||||
return
|
||||
}
|
||||
} else {
|
||||
query = sqlBuf.String()
|
||||
}
|
||||
|
||||
return p.SelectRows(query, args...)
|
||||
}
|
||||
|
||||
// LoadPolicy load all policy rules from the storage.
|
||||
func (p *SqlAdapter) LoadPolicy(model model.Model) error {
|
||||
lines, err := p.SelectRows(p.SqlSelectAll)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
p.loadPolicyLine(line, model)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// SavePolicy save policy rules to the storage.
|
||||
func (p *SqlAdapter) SavePolicy(model model.Model) error {
|
||||
args := make([][]interface{}, 0, 64)
|
||||
|
||||
for ptype, ast := range model["p"] {
|
||||
for _, rule := range ast.Policy {
|
||||
arg := p.GenArgs(ptype, rule)
|
||||
args = append(args, arg)
|
||||
}
|
||||
}
|
||||
|
||||
for ptype, ast := range model["g"] {
|
||||
for _, rule := range ast.Policy {
|
||||
arg := p.GenArgs(ptype, rule)
|
||||
args = append(args, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return p.DeleteAllAndInsertRows(args)
|
||||
}
|
||||
|
||||
// AddPolicy add one policy rule to the storage.
|
||||
func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error {
|
||||
args := p.GenArgs(ptype, rule)
|
||||
|
||||
_, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// AddPolicies add multiple policy rules to the storage.
|
||||
func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
|
||||
args := make([][]interface{}, 0, 8)
|
||||
|
||||
for _, rule := range rules {
|
||||
arg := p.GenArgs(ptype, rule)
|
||||
args = append(args, arg)
|
||||
}
|
||||
|
||||
return p.execTxSqlRows(p.SqlInsertRow, args)
|
||||
}
|
||||
|
||||
// RemovePolicy remove policy rules from the storage.
|
||||
func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
|
||||
var sqlBuf bytes.Buffer
|
||||
|
||||
sqlBuf.Grow(64)
|
||||
sqlBuf.WriteString(p.SqlDeleteByArgs)
|
||||
|
||||
args := make([]interface{}, 0, 4)
|
||||
args = append(args, ptype)
|
||||
|
||||
for idx, arg := range rule {
|
||||
if arg != "" {
|
||||
sqlBuf.WriteString(" AND v")
|
||||
sqlBuf.WriteString(strconv.Itoa(idx))
|
||||
sqlBuf.WriteString("=?")
|
||||
|
||||
args = append(args, arg)
|
||||
}
|
||||
}
|
||||
|
||||
return p.DeleteRows(sqlBuf.String(), args...)
|
||||
}
|
||||
|
||||
// RemoveFilteredPolicy remove policy rules that match the filter from the storage.
|
||||
func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
|
||||
var sqlBuf bytes.Buffer
|
||||
|
||||
sqlBuf.Grow(64)
|
||||
sqlBuf.WriteString(p.SqlDeleteByArgs)
|
||||
|
||||
args := make([]interface{}, 0, 4)
|
||||
args = append(args, ptype)
|
||||
|
||||
var value string
|
||||
|
||||
l := fieldIndex + len(fieldValues)
|
||||
|
||||
for idx := 0; idx < 6; idx++ {
|
||||
if fieldIndex <= idx && idx < l {
|
||||
value = fieldValues[idx-fieldIndex]
|
||||
|
||||
if value != "" {
|
||||
sqlBuf.WriteString(" AND v")
|
||||
sqlBuf.WriteString(strconv.Itoa(idx))
|
||||
sqlBuf.WriteString("=?")
|
||||
|
||||
args = append(args, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return p.DeleteRows(sqlBuf.String(), args...)
|
||||
}
|
||||
|
||||
// RemovePolicies remove policy rules.
|
||||
func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
|
||||
args := make([][]interface{}, 0, 8)
|
||||
|
||||
for _, rule := range rules {
|
||||
arg := p.GenArgs(ptype, rule)
|
||||
args = append(args, arg)
|
||||
}
|
||||
|
||||
return p.execTxSqlRows(p.SqlDeleteRow, args)
|
||||
}
|
||||
|
||||
// LoadFilteredPolicy load policy rules that match the filter.
|
||||
// filterPtr must be a pointer.
|
||||
func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error {
|
||||
if filterPtr == nil {
|
||||
return p.LoadPolicy(model)
|
||||
}
|
||||
|
||||
filter, ok := filterPtr.(*SqlFilter)
|
||||
if !ok {
|
||||
return errors.New("invalid filter type")
|
||||
}
|
||||
|
||||
lines, err := p.SelectWhereIn(filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, line := range lines {
|
||||
p.loadPolicyLine(line, model)
|
||||
}
|
||||
|
||||
p.isFiltered = true
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsFiltered returns true if the loaded policy rules has been filtered.
|
||||
func (p *SqlAdapter) IsFiltered() bool {
|
||||
return p.isFiltered
|
||||
}
|
||||
|
||||
// UpdatePolicy update a policy rule from storage.
|
||||
// This is part of the Auto-Save feature.
|
||||
func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error {
|
||||
oldArg := p.GenArgs(ptype, oldRule)
|
||||
newArg := p.GenArgs(ptype, newPolicy)
|
||||
|
||||
_, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...)
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdatePolicies updates policy rules to storage.
|
||||
func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) {
|
||||
if len(oldRules) != len(newRules) {
|
||||
return errors.New("old rules size not equal to new rules size")
|
||||
}
|
||||
|
||||
args := make([][]interface{}, 0, 16)
|
||||
|
||||
for idx := range oldRules {
|
||||
oldArg := p.GenArgs(ptype, oldRules[idx])
|
||||
newArg := p.GenArgs(ptype, newRules[idx])
|
||||
args = append(args, append(newArg, oldArg...))
|
||||
}
|
||||
|
||||
return p.execTxSqlRows(p.SqlUpdateRow, args)
|
||||
}
|
||||
|
||||
// UpdateFilteredPolicies deletes old rules and adds new rules.
|
||||
func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) {
|
||||
var value string
|
||||
|
||||
var whereBuf bytes.Buffer
|
||||
whereBuf.Grow(32)
|
||||
|
||||
l := fieldIndex + len(fieldValues)
|
||||
|
||||
whereArgs := make([]interface{}, 0, 4)
|
||||
whereArgs = append(whereArgs, ptype)
|
||||
|
||||
for idx := 0; idx < 6; idx++ {
|
||||
if fieldIndex <= idx && idx < l {
|
||||
value = fieldValues[idx-fieldIndex]
|
||||
|
||||
if value != "" {
|
||||
whereBuf.WriteString(" AND v")
|
||||
whereBuf.WriteString(strconv.Itoa(idx))
|
||||
whereBuf.WriteString("=?")
|
||||
|
||||
whereArgs = append(whereArgs, value)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var selectBuf bytes.Buffer
|
||||
selectBuf.Grow(64)
|
||||
selectBuf.WriteString(p.SqlSelectWhere)
|
||||
selectBuf.WriteString("p_type=?")
|
||||
selectBuf.Write(whereBuf.Bytes())
|
||||
|
||||
var oldRows []*SqlCasbinRule
|
||||
value = p.db.Rebind(selectBuf.String())
|
||||
oldRows, err = p.SelectRows(value, whereArgs...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var deleteBuf bytes.Buffer
|
||||
deleteBuf.Grow(64)
|
||||
deleteBuf.WriteString(p.SqlDeleteByArgs)
|
||||
deleteBuf.Write(whereBuf.Bytes())
|
||||
|
||||
var tx *sqlx.Tx
|
||||
tx, err = p.db.BeginTxx(p.ctx, nil)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var (
|
||||
stmt *sqlx.Stmt
|
||||
action string
|
||||
)
|
||||
value = p.db.Rebind(deleteBuf.String())
|
||||
if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil {
|
||||
action = "delete old policies"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow)
|
||||
if err != nil {
|
||||
action = "preparex context"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
for _, policy := range newPolicies {
|
||||
arg := p.GenArgs(ptype, policy)
|
||||
if _, err = stmt.ExecContext(p.ctx, arg...); err != nil {
|
||||
action = "stmt exec context"
|
||||
goto ROLLBACK
|
||||
}
|
||||
}
|
||||
|
||||
if err = stmt.Close(); err != nil {
|
||||
action = "stmt close"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
action = "commit"
|
||||
goto ROLLBACK
|
||||
}
|
||||
|
||||
oldPolicies = make([][]string, 0, len(oldRows))
|
||||
for _, rule := range oldRows {
|
||||
oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
ROLLBACK:
|
||||
|
||||
if err1 := tx.Rollback(); err1 != nil {
|
||||
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// loadPolicyLine load a policy line to model.
|
||||
func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) {
|
||||
if line == nil {
|
||||
return
|
||||
}
|
||||
|
||||
var lineBuf bytes.Buffer
|
||||
|
||||
lineBuf.Grow(64)
|
||||
lineBuf.WriteString(line.PType)
|
||||
|
||||
args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
|
||||
for _, arg := range args {
|
||||
if arg != "" {
|
||||
lineBuf.WriteByte(',')
|
||||
lineBuf.WriteString(arg)
|
||||
}
|
||||
}
|
||||
|
||||
persist.LoadPolicyLine(lineBuf.String(), model)
|
||||
}
|
||||
|
||||
// genArgs generate args from ptype and rule.
|
||||
func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} {
|
||||
l := len(rule)
|
||||
|
||||
args := make([]interface{}, maxParamLength)
|
||||
args[0] = ptype
|
||||
|
||||
for idx := 0; idx < l; idx++ {
|
||||
args[idx+1] = rule[idx]
|
||||
}
|
||||
|
||||
for idx := l + 1; idx < maxParamLength; idx++ {
|
||||
args[idx] = ""
|
||||
}
|
||||
|
||||
return args
|
||||
}
|
@ -1,466 +0,0 @@
|
||||
//
|
||||
// adapter_sqlx_test.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package gcasbin_test
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/casbin/casbin/v2"
|
||||
"github.com/casbin/casbin/v2/util"
|
||||
_ "github.com/go-sql-driver/mysql"
|
||||
"github.com/jmoiron/sqlx"
|
||||
"git.hexq.cn/tiglog/golib/gcasbin"
|
||||
)
|
||||
|
||||
const (
|
||||
rbacModelFile = "testdata/rbac_model.conf"
|
||||
rbacPolicyFile = "testdata/rbac_policy.csv"
|
||||
)
|
||||
|
||||
var (
|
||||
dataSourceNames = map[string]string{
|
||||
// "sqlite3": ":memory:",
|
||||
// "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test",
|
||||
"postgres": os.Getenv("DB_DSN"),
|
||||
// "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30",
|
||||
}
|
||||
|
||||
lines = []gcasbin.SqlCasbinRule{
|
||||
{PType: "p", V0: "alice", V1: "data1", V2: "read"},
|
||||
{PType: "p", V0: "bob", V1: "data2", V2: "read"},
|
||||
{PType: "p", V0: "bob", V1: "data2", V2: "write"},
|
||||
{PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"},
|
||||
{PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"},
|
||||
{PType: "p", V0: "data1_admin", V1: "data2", V2: "write"},
|
||||
{PType: "g", V0: "alice", V1: "data2_admin"},
|
||||
{PType: "g", V0: "bob", V1: "data2_admin", V2: "test"},
|
||||
{PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"},
|
||||
}
|
||||
|
||||
filter = gcasbin.SqlFilter{
|
||||
PType: []string{"p"},
|
||||
V0: []string{"bob", "data2_admin"},
|
||||
V1: []string{"data1", "data2"},
|
||||
V2: []string{"read", "write"},
|
||||
V3: []string{"test1"},
|
||||
V4: []string{"test2"},
|
||||
V5: []string{"test3"},
|
||||
}
|
||||
)
|
||||
|
||||
func TestSqlAdapters(t *testing.T) {
|
||||
for key, value := range dataSourceNames {
|
||||
t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value)
|
||||
|
||||
db, err := sqlx.Connect(key, value)
|
||||
if err != nil {
|
||||
t.Fatalf("sqlx.Connect failed, err: %v", err)
|
||||
}
|
||||
|
||||
t.Log("---------- testTableName start")
|
||||
testTableName(t, db)
|
||||
t.Log("---------- testTableName finished")
|
||||
|
||||
t.Log("---------- testSQL start")
|
||||
testSQL(t, db, "sqlxadapter_sql")
|
||||
t.Log("---------- testSQL finished")
|
||||
|
||||
t.Log("---------- testSaveLoad start")
|
||||
testSaveLoad(t, db, "sqlxadapter_save_load")
|
||||
t.Log("---------- testSaveLoad finished")
|
||||
|
||||
t.Log("---------- testAutoSave start")
|
||||
testAutoSave(t, db, "sqlxadapter_auto_save")
|
||||
t.Log("---------- testAutoSave finished")
|
||||
|
||||
t.Log("---------- testFilteredSqlPolicy start")
|
||||
testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy")
|
||||
t.Log("---------- testFilteredSqlPolicy finished")
|
||||
|
||||
// t.Log("---------- testUpdateSqlPolicy start")
|
||||
// testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy")
|
||||
// t.Log("---------- testUpdateSqlPolicy finished")
|
||||
|
||||
// t.Log("---------- testUpdateSqlPolicies start")
|
||||
// testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy")
|
||||
// t.Log("---------- testUpdateSqlPolicies finished")
|
||||
|
||||
// t.Log("---------- testUpdateFilteredSqlPolicies start")
|
||||
// testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy")
|
||||
// t.Log("---------- testUpdateFilteredSqlPolicies finished")
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func testTableName(t *testing.T, db *sqlx.DB) {
|
||||
_, err := gcasbin.NewSqlAdapter(db, "")
|
||||
if err != nil {
|
||||
t.Fatalf("NewAdapter failed, err: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
var err error
|
||||
logErr := func(action string) {
|
||||
if err != nil {
|
||||
t.Errorf("%s test failed, err: %v", action, err)
|
||||
}
|
||||
}
|
||||
|
||||
equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool {
|
||||
if line1.PType != line2.PType ||
|
||||
line1.V0 != line2.V0 ||
|
||||
line1.V1 != line2.V1 ||
|
||||
line1.V2 != line2.V2 ||
|
||||
line1.V3 != line2.V3 ||
|
||||
line1.V4 != line2.V4 ||
|
||||
line1.V5 != line2.V5 {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var a *gcasbin.SqlAdapter
|
||||
a, err = gcasbin.NewSqlAdapter(db, tableName)
|
||||
logErr("NewSqlAdapter")
|
||||
|
||||
// createTable test has passed when adapter create
|
||||
// err = a.CreateTable()
|
||||
// logErr("createTable")
|
||||
|
||||
if b := a.IsTableExist(); b == false {
|
||||
t.Fatal("isTableExist test failed")
|
||||
}
|
||||
|
||||
rules := make([][]interface{}, len(lines))
|
||||
for idx, rule := range lines {
|
||||
args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
|
||||
rules[idx] = args
|
||||
}
|
||||
|
||||
err = a.TruncateAndInsertRows(rules)
|
||||
logErr("truncateAndInsertRows")
|
||||
|
||||
err = a.DeleteAllAndInsertRows(rules)
|
||||
logErr("truncateAndInsertRows")
|
||||
|
||||
err = a.DeleteRows(a.SqlDeleteByArgs, "g")
|
||||
logErr("deleteRows sqlDeleteByArgs g")
|
||||
|
||||
err = a.DeleteRows(a.SqlDeleteAll)
|
||||
logErr("deleteRows sqlDeleteAll")
|
||||
|
||||
_ = a.TruncateAndInsertRows(rules)
|
||||
|
||||
records, err := a.SelectRows(a.SqlSelectAll)
|
||||
logErr("selectRows sqlSelectAll")
|
||||
for idx, record := range records {
|
||||
line := lines[idx]
|
||||
if !equalValue(*record, line) {
|
||||
t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line)
|
||||
}
|
||||
}
|
||||
|
||||
records, err = a.SelectWhereIn(&filter)
|
||||
logErr("selectWhereIn")
|
||||
i := 3
|
||||
for _, record := range records {
|
||||
line := lines[i]
|
||||
if !equalValue(*record, line) {
|
||||
t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line)
|
||||
}
|
||||
i++
|
||||
}
|
||||
|
||||
err = a.TruncateTable()
|
||||
logErr("truncateTable")
|
||||
}
|
||||
|
||||
func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Because the DB is empty at first,
|
||||
// so we need to load the policy from the file adapter (.CSV) first.
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile)
|
||||
|
||||
a, err := gcasbin.NewSqlAdapter(db, tableName)
|
||||
if err != nil {
|
||||
t.Fatal("NewAdapter test failed, err: ", err)
|
||||
}
|
||||
|
||||
// This is a trick to save the current policy to the DB.
|
||||
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
|
||||
// The current policy means the policy in the Casbin enforcer (aka in memory).
|
||||
err = a.SavePolicy(e.GetModel())
|
||||
if err != nil {
|
||||
t.Fatal("SavePolicy test failed, err: ", err)
|
||||
}
|
||||
|
||||
// Clear the current policy.
|
||||
e.ClearPolicy()
|
||||
testGetSqlPolicy(t, e, [][]string{})
|
||||
|
||||
// Load the policy from DB.
|
||||
err = a.LoadPolicy(e.GetModel())
|
||||
if err != nil {
|
||||
t.Fatal("LoadPolicy test failed, err: ", err)
|
||||
}
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
}
|
||||
|
||||
func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Initialize some policy in DB.
|
||||
initSqlPolicy(t, db, tableName)
|
||||
// Note: you don't need to look at the above code
|
||||
// if you already have a working DB with policy inside.
|
||||
|
||||
// Now the DB has policy, so we can provide a normal use case.
|
||||
// Create an adapter and an enforcer.
|
||||
// NewEnforcer() will load the policy automatically.
|
||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
}
|
||||
|
||||
func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Initialize some policy in DB.
|
||||
initSqlPolicy(t, db, tableName)
|
||||
// Note: you don't need to look at the above code
|
||||
// if you already have a working DB with policy inside.
|
||||
|
||||
// Now the DB has policy, so we can provide a normal use case.
|
||||
// Create an adapter and an enforcer.
|
||||
// NewEnforcer() will load the policy automatically.
|
||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
||||
|
||||
// AutoSave is enabled by default.
|
||||
// Now we disable it.
|
||||
e.EnableAutoSave(false)
|
||||
|
||||
var err error
|
||||
logErr := func(action string) {
|
||||
if err != nil {
|
||||
t.Errorf("%s test failed, err: %v", action, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer,
|
||||
// it doesn't affect the policy in the storage.
|
||||
_, err = e.AddPolicy("alice", "data1", "write")
|
||||
logErr("AddPolicy1")
|
||||
// Reload the policy from the storage to see the effect.
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy1")
|
||||
// This is still the original policy.
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
|
||||
_, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}})
|
||||
logErr("AddPolicies1")
|
||||
// Reload the policy from the storage to see the effect.
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy2")
|
||||
// This is still the original policy.
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
|
||||
// Now we enable the AutoSave.
|
||||
e.EnableAutoSave(true)
|
||||
|
||||
// Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer,
|
||||
// but also affects the policy in the storage.
|
||||
_, err = e.AddPolicy("alice", "data1", "write")
|
||||
logErr("AddPolicy2")
|
||||
// Reload the policy from the storage to see the effect.
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy3")
|
||||
// The policy has a new rule: {"alice", "data1", "write"}.
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
|
||||
|
||||
_, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
||||
logErr("AddPolicies2")
|
||||
// Reload the policy from the storage to see the effect.
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy4")
|
||||
// This is still the original policy.
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"},
|
||||
{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
||||
|
||||
_, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
||||
logErr("RemovePolicies")
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy5")
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
|
||||
|
||||
// Remove the added rule.
|
||||
_, err = e.RemovePolicy("alice", "data1", "write")
|
||||
logErr("RemovePolicy")
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy6")
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
|
||||
// Remove "data2_admin" related policy rules via a filter.
|
||||
// Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted.
|
||||
_, err = e.RemoveFilteredPolicy(0, "data2_admin")
|
||||
logErr("RemoveFilteredPolicy")
|
||||
err = e.LoadPolicy()
|
||||
logErr("LoadPolicy7")
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
|
||||
}
|
||||
|
||||
func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Initialize some policy in DB.
|
||||
initSqlPolicy(t, db, tableName)
|
||||
// Note: you don't need to look at the above code
|
||||
// if you already have a working DB with policy inside.
|
||||
|
||||
// Now the DB has policy, so we can provide a normal use case.
|
||||
// Create an adapter and an enforcer.
|
||||
// NewEnforcer() will load the policy automatically.
|
||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
||||
// Now set the adapter
|
||||
e.SetAdapter(a)
|
||||
|
||||
var err error
|
||||
logErr := func(action string) {
|
||||
if err != nil {
|
||||
t.Errorf("%s test failed, err: %v", action, err)
|
||||
}
|
||||
}
|
||||
|
||||
// Load only alice's policies
|
||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}})
|
||||
logErr("LoadFilteredPolicy alice")
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}})
|
||||
|
||||
// Load only bob's policies
|
||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}})
|
||||
logErr("LoadFilteredPolicy bob")
|
||||
testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}})
|
||||
|
||||
// Load policies for data2_admin
|
||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}})
|
||||
logErr("LoadFilteredPolicy data2_admin")
|
||||
testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
|
||||
// Load policies for alice and bob
|
||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}})
|
||||
logErr("LoadFilteredPolicy alice bob")
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
|
||||
|
||||
_, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3")
|
||||
logErr("AddPolicy")
|
||||
|
||||
err = e.LoadFilteredPolicy(&filter)
|
||||
logErr("LoadFilteredPolicy filter")
|
||||
testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}})
|
||||
}
|
||||
|
||||
func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Initialize some policy in DB.
|
||||
initSqlPolicy(t, db, tableName)
|
||||
|
||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
||||
|
||||
e.EnableAutoSave(true)
|
||||
e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"})
|
||||
e.LoadPolicy()
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
}
|
||||
|
||||
func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Initialize some policy in DB.
|
||||
initSqlPolicy(t, db, tableName)
|
||||
|
||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
||||
|
||||
e.EnableAutoSave(true)
|
||||
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
|
||||
e.LoadPolicy()
|
||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
||||
}
|
||||
|
||||
func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
|
||||
// Initialize some policy in DB.
|
||||
initSqlPolicy(t, db, tableName)
|
||||
|
||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
||||
|
||||
e.EnableAutoSave(true)
|
||||
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
|
||||
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
|
||||
e.LoadPolicy()
|
||||
testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
|
||||
}
|
||||
|
||||
func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
|
||||
t.Helper()
|
||||
myRes := e.GetPolicy()
|
||||
t.Log("Policy: ", myRes)
|
||||
|
||||
m := make(map[string]struct{}, len(myRes))
|
||||
for _, record := range myRes {
|
||||
key := strings.Join(record, ",")
|
||||
m[key] = struct{}{}
|
||||
}
|
||||
|
||||
for _, record := range res {
|
||||
key := strings.Join(record, ",")
|
||||
if _, ok := m[key]; !ok {
|
||||
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
|
||||
myRes := e.GetPolicy()
|
||||
// log.Print("Policy: \n", myRes)
|
||||
|
||||
if !arraySqlEqualsWithoutOrder(myRes, res) {
|
||||
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
|
||||
}
|
||||
}
|
||||
|
||||
func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
mapA := make(map[int]string)
|
||||
mapB := make(map[int]string)
|
||||
order := make(map[int]struct{})
|
||||
l := len(a)
|
||||
|
||||
for i := 0; i < l; i++ {
|
||||
mapA[i] = util.ArrayToString(a[i])
|
||||
mapB[i] = util.ArrayToString(b[i])
|
||||
}
|
||||
|
||||
for i := 0; i < l; i++ {
|
||||
for j := 0; j < l; j++ {
|
||||
if _, ok := order[j]; ok {
|
||||
if j == l-1 {
|
||||
return false
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if mapA[i] == mapB[j] {
|
||||
order[j] = struct{}{}
|
||||
break
|
||||
} else if j == l-1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
@ -1,180 +0,0 @@
|
||||
//
|
||||
// base_test.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
||||
// _ "github.com/go-sql-driver/mysql"
|
||||
)
|
||||
|
||||
type Schema struct {
|
||||
create string
|
||||
drop string
|
||||
}
|
||||
|
||||
var defaultSchema = Schema{
|
||||
create: `
|
||||
CREATE TABLE person (
|
||||
id serial,
|
||||
first_name text,
|
||||
last_name text,
|
||||
email text,
|
||||
added_at int default 0,
|
||||
PRIMARY KEY (id)
|
||||
);
|
||||
|
||||
CREATE TABLE place (
|
||||
country text,
|
||||
city text NULL,
|
||||
telcode integer
|
||||
);
|
||||
|
||||
CREATE TABLE capplace (
|
||||
country text,
|
||||
city text NULL,
|
||||
telcode integer
|
||||
);
|
||||
|
||||
CREATE TABLE nullperson (
|
||||
first_name text NULL,
|
||||
last_name text NULL,
|
||||
email text NULL
|
||||
);
|
||||
|
||||
CREATE TABLE employees (
|
||||
name text,
|
||||
id integer,
|
||||
boss_id integer
|
||||
);
|
||||
|
||||
`,
|
||||
drop: `
|
||||
drop table person;
|
||||
drop table place;
|
||||
drop table capplace;
|
||||
drop table nullperson;
|
||||
drop table employees;
|
||||
`,
|
||||
}
|
||||
|
||||
type Person struct {
|
||||
Id int64 `db:"id"`
|
||||
FirstName string `db:"first_name"`
|
||||
LastName string `db:"last_name"`
|
||||
Email string `db:"email"`
|
||||
AddedAt int64 `db:"added_at"`
|
||||
}
|
||||
|
||||
type Person2 struct {
|
||||
FirstName sql.NullString `db:"first_name"`
|
||||
LastName sql.NullString `db:"last_name"`
|
||||
Email sql.NullString
|
||||
}
|
||||
|
||||
type Place struct {
|
||||
Country string
|
||||
City sql.NullString
|
||||
TelCode int
|
||||
}
|
||||
|
||||
type PlacePtr struct {
|
||||
Country string
|
||||
City *string
|
||||
TelCode int
|
||||
}
|
||||
|
||||
type PersonPlace struct {
|
||||
Person
|
||||
Place
|
||||
}
|
||||
|
||||
type PersonPlacePtr struct {
|
||||
*Person
|
||||
*Place
|
||||
}
|
||||
|
||||
type EmbedConflict struct {
|
||||
FirstName string `db:"first_name"`
|
||||
Person
|
||||
}
|
||||
|
||||
type SliceMember struct {
|
||||
Country string
|
||||
City sql.NullString
|
||||
TelCode int
|
||||
People []Person `db:"-"`
|
||||
Addresses []Place `db:"-"`
|
||||
}
|
||||
|
||||
func loadDefaultFixture(db *sqldb.Engine, t *testing.T) {
|
||||
tx := db.MustBegin()
|
||||
|
||||
s1 := "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Jason", "Moiron", "jmoiron@jmoiron.net")
|
||||
|
||||
s1 = "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "John", "Doe", "johndoeDNE@gmail.net")
|
||||
|
||||
s1 = "INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "United States", "New York", "1")
|
||||
|
||||
s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Hong Kong", "852")
|
||||
|
||||
s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Singapore", "65")
|
||||
|
||||
s1 = "INSERT INTO capplace (country, telcode) VALUES (?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Sarf Efrica", "27")
|
||||
|
||||
s1 = "INSERT INTO employees (name, id) VALUES (?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Peter", "4444")
|
||||
|
||||
s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Joe", "1", "4444")
|
||||
|
||||
s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)"
|
||||
tx.MustExec(db.Rebind(s1), "Martin", "2", "4444")
|
||||
tx.Commit()
|
||||
}
|
||||
|
||||
func MultiExec(e *sqldb.Engine, query string) {
|
||||
stmts := strings.Split(query, ";\n")
|
||||
if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 {
|
||||
stmts = stmts[:len(stmts)-1]
|
||||
}
|
||||
for _, s := range stmts {
|
||||
_, err := e.Exec(s)
|
||||
if err != nil {
|
||||
fmt.Println(err, s)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RunDbTest(t *testing.T, test func(db *sqldb.Engine, t *testing.T)) {
|
||||
// 先初始化数据库
|
||||
url := os.Getenv("DB_URL")
|
||||
var db = sqldb.New(url)
|
||||
|
||||
// 再注册清空数据库
|
||||
defer func() {
|
||||
MultiExec(db, defaultSchema.drop)
|
||||
}()
|
||||
// 再加入一些数据
|
||||
MultiExec(db, defaultSchema.create)
|
||||
loadDefaultFixture(db, t)
|
||||
// 最后测试
|
||||
test(db, t)
|
||||
}
|
78
gdb/sqldb/column.go
Normal file
78
gdb/sqldb/column.go
Normal file
@ -0,0 +1,78 @@
|
||||
//
|
||||
// column.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import "reflect"
|
||||
|
||||
// ColumnMap represents a mapping between a Go struct field and a single
|
||||
// column in a table.
|
||||
// Unique and MaxSize only inform the
|
||||
// CreateTables() function and are not used by Insert/Update/Delete/Get.
|
||||
type ColumnMap struct {
|
||||
// Column name in db table
|
||||
ColumnName string
|
||||
|
||||
// If true, this column is skipped in generated SQL statements
|
||||
Transient bool
|
||||
|
||||
// If true, " unique" is added to create table statements.
|
||||
// Not used elsewhere
|
||||
Unique bool
|
||||
|
||||
// Query used for getting generated id after insert
|
||||
GeneratedIdQuery string
|
||||
|
||||
// Passed to Dialect.ToSqlType() to assist in informing the
|
||||
// correct column type to map to in CreateTables()
|
||||
MaxSize int
|
||||
|
||||
DefaultValue string
|
||||
|
||||
fieldName string
|
||||
gotype reflect.Type
|
||||
isPK bool
|
||||
isAutoIncr bool
|
||||
isNotNull bool
|
||||
}
|
||||
|
||||
// Rename allows you to specify the column name in the table
|
||||
//
|
||||
// Example: table.ColMap("Updated").Rename("date_updated")
|
||||
func (c *ColumnMap) Rename(colname string) *ColumnMap {
|
||||
c.ColumnName = colname
|
||||
return c
|
||||
}
|
||||
|
||||
// SetTransient allows you to mark the column as transient. If true
|
||||
// this column will be skipped when SQL statements are generated
|
||||
func (c *ColumnMap) SetTransient(b bool) *ColumnMap {
|
||||
c.Transient = b
|
||||
return c
|
||||
}
|
||||
|
||||
// SetUnique adds "unique" to the create table statements for this
|
||||
// column, if b is true.
|
||||
func (c *ColumnMap) SetUnique(b bool) *ColumnMap {
|
||||
c.Unique = b
|
||||
return c
|
||||
}
|
||||
|
||||
// SetNotNull adds "not null" to the create table statements for this
|
||||
// column, if nn is true.
|
||||
func (c *ColumnMap) SetNotNull(nn bool) *ColumnMap {
|
||||
c.isNotNull = nn
|
||||
return c
|
||||
}
|
||||
|
||||
// SetMaxSize specifies the max length of values of this column. This is
|
||||
// passed to the dialect.ToSqlType() function, which can use the value
|
||||
// to alter the generated type for "create table" statements
|
||||
func (c *ColumnMap) SetMaxSize(size int) *ColumnMap {
|
||||
c.MaxSize = size
|
||||
return c
|
||||
}
|
82
gdb/sqldb/context_test.go
Normal file
82
gdb/sqldb/context_test.go
Normal file
@ -0,0 +1,82 @@
|
||||
//
|
||||
// context_test.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Drivers that don't support cancellation.
|
||||
var unsupportedDrivers map[string]bool = map[string]bool{
|
||||
"mymysql": true,
|
||||
}
|
||||
|
||||
type SleepDialect interface {
|
||||
// string to sleep for d duration
|
||||
SleepClause(d time.Duration) string
|
||||
}
|
||||
|
||||
func TestWithNotCanceledContext(t *testing.T) {
|
||||
dbmap := initDBMap(t)
|
||||
defer dropAndClose(dbmap)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
withCtx := dbmap.WithContext(ctx)
|
||||
|
||||
_, err := withCtx.Exec("SELECT 1")
|
||||
assert.Nil(t, err)
|
||||
}
|
||||
|
||||
func TestWithCanceledContext(t *testing.T) {
|
||||
dialect, driver := dialectAndDriver()
|
||||
if unsupportedDrivers[driver] {
|
||||
t.Skipf("Cancellation is not yet supported by all drivers. Not known to be supported in %s.", driver)
|
||||
}
|
||||
|
||||
sleepDialect, ok := dialect.(SleepDialect)
|
||||
if !ok {
|
||||
t.Skipf("Sleep is not supported in all dialects. Not known to be supported in %s.", driver)
|
||||
}
|
||||
|
||||
dbmap := initDBMap(t)
|
||||
defer dropAndClose(dbmap)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
withCtx := dbmap.WithContext(ctx)
|
||||
|
||||
startTime := time.Now()
|
||||
|
||||
_, err := withCtx.Exec("SELECT " + sleepDialect.SleepClause(1*time.Second))
|
||||
|
||||
if d := time.Since(startTime); d > 500*time.Millisecond {
|
||||
t.Errorf("too long execution time: %s", d)
|
||||
}
|
||||
|
||||
switch driver {
|
||||
case "postgres":
|
||||
// pq doesn't return standard deadline exceeded error
|
||||
if err.Error() != "pq: canceling statement due to user request" {
|
||||
t.Errorf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
default:
|
||||
if err != context.DeadlineExceeded {
|
||||
t.Errorf("expected context.DeadlineExceeded, got %v", err)
|
||||
}
|
||||
}
|
||||
}
|
1050
gdb/sqldb/db.go
1050
gdb/sqldb/db.go
File diff suppressed because it is too large
Load Diff
@ -1,220 +0,0 @@
|
||||
//
|
||||
// db_func.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
func (e *Engine) Begin() (*sqlx.Tx, error) {
|
||||
return e.Beginx()
|
||||
}
|
||||
|
||||
// 插入一条记录
|
||||
func (e *Engine) NamedInsertRecord(opt *QueryOption, arg interface{}) (int64, error) { // {{{
|
||||
if len(opt.fields) == 0 {
|
||||
return 0, errors.New("empty fields")
|
||||
}
|
||||
var tmp = make([]string, 0)
|
||||
for _, field := range opt.fields {
|
||||
tmp = append(tmp, fmt.Sprintf(":%s", field))
|
||||
}
|
||||
fields_str := strings.Join(opt.fields, ",")
|
||||
fields_pl := strings.Join(tmp, ",")
|
||||
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", opt.table, fields_str, fields_pl)
|
||||
if e.DriverName() == "postgres" {
|
||||
sql += " returning id"
|
||||
}
|
||||
// sql = e.Rebind(sql)
|
||||
stmt, err := e.PrepareNamed(sql)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
var id int64
|
||||
err = stmt.Get(&id, arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, err
|
||||
} // }}}
|
||||
|
||||
// 插入一条记录
|
||||
func (e *Engine) InsertRecord(opt *QueryOption) (int64, error) { // {{{
|
||||
if len(opt.fields) == 0 {
|
||||
return 0, errors.New("empty fields")
|
||||
}
|
||||
fields_str := strings.Join(opt.fields, ",")
|
||||
fields_pl := strings.TrimRight(strings.Repeat("?,", len(opt.fields)), ",")
|
||||
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", opt.table, fields_str, fields_pl)
|
||||
if e.DriverName() == "postgres" {
|
||||
sql += " returning id"
|
||||
}
|
||||
sql = e.Rebind(sql)
|
||||
result, err := e.Exec(sql, opt.args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.LastInsertId()
|
||||
} // }}}
|
||||
|
||||
// 查询一条记录
|
||||
// dest 目标对象
|
||||
// table 查询表
|
||||
// query 查询条件
|
||||
// args bindvars
|
||||
func (e *Engine) GetRecord(dest interface{}, opt *QueryOption) error { // {{{
|
||||
if opt.query == "" {
|
||||
return errors.New("empty query")
|
||||
}
|
||||
opt.query = "WHERE " + opt.query
|
||||
sql := fmt.Sprintf("SELECT * FROM %s %s limit 1", opt.table, opt.query)
|
||||
sql = e.Rebind(sql)
|
||||
err := e.Get(dest, sql, opt.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
} // }}}
|
||||
|
||||
// 查询多条记录
|
||||
// dest 目标变量
|
||||
// opt 查询对象
|
||||
// args bindvars
|
||||
func (e *Engine) GetRecords(dest interface{}, opt *QueryOption) error { // {{{
|
||||
var tmp = []string{}
|
||||
if opt.query != "" {
|
||||
tmp = append(tmp, "where", opt.query)
|
||||
}
|
||||
if opt.sort != "" {
|
||||
tmp = append(tmp, "order by", opt.sort)
|
||||
}
|
||||
if opt.offset > 0 {
|
||||
tmp = append(tmp, "offset", strconv.Itoa(opt.offset))
|
||||
}
|
||||
if opt.limit > 0 {
|
||||
tmp = append(tmp, "limit", strconv.Itoa(opt.limit))
|
||||
}
|
||||
sql := fmt.Sprintf("select * from %s %s", opt.table, strings.Join(tmp, " "))
|
||||
sql = e.Rebind(sql)
|
||||
return e.Select(dest, sql, opt.args...)
|
||||
} // }}}
|
||||
|
||||
// 更新一条记录
|
||||
// table 待处理的表
|
||||
// set 需要设置的语句, eg: age=:age
|
||||
// query 查询语句,不能为空,确保误更新所有记录
|
||||
// arg 值
|
||||
func (e *Engine) NamedUpdateRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{
|
||||
if opt.set == "" || opt.query == "" {
|
||||
return 0, errors.New("empty set or query")
|
||||
}
|
||||
sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query)
|
||||
result, err := e.NamedExec(sql, arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return rows, nil
|
||||
} // }}}
|
||||
|
||||
func (e *Engine) UpdateRecords(opt *QueryOption) (int64, error) { // {{{
|
||||
if opt.set == "" || opt.query == "" {
|
||||
return 0, errors.New("empty set or query")
|
||||
}
|
||||
sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query)
|
||||
sql = e.Rebind(sql)
|
||||
result, err := e.Exec(sql, opt.args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return rows, nil
|
||||
} // }}}
|
||||
|
||||
// 删除若干条记录
|
||||
// opt 的 query 不能为空
|
||||
// arg bindvars
|
||||
func (e *Engine) NamedDeleteRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{
|
||||
if opt.query == "" {
|
||||
return 0, errors.New("emtpy query")
|
||||
}
|
||||
sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query)
|
||||
result, err := e.NamedExec(sql, arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return rows, nil
|
||||
} // }}}
|
||||
|
||||
func (e *Engine) DeleteRecords(opt *QueryOption) (int64, error) {
|
||||
if opt.query == "" {
|
||||
return 0, errors.New("emtpy query")
|
||||
}
|
||||
sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query)
|
||||
sql = e.Rebind(sql)
|
||||
result, err := e.Exec(sql, opt.args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
rows, err := result.RowsAffected()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return rows, nil
|
||||
}
|
||||
|
||||
func (e *Engine) CountRecords(opt *QueryOption) (int, error) {
|
||||
sql := fmt.Sprintf("select count(*) from %s where %s", opt.table, opt.query)
|
||||
sql = e.Rebind(sql)
|
||||
var num int
|
||||
err := e.Get(&num, sql, opt.args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return num, nil
|
||||
}
|
||||
|
||||
// var levels = []int{4, 6, 7}
|
||||
// query, args, err := sqlx.In("SELECT * FROM users WHERE level IN (?);", levels)
|
||||
// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend
|
||||
// query = db.Rebind(query)
|
||||
// rows, err := db.Query(query, args...)
|
||||
func (e *Engine) In(query string, args ...interface{}) (string, []interface{}, error) {
|
||||
return sqlx.In(query, args...)
|
||||
}
|
||||
|
||||
func IsNoRows(err error) bool {
|
||||
return err == ErrNoRows
|
||||
}
|
||||
|
||||
// 把 fields 转换为 field1=:field1, field2=:field2, ..., fieldN=:fieldN
|
||||
func GetSetString(fields []string) string {
|
||||
items := []string{}
|
||||
for _, field := range fields {
|
||||
if field == "id" {
|
||||
continue
|
||||
}
|
||||
items = append(items, fmt.Sprintf("%s=:%s", field, field))
|
||||
}
|
||||
return strings.Join(items, ",")
|
||||
}
|
@ -1,75 +0,0 @@
|
||||
//
|
||||
// db_func_opt.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
type QueryOption struct {
|
||||
table string
|
||||
query string
|
||||
set string
|
||||
fields []string
|
||||
sort string
|
||||
offset int
|
||||
limit int
|
||||
args []any
|
||||
joins []string
|
||||
}
|
||||
|
||||
func NewQueryOption(table string) *QueryOption {
|
||||
return &QueryOption{
|
||||
table: table,
|
||||
fields: []string{"*"},
|
||||
offset: 0,
|
||||
limit: 0,
|
||||
args: make([]any, 0),
|
||||
joins: make([]string, 0),
|
||||
}
|
||||
}
|
||||
func (opt *QueryOption) Query(query string) *QueryOption {
|
||||
opt.query = query
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Fields(args []string) *QueryOption {
|
||||
opt.fields = args
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Select(cols ...string) *QueryOption {
|
||||
opt.fields = cols
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Offset(offset int) *QueryOption {
|
||||
opt.offset = offset
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Limit(limit int) *QueryOption {
|
||||
opt.limit = limit
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Sort(sort string) *QueryOption {
|
||||
opt.sort = sort
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Set(set string) *QueryOption {
|
||||
opt.set = set
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Args(args ...any) *QueryOption {
|
||||
opt.args = args
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) Join(table string, cond string) *QueryOption {
|
||||
opt.joins = append(opt.joins, "join "+table+" on "+cond)
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) LeftJoin(table string, cond string) *QueryOption {
|
||||
opt.joins = append(opt.joins, "left join "+table+" on "+cond)
|
||||
return opt
|
||||
}
|
||||
func (opt *QueryOption) RightJoin(table string, cond string) *QueryOption {
|
||||
opt.joins = append(opt.joins, "right join "+table+" on "+cond)
|
||||
return opt
|
||||
}
|
@ -1,114 +0,0 @@
|
||||
//
|
||||
// db_func_test.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
||||
"git.hexq.cn/tiglog/golib/gtest"
|
||||
)
|
||||
|
||||
// 经过测试,发现数据库里面使用 time 类型容易出现 timezone 不一致的情况
|
||||
// 在存入数据库时,可能会导致时区丢失
|
||||
// 因此,为了更好的兼容性,使用 int 时间戳会更合适
|
||||
func dbFuncTest(db *sqldb.Engine, t *testing.T) {
|
||||
var err error
|
||||
fields := []string{"first_name", "last_name", "email"}
|
||||
p := &Person{
|
||||
FirstName: "三",
|
||||
LastName: "张",
|
||||
Email: "zs@foo.com",
|
||||
}
|
||||
// InsertRecord 的用法
|
||||
opt := sqldb.NewQueryOption("person").Fields(fields)
|
||||
rows, err := db.NamedInsertRecord(opt, p)
|
||||
gtest.Nil(t, err)
|
||||
gtest.True(t, rows > 0)
|
||||
// fmt.Println(rows)
|
||||
|
||||
// GetRecord 的用法
|
||||
var p3 Person
|
||||
opt = sqldb.NewQueryOption("person").Query("email=?").Args("zs@foo.com")
|
||||
err = db.GetRecord(&p3, opt)
|
||||
// fmt.Println(p3)
|
||||
gtest.Equal(t, "张", p3.LastName)
|
||||
gtest.Equal(t, "三", p3.FirstName)
|
||||
gtest.Equal(t, int64(0), p3.AddedAt)
|
||||
gtest.Nil(t, err)
|
||||
|
||||
p2 := &Person{
|
||||
FirstName: "四",
|
||||
LastName: "李",
|
||||
Email: "ls@foo.com",
|
||||
AddedAt: time.Now().Unix(),
|
||||
}
|
||||
fields2 := append(fields, "added_at")
|
||||
opt = sqldb.NewQueryOption("person").Fields(fields2)
|
||||
_, err = db.NamedInsertRecord(opt, p2)
|
||||
gtest.Nil(t, err)
|
||||
|
||||
var p4 Person
|
||||
opt = sqldb.NewQueryOption("person")
|
||||
err = db.GetRecord(&p4, opt)
|
||||
gtest.NotNil(t, err)
|
||||
gtest.Equal(t, "", p4.FirstName)
|
||||
|
||||
opt = sqldb.NewQueryOption("person").Query("first_name=?").Args("四")
|
||||
err = db.GetRecord(&p4, opt)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, time.Now().Unix(), p4.AddedAt)
|
||||
gtest.Equal(t, "ls@foo.com", p4.Email)
|
||||
|
||||
// GetRecords
|
||||
var ps []Person
|
||||
opt = sqldb.NewQueryOption("person").Query("id > ?").Args(0)
|
||||
err = db.GetRecords(&ps, opt)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Greater(t, int64(1), ps)
|
||||
|
||||
var ps2 []Person
|
||||
opt = sqldb.NewQueryOption("person").Query("id=?").Args(1)
|
||||
err = db.GetRecords(&ps2, opt)
|
||||
gtest.Equal(t, 1, len(ps2))
|
||||
if len(ps2) > 1 {
|
||||
gtest.Equal(t, int64(1), ps2[0].Id)
|
||||
}
|
||||
|
||||
// DeleteRecords
|
||||
opt = sqldb.NewQueryOption("person").Query("id=?").Args(2)
|
||||
n, err := db.DeleteRecords(opt)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Greater(t, int64(0), n)
|
||||
|
||||
// UpdateRecords
|
||||
opt = sqldb.NewQueryOption("person").Set("first_name=?").Query("email=?").Args("哈哈", "zs@foo.com")
|
||||
n, err = db.UpdateRecords(opt)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Greater(t, int64(0), n)
|
||||
|
||||
// NamedUpdateRecords
|
||||
var p5 = ps[0]
|
||||
p5.FirstName = "中华人民共和国"
|
||||
opt = sqldb.NewQueryOption("person").Set("first_name=:first_name").Query("email=:email")
|
||||
n, err = db.NamedUpdateRecords(opt, p5)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Greater(t, int64(0), n)
|
||||
|
||||
var p6 Person
|
||||
opt = sqldb.NewQueryOption("person").Query("first_name=?").Args(p5.FirstName)
|
||||
err = db.GetRecord(&p6, opt)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Greater(t, int64(0), p6.Id)
|
||||
gtest.Equal(t, p6.FirstName, p5.FirstName)
|
||||
}
|
||||
|
||||
func TestFunc(t *testing.T) {
|
||||
RunDbTest(t, dbFuncTest)
|
||||
}
|
@ -1,20 +0,0 @@
|
||||
//
|
||||
// db_model.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
// TODO 暂时不好实现,以后再说
|
||||
|
||||
type Model struct {
|
||||
db *Engine
|
||||
}
|
||||
|
||||
func NewModel() *Model {
|
||||
return &Model{
|
||||
db: Db,
|
||||
}
|
||||
}
|
@ -1,322 +0,0 @@
|
||||
//
|
||||
// db_query.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/jmoiron/sqlx"
|
||||
)
|
||||
|
||||
type Query struct {
|
||||
db *Engine
|
||||
table string
|
||||
fields []string
|
||||
wheres []string // 不能太复杂
|
||||
joins []string
|
||||
orderBy string
|
||||
groupBy string
|
||||
offset int
|
||||
limit int
|
||||
}
|
||||
|
||||
func NewQueryBuild(table string, db *Engine) *Query {
|
||||
return &Query{
|
||||
db: db,
|
||||
table: table,
|
||||
fields: []string{},
|
||||
wheres: []string{},
|
||||
joins: []string{},
|
||||
offset: 0,
|
||||
limit: 0,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *Query) Table(table string) *Query {
|
||||
q.table = table
|
||||
return q
|
||||
}
|
||||
|
||||
// 设置 select fields
|
||||
func (q *Query) Select(fields ...string) *Query {
|
||||
q.fields = fields
|
||||
return q
|
||||
}
|
||||
|
||||
// 增加一个 select field
|
||||
func (q *Query) AddFields(fields ...string) *Query {
|
||||
q.fields = append(q.fields, fields...)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) Where(query string) *Query {
|
||||
q.wheres = []string{query}
|
||||
return q
|
||||
}
|
||||
func (q *Query) AndWhere(query string) *Query {
|
||||
q.wheres = append(q.wheres, "and "+query)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) OrWhere(query string) *Query {
|
||||
q.wheres = append(q.wheres, "or "+query)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) Join(table string, on string) *Query {
|
||||
var join = "join " + table
|
||||
if on != "" {
|
||||
join = join + " on " + on
|
||||
}
|
||||
q.joins = append(q.joins, join)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) LeftJoin(table string, on string) *Query {
|
||||
var join = "left join " + table
|
||||
if on != "" {
|
||||
join = join + " on " + on
|
||||
}
|
||||
q.joins = append(q.joins, join)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) RightJoin(table string, on string) *Query {
|
||||
var join = "right join " + table
|
||||
if on != "" {
|
||||
join = join + " on " + on
|
||||
}
|
||||
q.joins = append(q.joins, join)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) InnerJoin(table string, on string) *Query {
|
||||
var join = "inner join " + table
|
||||
if on != "" {
|
||||
join = join + " on " + on
|
||||
}
|
||||
q.joins = append(q.joins, join)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) OrderBy(order string) *Query {
|
||||
q.orderBy = order
|
||||
return q
|
||||
}
|
||||
func (q *Query) GroupBy(group string) *Query {
|
||||
q.groupBy = group
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) Offset(offset int) *Query {
|
||||
q.offset = offset
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) Limit(limit int) *Query {
|
||||
q.limit = limit
|
||||
return q
|
||||
}
|
||||
|
||||
// returningId postgres 数据库返回 LastInsertId 处理
|
||||
// TODO returningId 暂时不处理
|
||||
func (q *Query) getInsertSql(named, returningId bool) string {
|
||||
fields_str := strings.Join(q.fields, ",")
|
||||
var pl string
|
||||
if named {
|
||||
var tmp []string
|
||||
for _, field := range q.fields {
|
||||
tmp = append(tmp, ":"+field)
|
||||
}
|
||||
pl = strings.Join(tmp, ",")
|
||||
} else {
|
||||
pl = strings.Repeat("?,", len(q.fields))
|
||||
pl = strings.TrimRight(pl, ",")
|
||||
}
|
||||
|
||||
sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl)
|
||||
sql = q.db.Rebind(sql)
|
||||
// fmt.Println(sql)
|
||||
return sql
|
||||
}
|
||||
|
||||
// return RowsAffected, error
|
||||
func (q *Query) Insert(args ...interface{}) (int64, error) {
|
||||
if len(q.fields) == 0 {
|
||||
return 0, errors.New("empty fields")
|
||||
}
|
||||
sql := q.getInsertSql(false, false)
|
||||
result, err := q.db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// return RowsAffected, error
|
||||
func (q *Query) NamedInsert(arg interface{}) (int64, error) {
|
||||
if len(q.fields) == 0 {
|
||||
return 0, errors.New("empty fields")
|
||||
}
|
||||
sql := q.getInsertSql(true, false)
|
||||
result, err := q.db.NamedExec(sql, arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (q *Query) getQuerySql() string {
|
||||
var (
|
||||
fields_str string = "*"
|
||||
join_str string
|
||||
where_str string
|
||||
offlim string
|
||||
)
|
||||
if len(q.fields) > 0 {
|
||||
fields_str = strings.Join(q.fields, ",")
|
||||
}
|
||||
|
||||
if len(q.joins) > 0 {
|
||||
join_str = strings.Join(q.joins, " ")
|
||||
}
|
||||
if len(q.wheres) > 0 {
|
||||
where_str = "where " + strings.Join(q.wheres, " ")
|
||||
}
|
||||
|
||||
if q.offset > 0 {
|
||||
offlim = " offset " + strconv.Itoa(q.offset)
|
||||
}
|
||||
if q.limit > 0 {
|
||||
offlim = " limit " + strconv.Itoa(q.limit)
|
||||
}
|
||||
// select fields from table t join where groupby orderby offset limit
|
||||
sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim)
|
||||
return sql
|
||||
}
|
||||
|
||||
func (q *Query) One(dest interface{}, args ...interface{}) error {
|
||||
q.Limit(1)
|
||||
sql := q.getQuerySql()
|
||||
sql = q.db.Rebind(sql)
|
||||
return q.db.Get(dest, sql, args...)
|
||||
}
|
||||
|
||||
func (q *Query) NamedOne(dest interface{}, arg interface{}) error {
|
||||
q.Limit(1)
|
||||
sql := q.getQuerySql()
|
||||
rows, err := q.db.NamedQuery(sql, arg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if rows.Next() {
|
||||
return rows.Scan(dest)
|
||||
}
|
||||
return errors.New("nr") // no record
|
||||
}
|
||||
|
||||
func (q *Query) All(dest interface{}, args ...interface{}) error {
|
||||
sql := q.getQuerySql()
|
||||
sql = q.db.Rebind(sql)
|
||||
return q.db.Select(dest, sql, args...)
|
||||
}
|
||||
|
||||
// 为了省内存,直接返回迭代器
|
||||
func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) {
|
||||
sql := q.getQuerySql()
|
||||
return q.db.NamedQuery(sql, arg)
|
||||
}
|
||||
|
||||
// set age=? / age=:age
|
||||
func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) {
|
||||
var where_str string
|
||||
if len(q.wheres) > 0 {
|
||||
where_str = strings.Join(q.wheres, " ")
|
||||
}
|
||||
if set == "" || where_str == "" {
|
||||
return 0, errors.New("empty set or where")
|
||||
}
|
||||
|
||||
// update table t where
|
||||
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
|
||||
sql = q.db.Rebind(sql)
|
||||
result, err := q.db.NamedExec(sql, arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数
|
||||
func (q *Query) Update(set string, args ...interface{}) (int64, error) {
|
||||
var where_str string
|
||||
if len(q.wheres) > 0 {
|
||||
where_str = strings.Join(q.wheres, " ")
|
||||
}
|
||||
if set == "" || where_str == "" {
|
||||
return 0, errors.New("empty set or where")
|
||||
}
|
||||
|
||||
// update table t where
|
||||
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
|
||||
sql = q.db.Rebind(sql)
|
||||
result, err := q.db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
// 普通的删除
|
||||
func (q *Query) Delete(args ...interface{}) (int64, error) {
|
||||
var where_str string
|
||||
if len(q.wheres) == 0 {
|
||||
return 0, errors.New("missing where clause")
|
||||
}
|
||||
where_str = strings.Join(q.wheres, " ")
|
||||
|
||||
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
|
||||
sql = q.db.Rebind(sql)
|
||||
result, err := q.db.Exec(sql, args...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (q *Query) NamedDelete(arg interface{}) (int64, error) {
|
||||
if len(q.wheres) == 0 {
|
||||
return 0, errors.New("missing where clause")
|
||||
}
|
||||
var where_str string
|
||||
where_str = strings.Join(q.wheres, " ")
|
||||
|
||||
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
|
||||
sql = q.db.Rebind(sql)
|
||||
|
||||
result, err := q.db.NamedExec(sql, arg)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return result.RowsAffected()
|
||||
}
|
||||
|
||||
func (q *Query) Count(args ...interface{}) (int64, error) {
|
||||
var where_str string
|
||||
if len(q.wheres) > 0 {
|
||||
where_str = " where " + strings.Join(q.wheres, " ")
|
||||
}
|
||||
sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str)
|
||||
sql = q.db.Rebind(sql)
|
||||
var num int64
|
||||
err := q.db.Get(&num, sql, args...)
|
||||
return num, err
|
||||
}
|
@ -1,109 +0,0 @@
|
||||
//
|
||||
// db_query_test.go
|
||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/golib/gtest"
|
||||
|
||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
||||
)
|
||||
|
||||
func dbQueryTest(db *sqldb.Engine, t *testing.T) {
|
||||
query := sqldb.NewQueryBuild("person", db)
|
||||
// query one
|
||||
var p1 Person
|
||||
query.Where("id=?")
|
||||
err := query.One(&p1, 1)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, int64(1), p1.Id)
|
||||
|
||||
// query all
|
||||
var ps1 []Person
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
query.Where("id > ?")
|
||||
err = query.All(&ps1, 1)
|
||||
gtest.Nil(t, err)
|
||||
gtest.True(t, len(ps1) > 0)
|
||||
// fmt.Println(ps1)
|
||||
if len(ps1) > 0 {
|
||||
var val int64 = 2
|
||||
gtest.Equal(t, val, ps1[0].Id)
|
||||
}
|
||||
|
||||
// insert
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
query.AddFields("first_name", "last_name", "email")
|
||||
id, err := query.Insert("三", "张", "zs@bar.com")
|
||||
gtest.Nil(t, err)
|
||||
gtest.Greater(t, int64(0), id)
|
||||
// fmt.Println(id)
|
||||
|
||||
// named insert
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
query.AddFields("first_name", "last_name", "email")
|
||||
row, err := query.NamedInsert(&Person{
|
||||
FirstName: "四",
|
||||
LastName: "李",
|
||||
Email: "ls@bar.com",
|
||||
AddedAt: time.Now().Unix(),
|
||||
})
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, int64(1), row)
|
||||
|
||||
// update
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
query.Where("email=?")
|
||||
n, err := query.Update("first_name=?", "哈哈", "ls@bar.com")
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, int64(1), n)
|
||||
|
||||
// named update map
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
query.Where("email=:email")
|
||||
n, err = query.NamedUpdate("first_name=:first_name", map[string]interface{}{
|
||||
"email": "ls@bar.com",
|
||||
"first_name": "中华人民共和国",
|
||||
})
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, int64(1), n)
|
||||
|
||||
// named update struct
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
query.Where("email=:email")
|
||||
var p = &Person{
|
||||
Email: "ls@bar.com",
|
||||
LastName: "中华人民共和国,救民于水火",
|
||||
}
|
||||
n, err = query.NamedUpdate("last_name=:last_name", p)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, int64(1), n)
|
||||
|
||||
// count
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
n, err = query.Count()
|
||||
gtest.Nil(t, err)
|
||||
// fmt.Println(n)
|
||||
gtest.Greater(t, int64(0), n)
|
||||
|
||||
// delete
|
||||
query = sqldb.NewQueryBuild("person", db)
|
||||
n, err = query.Delete()
|
||||
gtest.NotNil(t, err)
|
||||
gtest.Equal(t, int64(0), n)
|
||||
|
||||
n, err = query.Where("id=?").Delete(2)
|
||||
gtest.Nil(t, err)
|
||||
gtest.Equal(t, int64(1), n)
|
||||
}
|
||||
|
||||
func TestQuery(t *testing.T) {
|
||||
RunDbTest(t, dbQueryTest)
|
||||
}
|
187
gdb/sqldb/db_test.go
Normal file
187
gdb/sqldb/db_test.go
Normal file
@ -0,0 +1,187 @@
|
||||
//
|
||||
// db_test.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
type customType1 []string
|
||||
|
||||
func (c customType1) ToStringSlice() []string {
|
||||
return []string(c)
|
||||
}
|
||||
|
||||
type customType2 []int64
|
||||
|
||||
func (c customType2) ToInt64Slice() []int64 {
|
||||
return []int64(c)
|
||||
}
|
||||
|
||||
func TestDbMap_Select_expandSliceArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
query string
|
||||
args []interface{}
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
description: "it should handle slice placeholders correctly",
|
||||
query: `
|
||||
SELECT 1 FROM crazy_table
|
||||
WHERE field1 = :Field1
|
||||
AND field2 IN (:FieldStringList)
|
||||
AND field3 IN (:FieldUIntList)
|
||||
AND field4 IN (:FieldUInt8List)
|
||||
AND field5 IN (:FieldUInt16List)
|
||||
AND field6 IN (:FieldUInt32List)
|
||||
AND field7 IN (:FieldUInt64List)
|
||||
AND field8 IN (:FieldIntList)
|
||||
AND field9 IN (:FieldInt8List)
|
||||
AND field10 IN (:FieldInt16List)
|
||||
AND field11 IN (:FieldInt32List)
|
||||
AND field12 IN (:FieldInt64List)
|
||||
AND field13 IN (:FieldFloat32List)
|
||||
AND field14 IN (:FieldFloat64List)
|
||||
`,
|
||||
args: []interface{}{
|
||||
map[string]interface{}{
|
||||
"Field1": 123,
|
||||
"FieldStringList": []string{"h", "e", "y"},
|
||||
"FieldUIntList": []uint{1, 2, 3, 4},
|
||||
"FieldUInt8List": []uint8{1, 2, 3, 4},
|
||||
"FieldUInt16List": []uint16{1, 2, 3, 4},
|
||||
"FieldUInt32List": []uint32{1, 2, 3, 4},
|
||||
"FieldUInt64List": []uint64{1, 2, 3, 4},
|
||||
"FieldIntList": []int{1, 2, 3, 4},
|
||||
"FieldInt8List": []int8{1, 2, 3, 4},
|
||||
"FieldInt16List": []int16{1, 2, 3, 4},
|
||||
"FieldInt32List": []int32{1, 2, 3, 4},
|
||||
"FieldInt64List": []int64{1, 2, 3, 4},
|
||||
"FieldFloat32List": []float32{1, 2, 3, 4},
|
||||
"FieldFloat64List": []float64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
description: "it should handle slice placeholders correctly with custom types",
|
||||
query: `
|
||||
SELECT 1 FROM crazy_table
|
||||
WHERE field2 IN (:FieldStringList)
|
||||
AND field12 IN (:FieldIntList)
|
||||
`,
|
||||
args: []interface{}{
|
||||
map[string]interface{}{
|
||||
"FieldStringList": customType1{"h", "e", "y"},
|
||||
"FieldIntList": customType2{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
wantLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
type dataFormat struct {
|
||||
Field1 int `db:"field1"`
|
||||
Field2 string `db:"field2"`
|
||||
Field3 uint `db:"field3"`
|
||||
Field4 uint8 `db:"field4"`
|
||||
Field5 uint16 `db:"field5"`
|
||||
Field6 uint32 `db:"field6"`
|
||||
Field7 uint64 `db:"field7"`
|
||||
Field8 int `db:"field8"`
|
||||
Field9 int8 `db:"field9"`
|
||||
Field10 int16 `db:"field10"`
|
||||
Field11 int32 `db:"field11"`
|
||||
Field12 int64 `db:"field12"`
|
||||
Field13 float32 `db:"field13"`
|
||||
Field14 float64 `db:"field14"`
|
||||
}
|
||||
|
||||
dbmap := newDBMap(t)
|
||||
dbmap.ExpandSliceArgs = true
|
||||
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
|
||||
|
||||
err := dbmap.CreateTables()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer dropAndClose(dbmap)
|
||||
|
||||
err = dbmap.Insert(
|
||||
&dataFormat{
|
||||
Field1: 123,
|
||||
Field2: "h",
|
||||
Field3: 1,
|
||||
Field4: 1,
|
||||
Field5: 1,
|
||||
Field6: 1,
|
||||
Field7: 1,
|
||||
Field8: 1,
|
||||
Field9: 1,
|
||||
Field10: 1,
|
||||
Field11: 1,
|
||||
Field12: 1,
|
||||
Field13: 1,
|
||||
Field14: 1,
|
||||
},
|
||||
&dataFormat{
|
||||
Field1: 124,
|
||||
Field2: "e",
|
||||
Field3: 2,
|
||||
Field4: 2,
|
||||
Field5: 2,
|
||||
Field6: 2,
|
||||
Field7: 2,
|
||||
Field8: 2,
|
||||
Field9: 2,
|
||||
Field10: 2,
|
||||
Field11: 2,
|
||||
Field12: 2,
|
||||
Field13: 2,
|
||||
Field14: 2,
|
||||
},
|
||||
&dataFormat{
|
||||
Field1: 125,
|
||||
Field2: "y",
|
||||
Field3: 3,
|
||||
Field4: 3,
|
||||
Field5: 3,
|
||||
Field6: 3,
|
||||
Field7: 3,
|
||||
Field8: 3,
|
||||
Field9: 3,
|
||||
Field10: 3,
|
||||
Field11: 3,
|
||||
Field12: 3,
|
||||
Field13: 3,
|
||||
Field14: 3,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
var dummy []int
|
||||
_, err := dbmap.Select(&dummy, tt.query, tt.args...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(dummy) != tt.wantLen {
|
||||
t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
108
gdb/sqldb/dialect.go
Normal file
108
gdb/sqldb/dialect.go
Normal file
@ -0,0 +1,108 @@
|
||||
//
|
||||
// dialect.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// The Dialect interface encapsulates behaviors that differ across
|
||||
// SQL databases. At present the Dialect is only used by CreateTables()
|
||||
// but this could change in the future
|
||||
type Dialect interface {
|
||||
// adds a suffix to any query, usually ";"
|
||||
QuerySuffix() string
|
||||
|
||||
// ToSqlType returns the SQL column type to use when creating a
|
||||
// table of the given Go Type. maxsize can be used to switch based on
|
||||
// size. For example, in MySQL []byte could map to BLOB, MEDIUMBLOB,
|
||||
// or LONGBLOB depending on the maxsize
|
||||
ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string
|
||||
|
||||
// string to append to primary key column definitions
|
||||
AutoIncrStr() string
|
||||
|
||||
// string to bind autoincrement columns to. Empty string will
|
||||
// remove reference to those columns in the INSERT statement.
|
||||
AutoIncrBindValue() string
|
||||
|
||||
AutoIncrInsertSuffix(col *ColumnMap) string
|
||||
|
||||
// string to append to "create table" statement for vendor specific
|
||||
// table attributes
|
||||
CreateTableSuffix() string
|
||||
|
||||
// string to append to "create index" statement
|
||||
CreateIndexSuffix() string
|
||||
|
||||
// string to append to "drop index" statement
|
||||
DropIndexSuffix() string
|
||||
|
||||
// string to truncate tables
|
||||
TruncateClause() string
|
||||
|
||||
// bind variable string to use when forming SQL statements
|
||||
// in many dbs it is "?", but Postgres appears to use $1
|
||||
//
|
||||
// i is a zero based index of the bind variable in this statement
|
||||
//
|
||||
BindVar(i int) string
|
||||
|
||||
// Handles quoting of a field name to ensure that it doesn't raise any
|
||||
// SQL parsing exceptions by using a reserved word as a field name.
|
||||
QuoteField(field string) string
|
||||
|
||||
// Handles building up of a schema.database string that is compatible with
|
||||
// the given dialect
|
||||
//
|
||||
// schema - The schema that <table> lives in
|
||||
// table - The table name
|
||||
QuotedTableForQuery(schema string, table string) string
|
||||
|
||||
// Existence clause for table creation / deletion
|
||||
IfSchemaNotExists(command, schema string) string
|
||||
IfTableExists(command, schema, table string) string
|
||||
IfTableNotExists(command, schema, table string) string
|
||||
}
|
||||
|
||||
// IntegerAutoIncrInserter is implemented by dialects that can perform
|
||||
// inserts with automatically incremented integer primary keys. If
|
||||
// the dialect can handle automatic assignment of more than just
|
||||
// integers, see TargetedAutoIncrInserter.
|
||||
type IntegerAutoIncrInserter interface {
|
||||
InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error)
|
||||
}
|
||||
|
||||
// TargetedAutoIncrInserter is implemented by dialects that can
|
||||
// perform automatic assignment of any primary key type (i.e. strings
|
||||
// for uuids, integers for serials, etc).
|
||||
type TargetedAutoIncrInserter interface {
|
||||
// InsertAutoIncrToTarget runs an insert operation and assigns the
|
||||
// automatically generated primary key directly to the passed in
|
||||
// target. The target should be a pointer to the primary key
|
||||
// field of the value being inserted.
|
||||
InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error
|
||||
}
|
||||
|
||||
// TargetQueryInserter is implemented by dialects that can perform
|
||||
// assignment of integer primary key type by executing a query
|
||||
// like "select sequence.currval from dual".
|
||||
type TargetQueryInserter interface {
|
||||
// TargetQueryInserter runs an insert operation and assigns the
|
||||
// automatically generated primary key retrived by the query
|
||||
// extracted from the GeneratedIdQuery field of the id column.
|
||||
InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error
|
||||
}
|
||||
|
||||
func standardInsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
|
||||
res, err := exec.Exec(insertSql, params...)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return res.LastInsertId()
|
||||
}
|
172
gdb/sqldb/dialect_mysql.go
Normal file
172
gdb/sqldb/dialect_mysql.go
Normal file
@ -0,0 +1,172 @@
|
||||
//
|
||||
// dialect_mysql.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Implementation of Dialect for MySQL databases.
|
||||
type MySQLDialect struct {
|
||||
|
||||
// Engine is the storage engine to use "InnoDB" vs "MyISAM" for example
|
||||
Engine string
|
||||
|
||||
// Encoding is the character encoding to use for created tables
|
||||
Encoding string
|
||||
}
|
||||
|
||||
func (d MySQLDialect) QuerySuffix() string { return ";" }
|
||||
|
||||
func (d MySQLDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
||||
switch val.Kind() {
|
||||
case reflect.Ptr:
|
||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int8:
|
||||
return "tinyint"
|
||||
case reflect.Uint8:
|
||||
return "tinyint unsigned"
|
||||
case reflect.Int16:
|
||||
return "smallint"
|
||||
case reflect.Uint16:
|
||||
return "smallint unsigned"
|
||||
case reflect.Int, reflect.Int32:
|
||||
return "int"
|
||||
case reflect.Uint, reflect.Uint32:
|
||||
return "int unsigned"
|
||||
case reflect.Int64:
|
||||
return "bigint"
|
||||
case reflect.Uint64:
|
||||
return "bigint unsigned"
|
||||
case reflect.Float64, reflect.Float32:
|
||||
return "double"
|
||||
case reflect.Slice:
|
||||
if val.Elem().Kind() == reflect.Uint8 {
|
||||
return "mediumblob"
|
||||
}
|
||||
}
|
||||
|
||||
switch val.Name() {
|
||||
case "NullInt64":
|
||||
return "bigint"
|
||||
case "NullFloat64":
|
||||
return "double"
|
||||
case "NullBool":
|
||||
return "tinyint"
|
||||
case "Time":
|
||||
return "datetime"
|
||||
}
|
||||
|
||||
if maxsize < 1 {
|
||||
maxsize = 255
|
||||
}
|
||||
|
||||
/* == About varchar(N) ==
|
||||
* N is number of characters.
|
||||
* A varchar column can store up to 65535 bytes.
|
||||
* Remember that 1 character is 3 bytes in utf-8 charset.
|
||||
* Also remember that each row can store up to 65535 bytes,
|
||||
* and you have some overheads, so it's not possible for a
|
||||
* varchar column to have 65535/3 characters really.
|
||||
* So it would be better to use 'text' type in stead of
|
||||
* large varchar type.
|
||||
*/
|
||||
if maxsize < 256 {
|
||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
}
|
||||
|
||||
// Returns auto_increment
|
||||
func (d MySQLDialect) AutoIncrStr() string {
|
||||
return "auto_increment"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) AutoIncrBindValue() string {
|
||||
return "null"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Returns engine=%s charset=%s based on values stored on struct
|
||||
func (d MySQLDialect) CreateTableSuffix() string {
|
||||
if d.Engine == "" || d.Encoding == "" {
|
||||
msg := "sqldb - undefined"
|
||||
|
||||
if d.Engine == "" {
|
||||
msg += " MySQLDialect.Engine"
|
||||
}
|
||||
if d.Engine == "" && d.Encoding == "" {
|
||||
msg += ","
|
||||
}
|
||||
if d.Encoding == "" {
|
||||
msg += " MySQLDialect.Encoding"
|
||||
}
|
||||
msg += ". Check that your MySQLDialect was correctly initialized when declared."
|
||||
panic(msg)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(" engine=%s charset=%s", d.Engine, d.Encoding)
|
||||
}
|
||||
|
||||
func (d MySQLDialect) CreateIndexSuffix() string {
|
||||
return "using"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) DropIndexSuffix() string {
|
||||
return "on"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) TruncateClause() string {
|
||||
return "truncate"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) SleepClause(s time.Duration) string {
|
||||
return fmt.Sprintf("sleep(%f)", s.Seconds())
|
||||
}
|
||||
|
||||
// Returns "?"
|
||||
func (d MySQLDialect) BindVar(i int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
|
||||
return standardInsertAutoIncr(exec, insertSql, params...)
|
||||
}
|
||||
|
||||
func (d MySQLDialect) QuoteField(f string) string {
|
||||
return "`" + f + "`"
|
||||
}
|
||||
|
||||
func (d MySQLDialect) QuotedTableForQuery(schema string, table string) string {
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return d.QuoteField(table)
|
||||
}
|
||||
|
||||
return schema + "." + d.QuoteField(table)
|
||||
}
|
||||
|
||||
func (d MySQLDialect) IfSchemaNotExists(command, schema string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
||||
|
||||
func (d MySQLDialect) IfTableExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if exists", command)
|
||||
}
|
||||
|
||||
func (d MySQLDialect) IfTableNotExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
195
gdb/sqldb/dialect_mysql_test.go
Normal file
195
gdb/sqldb/dialect_mysql_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
//
|
||||
// dialect_mysql_test.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
//go:build !integration
|
||||
// +build !integration
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
||||
"github.com/poy/onpar"
|
||||
"github.com/poy/onpar/expect"
|
||||
"github.com/poy/onpar/matchers"
|
||||
)
|
||||
|
||||
func TestMySQLDialect(t *testing.T) {
|
||||
// o := onpar.New(t)
|
||||
// defer o.Run()
|
||||
|
||||
type testContext struct {
|
||||
t *testing.T
|
||||
dialect sqldb.MySQLDialect
|
||||
}
|
||||
|
||||
o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext {
|
||||
return testContext{
|
||||
t: t,
|
||||
dialect: sqldb.MySQLDialect{Engine: "foo", Encoding: "bar"},
|
||||
}
|
||||
})
|
||||
defer o.Run()
|
||||
|
||||
o.Group("ToSqlType", func() {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
maxSize int
|
||||
autoIncr bool
|
||||
expected string
|
||||
}{
|
||||
{"bool", true, 0, false, "boolean"},
|
||||
{"int8", int8(1), 0, false, "tinyint"},
|
||||
{"uint8", uint8(1), 0, false, "tinyint unsigned"},
|
||||
{"int16", int16(1), 0, false, "smallint"},
|
||||
{"uint16", uint16(1), 0, false, "smallint unsigned"},
|
||||
{"int32", int32(1), 0, false, "int"},
|
||||
{"int (treated as int32)", int(1), 0, false, "int"},
|
||||
{"uint32", uint32(1), 0, false, "int unsigned"},
|
||||
{"uint (treated as uint32)", uint(1), 0, false, "int unsigned"},
|
||||
{"int64", int64(1), 0, false, "bigint"},
|
||||
{"uint64", uint64(1), 0, false, "bigint unsigned"},
|
||||
{"float32", float32(1), 0, false, "double"},
|
||||
{"float64", float64(1), 0, false, "double"},
|
||||
{"[]uint8", []uint8{1}, 0, false, "mediumblob"},
|
||||
{"NullInt64", sql.NullInt64{}, 0, false, "bigint"},
|
||||
{"NullFloat64", sql.NullFloat64{}, 0, false, "double"},
|
||||
{"NullBool", sql.NullBool{}, 0, false, "tinyint"},
|
||||
{"Time", time.Time{}, 0, false, "datetime"},
|
||||
{"default-size string", "", 0, false, "varchar(255)"},
|
||||
{"sized string", "", 50, false, "varchar(50)"},
|
||||
{"large string", "", 1024, false, "text"},
|
||||
}
|
||||
for _, t := range tests {
|
||||
o.Spec(t.name, func(tt testContext) {
|
||||
typ := reflect.TypeOf(t.value)
|
||||
sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr)
|
||||
expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
o.Spec("AutoIncrStr", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal("auto_increment"))
|
||||
})
|
||||
|
||||
o.Spec("AutoIncrBindValue", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("null"))
|
||||
})
|
||||
|
||||
o.Spec("AutoIncrInsertSuffix", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(nil)).To(matchers.Equal(""))
|
||||
})
|
||||
|
||||
o.Group("CreateTableSuffix", func() {
|
||||
o.Group("with an empty engine", func() {
|
||||
o1 := onpar.BeforeEach(o, func(tt testContext) testContext {
|
||||
tt.dialect.Encoding = ""
|
||||
return tt
|
||||
})
|
||||
o1.Spec("panics", func(tt testContext) {
|
||||
expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
o.Group("with an empty encoding", func() {
|
||||
o2 := onpar.BeforeEach(o, func(tt testContext) testContext {
|
||||
tt.dialect.Encoding = ""
|
||||
return tt
|
||||
})
|
||||
o2.Spec("panics", func(tt testContext) {
|
||||
expect.Expect(t, func() { tt.dialect.CreateTableSuffix() }).To(Panic())
|
||||
})
|
||||
})
|
||||
|
||||
o.Spec("with an engine and an encoding", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(" engine=foo charset=bar"))
|
||||
})
|
||||
})
|
||||
|
||||
o.Spec("CreateIndexSuffix", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using"))
|
||||
})
|
||||
|
||||
o.Spec("DropIndexSuffix", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal("on"))
|
||||
})
|
||||
|
||||
o.Spec("TruncateClause", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate"))
|
||||
})
|
||||
|
||||
o.Spec("SleepClause", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("sleep(1.000000)"))
|
||||
expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("sleep(0.100000)"))
|
||||
})
|
||||
|
||||
o.Spec("BindVar", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("?"))
|
||||
})
|
||||
|
||||
o.Spec("QuoteField", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuoteField("foo")).To(matchers.Equal("`foo`"))
|
||||
})
|
||||
|
||||
o.Group("QuotedTableForQuery", func() {
|
||||
o.Spec("using the default schema", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal("`foo`"))
|
||||
})
|
||||
|
||||
o.Spec("with a supplied schema", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal("foo.`bar`"))
|
||||
})
|
||||
})
|
||||
|
||||
o.Spec("IfSchemaNotExists", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists"))
|
||||
})
|
||||
|
||||
o.Spec("IfTableExists", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists"))
|
||||
})
|
||||
|
||||
o.Spec("IfTableNotExists", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists"))
|
||||
})
|
||||
}
|
||||
|
||||
type panicMatcher struct {
|
||||
}
|
||||
|
||||
func Panic() panicMatcher {
|
||||
return panicMatcher{}
|
||||
}
|
||||
|
||||
func (m panicMatcher) Match(actual interface{}) (resultValue interface{}, err error) {
|
||||
switch f := actual.(type) {
|
||||
case func():
|
||||
panicked := false
|
||||
func() {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
panicked = true
|
||||
}
|
||||
}()
|
||||
f()
|
||||
}()
|
||||
if panicked {
|
||||
return f, nil
|
||||
}
|
||||
return f, errors.New("function did not panic")
|
||||
default:
|
||||
return f, fmt.Errorf("%T is not func()", f)
|
||||
}
|
||||
}
|
142
gdb/sqldb/dialect_oracle.go
Normal file
142
gdb/sqldb/dialect_oracle.go
Normal file
@ -0,0 +1,142 @@
|
||||
//
|
||||
// dialect_oracle.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// Implementation of Dialect for Oracle databases.
|
||||
type OracleDialect struct{}
|
||||
|
||||
func (d OracleDialect) QuerySuffix() string { return "" }
|
||||
|
||||
func (d OracleDialect) CreateIndexSuffix() string { return "" }
|
||||
|
||||
func (d OracleDialect) DropIndexSuffix() string { return "" }
|
||||
|
||||
func (d OracleDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
||||
switch val.Kind() {
|
||||
case reflect.Ptr:
|
||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
||||
if isAutoIncr {
|
||||
return "serial"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if isAutoIncr {
|
||||
return "bigserial"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float64:
|
||||
return "double precision"
|
||||
case reflect.Float32:
|
||||
return "real"
|
||||
case reflect.Slice:
|
||||
if val.Elem().Kind() == reflect.Uint8 {
|
||||
return "bytea"
|
||||
}
|
||||
}
|
||||
|
||||
switch val.Name() {
|
||||
case "NullInt64":
|
||||
return "bigint"
|
||||
case "NullFloat64":
|
||||
return "double precision"
|
||||
case "NullBool":
|
||||
return "boolean"
|
||||
case "NullTime", "Time":
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
|
||||
if maxsize > 0 {
|
||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Returns empty string
|
||||
func (d OracleDialect) AutoIncrStr() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d OracleDialect) AutoIncrBindValue() string {
|
||||
return "NULL"
|
||||
}
|
||||
|
||||
func (d OracleDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Returns suffix
|
||||
func (d OracleDialect) CreateTableSuffix() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d OracleDialect) TruncateClause() string {
|
||||
return "truncate"
|
||||
}
|
||||
|
||||
// Returns "$(i+1)"
|
||||
func (d OracleDialect) BindVar(i int) string {
|
||||
return fmt.Sprintf(":%d", i+1)
|
||||
}
|
||||
|
||||
// After executing the insert uses the ColMap IdQuery to get the generated id
|
||||
func (d OracleDialect) InsertQueryToTarget(exec SqlExecutor, insertSql, idSql string, target interface{}, params ...interface{}) error {
|
||||
_, err := exec.Exec(insertSql, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
id, err := exec.SelectInt(idSql)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch target.(type) {
|
||||
case *int64:
|
||||
*(target.(*int64)) = id
|
||||
case *int32:
|
||||
*(target.(*int32)) = int32(id)
|
||||
case int:
|
||||
*(target.(*int)) = int(id)
|
||||
default:
|
||||
return fmt.Errorf("Id field can be int, int32 or int64")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d OracleDialect) QuoteField(f string) string {
|
||||
return `"` + strings.ToUpper(f) + `"`
|
||||
}
|
||||
|
||||
func (d OracleDialect) QuotedTableForQuery(schema string, table string) string {
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return d.QuoteField(table)
|
||||
}
|
||||
|
||||
return schema + "." + d.QuoteField(table)
|
||||
}
|
||||
|
||||
func (d OracleDialect) IfSchemaNotExists(command, schema string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
||||
|
||||
func (d OracleDialect) IfTableExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if exists", command)
|
||||
}
|
||||
|
||||
func (d OracleDialect) IfTableNotExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
152
gdb/sqldb/dialect_postgres.go
Normal file
152
gdb/sqldb/dialect_postgres.go
Normal file
@ -0,0 +1,152 @@
|
||||
//
|
||||
// dialect_postgres.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type PostgresDialect struct {
|
||||
suffix string
|
||||
LowercaseFields bool
|
||||
}
|
||||
|
||||
func (d PostgresDialect) QuerySuffix() string { return ";" }
|
||||
|
||||
func (d PostgresDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
||||
switch val.Kind() {
|
||||
case reflect.Ptr:
|
||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
||||
case reflect.Bool:
|
||||
return "boolean"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32:
|
||||
if isAutoIncr {
|
||||
return "serial"
|
||||
}
|
||||
return "integer"
|
||||
case reflect.Int64, reflect.Uint64:
|
||||
if isAutoIncr {
|
||||
return "bigserial"
|
||||
}
|
||||
return "bigint"
|
||||
case reflect.Float64:
|
||||
return "double precision"
|
||||
case reflect.Float32:
|
||||
return "real"
|
||||
case reflect.Slice:
|
||||
if val.Elem().Kind() == reflect.Uint8 {
|
||||
return "bytea"
|
||||
}
|
||||
}
|
||||
|
||||
switch val.Name() {
|
||||
case "NullInt64":
|
||||
return "bigint"
|
||||
case "NullFloat64":
|
||||
return "double precision"
|
||||
case "NullBool":
|
||||
return "boolean"
|
||||
case "Time", "NullTime":
|
||||
return "timestamp with time zone"
|
||||
}
|
||||
|
||||
if maxsize > 0 {
|
||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
||||
} else {
|
||||
return "text"
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
// Returns empty string
|
||||
func (d PostgresDialect) AutoIncrStr() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d PostgresDialect) AutoIncrBindValue() string {
|
||||
return "default"
|
||||
}
|
||||
|
||||
func (d PostgresDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
||||
return " returning " + d.QuoteField(col.ColumnName)
|
||||
}
|
||||
|
||||
// Returns suffix
|
||||
func (d PostgresDialect) CreateTableSuffix() string {
|
||||
return d.suffix
|
||||
}
|
||||
|
||||
func (d PostgresDialect) CreateIndexSuffix() string {
|
||||
return "using"
|
||||
}
|
||||
|
||||
func (d PostgresDialect) DropIndexSuffix() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d PostgresDialect) TruncateClause() string {
|
||||
return "truncate"
|
||||
}
|
||||
|
||||
func (d PostgresDialect) SleepClause(s time.Duration) string {
|
||||
return fmt.Sprintf("pg_sleep(%f)", s.Seconds())
|
||||
}
|
||||
|
||||
// Returns "$(i+1)"
|
||||
func (d PostgresDialect) BindVar(i int) string {
|
||||
return fmt.Sprintf("$%d", i+1)
|
||||
}
|
||||
|
||||
func (d PostgresDialect) InsertAutoIncrToTarget(exec SqlExecutor, insertSql string, target interface{}, params ...interface{}) error {
|
||||
rows, err := exec.Query(insertSql, params...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
return fmt.Errorf("No serial value returned for insert: %s Encountered error: %s", insertSql, rows.Err())
|
||||
}
|
||||
if err := rows.Scan(target); err != nil {
|
||||
return err
|
||||
}
|
||||
if rows.Next() {
|
||||
return fmt.Errorf("more than two serial value returned for insert: %s", insertSql)
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func (d PostgresDialect) QuoteField(f string) string {
|
||||
if d.LowercaseFields {
|
||||
return `"` + strings.ToLower(f) + `"`
|
||||
}
|
||||
return `"` + f + `"`
|
||||
}
|
||||
|
||||
func (d PostgresDialect) QuotedTableForQuery(schema string, table string) string {
|
||||
if strings.TrimSpace(schema) == "" {
|
||||
return d.QuoteField(table)
|
||||
}
|
||||
|
||||
return schema + "." + d.QuoteField(table)
|
||||
}
|
||||
|
||||
func (d PostgresDialect) IfSchemaNotExists(command, schema string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
||||
|
||||
func (d PostgresDialect) IfTableExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if exists", command)
|
||||
}
|
||||
|
||||
func (d PostgresDialect) IfTableNotExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
161
gdb/sqldb/dialect_postgres_test.go
Normal file
161
gdb/sqldb/dialect_postgres_test.go
Normal file
@ -0,0 +1,161 @@
|
||||
//
|
||||
// dialect_postgres_test.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
//go:build !integration
|
||||
// +build !integration
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/golib/gdb/sqldb"
|
||||
"github.com/poy/onpar"
|
||||
"github.com/poy/onpar/expect"
|
||||
"github.com/poy/onpar/matchers"
|
||||
)
|
||||
|
||||
func TestPostgresDialect(t *testing.T) {
|
||||
|
||||
type testContext struct {
|
||||
t *testing.T
|
||||
dialect sqldb.PostgresDialect
|
||||
}
|
||||
|
||||
o := onpar.BeforeEach(onpar.New(t), func(t *testing.T) testContext {
|
||||
return testContext{
|
||||
t: t,
|
||||
dialect: sqldb.PostgresDialect{
|
||||
LowercaseFields: false,
|
||||
},
|
||||
}
|
||||
})
|
||||
defer o.Run()
|
||||
|
||||
o.Group("ToSqlType", func() {
|
||||
tests := []struct {
|
||||
name string
|
||||
value interface{}
|
||||
maxSize int
|
||||
autoIncr bool
|
||||
expected string
|
||||
}{
|
||||
{"bool", true, 0, false, "boolean"},
|
||||
{"int8", int8(1), 0, false, "integer"},
|
||||
{"uint8", uint8(1), 0, false, "integer"},
|
||||
{"int16", int16(1), 0, false, "integer"},
|
||||
{"uint16", uint16(1), 0, false, "integer"},
|
||||
{"int32", int32(1), 0, false, "integer"},
|
||||
{"int (treated as int32)", int(1), 0, false, "integer"},
|
||||
{"uint32", uint32(1), 0, false, "integer"},
|
||||
{"uint (treated as uint32)", uint(1), 0, false, "integer"},
|
||||
{"int64", int64(1), 0, false, "bigint"},
|
||||
{"uint64", uint64(1), 0, false, "bigint"},
|
||||
{"float32", float32(1), 0, false, "real"},
|
||||
{"float64", float64(1), 0, false, "double precision"},
|
||||
{"[]uint8", []uint8{1}, 0, false, "bytea"},
|
||||
{"NullInt64", sql.NullInt64{}, 0, false, "bigint"},
|
||||
{"NullFloat64", sql.NullFloat64{}, 0, false, "double precision"},
|
||||
{"NullBool", sql.NullBool{}, 0, false, "boolean"},
|
||||
{"Time", time.Time{}, 0, false, "timestamp with time zone"},
|
||||
{"default-size string", "", 0, false, "text"},
|
||||
{"sized string", "", 50, false, "varchar(50)"},
|
||||
{"large string", "", 1024, false, "varchar(1024)"},
|
||||
}
|
||||
for _, t := range tests {
|
||||
o.Spec(t.name, func(tt testContext) {
|
||||
typ := reflect.TypeOf(t.value)
|
||||
sqlType := tt.dialect.ToSqlType(typ, t.maxSize, t.autoIncr)
|
||||
expect.Expect(tt.t, sqlType).To(matchers.Equal(t.expected))
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
o.Spec("AutoIncrStr", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.AutoIncrStr()).To(matchers.Equal(""))
|
||||
})
|
||||
|
||||
o.Spec("AutoIncrBindValue", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.AutoIncrBindValue()).To(matchers.Equal("default"))
|
||||
})
|
||||
|
||||
o.Spec("AutoIncrInsertSuffix", func(tt testContext) {
|
||||
cm := sqldb.ColumnMap{
|
||||
ColumnName: "foo",
|
||||
}
|
||||
expect.Expect(t, tt.dialect.AutoIncrInsertSuffix(&cm)).To(matchers.Equal(` returning "foo"`))
|
||||
})
|
||||
|
||||
o.Spec("CreateTableSuffix", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.CreateTableSuffix()).To(matchers.Equal(""))
|
||||
})
|
||||
|
||||
o.Spec("CreateIndexSuffix", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.CreateIndexSuffix()).To(matchers.Equal("using"))
|
||||
})
|
||||
|
||||
o.Spec("DropIndexSuffix", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.DropIndexSuffix()).To(matchers.Equal(""))
|
||||
})
|
||||
|
||||
o.Spec("TruncateClause", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.TruncateClause()).To(matchers.Equal("truncate"))
|
||||
})
|
||||
|
||||
o.Spec("SleepClause", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.SleepClause(1*time.Second)).To(matchers.Equal("pg_sleep(1.000000)"))
|
||||
expect.Expect(t, tt.dialect.SleepClause(100*time.Millisecond)).To(matchers.Equal("pg_sleep(0.100000)"))
|
||||
})
|
||||
|
||||
o.Spec("BindVar", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.BindVar(0)).To(matchers.Equal("$1"))
|
||||
expect.Expect(t, tt.dialect.BindVar(4)).To(matchers.Equal("$5"))
|
||||
})
|
||||
|
||||
o.Group("QuoteField", func() {
|
||||
o.Spec("By default, case is preserved", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"Foo"`))
|
||||
expect.Expect(t, tt.dialect.QuoteField("bar")).To(matchers.Equal(`"bar"`))
|
||||
})
|
||||
|
||||
o.Group("With LowercaseFields set to true", func() {
|
||||
o1 := onpar.BeforeEach(o, func(tt testContext) testContext {
|
||||
tt.dialect.LowercaseFields = true
|
||||
return tt
|
||||
})
|
||||
|
||||
o1.Spec("fields are lowercased", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuoteField("Foo")).To(matchers.Equal(`"foo"`))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
o.Group("QuotedTableForQuery", func() {
|
||||
o.Spec("using the default schema", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("", "foo")).To(matchers.Equal(`"foo"`))
|
||||
})
|
||||
|
||||
o.Spec("with a supplied schema", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.QuotedTableForQuery("foo", "bar")).To(matchers.Equal(`foo."bar"`))
|
||||
})
|
||||
})
|
||||
|
||||
o.Spec("IfSchemaNotExists", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.IfSchemaNotExists("foo", "bar")).To(matchers.Equal("foo if not exists"))
|
||||
})
|
||||
|
||||
o.Spec("IfTableExists", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.IfTableExists("foo", "bar", "baz")).To(matchers.Equal("foo if exists"))
|
||||
})
|
||||
|
||||
o.Spec("IfTableNotExists", func(tt testContext) {
|
||||
expect.Expect(t, tt.dialect.IfTableNotExists("foo", "bar", "baz")).To(matchers.Equal("foo if not exists"))
|
||||
})
|
||||
}
|
115
gdb/sqldb/dialect_sqlite.go
Normal file
115
gdb/sqldb/dialect_sqlite.go
Normal file
@ -0,0 +1,115 @@
|
||||
//
|
||||
// dialect_sqlite.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
type SqliteDialect struct {
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (d SqliteDialect) QuerySuffix() string { return ";" }
|
||||
|
||||
func (d SqliteDialect) ToSqlType(val reflect.Type, maxsize int, isAutoIncr bool) string {
|
||||
switch val.Kind() {
|
||||
case reflect.Ptr:
|
||||
return d.ToSqlType(val.Elem(), maxsize, isAutoIncr)
|
||||
case reflect.Bool:
|
||||
return "integer"
|
||||
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
|
||||
return "integer"
|
||||
case reflect.Float64, reflect.Float32:
|
||||
return "real"
|
||||
case reflect.Slice:
|
||||
if val.Elem().Kind() == reflect.Uint8 {
|
||||
return "blob"
|
||||
}
|
||||
}
|
||||
|
||||
switch val.Name() {
|
||||
case "NullInt64":
|
||||
return "integer"
|
||||
case "NullFloat64":
|
||||
return "real"
|
||||
case "NullBool":
|
||||
return "integer"
|
||||
case "Time":
|
||||
return "datetime"
|
||||
}
|
||||
|
||||
if maxsize < 1 {
|
||||
maxsize = 255
|
||||
}
|
||||
return fmt.Sprintf("varchar(%d)", maxsize)
|
||||
}
|
||||
|
||||
// Returns autoincrement
|
||||
func (d SqliteDialect) AutoIncrStr() string {
|
||||
return "autoincrement"
|
||||
}
|
||||
|
||||
func (d SqliteDialect) AutoIncrBindValue() string {
|
||||
return "null"
|
||||
}
|
||||
|
||||
func (d SqliteDialect) AutoIncrInsertSuffix(col *ColumnMap) string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Returns suffix
|
||||
func (d SqliteDialect) CreateTableSuffix() string {
|
||||
return d.suffix
|
||||
}
|
||||
|
||||
func (d SqliteDialect) CreateIndexSuffix() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
func (d SqliteDialect) DropIndexSuffix() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// With sqlite, there technically isn't a TRUNCATE statement,
|
||||
// but a DELETE FROM uses a truncate optimization:
|
||||
// http://www.sqlite.org/lang_delete.html
|
||||
func (d SqliteDialect) TruncateClause() string {
|
||||
return "delete from"
|
||||
}
|
||||
|
||||
// Returns "?"
|
||||
func (d SqliteDialect) BindVar(i int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d SqliteDialect) InsertAutoIncr(exec SqlExecutor, insertSql string, params ...interface{}) (int64, error) {
|
||||
return standardInsertAutoIncr(exec, insertSql, params...)
|
||||
}
|
||||
|
||||
func (d SqliteDialect) QuoteField(f string) string {
|
||||
return `"` + f + `"`
|
||||
}
|
||||
|
||||
// sqlite does not have schemas like PostgreSQL does, so just escape it like normal
|
||||
func (d SqliteDialect) QuotedTableForQuery(schema string, table string) string {
|
||||
return d.QuoteField(table)
|
||||
}
|
||||
|
||||
func (d SqliteDialect) IfSchemaNotExists(command, schema string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
||||
|
||||
func (d SqliteDialect) IfTableExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if exists", command)
|
||||
}
|
||||
|
||||
func (d SqliteDialect) IfTableNotExists(command, schema, table string) string {
|
||||
return fmt.Sprintf("%s if not exists", command)
|
||||
}
|
13
gdb/sqldb/doc.go
Normal file
13
gdb/sqldb/doc.go
Normal file
@ -0,0 +1,13 @@
|
||||
//
|
||||
// doc.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
// Package sqldb provides a simple way to marshal Go structs to and from
|
||||
// SQL databases. It uses the database/sql package, and should work with any
|
||||
// compliant database/sql driver.
|
||||
//
|
||||
// Source code and project home:
|
||||
package sqldb
|
34
gdb/sqldb/errors.go
Normal file
34
gdb/sqldb/errors.go
Normal file
@ -0,0 +1,34 @@
|
||||
//
|
||||
// errors.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// A non-fatal error, when a select query returns columns that do not exist
|
||||
// as fields in the struct it is being mapped to
|
||||
// TODO: discuss wether this needs an error. encoding/json silently ignores missing fields
|
||||
type NoFieldInTypeError struct {
|
||||
TypeName string
|
||||
MissingColNames []string
|
||||
}
|
||||
|
||||
func (err *NoFieldInTypeError) Error() string {
|
||||
return fmt.Sprintf("sqldb: no fields %+v in type %s", err.MissingColNames, err.TypeName)
|
||||
}
|
||||
|
||||
// returns true if the error is non-fatal (ie, we shouldn't immediately return)
|
||||
func NonFatalError(err error) bool {
|
||||
switch err.(type) {
|
||||
case *NoFieldInTypeError:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
45
gdb/sqldb/hooks.go
Normal file
45
gdb/sqldb/hooks.go
Normal file
@ -0,0 +1,45 @@
|
||||
//
|
||||
// hooks.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
//++ TODO v2-phase3: HasPostGet => PostGetter, HasPostDelete => PostDeleter, etc.
|
||||
|
||||
// HasPostGet provides PostGet() which will be executed after the GET statement.
|
||||
type HasPostGet interface {
|
||||
PostGet(SqlExecutor) error
|
||||
}
|
||||
|
||||
// HasPostDelete provides PostDelete() which will be executed after the DELETE statement
|
||||
type HasPostDelete interface {
|
||||
PostDelete(SqlExecutor) error
|
||||
}
|
||||
|
||||
// HasPostUpdate provides PostUpdate() which will be executed after the UPDATE statement
|
||||
type HasPostUpdate interface {
|
||||
PostUpdate(SqlExecutor) error
|
||||
}
|
||||
|
||||
// HasPostInsert provides PostInsert() which will be executed after the INSERT statement
|
||||
type HasPostInsert interface {
|
||||
PostInsert(SqlExecutor) error
|
||||
}
|
||||
|
||||
// HasPreDelete provides PreDelete() which will be executed before the DELETE statement.
|
||||
type HasPreDelete interface {
|
||||
PreDelete(SqlExecutor) error
|
||||
}
|
||||
|
||||
// HasPreUpdate provides PreUpdate() which will be executed before UPDATE statement.
|
||||
type HasPreUpdate interface {
|
||||
PreUpdate(SqlExecutor) error
|
||||
}
|
||||
|
||||
// HasPreInsert provides PreInsert() which will be executed before INSERT statement.
|
||||
type HasPreInsert interface {
|
||||
PreInsert(SqlExecutor) error
|
||||
}
|
51
gdb/sqldb/index.go
Normal file
51
gdb/sqldb/index.go
Normal file
@ -0,0 +1,51 @@
|
||||
//
|
||||
// index.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
// IndexMap represents a mapping between a Go struct field and a single
|
||||
// index in a table.
|
||||
// Unique and MaxSize only inform the
|
||||
// CreateTables() function and are not used by Insert/Update/Delete/Get.
|
||||
type IndexMap struct {
|
||||
// Index name in db table
|
||||
IndexName string
|
||||
|
||||
// If true, " unique" is added to create index statements.
|
||||
// Not used elsewhere
|
||||
Unique bool
|
||||
|
||||
// Index type supported by Dialect
|
||||
// Postgres: B-tree, Hash, GiST and GIN.
|
||||
// Mysql: Btree, Hash.
|
||||
// Sqlite: nil.
|
||||
IndexType string
|
||||
|
||||
// Columns name for single and multiple indexes
|
||||
columns []string
|
||||
}
|
||||
|
||||
// Rename allows you to specify the index name in the table
|
||||
//
|
||||
// Example: table.IndMap("customer_test_idx").Rename("customer_idx")
|
||||
func (idx *IndexMap) Rename(indname string) *IndexMap {
|
||||
idx.IndexName = indname
|
||||
return idx
|
||||
}
|
||||
|
||||
// SetUnique adds "unique" to the create index statements for this
|
||||
// index, if b is true.
|
||||
func (idx *IndexMap) SetUnique(b bool) *IndexMap {
|
||||
idx.Unique = b
|
||||
return idx
|
||||
}
|
||||
|
||||
// SetIndexType specifies the index type supported by chousen SQL Dialect
|
||||
func (idx *IndexMap) SetIndexType(indtype string) *IndexMap {
|
||||
idx.IndexType = indtype
|
||||
return idx
|
||||
}
|
59
gdb/sqldb/lockerror.go
Normal file
59
gdb/sqldb/lockerror.go
Normal file
@ -0,0 +1,59 @@
|
||||
//
|
||||
// lockerror.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// OptimisticLockError is returned by Update() or Delete() if the
|
||||
// struct being modified has a Version field and the value is not equal to
|
||||
// the current value in the database
|
||||
type OptimisticLockError struct {
|
||||
// Table name where the lock error occurred
|
||||
TableName string
|
||||
|
||||
// Primary key values of the row being updated/deleted
|
||||
Keys []interface{}
|
||||
|
||||
// true if a row was found with those keys, indicating the
|
||||
// LocalVersion is stale. false if no value was found with those
|
||||
// keys, suggesting the row has been deleted since loaded, or
|
||||
// was never inserted to begin with
|
||||
RowExists bool
|
||||
|
||||
// Version value on the struct passed to Update/Delete. This value is
|
||||
// out of sync with the database.
|
||||
LocalVersion int64
|
||||
}
|
||||
|
||||
// Error returns a description of the cause of the lock error
|
||||
func (e OptimisticLockError) Error() string {
|
||||
if e.RowExists {
|
||||
return fmt.Sprintf("sqldb: OptimisticLockError table=%s keys=%v out of date version=%d", e.TableName, e.Keys, e.LocalVersion)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("sqldb: OptimisticLockError no row found for table=%s keys=%v", e.TableName, e.Keys)
|
||||
}
|
||||
|
||||
func lockError(m *DbMap, exec SqlExecutor, tableName string,
|
||||
existingVer int64, elem reflect.Value,
|
||||
keys ...interface{}) (int64, error) {
|
||||
|
||||
existing, err := get(m, exec, elem.Interface(), keys...)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
ole := OptimisticLockError{tableName, keys, true, existingVer}
|
||||
if existing == nil {
|
||||
ole.RowExists = false
|
||||
}
|
||||
return -1, ole
|
||||
}
|
45
gdb/sqldb/logging.go
Normal file
45
gdb/sqldb/logging.go
Normal file
@ -0,0 +1,45 @@
|
||||
//
|
||||
// logging.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import "fmt"
|
||||
|
||||
// SqldbLogger is a deprecated alias of Logger.
|
||||
type SqldbLogger = Logger
|
||||
|
||||
// Logger is the type that sqldb uses to log SQL statements.
|
||||
// See DbMap.TraceOn.
|
||||
type Logger interface {
|
||||
Printf(format string, v ...interface{})
|
||||
}
|
||||
|
||||
// TraceOn turns on SQL statement logging for this DbMap. After this is
|
||||
// called, all SQL statements will be sent to the logger. If prefix is
|
||||
// a non-empty string, it will be written to the front of all logged
|
||||
// strings, which can aid in filtering log lines.
|
||||
//
|
||||
// Use TraceOn if you want to spy on the SQL statements that sqldb
|
||||
// generates.
|
||||
//
|
||||
// Note that the base log.Logger type satisfies Logger, but adapters can
|
||||
// easily be written for other logging packages (e.g., the golang-sanctioned
|
||||
// glog framework).
|
||||
func (m *DbMap) TraceOn(prefix string, logger Logger) {
|
||||
m.logger = logger
|
||||
if prefix == "" {
|
||||
m.logPrefix = prefix
|
||||
} else {
|
||||
m.logPrefix = fmt.Sprintf("%s ", prefix)
|
||||
}
|
||||
}
|
||||
|
||||
// TraceOff turns off tracing. It is idempotent.
|
||||
func (m *DbMap) TraceOff() {
|
||||
m.logger = nil
|
||||
m.logPrefix = ""
|
||||
}
|
68
gdb/sqldb/nulltypes.go
Normal file
68
gdb/sqldb/nulltypes.go
Normal file
@ -0,0 +1,68 @@
|
||||
//
|
||||
// nulltypes.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"log"
|
||||
"time"
|
||||
)
|
||||
|
||||
// A nullable Time value
|
||||
type NullTime struct {
|
||||
Time time.Time
|
||||
Valid bool // Valid is true if Time is not NULL
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *NullTime) Scan(value interface{}) error {
|
||||
log.Printf("Time scan value is: %#v", value)
|
||||
switch t := value.(type) {
|
||||
case time.Time:
|
||||
nt.Time, nt.Valid = t, true
|
||||
case []byte:
|
||||
v := strToTime(string(t))
|
||||
if v != nil {
|
||||
nt.Valid = true
|
||||
nt.Time = *v
|
||||
}
|
||||
case string:
|
||||
v := strToTime(t)
|
||||
if v != nil {
|
||||
nt.Valid = true
|
||||
nt.Time = *v
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func strToTime(v string) *time.Time {
|
||||
for _, dtfmt := range []string{
|
||||
"2006-01-02 15:04:05.999999999",
|
||||
"2006-01-02T15:04:05.999999999",
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05",
|
||||
"2006-01-02 15:04",
|
||||
"2006-01-02T15:04",
|
||||
"2006-01-02",
|
||||
"2006-01-02 15:04:05-07:00",
|
||||
} {
|
||||
if t, err := time.Parse(dtfmt, v); err == nil {
|
||||
return &t
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (nt NullTime) Value() (driver.Value, error) {
|
||||
if !nt.Valid {
|
||||
return nil, nil
|
||||
}
|
||||
return nt.Time, nil
|
||||
}
|
361
gdb/sqldb/select.go
Normal file
361
gdb/sqldb/select.go
Normal file
@ -0,0 +1,361 @@
|
||||
//
|
||||
// select.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"reflect"
|
||||
)
|
||||
|
||||
// SelectInt executes the given query, which should be a SELECT statement for a single
|
||||
// integer column, and returns the value of the first row returned. If no rows are
|
||||
// found, zero is returned.
|
||||
func SelectInt(e SqlExecutor, query string, args ...interface{}) (int64, error) {
|
||||
var h int64
|
||||
err := selectVal(e, &h, query, args...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// SelectNullInt executes the given query, which should be a SELECT statement for a single
|
||||
// integer column, and returns the value of the first row returned. If no rows are
|
||||
// found, the empty sql.NullInt64 value is returned.
|
||||
func SelectNullInt(e SqlExecutor, query string, args ...interface{}) (sql.NullInt64, error) {
|
||||
var h sql.NullInt64
|
||||
err := selectVal(e, &h, query, args...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return h, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// SelectFloat executes the given query, which should be a SELECT statement for a single
|
||||
// float column, and returns the value of the first row returned. If no rows are
|
||||
// found, zero is returned.
|
||||
func SelectFloat(e SqlExecutor, query string, args ...interface{}) (float64, error) {
|
||||
var h float64
|
||||
err := selectVal(e, &h, query, args...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return 0, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// SelectNullFloat executes the given query, which should be a SELECT statement for a single
|
||||
// float column, and returns the value of the first row returned. If no rows are
|
||||
// found, the empty sql.NullInt64 value is returned.
|
||||
func SelectNullFloat(e SqlExecutor, query string, args ...interface{}) (sql.NullFloat64, error) {
|
||||
var h sql.NullFloat64
|
||||
err := selectVal(e, &h, query, args...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return h, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// SelectStr executes the given query, which should be a SELECT statement for a single
|
||||
// char/varchar column, and returns the value of the first row returned. If no rows are
|
||||
// found, an empty string is returned.
|
||||
func SelectStr(e SqlExecutor, query string, args ...interface{}) (string, error) {
|
||||
var h string
|
||||
err := selectVal(e, &h, query, args...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return "", err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// SelectNullStr executes the given query, which should be a SELECT
|
||||
// statement for a single char/varchar column, and returns the value
|
||||
// of the first row returned. If no rows are found, the empty
|
||||
// sql.NullString is returned.
|
||||
func SelectNullStr(e SqlExecutor, query string, args ...interface{}) (sql.NullString, error) {
|
||||
var h sql.NullString
|
||||
err := selectVal(e, &h, query, args...)
|
||||
if err != nil && err != sql.ErrNoRows {
|
||||
return h, err
|
||||
}
|
||||
return h, nil
|
||||
}
|
||||
|
||||
// SelectOne executes the given query (which should be a SELECT statement)
|
||||
// and binds the result to holder, which must be a pointer.
|
||||
//
|
||||
// # If no row is found, an error (sql.ErrNoRows specifically) will be returned
|
||||
//
|
||||
// If more than one row is found, an error will be returned.
|
||||
func SelectOne(m *DbMap, e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
|
||||
t := reflect.TypeOf(holder)
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
} else {
|
||||
return fmt.Errorf("sqldb: SelectOne holder must be a pointer, but got: %t", holder)
|
||||
}
|
||||
|
||||
// Handle pointer to pointer
|
||||
isptr := false
|
||||
if t.Kind() == reflect.Ptr {
|
||||
isptr = true
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() == reflect.Struct {
|
||||
var nonFatalErr error
|
||||
|
||||
list, err := hookedselect(m, e, holder, query, args...)
|
||||
if err != nil {
|
||||
if !NonFatalError(err) { // FIXME: double negative, rename NonFatalError to FatalError
|
||||
return err
|
||||
}
|
||||
nonFatalErr = err
|
||||
}
|
||||
|
||||
dest := reflect.ValueOf(holder)
|
||||
if isptr {
|
||||
dest = dest.Elem()
|
||||
}
|
||||
|
||||
if list != nil && len(list) > 0 { // FIXME: invert if/else
|
||||
// check for multiple rows
|
||||
if len(list) > 1 {
|
||||
return fmt.Errorf("sqldb: multiple rows returned for: %s - %v", query, args)
|
||||
}
|
||||
|
||||
// Initialize if nil
|
||||
if dest.IsNil() {
|
||||
dest.Set(reflect.New(t))
|
||||
}
|
||||
|
||||
// only one row found
|
||||
src := reflect.ValueOf(list[0])
|
||||
dest.Elem().Set(src.Elem())
|
||||
} else {
|
||||
// No rows found, return a proper error.
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
return nonFatalErr
|
||||
}
|
||||
|
||||
return selectVal(e, holder, query, args...)
|
||||
}
|
||||
|
||||
func selectVal(e SqlExecutor, holder interface{}, query string, args ...interface{}) error {
|
||||
if len(args) == 1 {
|
||||
switch m := e.(type) {
|
||||
case *DbMap:
|
||||
query, args = maybeExpandNamedQuery(m, query, args)
|
||||
case *Transaction:
|
||||
query, args = maybeExpandNamedQuery(m.dbmap, query, args)
|
||||
}
|
||||
}
|
||||
rows, err := e.Query(query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
if !rows.Next() {
|
||||
if err := rows.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
return sql.ErrNoRows
|
||||
}
|
||||
|
||||
return rows.Scan(holder)
|
||||
}
|
||||
|
||||
func hookedselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
|
||||
args ...interface{}) ([]interface{}, error) {
|
||||
|
||||
var nonFatalErr error
|
||||
|
||||
list, err := rawselect(m, exec, i, query, args...)
|
||||
if err != nil {
|
||||
if !NonFatalError(err) {
|
||||
return nil, err
|
||||
}
|
||||
nonFatalErr = err
|
||||
}
|
||||
|
||||
// Determine where the results are: written to i, or returned in list
|
||||
if t, _ := toSliceType(i); t == nil {
|
||||
for _, v := range list {
|
||||
if v, ok := v.(HasPostGet); ok {
|
||||
err := v.PostGet(exec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
resultsValue := reflect.Indirect(reflect.ValueOf(i))
|
||||
for i := 0; i < resultsValue.Len(); i++ {
|
||||
if v, ok := resultsValue.Index(i).Interface().(HasPostGet); ok {
|
||||
err := v.PostGet(exec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return list, nonFatalErr
|
||||
}
|
||||
|
||||
func rawselect(m *DbMap, exec SqlExecutor, i interface{}, query string,
|
||||
args ...interface{}) ([]interface{}, error) {
|
||||
var (
|
||||
appendToSlice = false // Write results to i directly?
|
||||
intoStruct = true // Selecting into a struct?
|
||||
pointerElements = true // Are the slice elements pointers (vs values)?
|
||||
)
|
||||
|
||||
var nonFatalErr error
|
||||
|
||||
tableName := ""
|
||||
var dynObj DynamicTable
|
||||
isDynamic := false
|
||||
if dynObj, isDynamic = i.(DynamicTable); isDynamic {
|
||||
tableName = dynObj.TableName()
|
||||
}
|
||||
|
||||
// get type for i, verifying it's a supported destination
|
||||
t, err := toType(i)
|
||||
if err != nil {
|
||||
var err2 error
|
||||
if t, err2 = toSliceType(i); t == nil {
|
||||
if err2 != nil {
|
||||
return nil, err2
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
pointerElements = t.Kind() == reflect.Ptr
|
||||
if pointerElements {
|
||||
t = t.Elem()
|
||||
}
|
||||
appendToSlice = true
|
||||
intoStruct = t.Kind() == reflect.Struct
|
||||
}
|
||||
|
||||
// If the caller supplied a single struct/map argument, assume a "named
|
||||
// parameter" query. Extract the named arguments from the struct/map, create
|
||||
// the flat arg slice, and rewrite the query to use the dialect's placeholder.
|
||||
if len(args) == 1 {
|
||||
query, args = maybeExpandNamedQuery(m, query, args)
|
||||
}
|
||||
|
||||
// Run the query
|
||||
rows, err := exec.Query(query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
// Fetch the column names as returned from db
|
||||
cols, err := rows.Columns()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if !intoStruct && len(cols) > 1 {
|
||||
return nil, fmt.Errorf("sqldb: select into non-struct slice requires 1 column, got %d", len(cols))
|
||||
}
|
||||
|
||||
var colToFieldIndex [][]int
|
||||
if intoStruct {
|
||||
colToFieldIndex, err = columnToFieldIndex(m, t, tableName, cols)
|
||||
if err != nil {
|
||||
if !NonFatalError(err) {
|
||||
return nil, err
|
||||
}
|
||||
nonFatalErr = err
|
||||
}
|
||||
}
|
||||
|
||||
conv := m.TypeConverter
|
||||
|
||||
// Add results to one of these two slices.
|
||||
var (
|
||||
list = make([]interface{}, 0)
|
||||
sliceValue = reflect.Indirect(reflect.ValueOf(i))
|
||||
)
|
||||
|
||||
for {
|
||||
if !rows.Next() {
|
||||
// if error occured return rawselect
|
||||
if rows.Err() != nil {
|
||||
return nil, rows.Err()
|
||||
}
|
||||
// time to exit from outer "for" loop
|
||||
break
|
||||
}
|
||||
v := reflect.New(t)
|
||||
|
||||
if isDynamic {
|
||||
v.Interface().(DynamicTable).SetTableName(tableName)
|
||||
}
|
||||
|
||||
dest := make([]interface{}, len(cols))
|
||||
|
||||
custScan := make([]CustomScanner, 0)
|
||||
|
||||
for x := range cols {
|
||||
f := v.Elem()
|
||||
if intoStruct {
|
||||
index := colToFieldIndex[x]
|
||||
if index == nil {
|
||||
// this field is not present in the struct, so create a dummy
|
||||
// value for rows.Scan to scan into
|
||||
var dummy dummyField
|
||||
dest[x] = &dummy
|
||||
continue
|
||||
}
|
||||
f = f.FieldByIndex(index)
|
||||
}
|
||||
target := f.Addr().Interface()
|
||||
if conv != nil {
|
||||
scanner, ok := conv.FromDb(target)
|
||||
if ok {
|
||||
target = scanner.Holder
|
||||
custScan = append(custScan, scanner)
|
||||
}
|
||||
}
|
||||
dest[x] = target
|
||||
}
|
||||
|
||||
err = rows.Scan(dest...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, c := range custScan {
|
||||
err = c.Bind()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if appendToSlice {
|
||||
if !pointerElements {
|
||||
v = v.Elem()
|
||||
}
|
||||
sliceValue.Set(reflect.Append(sliceValue, v))
|
||||
} else {
|
||||
list = append(list, v.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
if appendToSlice && sliceValue.IsNil() {
|
||||
sliceValue.Set(reflect.MakeSlice(sliceValue.Type(), 0, 0))
|
||||
}
|
||||
|
||||
return list, nonFatalErr
|
||||
}
|
675
gdb/sqldb/sqldb.go
Normal file
675
gdb/sqldb/sqldb.go
Normal file
@ -0,0 +1,675 @@
|
||||
//
|
||||
// sqldb.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
// OracleString (empty string is null)
|
||||
// TODO: move to dialect/oracle?, rename to String?
|
||||
type OracleString struct {
|
||||
sql.NullString
|
||||
}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (os *OracleString) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
os.String, os.Valid = "", false
|
||||
return nil
|
||||
}
|
||||
os.Valid = true
|
||||
return os.NullString.Scan(value)
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (os OracleString) Value() (driver.Value, error) {
|
||||
if !os.Valid || os.String == "" {
|
||||
return nil, nil
|
||||
}
|
||||
return os.String, nil
|
||||
}
|
||||
|
||||
// SqlTyper is a type that returns its database type. Most of the
|
||||
// time, the type can just use "database/sql/driver".Valuer; but when
|
||||
// it returns nil for its empty value, it needs to implement SqlTyper
|
||||
// to have its column type detected properly during table creation.
|
||||
type SqlTyper interface {
|
||||
SqlType() driver.Value
|
||||
}
|
||||
|
||||
// legacySqlTyper prevents breaking clients who depended on the previous
|
||||
// SqlTyper interface
|
||||
type legacySqlTyper interface {
|
||||
SqlType() driver.Valuer
|
||||
}
|
||||
|
||||
// for fields that exists in DB table, but not exists in struct
|
||||
type dummyField struct{}
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (nt *dummyField) Scan(value interface{}) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var zeroVal reflect.Value
|
||||
var versFieldConst = "[sqldb_ver_field]"
|
||||
|
||||
// The TypeConverter interface provides a way to map a value of one
|
||||
// type to another type when persisting to, or loading from, a database.
|
||||
//
|
||||
// Example use cases: Implement type converter to convert bool types to "y"/"n" strings,
|
||||
// or serialize a struct member as a JSON blob.
|
||||
type TypeConverter interface {
|
||||
// ToDb converts val to another type. Called before INSERT/UPDATE operations
|
||||
ToDb(val interface{}) (interface{}, error)
|
||||
|
||||
// FromDb returns a CustomScanner appropriate for this type. This will be used
|
||||
// to hold values returned from SELECT queries.
|
||||
//
|
||||
// In particular the CustomScanner returned should implement a Binder
|
||||
// function appropriate for the Go type you wish to convert the db value to
|
||||
//
|
||||
// If bool==false, then no custom scanner will be used for this field.
|
||||
FromDb(target interface{}) (CustomScanner, bool)
|
||||
}
|
||||
|
||||
// SqlExecutor exposes sqldb operations that can be run from Pre/Post
|
||||
// hooks. This hides whether the current operation that triggered the
|
||||
// hook is in a transaction.
|
||||
//
|
||||
// See the DbMap function docs for each of the functions below for more
|
||||
// information.
|
||||
type SqlExecutor interface {
|
||||
WithContext(ctx context.Context) SqlExecutor
|
||||
Get(i interface{}, keys ...interface{}) (interface{}, error)
|
||||
Insert(list ...interface{}) error
|
||||
Update(list ...interface{}) (int64, error)
|
||||
Delete(list ...interface{}) (int64, error)
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Select(i interface{}, query string, args ...interface{}) ([]interface{}, error)
|
||||
SelectInt(query string, args ...interface{}) (int64, error)
|
||||
SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error)
|
||||
SelectFloat(query string, args ...interface{}) (float64, error)
|
||||
SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error)
|
||||
SelectStr(query string, args ...interface{}) (string, error)
|
||||
SelectNullStr(query string, args ...interface{}) (sql.NullString, error)
|
||||
SelectOne(holder interface{}, query string, args ...interface{}) error
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
// DynamicTable allows the users of sqldb to dynamically
|
||||
// use different database table names during runtime
|
||||
// while sharing the same golang struct for in-memory data
|
||||
type DynamicTable interface {
|
||||
TableName() string
|
||||
SetTableName(string)
|
||||
}
|
||||
|
||||
// Compile-time check that DbMap and Transaction implement the SqlExecutor
|
||||
// interface.
|
||||
var _, _ SqlExecutor = &DbMap{}, &Transaction{}
|
||||
|
||||
func argValue(a interface{}) interface{} {
|
||||
v, ok := a.(driver.Valuer)
|
||||
if !ok {
|
||||
return a
|
||||
}
|
||||
vV := reflect.ValueOf(v)
|
||||
if vV.Kind() == reflect.Ptr && vV.IsNil() {
|
||||
return nil
|
||||
}
|
||||
ret, err := v.Value()
|
||||
if err != nil {
|
||||
return a
|
||||
}
|
||||
return ret
|
||||
}
|
||||
|
||||
func argsString(args ...interface{}) string {
|
||||
var margs string
|
||||
for i, a := range args {
|
||||
v := argValue(a)
|
||||
switch v.(type) {
|
||||
case string:
|
||||
v = fmt.Sprintf("%q", v)
|
||||
default:
|
||||
v = fmt.Sprintf("%v", v)
|
||||
}
|
||||
margs += fmt.Sprintf("%d:%s", i+1, v)
|
||||
if i+1 < len(args) {
|
||||
margs += " "
|
||||
}
|
||||
}
|
||||
return margs
|
||||
}
|
||||
|
||||
// Calls the Exec function on the executor, but attempts to expand any eligible named
|
||||
// query arguments first.
|
||||
func maybeExpandNamedQueryAndExec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
|
||||
dbMap := extractDbMap(e)
|
||||
|
||||
if len(args) == 1 {
|
||||
query, args = maybeExpandNamedQuery(dbMap, query, args)
|
||||
}
|
||||
|
||||
return exec(e, query, args...)
|
||||
}
|
||||
|
||||
func extractDbMap(e SqlExecutor) *DbMap {
|
||||
switch m := e.(type) {
|
||||
case *DbMap:
|
||||
return m
|
||||
case *Transaction:
|
||||
return m.dbmap
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// executor exposes the sql.DB and sql.Tx functions so that it can be used
|
||||
// on internal functions that need to be agnostic to the underlying object.
|
||||
type executor interface {
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Prepare(query string) (*sql.Stmt, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
PrepareContext(ctx context.Context, query string) (*sql.Stmt, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func extractExecutorAndContext(e SqlExecutor) (executor, context.Context) {
|
||||
switch m := e.(type) {
|
||||
case *DbMap:
|
||||
return m.Db, m.ctx
|
||||
case *Transaction:
|
||||
return m.tx, m.ctx
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// maybeExpandNamedQuery checks the given arg to see if it's eligible to be used
|
||||
// as input to a named query. If so, it rewrites the query to use
|
||||
// dialect-dependent bindvars and instantiates the corresponding slice of
|
||||
// parameters by extracting data from the map / struct.
|
||||
// If not, returns the input values unchanged.
|
||||
func maybeExpandNamedQuery(m *DbMap, query string, args []interface{}) (string, []interface{}) {
|
||||
var (
|
||||
arg = args[0]
|
||||
argval = reflect.ValueOf(arg)
|
||||
)
|
||||
if argval.Kind() == reflect.Ptr {
|
||||
argval = argval.Elem()
|
||||
}
|
||||
|
||||
if argval.Kind() == reflect.Map && argval.Type().Key().Kind() == reflect.String {
|
||||
return expandNamedQuery(m, query, func(key string) reflect.Value {
|
||||
return argval.MapIndex(reflect.ValueOf(key))
|
||||
})
|
||||
}
|
||||
if argval.Kind() != reflect.Struct {
|
||||
return query, args
|
||||
}
|
||||
if _, ok := arg.(time.Time); ok {
|
||||
// time.Time is driver.Value
|
||||
return query, args
|
||||
}
|
||||
if _, ok := arg.(driver.Valuer); ok {
|
||||
// driver.Valuer will be converted to driver.Value.
|
||||
return query, args
|
||||
}
|
||||
|
||||
return expandNamedQuery(m, query, argval.FieldByName)
|
||||
}
|
||||
|
||||
var keyRegexp = regexp.MustCompile(`:[[:word:]]+`)
|
||||
|
||||
// expandNamedQuery accepts a query with placeholders of the form ":key", and a
|
||||
// single arg of Kind Struct or Map[string]. It returns the query with the
|
||||
// dialect's placeholders, and a slice of args ready for positional insertion
|
||||
// into the query.
|
||||
func expandNamedQuery(m *DbMap, query string, keyGetter func(key string) reflect.Value) (string, []interface{}) {
|
||||
var (
|
||||
n int
|
||||
args []interface{}
|
||||
)
|
||||
return keyRegexp.ReplaceAllStringFunc(query, func(key string) string {
|
||||
val := keyGetter(key[1:])
|
||||
if !val.IsValid() {
|
||||
return key
|
||||
}
|
||||
args = append(args, val.Interface())
|
||||
newVar := m.Dialect.BindVar(n)
|
||||
n++
|
||||
return newVar
|
||||
}), args
|
||||
}
|
||||
|
||||
func columnToFieldIndex(m *DbMap, t reflect.Type, name string, cols []string) ([][]int, error) {
|
||||
colToFieldIndex := make([][]int, len(cols))
|
||||
|
||||
// check if type t is a mapped table - if so we'll
|
||||
// check the table for column aliasing below
|
||||
tableMapped := false
|
||||
table := tableOrNil(m, t, name)
|
||||
if table != nil {
|
||||
tableMapped = true
|
||||
}
|
||||
|
||||
// Loop over column names and find field in i to bind to
|
||||
// based on column name. all returned columns must match
|
||||
// a field in the i struct
|
||||
missingColNames := []string{}
|
||||
for x := range cols {
|
||||
colName := strings.ToLower(cols[x])
|
||||
field, found := t.FieldByNameFunc(func(fieldName string) bool {
|
||||
field, _ := t.FieldByName(fieldName)
|
||||
cArguments := strings.Split(field.Tag.Get("db"), ",")
|
||||
fieldName = cArguments[0]
|
||||
|
||||
if fieldName == "-" {
|
||||
return false
|
||||
} else if fieldName == "" {
|
||||
fieldName = field.Name
|
||||
}
|
||||
if tableMapped {
|
||||
colMap := colMapOrNil(table, fieldName)
|
||||
if colMap != nil {
|
||||
fieldName = colMap.ColumnName
|
||||
}
|
||||
}
|
||||
return colName == strings.ToLower(fieldName)
|
||||
})
|
||||
if found {
|
||||
colToFieldIndex[x] = field.Index
|
||||
}
|
||||
if colToFieldIndex[x] == nil {
|
||||
missingColNames = append(missingColNames, colName)
|
||||
}
|
||||
}
|
||||
if len(missingColNames) > 0 {
|
||||
return colToFieldIndex, &NoFieldInTypeError{
|
||||
TypeName: t.Name(),
|
||||
MissingColNames: missingColNames,
|
||||
}
|
||||
}
|
||||
return colToFieldIndex, nil
|
||||
}
|
||||
|
||||
func fieldByName(val reflect.Value, fieldName string) *reflect.Value {
|
||||
// try to find field by exact match
|
||||
f := val.FieldByName(fieldName)
|
||||
|
||||
if f != zeroVal {
|
||||
return &f
|
||||
}
|
||||
|
||||
// try to find by case insensitive match - only the Postgres driver
|
||||
// seems to require this - in the case where columns are aliased in the sql
|
||||
fieldNameL := strings.ToLower(fieldName)
|
||||
fieldCount := val.NumField()
|
||||
t := val.Type()
|
||||
for i := 0; i < fieldCount; i++ {
|
||||
sf := t.Field(i)
|
||||
if strings.ToLower(sf.Name) == fieldNameL {
|
||||
f := val.Field(i)
|
||||
return &f
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// toSliceType returns the element type of the given object, if the object is a
|
||||
// "*[]*Element" or "*[]Element". If not, returns nil.
|
||||
// err is returned if the user was trying to pass a pointer-to-slice but failed.
|
||||
func toSliceType(i interface{}) (reflect.Type, error) {
|
||||
t := reflect.TypeOf(i)
|
||||
if t.Kind() != reflect.Ptr {
|
||||
// If it's a slice, return a more helpful error message
|
||||
if t.Kind() == reflect.Slice {
|
||||
return nil, fmt.Errorf("sqldb: cannot SELECT into a non-pointer slice: %v", t)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
if t = t.Elem(); t.Kind() != reflect.Slice {
|
||||
return nil, nil
|
||||
}
|
||||
return t.Elem(), nil
|
||||
}
|
||||
|
||||
func toType(i interface{}) (reflect.Type, error) {
|
||||
t := reflect.TypeOf(i)
|
||||
|
||||
// If a Pointer to a type, follow
|
||||
for t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
|
||||
if t.Kind() != reflect.Struct {
|
||||
return nil, fmt.Errorf("sqldb: cannot SELECT into this type: %v", reflect.TypeOf(i))
|
||||
}
|
||||
return t, nil
|
||||
}
|
||||
|
||||
type foundTable struct {
|
||||
table *TableMap
|
||||
dynName *string
|
||||
}
|
||||
|
||||
func tableFor(m *DbMap, t reflect.Type, i interface{}) (*foundTable, error) {
|
||||
if dyn, isDynamic := i.(DynamicTable); isDynamic {
|
||||
tableName := dyn.TableName()
|
||||
table, err := m.DynamicTableFor(tableName, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &foundTable{
|
||||
table: table,
|
||||
dynName: &tableName,
|
||||
}, nil
|
||||
}
|
||||
table, err := m.TableFor(t, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &foundTable{table: table}, nil
|
||||
}
|
||||
|
||||
func get(m *DbMap, exec SqlExecutor, i interface{},
|
||||
keys ...interface{}) (interface{}, error) {
|
||||
|
||||
t, err := toType(i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
foundTable, err := tableFor(m, t, i)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
table := foundTable.table
|
||||
|
||||
plan := table.bindGet()
|
||||
|
||||
v := reflect.New(t)
|
||||
if foundTable.dynName != nil {
|
||||
retDyn := v.Interface().(DynamicTable)
|
||||
retDyn.SetTableName(*foundTable.dynName)
|
||||
}
|
||||
|
||||
dest := make([]interface{}, len(plan.argFields))
|
||||
|
||||
conv := m.TypeConverter
|
||||
custScan := make([]CustomScanner, 0)
|
||||
|
||||
for x, fieldName := range plan.argFields {
|
||||
f := v.Elem().FieldByName(fieldName)
|
||||
target := f.Addr().Interface()
|
||||
if conv != nil {
|
||||
scanner, ok := conv.FromDb(target)
|
||||
if ok {
|
||||
target = scanner.Holder
|
||||
custScan = append(custScan, scanner)
|
||||
}
|
||||
}
|
||||
dest[x] = target
|
||||
}
|
||||
|
||||
row := exec.QueryRow(plan.query, keys...)
|
||||
err = row.Scan(dest...)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
err = nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, c := range custScan {
|
||||
err = c.Bind()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := v.Interface().(HasPostGet); ok {
|
||||
err := v.PostGet(exec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return v.Interface(), nil
|
||||
}
|
||||
|
||||
func delete(m *DbMap, exec SqlExecutor, list ...interface{}) (int64, error) {
|
||||
count := int64(0)
|
||||
for _, ptr := range list {
|
||||
table, elem, err := m.tableForPointer(ptr, true)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
eval := elem.Addr().Interface()
|
||||
if v, ok := eval.(HasPreDelete); ok {
|
||||
err = v.PreDelete(exec)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
bi, err := table.bindDelete(elem)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
res, err := exec.Exec(bi.query, bi.args...)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if rows == 0 && bi.existingVersion > 0 {
|
||||
return lockError(m, exec, table.TableName,
|
||||
bi.existingVersion, elem, bi.keys...)
|
||||
}
|
||||
|
||||
count += rows
|
||||
|
||||
if v, ok := eval.(HasPostDelete); ok {
|
||||
err := v.PostDelete(exec)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func update(m *DbMap, exec SqlExecutor, colFilter ColumnFilter, list ...interface{}) (int64, error) {
|
||||
count := int64(0)
|
||||
for _, ptr := range list {
|
||||
table, elem, err := m.tableForPointer(ptr, true)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
eval := elem.Addr().Interface()
|
||||
if v, ok := eval.(HasPreUpdate); ok {
|
||||
err = v.PreUpdate(exec)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
bi, err := table.bindUpdate(elem, colFilter)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
res, err := exec.Exec(bi.query, bi.args...)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
rows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if rows == 0 && bi.existingVersion > 0 {
|
||||
return lockError(m, exec, table.TableName,
|
||||
bi.existingVersion, elem, bi.keys...)
|
||||
}
|
||||
|
||||
if bi.versField != "" {
|
||||
elem.FieldByName(bi.versField).SetInt(bi.existingVersion + 1)
|
||||
}
|
||||
|
||||
count += rows
|
||||
|
||||
if v, ok := eval.(HasPostUpdate); ok {
|
||||
err = v.PostUpdate(exec)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func insert(m *DbMap, exec SqlExecutor, list ...interface{}) error {
|
||||
for _, ptr := range list {
|
||||
table, elem, err := m.tableForPointer(ptr, false)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
eval := elem.Addr().Interface()
|
||||
if v, ok := eval.(HasPreInsert); ok {
|
||||
err := v.PreInsert(exec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
bi, err := table.bindInsert(elem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if bi.autoIncrIdx > -1 {
|
||||
f := elem.FieldByName(bi.autoIncrFieldName)
|
||||
switch inserter := m.Dialect.(type) {
|
||||
case IntegerAutoIncrInserter:
|
||||
id, err := inserter.InsertAutoIncr(exec, bi.query, bi.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
k := f.Kind()
|
||||
if (k == reflect.Int) || (k == reflect.Int16) || (k == reflect.Int32) || (k == reflect.Int64) {
|
||||
f.SetInt(id)
|
||||
} else if (k == reflect.Uint) || (k == reflect.Uint16) || (k == reflect.Uint32) || (k == reflect.Uint64) {
|
||||
f.SetUint(uint64(id))
|
||||
} else {
|
||||
return fmt.Errorf("sqldb: cannot set autoincrement value on non-Int field. SQL=%s autoIncrIdx=%d autoIncrFieldName=%s", bi.query, bi.autoIncrIdx, bi.autoIncrFieldName)
|
||||
}
|
||||
case TargetedAutoIncrInserter:
|
||||
err := inserter.InsertAutoIncrToTarget(exec, bi.query, f.Addr().Interface(), bi.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
case TargetQueryInserter:
|
||||
var idQuery = table.ColMap(bi.autoIncrFieldName).GeneratedIdQuery
|
||||
if idQuery == "" {
|
||||
return fmt.Errorf("sqldb: cannot set %s value if its ColumnMap.GeneratedIdQuery is empty", bi.autoIncrFieldName)
|
||||
}
|
||||
err := inserter.InsertQueryToTarget(exec, bi.query, idQuery, f.Addr().Interface(), bi.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
default:
|
||||
return fmt.Errorf("sqldb: cannot use autoincrement fields on dialects that do not implement an autoincrementing interface")
|
||||
}
|
||||
} else {
|
||||
_, err := exec.Exec(bi.query, bi.args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := eval.(HasPostInsert); ok {
|
||||
err := v.PostInsert(exec)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func exec(e SqlExecutor, query string, args ...interface{}) (sql.Result, error) {
|
||||
executor, ctx := extractExecutorAndContext(e)
|
||||
|
||||
if ctx != nil {
|
||||
return executor.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
return executor.Exec(query, args...)
|
||||
}
|
||||
|
||||
func prepare(e SqlExecutor, query string) (*sql.Stmt, error) {
|
||||
executor, ctx := extractExecutorAndContext(e)
|
||||
|
||||
if ctx != nil {
|
||||
return executor.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
return executor.Prepare(query)
|
||||
}
|
||||
|
||||
func queryRow(e SqlExecutor, query string, args ...interface{}) *sql.Row {
|
||||
executor, ctx := extractExecutorAndContext(e)
|
||||
|
||||
if ctx != nil {
|
||||
return executor.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
return executor.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
func query(e SqlExecutor, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
executor, ctx := extractExecutorAndContext(e)
|
||||
|
||||
if ctx != nil {
|
||||
return executor.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
return executor.Query(query, args...)
|
||||
}
|
||||
|
||||
func begin(m *DbMap) (*sql.Tx, error) {
|
||||
if m.ctx != nil {
|
||||
return m.Db.BeginTx(m.ctx, nil)
|
||||
}
|
||||
|
||||
return m.Db.Begin()
|
||||
}
|
2875
gdb/sqldb/sqldb_test.go
Normal file
2875
gdb/sqldb/sqldb_test.go
Normal file
File diff suppressed because it is too large
Load Diff
258
gdb/sqldb/table.go
Normal file
258
gdb/sqldb/table.go
Normal file
@ -0,0 +1,258 @@
|
||||
//
|
||||
// table.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// TableMap represents a mapping between a Go struct and a database table
|
||||
// Use dbmap.AddTable() or dbmap.AddTableWithName() to create these
|
||||
type TableMap struct {
|
||||
// Name of database table.
|
||||
TableName string
|
||||
SchemaName string
|
||||
gotype reflect.Type
|
||||
Columns []*ColumnMap
|
||||
keys []*ColumnMap
|
||||
indexes []*IndexMap
|
||||
uniqueTogether [][]string
|
||||
version *ColumnMap
|
||||
insertPlan bindPlan
|
||||
updatePlan bindPlan
|
||||
deletePlan bindPlan
|
||||
getPlan bindPlan
|
||||
dbmap *DbMap
|
||||
}
|
||||
|
||||
// ResetSql removes cached insert/update/select/delete SQL strings
|
||||
// associated with this TableMap. Call this if you've modified
|
||||
// any column names or the table name itself.
|
||||
func (t *TableMap) ResetSql() {
|
||||
t.insertPlan = bindPlan{}
|
||||
t.updatePlan = bindPlan{}
|
||||
t.deletePlan = bindPlan{}
|
||||
t.getPlan = bindPlan{}
|
||||
}
|
||||
|
||||
// SetKeys lets you specify the fields on a struct that map to primary
|
||||
// key columns on the table. If isAutoIncr is set, result.LastInsertId()
|
||||
// will be used after INSERT to bind the generated id to the Go struct.
|
||||
//
|
||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
||||
//
|
||||
// Panics if isAutoIncr is true, and fieldNames length != 1
|
||||
func (t *TableMap) SetKeys(isAutoIncr bool, fieldNames ...string) *TableMap {
|
||||
if isAutoIncr && len(fieldNames) != 1 {
|
||||
panic(fmt.Sprintf(
|
||||
"sqldb: SetKeys: fieldNames length must be 1 if key is auto-increment. (Saw %v fieldNames)",
|
||||
len(fieldNames)))
|
||||
}
|
||||
t.keys = make([]*ColumnMap, 0)
|
||||
for _, name := range fieldNames {
|
||||
colmap := t.ColMap(name)
|
||||
colmap.isPK = true
|
||||
colmap.isAutoIncr = isAutoIncr
|
||||
t.keys = append(t.keys, colmap)
|
||||
}
|
||||
t.ResetSql()
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// SetUniqueTogether lets you specify uniqueness constraints across multiple
|
||||
// columns on the table. Each call adds an additional constraint for the
|
||||
// specified columns.
|
||||
//
|
||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
||||
//
|
||||
// Panics if fieldNames length < 2.
|
||||
func (t *TableMap) SetUniqueTogether(fieldNames ...string) *TableMap {
|
||||
if len(fieldNames) < 2 {
|
||||
panic(fmt.Sprintf(
|
||||
"sqldb: SetUniqueTogether: must provide at least two fieldNames to set uniqueness constraint."))
|
||||
}
|
||||
|
||||
columns := make([]string, 0, len(fieldNames))
|
||||
for _, name := range fieldNames {
|
||||
columns = append(columns, name)
|
||||
}
|
||||
|
||||
for _, existingColumns := range t.uniqueTogether {
|
||||
if equal(existingColumns, columns) {
|
||||
return t
|
||||
}
|
||||
}
|
||||
t.uniqueTogether = append(t.uniqueTogether, columns)
|
||||
t.ResetSql()
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// ColMap returns the ColumnMap pointer matching the given struct field
|
||||
// name. It panics if the struct does not contain a field matching this
|
||||
// name.
|
||||
func (t *TableMap) ColMap(field string) *ColumnMap {
|
||||
col := colMapOrNil(t, field)
|
||||
if col == nil {
|
||||
e := fmt.Sprintf("No ColumnMap in table %s type %s with field %s",
|
||||
t.TableName, t.gotype.Name(), field)
|
||||
|
||||
panic(e)
|
||||
}
|
||||
return col
|
||||
}
|
||||
|
||||
func colMapOrNil(t *TableMap, field string) *ColumnMap {
|
||||
for _, col := range t.Columns {
|
||||
if col.fieldName == field || col.ColumnName == field {
|
||||
return col
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IdxMap returns the IndexMap pointer matching the given index name.
|
||||
func (t *TableMap) IdxMap(field string) *IndexMap {
|
||||
for _, idx := range t.indexes {
|
||||
if idx.IndexName == field {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddIndex registers the index with sqldb for specified table with given parameters.
|
||||
// This operation is idempotent. If index is already mapped, the
|
||||
// existing *IndexMap is returned
|
||||
// Function will panic if one of the given for index columns does not exists
|
||||
//
|
||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
||||
func (t *TableMap) AddIndex(name string, idxtype string, columns []string) *IndexMap {
|
||||
// check if we have a index with this name already
|
||||
for _, idx := range t.indexes {
|
||||
if idx.IndexName == name {
|
||||
return idx
|
||||
}
|
||||
}
|
||||
for _, icol := range columns {
|
||||
if res := t.ColMap(icol); res == nil {
|
||||
e := fmt.Sprintf("No ColumnName in table %s to create index on", t.TableName)
|
||||
panic(e)
|
||||
}
|
||||
}
|
||||
|
||||
idx := &IndexMap{IndexName: name, Unique: false, IndexType: idxtype, columns: columns}
|
||||
t.indexes = append(t.indexes, idx)
|
||||
t.ResetSql()
|
||||
return idx
|
||||
}
|
||||
|
||||
// SetVersionCol sets the column to use as the Version field. By default
|
||||
// the "Version" field is used. Returns the column found, or panics
|
||||
// if the struct does not contain a field matching this name.
|
||||
//
|
||||
// Automatically calls ResetSql() to ensure SQL statements are regenerated.
|
||||
func (t *TableMap) SetVersionCol(field string) *ColumnMap {
|
||||
c := t.ColMap(field)
|
||||
t.version = c
|
||||
t.ResetSql()
|
||||
return c
|
||||
}
|
||||
|
||||
// SqlForCreateTable gets a sequence of SQL commands that will create
|
||||
// the specified table and any associated schema
|
||||
func (t *TableMap) SqlForCreate(ifNotExists bool) string {
|
||||
s := bytes.Buffer{}
|
||||
dialect := t.dbmap.Dialect
|
||||
|
||||
if strings.TrimSpace(t.SchemaName) != "" {
|
||||
schemaCreate := "create schema"
|
||||
if ifNotExists {
|
||||
s.WriteString(dialect.IfSchemaNotExists(schemaCreate, t.SchemaName))
|
||||
} else {
|
||||
s.WriteString(schemaCreate)
|
||||
}
|
||||
s.WriteString(fmt.Sprintf(" %s;", t.SchemaName))
|
||||
}
|
||||
|
||||
tableCreate := "create table"
|
||||
if ifNotExists {
|
||||
s.WriteString(dialect.IfTableNotExists(tableCreate, t.SchemaName, t.TableName))
|
||||
} else {
|
||||
s.WriteString(tableCreate)
|
||||
}
|
||||
s.WriteString(fmt.Sprintf(" %s (", dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
||||
|
||||
x := 0
|
||||
for _, col := range t.Columns {
|
||||
if !col.Transient {
|
||||
if x > 0 {
|
||||
s.WriteString(", ")
|
||||
}
|
||||
stype := dialect.ToSqlType(col.gotype, col.MaxSize, col.isAutoIncr)
|
||||
s.WriteString(fmt.Sprintf("%s %s", dialect.QuoteField(col.ColumnName), stype))
|
||||
|
||||
if col.isPK || col.isNotNull {
|
||||
s.WriteString(" not null")
|
||||
}
|
||||
if col.isPK && len(t.keys) == 1 {
|
||||
s.WriteString(" primary key")
|
||||
}
|
||||
if col.Unique {
|
||||
s.WriteString(" unique")
|
||||
}
|
||||
if col.isAutoIncr {
|
||||
s.WriteString(fmt.Sprintf(" %s", dialect.AutoIncrStr()))
|
||||
}
|
||||
|
||||
x++
|
||||
}
|
||||
}
|
||||
if len(t.keys) > 1 {
|
||||
s.WriteString(", primary key (")
|
||||
for x := range t.keys {
|
||||
if x > 0 {
|
||||
s.WriteString(", ")
|
||||
}
|
||||
s.WriteString(dialect.QuoteField(t.keys[x].ColumnName))
|
||||
}
|
||||
s.WriteString(")")
|
||||
}
|
||||
if len(t.uniqueTogether) > 0 {
|
||||
for _, columns := range t.uniqueTogether {
|
||||
s.WriteString(", unique (")
|
||||
for i, column := range columns {
|
||||
if i > 0 {
|
||||
s.WriteString(", ")
|
||||
}
|
||||
s.WriteString(dialect.QuoteField(column))
|
||||
}
|
||||
s.WriteString(")")
|
||||
}
|
||||
}
|
||||
s.WriteString(") ")
|
||||
s.WriteString(dialect.CreateTableSuffix())
|
||||
s.WriteString(dialect.QuerySuffix())
|
||||
return s.String()
|
||||
}
|
||||
|
||||
func equal(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for i := range a {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
308
gdb/sqldb/table_bindings.go
Normal file
308
gdb/sqldb/table_bindings.go
Normal file
@ -0,0 +1,308 @@
|
||||
//
|
||||
// table_bindings.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// CustomScanner binds a database column value to a Go type
|
||||
type CustomScanner struct {
|
||||
// After a row is scanned, Holder will contain the value from the database column.
|
||||
// Initialize the CustomScanner with the concrete Go type you wish the database
|
||||
// driver to scan the raw column into.
|
||||
Holder interface{}
|
||||
// Target typically holds a pointer to the target struct field to bind the Holder
|
||||
// value to.
|
||||
Target interface{}
|
||||
// Binder is a custom function that converts the holder value to the target type
|
||||
// and sets target accordingly. This function should return error if a problem
|
||||
// occurs converting the holder to the target.
|
||||
Binder func(holder interface{}, target interface{}) error
|
||||
}
|
||||
|
||||
// Used to filter columns when selectively updating
|
||||
type ColumnFilter func(*ColumnMap) bool
|
||||
|
||||
func acceptAllFilter(col *ColumnMap) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// Bind is called automatically by sqldb after Scan()
|
||||
func (me CustomScanner) Bind() error {
|
||||
return me.Binder(me.Holder, me.Target)
|
||||
}
|
||||
|
||||
type bindPlan struct {
|
||||
query string
|
||||
argFields []string
|
||||
keyFields []string
|
||||
versField string
|
||||
autoIncrIdx int
|
||||
autoIncrFieldName string
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func (plan *bindPlan) createBindInstance(elem reflect.Value, conv TypeConverter) (bindInstance, error) {
|
||||
bi := bindInstance{query: plan.query, autoIncrIdx: plan.autoIncrIdx, autoIncrFieldName: plan.autoIncrFieldName, versField: plan.versField}
|
||||
if plan.versField != "" {
|
||||
bi.existingVersion = elem.FieldByName(plan.versField).Int()
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
for i := 0; i < len(plan.argFields); i++ {
|
||||
k := plan.argFields[i]
|
||||
if k == versFieldConst {
|
||||
newVer := bi.existingVersion + 1
|
||||
bi.args = append(bi.args, newVer)
|
||||
if bi.existingVersion == 0 {
|
||||
elem.FieldByName(plan.versField).SetInt(int64(newVer))
|
||||
}
|
||||
} else {
|
||||
val := elem.FieldByName(k).Interface()
|
||||
if conv != nil {
|
||||
val, err = conv.ToDb(val)
|
||||
if err != nil {
|
||||
return bindInstance{}, err
|
||||
}
|
||||
}
|
||||
bi.args = append(bi.args, val)
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < len(plan.keyFields); i++ {
|
||||
k := plan.keyFields[i]
|
||||
val := elem.FieldByName(k).Interface()
|
||||
if conv != nil {
|
||||
val, err = conv.ToDb(val)
|
||||
if err != nil {
|
||||
return bindInstance{}, err
|
||||
}
|
||||
}
|
||||
bi.keys = append(bi.keys, val)
|
||||
}
|
||||
|
||||
return bi, nil
|
||||
}
|
||||
|
||||
type bindInstance struct {
|
||||
query string
|
||||
args []interface{}
|
||||
keys []interface{}
|
||||
existingVersion int64
|
||||
versField string
|
||||
autoIncrIdx int
|
||||
autoIncrFieldName string
|
||||
}
|
||||
|
||||
func (t *TableMap) bindInsert(elem reflect.Value) (bindInstance, error) {
|
||||
plan := &t.insertPlan
|
||||
plan.once.Do(func() {
|
||||
plan.autoIncrIdx = -1
|
||||
|
||||
s := bytes.Buffer{}
|
||||
s2 := bytes.Buffer{}
|
||||
s.WriteString(fmt.Sprintf("insert into %s (", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
||||
|
||||
x := 0
|
||||
first := true
|
||||
for y := range t.Columns {
|
||||
col := t.Columns[y]
|
||||
if !(col.isAutoIncr && t.dbmap.Dialect.AutoIncrBindValue() == "") {
|
||||
if !col.Transient {
|
||||
if !first {
|
||||
s.WriteString(",")
|
||||
s2.WriteString(",")
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
||||
|
||||
if col.isAutoIncr {
|
||||
s2.WriteString(t.dbmap.Dialect.AutoIncrBindValue())
|
||||
plan.autoIncrIdx = y
|
||||
plan.autoIncrFieldName = col.fieldName
|
||||
} else {
|
||||
if col.DefaultValue == "" {
|
||||
s2.WriteString(t.dbmap.Dialect.BindVar(x))
|
||||
if col == t.version {
|
||||
plan.versField = col.fieldName
|
||||
plan.argFields = append(plan.argFields, versFieldConst)
|
||||
} else {
|
||||
plan.argFields = append(plan.argFields, col.fieldName)
|
||||
}
|
||||
x++
|
||||
} else {
|
||||
s2.WriteString(col.DefaultValue)
|
||||
}
|
||||
}
|
||||
first = false
|
||||
}
|
||||
} else {
|
||||
plan.autoIncrIdx = y
|
||||
plan.autoIncrFieldName = col.fieldName
|
||||
}
|
||||
}
|
||||
s.WriteString(") values (")
|
||||
s.WriteString(s2.String())
|
||||
s.WriteString(")")
|
||||
if plan.autoIncrIdx > -1 {
|
||||
s.WriteString(t.dbmap.Dialect.AutoIncrInsertSuffix(t.Columns[plan.autoIncrIdx]))
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
||||
|
||||
plan.query = s.String()
|
||||
})
|
||||
|
||||
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
|
||||
}
|
||||
|
||||
func (t *TableMap) bindUpdate(elem reflect.Value, colFilter ColumnFilter) (bindInstance, error) {
|
||||
if colFilter == nil {
|
||||
colFilter = acceptAllFilter
|
||||
}
|
||||
|
||||
plan := &t.updatePlan
|
||||
plan.once.Do(func() {
|
||||
s := bytes.Buffer{}
|
||||
s.WriteString(fmt.Sprintf("update %s set ", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
||||
x := 0
|
||||
|
||||
for y := range t.Columns {
|
||||
col := t.Columns[y]
|
||||
if !col.isAutoIncr && !col.Transient && colFilter(col) {
|
||||
if x > 0 {
|
||||
s.WriteString(", ")
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
||||
s.WriteString("=")
|
||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
||||
|
||||
if col == t.version {
|
||||
plan.versField = col.fieldName
|
||||
plan.argFields = append(plan.argFields, versFieldConst)
|
||||
} else {
|
||||
plan.argFields = append(plan.argFields, col.fieldName)
|
||||
}
|
||||
x++
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString(" where ")
|
||||
for y := range t.keys {
|
||||
col := t.keys[y]
|
||||
if y > 0 {
|
||||
s.WriteString(" and ")
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
||||
s.WriteString("=")
|
||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
||||
|
||||
plan.argFields = append(plan.argFields, col.fieldName)
|
||||
plan.keyFields = append(plan.keyFields, col.fieldName)
|
||||
x++
|
||||
}
|
||||
if plan.versField != "" {
|
||||
s.WriteString(" and ")
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
|
||||
s.WriteString("=")
|
||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
||||
plan.argFields = append(plan.argFields, plan.versField)
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
||||
|
||||
plan.query = s.String()
|
||||
})
|
||||
|
||||
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
|
||||
}
|
||||
|
||||
func (t *TableMap) bindDelete(elem reflect.Value) (bindInstance, error) {
|
||||
plan := &t.deletePlan
|
||||
plan.once.Do(func() {
|
||||
s := bytes.Buffer{}
|
||||
s.WriteString(fmt.Sprintf("delete from %s", t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName)))
|
||||
|
||||
for y := range t.Columns {
|
||||
col := t.Columns[y]
|
||||
if !col.Transient {
|
||||
if col == t.version {
|
||||
plan.versField = col.fieldName
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
s.WriteString(" where ")
|
||||
for x := range t.keys {
|
||||
k := t.keys[x]
|
||||
if x > 0 {
|
||||
s.WriteString(" and ")
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(k.ColumnName))
|
||||
s.WriteString("=")
|
||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
||||
|
||||
plan.keyFields = append(plan.keyFields, k.fieldName)
|
||||
plan.argFields = append(plan.argFields, k.fieldName)
|
||||
}
|
||||
if plan.versField != "" {
|
||||
s.WriteString(" and ")
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(t.version.ColumnName))
|
||||
s.WriteString("=")
|
||||
s.WriteString(t.dbmap.Dialect.BindVar(len(plan.argFields)))
|
||||
|
||||
plan.argFields = append(plan.argFields, plan.versField)
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
||||
|
||||
plan.query = s.String()
|
||||
})
|
||||
|
||||
return plan.createBindInstance(elem, t.dbmap.TypeConverter)
|
||||
}
|
||||
|
||||
func (t *TableMap) bindGet() *bindPlan {
|
||||
plan := &t.getPlan
|
||||
plan.once.Do(func() {
|
||||
s := bytes.Buffer{}
|
||||
s.WriteString("select ")
|
||||
|
||||
x := 0
|
||||
for _, col := range t.Columns {
|
||||
if !col.Transient {
|
||||
if x > 0 {
|
||||
s.WriteString(",")
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
||||
plan.argFields = append(plan.argFields, col.fieldName)
|
||||
x++
|
||||
}
|
||||
}
|
||||
s.WriteString(" from ")
|
||||
s.WriteString(t.dbmap.Dialect.QuotedTableForQuery(t.SchemaName, t.TableName))
|
||||
s.WriteString(" where ")
|
||||
for x := range t.keys {
|
||||
col := t.keys[x]
|
||||
if x > 0 {
|
||||
s.WriteString(" and ")
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuoteField(col.ColumnName))
|
||||
s.WriteString("=")
|
||||
s.WriteString(t.dbmap.Dialect.BindVar(x))
|
||||
|
||||
plan.keyFields = append(plan.keyFields, col.fieldName)
|
||||
}
|
||||
s.WriteString(t.dbmap.Dialect.QuerySuffix())
|
||||
|
||||
plan.query = s.String()
|
||||
})
|
||||
|
||||
return plan
|
||||
}
|
24
gdb/sqldb/test_all.sh
Executable file
24
gdb/sqldb/test_all.sh
Executable file
@ -0,0 +1,24 @@
|
||||
#!/bin/bash -ex
|
||||
|
||||
# on macs, you may need to:
|
||||
# export GOBUILDFLAG=-ldflags -linkmode=external
|
||||
|
||||
echo "Running unit tests"
|
||||
go test -race
|
||||
|
||||
echo "Testing against postgres"
|
||||
export SQLDB_TEST_DSN="host=127.0.0.1 user=testuser password=123 dbname=testdb sslmode=disable"
|
||||
export SQLDB_TEST_DIALECT=postgres
|
||||
go test -tags integration $GOBUILDFLAG $@ .
|
||||
|
||||
echo "Testing against sqlite"
|
||||
export SQLDB_TEST_DSN=/tmp/testdb.bin
|
||||
export SQLDB_TEST_DIALECT=sqlite
|
||||
go test -tags integration $GOBUILDFLAG $@ .
|
||||
rm -f /tmp/testdb.bin
|
||||
|
||||
echo "Testing against mysql"
|
||||
# export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb?charset=utf8mb4&parseTime=True&loc=Local"
|
||||
export SQLDB_TEST_DSN="testuser:123@tcp(127.0.0.1:3306)/testdb"
|
||||
export SQLDB_TEST_DIALECT=mysql
|
||||
go test -tags integration $GOBUILDFLAG $@ .
|
242
gdb/sqldb/transaction.go
Normal file
242
gdb/sqldb/transaction.go
Normal file
@ -0,0 +1,242 @@
|
||||
//
|
||||
// transaction.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
package sqldb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
)
|
||||
|
||||
// Transaction represents a database transaction.
|
||||
// Insert/Update/Delete/Get/Exec operations will be run in the context
|
||||
// of that transaction. Transactions should be terminated with
|
||||
// a call to Commit() or Rollback()
|
||||
type Transaction struct {
|
||||
ctx context.Context
|
||||
dbmap *DbMap
|
||||
tx *sql.Tx
|
||||
closed bool
|
||||
}
|
||||
|
||||
func (t *Transaction) WithContext(ctx context.Context) SqlExecutor {
|
||||
copy := &Transaction{}
|
||||
*copy = *t
|
||||
copy.ctx = ctx
|
||||
return copy
|
||||
}
|
||||
|
||||
// Insert has the same behavior as DbMap.Insert(), but runs in a transaction.
|
||||
func (t *Transaction) Insert(list ...interface{}) error {
|
||||
return insert(t.dbmap, t, list...)
|
||||
}
|
||||
|
||||
// Update had the same behavior as DbMap.Update(), but runs in a transaction.
|
||||
func (t *Transaction) Update(list ...interface{}) (int64, error) {
|
||||
return update(t.dbmap, t, nil, list...)
|
||||
}
|
||||
|
||||
// UpdateColumns had the same behavior as DbMap.UpdateColumns(), but runs in a transaction.
|
||||
func (t *Transaction) UpdateColumns(filter ColumnFilter, list ...interface{}) (int64, error) {
|
||||
return update(t.dbmap, t, filter, list...)
|
||||
}
|
||||
|
||||
// Delete has the same behavior as DbMap.Delete(), but runs in a transaction.
|
||||
func (t *Transaction) Delete(list ...interface{}) (int64, error) {
|
||||
return delete(t.dbmap, t, list...)
|
||||
}
|
||||
|
||||
// Get has the same behavior as DbMap.Get(), but runs in a transaction.
|
||||
func (t *Transaction) Get(i interface{}, keys ...interface{}) (interface{}, error) {
|
||||
return get(t.dbmap, t, i, keys...)
|
||||
}
|
||||
|
||||
// Select has the same behavior as DbMap.Select(), but runs in a transaction.
|
||||
func (t *Transaction) Select(i interface{}, query string, args ...interface{}) ([]interface{}, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return hookedselect(t.dbmap, t, i, query, args...)
|
||||
}
|
||||
|
||||
// Exec has the same behavior as DbMap.Exec(), but runs in a transaction.
|
||||
func (t *Transaction) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, query, args...)
|
||||
}
|
||||
return maybeExpandNamedQueryAndExec(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectInt is a convenience wrapper around the sqldb.SelectInt function.
|
||||
func (t *Transaction) SelectInt(query string, args ...interface{}) (int64, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectInt(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectNullInt is a convenience wrapper around the sqldb.SelectNullInt function.
|
||||
func (t *Transaction) SelectNullInt(query string, args ...interface{}) (sql.NullInt64, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectNullInt(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectFloat is a convenience wrapper around the sqldb.SelectFloat function.
|
||||
func (t *Transaction) SelectFloat(query string, args ...interface{}) (float64, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectFloat(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectNullFloat is a convenience wrapper around the sqldb.SelectNullFloat function.
|
||||
func (t *Transaction) SelectNullFloat(query string, args ...interface{}) (sql.NullFloat64, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectNullFloat(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectStr is a convenience wrapper around the sqldb.SelectStr function.
|
||||
func (t *Transaction) SelectStr(query string, args ...interface{}) (string, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectStr(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectNullStr is a convenience wrapper around the sqldb.SelectNullStr function.
|
||||
func (t *Transaction) SelectNullStr(query string, args ...interface{}) (sql.NullString, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectNullStr(t, query, args...)
|
||||
}
|
||||
|
||||
// SelectOne is a convenience wrapper around the sqldb.SelectOne function.
|
||||
func (t *Transaction) SelectOne(holder interface{}, query string, args ...interface{}) error {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
return SelectOne(t.dbmap, t, holder, query, args...)
|
||||
}
|
||||
|
||||
// Commit commits the underlying database transaction.
|
||||
func (t *Transaction) Commit() error {
|
||||
if !t.closed {
|
||||
t.closed = true
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, "commit;")
|
||||
}
|
||||
return t.tx.Commit()
|
||||
}
|
||||
|
||||
return sql.ErrTxDone
|
||||
}
|
||||
|
||||
// Rollback rolls back the underlying database transaction.
|
||||
func (t *Transaction) Rollback() error {
|
||||
if !t.closed {
|
||||
t.closed = true
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, "rollback;")
|
||||
}
|
||||
return t.tx.Rollback()
|
||||
}
|
||||
|
||||
return sql.ErrTxDone
|
||||
}
|
||||
|
||||
// Savepoint creates a savepoint with the given name. The name is interpolated
|
||||
// directly into the SQL SAVEPOINT statement, so you must sanitize it if it is
|
||||
// derived from user input.
|
||||
func (t *Transaction) Savepoint(name string) error {
|
||||
query := "savepoint " + t.dbmap.Dialect.QuoteField(name)
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, query, nil)
|
||||
}
|
||||
_, err := exec(t, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// RollbackToSavepoint rolls back to the savepoint with the given name. The
|
||||
// name is interpolated directly into the SQL SAVEPOINT statement, so you must
|
||||
// sanitize it if it is derived from user input.
|
||||
func (t *Transaction) RollbackToSavepoint(savepoint string) error {
|
||||
query := "rollback to savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, query, nil)
|
||||
}
|
||||
_, err := exec(t, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// ReleaseSavepint releases the savepoint with the given name. The name is
|
||||
// interpolated directly into the SQL SAVEPOINT statement, so you must sanitize
|
||||
// it if it is derived from user input.
|
||||
func (t *Transaction) ReleaseSavepoint(savepoint string) error {
|
||||
query := "release savepoint " + t.dbmap.Dialect.QuoteField(savepoint)
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, query, nil)
|
||||
}
|
||||
_, err := exec(t, query)
|
||||
return err
|
||||
}
|
||||
|
||||
// Prepare has the same behavior as DbMap.Prepare(), but runs in a transaction.
|
||||
func (t *Transaction) Prepare(query string) (*sql.Stmt, error) {
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, query, nil)
|
||||
}
|
||||
return prepare(t, query)
|
||||
}
|
||||
|
||||
func (t *Transaction) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&query, args...)
|
||||
}
|
||||
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, query, args...)
|
||||
}
|
||||
return queryRow(t, query, args...)
|
||||
}
|
||||
|
||||
func (t *Transaction) Query(q string, args ...interface{}) (*sql.Rows, error) {
|
||||
if t.dbmap.ExpandSliceArgs {
|
||||
expandSliceArgs(&q, args...)
|
||||
}
|
||||
|
||||
if t.dbmap.logger != nil {
|
||||
now := time.Now()
|
||||
defer t.dbmap.trace(now, q, args...)
|
||||
}
|
||||
return query(t, q, args...)
|
||||
}
|
340
gdb/sqldb/transaction_test.go
Normal file
340
gdb/sqldb/transaction_test.go
Normal file
@ -0,0 +1,340 @@
|
||||
//
|
||||
// transaction_test.go
|
||||
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||||
//
|
||||
// Distributed under terms of the MIT license.
|
||||
//
|
||||
|
||||
//go:build integration
|
||||
// +build integration
|
||||
|
||||
package sqldb_test
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestTransaction_Select_expandSliceArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
query string
|
||||
args []interface{}
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
description: "it should handle slice placeholders correctly",
|
||||
query: `
|
||||
SELECT 1 FROM crazy_table
|
||||
WHERE field1 = :Field1
|
||||
AND field2 IN (:FieldStringList)
|
||||
AND field3 IN (:FieldUIntList)
|
||||
AND field4 IN (:FieldUInt8List)
|
||||
AND field5 IN (:FieldUInt16List)
|
||||
AND field6 IN (:FieldUInt32List)
|
||||
AND field7 IN (:FieldUInt64List)
|
||||
AND field8 IN (:FieldIntList)
|
||||
AND field9 IN (:FieldInt8List)
|
||||
AND field10 IN (:FieldInt16List)
|
||||
AND field11 IN (:FieldInt32List)
|
||||
AND field12 IN (:FieldInt64List)
|
||||
AND field13 IN (:FieldFloat32List)
|
||||
AND field14 IN (:FieldFloat64List)
|
||||
`,
|
||||
args: []interface{}{
|
||||
map[string]interface{}{
|
||||
"Field1": 123,
|
||||
"FieldStringList": []string{"h", "e", "y"},
|
||||
"FieldUIntList": []uint{1, 2, 3, 4},
|
||||
"FieldUInt8List": []uint8{1, 2, 3, 4},
|
||||
"FieldUInt16List": []uint16{1, 2, 3, 4},
|
||||
"FieldUInt32List": []uint32{1, 2, 3, 4},
|
||||
"FieldUInt64List": []uint64{1, 2, 3, 4},
|
||||
"FieldIntList": []int{1, 2, 3, 4},
|
||||
"FieldInt8List": []int8{1, 2, 3, 4},
|
||||
"FieldInt16List": []int16{1, 2, 3, 4},
|
||||
"FieldInt32List": []int32{1, 2, 3, 4},
|
||||
"FieldInt64List": []int64{1, 2, 3, 4},
|
||||
"FieldFloat32List": []float32{1, 2, 3, 4},
|
||||
"FieldFloat64List": []float64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
description: "it should handle slice placeholders correctly with custom types",
|
||||
query: `
|
||||
SELECT 1 FROM crazy_table
|
||||
WHERE field2 IN (:FieldStringList)
|
||||
AND field12 IN (:FieldIntList)
|
||||
`,
|
||||
args: []interface{}{
|
||||
map[string]interface{}{
|
||||
"FieldStringList": customType1{"h", "e", "y"},
|
||||
"FieldIntList": customType2{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
wantLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
type dataFormat struct {
|
||||
Field1 int `db:"field1"`
|
||||
Field2 string `db:"field2"`
|
||||
Field3 uint `db:"field3"`
|
||||
Field4 uint8 `db:"field4"`
|
||||
Field5 uint16 `db:"field5"`
|
||||
Field6 uint32 `db:"field6"`
|
||||
Field7 uint64 `db:"field7"`
|
||||
Field8 int `db:"field8"`
|
||||
Field9 int8 `db:"field9"`
|
||||
Field10 int16 `db:"field10"`
|
||||
Field11 int32 `db:"field11"`
|
||||
Field12 int64 `db:"field12"`
|
||||
Field13 float32 `db:"field13"`
|
||||
Field14 float64 `db:"field14"`
|
||||
}
|
||||
|
||||
dbmap := newDBMap(t)
|
||||
dbmap.ExpandSliceArgs = true
|
||||
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
|
||||
|
||||
err := dbmap.CreateTables()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer dropAndClose(dbmap)
|
||||
|
||||
err = dbmap.Insert(
|
||||
&dataFormat{
|
||||
Field1: 123,
|
||||
Field2: "h",
|
||||
Field3: 1,
|
||||
Field4: 1,
|
||||
Field5: 1,
|
||||
Field6: 1,
|
||||
Field7: 1,
|
||||
Field8: 1,
|
||||
Field9: 1,
|
||||
Field10: 1,
|
||||
Field11: 1,
|
||||
Field12: 1,
|
||||
Field13: 1,
|
||||
Field14: 1,
|
||||
},
|
||||
&dataFormat{
|
||||
Field1: 124,
|
||||
Field2: "e",
|
||||
Field3: 2,
|
||||
Field4: 2,
|
||||
Field5: 2,
|
||||
Field6: 2,
|
||||
Field7: 2,
|
||||
Field8: 2,
|
||||
Field9: 2,
|
||||
Field10: 2,
|
||||
Field11: 2,
|
||||
Field12: 2,
|
||||
Field13: 2,
|
||||
Field14: 2,
|
||||
},
|
||||
&dataFormat{
|
||||
Field1: 125,
|
||||
Field2: "y",
|
||||
Field3: 3,
|
||||
Field4: 3,
|
||||
Field5: 3,
|
||||
Field6: 3,
|
||||
Field7: 3,
|
||||
Field8: 3,
|
||||
Field9: 3,
|
||||
Field10: 3,
|
||||
Field11: 3,
|
||||
Field12: 3,
|
||||
Field13: 3,
|
||||
Field14: 3,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
tx, err := dbmap.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
var dummy []int
|
||||
_, err = tx.Select(&dummy, tt.query, tt.args...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if len(dummy) != tt.wantLen {
|
||||
t.Errorf("wrong result count\ngot: %d\nwant: %d", len(dummy), tt.wantLen)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTransaction_Exec_expandSliceArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
description string
|
||||
query string
|
||||
args []interface{}
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
description: "it should handle slice placeholders correctly",
|
||||
query: `
|
||||
DELETE FROM crazy_table
|
||||
WHERE field1 = :Field1
|
||||
AND field2 IN (:FieldStringList)
|
||||
AND field3 IN (:FieldUIntList)
|
||||
AND field4 IN (:FieldUInt8List)
|
||||
AND field5 IN (:FieldUInt16List)
|
||||
AND field6 IN (:FieldUInt32List)
|
||||
AND field7 IN (:FieldUInt64List)
|
||||
AND field8 IN (:FieldIntList)
|
||||
AND field9 IN (:FieldInt8List)
|
||||
AND field10 IN (:FieldInt16List)
|
||||
AND field11 IN (:FieldInt32List)
|
||||
AND field12 IN (:FieldInt64List)
|
||||
AND field13 IN (:FieldFloat32List)
|
||||
AND field14 IN (:FieldFloat64List)
|
||||
`,
|
||||
args: []interface{}{
|
||||
map[string]interface{}{
|
||||
"Field1": 123,
|
||||
"FieldStringList": []string{"h", "e", "y"},
|
||||
"FieldUIntList": []uint{1, 2, 3, 4},
|
||||
"FieldUInt8List": []uint8{1, 2, 3, 4},
|
||||
"FieldUInt16List": []uint16{1, 2, 3, 4},
|
||||
"FieldUInt32List": []uint32{1, 2, 3, 4},
|
||||
"FieldUInt64List": []uint64{1, 2, 3, 4},
|
||||
"FieldIntList": []int{1, 2, 3, 4},
|
||||
"FieldInt8List": []int8{1, 2, 3, 4},
|
||||
"FieldInt16List": []int16{1, 2, 3, 4},
|
||||
"FieldInt32List": []int32{1, 2, 3, 4},
|
||||
"FieldInt64List": []int64{1, 2, 3, 4},
|
||||
"FieldFloat32List": []float32{1, 2, 3, 4},
|
||||
"FieldFloat64List": []float64{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
description: "it should handle slice placeholders correctly with custom types",
|
||||
query: `
|
||||
DELETE FROM crazy_table
|
||||
WHERE field2 IN (:FieldStringList)
|
||||
AND field12 IN (:FieldIntList)
|
||||
`,
|
||||
args: []interface{}{
|
||||
map[string]interface{}{
|
||||
"FieldStringList": customType1{"h", "e", "y"},
|
||||
"FieldIntList": customType2{1, 2, 3, 4},
|
||||
},
|
||||
},
|
||||
wantLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
type dataFormat struct {
|
||||
Field1 int `db:"field1"`
|
||||
Field2 string `db:"field2"`
|
||||
Field3 uint `db:"field3"`
|
||||
Field4 uint8 `db:"field4"`
|
||||
Field5 uint16 `db:"field5"`
|
||||
Field6 uint32 `db:"field6"`
|
||||
Field7 uint64 `db:"field7"`
|
||||
Field8 int `db:"field8"`
|
||||
Field9 int8 `db:"field9"`
|
||||
Field10 int16 `db:"field10"`
|
||||
Field11 int32 `db:"field11"`
|
||||
Field12 int64 `db:"field12"`
|
||||
Field13 float32 `db:"field13"`
|
||||
Field14 float64 `db:"field14"`
|
||||
}
|
||||
|
||||
dbmap := newDBMap(t)
|
||||
dbmap.ExpandSliceArgs = true
|
||||
dbmap.AddTableWithName(dataFormat{}, "crazy_table")
|
||||
|
||||
err := dbmap.CreateTables()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
defer dropAndClose(dbmap)
|
||||
|
||||
err = dbmap.Insert(
|
||||
&dataFormat{
|
||||
Field1: 123,
|
||||
Field2: "h",
|
||||
Field3: 1,
|
||||
Field4: 1,
|
||||
Field5: 1,
|
||||
Field6: 1,
|
||||
Field7: 1,
|
||||
Field8: 1,
|
||||
Field9: 1,
|
||||
Field10: 1,
|
||||
Field11: 1,
|
||||
Field12: 1,
|
||||
Field13: 1,
|
||||
Field14: 1,
|
||||
},
|
||||
&dataFormat{
|
||||
Field1: 124,
|
||||
Field2: "e",
|
||||
Field3: 2,
|
||||
Field4: 2,
|
||||
Field5: 2,
|
||||
Field6: 2,
|
||||
Field7: 2,
|
||||
Field8: 2,
|
||||
Field9: 2,
|
||||
Field10: 2,
|
||||
Field11: 2,
|
||||
Field12: 2,
|
||||
Field13: 2,
|
||||
Field14: 2,
|
||||
},
|
||||
&dataFormat{
|
||||
Field1: 125,
|
||||
Field2: "y",
|
||||
Field3: 3,
|
||||
Field4: 3,
|
||||
Field5: 3,
|
||||
Field6: 3,
|
||||
Field7: 3,
|
||||
Field8: 3,
|
||||
Field9: 3,
|
||||
Field10: 3,
|
||||
Field11: 3,
|
||||
Field12: 3,
|
||||
Field13: 3,
|
||||
Field14: 3,
|
||||
},
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
tx, err := dbmap.Begin()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tx.Rollback()
|
||||
|
||||
_, err = tx.Exec(tt.query, tt.args...)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
8
go.mod
8
go.mod
@ -8,17 +8,20 @@ require (
|
||||
github.com/go-redis/redis/v8 v8.11.5
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/hibiken/asynq v0.24.1
|
||||
github.com/jmoiron/sqlx v1.3.5
|
||||
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-runewidth v0.0.14
|
||||
github.com/mattn/go-sqlite3 v1.14.6
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/poy/onpar v0.3.2
|
||||
github.com/rs/xid v1.5.0
|
||||
github.com/stretchr/testify v1.8.3
|
||||
go.mongodb.org/mongo-driver v1.11.7
|
||||
go.uber.org/zap v1.25.0
|
||||
golang.org/x/crypto v0.10.0
|
||||
gopkg.in/natefinch/lumberjack.v2 v2.2.1
|
||||
gopkg.in/yaml.v2 v2.4.0
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
@ -26,6 +29,7 @@ require (
|
||||
github.com/bytedance/sonic v1.9.1 // indirect
|
||||
github.com/cespare/xxhash/v2 v2.2.0 // indirect
|
||||
github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect
|
||||
github.com/gabriel-vasile/mimetype v1.4.2 // indirect
|
||||
github.com/gin-contrib/sse v0.1.0 // indirect
|
||||
@ -47,6 +51,7 @@ require (
|
||||
github.com/modern-go/reflect2 v1.0.2 // indirect
|
||||
github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect
|
||||
github.com/pelletier/go-toml/v2 v2.0.8 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/redis/go-redis/v9 v9.0.3 // indirect
|
||||
github.com/rivo/uniseg v0.2.0 // indirect
|
||||
github.com/robfig/cron/v3 v3.0.1 // indirect
|
||||
@ -65,5 +70,4 @@ require (
|
||||
golang.org/x/text v0.10.0 // indirect
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect
|
||||
google.golang.org/protobuf v1.30.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
||||
|
7
go.sum
7
go.sum
@ -1,3 +1,4 @@
|
||||
git.sr.ht/~nelsam/hel v0.4.3 h1:9W0zz8zv8CZhFsp8r9Wq6c8gFemBdtMurjZU/JKfvfM=
|
||||
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw=
|
||||
github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0=
|
||||
github.com/benbjohnson/clock v1.3.0 h1:ip6w0uFQkncKQ979AypyG0ER7mqUSBdKLOgAle/AT8A=
|
||||
@ -36,7 +37,6 @@ github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg
|
||||
github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU=
|
||||
github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI=
|
||||
github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo=
|
||||
github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg=
|
||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU=
|
||||
@ -57,8 +57,6 @@ github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs=
|
||||
github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw=
|
||||
github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts=
|
||||
github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g=
|
||||
github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ=
|
||||
github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4=
|
||||
github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc=
|
||||
github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM=
|
||||
@ -81,7 +79,6 @@ github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible h1:Y6sqxHMyB1D2YSzWkL
|
||||
github.com/lestrrat-go/file-rotatelogs v2.4.0+incompatible/go.mod h1:ZQnN8lSECaebrkQytbHj4xNgtg8CR7RYXnPok8e0EHA=
|
||||
github.com/lestrrat-go/strftime v1.0.6 h1:CFGsDEt1pOpFNU+TJB0nhz9jl+K0hZSLE205AhTIGQQ=
|
||||
github.com/lestrrat-go/strftime v1.0.6/go.mod h1:f7jQKgV5nnJpYgdEasS+/y7EsTb8ykN2z68n3TtcTaw=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA=
|
||||
@ -106,6 +103,8 @@ github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/poy/onpar v0.3.2 h1:yo8ZRqU3C4RlvkXPWUWfonQiTodAgpKQZ1g8VTNU9xU=
|
||||
github.com/poy/onpar v0.3.2/go.mod h1:6XDWG8DJ1HsFX6/Btn0pHl3Jz5d1SEEGNZ5N1gtYo+I=
|
||||
github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k=
|
||||
github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk=
|
||||
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
|
||||
|
Loading…
Reference in New Issue
Block a user