golib/gdb/sqldb/db_query.go
2023-06-15 21:22:51 +08:00

323 lines
7.0 KiB
Go

//
// db_query.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package sqldb
import (
"errors"
"fmt"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
)
type Query struct {
db *Engine
table string
fields []string
wheres []string // 不能太复杂
joins []string
orderBy string
groupBy string
offset int
limit int
}
func NewQueryBuild(table string, db *Engine) *Query {
return &Query{
db: db,
table: table,
fields: []string{},
wheres: []string{},
joins: []string{},
offset: 0,
limit: 0,
}
}
func (q *Query) Table(table string) *Query {
q.table = table
return q
}
// 设置 select fields
func (q *Query) Select(fields ...string) *Query {
q.fields = fields
return q
}
// 增加一个 select field
func (q *Query) AddFields(fields ...string) *Query {
q.fields = append(q.fields, fields...)
return q
}
func (q *Query) Where(query string) *Query {
q.wheres = []string{query}
return q
}
func (q *Query) AndWhere(query string) *Query {
q.wheres = append(q.wheres, "and "+query)
return q
}
func (q *Query) OrWhere(query string) *Query {
q.wheres = append(q.wheres, "or "+query)
return q
}
func (q *Query) Join(table string, on string) *Query {
var join = "join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) LeftJoin(table string, on string) *Query {
var join = "left join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) RightJoin(table string, on string) *Query {
var join = "right join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) InnerJoin(table string, on string) *Query {
var join = "inner join " + table
if on != "" {
join = join + " on " + on
}
q.joins = append(q.joins, join)
return q
}
func (q *Query) OrderBy(order string) *Query {
q.orderBy = order
return q
}
func (q *Query) GroupBy(group string) *Query {
q.groupBy = group
return q
}
func (q *Query) Offset(offset int) *Query {
q.offset = offset
return q
}
func (q *Query) Limit(limit int) *Query {
q.limit = limit
return q
}
// returningId postgres 数据库返回 LastInsertId 处理
// TODO returningId 暂时不处理
func (q *Query) getInsertSql(named, returningId bool) string {
fields_str := strings.Join(q.fields, ",")
var pl string
if named {
var tmp []string
for _, field := range q.fields {
tmp = append(tmp, ":"+field)
}
pl = strings.Join(tmp, ",")
} else {
pl = strings.Repeat("?,", len(q.fields))
pl = strings.TrimRight(pl, ",")
}
sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl)
sql = q.db.Rebind(sql)
// fmt.Println(sql)
return sql
}
// return RowsAffected, error
func (q *Query) Insert(args ...interface{}) (int64, error) {
if len(q.fields) == 0 {
return 0, errors.New("empty fields")
}
sql := q.getInsertSql(false, false)
result, err := q.db.Exec(sql, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// return RowsAffected, error
func (q *Query) NamedInsert(arg interface{}) (int64, error) {
if len(q.fields) == 0 {
return 0, errors.New("empty fields")
}
sql := q.getInsertSql(true, false)
result, err := q.db.NamedExec(sql, arg)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (q *Query) getQuerySql() string {
var (
fields_str string = "*"
join_str string
where_str string
offlim string
)
if len(q.fields) > 0 {
fields_str = strings.Join(q.fields, ",")
}
if len(q.joins) > 0 {
join_str = strings.Join(q.joins, " ")
}
if len(q.wheres) > 0 {
where_str = "where " + strings.Join(q.wheres, " ")
}
if q.offset > 0 {
offlim = " offset " + strconv.Itoa(q.offset)
}
if q.limit > 0 {
offlim = " limit " + strconv.Itoa(q.limit)
}
// select fields from table t join where groupby orderby offset limit
sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim)
return sql
}
func (q *Query) One(dest interface{}, args ...interface{}) error {
q.Limit(1)
sql := q.getQuerySql()
sql = q.db.Rebind(sql)
return q.db.Get(dest, sql, args...)
}
func (q *Query) NamedOne(dest interface{}, arg interface{}) error {
q.Limit(1)
sql := q.getQuerySql()
rows, err := q.db.NamedQuery(sql, arg)
if err != nil {
return err
}
if rows.Next() {
return rows.Scan(dest)
}
return errors.New("nr") // no record
}
func (q *Query) All(dest interface{}, args ...interface{}) error {
sql := q.getQuerySql()
sql = q.db.Rebind(sql)
return q.db.Select(dest, sql, args...)
}
// 为了省内存,直接返回迭代器
func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) {
sql := q.getQuerySql()
return q.db.NamedQuery(sql, arg)
}
// set age=? / age=:age
func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) {
var where_str string
if len(q.wheres) > 0 {
where_str = strings.Join(q.wheres, " ")
}
if set == "" || where_str == "" {
return 0, errors.New("empty set or where")
}
// update table t where
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.NamedExec(sql, arg)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数
func (q *Query) Update(set string, args ...interface{}) (int64, error) {
var where_str string
if len(q.wheres) > 0 {
where_str = strings.Join(q.wheres, " ")
}
if set == "" || where_str == "" {
return 0, errors.New("empty set or where")
}
// update table t where
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.Exec(sql, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
// 普通的删除
func (q *Query) Delete(args ...interface{}) (int64, error) {
var where_str string
if len(q.wheres) == 0 {
return 0, errors.New("missing where clause")
}
where_str = strings.Join(q.wheres, " ")
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.Exec(sql, args...)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (q *Query) NamedDelete(arg interface{}) (int64, error) {
if len(q.wheres) == 0 {
return 0, errors.New("missing where clause")
}
var where_str string
where_str = strings.Join(q.wheres, " ")
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
sql = q.db.Rebind(sql)
result, err := q.db.NamedExec(sql, arg)
if err != nil {
return 0, err
}
return result.RowsAffected()
}
func (q *Query) Count(args ...interface{}) (int64, error) {
var where_str string
if len(q.wheres) > 0 {
where_str = " where " + strings.Join(q.wheres, " ")
}
sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str)
sql = q.db.Rebind(sql)
var num int64
err := q.db.Get(&num, sql, args...)
return num, err
}