323 lines
7.0 KiB
Go
323 lines
7.0 KiB
Go
//
|
|
// db_query.go
|
|
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
|
//
|
|
// Distributed under terms of the MIT license.
|
|
//
|
|
|
|
package sqldb
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
)
|
|
|
|
type Query struct {
|
|
db *Engine
|
|
table string
|
|
fields []string
|
|
wheres []string // 不能太复杂
|
|
joins []string
|
|
orderBy string
|
|
groupBy string
|
|
offset int
|
|
limit int
|
|
}
|
|
|
|
func NewQueryBuild(table string, db *Engine) *Query {
|
|
return &Query{
|
|
db: db,
|
|
table: table,
|
|
fields: []string{},
|
|
wheres: []string{},
|
|
joins: []string{},
|
|
offset: 0,
|
|
limit: 0,
|
|
}
|
|
}
|
|
|
|
func (q *Query) Table(table string) *Query {
|
|
q.table = table
|
|
return q
|
|
}
|
|
|
|
// 设置 select fields
|
|
func (q *Query) Select(fields ...string) *Query {
|
|
q.fields = fields
|
|
return q
|
|
}
|
|
|
|
// 增加一个 select field
|
|
func (q *Query) AddFields(fields ...string) *Query {
|
|
q.fields = append(q.fields, fields...)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Where(query string) *Query {
|
|
q.wheres = []string{query}
|
|
return q
|
|
}
|
|
func (q *Query) AndWhere(query string) *Query {
|
|
q.wheres = append(q.wheres, "and "+query)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) OrWhere(query string) *Query {
|
|
q.wheres = append(q.wheres, "or "+query)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Join(table string, on string) *Query {
|
|
var join = "join " + table
|
|
if on != "" {
|
|
join = join + " on " + on
|
|
}
|
|
q.joins = append(q.joins, join)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) LeftJoin(table string, on string) *Query {
|
|
var join = "left join " + table
|
|
if on != "" {
|
|
join = join + " on " + on
|
|
}
|
|
q.joins = append(q.joins, join)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) RightJoin(table string, on string) *Query {
|
|
var join = "right join " + table
|
|
if on != "" {
|
|
join = join + " on " + on
|
|
}
|
|
q.joins = append(q.joins, join)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) InnerJoin(table string, on string) *Query {
|
|
var join = "inner join " + table
|
|
if on != "" {
|
|
join = join + " on " + on
|
|
}
|
|
q.joins = append(q.joins, join)
|
|
return q
|
|
}
|
|
|
|
func (q *Query) OrderBy(order string) *Query {
|
|
q.orderBy = order
|
|
return q
|
|
}
|
|
func (q *Query) GroupBy(group string) *Query {
|
|
q.groupBy = group
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Offset(offset int) *Query {
|
|
q.offset = offset
|
|
return q
|
|
}
|
|
|
|
func (q *Query) Limit(limit int) *Query {
|
|
q.limit = limit
|
|
return q
|
|
}
|
|
|
|
// returningId postgres 数据库返回 LastInsertId 处理
|
|
// TODO returningId 暂时不处理
|
|
func (q *Query) getInsertSql(named, returningId bool) string {
|
|
fields_str := strings.Join(q.fields, ",")
|
|
var pl string
|
|
if named {
|
|
var tmp []string
|
|
for _, field := range q.fields {
|
|
tmp = append(tmp, ":"+field)
|
|
}
|
|
pl = strings.Join(tmp, ",")
|
|
} else {
|
|
pl = strings.Repeat("?,", len(q.fields))
|
|
pl = strings.TrimRight(pl, ",")
|
|
}
|
|
|
|
sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl)
|
|
sql = q.db.Rebind(sql)
|
|
// fmt.Println(sql)
|
|
return sql
|
|
}
|
|
|
|
// return RowsAffected, error
|
|
func (q *Query) Insert(args ...interface{}) (int64, error) {
|
|
if len(q.fields) == 0 {
|
|
return 0, errors.New("empty fields")
|
|
}
|
|
sql := q.getInsertSql(false, false)
|
|
result, err := q.db.Exec(sql, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// return RowsAffected, error
|
|
func (q *Query) NamedInsert(arg interface{}) (int64, error) {
|
|
if len(q.fields) == 0 {
|
|
return 0, errors.New("empty fields")
|
|
}
|
|
sql := q.getInsertSql(true, false)
|
|
result, err := q.db.NamedExec(sql, arg)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
func (q *Query) getQuerySql() string {
|
|
var (
|
|
fields_str string = "*"
|
|
join_str string
|
|
where_str string
|
|
offlim string
|
|
)
|
|
if len(q.fields) > 0 {
|
|
fields_str = strings.Join(q.fields, ",")
|
|
}
|
|
|
|
if len(q.joins) > 0 {
|
|
join_str = strings.Join(q.joins, " ")
|
|
}
|
|
if len(q.wheres) > 0 {
|
|
where_str = "where " + strings.Join(q.wheres, " ")
|
|
}
|
|
|
|
if q.offset > 0 {
|
|
offlim = " offset " + strconv.Itoa(q.offset)
|
|
}
|
|
if q.limit > 0 {
|
|
offlim = " limit " + strconv.Itoa(q.limit)
|
|
}
|
|
// select fields from table t join where groupby orderby offset limit
|
|
sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim)
|
|
return sql
|
|
}
|
|
|
|
func (q *Query) One(dest interface{}, args ...interface{}) error {
|
|
q.Limit(1)
|
|
sql := q.getQuerySql()
|
|
sql = q.db.Rebind(sql)
|
|
return q.db.Get(dest, sql, args...)
|
|
}
|
|
|
|
func (q *Query) NamedOne(dest interface{}, arg interface{}) error {
|
|
q.Limit(1)
|
|
sql := q.getQuerySql()
|
|
rows, err := q.db.NamedQuery(sql, arg)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if rows.Next() {
|
|
return rows.Scan(dest)
|
|
}
|
|
return errors.New("nr") // no record
|
|
}
|
|
|
|
func (q *Query) All(dest interface{}, args ...interface{}) error {
|
|
sql := q.getQuerySql()
|
|
sql = q.db.Rebind(sql)
|
|
return q.db.Select(dest, sql, args...)
|
|
}
|
|
|
|
// 为了省内存,直接返回迭代器
|
|
func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) {
|
|
sql := q.getQuerySql()
|
|
return q.db.NamedQuery(sql, arg)
|
|
}
|
|
|
|
// set age=? / age=:age
|
|
func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) {
|
|
var where_str string
|
|
if len(q.wheres) > 0 {
|
|
where_str = strings.Join(q.wheres, " ")
|
|
}
|
|
if set == "" || where_str == "" {
|
|
return 0, errors.New("empty set or where")
|
|
}
|
|
|
|
// update table t where
|
|
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
|
|
sql = q.db.Rebind(sql)
|
|
result, err := q.db.NamedExec(sql, arg)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数
|
|
func (q *Query) Update(set string, args ...interface{}) (int64, error) {
|
|
var where_str string
|
|
if len(q.wheres) > 0 {
|
|
where_str = strings.Join(q.wheres, " ")
|
|
}
|
|
if set == "" || where_str == "" {
|
|
return 0, errors.New("empty set or where")
|
|
}
|
|
|
|
// update table t where
|
|
sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str)
|
|
sql = q.db.Rebind(sql)
|
|
result, err := q.db.Exec(sql, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
// 普通的删除
|
|
func (q *Query) Delete(args ...interface{}) (int64, error) {
|
|
var where_str string
|
|
if len(q.wheres) == 0 {
|
|
return 0, errors.New("missing where clause")
|
|
}
|
|
where_str = strings.Join(q.wheres, " ")
|
|
|
|
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
|
|
sql = q.db.Rebind(sql)
|
|
result, err := q.db.Exec(sql, args...)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
func (q *Query) NamedDelete(arg interface{}) (int64, error) {
|
|
if len(q.wheres) == 0 {
|
|
return 0, errors.New("missing where clause")
|
|
}
|
|
var where_str string
|
|
where_str = strings.Join(q.wheres, " ")
|
|
|
|
sql := fmt.Sprintf("delete from %s where %s", q.table, where_str)
|
|
sql = q.db.Rebind(sql)
|
|
|
|
result, err := q.db.NamedExec(sql, arg)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return result.RowsAffected()
|
|
}
|
|
|
|
func (q *Query) Count(args ...interface{}) (int64, error) {
|
|
var where_str string
|
|
if len(q.wheres) > 0 {
|
|
where_str = " where " + strings.Join(q.wheres, " ")
|
|
}
|
|
sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str)
|
|
sql = q.db.Rebind(sql)
|
|
var num int64
|
|
err := q.db.Get(&num, sql, args...)
|
|
return num, err
|
|
}
|