566 lines
11 KiB
Go
566 lines
11 KiB
Go
|
package mongo
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"math"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"encoding/json"
|
||
|
|
||
|
"git.hexq.cn/tiglog/mydb"
|
||
|
"git.hexq.cn/tiglog/mydb/internal/immutable"
|
||
|
mgo "gopkg.in/mgo.v2"
|
||
|
"gopkg.in/mgo.v2/bson"
|
||
|
)
|
||
|
|
||
|
type resultQuery struct {
|
||
|
c *Collection
|
||
|
|
||
|
fields []string
|
||
|
limit int
|
||
|
offset int
|
||
|
sort []string
|
||
|
conditions interface{}
|
||
|
groupBy []interface{}
|
||
|
|
||
|
pageSize uint
|
||
|
pageNumber uint
|
||
|
cursorColumn string
|
||
|
cursorValue interface{}
|
||
|
cursorCond mydb.Cond
|
||
|
cursorReverseOrder bool
|
||
|
}
|
||
|
|
||
|
type result struct {
|
||
|
iter *mgo.Iter
|
||
|
err error
|
||
|
errMu sync.Mutex
|
||
|
|
||
|
fn func(*resultQuery) error
|
||
|
prev *result
|
||
|
}
|
||
|
|
||
|
var _ = immutable.Immutable(&result{})
|
||
|
|
||
|
func (res *result) frame(fn func(*resultQuery) error) *result {
|
||
|
return &result{prev: res, fn: fn}
|
||
|
}
|
||
|
|
||
|
func (r *resultQuery) and(terms ...interface{}) error {
|
||
|
if r.conditions == nil {
|
||
|
return r.where(terms...)
|
||
|
}
|
||
|
|
||
|
r.conditions = map[string]interface{}{
|
||
|
"$and": []interface{}{
|
||
|
r.conditions,
|
||
|
r.c.compileQuery(terms...),
|
||
|
},
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (r *resultQuery) where(terms ...interface{}) error {
|
||
|
r.conditions = r.c.compileQuery(terms...)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (res *result) And(terms ...interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
return r.and(terms...)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) Where(terms ...interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
return r.where(terms...)
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) Paginate(pageSize uint) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.pageSize = pageSize
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) Page(pageNumber uint) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.pageNumber = pageNumber
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) Cursor(cursorColumn string) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.cursorColumn = cursorColumn
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) NextPage(cursorValue interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.cursorValue = cursorValue
|
||
|
r.cursorReverseOrder = false
|
||
|
r.cursorCond = mydb.Cond{
|
||
|
r.cursorColumn: bson.M{"$gt": cursorValue},
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) PrevPage(cursorValue interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.cursorValue = cursorValue
|
||
|
r.cursorReverseOrder = true
|
||
|
r.cursorCond = mydb.Cond{
|
||
|
r.cursorColumn: bson.M{"$lt": cursorValue},
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
func (res *result) TotalEntries() (uint64, error) {
|
||
|
return res.Count()
|
||
|
}
|
||
|
|
||
|
func (res *result) TotalPages() (uint, error) {
|
||
|
count, err := res.Count()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
if rq.pageSize < 1 {
|
||
|
return 1, nil
|
||
|
}
|
||
|
|
||
|
total := uint(math.Ceil(float64(count) / float64(rq.pageSize)))
|
||
|
return total, nil
|
||
|
}
|
||
|
|
||
|
// Limit determines the maximum limit of results to be returned.
|
||
|
func (res *result) Limit(n int) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.limit = n
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// Offset determines how many documents will be skipped before starting to grab
|
||
|
// results.
|
||
|
func (res *result) Offset(n int) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.offset = n
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// OrderBy determines sorting of results according to the provided names. Fields
|
||
|
// may be prefixed by - (minus) which means descending order, ascending order
|
||
|
// would be used otherwise.
|
||
|
func (res *result) OrderBy(fields ...interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
ss := make([]string, len(fields))
|
||
|
for i, field := range fields {
|
||
|
ss[i] = fmt.Sprintf(`%v`, field)
|
||
|
}
|
||
|
r.sort = ss
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// String satisfies fmt.Stringer
|
||
|
func (res *result) String() string {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
// Select marks the specific fields the user wants to retrieve.
|
||
|
func (res *result) Select(fields ...interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
fieldslen := len(fields)
|
||
|
r.fields = make([]string, 0, fieldslen)
|
||
|
for i := 0; i < fieldslen; i++ {
|
||
|
r.fields = append(r.fields, fmt.Sprintf(`%v`, fields[i]))
|
||
|
}
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// All dumps all results into a pointer to an slice of structs or maps.
|
||
|
func (res *result) All(dst interface{}) error {
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
q, err := rq.query()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer func(start time.Time) {
|
||
|
queryLog(&mydb.QueryStatus{
|
||
|
RawQuery: rq.debugQuery("Find.All"),
|
||
|
Err: err,
|
||
|
Start: start,
|
||
|
End: time.Now(),
|
||
|
})
|
||
|
}(time.Now())
|
||
|
|
||
|
err = q.All(dst)
|
||
|
if errors.Is(err, mgo.ErrNotFound) {
|
||
|
return mydb.ErrNoMoreRows
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// GroupBy is used to group results that have the same value in the same column
|
||
|
// or columns.
|
||
|
func (res *result) GroupBy(fields ...interface{}) mydb.Result {
|
||
|
return res.frame(func(r *resultQuery) error {
|
||
|
r.groupBy = fields
|
||
|
return nil
|
||
|
})
|
||
|
}
|
||
|
|
||
|
// One fetches only one result from the resultset.
|
||
|
func (res *result) One(dst interface{}) error {
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
q, err := rq.query()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer func(start time.Time) {
|
||
|
queryLog(&mydb.QueryStatus{
|
||
|
RawQuery: rq.debugQuery("Find.One"),
|
||
|
Err: err,
|
||
|
Start: start,
|
||
|
End: time.Now(),
|
||
|
})
|
||
|
}(time.Now())
|
||
|
|
||
|
err = q.One(dst)
|
||
|
if errors.Is(err, mgo.ErrNotFound) {
|
||
|
return mydb.ErrNoMoreRows
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (res *result) Err() error {
|
||
|
res.errMu.Lock()
|
||
|
defer res.errMu.Unlock()
|
||
|
|
||
|
return res.err
|
||
|
}
|
||
|
|
||
|
func (res *result) setErr(err error) {
|
||
|
res.errMu.Lock()
|
||
|
defer res.errMu.Unlock()
|
||
|
res.err = err
|
||
|
}
|
||
|
|
||
|
func (res *result) Next(dst interface{}) bool {
|
||
|
if res.iter == nil {
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
q, err := rq.query()
|
||
|
if err != nil {
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
defer func(start time.Time) {
|
||
|
queryLog(&mydb.QueryStatus{
|
||
|
RawQuery: rq.debugQuery("Find.Next"),
|
||
|
Err: err,
|
||
|
Start: start,
|
||
|
End: time.Now(),
|
||
|
})
|
||
|
}(time.Now())
|
||
|
|
||
|
res.iter = q.Iter()
|
||
|
}
|
||
|
|
||
|
if !res.iter.Next(dst) {
|
||
|
res.setErr(res.iter.Err())
|
||
|
return false
|
||
|
}
|
||
|
|
||
|
return true
|
||
|
}
|
||
|
|
||
|
// Delete remove the matching items from the collection.
|
||
|
func (res *result) Delete() error {
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer func(start time.Time) {
|
||
|
queryLog(&mydb.QueryStatus{
|
||
|
RawQuery: rq.debugQuery("Remove"),
|
||
|
Err: err,
|
||
|
Start: start,
|
||
|
End: time.Now(),
|
||
|
})
|
||
|
}(time.Now())
|
||
|
|
||
|
_, err = rq.c.collection.RemoveAll(rq.conditions)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// Close closes the result set.
|
||
|
func (r *result) Close() error {
|
||
|
var err error
|
||
|
if r.iter != nil {
|
||
|
err = r.iter.Close()
|
||
|
r.iter = nil
|
||
|
}
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// Update modified matching items from the collection with values of the given
|
||
|
// map or struct.
|
||
|
func (res *result) Update(src interface{}) (err error) {
|
||
|
updateSet := map[string]interface{}{"$set": src}
|
||
|
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
defer func(start time.Time) {
|
||
|
queryLog(&mydb.QueryStatus{
|
||
|
RawQuery: rq.debugQuery("Update"),
|
||
|
Err: err,
|
||
|
Start: start,
|
||
|
End: time.Now(),
|
||
|
})
|
||
|
}(time.Now())
|
||
|
|
||
|
_, err = rq.c.collection.UpdateAll(rq.conditions, updateSet)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (res *result) build() (*resultQuery, error) {
|
||
|
rqi, err := immutable.FastForward(res)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
rq := rqi.(*resultQuery)
|
||
|
if !rq.cursorCond.Empty() {
|
||
|
if err := rq.and(rq.cursorCond); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if rq.cursorColumn != "" {
|
||
|
if rq.cursorReverseOrder {
|
||
|
rq.sort = append(rq.sort, "-"+rq.cursorColumn)
|
||
|
} else {
|
||
|
rq.sort = append(rq.sort, rq.cursorColumn)
|
||
|
}
|
||
|
}
|
||
|
return rq, nil
|
||
|
}
|
||
|
|
||
|
// query executes a mgo query.
|
||
|
func (r *resultQuery) query() (*mgo.Query, error) {
|
||
|
if len(r.groupBy) > 0 {
|
||
|
return nil, mydb.ErrUnsupported
|
||
|
}
|
||
|
|
||
|
q := r.c.collection.Find(r.conditions)
|
||
|
|
||
|
if r.pageSize > 0 {
|
||
|
r.offset = int(r.pageSize * r.pageNumber)
|
||
|
r.limit = int(r.pageSize)
|
||
|
}
|
||
|
|
||
|
if r.offset > 0 {
|
||
|
q.Skip(r.offset)
|
||
|
}
|
||
|
|
||
|
if r.limit > 0 {
|
||
|
q.Limit(r.limit)
|
||
|
}
|
||
|
|
||
|
if len(r.sort) > 0 {
|
||
|
q.Sort(r.sort...)
|
||
|
}
|
||
|
|
||
|
selectedFields := bson.M{}
|
||
|
if len(r.fields) > 0 {
|
||
|
for _, field := range r.fields {
|
||
|
if field == `*` {
|
||
|
break
|
||
|
}
|
||
|
selectedFields[field] = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if r.cursorReverseOrder {
|
||
|
ids := make([]bson.ObjectId, 0, r.limit)
|
||
|
|
||
|
iter := q.Select(bson.M{"_id": true}).Iter()
|
||
|
defer iter.Close()
|
||
|
|
||
|
var item map[string]bson.ObjectId
|
||
|
for iter.Next(&item) {
|
||
|
ids = append(ids, item["_id"])
|
||
|
}
|
||
|
|
||
|
r.conditions = bson.M{"_id": bson.M{"$in": ids}}
|
||
|
|
||
|
q = r.c.collection.Find(r.conditions)
|
||
|
}
|
||
|
|
||
|
if len(selectedFields) > 0 {
|
||
|
q.Select(selectedFields)
|
||
|
}
|
||
|
|
||
|
return q, nil
|
||
|
}
|
||
|
|
||
|
func (res *result) Exists() (bool, error) {
|
||
|
total, err := res.Count()
|
||
|
if err != nil {
|
||
|
return false, err
|
||
|
}
|
||
|
if total > 0 {
|
||
|
return true, nil
|
||
|
}
|
||
|
return false, nil
|
||
|
}
|
||
|
|
||
|
// Count counts matching elements.
|
||
|
func (res *result) Count() (total uint64, err error) {
|
||
|
rq, err := res.build()
|
||
|
if err != nil {
|
||
|
return 0, err
|
||
|
}
|
||
|
|
||
|
defer func(start time.Time) {
|
||
|
queryLog(&mydb.QueryStatus{
|
||
|
RawQuery: rq.debugQuery("Find.Count"),
|
||
|
Err: err,
|
||
|
Start: start,
|
||
|
End: time.Now(),
|
||
|
})
|
||
|
}(time.Now())
|
||
|
|
||
|
q := rq.c.collection.Find(rq.conditions)
|
||
|
|
||
|
var c int
|
||
|
c, err = q.Count()
|
||
|
|
||
|
return uint64(c), err
|
||
|
}
|
||
|
|
||
|
func (res *result) Prev() immutable.Immutable {
|
||
|
if res == nil {
|
||
|
return nil
|
||
|
}
|
||
|
return res.prev
|
||
|
}
|
||
|
|
||
|
func (res *result) Fn(in interface{}) error {
|
||
|
if res.fn == nil {
|
||
|
return nil
|
||
|
}
|
||
|
return res.fn(in.(*resultQuery))
|
||
|
}
|
||
|
|
||
|
func (res *result) Base() interface{} {
|
||
|
return &resultQuery{}
|
||
|
}
|
||
|
|
||
|
func (r *resultQuery) debugQuery(action string) string {
|
||
|
query := fmt.Sprintf("mydb.%s.%s", r.c.collection.Name, action)
|
||
|
|
||
|
if r.conditions != nil {
|
||
|
query = fmt.Sprintf("%s.conds(%v)", query, r.conditions)
|
||
|
}
|
||
|
if r.limit > 0 {
|
||
|
query = fmt.Sprintf("%s.limit(%d)", query, r.limit)
|
||
|
}
|
||
|
if r.offset > 0 {
|
||
|
query = fmt.Sprintf("%s.offset(%d)", query, r.offset)
|
||
|
}
|
||
|
if len(r.fields) > 0 {
|
||
|
selectedFields := bson.M{}
|
||
|
for _, field := range r.fields {
|
||
|
if field == `*` {
|
||
|
break
|
||
|
}
|
||
|
selectedFields[field] = true
|
||
|
}
|
||
|
if len(selectedFields) > 0 {
|
||
|
query = fmt.Sprintf("%s.select(%v)", query, selectedFields)
|
||
|
}
|
||
|
}
|
||
|
if len(r.groupBy) > 0 {
|
||
|
escaped := make([]string, len(r.groupBy))
|
||
|
for i := range r.groupBy {
|
||
|
escaped[i] = string(mustJSON(r.groupBy[i]))
|
||
|
}
|
||
|
query = fmt.Sprintf("%s.groupBy(%v)", query, strings.Join(escaped, ", "))
|
||
|
}
|
||
|
if len(r.sort) > 0 {
|
||
|
escaped := make([]string, len(r.sort))
|
||
|
for i := range r.sort {
|
||
|
escaped[i] = string(mustJSON(r.sort[i]))
|
||
|
}
|
||
|
query = fmt.Sprintf("%s.sort(%s)", query, strings.Join(escaped, ", "))
|
||
|
}
|
||
|
return query
|
||
|
}
|
||
|
|
||
|
func mustJSON(in interface{}) (out []byte) {
|
||
|
out, err := json.Marshal(in)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
return out
|
||
|
}
|
||
|
|
||
|
func queryLog(status *mydb.QueryStatus) {
|
||
|
diff := status.End.Sub(status.Start)
|
||
|
|
||
|
slowQuery := false
|
||
|
if diff >= time.Millisecond*100 {
|
||
|
status.Err = mydb.ErrWarnSlowQuery
|
||
|
slowQuery = true
|
||
|
}
|
||
|
|
||
|
if status.Err != nil || slowQuery {
|
||
|
mydb.LC().Warn(status)
|
||
|
return
|
||
|
}
|
||
|
|
||
|
mydb.LC().Debug(status)
|
||
|
}
|