package sqladapter import ( "errors" "sync" "sync/atomic" "git.hexq.cn/tiglog/mydb" "git.hexq.cn/tiglog/mydb/internal/immutable" ) type Result struct { builder mydb.SQL err atomic.Value iter mydb.Iterator iterMu sync.Mutex prev *Result fn func(*result) error } // result represents a delimited set of items bound by a condition. type result struct { table string limit int offset int pageSize uint pageNumber uint cursorColumn string nextPageCursorValue interface{} prevPageCursorValue interface{} fields []interface{} orderBy []interface{} groupBy []interface{} conds [][]interface{} } func filter(conds []interface{}) []interface{} { return conds } // NewResult creates and Results a new Result set on the given table, this set // is limited by the given exql.Where conditions. func NewResult(builder mydb.SQL, table string, conds []interface{}) *Result { r := &Result{ builder: builder, } return r.from(table).where(conds) } func (r *Result) frame(fn func(*result) error) *Result { return &Result{err: r.err, prev: r, fn: fn} } func (r *Result) SQL() mydb.SQL { if r.prev == nil { return r.builder } return r.prev.SQL() } func (r *Result) from(table string) *Result { return r.frame(func(res *result) error { res.table = table return nil }) } func (r *Result) where(conds []interface{}) *Result { return r.frame(func(res *result) error { res.conds = [][]interface{}{conds} return nil }) } func (r *Result) setErr(err error) { if err == nil { return } r.err.Store(err) } // Err returns the last error that has happened with the result set, // nil otherwise func (r *Result) Err() error { if errV := r.err.Load(); errV != nil { return errV.(error) } return nil } // Where sets conditions for the result set. func (r *Result) Where(conds ...interface{}) mydb.Result { return r.where(conds) } // And adds more conditions on top of the existing ones. func (r *Result) And(conds ...interface{}) mydb.Result { return r.frame(func(res *result) error { res.conds = append(res.conds, conds) return nil }) } // Limit determines the maximum limit of Results to be returned. func (r *Result) Limit(n int) mydb.Result { return r.frame(func(res *result) error { res.limit = n return nil }) } func (r *Result) Paginate(pageSize uint) mydb.Result { return r.frame(func(res *result) error { res.pageSize = pageSize return nil }) } func (r *Result) Page(pageNumber uint) mydb.Result { return r.frame(func(res *result) error { res.pageNumber = pageNumber res.nextPageCursorValue = nil res.prevPageCursorValue = nil return nil }) } func (r *Result) Cursor(cursorColumn string) mydb.Result { return r.frame(func(res *result) error { res.cursorColumn = cursorColumn return nil }) } func (r *Result) NextPage(cursorValue interface{}) mydb.Result { return r.frame(func(res *result) error { res.nextPageCursorValue = cursorValue res.prevPageCursorValue = nil return nil }) } func (r *Result) PrevPage(cursorValue interface{}) mydb.Result { return r.frame(func(res *result) error { res.nextPageCursorValue = nil res.prevPageCursorValue = cursorValue return nil }) } // Offset determines how many documents will be skipped before starting to grab // Results. func (r *Result) Offset(n int) mydb.Result { return r.frame(func(res *result) error { res.offset = n return nil }) } // GroupBy is used to group Results that have the same value in the same column // or columns. func (r *Result) GroupBy(fields ...interface{}) mydb.Result { return r.frame(func(res *result) error { res.groupBy = fields return nil }) } // OrderBy determines sorting of Results according to the provided names. Fields // may be prefixed by - (minus) which means descending order, ascending order // would be used otherwise. func (r *Result) OrderBy(fields ...interface{}) mydb.Result { return r.frame(func(res *result) error { res.orderBy = fields return nil }) } // Select determines which fields to return. func (r *Result) Select(fields ...interface{}) mydb.Result { return r.frame(func(res *result) error { res.fields = fields return nil }) } // String satisfies fmt.Stringer func (r *Result) String() string { query, err := r.Paginator() if err != nil { panic(err.Error()) } return query.String() } // All dumps all Results into a pointer to an slice of structs or maps. func (r *Result) All(dst interface{}) error { query, err := r.Paginator() if err != nil { r.setErr(err) return err } err = query.Iterator().All(dst) r.setErr(err) return err } // One fetches only one Result from the set. func (r *Result) One(dst interface{}) error { one := r.Limit(1).(*Result) query, err := one.Paginator() if err != nil { r.setErr(err) return err } err = query.Iterator().One(dst) r.setErr(err) return err } // Next fetches the next Result from the set. func (r *Result) Next(dst interface{}) bool { r.iterMu.Lock() defer r.iterMu.Unlock() if r.iter == nil { query, err := r.Paginator() if err != nil { r.setErr(err) return false } r.iter = query.Iterator() } if r.iter.Next(dst) { return true } if err := r.iter.Err(); !errors.Is(err, mydb.ErrNoMoreRows) { r.setErr(err) return false } return false } // Delete deletes all matching items from the collection. func (r *Result) Delete() error { query, err := r.buildDelete() if err != nil { r.setErr(err) return err } _, err = query.Exec() r.setErr(err) return err } // Close closes the Result set. func (r *Result) Close() error { if r.iter != nil { err := r.iter.Close() r.setErr(err) return err } return nil } // Update updates matching items from the collection with values of the given // map or struct. func (r *Result) Update(values interface{}) error { query, err := r.buildUpdate(values) if err != nil { r.setErr(err) return err } _, err = query.Exec() r.setErr(err) return err } func (r *Result) TotalPages() (uint, error) { query, err := r.Paginator() if err != nil { r.setErr(err) return 0, err } total, err := query.TotalPages() if err != nil { r.setErr(err) return 0, err } return total, nil } func (r *Result) TotalEntries() (uint64, error) { query, err := r.Paginator() if err != nil { r.setErr(err) return 0, err } total, err := query.TotalEntries() if err != nil { r.setErr(err) return 0, err } return total, nil } // Exists returns true if at least one item on the collection exists. func (r *Result) Exists() (bool, error) { query, err := r.buildCount() if err != nil { r.setErr(err) return false, err } query = query.Limit(1) value := struct { Exists uint64 `db:"_t"` }{} if err := query.One(&value); err != nil { if errors.Is(err, mydb.ErrNoMoreRows) { return false, nil } r.setErr(err) return false, err } if value.Exists > 0 { return true, nil } return false, nil } // Count counts the elements on the set. func (r *Result) Count() (uint64, error) { query, err := r.buildCount() if err != nil { r.setErr(err) return 0, err } counter := struct { Count uint64 `db:"_t"` }{} if err := query.One(&counter); err != nil { if errors.Is(err, mydb.ErrNoMoreRows) { return 0, nil } r.setErr(err) return 0, err } return counter.Count, nil } func (r *Result) Paginator() (mydb.Paginator, error) { if err := r.Err(); err != nil { return nil, err } res, err := r.fastForward() if err != nil { return nil, err } sel := r.SQL().Select(res.fields...). From(res.table). Limit(res.limit). Offset(res.offset). GroupBy(res.groupBy...). OrderBy(res.orderBy...) for i := range res.conds { sel = sel.And(filter(res.conds[i])...) } pag := sel.Paginate(res.pageSize). Page(res.pageNumber). Cursor(res.cursorColumn) if res.nextPageCursorValue != nil { pag = pag.NextPage(res.nextPageCursorValue) } if res.prevPageCursorValue != nil { pag = pag.PrevPage(res.prevPageCursorValue) } return pag, nil } func (r *Result) buildDelete() (mydb.Deleter, error) { if err := r.Err(); err != nil { return nil, err } res, err := r.fastForward() if err != nil { return nil, err } del := r.SQL().DeleteFrom(res.table). Limit(res.limit) for i := range res.conds { del = del.And(filter(res.conds[i])...) } return del, nil } func (r *Result) buildUpdate(values interface{}) (mydb.Updater, error) { if err := r.Err(); err != nil { return nil, err } res, err := r.fastForward() if err != nil { return nil, err } upd := r.SQL().Update(res.table). Set(values). Limit(res.limit) for i := range res.conds { upd = upd.And(filter(res.conds[i])...) } return upd, nil } func (r *Result) buildCount() (mydb.Selector, error) { if err := r.Err(); err != nil { return nil, err } res, err := r.fastForward() if err != nil { return nil, err } sel := r.SQL().Select(mydb.Raw("count(1) AS _t")). From(res.table). GroupBy(res.groupBy...) for i := range res.conds { sel = sel.And(filter(res.conds[i])...) } return sel, nil } func (r *Result) Prev() immutable.Immutable { if r == nil { return nil } return r.prev } func (r *Result) Fn(in interface{}) error { if r.fn == nil { return nil } return r.fn(in.(*result)) } func (r *Result) Base() interface{} { return &result{} } func (r *Result) fastForward() (*result, error) { ff, err := immutable.FastForward(r) if err != nil { return nil, err } return ff.(*result), nil } var _ = immutable.Immutable(&Result{})