package mongo import ( "errors" "fmt" "math" "strings" "sync" "time" "encoding/json" "git.hexq.cn/tiglog/mydb" "git.hexq.cn/tiglog/mydb/internal/immutable" mgo "gopkg.in/mgo.v2" "gopkg.in/mgo.v2/bson" ) type resultQuery struct { c *Collection fields []string limit int offset int sort []string conditions interface{} groupBy []interface{} pageSize uint pageNumber uint cursorColumn string cursorValue interface{} cursorCond mydb.Cond cursorReverseOrder bool } type result struct { iter *mgo.Iter err error errMu sync.Mutex fn func(*resultQuery) error prev *result } var _ = immutable.Immutable(&result{}) func (res *result) frame(fn func(*resultQuery) error) *result { return &result{prev: res, fn: fn} } func (r *resultQuery) and(terms ...interface{}) error { if r.conditions == nil { return r.where(terms...) } r.conditions = map[string]interface{}{ "$and": []interface{}{ r.conditions, r.c.compileQuery(terms...), }, } return nil } func (r *resultQuery) where(terms ...interface{}) error { r.conditions = r.c.compileQuery(terms...) return nil } func (res *result) And(terms ...interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { return r.and(terms...) }) } func (res *result) Where(terms ...interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { return r.where(terms...) }) } func (res *result) Paginate(pageSize uint) mydb.Result { return res.frame(func(r *resultQuery) error { r.pageSize = pageSize return nil }) } func (res *result) Page(pageNumber uint) mydb.Result { return res.frame(func(r *resultQuery) error { r.pageNumber = pageNumber return nil }) } func (res *result) Cursor(cursorColumn string) mydb.Result { return res.frame(func(r *resultQuery) error { r.cursorColumn = cursorColumn return nil }) } func (res *result) NextPage(cursorValue interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { r.cursorValue = cursorValue r.cursorReverseOrder = false r.cursorCond = mydb.Cond{ r.cursorColumn: bson.M{"$gt": cursorValue}, } return nil }) } func (res *result) PrevPage(cursorValue interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { r.cursorValue = cursorValue r.cursorReverseOrder = true r.cursorCond = mydb.Cond{ r.cursorColumn: bson.M{"$lt": cursorValue}, } return nil }) } func (res *result) TotalEntries() (uint64, error) { return res.Count() } func (res *result) TotalPages() (uint, error) { count, err := res.Count() if err != nil { return 0, err } rq, err := res.build() if err != nil { return 0, err } if rq.pageSize < 1 { return 1, nil } total := uint(math.Ceil(float64(count) / float64(rq.pageSize))) return total, nil } // Limit determines the maximum limit of results to be returned. func (res *result) Limit(n int) mydb.Result { return res.frame(func(r *resultQuery) error { r.limit = n return nil }) } // Offset determines how many documents will be skipped before starting to grab // results. func (res *result) Offset(n int) mydb.Result { return res.frame(func(r *resultQuery) error { r.offset = n return nil }) } // OrderBy determines sorting of results according to the provided names. Fields // may be prefixed by - (minus) which means descending order, ascending order // would be used otherwise. func (res *result) OrderBy(fields ...interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { ss := make([]string, len(fields)) for i, field := range fields { ss[i] = fmt.Sprintf(`%v`, field) } r.sort = ss return nil }) } // String satisfies fmt.Stringer func (res *result) String() string { return "" } // Select marks the specific fields the user wants to retrieve. func (res *result) Select(fields ...interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { fieldslen := len(fields) r.fields = make([]string, 0, fieldslen) for i := 0; i < fieldslen; i++ { r.fields = append(r.fields, fmt.Sprintf(`%v`, fields[i])) } return nil }) } // All dumps all results into a pointer to an slice of structs or maps. func (res *result) All(dst interface{}) error { rq, err := res.build() if err != nil { return err } q, err := rq.query() if err != nil { return err } defer func(start time.Time) { queryLog(&mydb.QueryStatus{ RawQuery: rq.debugQuery("Find.All"), Err: err, Start: start, End: time.Now(), }) }(time.Now()) err = q.All(dst) if errors.Is(err, mgo.ErrNotFound) { return mydb.ErrNoMoreRows } return err } // GroupBy is used to group results that have the same value in the same column // or columns. func (res *result) GroupBy(fields ...interface{}) mydb.Result { return res.frame(func(r *resultQuery) error { r.groupBy = fields return nil }) } // One fetches only one result from the resultset. func (res *result) One(dst interface{}) error { rq, err := res.build() if err != nil { return err } q, err := rq.query() if err != nil { return err } defer func(start time.Time) { queryLog(&mydb.QueryStatus{ RawQuery: rq.debugQuery("Find.One"), Err: err, Start: start, End: time.Now(), }) }(time.Now()) err = q.One(dst) if errors.Is(err, mgo.ErrNotFound) { return mydb.ErrNoMoreRows } return err } func (res *result) Err() error { res.errMu.Lock() defer res.errMu.Unlock() return res.err } func (res *result) setErr(err error) { res.errMu.Lock() defer res.errMu.Unlock() res.err = err } func (res *result) Next(dst interface{}) bool { if res.iter == nil { rq, err := res.build() if err != nil { return false } q, err := rq.query() if err != nil { return false } defer func(start time.Time) { queryLog(&mydb.QueryStatus{ RawQuery: rq.debugQuery("Find.Next"), Err: err, Start: start, End: time.Now(), }) }(time.Now()) res.iter = q.Iter() } if !res.iter.Next(dst) { res.setErr(res.iter.Err()) return false } return true } // Delete remove the matching items from the collection. func (res *result) Delete() error { rq, err := res.build() if err != nil { return err } defer func(start time.Time) { queryLog(&mydb.QueryStatus{ RawQuery: rq.debugQuery("Remove"), Err: err, Start: start, End: time.Now(), }) }(time.Now()) _, err = rq.c.collection.RemoveAll(rq.conditions) if err != nil { return err } return nil } // Close closes the result set. func (r *result) Close() error { var err error if r.iter != nil { err = r.iter.Close() r.iter = nil } return err } // Update modified matching items from the collection with values of the given // map or struct. func (res *result) Update(src interface{}) (err error) { updateSet := map[string]interface{}{"$set": src} rq, err := res.build() if err != nil { return err } defer func(start time.Time) { queryLog(&mydb.QueryStatus{ RawQuery: rq.debugQuery("Update"), Err: err, Start: start, End: time.Now(), }) }(time.Now()) _, err = rq.c.collection.UpdateAll(rq.conditions, updateSet) if err != nil { return err } return nil } func (res *result) build() (*resultQuery, error) { rqi, err := immutable.FastForward(res) if err != nil { return nil, err } rq := rqi.(*resultQuery) if !rq.cursorCond.Empty() { if err := rq.and(rq.cursorCond); err != nil { return nil, err } } if rq.cursorColumn != "" { if rq.cursorReverseOrder { rq.sort = append(rq.sort, "-"+rq.cursorColumn) } else { rq.sort = append(rq.sort, rq.cursorColumn) } } return rq, nil } // query executes a mgo query. func (r *resultQuery) query() (*mgo.Query, error) { if len(r.groupBy) > 0 { return nil, mydb.ErrUnsupported } q := r.c.collection.Find(r.conditions) if r.pageSize > 0 { r.offset = int(r.pageSize * r.pageNumber) r.limit = int(r.pageSize) } if r.offset > 0 { q.Skip(r.offset) } if r.limit > 0 { q.Limit(r.limit) } if len(r.sort) > 0 { q.Sort(r.sort...) } selectedFields := bson.M{} if len(r.fields) > 0 { for _, field := range r.fields { if field == `*` { break } selectedFields[field] = true } } if r.cursorReverseOrder { ids := make([]bson.ObjectId, 0, r.limit) iter := q.Select(bson.M{"_id": true}).Iter() defer iter.Close() var item map[string]bson.ObjectId for iter.Next(&item) { ids = append(ids, item["_id"]) } r.conditions = bson.M{"_id": bson.M{"$in": ids}} q = r.c.collection.Find(r.conditions) } if len(selectedFields) > 0 { q.Select(selectedFields) } return q, nil } func (res *result) Exists() (bool, error) { total, err := res.Count() if err != nil { return false, err } if total > 0 { return true, nil } return false, nil } // Count counts matching elements. func (res *result) Count() (total uint64, err error) { rq, err := res.build() if err != nil { return 0, err } defer func(start time.Time) { queryLog(&mydb.QueryStatus{ RawQuery: rq.debugQuery("Find.Count"), Err: err, Start: start, End: time.Now(), }) }(time.Now()) q := rq.c.collection.Find(rq.conditions) var c int c, err = q.Count() return uint64(c), err } func (res *result) Prev() immutable.Immutable { if res == nil { return nil } return res.prev } func (res *result) Fn(in interface{}) error { if res.fn == nil { return nil } return res.fn(in.(*resultQuery)) } func (res *result) Base() interface{} { return &resultQuery{} } func (r *resultQuery) debugQuery(action string) string { query := fmt.Sprintf("mydb.%s.%s", r.c.collection.Name, action) if r.conditions != nil { query = fmt.Sprintf("%s.conds(%v)", query, r.conditions) } if r.limit > 0 { query = fmt.Sprintf("%s.limit(%d)", query, r.limit) } if r.offset > 0 { query = fmt.Sprintf("%s.offset(%d)", query, r.offset) } if len(r.fields) > 0 { selectedFields := bson.M{} for _, field := range r.fields { if field == `*` { break } selectedFields[field] = true } if len(selectedFields) > 0 { query = fmt.Sprintf("%s.select(%v)", query, selectedFields) } } if len(r.groupBy) > 0 { escaped := make([]string, len(r.groupBy)) for i := range r.groupBy { escaped[i] = string(mustJSON(r.groupBy[i])) } query = fmt.Sprintf("%s.groupBy(%v)", query, strings.Join(escaped, ", ")) } if len(r.sort) > 0 { escaped := make([]string, len(r.sort)) for i := range r.sort { escaped[i] = string(mustJSON(r.sort[i])) } query = fmt.Sprintf("%s.sort(%s)", query, strings.Join(escaped, ", ")) } return query } func mustJSON(in interface{}) (out []byte) { out, err := json.Marshal(in) if err != nil { panic(err) } return out } func queryLog(status *mydb.QueryStatus) { diff := status.End.Sub(status.Start) slowQuery := false if diff >= time.Millisecond*100 { status.Err = mydb.ErrWarnSlowQuery slowQuery = true } if status.Err != nil || slowQuery { mydb.LC().Warn(status) return } mydb.LC().Debug(status) }