// // db_func.go // Copyright (C) 2022 tiglog // // Distributed under terms of the MIT license. // package sqldb import ( "errors" "fmt" "strconv" "strings" "github.com/jmoiron/sqlx" ) func (e *Engine) Begin() (*sqlx.Tx, error) { return e.Beginx() } // 插入一条记录 func (e *Engine) NamedInsertRecord(opt *QueryOption, arg interface{}) (int64, error) { // {{{ if len(opt.fields) == 0 { return 0, errors.New("empty fields") } var tmp = make([]string, 0) for _, field := range opt.fields { tmp = append(tmp, fmt.Sprintf(":%s", field)) } fields_str := strings.Join(opt.fields, ",") fields_pl := strings.Join(tmp, ",") sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", opt.table, fields_str, fields_pl) if e.DriverName() == "postgres" { sql += " returning id" } // sql = e.Rebind(sql) stmt, err := e.PrepareNamed(sql) if err != nil { return 0, err } var id int64 err = stmt.Get(&id, arg) if err != nil { return 0, err } return id, err } // }}} // 插入一条记录 func (e *Engine) InsertRecord(opt *QueryOption) (int64, error) { // {{{ if len(opt.fields) == 0 { return 0, errors.New("empty fields") } fields_str := strings.Join(opt.fields, ",") fields_pl := strings.TrimRight(strings.Repeat("?,", len(opt.fields)), ",") sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", opt.table, fields_str, fields_pl) if e.DriverName() == "postgres" { sql += " returning id" } sql = e.Rebind(sql) result, err := e.Exec(sql, opt.args...) if err != nil { return 0, err } return result.LastInsertId() } // }}} // 查询一条记录 // dest 目标对象 // table 查询表 // query 查询条件 // args bindvars func (e *Engine) GetRecord(dest interface{}, opt *QueryOption) error { // {{{ if opt.query == "" { return errors.New("empty query") } opt.query = "WHERE " + opt.query sql := fmt.Sprintf("SELECT * FROM %s %s limit 1", opt.table, opt.query) sql = e.Rebind(sql) err := e.Get(dest, sql, opt.args...) if err != nil { return err } return nil } // }}} // 查询多条记录 // dest 目标变量 // opt 查询对象 // args bindvars func (e *Engine) GetRecords(dest interface{}, opt *QueryOption) error { // {{{ var tmp = []string{} if opt.query != "" { tmp = append(tmp, "where", opt.query) } if opt.sort != "" { tmp = append(tmp, "order by", opt.sort) } if opt.offset > 0 { tmp = append(tmp, "offset", strconv.Itoa(opt.offset)) } if opt.limit > 0 { tmp = append(tmp, "limit", strconv.Itoa(opt.limit)) } sql := fmt.Sprintf("select * from %s %s", opt.table, strings.Join(tmp, " ")) sql = e.Rebind(sql) return e.Select(dest, sql, opt.args...) } // }}} // 更新一条记录 // table 待处理的表 // set 需要设置的语句, eg: age=:age // query 查询语句,不能为空,确保误更新所有记录 // arg 值 func (e *Engine) NamedUpdateRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{ if opt.set == "" || opt.query == "" { return 0, errors.New("empty set or query") } sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query) result, err := e.NamedExec(sql, arg) if err != nil { return 0, err } rows, err := result.RowsAffected() if err != nil { return 0, err } return rows, nil } // }}} func (e *Engine) UpdateRecords(opt *QueryOption) (int64, error) { // {{{ if opt.set == "" || opt.query == "" { return 0, errors.New("empty set or query") } sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query) sql = e.Rebind(sql) result, err := e.Exec(sql, opt.args...) if err != nil { return 0, err } rows, err := result.RowsAffected() if err != nil { return 0, err } return rows, nil } // }}} // 删除若干条记录 // opt 的 query 不能为空 // arg bindvars func (e *Engine) NamedDeleteRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{ if opt.query == "" { return 0, errors.New("emtpy query") } sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query) result, err := e.NamedExec(sql, arg) if err != nil { return 0, err } rows, err := result.RowsAffected() if err != nil { return 0, err } return rows, nil } // }}} func (e *Engine) DeleteRecords(opt *QueryOption) (int64, error) { if opt.query == "" { return 0, errors.New("emtpy query") } sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query) sql = e.Rebind(sql) result, err := e.Exec(sql, opt.args...) if err != nil { return 0, err } rows, err := result.RowsAffected() if err != nil { return 0, err } return rows, nil } func (e *Engine) CountRecords(opt *QueryOption) (int, error) { sql := fmt.Sprintf("select count(*) from %s where %s", opt.table, opt.query) sql = e.Rebind(sql) var num int err := e.Get(&num, sql, opt.args...) if err != nil { return 0, err } return num, nil } // var levels = []int{4, 6, 7} // query, args, err := sqlx.In("SELECT * FROM users WHERE level IN (?);", levels) // sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend // query = db.Rebind(query) // rows, err := db.Query(query, args...) func (e *Engine) In(query string, args ...interface{}) (string, []interface{}, error) { return sqlx.In(query, args...) } func IsNoRows(err error) bool { return err == ErrNoRows } // 把 fields 转换为 field1=:field1, field2=:field2, ..., fieldN=:fieldN func GetSetString(fields []string) string { items := []string{} for _, field := range fields { if field == "id" { continue } items = append(items, fmt.Sprintf("%s=:%s", field, field)) } return strings.Join(items, ",") }