golib/gdb/sqldb/db_func.go
2023-06-15 21:22:51 +08:00

221 lines
5.5 KiB
Go

//
// db_func.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
)
func (e *Engine) Begin() (*sqlx.Tx, error) {
return e.Beginx()
}
// 插入一条记录
func (e *Engine) NamedInsertRecord(opt *QueryOption, arg interface{}) (int64, error) { // {{{
if len(opt.fields) == 0 {
return 0, errors.New("empty fields")
}
var tmp = make([]string, 0)
for _, field := range opt.fields {
tmp = append(tmp, fmt.Sprintf(":%s", field))
}
fields_str := strings.Join(opt.fields, ",")
fields_pl := strings.Join(tmp, ",")
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", opt.table, fields_str, fields_pl)
if e.DriverName() == "postgres" {
sql += " returning id"
}
// sql = e.Rebind(sql)
stmt, err := e.PrepareNamed(sql)
if err != nil {
return 0, err
}
var id int64
err = stmt.Get(&id, arg)
if err != nil {
return 0, err
}
return id, err
} // }}}
// 插入一条记录
func (e *Engine) InsertRecord(opt *QueryOption) (int64, error) { // {{{
if len(opt.fields) == 0 {
return 0, errors.New("empty fields")
}
fields_str := strings.Join(opt.fields, ",")
fields_pl := strings.TrimRight(strings.Repeat("?,", len(opt.fields)), ",")
sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", opt.table, fields_str, fields_pl)
if e.DriverName() == "postgres" {
sql += " returning id"
}
sql = e.Rebind(sql)
result, err := e.Exec(sql, opt.args...)
if err != nil {
return 0, err
}
return result.LastInsertId()
} // }}}
// 查询一条记录
// dest 目标对象
// table 查询表
// query 查询条件
// args bindvars
func (e *Engine) GetRecord(dest interface{}, opt *QueryOption) error { // {{{
if opt.query == "" {
return errors.New("empty query")
}
opt.query = "WHERE " + opt.query
sql := fmt.Sprintf("SELECT * FROM %s %s limit 1", opt.table, opt.query)
sql = e.Rebind(sql)
err := e.Get(dest, sql, opt.args...)
if err != nil {
return err
}
return nil
} // }}}
// 查询多条记录
// dest 目标变量
// opt 查询对象
// args bindvars
func (e *Engine) GetRecords(dest interface{}, opt *QueryOption) error { // {{{
var tmp = []string{}
if opt.query != "" {
tmp = append(tmp, "where", opt.query)
}
if opt.sort != "" {
tmp = append(tmp, "order by", opt.sort)
}
if opt.offset > 0 {
tmp = append(tmp, "offset", strconv.Itoa(opt.offset))
}
if opt.limit > 0 {
tmp = append(tmp, "limit", strconv.Itoa(opt.limit))
}
sql := fmt.Sprintf("select * from %s %s", opt.table, strings.Join(tmp, " "))
sql = e.Rebind(sql)
return e.Select(dest, sql, opt.args...)
} // }}}
// 更新一条记录
// table 待处理的表
// set 需要设置的语句, eg: age=:age
// query 查询语句,不能为空,确保误更新所有记录
// arg 值
func (e *Engine) NamedUpdateRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{
if opt.set == "" || opt.query == "" {
return 0, errors.New("empty set or query")
}
sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query)
result, err := e.NamedExec(sql, arg)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
} // }}}
func (e *Engine) UpdateRecords(opt *QueryOption) (int64, error) { // {{{
if opt.set == "" || opt.query == "" {
return 0, errors.New("empty set or query")
}
sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query)
sql = e.Rebind(sql)
result, err := e.Exec(sql, opt.args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
} // }}}
// 删除若干条记录
// opt 的 query 不能为空
// arg bindvars
func (e *Engine) NamedDeleteRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{
if opt.query == "" {
return 0, errors.New("emtpy query")
}
sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query)
result, err := e.NamedExec(sql, arg)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
} // }}}
func (e *Engine) DeleteRecords(opt *QueryOption) (int64, error) {
if opt.query == "" {
return 0, errors.New("emtpy query")
}
sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query)
sql = e.Rebind(sql)
result, err := e.Exec(sql, opt.args...)
if err != nil {
return 0, err
}
rows, err := result.RowsAffected()
if err != nil {
return 0, err
}
return rows, nil
}
func (e *Engine) CountRecords(opt *QueryOption) (int, error) {
sql := fmt.Sprintf("select count(*) from %s where %s", opt.table, opt.query)
sql = e.Rebind(sql)
var num int
err := e.Get(&num, sql, opt.args...)
if err != nil {
return 0, err
}
return num, nil
}
// var levels = []int{4, 6, 7}
// query, args, err := sqlx.In("SELECT * FROM users WHERE level IN (?);", levels)
// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend
// query = db.Rebind(query)
// rows, err := db.Query(query, args...)
func (e *Engine) In(query string, args ...interface{}) (string, []interface{}, error) {
return sqlx.In(query, args...)
}
func IsNoRows(err error) bool {
return err == ErrNoRows
}
// 把 fields 转换为 field1=:field1, field2=:field2, ..., fieldN=:fieldN
func GetSetString(fields []string) string {
items := []string{}
for _, field := range fields {
if field == "id" {
continue
}
items = append(items, fmt.Sprintf("%s=:%s", field, field))
}
return strings.Join(items, ",")
}