347 lines
7.7 KiB
Go
347 lines
7.7 KiB
Go
package mongo
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
"sync"
|
|
|
|
"reflect"
|
|
|
|
"git.hexq.cn/tiglog/mydb"
|
|
"git.hexq.cn/tiglog/mydb/internal/adapter"
|
|
mgo "gopkg.in/mgo.v2"
|
|
"gopkg.in/mgo.v2/bson"
|
|
)
|
|
|
|
// Collection represents a mongodb collection.
|
|
type Collection struct {
|
|
parent *Source
|
|
collection *mgo.Collection
|
|
}
|
|
|
|
var (
|
|
// idCache should be a struct if we're going to cache more than just
|
|
// _id field here
|
|
idCache = make(map[reflect.Type]string)
|
|
idCacheMutex sync.RWMutex
|
|
)
|
|
|
|
// Find creates a result set with the given conditions.
|
|
func (col *Collection) Find(terms ...interface{}) mydb.Result {
|
|
fields := []string{"*"}
|
|
|
|
conditions := col.compileQuery(terms...)
|
|
|
|
res := &result{}
|
|
res = res.frame(func(r *resultQuery) error {
|
|
r.c = col
|
|
r.conditions = conditions
|
|
r.fields = fields
|
|
return nil
|
|
})
|
|
|
|
return res
|
|
}
|
|
|
|
var comparisonOperators = map[adapter.ComparisonOperator]string{
|
|
adapter.ComparisonOperatorEqual: "$eq",
|
|
adapter.ComparisonOperatorNotEqual: "$ne",
|
|
|
|
adapter.ComparisonOperatorLessThan: "$lt",
|
|
adapter.ComparisonOperatorGreaterThan: "$gt",
|
|
|
|
adapter.ComparisonOperatorLessThanOrEqualTo: "$lte",
|
|
adapter.ComparisonOperatorGreaterThanOrEqualTo: "$gte",
|
|
|
|
adapter.ComparisonOperatorIn: "$in",
|
|
adapter.ComparisonOperatorNotIn: "$nin",
|
|
}
|
|
|
|
func compare(field string, cmp *adapter.Comparison) (string, interface{}) {
|
|
op := cmp.Operator()
|
|
value := cmp.Value()
|
|
|
|
switch op {
|
|
case adapter.ComparisonOperatorEqual:
|
|
return field, value
|
|
case adapter.ComparisonOperatorBetween:
|
|
values := value.([]interface{})
|
|
return field, bson.M{
|
|
"$gte": values[0],
|
|
"$lte": values[1],
|
|
}
|
|
case adapter.ComparisonOperatorNotBetween:
|
|
values := value.([]interface{})
|
|
return "$or", []bson.M{
|
|
{field: bson.M{"$gt": values[1]}},
|
|
{field: bson.M{"$lt": values[0]}},
|
|
}
|
|
case adapter.ComparisonOperatorIs:
|
|
if value == nil {
|
|
return field, bson.M{"$exists": false}
|
|
}
|
|
return field, bson.M{"$eq": value}
|
|
case adapter.ComparisonOperatorIsNot:
|
|
if value == nil {
|
|
return field, bson.M{"$exists": true}
|
|
}
|
|
return field, bson.M{"$ne": value}
|
|
case adapter.ComparisonOperatorRegExp, adapter.ComparisonOperatorLike:
|
|
return field, bson.RegEx{Pattern: value.(string), Options: ""}
|
|
case adapter.ComparisonOperatorNotRegExp, adapter.ComparisonOperatorNotLike:
|
|
return field, bson.M{"$not": bson.RegEx{Pattern: value.(string), Options: ""}}
|
|
}
|
|
|
|
if cmpOp, ok := comparisonOperators[op]; ok {
|
|
return field, bson.M{
|
|
cmpOp: value,
|
|
}
|
|
}
|
|
|
|
panic(fmt.Sprintf("Unsupported operator %v", op))
|
|
}
|
|
|
|
// compileStatement transforms conditions into something *mgo.Session can
|
|
// understand.
|
|
func compileStatement(cond mydb.Cond) bson.M {
|
|
conds := bson.M{}
|
|
|
|
// Walking over conditions
|
|
for fieldI, value := range cond {
|
|
field := strings.TrimSpace(fmt.Sprintf("%v", fieldI))
|
|
|
|
if cmp, ok := value.(*mydb.Comparison); ok {
|
|
k, v := compare(field, cmp.Comparison)
|
|
conds[k] = v
|
|
continue
|
|
}
|
|
|
|
var op string
|
|
chunks := strings.SplitN(field, ` `, 2)
|
|
|
|
if len(chunks) > 1 {
|
|
switch chunks[1] {
|
|
case `IN`:
|
|
op = `$in`
|
|
case `NOT IN`:
|
|
op = `$nin`
|
|
case `>`:
|
|
op = `$gt`
|
|
case `<`:
|
|
op = `$lt`
|
|
case `<=`:
|
|
op = `$lte`
|
|
case `>=`:
|
|
op = `$gte`
|
|
default:
|
|
op = chunks[1]
|
|
}
|
|
}
|
|
field = chunks[0]
|
|
|
|
if op == "" {
|
|
conds[field] = value
|
|
} else {
|
|
conds[field] = bson.M{op: value}
|
|
}
|
|
}
|
|
|
|
return conds
|
|
}
|
|
|
|
// compileConditions compiles terms into something *mgo.Session can
|
|
// understand.
|
|
func (col *Collection) compileConditions(term interface{}) interface{} {
|
|
|
|
switch t := term.(type) {
|
|
case []interface{}:
|
|
values := []interface{}{}
|
|
for i := range t {
|
|
value := col.compileConditions(t[i])
|
|
if value != nil {
|
|
values = append(values, value)
|
|
}
|
|
}
|
|
if len(values) > 0 {
|
|
return values
|
|
}
|
|
case mydb.Cond:
|
|
return compileStatement(t)
|
|
case adapter.LogicalExpr:
|
|
values := []interface{}{}
|
|
|
|
for _, s := range t.Expressions() {
|
|
values = append(values, col.compileConditions(s))
|
|
}
|
|
|
|
var op string
|
|
switch t.Operator() {
|
|
case adapter.LogicalOperatorOr:
|
|
op = `$or`
|
|
default:
|
|
op = `$and`
|
|
}
|
|
|
|
return bson.M{op: values}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// compileQuery compiles terms into something that *mgo.Session can
|
|
// understand.
|
|
func (col *Collection) compileQuery(terms ...interface{}) interface{} {
|
|
compiled := col.compileConditions(terms)
|
|
if compiled == nil {
|
|
return nil
|
|
}
|
|
|
|
conditions := compiled.([]interface{})
|
|
if len(conditions) == 1 {
|
|
return conditions[0]
|
|
}
|
|
// this should be correct.
|
|
// query = map[string]interface{}{"$and": conditions}
|
|
|
|
// attempt to workaround https://jira.mongomydb.org/browse/SERVER-4572
|
|
mapped := map[string]interface{}{}
|
|
for _, v := range conditions {
|
|
for kk := range v.(map[string]interface{}) {
|
|
mapped[kk] = v.(map[string]interface{})[kk]
|
|
}
|
|
}
|
|
|
|
return mapped
|
|
}
|
|
|
|
// Name returns the name of the table or tables that form the collection.
|
|
func (col *Collection) Name() string {
|
|
return col.collection.Name
|
|
}
|
|
|
|
// Truncate deletes all rows from the table.
|
|
func (col *Collection) Truncate() error {
|
|
err := col.collection.DropCollection()
|
|
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (col *Collection) Session() mydb.Session {
|
|
return col.parent
|
|
}
|
|
|
|
func (col *Collection) Count() (uint64, error) {
|
|
return col.Find().Count()
|
|
}
|
|
|
|
func (col *Collection) InsertReturning(item interface{}) error {
|
|
return mydb.ErrUnsupported
|
|
}
|
|
|
|
func (col *Collection) UpdateReturning(item interface{}) error {
|
|
return mydb.ErrUnsupported
|
|
}
|
|
|
|
// Insert inserts a record (map or struct) into the collection.
|
|
func (col *Collection) Insert(item interface{}) (mydb.InsertResult, error) {
|
|
var err error
|
|
|
|
id := getID(item)
|
|
|
|
if col.parent.versionAtLeast(2, 6, 0, 0) {
|
|
// this breaks MongoDb older than 2.6
|
|
if _, err = col.collection.Upsert(bson.M{"_id": id}, item); err != nil {
|
|
return nil, err
|
|
}
|
|
} else {
|
|
// Allocating a new ID.
|
|
if err = col.collection.Insert(bson.M{"_id": id}); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Now append data the user wants to append.
|
|
if err = col.collection.Update(bson.M{"_id": id}, item); err != nil {
|
|
// Cleanup allocated ID
|
|
if err := col.collection.Remove(bson.M{"_id": id}); err != nil {
|
|
return nil, err
|
|
}
|
|
return nil, err
|
|
}
|
|
}
|
|
|
|
return mydb.NewInsertResult(id), nil
|
|
}
|
|
|
|
// Exists returns true if the collection exists.
|
|
func (col *Collection) Exists() (bool, error) {
|
|
query := col.parent.database.C(`system.namespaces`).Find(map[string]string{`name`: fmt.Sprintf(`%s.%s`, col.parent.database.Name, col.collection.Name)})
|
|
count, err := query.Count()
|
|
return count > 0, err
|
|
}
|
|
|
|
// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid
|
|
func getID(item interface{}) interface{} {
|
|
v := reflect.ValueOf(item) // convert interface to Value
|
|
v = reflect.Indirect(v) // convert pointers
|
|
|
|
switch v.Kind() {
|
|
case reflect.Map:
|
|
if inItem, ok := item.(map[string]interface{}); ok {
|
|
if id, ok := inItem["_id"]; ok {
|
|
bsonID, ok := id.(bson.ObjectId)
|
|
if ok {
|
|
return bsonID
|
|
}
|
|
}
|
|
}
|
|
case reflect.Struct:
|
|
t := v.Type()
|
|
|
|
idCacheMutex.RLock()
|
|
fieldName, found := idCache[t]
|
|
idCacheMutex.RUnlock()
|
|
|
|
if !found {
|
|
for n := 0; n < t.NumField(); n++ {
|
|
field := t.Field(n)
|
|
if field.PkgPath != "" {
|
|
continue // Private field
|
|
}
|
|
|
|
tag := field.Tag.Get("bson")
|
|
if tag == "" {
|
|
tag = field.Tag.Get("db")
|
|
}
|
|
|
|
if tag == "" {
|
|
continue
|
|
}
|
|
|
|
parts := strings.Split(tag, ",")
|
|
|
|
if parts[0] == "_id" {
|
|
fieldName = field.Name
|
|
idCacheMutex.RLock()
|
|
idCache[t] = fieldName
|
|
idCacheMutex.RUnlock()
|
|
break
|
|
}
|
|
}
|
|
}
|
|
if fieldName != "" {
|
|
if bsonID, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok {
|
|
if bsonID.Valid() {
|
|
return bsonID
|
|
}
|
|
} else {
|
|
return v.FieldByName(fieldName).Interface()
|
|
}
|
|
}
|
|
}
|
|
|
|
return bson.NewObjectId()
|
|
}
|