// // db_query.go // Copyright (C) 2022 tiglog // // Distributed under terms of the MIT license. // package sqldb import ( "errors" "fmt" "strconv" "strings" "github.com/jmoiron/sqlx" ) type Query struct { db *Engine table string fields []string wheres []string // 不能太复杂 joins []string orderBy string groupBy string offset int limit int } func NewQueryBuild(table string, db *Engine) *Query { return &Query{ db: db, table: table, fields: []string{}, wheres: []string{}, joins: []string{}, offset: 0, limit: 0, } } func (q *Query) Table(table string) *Query { q.table = table return q } // 设置 select fields func (q *Query) Select(fields ...string) *Query { q.fields = fields return q } // 增加一个 select field func (q *Query) AddFields(fields ...string) *Query { q.fields = append(q.fields, fields...) return q } func (q *Query) Where(query string) *Query { q.wheres = []string{query} return q } func (q *Query) AndWhere(query string) *Query { q.wheres = append(q.wheres, "and "+query) return q } func (q *Query) OrWhere(query string) *Query { q.wheres = append(q.wheres, "or "+query) return q } func (q *Query) Join(table string, on string) *Query { var join = "join " + table if on != "" { join = join + " on " + on } q.joins = append(q.joins, join) return q } func (q *Query) LeftJoin(table string, on string) *Query { var join = "left join " + table if on != "" { join = join + " on " + on } q.joins = append(q.joins, join) return q } func (q *Query) RightJoin(table string, on string) *Query { var join = "right join " + table if on != "" { join = join + " on " + on } q.joins = append(q.joins, join) return q } func (q *Query) InnerJoin(table string, on string) *Query { var join = "inner join " + table if on != "" { join = join + " on " + on } q.joins = append(q.joins, join) return q } func (q *Query) OrderBy(order string) *Query { q.orderBy = order return q } func (q *Query) GroupBy(group string) *Query { q.groupBy = group return q } func (q *Query) Offset(offset int) *Query { q.offset = offset return q } func (q *Query) Limit(limit int) *Query { q.limit = limit return q } // returningId postgres 数据库返回 LastInsertId 处理 // TODO returningId 暂时不处理 func (q *Query) getInsertSql(named, returningId bool) string { fields_str := strings.Join(q.fields, ",") var pl string if named { var tmp []string for _, field := range q.fields { tmp = append(tmp, ":"+field) } pl = strings.Join(tmp, ",") } else { pl = strings.Repeat("?,", len(q.fields)) pl = strings.TrimRight(pl, ",") } sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl) sql = q.db.Rebind(sql) // fmt.Println(sql) return sql } // return RowsAffected, error func (q *Query) Insert(args ...interface{}) (int64, error) { if len(q.fields) == 0 { return 0, errors.New("empty fields") } sql := q.getInsertSql(false, false) result, err := q.db.Exec(sql, args...) if err != nil { return 0, err } return result.RowsAffected() } // return RowsAffected, error func (q *Query) NamedInsert(arg interface{}) (int64, error) { if len(q.fields) == 0 { return 0, errors.New("empty fields") } sql := q.getInsertSql(true, false) result, err := q.db.NamedExec(sql, arg) if err != nil { return 0, err } return result.RowsAffected() } func (q *Query) getQuerySql() string { var ( fields_str string = "*" join_str string where_str string offlim string ) if len(q.fields) > 0 { fields_str = strings.Join(q.fields, ",") } if len(q.joins) > 0 { join_str = strings.Join(q.joins, " ") } if len(q.wheres) > 0 { where_str = "where " + strings.Join(q.wheres, " ") } if q.offset > 0 { offlim = " offset " + strconv.Itoa(q.offset) } if q.limit > 0 { offlim = " limit " + strconv.Itoa(q.limit) } // select fields from table t join where groupby orderby offset limit sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim) return sql } func (q *Query) One(dest interface{}, args ...interface{}) error { q.Limit(1) sql := q.getQuerySql() sql = q.db.Rebind(sql) return q.db.Get(dest, sql, args...) } func (q *Query) NamedOne(dest interface{}, arg interface{}) error { q.Limit(1) sql := q.getQuerySql() rows, err := q.db.NamedQuery(sql, arg) if err != nil { return err } if rows.Next() { return rows.Scan(dest) } return errors.New("nr") // no record } func (q *Query) All(dest interface{}, args ...interface{}) error { sql := q.getQuerySql() sql = q.db.Rebind(sql) return q.db.Select(dest, sql, args...) } // 为了省内存,直接返回迭代器 func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) { sql := q.getQuerySql() return q.db.NamedQuery(sql, arg) } // set age=? / age=:age func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) { var where_str string if len(q.wheres) > 0 { where_str = strings.Join(q.wheres, " ") } if set == "" || where_str == "" { return 0, errors.New("empty set or where") } // update table t where sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str) sql = q.db.Rebind(sql) result, err := q.db.NamedExec(sql, arg) if err != nil { return 0, err } return result.RowsAffected() } // 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数 func (q *Query) Update(set string, args ...interface{}) (int64, error) { var where_str string if len(q.wheres) > 0 { where_str = strings.Join(q.wheres, " ") } if set == "" || where_str == "" { return 0, errors.New("empty set or where") } // update table t where sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str) sql = q.db.Rebind(sql) result, err := q.db.Exec(sql, args...) if err != nil { return 0, err } return result.RowsAffected() } // 普通的删除 func (q *Query) Delete(args ...interface{}) (int64, error) { var where_str string if len(q.wheres) == 0 { return 0, errors.New("missing where clause") } where_str = strings.Join(q.wheres, " ") sql := fmt.Sprintf("delete from %s where %s", q.table, where_str) sql = q.db.Rebind(sql) result, err := q.db.Exec(sql, args...) if err != nil { return 0, err } return result.RowsAffected() } func (q *Query) NamedDelete(arg interface{}) (int64, error) { if len(q.wheres) == 0 { return 0, errors.New("missing where clause") } var where_str string where_str = strings.Join(q.wheres, " ") sql := fmt.Sprintf("delete from %s where %s", q.table, where_str) sql = q.db.Rebind(sql) result, err := q.db.NamedExec(sql, arg) if err != nil { return 0, err } return result.RowsAffected() } func (q *Query) Count(args ...interface{}) (int64, error) { var where_str string if len(q.wheres) > 0 { where_str = " where " + strings.Join(q.wheres, " ") } sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str) sql = q.db.Rebind(sql) var num int64 err := q.db.Get(&num, sql, args...) return num, err }