mydb/internal/sqladapter/result.go
2023-09-18 15:15:42 +08:00

499 lines
9.3 KiB
Go

package sqladapter
import (
"errors"
"sync"
"sync/atomic"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/immutable"
)
type Result struct {
builder mydb.SQL
err atomic.Value
iter mydb.Iterator
iterMu sync.Mutex
prev *Result
fn func(*result) error
}
// result represents a delimited set of items bound by a condition.
type result struct {
table string
limit int
offset int
pageSize uint
pageNumber uint
cursorColumn string
nextPageCursorValue interface{}
prevPageCursorValue interface{}
fields []interface{}
orderBy []interface{}
groupBy []interface{}
conds [][]interface{}
}
func filter(conds []interface{}) []interface{} {
return conds
}
// NewResult creates and Results a new Result set on the given table, this set
// is limited by the given exql.Where conditions.
func NewResult(builder mydb.SQL, table string, conds []interface{}) *Result {
r := &Result{
builder: builder,
}
return r.from(table).where(conds)
}
func (r *Result) frame(fn func(*result) error) *Result {
return &Result{err: r.err, prev: r, fn: fn}
}
func (r *Result) SQL() mydb.SQL {
if r.prev == nil {
return r.builder
}
return r.prev.SQL()
}
func (r *Result) from(table string) *Result {
return r.frame(func(res *result) error {
res.table = table
return nil
})
}
func (r *Result) where(conds []interface{}) *Result {
return r.frame(func(res *result) error {
res.conds = [][]interface{}{conds}
return nil
})
}
func (r *Result) setErr(err error) {
if err == nil {
return
}
r.err.Store(err)
}
// Err returns the last error that has happened with the result set,
// nil otherwise
func (r *Result) Err() error {
if errV := r.err.Load(); errV != nil {
return errV.(error)
}
return nil
}
// Where sets conditions for the result set.
func (r *Result) Where(conds ...interface{}) mydb.Result {
return r.where(conds)
}
// And adds more conditions on top of the existing ones.
func (r *Result) And(conds ...interface{}) mydb.Result {
return r.frame(func(res *result) error {
res.conds = append(res.conds, conds)
return nil
})
}
// Limit determines the maximum limit of Results to be returned.
func (r *Result) Limit(n int) mydb.Result {
return r.frame(func(res *result) error {
res.limit = n
return nil
})
}
func (r *Result) Paginate(pageSize uint) mydb.Result {
return r.frame(func(res *result) error {
res.pageSize = pageSize
return nil
})
}
func (r *Result) Page(pageNumber uint) mydb.Result {
return r.frame(func(res *result) error {
res.pageNumber = pageNumber
res.nextPageCursorValue = nil
res.prevPageCursorValue = nil
return nil
})
}
func (r *Result) Cursor(cursorColumn string) mydb.Result {
return r.frame(func(res *result) error {
res.cursorColumn = cursorColumn
return nil
})
}
func (r *Result) NextPage(cursorValue interface{}) mydb.Result {
return r.frame(func(res *result) error {
res.nextPageCursorValue = cursorValue
res.prevPageCursorValue = nil
return nil
})
}
func (r *Result) PrevPage(cursorValue interface{}) mydb.Result {
return r.frame(func(res *result) error {
res.nextPageCursorValue = nil
res.prevPageCursorValue = cursorValue
return nil
})
}
// Offset determines how many documents will be skipped before starting to grab
// Results.
func (r *Result) Offset(n int) mydb.Result {
return r.frame(func(res *result) error {
res.offset = n
return nil
})
}
// GroupBy is used to group Results that have the same value in the same column
// or columns.
func (r *Result) GroupBy(fields ...interface{}) mydb.Result {
return r.frame(func(res *result) error {
res.groupBy = fields
return nil
})
}
// OrderBy determines sorting of Results according to the provided names. Fields
// may be prefixed by - (minus) which means descending order, ascending order
// would be used otherwise.
func (r *Result) OrderBy(fields ...interface{}) mydb.Result {
return r.frame(func(res *result) error {
res.orderBy = fields
return nil
})
}
// Select determines which fields to return.
func (r *Result) Select(fields ...interface{}) mydb.Result {
return r.frame(func(res *result) error {
res.fields = fields
return nil
})
}
// String satisfies fmt.Stringer
func (r *Result) String() string {
query, err := r.Paginator()
if err != nil {
panic(err.Error())
}
return query.String()
}
// All dumps all Results into a pointer to an slice of structs or maps.
func (r *Result) All(dst interface{}) error {
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return err
}
err = query.Iterator().All(dst)
r.setErr(err)
return err
}
// One fetches only one Result from the set.
func (r *Result) One(dst interface{}) error {
one := r.Limit(1).(*Result)
query, err := one.Paginator()
if err != nil {
r.setErr(err)
return err
}
err = query.Iterator().One(dst)
r.setErr(err)
return err
}
// Next fetches the next Result from the set.
func (r *Result) Next(dst interface{}) bool {
r.iterMu.Lock()
defer r.iterMu.Unlock()
if r.iter == nil {
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return false
}
r.iter = query.Iterator()
}
if r.iter.Next(dst) {
return true
}
if err := r.iter.Err(); !errors.Is(err, mydb.ErrNoMoreRows) {
r.setErr(err)
return false
}
return false
}
// Delete deletes all matching items from the collection.
func (r *Result) Delete() error {
query, err := r.buildDelete()
if err != nil {
r.setErr(err)
return err
}
_, err = query.Exec()
r.setErr(err)
return err
}
// Close closes the Result set.
func (r *Result) Close() error {
if r.iter != nil {
err := r.iter.Close()
r.setErr(err)
return err
}
return nil
}
// Update updates matching items from the collection with values of the given
// map or struct.
func (r *Result) Update(values interface{}) error {
query, err := r.buildUpdate(values)
if err != nil {
r.setErr(err)
return err
}
_, err = query.Exec()
r.setErr(err)
return err
}
func (r *Result) TotalPages() (uint, error) {
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return 0, err
}
total, err := query.TotalPages()
if err != nil {
r.setErr(err)
return 0, err
}
return total, nil
}
func (r *Result) TotalEntries() (uint64, error) {
query, err := r.Paginator()
if err != nil {
r.setErr(err)
return 0, err
}
total, err := query.TotalEntries()
if err != nil {
r.setErr(err)
return 0, err
}
return total, nil
}
// Exists returns true if at least one item on the collection exists.
func (r *Result) Exists() (bool, error) {
query, err := r.buildCount()
if err != nil {
r.setErr(err)
return false, err
}
query = query.Limit(1)
value := struct {
Exists uint64 `db:"_t"`
}{}
if err := query.One(&value); err != nil {
if errors.Is(err, mydb.ErrNoMoreRows) {
return false, nil
}
r.setErr(err)
return false, err
}
if value.Exists > 0 {
return true, nil
}
return false, nil
}
// Count counts the elements on the set.
func (r *Result) Count() (uint64, error) {
query, err := r.buildCount()
if err != nil {
r.setErr(err)
return 0, err
}
counter := struct {
Count uint64 `db:"_t"`
}{}
if err := query.One(&counter); err != nil {
if errors.Is(err, mydb.ErrNoMoreRows) {
return 0, nil
}
r.setErr(err)
return 0, err
}
return counter.Count, nil
}
func (r *Result) Paginator() (mydb.Paginator, error) {
if err := r.Err(); err != nil {
return nil, err
}
res, err := r.fastForward()
if err != nil {
return nil, err
}
sel := r.SQL().Select(res.fields...).
From(res.table).
Limit(res.limit).
Offset(res.offset).
GroupBy(res.groupBy...).
OrderBy(res.orderBy...)
for i := range res.conds {
sel = sel.And(filter(res.conds[i])...)
}
pag := sel.Paginate(res.pageSize).
Page(res.pageNumber).
Cursor(res.cursorColumn)
if res.nextPageCursorValue != nil {
pag = pag.NextPage(res.nextPageCursorValue)
}
if res.prevPageCursorValue != nil {
pag = pag.PrevPage(res.prevPageCursorValue)
}
return pag, nil
}
func (r *Result) buildDelete() (mydb.Deleter, error) {
if err := r.Err(); err != nil {
return nil, err
}
res, err := r.fastForward()
if err != nil {
return nil, err
}
del := r.SQL().DeleteFrom(res.table).
Limit(res.limit)
for i := range res.conds {
del = del.And(filter(res.conds[i])...)
}
return del, nil
}
func (r *Result) buildUpdate(values interface{}) (mydb.Updater, error) {
if err := r.Err(); err != nil {
return nil, err
}
res, err := r.fastForward()
if err != nil {
return nil, err
}
upd := r.SQL().Update(res.table).
Set(values).
Limit(res.limit)
for i := range res.conds {
upd = upd.And(filter(res.conds[i])...)
}
return upd, nil
}
func (r *Result) buildCount() (mydb.Selector, error) {
if err := r.Err(); err != nil {
return nil, err
}
res, err := r.fastForward()
if err != nil {
return nil, err
}
sel := r.SQL().Select(mydb.Raw("count(1) AS _t")).
From(res.table).
GroupBy(res.groupBy...)
for i := range res.conds {
sel = sel.And(filter(res.conds[i])...)
}
return sel, nil
}
func (r *Result) Prev() immutable.Immutable {
if r == nil {
return nil
}
return r.prev
}
func (r *Result) Fn(in interface{}) error {
if r.fn == nil {
return nil
}
return r.fn(in.(*result))
}
func (r *Result) Base() interface{} {
return &result{}
}
func (r *Result) fastForward() (*result, error) {
ff, err := immutable.FastForward(r)
if err != nil {
return nil, err
}
return ff.(*result), nil
}
var _ = immutable.Immutable(&Result{})