commit 4ac92c26b052d2a21a05e9d83096a9ce1d3afbaf Author: tiglog Date: Mon Sep 18 15:15:42 2023 +0800 first commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2946070 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.sw? +*.db +*.tmp +generated_*.go diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c5a605a --- /dev/null +++ b/Makefile @@ -0,0 +1,39 @@ +SHELL ?= /bin/bash + +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs=2 -v -u + +TEST_FLAGS ?= + +UPPER_DB_LOG ?= WARN + +export TEST_FLAGS +export PARALLEL_FLAGS +export UPPER_DB_LOG + +test: go-test-internal test-adapters + +benchmark: go-benchmark-internal + +go-benchmark-%: + go test -v -benchtime=500ms -bench=. ./$*/... + +go-test-%: + go test -v ./$*/... + +test-adapters: \ + test-adapter-postgresql \ + # test-adapter-mysql \ + # test-adapter-sqlite \ + # test-adapter-mongo + +test-adapter-%: + ($(MAKE) -C adapter/$* test-extended || exit 1) + +test-generic: + export TEST_FLAGS="-run TestGeneric"; \ + $(MAKE) test-adapters + +goimports: + for FILE in $$(find -name "*.go" | grep -v vendor); do \ + goimports -w $$FILE; \ + done diff --git a/adapter.go b/adapter.go new file mode 100644 index 0000000..1ef65b2 --- /dev/null +++ b/adapter.go @@ -0,0 +1,54 @@ +package mydb + +import ( + "fmt" + "sync" +) + +var ( + adapterMap = make(map[string]Adapter) + adapterMapMu sync.RWMutex +) + +// Adapter interface defines an adapter +type Adapter interface { + Open(ConnectionURL) (Session, error) +} + +type missingAdapter struct { + name string +} + +func (ma *missingAdapter) Open(ConnectionURL) (Session, error) { + return nil, fmt.Errorf("mydb: Missing adapter %q, did you forget to import it?", ma.name) +} + +// RegisterAdapter registers a generic database adapter. +func RegisterAdapter(name string, adapter Adapter) { + adapterMapMu.Lock() + defer adapterMapMu.Unlock() + + if name == "" { + panic(`Missing adapter name`) + } + if _, ok := adapterMap[name]; ok { + panic(`db.RegisterAdapter() called twice for adapter: ` + name) + } + adapterMap[name] = adapter +} + +// LookupAdapter returns a previously registered adapter by name. +func LookupAdapter(name string) Adapter { + adapterMapMu.RLock() + defer adapterMapMu.RUnlock() + + if adapter, ok := adapterMap[name]; ok { + return adapter + } + return &missingAdapter{name: name} +} + +// Open attempts to stablish a connection with a database. +func Open(adapterName string, settings ConnectionURL) (Session, error) { + return LookupAdapter(adapterName).Open(settings) +} diff --git a/adapter/mongo/Makefile b/adapter/mongo/Makefile new file mode 100644 index 0000000..7660d30 --- /dev/null +++ b/adapter/mongo/Makefile @@ -0,0 +1,43 @@ +SHELL ?= bash + +MONGO_VERSION ?= 4 +MONGO_SUPPORTED ?= $(MONGO_VERSION) 3 +PROJECT ?= upper_mongo_$(MONGO_VERSION) + +DB_HOST ?= 127.0.0.1 +DB_PORT ?= 27017 + +DB_NAME ?= admin +DB_USERNAME ?= upperio_user +DB_PASSWORD ?= upperio//s3cr37 + +TEST_FLAGS ?= +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1 + +export MONGO_VERSION + +export DB_HOST +export DB_NAME +export DB_PASSWORD +export DB_PORT +export DB_USERNAME + +export TEST_FLAGS + +test: + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +server-up: server-down + docker-compose -p $(PROJECT) up -d && \ + sleep 10 + +server-down: + docker-compose -p $(PROJECT) down + +test-extended: + parallel $(PARALLEL_FLAGS) \ + "MONGO_VERSION={} DB_PORT=\$$((27017+{#})) $(MAKE) server-up test server-down" ::: \ + $(MONGO_SUPPORTED) diff --git a/adapter/mongo/README.md b/adapter/mongo/README.md new file mode 100644 index 0000000..cb56a27 --- /dev/null +++ b/adapter/mongo/README.md @@ -0,0 +1,4 @@ +# MongoDB adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/mongo/](https://upper.io/v4/adapter/mongo/). diff --git a/adapter/mongo/collection.go b/adapter/mongo/collection.go new file mode 100644 index 0000000..0c27c06 --- /dev/null +++ b/adapter/mongo/collection.go @@ -0,0 +1,346 @@ +package mongo + +import ( + "fmt" + "strings" + "sync" + + "reflect" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + mgo "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +// Collection represents a mongodb collection. +type Collection struct { + parent *Source + collection *mgo.Collection +} + +var ( + // idCache should be a struct if we're going to cache more than just + // _id field here + idCache = make(map[reflect.Type]string) + idCacheMutex sync.RWMutex +) + +// Find creates a result set with the given conditions. +func (col *Collection) Find(terms ...interface{}) mydb.Result { + fields := []string{"*"} + + conditions := col.compileQuery(terms...) + + res := &result{} + res = res.frame(func(r *resultQuery) error { + r.c = col + r.conditions = conditions + r.fields = fields + return nil + }) + + return res +} + +var comparisonOperators = map[adapter.ComparisonOperator]string{ + adapter.ComparisonOperatorEqual: "$eq", + adapter.ComparisonOperatorNotEqual: "$ne", + + adapter.ComparisonOperatorLessThan: "$lt", + adapter.ComparisonOperatorGreaterThan: "$gt", + + adapter.ComparisonOperatorLessThanOrEqualTo: "$lte", + adapter.ComparisonOperatorGreaterThanOrEqualTo: "$gte", + + adapter.ComparisonOperatorIn: "$in", + adapter.ComparisonOperatorNotIn: "$nin", +} + +func compare(field string, cmp *adapter.Comparison) (string, interface{}) { + op := cmp.Operator() + value := cmp.Value() + + switch op { + case adapter.ComparisonOperatorEqual: + return field, value + case adapter.ComparisonOperatorBetween: + values := value.([]interface{}) + return field, bson.M{ + "$gte": values[0], + "$lte": values[1], + } + case adapter.ComparisonOperatorNotBetween: + values := value.([]interface{}) + return "$or", []bson.M{ + {field: bson.M{"$gt": values[1]}}, + {field: bson.M{"$lt": values[0]}}, + } + case adapter.ComparisonOperatorIs: + if value == nil { + return field, bson.M{"$exists": false} + } + return field, bson.M{"$eq": value} + case adapter.ComparisonOperatorIsNot: + if value == nil { + return field, bson.M{"$exists": true} + } + return field, bson.M{"$ne": value} + case adapter.ComparisonOperatorRegExp, adapter.ComparisonOperatorLike: + return field, bson.RegEx{Pattern: value.(string), Options: ""} + case adapter.ComparisonOperatorNotRegExp, adapter.ComparisonOperatorNotLike: + return field, bson.M{"$not": bson.RegEx{Pattern: value.(string), Options: ""}} + } + + if cmpOp, ok := comparisonOperators[op]; ok { + return field, bson.M{ + cmpOp: value, + } + } + + panic(fmt.Sprintf("Unsupported operator %v", op)) +} + +// compileStatement transforms conditions into something *mgo.Session can +// understand. +func compileStatement(cond mydb.Cond) bson.M { + conds := bson.M{} + + // Walking over conditions + for fieldI, value := range cond { + field := strings.TrimSpace(fmt.Sprintf("%v", fieldI)) + + if cmp, ok := value.(*mydb.Comparison); ok { + k, v := compare(field, cmp.Comparison) + conds[k] = v + continue + } + + var op string + chunks := strings.SplitN(field, ` `, 2) + + if len(chunks) > 1 { + switch chunks[1] { + case `IN`: + op = `$in` + case `NOT IN`: + op = `$nin` + case `>`: + op = `$gt` + case `<`: + op = `$lt` + case `<=`: + op = `$lte` + case `>=`: + op = `$gte` + default: + op = chunks[1] + } + } + field = chunks[0] + + if op == "" { + conds[field] = value + } else { + conds[field] = bson.M{op: value} + } + } + + return conds +} + +// compileConditions compiles terms into something *mgo.Session can +// understand. +func (col *Collection) compileConditions(term interface{}) interface{} { + + switch t := term.(type) { + case []interface{}: + values := []interface{}{} + for i := range t { + value := col.compileConditions(t[i]) + if value != nil { + values = append(values, value) + } + } + if len(values) > 0 { + return values + } + case mydb.Cond: + return compileStatement(t) + case adapter.LogicalExpr: + values := []interface{}{} + + for _, s := range t.Expressions() { + values = append(values, col.compileConditions(s)) + } + + var op string + switch t.Operator() { + case adapter.LogicalOperatorOr: + op = `$or` + default: + op = `$and` + } + + return bson.M{op: values} + } + return nil +} + +// compileQuery compiles terms into something that *mgo.Session can +// understand. +func (col *Collection) compileQuery(terms ...interface{}) interface{} { + compiled := col.compileConditions(terms) + if compiled == nil { + return nil + } + + conditions := compiled.([]interface{}) + if len(conditions) == 1 { + return conditions[0] + } + // this should be correct. + // query = map[string]interface{}{"$and": conditions} + + // attempt to workaround https://jira.mongomydb.org/browse/SERVER-4572 + mapped := map[string]interface{}{} + for _, v := range conditions { + for kk := range v.(map[string]interface{}) { + mapped[kk] = v.(map[string]interface{})[kk] + } + } + + return mapped +} + +// Name returns the name of the table or tables that form the collection. +func (col *Collection) Name() string { + return col.collection.Name +} + +// Truncate deletes all rows from the table. +func (col *Collection) Truncate() error { + err := col.collection.DropCollection() + + if err != nil { + return err + } + + return nil +} + +func (col *Collection) Session() mydb.Session { + return col.parent +} + +func (col *Collection) Count() (uint64, error) { + return col.Find().Count() +} + +func (col *Collection) InsertReturning(item interface{}) error { + return mydb.ErrUnsupported +} + +func (col *Collection) UpdateReturning(item interface{}) error { + return mydb.ErrUnsupported +} + +// Insert inserts a record (map or struct) into the collection. +func (col *Collection) Insert(item interface{}) (mydb.InsertResult, error) { + var err error + + id := getID(item) + + if col.parent.versionAtLeast(2, 6, 0, 0) { + // this breaks MongoDb older than 2.6 + if _, err = col.collection.Upsert(bson.M{"_id": id}, item); err != nil { + return nil, err + } + } else { + // Allocating a new ID. + if err = col.collection.Insert(bson.M{"_id": id}); err != nil { + return nil, err + } + + // Now append data the user wants to append. + if err = col.collection.Update(bson.M{"_id": id}, item); err != nil { + // Cleanup allocated ID + if err := col.collection.Remove(bson.M{"_id": id}); err != nil { + return nil, err + } + return nil, err + } + } + + return mydb.NewInsertResult(id), nil +} + +// Exists returns true if the collection exists. +func (col *Collection) Exists() (bool, error) { + query := col.parent.database.C(`system.namespaces`).Find(map[string]string{`name`: fmt.Sprintf(`%s.%s`, col.parent.database.Name, col.collection.Name)}) + count, err := query.Count() + return count > 0, err +} + +// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid +func getID(item interface{}) interface{} { + v := reflect.ValueOf(item) // convert interface to Value + v = reflect.Indirect(v) // convert pointers + + switch v.Kind() { + case reflect.Map: + if inItem, ok := item.(map[string]interface{}); ok { + if id, ok := inItem["_id"]; ok { + bsonID, ok := id.(bson.ObjectId) + if ok { + return bsonID + } + } + } + case reflect.Struct: + t := v.Type() + + idCacheMutex.RLock() + fieldName, found := idCache[t] + idCacheMutex.RUnlock() + + if !found { + for n := 0; n < t.NumField(); n++ { + field := t.Field(n) + if field.PkgPath != "" { + continue // Private field + } + + tag := field.Tag.Get("bson") + if tag == "" { + tag = field.Tag.Get("db") + } + + if tag == "" { + continue + } + + parts := strings.Split(tag, ",") + + if parts[0] == "_id" { + fieldName = field.Name + idCacheMutex.RLock() + idCache[t] = fieldName + idCacheMutex.RUnlock() + break + } + } + } + if fieldName != "" { + if bsonID, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok { + if bsonID.Valid() { + return bsonID + } + } else { + return v.FieldByName(fieldName).Interface() + } + } + } + + return bson.NewObjectId() +} diff --git a/adapter/mongo/connection.go b/adapter/mongo/connection.go new file mode 100644 index 0000000..48377c4 --- /dev/null +++ b/adapter/mongo/connection.go @@ -0,0 +1,98 @@ +package mongo + +import ( + "fmt" + "net/url" + "strings" +) + +const connectionScheme = `mongodb` + +// ConnectionURL implements a MongoDB connection struct. +type ConnectionURL struct { + User string + Password string + Host string + Database string + Options map[string]string +} + +func (c ConnectionURL) String() (s string) { + + if c.Database == "" { + return "" + } + + vv := url.Values{} + + // Do we have any options? + if c.Options == nil { + c.Options = map[string]string{} + } + + // Converting options into URL values. + for k, v := range c.Options { + vv.Set(k, v) + } + + // Has user? + var userInfo *url.Userinfo + + if c.User != "" { + if c.Password == "" { + userInfo = url.User(c.User) + } else { + userInfo = url.UserPassword(c.User, c.Password) + } + } + + // Building URL. + u := url.URL{ + Scheme: connectionScheme, + Path: c.Database, + Host: c.Host, + User: userInfo, + RawQuery: vv.Encode(), + } + + return u.String() +} + +// ParseURL parses s into a ConnectionURL struct. +func ParseURL(s string) (conn ConnectionURL, err error) { + var u *url.URL + + if !strings.HasPrefix(s, connectionScheme+"://") { + return conn, fmt.Errorf(`Expecting mongodb:// connection scheme.`) + } + + if u, err = url.Parse(s); err != nil { + return conn, err + } + + conn.Host = u.Host + + // Deleting / from start of the string. + conn.Database = strings.Trim(u.Path, "/") + + // Adding user / password. + if u.User != nil { + conn.User = u.User.Username() + conn.Password, _ = u.User.Password() + } + + // Adding options. + conn.Options = map[string]string{} + + var vv url.Values + + if vv, err = url.ParseQuery(u.RawQuery); err != nil { + return conn, err + } + + for k := range vv { + conn.Options[k] = vv.Get(k) + } + + return conn, err +} diff --git a/adapter/mongo/connection_test.go b/adapter/mongo/connection_test.go new file mode 100644 index 0000000..e60b323 --- /dev/null +++ b/adapter/mongo/connection_test.go @@ -0,0 +1,114 @@ +package mongo + +import ( + "testing" +) + +func TestConnectionURL(t *testing.T) { + + c := ConnectionURL{} + + // Default connection string is only the protocol. + if c.String() != "" { + t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String()) + } + + // Adding a database name. + c.Database = "myfilename" + + if c.String() != "mongodb://myfilename" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Adding an option. + c.Options = map[string]string{ + "cache": "foobar", + "mode": "ro", + } + + // Adding username and password + c.User = "user" + c.Password = "pass" + + // Setting host. + c.Host = "localhost" + + if c.String() != "mongodb://user:pass@localhost/myfilename?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting host and port. + c.Host = "localhost:27017" + + if c.String() != "mongodb://user:pass@localhost:27017/myfilename?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting cluster. + c.Host = "localhost,1.2.3.4,example.org:1234" + + if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/myfilename?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting another database. + c.Database = "another_database" + + if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + +} + +func TestParseConnectionURL(t *testing.T) { + var u ConnectionURL + var s string + var err error + + s = "mongodb:///mydatabase" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "mydatabase" { + t.Fatal("Failed to parse database.") + } + + s = "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "another_database" { + t.Fatal("Failed to get database.") + } + + if u.Options["cache"] != "foobar" { + t.Fatal("Expecting option.") + } + + if u.Options["mode"] != "ro" { + t.Fatal("Expecting option.") + } + + if u.User != "user" { + t.Fatal("Expecting user.") + } + + if u.Password != "pass" { + t.Fatal("Expecting password.") + } + + if u.Host != "localhost,1.2.3.4,example.org:1234" { + t.Fatal("Expecting host.") + } + + s = "http://example.org" + + if _, err = ParseURL(s); err == nil { + t.Fatal("Expecting error.") + } + +} diff --git a/adapter/mongo/database.go b/adapter/mongo/database.go new file mode 100644 index 0000000..e2cdb50 --- /dev/null +++ b/adapter/mongo/database.go @@ -0,0 +1,245 @@ +// Package mongo wraps the gopkg.in/mgo.v2 MongoDB driver. See +// https://github.com/upper/db/adapter/mongo for documentation, particularities and usage +// examples. +package mongo + +import ( + "context" + "database/sql" + "strings" + "sync" + "time" + + "git.hexq.cn/tiglog/mydb" + mgo "gopkg.in/mgo.v2" +) + +// Adapter holds the name of the mongodb adapter. +const Adapter = `mongo` + +var connTimeout = time.Second * 5 + +// Source represents a MongoDB database. +type Source struct { + mydb.Settings + + ctx context.Context + + name string + connURL mydb.ConnectionURL + session *mgo.Session + database *mgo.Database + version []int + collections map[string]*Collection + collectionsMu sync.Mutex +} + +type mongoAdapter struct { +} + +func (mongoAdapter) Open(dsn mydb.ConnectionURL) (mydb.Session, error) { + return Open(dsn) +} + +func init() { + mydb.RegisterAdapter(Adapter, mydb.Adapter(&mongoAdapter{})) +} + +// Open stablishes a new connection to a SQL server. +func Open(settings mydb.ConnectionURL) (mydb.Session, error) { + d := &Source{Settings: mydb.NewSettings(), ctx: context.Background()} + if err := d.Open(settings); err != nil { + return nil, err + } + return d, nil +} + +func (s *Source) TxContext(context.Context, func(tx mydb.Session) error, *sql.TxOptions) error { + return mydb.ErrNotSupportedByAdapter +} + +func (s *Source) Tx(func(mydb.Session) error) error { + return mydb.ErrNotSupportedByAdapter +} + +func (s *Source) SQL() mydb.SQL { + // Not supported + panic("sql builder is not supported by mongodb") +} + +func (s *Source) ConnectionURL() mydb.ConnectionURL { + return s.connURL +} + +// SetConnMaxLifetime is not supported. +func (s *Source) SetConnMaxLifetime(time.Duration) { + s.Settings.SetConnMaxLifetime(time.Duration(0)) +} + +// SetMaxIdleConns is not supported. +func (s *Source) SetMaxIdleConns(int) { + s.Settings.SetMaxIdleConns(0) +} + +// SetMaxOpenConns is not supported. +func (s *Source) SetMaxOpenConns(int) { + s.Settings.SetMaxOpenConns(0) +} + +// Name returns the name of the database. +func (s *Source) Name() string { + return s.name +} + +// Open attempts to connect to the database. +func (s *Source) Open(connURL mydb.ConnectionURL) error { + s.connURL = connURL + return s.open() +} + +// Clone returns a cloned mydb.Session session. +func (s *Source) Clone() (mydb.Session, error) { + newSession := s.session.Copy() + clone := &Source{ + Settings: mydb.NewSettings(), + + name: s.name, + connURL: s.connURL, + session: newSession, + database: newSession.DB(s.database.Name), + version: s.version, + collections: map[string]*Collection{}, + } + return clone, nil +} + +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (s *Source) Ping() error { + return s.session.Ping() +} + +func (s *Source) Reset() { + s.collectionsMu.Lock() + defer s.collectionsMu.Unlock() + s.collections = make(map[string]*Collection) +} + +// Driver returns the underlying *mgo.Session instance. +func (s *Source) Driver() interface{} { + return s.session +} + +func (s *Source) open() error { + var err error + + if s.session, err = mgo.DialWithTimeout(s.connURL.String(), connTimeout); err != nil { + return err + } + + s.collections = map[string]*Collection{} + s.database = s.session.DB("") + + return nil +} + +// Close terminates the current database session. +func (s *Source) Close() error { + if s.session != nil { + s.session.Close() + } + return nil +} + +// Collections returns a list of non-system tables from the database. +func (s *Source) Collections() (cols []mydb.Collection, err error) { + var rawcols []string + var col string + + if rawcols, err = s.database.CollectionNames(); err != nil { + return nil, err + } + + cols = make([]mydb.Collection, 0, len(rawcols)) + + for _, col = range rawcols { + if !strings.HasPrefix(col, "system.") { + cols = append(cols, s.Collection(col)) + } + } + + return cols, nil +} + +func (s *Source) Delete(mydb.Record) error { + return mydb.ErrNotImplemented +} + +func (s *Source) Get(mydb.Record, interface{}) error { + return mydb.ErrNotImplemented +} + +func (s *Source) Save(mydb.Record) error { + return mydb.ErrNotImplemented +} + +func (s *Source) Context() context.Context { + return s.ctx +} + +func (s *Source) WithContext(ctx context.Context) mydb.Session { + return &Source{ + ctx: ctx, + Settings: s.Settings, + name: s.name, + connURL: s.connURL, + session: s.session, + database: s.database, + version: s.version, + } +} + +// Collection returns a collection by name. +func (s *Source) Collection(name string) mydb.Collection { + s.collectionsMu.Lock() + defer s.collectionsMu.Unlock() + + var col *Collection + var ok bool + + if col, ok = s.collections[name]; !ok { + col = &Collection{ + parent: s, + collection: s.database.C(name), + } + s.collections[name] = col + } + + return col +} + +func (s *Source) versionAtLeast(version ...int) bool { + // only fetch this once - it makes a db call + if len(s.version) == 0 { + buildInfo, err := s.database.Session.BuildInfo() + if err != nil { + return false + } + s.version = buildInfo.VersionArray + } + + // Check major version first + if s.version[0] > version[0] { + return true + } + + for i := range version { + if i == len(s.version) { + return false + } + if s.version[i] < version[i] { + return false + } + } + return true +} diff --git a/adapter/mongo/docker-compose.yml b/adapter/mongo/docker-compose.yml new file mode 100644 index 0000000..af27525 --- /dev/null +++ b/adapter/mongo/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3' + +services: + + server: + image: mongo:${MONGO_VERSION:-3} + environment: + MONGO_INITDB_ROOT_USERNAME: ${DB_USERNAME:-upperio_user} + MONGO_INITDB_ROOT_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37} + MONGO_INITDB_DATABASE: ${DB_NAME:-upperio} + ports: + - '${BIND_HOST:-127.0.0.1}:${DB_PORT:-27017}:27017' + diff --git a/adapter/mongo/generic_test.go b/adapter/mongo/generic_test.go new file mode 100644 index 0000000..13840af --- /dev/null +++ b/adapter/mongo/generic_test.go @@ -0,0 +1,20 @@ +package mongo + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/mongo/helper_test.go b/adapter/mongo/helper_test.go new file mode 100644 index 0000000..6db7323 --- /dev/null +++ b/adapter/mongo/helper_test.go @@ -0,0 +1,77 @@ +package mongo + +import ( + "fmt" + "os" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + mgo "gopkg.in/mgo.v2" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), + User: os.Getenv("DB_USERNAME"), + Password: os.Getenv("DB_PASSWORD"), + Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"), +} + +type Helper struct { + sess mydb.Session +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return "mongo" +} + +func (h *Helper) TearDown() error { + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + mgod, ok := h.sess.Driver().(*mgo.Session) + if !ok { + panic("expecting mgo.Session") + } + + var col *mgo.Collection + col = mgod.DB(settings.Database).C("birthdays") + _ = col.DropCollection() + + col = mgod.DB(settings.Database).C("fibonacci") + _ = col.DropCollection() + + col = mgod.DB(settings.Database).C("is_even") + _ = col.DropCollection() + + col = mgod.DB(settings.Database).C("CaSe_TesT") + _ = col.DropCollection() + + // Getting a pointer to the "artist" collection. + artist := h.sess.Collection("artist") + + _ = artist.Truncate() + for i := 0; i < 999; i++ { + _, err = artist.Insert(artistType{ + Name: fmt.Sprintf("artist-%d", i), + }) + if err != nil { + return err + } + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/mongo/mongo_test.go b/adapter/mongo/mongo_test.go new file mode 100644 index 0000000..8e0f42e --- /dev/null +++ b/adapter/mongo/mongo_test.go @@ -0,0 +1,754 @@ +// Tests for the mongodb adapter. +package mongo + +import ( + "fmt" + "log" + "math/rand" + "strings" + "testing" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" + "gopkg.in/mgo.v2/bson" +) + +type artistType struct { + ID bson.ObjectId `bson:"_id,omitempty"` + Name string `bson:"name"` +} + +// Structure for testing conversions and datatypes. +type testValuesStruct struct { + Uint uint `bson:"_uint"` + Uint8 uint8 `bson:"_uint8"` + Uint16 uint16 `bson:"_uint16"` + Uint32 uint32 `bson:"_uint32"` + Uint64 uint64 `bson:"_uint64"` + + Int int `bson:"_int"` + Int8 int8 `bson:"_int8"` + Int16 int16 `bson:"_int16"` + Int32 int32 `bson:"_int32"` + Int64 int64 `bson:"_int64"` + + Float32 float32 `bson:"_float32"` + Float64 float64 `bson:"_float64"` + + Bool bool `bson:"_bool"` + String string `bson:"_string"` + + Date time.Time `bson:"_date"` + DateN *time.Time `bson:"_nildate"` + DateP *time.Time `bson:"_ptrdate"` + Time time.Duration `bson:"_time"` +} + +var testValues testValuesStruct + +func init() { + t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) + + testValues = testValuesStruct{ + 1, 1, 1, 1, 1, + -1, -1, -1, -1, -1, + 1.337, 1.337, + true, + "Hello world!", + t, + nil, + &t, + time.Second * time.Duration(7331), + } +} + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) TestOpenWithWrongData() { + var err error + var rightSettings, wrongSettings ConnectionURL + + // Attempt to open with safe settings. + rightSettings = ConnectionURL{ + Database: settings.Database, + Host: settings.Host, + User: settings.User, + Password: settings.Password, + } + + // Attempt to open an empty database. + _, err = Open(rightSettings) + s.NoError(err) + + // Attempt to open with wrong password. + wrongSettings = ConnectionURL{ + Database: settings.Database, + Host: settings.Host, + User: settings.User, + Password: "fail", + } + + _, err = Open(wrongSettings) + s.Error(err) + + // Attempt to open with wrong database. + wrongSettings = ConnectionURL{ + Database: "fail", + Host: settings.Host, + User: settings.User, + Password: settings.Password, + } + + _, err = Open(wrongSettings) + s.Error(err) + + // Attempt to open with wrong username. + wrongSettings = ConnectionURL{ + Database: settings.Database, + Host: settings.Host, + User: "fail", + Password: settings.Password, + } + + _, err = Open(wrongSettings) + s.Error(err) +} + +func (s *AdapterTests) TestTruncate() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a list of all collections in this database. + collections, err := sess.Collections() + s.NoError(err) + + for _, col := range collections { + // The collection may ot may not exists. + if ok, _ := col.Exists(); ok { + // Truncating the structure, if exists. + err = col.Truncate() + s.NoError(err) + } + } +} + +func (s *AdapterTests) TestInsert() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + _ = artist.Truncate() + + // Inserting a map. + record, err := artist.Insert(map[string]string{ + "name": "Ozzie", + }) + s.NoError(err) + + id := record.ID() + s.NotZero(record.ID()) + + _, ok := id.(bson.ObjectId) + s.True(ok) + + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a struct. + record, err = artist.Insert(struct { + Name string + }{ + "Flea", + }) + s.NoError(err) + + id = record.ID() + s.NotZero(id) + + _, ok = id.(bson.ObjectId) + s.True(ok) + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a struct (using tags to specify the field name). + record, err = artist.Insert(struct { + ArtistName string `bson:"name"` + }{ + "Slash", + }) + s.NoError(err) + + id = record.ID() + s.NotNil(id) + + _, ok = id.(bson.ObjectId) + + s.True(ok) + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a pointer to a struct + record, err = artist.Insert(&struct { + ArtistName string `bson:"name"` + }{ + "Metallica", + }) + s.NoError(err) + + id = record.ID() + s.NotNil(id) + + _, ok = id.(bson.ObjectId) + s.True(ok) + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a pointer to a map + record, err = artist.Insert(&map[string]string{ + "name": "Freddie", + }) + s.NoError(err) + s.NotZero(id) + + _, ok = id.(bson.ObjectId) + s.True(ok) + + id = record.ID() + s.NotNil(id) + + s.True(id.(bson.ObjectId).Valid()) + + // Counting elements, must be exactly 6 elements. + total, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(5), total) +} + +func (s *AdapterTests) TestGetNonExistentRow_Issue426() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + defer sess.Close() + + artist := sess.Collection("artist") + + var one artistType + err = artist.Find(mydb.Cond{"name": "nothing"}).One(&one) + + s.NotZero(err) + s.Equal(mydb.ErrNoMoreRows, err) + + var all []artistType + err = artist.Find(mydb.Cond{"name": "nothing"}).All(&all) + + s.Zero(err, "All should not return mgo.ErrNotFound") + s.Equal(0, len(all)) +} + +func (s *AdapterTests) TestResultCount() { + var err error + var res mydb.Result + + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + defer sess.Close() + + // We should close the database when it's no longer in use. + artist := sess.Collection("artist") + + res = artist.Find() + + // Counting all the matching rows. + total, err := res.Count() + s.NoError(err) + s.NotZero(total) +} + +func (s *AdapterTests) TestGroup() { + var stats mydb.Collection + + sess, err := Open(settings) + s.NoError(err) + + type statsT struct { + Numeric int `db:"numeric" bson:"numeric"` + Value int `db:"value" bson:"value"` + } + + defer sess.Close() + + stats = sess.Collection("statsTest") + + // Truncating table. + _ = stats.Truncate() + + // Adding row append. + for i := 0; i < 1000; i++ { + numeric, value := rand.Intn(10), rand.Intn(100) + _, err = stats.Insert(statsT{numeric, value}) + s.NoError(err) + } + + // mydb.statsTest.group({key: {numeric: true}, initial: {sum: 0}, reduce: function(doc, prev) { prev.sum += 1}}); + + // Testing GROUP BY + res := stats.Find().GroupBy(bson.M{ + "key": bson.M{"numeric": true}, + "initial": bson.M{"sum": 0}, + "reduce": `function(doc, prev) { prev.sum += 1}`, + }) + + var results []map[string]interface{} + + err = res.All(&results) + s.Equal(mydb.ErrUnsupported, err) +} + +func (s *AdapterTests) TestResultNonExistentCount() { + sess, err := Open(settings) + s.NoError(err) + + defer sess.Close() + + total, err := sess.Collection("notartist").Find().Count() + s.NoError(err) + s.Zero(total) +} + +func (s *AdapterTests) TestResultFetch() { + + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + artist := sess.Collection("artist") + + // Testing map + res := artist.Find() + + rowM := map[string]interface{}{} + + for res.Next(&rowM) { + s.NotZero(rowM["_id"]) + + _, ok := rowM["_id"].(bson.ObjectId) + s.True(ok) + + s.True(rowM["_id"].(bson.ObjectId).Valid()) + + name, ok := rowM["name"].(string) + s.True(ok) + s.NotZero(name) + } + + err = res.Close() + s.NoError(err) + + // Testing struct + rowS := struct { + ID bson.ObjectId `bson:"_id"` + Name string `bson:"name"` + }{} + + res = artist.Find() + + for res.Next(&rowS) { + s.True(rowS.ID.Valid()) + s.NotZero(rowS.Name) + } + + err = res.Close() + s.NoError(err) + + // Testing tagged struct + rowT := struct { + Value1 bson.ObjectId `bson:"_id"` + Value2 string `bson:"name"` + }{} + + res = artist.Find() + + for res.Next(&rowT) { + s.True(rowT.Value1.Valid()) + s.NotZero(rowT.Value2) + } + + err = res.Close() + s.NoError(err) + + // Testing Result.All() with a slice of maps. + res = artist.Find() + + allRowsM := []map[string]interface{}{} + err = res.All(&allRowsM) + s.NoError(err) + + for _, singleRowM := range allRowsM { + s.NotZero(singleRowM["_id"]) + } + + // Testing Result.All() with a slice of structs. + res = artist.Find() + + allRowsS := []struct { + ID bson.ObjectId `bson:"_id"` + Name string + }{} + err = res.All(&allRowsS) + s.NoError(err) + + for _, singleRowS := range allRowsS { + s.True(singleRowS.ID.Valid()) + } + + // Testing Result.All() with a slice of tagged structs. + res = artist.Find() + + allRowsT := []struct { + Value1 bson.ObjectId `bson:"_id"` + Value2 string `bson:"name"` + }{} + err = res.All(&allRowsT) + s.NoError(err) + + for _, singleRowT := range allRowsT { + s.True(singleRowT.Value1.Valid()) + } +} + +func (s *AdapterTests) TestUpdate() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + // Value + value := struct { + ID bson.ObjectId `bson:"_id"` + Name string + }{} + + // Getting the first artist. + res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1) + + err = res.One(&value) + s.NoError(err) + + // Updating with a map + rowM := map[string]interface{}{ + "name": strings.ToUpper(value.Name), + } + + err = res.Update(rowM) + s.NoError(err) + + err = res.One(&value) + s.NoError(err) + + s.Equal(value.Name, rowM["name"]) + + // Updating with a struct + rowS := struct { + Name string + }{strings.ToLower(value.Name)} + + err = res.Update(rowS) + s.NoError(err) + + err = res.One(&value) + s.NoError(err) + + s.Equal(value.Name, rowS.Name) + + // Updating with a tagged struct + rowT := struct { + Value1 string `bson:"name"` + }{strings.Replace(value.Name, "z", "Z", -1)} + + err = res.Update(rowT) + s.NoError(err) + + err = res.One(&value) + s.NoError(err) + + s.Equal(value.Name, rowT.Value1) +} + +func (s *AdapterTests) TestOperators() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + rowS := struct { + ID uint64 + Name string + }{} + + res := artist.Find(mydb.Cond{"_id": mydb.NotIn(0, -1)}) + + err = res.One(&rowS) + s.NoError(err) + + err = res.Close() + s.NoError(err) +} + +func (s *AdapterTests) TestDelete() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + // Getting the first artist. + res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1) + + var first struct { + ID bson.ObjectId `bson:"_id"` + } + + err = res.One(&first) + s.NoError(err) + + res = artist.Find(mydb.Cond{"_id": mydb.Eq(first.ID)}) + + // Trying to remove the row. + err = res.Delete() + s.NoError(err) +} + +func (s *AdapterTests) TestDataTypes() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "data_types" collection. + dataTypes := sess.Collection("data_types") + + // Inserting our test subject. + record, err := dataTypes.Insert(testValues) + s.NoError(err) + + id := record.ID() + s.NotZero(id) + + // Trying to get the same subject we added. + res := dataTypes.Find(mydb.Cond{"_id": mydb.Eq(id)}) + + exists, err := res.Count() + s.NoError(err) + s.NotZero(exists) + + // Trying to dump the subject into an empty structure of the same type. + var item testValuesStruct + err = res.One(&item) + s.NoError(err) + + // The original value and the test subject must match. + s.Equal(testValues, item) +} + +func (s *AdapterTests) TestPaginator() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + err = artist.Truncate() + s.NoError(err) + + for i := 0; i < 999; i++ { + _, err = artist.Insert(artistType{ + Name: fmt.Sprintf("artist-%d", i), + }) + s.NoError(err) + } + + q := sess.Collection("artist").Find().Paginate(15) + paginator := q.Paginate(13) + + var zerothPage []artistType + err = paginator.Page(0).All(&zerothPage) + s.NoError(err) + s.Equal(13, len(zerothPage)) + + var secondPage []artistType + err = paginator.Page(2).All(&secondPage) + s.NoError(err) + s.Equal(13, len(secondPage)) + + tp, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(tp) + s.Equal(uint(77), tp) + + ti, err := paginator.TotalEntries() + s.NoError(err) + s.NotZero(ti) + s.Equal(uint64(999), ti) + + var seventySixthPage []artistType + err = paginator.Page(76).All(&seventySixthPage) + s.NoError(err) + s.Equal(11, len(seventySixthPage)) + + var seventySeventhPage []artistType + err = paginator.Page(77).All(&seventySeventhPage) + s.NoError(err) + s.Equal(0, len(seventySeventhPage)) + + var hundredthPage []artistType + err = paginator.Page(100).All(&hundredthPage) + s.NoError(err) + s.Equal(0, len(hundredthPage)) + + for i := uint(0); i < tp; i++ { + current := paginator.Page(i) + + var items []artistType + err := current.All(&items) + s.NoError(err) + if len(items) < 1 { + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name) + } + } + + paginator = paginator.Cursor("_id") + { + current := paginator.Page(0) + for i := 0; ; i++ { + var items []artistType + err := current.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name) + } + current = current.NextPage(items[len(items)-1].ID) + } + } + + { + log.Printf("Page 76") + current := paginator.Page(76) + for i := 76; ; i-- { + var items []artistType + + err := current.All(&items) + s.NoError(err) + + if len(items) < 1 { + s.Equal(0, len(items)) + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 13*int(i)+j), items[j].Name) + } + + current = current.PrevPage(items[0].ID) + } + } + + { + resultPaginator := sess.Collection("artist").Find().Paginate(15) + + count, err := resultPaginator.TotalPages() + s.Equal(uint(67), count) + s.NoError(err) + + var items []artistType + err = resultPaginator.Page(5).All(&items) + s.NoError(err) + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 15*5+j), items[j].Name) + } + + resultPaginator = resultPaginator.Cursor("_id").Page(0) + for i := 0; ; i++ { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name) + } + resultPaginator = resultPaginator.NextPage(items[len(items)-1].ID) + } + + resultPaginator = resultPaginator.Cursor("_id").Page(66) + for i := 66; ; i-- { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name) + } + resultPaginator = resultPaginator.PrevPage(items[0].ID) + } + } +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/mongo/result.go b/adapter/mongo/result.go new file mode 100644 index 0000000..f19d9ef --- /dev/null +++ b/adapter/mongo/result.go @@ -0,0 +1,565 @@ +package mongo + +import ( + "errors" + "fmt" + "math" + "strings" + "sync" + "time" + + "encoding/json" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + mgo "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +type resultQuery struct { + c *Collection + + fields []string + limit int + offset int + sort []string + conditions interface{} + groupBy []interface{} + + pageSize uint + pageNumber uint + cursorColumn string + cursorValue interface{} + cursorCond mydb.Cond + cursorReverseOrder bool +} + +type result struct { + iter *mgo.Iter + err error + errMu sync.Mutex + + fn func(*resultQuery) error + prev *result +} + +var _ = immutable.Immutable(&result{}) + +func (res *result) frame(fn func(*resultQuery) error) *result { + return &result{prev: res, fn: fn} +} + +func (r *resultQuery) and(terms ...interface{}) error { + if r.conditions == nil { + return r.where(terms...) + } + + r.conditions = map[string]interface{}{ + "$and": []interface{}{ + r.conditions, + r.c.compileQuery(terms...), + }, + } + return nil +} + +func (r *resultQuery) where(terms ...interface{}) error { + r.conditions = r.c.compileQuery(terms...) + return nil +} + +func (res *result) And(terms ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + return r.and(terms...) + }) +} + +func (res *result) Where(terms ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + return r.where(terms...) + }) +} + +func (res *result) Paginate(pageSize uint) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.pageSize = pageSize + return nil + }) +} + +func (res *result) Page(pageNumber uint) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.pageNumber = pageNumber + return nil + }) +} + +func (res *result) Cursor(cursorColumn string) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.cursorColumn = cursorColumn + return nil + }) +} + +func (res *result) NextPage(cursorValue interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.cursorValue = cursorValue + r.cursorReverseOrder = false + r.cursorCond = mydb.Cond{ + r.cursorColumn: bson.M{"$gt": cursorValue}, + } + return nil + }) +} + +func (res *result) PrevPage(cursorValue interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.cursorValue = cursorValue + r.cursorReverseOrder = true + r.cursorCond = mydb.Cond{ + r.cursorColumn: bson.M{"$lt": cursorValue}, + } + return nil + }) +} + +func (res *result) TotalEntries() (uint64, error) { + return res.Count() +} + +func (res *result) TotalPages() (uint, error) { + count, err := res.Count() + if err != nil { + return 0, err + } + + rq, err := res.build() + if err != nil { + return 0, err + } + + if rq.pageSize < 1 { + return 1, nil + } + + total := uint(math.Ceil(float64(count) / float64(rq.pageSize))) + return total, nil +} + +// Limit determines the maximum limit of results to be returned. +func (res *result) Limit(n int) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.limit = n + return nil + }) +} + +// Offset determines how many documents will be skipped before starting to grab +// results. +func (res *result) Offset(n int) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.offset = n + return nil + }) +} + +// OrderBy determines sorting of results according to the provided names. Fields +// may be prefixed by - (minus) which means descending order, ascending order +// would be used otherwise. +func (res *result) OrderBy(fields ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + ss := make([]string, len(fields)) + for i, field := range fields { + ss[i] = fmt.Sprintf(`%v`, field) + } + r.sort = ss + return nil + }) +} + +// String satisfies fmt.Stringer +func (res *result) String() string { + return "" +} + +// Select marks the specific fields the user wants to retrieve. +func (res *result) Select(fields ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + fieldslen := len(fields) + r.fields = make([]string, 0, fieldslen) + for i := 0; i < fieldslen; i++ { + r.fields = append(r.fields, fmt.Sprintf(`%v`, fields[i])) + } + return nil + }) +} + +// All dumps all results into a pointer to an slice of structs or maps. +func (res *result) All(dst interface{}) error { + rq, err := res.build() + if err != nil { + return err + } + + q, err := rq.query() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.All"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + err = q.All(dst) + if errors.Is(err, mgo.ErrNotFound) { + return mydb.ErrNoMoreRows + } + return err +} + +// GroupBy is used to group results that have the same value in the same column +// or columns. +func (res *result) GroupBy(fields ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.groupBy = fields + return nil + }) +} + +// One fetches only one result from the resultset. +func (res *result) One(dst interface{}) error { + rq, err := res.build() + if err != nil { + return err + } + + q, err := rq.query() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.One"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + err = q.One(dst) + if errors.Is(err, mgo.ErrNotFound) { + return mydb.ErrNoMoreRows + } + return err +} + +func (res *result) Err() error { + res.errMu.Lock() + defer res.errMu.Unlock() + + return res.err +} + +func (res *result) setErr(err error) { + res.errMu.Lock() + defer res.errMu.Unlock() + res.err = err +} + +func (res *result) Next(dst interface{}) bool { + if res.iter == nil { + rq, err := res.build() + if err != nil { + return false + } + + q, err := rq.query() + if err != nil { + return false + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.Next"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + res.iter = q.Iter() + } + + if !res.iter.Next(dst) { + res.setErr(res.iter.Err()) + return false + } + + return true +} + +// Delete remove the matching items from the collection. +func (res *result) Delete() error { + rq, err := res.build() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Remove"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + _, err = rq.c.collection.RemoveAll(rq.conditions) + if err != nil { + return err + } + + return nil +} + +// Close closes the result set. +func (r *result) Close() error { + var err error + if r.iter != nil { + err = r.iter.Close() + r.iter = nil + } + return err +} + +// Update modified matching items from the collection with values of the given +// map or struct. +func (res *result) Update(src interface{}) (err error) { + updateSet := map[string]interface{}{"$set": src} + + rq, err := res.build() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Update"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + _, err = rq.c.collection.UpdateAll(rq.conditions, updateSet) + if err != nil { + return err + } + return nil +} + +func (res *result) build() (*resultQuery, error) { + rqi, err := immutable.FastForward(res) + if err != nil { + return nil, err + } + + rq := rqi.(*resultQuery) + if !rq.cursorCond.Empty() { + if err := rq.and(rq.cursorCond); err != nil { + return nil, err + } + } + + if rq.cursorColumn != "" { + if rq.cursorReverseOrder { + rq.sort = append(rq.sort, "-"+rq.cursorColumn) + } else { + rq.sort = append(rq.sort, rq.cursorColumn) + } + } + return rq, nil +} + +// query executes a mgo query. +func (r *resultQuery) query() (*mgo.Query, error) { + if len(r.groupBy) > 0 { + return nil, mydb.ErrUnsupported + } + + q := r.c.collection.Find(r.conditions) + + if r.pageSize > 0 { + r.offset = int(r.pageSize * r.pageNumber) + r.limit = int(r.pageSize) + } + + if r.offset > 0 { + q.Skip(r.offset) + } + + if r.limit > 0 { + q.Limit(r.limit) + } + + if len(r.sort) > 0 { + q.Sort(r.sort...) + } + + selectedFields := bson.M{} + if len(r.fields) > 0 { + for _, field := range r.fields { + if field == `*` { + break + } + selectedFields[field] = true + } + } + + if r.cursorReverseOrder { + ids := make([]bson.ObjectId, 0, r.limit) + + iter := q.Select(bson.M{"_id": true}).Iter() + defer iter.Close() + + var item map[string]bson.ObjectId + for iter.Next(&item) { + ids = append(ids, item["_id"]) + } + + r.conditions = bson.M{"_id": bson.M{"$in": ids}} + + q = r.c.collection.Find(r.conditions) + } + + if len(selectedFields) > 0 { + q.Select(selectedFields) + } + + return q, nil +} + +func (res *result) Exists() (bool, error) { + total, err := res.Count() + if err != nil { + return false, err + } + if total > 0 { + return true, nil + } + return false, nil +} + +// Count counts matching elements. +func (res *result) Count() (total uint64, err error) { + rq, err := res.build() + if err != nil { + return 0, err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.Count"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + q := rq.c.collection.Find(rq.conditions) + + var c int + c, err = q.Count() + + return uint64(c), err +} + +func (res *result) Prev() immutable.Immutable { + if res == nil { + return nil + } + return res.prev +} + +func (res *result) Fn(in interface{}) error { + if res.fn == nil { + return nil + } + return res.fn(in.(*resultQuery)) +} + +func (res *result) Base() interface{} { + return &resultQuery{} +} + +func (r *resultQuery) debugQuery(action string) string { + query := fmt.Sprintf("mydb.%s.%s", r.c.collection.Name, action) + + if r.conditions != nil { + query = fmt.Sprintf("%s.conds(%v)", query, r.conditions) + } + if r.limit > 0 { + query = fmt.Sprintf("%s.limit(%d)", query, r.limit) + } + if r.offset > 0 { + query = fmt.Sprintf("%s.offset(%d)", query, r.offset) + } + if len(r.fields) > 0 { + selectedFields := bson.M{} + for _, field := range r.fields { + if field == `*` { + break + } + selectedFields[field] = true + } + if len(selectedFields) > 0 { + query = fmt.Sprintf("%s.select(%v)", query, selectedFields) + } + } + if len(r.groupBy) > 0 { + escaped := make([]string, len(r.groupBy)) + for i := range r.groupBy { + escaped[i] = string(mustJSON(r.groupBy[i])) + } + query = fmt.Sprintf("%s.groupBy(%v)", query, strings.Join(escaped, ", ")) + } + if len(r.sort) > 0 { + escaped := make([]string, len(r.sort)) + for i := range r.sort { + escaped[i] = string(mustJSON(r.sort[i])) + } + query = fmt.Sprintf("%s.sort(%s)", query, strings.Join(escaped, ", ")) + } + return query +} + +func mustJSON(in interface{}) (out []byte) { + out, err := json.Marshal(in) + if err != nil { + panic(err) + } + return out +} + +func queryLog(status *mydb.QueryStatus) { + diff := status.End.Sub(status.Start) + + slowQuery := false + if diff >= time.Millisecond*100 { + status.Err = mydb.ErrWarnSlowQuery + slowQuery = true + } + + if status.Err != nil || slowQuery { + mydb.LC().Warn(status) + return + } + + mydb.LC().Debug(status) +} diff --git a/adapter/mysql/Makefile b/adapter/mysql/Makefile new file mode 100644 index 0000000..f3a8759 --- /dev/null +++ b/adapter/mysql/Makefile @@ -0,0 +1,43 @@ +SHELL ?= bash + +MYSQL_VERSION ?= 8.1 +MYSQL_SUPPORTED ?= $(MYSQL_VERSION) 5.7 +PROJECT ?= upper_mysql_$(MYSQL_VERSION) + +DB_HOST ?= 127.0.0.1 +DB_PORT ?= 3306 + +DB_NAME ?= upperio +DB_USERNAME ?= upperio_user +DB_PASSWORD ?= upperio//s3cr37 + +TEST_FLAGS ?= +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1 + +export MYSQL_VERSION + +export DB_HOST +export DB_NAME +export DB_PASSWORD +export DB_PORT +export DB_USERNAME + +export TEST_FLAGS + +test: + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +server-up: server-down + docker-compose -p $(PROJECT) up -d && \ + sleep 15 + +server-down: + docker-compose -p $(PROJECT) down + +test-extended: + parallel $(PARALLEL_FLAGS) \ + "MYSQL_VERSION={} DB_PORT=\$$((3306+{#})) $(MAKE) server-up test server-down" ::: \ + $(MYSQL_SUPPORTED) diff --git a/adapter/mysql/README.md b/adapter/mysql/README.md new file mode 100644 index 0000000..f427fee --- /dev/null +++ b/adapter/mysql/README.md @@ -0,0 +1,5 @@ +# MySQL adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/mysql/](https://upper.io/v4/adapter/mysql/). + diff --git a/adapter/mysql/collection.go b/adapter/mysql/collection.go new file mode 100644 index 0000000..57a5275 --- /dev/null +++ b/adapter/mysql/collection.go @@ -0,0 +1,56 @@ +package mysql + +import ( + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +type collectionAdapter struct { +} + +func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) { + columnNames, columnValues, err := sqlbuilder.Map(item, nil) + if err != nil { + return nil, err + } + + pKey, err := col.PrimaryKeys() + if err != nil { + return nil, err + } + + q := col.SQL().InsertInto(col.Name()). + Columns(columnNames...). + Values(columnValues...) + + res, err := q.Exec() + if err != nil { + return nil, err + } + + lastID, err := res.LastInsertId() + if err == nil && len(pKey) <= 1 { + return lastID, nil + } + + keyMap := mydb.Cond{} + for i := range columnNames { + for j := 0; j < len(pKey); j++ { + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] + } + } + } + + // There was an auto column among primary keys, let's search for it. + if lastID > 0 { + for j := 0; j < len(pKey); j++ { + if keyMap[pKey[j]] == nil { + keyMap[pKey[j]] = lastID + } + } + } + + return keyMap, nil +} diff --git a/adapter/mysql/connection.go b/adapter/mysql/connection.go new file mode 100644 index 0000000..b9e6237 --- /dev/null +++ b/adapter/mysql/connection.go @@ -0,0 +1,244 @@ +package mysql + +import ( + "errors" + "fmt" + "net" + "net/url" + "strings" +) + +// From https://github.com/go-sql-driver/mysql/blob/master/utils.go +var ( + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") +) + +// From https://github.com/go-sql-driver/mysql/blob/master/utils.go +type config struct { + user string + passwd string + net string + addr string + dbname string + params map[string]string +} + +// ConnectionURL implements a MySQL connection struct. +type ConnectionURL struct { + User string + Password string + Database string + Host string + Socket string + Options map[string]string +} + +func (c ConnectionURL) String() (s string) { + + if c.Database == "" { + return "" + } + + // Adding username. + if c.User != "" { + s = s + c.User + // Adding password. + if c.Password != "" { + s = s + ":" + c.Password + } + s = s + "@" + } + + // Adding protocol and address + if c.Socket != "" { + s = s + fmt.Sprintf("unix(%s)", c.Socket) + } else if c.Host != "" { + host, port, err := net.SplitHostPort(c.Host) + if err != nil { + host = c.Host + port = "3306" + } + s = s + fmt.Sprintf("tcp(%s:%s)", host, port) + } + + // Adding database + s = s + "/" + c.Database + + // Do we have any options? + if c.Options == nil { + c.Options = map[string]string{} + } + + // Default options. + if _, ok := c.Options["charset"]; !ok { + c.Options["charset"] = "utf8" + } + + if _, ok := c.Options["parseTime"]; !ok { + c.Options["parseTime"] = "true" + } + + // Converting options into URL values. + vv := url.Values{} + + for k, v := range c.Options { + vv.Set(k, v) + } + + // Inserting options. + if p := vv.Encode(); p != "" { + s = s + "?" + p + } + + return s +} + +// ParseURL parses s into a ConnectionURL struct. +func ParseURL(s string) (conn ConnectionURL, err error) { + var cfg *config + + if cfg, err = parseDSN(s); err != nil { + return + } + + conn.User = cfg.user + conn.Password = cfg.passwd + + if cfg.net == "unix" { + conn.Socket = cfg.addr + } else if cfg.net == "tcp" { + conn.Host = cfg.addr + } + + conn.Database = cfg.dbname + + conn.Options = map[string]string{} + + for k, v := range cfg.params { + conn.Options[k] = v + } + + return +} + +// from https://github.com/go-sql-driver/mysql/blob/master/utils.go +// parseDSN parses the DSN string to a config +func parseDSN(dsn string) (cfg *config, err error) { + // New config with some default values + cfg = &config{} + + // TODO: use strings.IndexByte when we can depend on Go 1.2 + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.passwd = dsn[k+1 : j] + break + } + } + cfg.user = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.addr = dsn[k+1 : i-1] + break + } + } + cfg.net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.dbname = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + // Set default network if empty + if cfg.net == "" { + cfg.net = "tcp" + } + + // Set default address if empty + if cfg.addr == "" { + switch cfg.net { + case "tcp": + cfg.addr = "127.0.0.1:3306" + case "unix": + cfg.addr = "/tmp/mysql.sock" + default: + return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") + } + + } + + return +} + +// From https://github.com/go-sql-driver/mysql/blob/master/utils.go +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + value := param[1] + + // lazy init + if cfg.params == nil { + cfg.params = make(map[string]string) + } + + if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + + return +} diff --git a/adapter/mysql/connection_test.go b/adapter/mysql/connection_test.go new file mode 100644 index 0000000..172856f --- /dev/null +++ b/adapter/mysql/connection_test.go @@ -0,0 +1,117 @@ +package mysql + +import ( + "testing" +) + +func TestConnectionURL(t *testing.T) { + + c := ConnectionURL{} + + // Zero value equals to an empty string. + if c.String() != "" { + t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String()) + } + + // Adding a database name. + c.Database = "mydbname" + + if c.String() != "/mydbname?charset=utf8&parseTime=true" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Adding an option. + c.Options = map[string]string{ + "charset": "utf8mb4,utf8", + "sys_var": "esc@ped", + } + + if c.String() != "/mydbname?charset=utf8mb4%2Cutf8&parseTime=true&sys_var=esc%40ped" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting default options + c.Options = nil + + // Setting user and password. + c.User = "user" + c.Password = "pass" + + if c.String() != `user:pass@/mydbname?charset=utf8&parseTime=true` { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting host. + c.Host = "1.2.3.4:3306" + + if c.String() != `user:pass@tcp(1.2.3.4:3306)/mydbname?charset=utf8&parseTime=true` { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting socket. + c.Socket = "/path/to/socket" + + if c.String() != `user:pass@unix(/path/to/socket)/mydbname?charset=utf8&parseTime=true` { + t.Fatal(`Test failed, got:`, c.String()) + } + +} + +func TestParseConnectionURL(t *testing.T) { + var u ConnectionURL + var s string + var err error + + s = "user:pass@unix(/path/to/socket)/mydbname?charset=utf8" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.User != "user" { + t.Fatal("Expecting username.") + } + + if u.Password != "pass" { + t.Fatal("Expecting password.") + } + + if u.Socket != "/path/to/socket" { + t.Fatal("Expecting socket.") + } + + if u.Database != "mydbname" { + t.Fatal("Expecting database.") + } + + if u.Options["charset"] != "utf8" { + t.Fatal("Expecting charset.") + } + + s = "user:pass@tcp(1.2.3.4:5678)/mydbname?charset=utf8" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.User != "user" { + t.Fatal("Expecting username.") + } + + if u.Password != "pass" { + t.Fatal("Expecting password.") + } + + if u.Host != "1.2.3.4:5678" { + t.Fatal("Expecting host.") + } + + if u.Database != "mydbname" { + t.Fatal("Expecting database.") + } + + if u.Options["charset"] != "utf8" { + t.Fatal("Expecting charset.") + } + +} diff --git a/adapter/mysql/custom_types.go b/adapter/mysql/custom_types.go new file mode 100644 index 0000000..034f354 --- /dev/null +++ b/adapter/mysql/custom_types.go @@ -0,0 +1,151 @@ +package mysql + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "reflect" + + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// JSON represents a MySQL's JSON value: +// https://www.mysql.org/docs/9.6/static/datatype-json.html. JSON +// satisfies sqlbuilder.ScannerValuer. +type JSON struct { + V interface{} +} + +// MarshalJSON encodes the wrapper value as JSON. +func (j JSON) MarshalJSON() ([]byte, error) { + return json.Marshal(j.V) +} + +// UnmarshalJSON decodes the given JSON into the wrapped value. +func (j *JSON) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + j.V = v + return nil +} + +// Scan satisfies the sql.Scanner interface. +func (j *JSON) Scan(src interface{}) error { + if j.V == nil { + return nil + } + if src == nil { + dv := reflect.Indirect(reflect.ValueOf(j.V)) + dv.Set(reflect.Zero(dv.Type())) + return nil + } + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + if err := json.Unmarshal(b, j.V); err != nil { + return err + } + return nil +} + +// Value satisfies the driver.Valuer interface. +func (j JSON) Value() (driver.Value, error) { + if j.V == nil { + return nil, nil + } + if v, ok := j.V.(json.RawMessage); ok { + return string(v), nil + } + b, err := json.Marshal(j.V) + if err != nil { + return nil, err + } + return string(b), nil +} + +// JSONMap represents a map of interfaces with string keys +// (`map[string]interface{}`) that is compatible with MySQL's JSON type. +// JSONMap satisfies sqlbuilder.ScannerValuer. +type JSONMap map[string]interface{} + +// Value satisfies the driver.Valuer interface. +func (m JSONMap) Value() (driver.Value, error) { + return JSONValue(m) +} + +// Scan satisfies the sql.Scanner interface. +func (m *JSONMap) Scan(src interface{}) error { + *m = map[string]interface{}(nil) + return ScanJSON(m, src) +} + +// JSONArray represents an array of any type (`[]interface{}`) that is +// compatible with MySQL's JSON type. JSONArray satisfies +// sqlbuilder.ScannerValuer. +type JSONArray []interface{} + +// Value satisfies the driver.Valuer interface. +func (a JSONArray) Value() (driver.Value, error) { + return JSONValue(a) +} + +// Scan satisfies the sql.Scanner interface. +func (a *JSONArray) Scan(src interface{}) error { + return ScanJSON(a, src) +} + +// JSONValue takes an interface and provides a driver.Value that can be +// stored as a JSON column. +func JSONValue(i interface{}) (driver.Value, error) { + v := JSON{i} + return v.Value() +} + +// ScanJSON decodes a JSON byte stream into the passed dst value. +func ScanJSON(dst interface{}, src interface{}) error { + v := JSON{dst} + return v.Scan(src) +} + +// EncodeJSON is deprecated and going to be removed. Use ScanJSON instead. +func EncodeJSON(i interface{}) (driver.Value, error) { + return JSONValue(i) +} + +// DecodeJSON is deprecated and going to be removed. Use JSONValue instead. +func DecodeJSON(dst interface{}, src interface{}) error { + return ScanJSON(dst, src) +} + +// JSONConverter provides a helper method WrapValue that satisfies +// sqlbuilder.ValueWrapper, can be used to encode Go structs into JSON +// MySQL types and vice versa. +// +// Example: +// +// type MyCustomStruct struct { +// ID int64 `db:"id" json:"id"` +// Name string `db:"name" json:"name"` +// ... +// mysql.JSONConverter +// } +type JSONConverter struct{} + +func (*JSONConverter) ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer +} { + return &JSON{in} +} + +// Type checks. +var ( + _ sqlbuilder.ScannerValuer = &JSONMap{} + _ sqlbuilder.ScannerValuer = &JSONArray{} + _ sqlbuilder.ScannerValuer = &JSON{} +) diff --git a/adapter/mysql/database.go b/adapter/mysql/database.go new file mode 100644 index 0000000..e601eb7 --- /dev/null +++ b/adapter/mysql/database.go @@ -0,0 +1,168 @@ +// Package mysql wraps the github.com/go-sql-driver/mysql MySQL driver. See +// https://github.com/upper/db/adapter/mysql for documentation, particularities and usage +// examples. +package mysql + +import ( + "reflect" + "strings" + + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + _ "github.com/go-sql-driver/mysql" // MySQL driver. +) + +// database is the actual implementation of Database +type database struct { +} + +func (*database) Template() *exql.Template { + return template +} + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + return sql.Open("mysql", dsn) +} + +func (*database) Collections(sess sqladapter.Session) (collections []string, err error) { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ?", sess.Name()) + + iter := q.Iterator() + defer iter.Close() + + for iter.Next() { + var tableName string + if err := iter.Scan(&tableName); err != nil { + return nil, err + } + collections = append(collections, tableName) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return collections, nil +} + +func (d *database) ConvertValue(in interface{}) interface{} { + switch v := in.(type) { + case *map[string]interface{}: + return (*JSONMap)(v) + + case map[string]interface{}: + return (*JSONMap)(&v) + } + + dv := reflect.ValueOf(in) + if dv.IsValid() { + if dv.Type().Kind() == reflect.Ptr { + dv = dv.Elem() + } + + switch dv.Kind() { + case reflect.Map: + if reflect.TypeOf(in).Kind() == reflect.Ptr { + w := reflect.ValueOf(in) + z := reflect.New(w.Elem().Type()) + w.Elem().Set(z.Elem()) + } + return &JSON{in} + case reflect.Slice: + return &JSON{in} + } + } + + return in +} + +func (*database) Err(err error) error { + if err != nil { + // This error is not exported so we have to check it by its string value. + s := err.Error() + if strings.Contains(s, `many connections`) { + return mydb.ErrTooManyClients + } + } + return err +} + +func (*database) NewCollection() sqladapter.CollectionAdapter { + return &collectionAdapter{} +} + +func (*database) LookupName(sess sqladapter.Session) (string, error) { + q := sess.SQL(). + Select(mydb.Raw("DATABASE() AS name")) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return "", err + } + return name, nil + } + + return "", iter.Err() +} + +func (*database) TableExists(sess sqladapter.Session, name string) error { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ? AND table_name = ?", sess.Name(), name) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return err + } + return nil + } + if err := iter.Err(); err != nil { + return err + } + + return mydb.ErrCollectionDoesNotExist +} + +func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) { + q := sess.SQL(). + Select("k.column_name"). + From("information_schema.key_column_usage AS k"). + Where(` + k.constraint_name = 'PRIMARY' + AND k.table_schema = ? + AND k.table_name = ? + `, sess.Name(), tableName). + OrderBy("k.ordinal_position") + + iter := q.Iterator() + defer iter.Close() + + pk := []string{} + + for iter.Next() { + var k string + if err := iter.Scan(&k); err != nil { + return nil, err + } + pk = append(pk, k) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return pk, nil +} diff --git a/adapter/mysql/docker-compose.yml b/adapter/mysql/docker-compose.yml new file mode 100644 index 0000000..18ab349 --- /dev/null +++ b/adapter/mysql/docker-compose.yml @@ -0,0 +1,14 @@ +version: '3' + +services: + + server: + image: mysql:${MYSQL_VERSION:-5} + environment: + MYSQL_USER: ${DB_USERNAME:-upperio_user} + MYSQL_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37} + MYSQL_ALLOW_EMPTY_PASSWORD: 1 + MYSQL_DATABASE: ${DB_NAME:-upperio} + ports: + - '${DB_HOST:-127.0.0.1}:${DB_PORT:-3306}:3306' + diff --git a/adapter/mysql/generic_test.go b/adapter/mysql/generic_test.go new file mode 100644 index 0000000..db7e430 --- /dev/null +++ b/adapter/mysql/generic_test.go @@ -0,0 +1,20 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/mysql/helper_test.go b/adapter/mysql/helper_test.go new file mode 100644 index 0000000..a14b897 --- /dev/null +++ b/adapter/mysql/helper_test.go @@ -0,0 +1,276 @@ +package mysql + +import ( + "database/sql" + "fmt" + "os" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/testsuite" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), + User: os.Getenv("DB_USERNAME"), + Password: os.Getenv("DB_PASSWORD"), + Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"), + Options: map[string]string{ + // See https://github.com/go-sql-driver/mysql/issues/9 + "parseTime": "true", + // Might require you to use mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root -p mysql + "time_zone": fmt.Sprintf(`'%s'`, testsuite.TimeZone), + "loc": testsuite.TimeZone, + }, +} + +type Helper struct { + sess mydb.Session +} + +func cleanUp(sess mydb.Session) error { + if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 { + return fmt.Errorf("Expecting active statements to be at most 128, got %d", activeStatements) + } + + sess.Reset() + + if activeStatements := sqladapter.NumActiveStatements(); activeStatements != 0 { + return fmt.Errorf("Expecting active statements to be 0, got %d", activeStatements) + } + + var err error + var stats map[string]int + for i := 0; i < 10; i++ { + stats, err = getStats(sess) + if err != nil { + return err + } + + if stats["Prepared_stmt_count"] != 0 { + time.Sleep(time.Millisecond * 200) // Sometimes it takes a bit to clean prepared statements + err = fmt.Errorf(`Expecting "Prepared_stmt_count" to be 0, got %d`, stats["Prepared_stmt_count"]) + continue + } + break + } + + return err +} + +func getStats(sess mydb.Session) (map[string]int, error) { + stats := make(map[string]int) + + res, err := sess.Driver().(*sql.DB).Query(`SHOW GLOBAL STATUS LIKE '%stmt%'`) + if err != nil { + return nil, err + } + var result struct { + VariableName string `db:"Variable_name"` + Value int `db:"Value"` + } + + iter := sess.SQL().NewIterator(res) + for iter.Next(&result) { + stats[result.VariableName] = result.Value + } + + return stats, nil +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return "mysql" +} + +func (h *Helper) TearDown() error { + if err := cleanUp(h.sess); err != nil { + return err + } + + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + batch := []string{ + `DROP TABLE IF EXISTS artist`, + `CREATE TABLE artist ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + name VARCHAR(60) + )`, + + `DROP TABLE IF EXISTS publication`, + `CREATE TABLE publication ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + title VARCHAR(80), + author_id BIGINT(20) + )`, + + `DROP TABLE IF EXISTS review`, + `CREATE TABLE review ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + publication_id BIGINT(20), + name VARCHAR(80), + comments TEXT, + created DATETIME NOT NULL + )`, + + `DROP TABLE IF EXISTS data_types`, + `CREATE TABLE data_types ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + _uint INT(10) UNSIGNED DEFAULT 0, + _uint8 INT(10) UNSIGNED DEFAULT 0, + _uint16 INT(10) UNSIGNED DEFAULT 0, + _uint32 INT(10) UNSIGNED DEFAULT 0, + _uint64 INT(10) UNSIGNED DEFAULT 0, + _int INT(10) DEFAULT 0, + _int8 INT(10) DEFAULT 0, + _int16 INT(10) DEFAULT 0, + _int32 INT(10) DEFAULT 0, + _int64 INT(10) DEFAULT 0, + _float32 DECIMAL(10,6), + _float64 DECIMAL(10,6), + _bool TINYINT(1), + _string text, + _blob blob, + _date TIMESTAMP NULL, + _nildate DATETIME NULL, + _ptrdate DATETIME NULL, + _defaultdate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + _time BIGINT UNSIGNED NOT NULL DEFAULT 0 + )`, + + `DROP TABLE IF EXISTS stats_test`, + `CREATE TABLE stats_test ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + ` + "`numeric`" + ` INT(10), + ` + "`value`" + ` INT(10) + )`, + + `DROP TABLE IF EXISTS composite_keys`, + `CREATE TABLE composite_keys ( + code VARCHAR(255) default '', + user_id VARCHAR(255) default '', + some_val VARCHAR(255) default '', + primary key (code, user_id) + )`, + + `DROP TABLE IF EXISTS admin`, + `CREATE TABLE admin ( + ID int(11) NOT NULL AUTO_INCREMENT, + Accounts varchar(255) DEFAULT '', + LoginPassWord varchar(255) DEFAULT '', + Date TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (ID,Date) + ) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8`, + + `DROP TABLE IF EXISTS my_types`, + + `CREATE TABLE my_types (id int(11) NOT NULL AUTO_INCREMENT, PRIMARY KEY(id) + , json_map JSON + , json_map_ptr JSON + + , auto_json_map JSON + , auto_json_map_string JSON + , auto_json_map_integer JSON + + , json_object JSON + , json_array JSON + + , custom_json_object JSON + , auto_custom_json_object JSON + + , custom_json_object_ptr JSON + , auto_custom_json_object_ptr JSON + + , custom_json_object_array JSON + , auto_custom_json_object_array JSON + , auto_custom_json_object_map JSON + + , integer_compat_value_json_array JSON + , string_compat_value_json_array JSON + , uinteger_compat_value_json_array JSON + + )`, + + `DROP TABLE IF EXISTS ` + "`" + `birthdays` + "`" + ``, + `CREATE TABLE ` + "`" + `birthdays` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + name VARCHAR(50), + born DATE, + born_ut BIGINT(20) SIGNED + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS ` + "`" + `fibonacci` + "`" + ``, + `CREATE TABLE ` + "`" + `fibonacci` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + input BIGINT(20) UNSIGNED NOT NULL, + output BIGINT(20) UNSIGNED NOT NULL + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS ` + "`" + `is_even` + "`" + ``, + `CREATE TABLE ` + "`" + `is_even` + "`" + ` ( + input BIGINT(20) UNSIGNED NOT NULL, + is_even TINYINT(1) + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS ` + "`" + `CaSe_TesT` + "`" + ``, + `CREATE TABLE ` + "`" + `CaSe_TesT` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + case_test VARCHAR(60) + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS accounts`, + + `CREATE TABLE accounts ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + name varchar(255), + disabled BOOL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )`, + + `DROP TABLE IF EXISTS users`, + + `CREATE TABLE users ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + account_id BIGINT(20), + username varchar(255) UNIQUE + )`, + + `DROP TABLE IF EXISTS logs`, + + `CREATE TABLE logs ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + message VARCHAR(255) + )`, + } + + for _, query := range batch { + driver := h.sess.Driver().(*sql.DB) + if _, err := driver.Exec(query); err != nil { + return err + } + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/mysql/mysql.go b/adapter/mysql/mysql.go new file mode 100644 index 0000000..00a3876 --- /dev/null +++ b/adapter/mysql/mysql.go @@ -0,0 +1,30 @@ +package mysql + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// Adapter is the public name of the adapter. +const Adapter = `mysql` + +var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{}) + +// Open establishes a connection to the database server and returns a +// mydb.Session instance (which is compatible with mydb.Session). +func Open(connURL mydb.ConnectionURL) (mydb.Session, error) { + return registeredAdapter.OpenDSN(connURL) +} + +// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value. +func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + return registeredAdapter.NewTx(sqlTx) +} + +// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value. +func New(sqlDB *sql.DB) (mydb.Session, error) { + return registeredAdapter.New(sqlDB) +} diff --git a/adapter/mysql/mysql_test.go b/adapter/mysql/mysql_test.go new file mode 100644 index 0000000..7e0d4ff --- /dev/null +++ b/adapter/mysql/mysql_test.go @@ -0,0 +1,379 @@ +package mysql + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type int64Compat int64 + +type uintCompat uint + +type stringCompat string + +type uintCompatArray []uintCompat + +func (u *int64Compat) Scan(src interface{}) error { + if src != nil { + switch v := src.(type) { + case int64: + *u = int64Compat((src).(int64)) + case []byte: + i, err := strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + *u = int64Compat(i) + default: + panic(fmt.Sprintf("expected type %T", src)) + } + } + return nil +} + +type customJSON struct { + N string `json:"name"` + V float64 `json:"value"` +} + +func (c customJSON) Value() (driver.Value, error) { + return JSONValue(c) +} + +func (c *customJSON) Scan(src interface{}) error { + return ScanJSON(c, src) +} + +type autoCustomJSON struct { + N string `json:"name"` + V float64 `json:"value"` + + *JSONConverter +} + +var ( + _ = driver.Valuer(&customJSON{}) + _ = sql.Scanner(&customJSON{}) +) + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) TestInsertReturningCompositeKey_Issue383() { + sess := s.Session() + + type Admin struct { + ID int `db:"ID,omitempty"` + Accounts string `db:"Accounts"` + LoginPassWord string `db:"LoginPassWord"` + Date time.Time `db:"Date"` + } + + dateNow := time.Now() + + a := Admin{ + Accounts: "admin", + LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E", + Date: dateNow, + } + + adminCollection := sess.Collection("admin") + err := adminCollection.InsertReturning(&a) + s.NoError(err) + + s.NotZero(a.ID) + s.NotZero(a.Date) + s.Equal("admin", a.Accounts) + s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord) + + b := Admin{ + Accounts: "admin2", + LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E", + Date: dateNow, + } + + err = adminCollection.InsertReturning(&b) + s.NoError(err) + + s.NotZero(b.ID) + s.NotZero(b.Date) + s.Equal("admin2", b.Accounts) + s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord) +} + +func (s *AdapterTests) TestIssue469_BadConnection() { + var err error + sess := s.Session() + + // Ask the MySQL server to disconnect sessions that remain inactive for more + // than 1 second. + _, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // A query should start a new connection, even if the server disconnected us. + _, err = sess.Collection("artist").Find().Count() + s.NoError(err) + + // This is a new session, ask the MySQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // At this point the server should have disconnected us. Let's try to create + // a transaction anyway. + err = sess.Tx(func(sess mydb.Session) error { + var err error + + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + return nil + }) + s.NoError(err) + + // This is a new session, ask the MySQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`) + s.NoError(err) + + err = sess.Tx(func(sess mydb.Session) error { + var err error + + // This query should succeed. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + panic(err.Error()) + } + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // This query should fail because the server disconnected us in the middle + // of a transaction. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + + return nil + }) + + s.Error(err, "Expecting an error (can't recover from this)") +} + +func (s *AdapterTests) TestMySQLTypes() { + sess := s.Session() + + type MyType struct { + ID int64 `db:"id,omitempty"` + + JSONMap JSONMap `db:"json_map"` + + JSONObject JSONMap `db:"json_object"` + JSONArray JSONArray `db:"json_array"` + + CustomJSONObject customJSON `db:"custom_json_object"` + AutoCustomJSONObject autoCustomJSON `db:"auto_custom_json_object"` + + CustomJSONObjectPtr *customJSON `db:"custom_json_object_ptr,omitempty"` + AutoCustomJSONObjectPtr *autoCustomJSON `db:"auto_custom_json_object_ptr,omitempty"` + + AutoCustomJSONObjectArray []autoCustomJSON `db:"auto_custom_json_object_array"` + AutoCustomJSONObjectMap map[string]autoCustomJSON `db:"auto_custom_json_object_map"` + + Int64CompatValueJSONArray []int64Compat `db:"integer_compat_value_json_array"` + UIntCompatValueJSONArray uintCompatArray `db:"uinteger_compat_value_json_array"` + StringCompatValueJSONArray []stringCompat `db:"string_compat_value_json_array"` + } + + origMyTypeTests := []MyType{ + MyType{ + Int64CompatValueJSONArray: []int64Compat{1, -2, 3, -4}, + UIntCompatValueJSONArray: []uintCompat{1, 2, 3, 4}, + StringCompatValueJSONArray: []stringCompat{"a", "b", "", "c"}, + }, + MyType{ + Int64CompatValueJSONArray: []int64Compat(nil), + UIntCompatValueJSONArray: []uintCompat(nil), + StringCompatValueJSONArray: []stringCompat(nil), + }, + MyType{ + AutoCustomJSONObjectArray: []autoCustomJSON{ + autoCustomJSON{ + N: "Hello", + }, + autoCustomJSON{ + N: "World", + }, + }, + AutoCustomJSONObjectMap: map[string]autoCustomJSON{ + "a": autoCustomJSON{ + N: "Hello", + }, + "b": autoCustomJSON{ + N: "World", + }, + }, + JSONArray: JSONArray{float64(1), float64(2), float64(3), float64(4)}, + }, + MyType{ + JSONArray: JSONArray{}, + }, + MyType{ + JSONArray: JSONArray(nil), + }, + MyType{}, + MyType{ + CustomJSONObject: customJSON{ + N: "Hello", + }, + AutoCustomJSONObject: autoCustomJSON{ + N: "World", + }, + }, + MyType{ + CustomJSONObject: customJSON{}, + AutoCustomJSONObject: autoCustomJSON{}, + }, + MyType{ + CustomJSONObject: customJSON{ + N: "Hello 1", + }, + AutoCustomJSONObject: autoCustomJSON{ + N: "World 2", + }, + }, + MyType{ + CustomJSONObjectPtr: nil, + AutoCustomJSONObjectPtr: nil, + }, + MyType{ + CustomJSONObjectPtr: &customJSON{}, + AutoCustomJSONObjectPtr: &autoCustomJSON{}, + }, + MyType{ + CustomJSONObjectPtr: &customJSON{ + N: "Hello 3", + }, + AutoCustomJSONObjectPtr: &autoCustomJSON{ + N: "World 4", + }, + }, + MyType{}, + MyType{ + CustomJSONObject: customJSON{ + V: 4.4, + }, + }, + MyType{ + CustomJSONObject: customJSON{}, + }, + MyType{ + CustomJSONObject: customJSON{ + N: "Peter", + V: 5.56, + }, + }, + } + + for i := 0; i < 100; i++ { + + myTypeTests := make([]MyType, len(origMyTypeTests)) + perm := rand.Perm(len(origMyTypeTests)) + for i, v := range perm { + myTypeTests[v] = origMyTypeTests[i] + } + + for i := range myTypeTests { + result, err := sess.Collection("my_types").Insert(myTypeTests[i]) + s.NoError(err) + + var actual MyType + err = sess.Collection("my_types").Find(result).One(&actual) + s.NoError(err) + + expected := myTypeTests[i] + expected.ID = result.ID().(int64) + s.Equal(expected, actual) + } + + for i := range myTypeTests { + res, err := sess.SQL().InsertInto("my_types").Values(myTypeTests[i]).Exec() + s.NoError(err) + + id, err := res.LastInsertId() + s.NoError(err) + s.NotEqual(0, id) + + var actual MyType + err = sess.Collection("my_types").Find(id).One(&actual) + s.NoError(err) + + expected := myTypeTests[i] + expected.ID = id + + s.Equal(expected, actual) + + var actual2 MyType + err = sess.SQL().SelectFrom("my_types").Where("id = ?", id).One(&actual2) + s.NoError(err) + s.Equal(expected, actual2) + } + + inserter := sess.SQL().InsertInto("my_types") + for i := range myTypeTests { + inserter = inserter.Values(myTypeTests[i]) + } + _, err := inserter.Exec() + s.NoError(err) + + err = sess.Collection("my_types").Truncate() + s.NoError(err) + + batch := sess.SQL().InsertInto("my_types").Batch(50) + go func() { + defer batch.Done() + for i := range myTypeTests { + batch.Values(myTypeTests[i]) + } + }() + + err = batch.Wait() + s.NoError(err) + + var values []MyType + err = sess.SQL().SelectFrom("my_types").All(&values) + s.NoError(err) + + for i := range values { + expected := myTypeTests[i] + expected.ID = values[i].ID + s.Equal(expected, values[i]) + } + } +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/mysql/record_test.go b/adapter/mysql/record_test.go new file mode 100644 index 0000000..7eb597c --- /dev/null +++ b/adapter/mysql/record_test.go @@ -0,0 +1,20 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type RecordTests struct { + testsuite.RecordTestSuite +} + +func (s *RecordTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestRecord(t *testing.T) { + suite.Run(t, &RecordTests{}) +} diff --git a/adapter/mysql/sql_test.go b/adapter/mysql/sql_test.go new file mode 100644 index 0000000..5605532 --- /dev/null +++ b/adapter/mysql/sql_test.go @@ -0,0 +1,20 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type SQLTests struct { + testsuite.SQLTestSuite +} + +func (s *SQLTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestSQL(t *testing.T) { + suite.Run(t, &SQLTests{}) +} diff --git a/adapter/mysql/template.go b/adapter/mysql/template.go new file mode 100644 index 0000000..4bd33a6 --- /dev/null +++ b/adapter/mysql/template.go @@ -0,0 +1,198 @@ +package mysql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = "`{{.Value}}`" + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + adapterJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + adapterOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + ` + + // The argument for LIMIT when only OFFSET is specified is a pretty odd magic + // number; this comes directly from MySQL's manual, see: + // https://dev.mysql.com/doc/refman/5.7/en/select.html + // + // "To retrieve all rows from a certain offset up to the end of the result + // set, you can use some large number for the second parameter. This + // statement retrieves all rows from the 96th row to the last: + // SELECT * FROM tbl LIMIT 95,18446744073709551615; " + // + // ¯\_(ツ)_/¯ + ` + {{if .Offset}} + {{if not .Limit}} + LIMIT 18446744073709551615 + {{end}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if defined .Columns}}({{.Columns | compile}}){{end}} + VALUES + {{if defined .Values}} + {{.Values | compile}} + {{else}} + () + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + adapterTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var template = &exql.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + JoinLayout: adapterJoinLayout, + OnLayout: adapterOnLayout, + UsingLayout: adapterUsingLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), +} diff --git a/adapter/mysql/template_test.go b/adapter/mysql/template_test.go new file mode 100644 index 0000000..4bd5ac1 --- /dev/null +++ b/adapter/mysql/template_test.go @@ -0,0 +1,269 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/assert" +) + +func TestTemplateSelect(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "SELECT * FROM `artist`", + b.SelectFrom("artist").String(), + ) + + assert.Equal( + "SELECT * FROM `artist`", + b.Select().From("artist").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` DESC", + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` DESC", + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` ASC", + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` ASC", + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` LIMIT 18446744073709551615 OFFSET 5", + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + "SELECT `id` FROM `artist`", + b.Select("id").From("artist").String(), + ) + + assert.Equal( + "SELECT `id`, `name` FROM `artist`", + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`name` = $1)", + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (name LIKE $1)", + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + "SELECT `id` FROM `artist` WHERE (name LIKE $1 OR name LIKE $2)", + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` > $1)", + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (id <= 2 AND name != $1)", + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2, $3, $4))", + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (name IS NOT NULL)", + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a`, `publication` AS `p` WHERE (p.author_id = a.id) LIMIT 1", + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + "SELECT `id` FROM `artist` NATURAL JOIN `publication`", + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) LIMIT 1", + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (`a`.`id` = $1) LIMIT 1", + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1", + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1", + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` LEFT JOIN `publication` AS `p1` ON (p1.id = a.id) RIGHT JOIN `publication` AS `p2` ON (p2.id = a.id)", + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + "SELECT * FROM `artist` CROSS JOIN `publication`", + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` JOIN `publication` USING (`id`)", + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + "SELECT DATE()", + b.Select(mydb.Raw("DATE()")).String(), + ) + + // Issue #408 + { + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id IN": []int{1, 2}}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` = $1 AND `name` LIKE $2)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": []byte{1, 2}}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.In(1, 2)}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.AnyOf([]int{1, 2})}).String(), + ) + } +} + +func TestTemplateInsert(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "INSERT INTO `artist` VALUES ($1, $2), ($3, $4), ($5, $6)", + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)", + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2) RETURNING `id`", + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)", + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)", + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`name`, `id`) VALUES ($1, $2)", + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) +} + +func TestTemplateUpdate(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "UPDATE `artist` SET `name` = $1", + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)", + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)", + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)", + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1, `last_name` = $2 WHERE (`id` < $3)", + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 || ' ' || $2 || id, `id` = id + $3 WHERE (id > $4)", + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) +} + +func TestTemplateDelete(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "DELETE FROM `artist` WHERE (name = $1)", + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(), + ) + + assert.Equal( + "DELETE FROM `artist` WHERE (id > 5)", + b.DeleteFrom("artist").Where("id > 5").String(), + ) +} diff --git a/adapter/postgresql/Makefile b/adapter/postgresql/Makefile new file mode 100644 index 0000000..c9a0ff2 --- /dev/null +++ b/adapter/postgresql/Makefile @@ -0,0 +1,44 @@ +SHELL ?= bash + +POSTGRES_VERSION ?= 15-alpine +POSTGRES_SUPPORTED ?= $(POSTGRES_VERSION) 14-alpine 13-alpine 12-alpine + +PROJECT ?= upper_postgres_$(POSTGRES_VERSION) + +DB_HOST ?= 127.0.0.1 +DB_PORT ?= 5432 + +DB_NAME ?= upperio +DB_USERNAME ?= upperio_user +DB_PASSWORD ?= upperio//s3cr37 + +TEST_FLAGS ?= +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1 + +export POSTGRES_VERSION + +export DB_HOST +export DB_NAME +export DB_PASSWORD +export DB_PORT +export DB_USERNAME + +export TEST_FLAGS + +test: + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +server-up: server-down + docker-compose -p $(PROJECT) up -d && \ + sleep 10 + +server-down: + docker-compose -p $(PROJECT) down + +test-extended: + parallel $(PARALLEL_FLAGS) \ + "POSTGRES_VERSION={} DB_PORT=\$$((5432+{#})) $(MAKE) server-up test server-down" ::: \ + $(POSTGRES_SUPPORTED) diff --git a/adapter/postgresql/README.md b/adapter/postgresql/README.md new file mode 100644 index 0000000..7e72601 --- /dev/null +++ b/adapter/postgresql/README.md @@ -0,0 +1,5 @@ +# PostgreSQL adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/postgresql/](https://upper.io/v4/adapter/postgresql/). + diff --git a/adapter/postgresql/collection.go b/adapter/postgresql/collection.go new file mode 100644 index 0000000..df41dd1 --- /dev/null +++ b/adapter/postgresql/collection.go @@ -0,0 +1,50 @@ +package postgresql + +import ( + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" +) + +type collectionAdapter struct { +} + +func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) { + pKey, err := col.PrimaryKeys() + if err != nil { + return nil, err + } + + q := col.SQL().InsertInto(col.Name()).Values(item) + + if len(pKey) == 0 { + // There is no primary key. + res, err := q.Exec() + if err != nil { + return nil, err + } + + // Attempt to use LastInsertId() (probably won't work, but the Exec() + // succeeded, so we can safely ignore the error from LastInsertId()). + lastID, err := res.LastInsertId() + if err != nil { + return nil, nil + } + return lastID, nil + } + + // Asking the database to return the primary key after insertion. + q = q.Returning(pKey...) + + var keyMap mydb.Cond + if err := q.Iterator().One(&keyMap); err != nil { + return nil, err + } + + // The IDSetter interface does not match, look for another interface match. + if len(keyMap) == 1 { + return keyMap[pKey[0]], nil + } + + // This was a compound key and no interface matched it, let's return a map. + return keyMap, nil +} diff --git a/adapter/postgresql/connection.go b/adapter/postgresql/connection.go new file mode 100644 index 0000000..2c2e4a7 --- /dev/null +++ b/adapter/postgresql/connection.go @@ -0,0 +1,289 @@ +package postgresql + +import ( + "fmt" + "net" + "net/url" + "sort" + "strings" + "time" + "unicode" +) + +// scanner implements a tokenizer for libpq-style option strings. +type scanner struct { + s []rune + i int +} + +// Next returns the next rune. It returns 0, false if the end of the text has +// been reached. +func (s *scanner) Next() (rune, bool) { + if s.i >= len(s.s) { + return 0, false + } + r := s.s[s.i] + s.i++ + return r, true +} + +// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the +// end of the text has been reached. +func (s *scanner) SkipSpaces() (rune, bool) { + r, ok := s.Next() + for unicode.IsSpace(r) && ok { + r, ok = s.Next() + } + return r, ok +} + +type values map[string]string + +func (vs values) Set(k, v string) { + vs[k] = v +} + +func (vs values) Get(k string) (v string) { + return vs[k] +} + +func (vs values) Isset(k string) bool { + _, ok := vs[k] + return ok +} + +// ConnectionURL represents a parsed PostgreSQL connection URL. +// +// You can use a ConnectionURL struct as an argument for Open: +// +// var settings = postgresql.ConnectionURL{ +// Host: "localhost", // PostgreSQL server IP or name. +// Database: "peanuts", // Database name. +// User: "cbrown", // Optional user name. +// Password: "snoopy", // Optional user password. +// } +// +// sess, err = postgresql.Open(settings) +// +// If you already have a valid DSN, you can use ParseURL to convert it into +// a ConnectionURL before passing it to Open. +type ConnectionURL struct { + User string + Password string + Host string + Socket string + Database string + Options map[string]string + + timezone *time.Location +} + +var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + +// ParseURL parses the given DSN into a ConnectionURL struct. +// A typical PostgreSQL connection URL looks like: +// +// postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full +func ParseURL(s string) (u *ConnectionURL, err error) { + o := make(values) + + if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") { + s, err = parseURL(s) + if err != nil { + return u, err + } + } + + if err := parseOpts(s, o); err != nil { + return u, err + } + u = &ConnectionURL{} + + u.User = o.Get("user") + u.Password = o.Get("password") + + h := o.Get("host") + p := o.Get("port") + + if strings.HasPrefix(h, "/") { + u.Socket = h + } else { + if p == "" { + u.Host = h + } else { + u.Host = fmt.Sprintf("%s:%s", h, p) + } + } + + u.Database = o.Get("dbname") + + u.Options = make(map[string]string) + + for k := range o { + switch k { + case "user", "password", "host", "port", "dbname": + // Skip + default: + u.Options[k] = o[k] + } + } + + if timezone, ok := u.Options["timezone"]; ok { + u.timezone, _ = time.LoadLocation(timezone) + } + + return u, err +} + +// parseOpts parses the options from name and adds them to the values. +// +// The parsing code is based on conninfo_parse from libpq's fe-connect.c +func parseOpts(name string, o values) error { + s := newScanner(name) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = s.SkipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = s.Next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = s.SkipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = s.SkipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + o.Set(string(keyRunes), "") + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = s.Next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = s.Next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } + } + + o.Set(string(keyRunes), string(valRunes)) + } + + return nil +} + +// newScanner returns a new scanner initialized with the option string s. +func newScanner(s string) *scanner { + return &scanner{[]rune(s), 0} +} + +// ParseURL no longer needs to be used by clients of this library since supplying a URL as a +// connection string to sql.Open() is now supported: +// +// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") +// +// It remains exported here for backwards-compatibility. +// +// ParseURL converts a url to a connection string for driver.Open. +// Example: +// +// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" +// +// converts to: +// +// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" +// +// A minimal example: +// +// "postgres://" +// +// # This will be blank, causing driver.Open to use all of the defaults +// +// NOTE: vendored/copied from github.com/lib/pq +func parseURL(uri string) (string, error) { + u, err := url.Parse(uri) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"="+escaper.Replace(v)) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} diff --git a/adapter/postgresql/connection_pgx.go b/adapter/postgresql/connection_pgx.go new file mode 100644 index 0000000..8cecdb7 --- /dev/null +++ b/adapter/postgresql/connection_pgx.go @@ -0,0 +1,73 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "net" + "sort" + "strings" +) + +// String reassembles the parsed PostgreSQL connection URL into a valid DSN. +func (c ConnectionURL) String() (s string) { + u := []string{} + + // TODO: This surely needs some sort of escaping. + if c.User != "" { + u = append(u, "user="+escaper.Replace(c.User)) + } + + if c.Password != "" { + u = append(u, "password="+escaper.Replace(c.Password)) + } + + if c.Host != "" { + host, port, err := net.SplitHostPort(c.Host) + if err == nil { + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port = "5432" + } + u = append(u, "host="+escaper.Replace(host)) + u = append(u, "port="+escaper.Replace(port)) + } else { + u = append(u, "host="+escaper.Replace(c.Host)) + } + } + + if c.Socket != "" { + u = append(u, "host="+escaper.Replace(c.Socket)) + } + + if c.Database != "" { + u = append(u, "dbname="+escaper.Replace(c.Database)) + } + + // Is there actually any connection data? + if len(u) == 0 { + return "" + } + + if c.Options == nil { + c.Options = map[string]string{} + } + + // If not present, SSL mode is assumed "prefer". + if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" { + c.Options["sslmode"] = "prefer" + } + + // Disabled by default + c.Options["statement_cache_capacity"] = "0" + + for k, v := range c.Options { + u = append(u, escaper.Replace(k)+"="+escaper.Replace(v)) + } + + sort.Strings(u) + + return strings.Join(u, " ") +} diff --git a/adapter/postgresql/connection_pgx_test.go b/adapter/postgresql/connection_pgx_test.go new file mode 100644 index 0000000..380c9db --- /dev/null +++ b/adapter/postgresql/connection_pgx_test.go @@ -0,0 +1,108 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectionURL(t *testing.T) { + c := ConnectionURL{} + + // Default connection string is empty. + assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty") + + // Adding a host with port. + c.Host = "localhost:1234" + assert.Equal(t, "host=localhost port=1234 sslmode=prefer statement_cache_capacity=0", c.String()) + + // Adding a host. + c.Host = "localhost" + assert.Equal(t, "host=localhost sslmode=prefer statement_cache_capacity=0", c.String()) + + // Adding a username. + c.User = "Anakin" + assert.Equal(t, `host=localhost sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding a password with special characters. + c.Password = "Some Sort of ' Password" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding a port. + c.Host = "localhost:1234" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding a database. + c.Database = "MyDatabase" + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding options. + c.Options = map[string]string{ + "sslmode": "verify-full", + } + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full statement_cache_capacity=0 user=Anakin`, c.String()) +} + +func TestParseConnectionURL(t *testing.T) { + + { + s := "postgres://anakin:skywalker@localhost/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "localhost", u.Host) + assert.Equal(t, "jedis", u.Database) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + // case with port + s := "postgres://anakin:skywalker@localhost:1234/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost:1234", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, 2, len(u.Options), "Expecting exactly two options") + assert.Equal(t, "verify-full", u.Options["sslmode"]) + assert.Equal(t, "UTC", u.Options["timezone"]) + } +} diff --git a/adapter/postgresql/connection_pq.go b/adapter/postgresql/connection_pq.go new file mode 100644 index 0000000..fc3bddd --- /dev/null +++ b/adapter/postgresql/connection_pq.go @@ -0,0 +1,70 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "net" + "sort" + "strings" +) + +// String reassembles the parsed PostgreSQL connection URL into a valid DSN. +func (c ConnectionURL) String() (s string) { + u := []string{} + + // TODO: This surely needs some sort of escaping. + if c.User != "" { + u = append(u, "user="+escaper.Replace(c.User)) + } + + if c.Password != "" { + u = append(u, "password="+escaper.Replace(c.Password)) + } + + if c.Host != "" { + host, port, err := net.SplitHostPort(c.Host) + if err == nil { + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port = "5432" + } + u = append(u, "host="+escaper.Replace(host)) + u = append(u, "port="+escaper.Replace(port)) + } else { + u = append(u, "host="+escaper.Replace(c.Host)) + } + } + + if c.Socket != "" { + u = append(u, "host="+escaper.Replace(c.Socket)) + } + + if c.Database != "" { + u = append(u, "dbname="+escaper.Replace(c.Database)) + } + + // Is there actually any connection data? + if len(u) == 0 { + return "" + } + + if c.Options == nil { + c.Options = map[string]string{} + } + + // If not present, SSL mode is assumed "prefer". + if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" { + c.Options["sslmode"] = "prefer" + } + + for k, v := range c.Options { + u = append(u, escaper.Replace(k)+"="+escaper.Replace(v)) + } + + sort.Strings(u) + + return strings.Join(u, " ") +} diff --git a/adapter/postgresql/connection_pq_test.go b/adapter/postgresql/connection_pq_test.go new file mode 100644 index 0000000..7646a7b --- /dev/null +++ b/adapter/postgresql/connection_pq_test.go @@ -0,0 +1,108 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectionURL(t *testing.T) { + c := ConnectionURL{} + + // Default connection string is empty. + assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty") + + // Adding a host with port. + c.Host = "localhost:1234" + assert.Equal(t, "host=localhost port=1234 sslmode=prefer", c.String()) + + // Adding a host. + c.Host = "localhost" + assert.Equal(t, "host=localhost sslmode=prefer", c.String()) + + // Adding a username. + c.User = "Anakin" + assert.Equal(t, `host=localhost sslmode=prefer user=Anakin`, c.String()) + + // Adding a password with special characters. + c.Password = "Some Sort of ' Password" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer user=Anakin`, c.String()) + + // Adding a port. + c.Host = "localhost:1234" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String()) + + // Adding a database. + c.Database = "MyDatabase" + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String()) + + // Adding options. + c.Options = map[string]string{ + "sslmode": "verify-full", + } + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full user=Anakin`, c.String()) +} + +func TestParseConnectionURL(t *testing.T) { + + { + s := "postgres://anakin:skywalker@localhost/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "localhost", u.Host) + assert.Equal(t, "jedis", u.Database) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + // case with port + s := "postgres://anakin:skywalker@localhost:1234/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost:1234", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, 2, len(u.Options), "Expecting exactly two options") + assert.Equal(t, "verify-full", u.Options["sslmode"]) + assert.Equal(t, "UTC", u.Options["timezone"]) + } +} diff --git a/adapter/postgresql/custom_types.go b/adapter/postgresql/custom_types.go new file mode 100644 index 0000000..7b11668 --- /dev/null +++ b/adapter/postgresql/custom_types.go @@ -0,0 +1,126 @@ +package postgresql + +import ( + "context" + "database/sql" + "database/sql/driver" + "time" + + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// JSONBMap represents a map of interfaces with string keys +// (`map[string]interface{}`) that is compatible with PostgreSQL's JSONB type. +// JSONBMap satisfies sqlbuilder.ScannerValuer. +type JSONBMap map[string]interface{} + +// Value satisfies the driver.Valuer interface. +func (m JSONBMap) Value() (driver.Value, error) { + return JSONBValue(m) +} + +// Scan satisfies the sql.Scanner interface. +func (m *JSONBMap) Scan(src interface{}) error { + *m = map[string]interface{}(nil) + return ScanJSONB(m, src) +} + +// JSONBArray represents an array of any type (`[]interface{}`) that is +// compatible with PostgreSQL's JSONB type. JSONBArray satisfies +// sqlbuilder.ScannerValuer. +type JSONBArray []interface{} + +// Value satisfies the driver.Valuer interface. +func (a JSONBArray) Value() (driver.Value, error) { + return JSONBValue(a) +} + +// Scan satisfies the sql.Scanner interface. +func (a *JSONBArray) Scan(src interface{}) error { + return ScanJSONB(a, src) +} + +// JSONBValue takes an interface and provides a driver.Value that can be +// stored as a JSONB column. +func JSONBValue(i interface{}) (driver.Value, error) { + v := JSONB{i} + return v.Value() +} + +// ScanJSONB decodes a JSON byte stream into the passed dst value. +func ScanJSONB(dst interface{}, src interface{}) error { + v := JSONB{dst} + return v.Scan(src) +} + +type JSONBConverter struct { +} + +func (*JSONBConverter) ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer +} { + return &JSONB{in} +} + +type timeWrapper struct { + v **time.Time + loc *time.Location +} + +func (t timeWrapper) Value() (driver.Value, error) { + if *t.v != nil { + return **t.v, nil + } + return nil, nil +} + +func (t *timeWrapper) Scan(src interface{}) error { + if src == nil { + nilTime := (*time.Time)(nil) + if t.v == nil { + t.v = &nilTime + } else { + *(t.v) = nilTime + } + return nil + } + tz := src.(time.Time) + if t.loc != nil && (tz.Location() == time.Local) { + tz = tz.In(t.loc) + } + if tz.Location().String() == "" { + tz = tz.In(time.UTC) + } + if *(t.v) == nil { + *(t.v) = &tz + } else { + **t.v = tz + } + return nil +} + +func (d *database) ConvertValueContext(ctx context.Context, in interface{}) interface{} { + tz, _ := ctx.Value("timezone").(*time.Location) + + switch v := in.(type) { + case *time.Time: + return &timeWrapper{&v, tz} + case **time.Time: + return &timeWrapper{v, tz} + } + + return d.ConvertValue(in) +} + +// Type checks. +var ( + _ sqlbuilder.ScannerValuer = &StringArray{} + _ sqlbuilder.ScannerValuer = &Int64Array{} + _ sqlbuilder.ScannerValuer = &Float64Array{} + _ sqlbuilder.ScannerValuer = &Float32Array{} + _ sqlbuilder.ScannerValuer = &BoolArray{} + _ sqlbuilder.ScannerValuer = &JSONBMap{} + _ sqlbuilder.ScannerValuer = &JSONBArray{} + _ sqlbuilder.ScannerValuer = &JSONB{} +) diff --git a/adapter/postgresql/custom_types_pgx.go b/adapter/postgresql/custom_types_pgx.go new file mode 100644 index 0000000..7f5324c --- /dev/null +++ b/adapter/postgresql/custom_types_pgx.go @@ -0,0 +1,286 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +// JSONB represents a PostgreSQL's JSONB value: +// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB +// satisfies sqlbuilder.ScannerValuer. +type JSONB struct { + Data interface{} +} + +// MarshalJSON encodes the wrapper value as JSON. +func (j JSONB) MarshalJSON() ([]byte, error) { + t := &pgtype.JSONB{} + if err := t.Set(j.Data); err != nil { + return nil, err + } + return t.MarshalJSON() +} + +// UnmarshalJSON decodes the given JSON into the wrapped value. +func (j *JSONB) UnmarshalJSON(b []byte) error { + t := &pgtype.JSONB{} + if err := t.UnmarshalJSON(b); err != nil { + return err + } + if j.Data == nil { + j.Data = t.Get() + return nil + } + if err := t.AssignTo(&j.Data); err != nil { + return err + } + return nil +} + +// Scan satisfies the sql.Scanner interface. +func (j *JSONB) Scan(src interface{}) error { + t := &pgtype.JSONB{} + if err := t.Scan(src); err != nil { + return err + } + if j.Data == nil { + j.Data = t.Get() + return nil + } + if err := t.AssignTo(j.Data); err != nil { + return err + } + return nil +} + +// Value satisfies the driver.Valuer interface. +func (j JSONB) Value() (driver.Value, error) { + t := &pgtype.JSONB{} + if err := t.Set(j.Data); err != nil { + return nil, err + } + return t.Value() +} + +// StringArray represents a one-dimensional array of strings (`[]string{}`) +// that is compatible with PostgreSQL's text array (`text[]`). StringArray +// satisfies sqlbuilder.ScannerValuer. +type StringArray []string + +// Value satisfies the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + t := pgtype.TextArray{} + if err := t.Set(a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (sa *StringArray) Scan(src interface{}) error { + d := []string{} + t := pgtype.TextArray{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *sa = StringArray(d) + return nil +} + +type Bytea []byte + +func (b Bytea) Value() (driver.Value, error) { + t := pgtype.Bytea{Bytes: b} + if err := t.Set(b); err != nil { + return nil, err + } + return t.Value() +} + +func (b *Bytea) Scan(src interface{}) error { + d := []byte{} + t := pgtype.Bytea{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *b = Bytea(d) + return nil +} + +// ByteaArray represents a one-dimensional array of strings (`[]string{}`) +// that is compatible with PostgreSQL's text array (`text[]`). ByteaArray +// satisfies sqlbuilder.ScannerValuer. +type ByteaArray [][]byte + +// Value satisfies the driver.Valuer interface. +func (a ByteaArray) Value() (driver.Value, error) { + t := pgtype.ByteaArray{} + if err := t.Set(a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (ba *ByteaArray) Scan(src interface{}) error { + d := [][]byte{} + t := pgtype.ByteaArray{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *ba = ByteaArray(d) + return nil +} + +// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that +// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array +// satisfies sqlbuilder.ScannerValuer. +type Int64Array []int64 + +// Value satisfies the driver.Valuer interface. +func (i64a Int64Array) Value() (driver.Value, error) { + t := pgtype.Int8Array{} + if err := t.Set(i64a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (i64a *Int64Array) Scan(src interface{}) error { + d := []int64{} + t := pgtype.Int8Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *i64a = Int64Array(d) + return nil +} + +// Int32Array represents a one-dimensional array of int32s (`[]int32{}`) that +// is compatible with PostgreSQL's integer array (`integer[]`). Int32Array +// satisfies sqlbuilder.ScannerValuer. +type Int32Array []int32 + +// Value satisfies the driver.Valuer interface. +func (i32a Int32Array) Value() (driver.Value, error) { + t := pgtype.Int4Array{} + if err := t.Set(i32a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (i32a *Int32Array) Scan(src interface{}) error { + d := []int32{} + t := pgtype.Int4Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *i32a = Int32Array(d) + return nil +} + +// Float64Array represents a one-dimensional array of float64s (`[]float64{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer. +type Float64Array []float64 + +// Value satisfies the driver.Valuer interface. +func (f64a Float64Array) Value() (driver.Value, error) { + t := pgtype.Float8Array{} + if err := t.Set(f64a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f64a *Float64Array) Scan(src interface{}) error { + d := []float64{} + t := pgtype.Float8Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *f64a = Float64Array(d) + return nil +} + +// Float32Array represents a one-dimensional array of float32s (`[]float32{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer. +type Float32Array []float32 + +// Value satisfies the driver.Valuer interface. +func (f32a Float32Array) Value() (driver.Value, error) { + t := pgtype.Float8Array{} + if err := t.Set(f32a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f32a *Float32Array) Scan(src interface{}) error { + d := []float32{} + t := pgtype.Float8Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *f32a = Float32Array(d) + return nil +} + +// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that +// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray +// satisfies sqlbuilder.ScannerValuer. +type BoolArray []bool + +// Value satisfies the driver.Valuer interface. +func (ba BoolArray) Value() (driver.Value, error) { + t := pgtype.BoolArray{} + if err := t.Set(ba); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (ba *BoolArray) Scan(src interface{}) error { + d := []bool{} + t := pgtype.BoolArray{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *ba = BoolArray(d) + return nil +} diff --git a/adapter/postgresql/custom_types_pq.go b/adapter/postgresql/custom_types_pq.go new file mode 100644 index 0000000..f93ca4c --- /dev/null +++ b/adapter/postgresql/custom_types_pq.go @@ -0,0 +1,249 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "bytes" + "database/sql/driver" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/lib/pq" +) + +// JSONB represents a PostgreSQL's JSONB value: +// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB +// satisfies sqlbuilder.ScannerValuer. +type JSONB struct { + Data interface{} +} + +// MarshalJSON encodes the wrapper value as JSON. +func (j JSONB) MarshalJSON() ([]byte, error) { + return json.Marshal(j.Data) +} + +// UnmarshalJSON decodes the given JSON into the wrapped value. +func (j *JSONB) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + j.Data = v + return nil +} + +// Scan satisfies the sql.Scanner interface. +func (j *JSONB) Scan(src interface{}) error { + if j.Data == nil { + return nil + } + if src == nil { + dv := reflect.Indirect(reflect.ValueOf(j.Data)) + dv.Set(reflect.Zero(dv.Type())) + return nil + } + + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + if err := json.Unmarshal(b, j.Data); err != nil { + return err + } + return nil +} + +// Value satisfies the driver.Valuer interface. +func (j JSONB) Value() (driver.Value, error) { + // See https://github.com/lib/pq/issues/528#issuecomment-257197239 on why are + // we returning string instead of []byte. + if j.Data == nil { + return nil, nil + } + if v, ok := j.Data.(json.RawMessage); ok { + return string(v), nil + } + b, err := json.Marshal(j.Data) + if err != nil { + return nil, err + } + return string(b), nil +} + +// StringArray represents a one-dimensional array of strings (`[]string{}`) +// that is compatible with PostgreSQL's text array (`text[]`). StringArray +// satisfies sqlbuilder.ScannerValuer. +type StringArray pq.StringArray + +// Value satisfies the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + return pq.StringArray(a).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + s := pq.StringArray(*a) + if err := s.Scan(src); err != nil { + return err + } + *a = StringArray(s) + return nil +} + +// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that +// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array +// satisfies sqlbuilder.ScannerValuer. +type Int64Array pq.Int64Array + +// Value satisfies the driver.Valuer interface. +func (i Int64Array) Value() (driver.Value, error) { + return pq.Int64Array(i).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (i *Int64Array) Scan(src interface{}) error { + s := pq.Int64Array(*i) + if err := s.Scan(src); err != nil { + return err + } + *i = Int64Array(s) + return nil +} + +// Float64Array represents a one-dimensional array of float64s (`[]float64{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer. +type Float64Array pq.Float64Array + +// Value satisfies the driver.Valuer interface. +func (f Float64Array) Value() (driver.Value, error) { + return pq.Float64Array(f).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f *Float64Array) Scan(src interface{}) error { + s := pq.Float64Array(*f) + if err := s.Scan(src); err != nil { + return err + } + *f = Float64Array(s) + return nil +} + +// Float32Array represents a one-dimensional array of float32s (`[]float32{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer. +type Float32Array pq.Float32Array + +// Value satisfies the driver.Valuer interface. +func (f Float32Array) Value() (driver.Value, error) { + return pq.Float32Array(f).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f *Float32Array) Scan(src interface{}) error { + s := pq.Float32Array(*f) + if err := s.Scan(src); err != nil { + return err + } + *f = Float32Array(s) + return nil +} + +// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that +// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray +// satisfies sqlbuilder.ScannerValuer. +type BoolArray pq.BoolArray + +// Value satisfies the driver.Valuer interface. +func (b BoolArray) Value() (driver.Value, error) { + return pq.BoolArray(b).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (b *BoolArray) Scan(src interface{}) error { + s := pq.BoolArray(*b) + if err := s.Scan(src); err != nil { + return err + } + *b = BoolArray(s) + return nil +} + +type Bytea []byte + +// Scan satisfies the sql.Scanner interface. +func (b *Bytea) Scan(src interface{}) error { + decoded, err := parseBytea(src.([]byte)) + if err != nil { + return err + } + if len(decoded) < 1 { + *b = nil + return nil + } + (*b) = make(Bytea, len(decoded)) + for i := range decoded { + (*b)[i] = decoded[i] + } + return nil +} + +type Time time.Time + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytea(s []byte) (result []byte, err error) { + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + // bytea_output = hex + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + return nil, err + } + } else { + // bytea_output = escape + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseInt(string(s[1:4]), 8, 9) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + } + + return result, nil +} diff --git a/adapter/postgresql/custom_types_test.go b/adapter/postgresql/custom_types_test.go new file mode 100644 index 0000000..8624fda --- /dev/null +++ b/adapter/postgresql/custom_types_test.go @@ -0,0 +1,105 @@ +package postgresql + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testStruct struct { + X int `json:"x"` + Z string `json:"z"` + V interface{} `json:"v"` +} + +func TestScanJSONB(t *testing.T) { + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": 1}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.Z) + assert.Equal(t, float64(1), a.V) + assert.Equal(t, 5, a.X) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": null}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.Z) + assert.Equal(t, nil, a.V) + assert.Equal(t, 5, a.X) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello"}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.Z) + assert.Equal(t, nil, a.V) + assert.Equal(t, 5, a.X) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"v": "Hello"}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.V) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"v": true}`)) + assert.NoError(t, err) + assert.Equal(t, true, a.V) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{}`)) + assert.NoError(t, err) + assert.Equal(t, nil, a.V) + } + { + var a []byte + err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`)) + assert.NoError(t, err) + assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a)) + } + { + var a json.RawMessage + err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`)) + assert.NoError(t, err) + assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a)) + } + { + var a json.RawMessage + err := ScanJSONB(&a, []byte("{\"age\":[{\"\u003e\":\"1h\"}]}")) + assert.NoError(t, err) + assert.Equal(t, `{"age":[{">":"1h"}]}`, string(a)) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Nil(t, a[0].V) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{"v": true}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Equal(t, true, a[0].V) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{"v": null}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Nil(t, a[0].V) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{"v": 12.34}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Equal(t, 12.34, a[0].V) + } +} diff --git a/adapter/postgresql/database.go b/adapter/postgresql/database.go new file mode 100644 index 0000000..f718e13 --- /dev/null +++ b/adapter/postgresql/database.go @@ -0,0 +1,180 @@ +// Package postgresql provides an adapter for PostgreSQL. +// See https://github.com/upper/db/adapter/postgresql for documentation, +// particularities and usage examples. +package postgresql + +import ( + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +type database struct { +} + +func (*database) Template() *exql.Template { + return template +} + +func (*database) Collections(sess sqladapter.Session) (collections []string, err error) { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ?", "public") + + iter := q.Iterator() + defer iter.Close() + + for iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return nil, err + } + collections = append(collections, name) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return collections, nil +} + +func (*database) ConvertValue(in interface{}) interface{} { + switch v := in.(type) { + case *[]int64: + return (*Int64Array)(v) + case *[]string: + return (*StringArray)(v) + case *[]float64: + return (*Float64Array)(v) + case *[]bool: + return (*BoolArray)(v) + case *map[string]interface{}: + return (*JSONBMap)(v) + + case []int64: + return (*Int64Array)(&v) + case []string: + return (*StringArray)(&v) + case []float64: + return (*Float64Array)(&v) + case []bool: + return (*BoolArray)(&v) + case map[string]interface{}: + return (*JSONBMap)(&v) + + } + return in +} + +func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error) { + compiled, err := stmt.Compile(template) + if err != nil { + return "", nil, err + } + + query, args := sqlbuilder.Preprocess(compiled, args) + query = string(sqladapter.ReplaceWithDollarSign([]byte(query))) + return query, args, nil +} + +func (*database) Err(err error) error { + if err != nil { + s := err.Error() + // These errors are not exported so we have to check them by they string value. + if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) || strings.Contains(s, `too many open`) { + return mydb.ErrTooManyClients + } + } + return err +} + +func (*database) NewCollection() sqladapter.CollectionAdapter { + return &collectionAdapter{} +} + +func (*database) LookupName(sess sqladapter.Session) (string, error) { + q := sess.SQL(). + Select(mydb.Raw("CURRENT_DATABASE() AS name")) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return "", err + } + return name, nil + } + + return "", iter.Err() +} + +func (*database) TableExists(sess sqladapter.Session, name string) error { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_catalog = ? AND table_name = ?", sess.Name(), name) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return err + } + return nil + } + if err := iter.Err(); err != nil { + return err + } + + return mydb.ErrCollectionDoesNotExist +} + +func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) { + q := sess.SQL(). + Select("pg_attribute.attname AS pkey"). + From("pg_index", "pg_class", "pg_attribute"). + Where(` + pg_class.oid = '` + quotedTableName(tableName) + `'::regclass + AND indrelid = pg_class.oid + AND pg_attribute.attrelid = pg_class.oid + AND pg_attribute.attnum = ANY(pg_index.indkey) + AND indisprimary + `).OrderBy("pkey") + + iter := q.Iterator() + defer iter.Close() + + pk := []string{} + + for iter.Next() { + var k string + if err := iter.Scan(&k); err != nil { + return nil, err + } + pk = append(pk, k) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return pk, nil +} + +// quotedTableName returns a valid regclass name for both regular tables and +// for schemas. +func quotedTableName(s string) string { + chunks := strings.Split(s, ".") + for i := range chunks { + chunks[i] = fmt.Sprintf("%q", chunks[i]) + } + return strings.Join(chunks, ".") +} diff --git a/adapter/postgresql/database_pgx.go b/adapter/postgresql/database_pgx.go new file mode 100644 index 0000000..a40aee3 --- /dev/null +++ b/adapter/postgresql/database_pgx.go @@ -0,0 +1,26 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "context" + "database/sql" + "time" + + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + _ "github.com/jackc/pgx/v4/stdlib" +) + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + connURL, err := ParseURL(dsn) + if err != nil { + return nil, err + } + if tz := connURL.Options["timezone"]; tz != "" { + loc, _ := time.LoadLocation(tz) + ctx := context.WithValue(sess.Context(), "timezone", loc) + sess.SetContext(ctx) + } + return sql.Open("pgx", dsn) +} diff --git a/adapter/postgresql/database_pq.go b/adapter/postgresql/database_pq.go new file mode 100644 index 0000000..da3b904 --- /dev/null +++ b/adapter/postgresql/database_pq.go @@ -0,0 +1,26 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "context" + "database/sql" + "time" + + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + _ "github.com/lib/pq" +) + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + connURL, err := ParseURL(dsn) + if err != nil { + return nil, err + } + if tz := connURL.Options["timezone"]; tz != "" { + loc, _ := time.LoadLocation(tz) + ctx := context.WithValue(sess.Context(), "timezone", loc) + sess.SetContext(ctx) + } + return sql.Open("postgres", dsn) +} diff --git a/adapter/postgresql/docker-compose.yml b/adapter/postgresql/docker-compose.yml new file mode 100644 index 0000000..4f4884a --- /dev/null +++ b/adapter/postgresql/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3' + +services: + + server: + image: postgres:${POSTGRES_VERSION:-11} + environment: + POSTGRES_USER: ${DB_USERNAME:-upperio_user} + POSTGRES_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37} + POSTGRES_DB: ${DB_NAME:-upperio} + ports: + - '${DB_HOST:-127.0.0.1}:${DB_PORT:-5432}:5432' + diff --git a/adapter/postgresql/generic_test.go b/adapter/postgresql/generic_test.go new file mode 100644 index 0000000..feec428 --- /dev/null +++ b/adapter/postgresql/generic_test.go @@ -0,0 +1,20 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/postgresql/helper_test.go b/adapter/postgresql/helper_test.go new file mode 100644 index 0000000..acb490e --- /dev/null +++ b/adapter/postgresql/helper_test.go @@ -0,0 +1,321 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "os" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/testsuite" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), + User: os.Getenv("DB_USERNAME"), + Password: os.Getenv("DB_PASSWORD"), + Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"), + Options: map[string]string{ + "timezone": testsuite.TimeZone, + }, +} + +const preparedStatementsKey = "pg_prepared_statements_count" + +type Helper struct { + sess mydb.Session +} + +func cleanUp(sess mydb.Session) error { + if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 { + return fmt.Errorf("Expecting active statements to be less than 128, got %d", activeStatements) + } + + sess.Reset() + + stats, err := getStats(sess) + if err != nil { + return err + } + + if stats[preparedStatementsKey] != 0 { + return fmt.Errorf(`Expecting %q to be 0, got %d`, preparedStatementsKey, stats[preparedStatementsKey]) + } + + return nil +} + +func getStats(sess mydb.Session) (map[string]int, error) { + stats := make(map[string]int) + + row := sess.Driver().(*sql.DB).QueryRow(`SELECT count(1) AS value FROM pg_prepared_statements`) + + var value int + err := row.Scan(&value) + if err != nil { + return nil, err + } + + stats[preparedStatementsKey] = value + + return stats, nil +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return Adapter +} + +func (h *Helper) TearDown() error { + if err := cleanUp(h.sess); err != nil { + return err + } + + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + batch := []string{ + `DROP TABLE IF EXISTS artist`, + `CREATE TABLE artist ( + id serial primary key, + name varchar(60) + )`, + + `DROP TABLE IF EXISTS publication`, + `CREATE TABLE publication ( + id serial primary key, + title varchar(80), + author_id integer + )`, + + `DROP TABLE IF EXISTS review`, + `CREATE TABLE review ( + id serial primary key, + publication_id integer, + name varchar(80), + comments text, + created timestamp without time zone + )`, + + `DROP TABLE IF EXISTS data_types`, + `CREATE TABLE data_types ( + id serial primary key, + _uint integer, + _uint8 integer, + _uint16 integer, + _uint32 integer, + _uint64 integer, + _int integer, + _int8 integer, + _int16 integer, + _int32 integer, + _int64 integer, + _float32 numeric(10,6), + _float64 numeric(10,6), + _bool boolean, + _string text, + _blob bytea, + _date timestamp with time zone, + _nildate timestamp without time zone null, + _ptrdate timestamp without time zone, + _defaultdate timestamp without time zone DEFAULT now(), + _time bigint + )`, + + `DROP TABLE IF EXISTS stats_test`, + `CREATE TABLE stats_test ( + id serial primary key, + numeric integer, + value integer + )`, + + `DROP TABLE IF EXISTS composite_keys`, + `CREATE TABLE composite_keys ( + code varchar(255) default '', + user_id varchar(255) default '', + some_val varchar(255) default '', + primary key (code, user_id) + )`, + + `DROP TABLE IF EXISTS option_types`, + `CREATE TABLE option_types ( + id serial primary key, + name varchar(255) default '', + tags varchar(64)[], + settings jsonb + )`, + + `DROP TABLE IF EXISTS test_schema.test`, + `DROP SCHEMA IF EXISTS test_schema`, + + `CREATE SCHEMA test_schema`, + `CREATE TABLE test_schema.test (id integer)`, + + `DROP TABLE IF EXISTS pg_types`, + `CREATE TABLE pg_types (id serial primary key + , uint8_value smallint + , uint8_value_array bytea + + , int64_value smallint + , int64_value_array smallint[] + + , integer_array integer[] + , string_array text[] + , jsonb_map jsonb + , raw_jsonb_map jsonb + , raw_jsonb_text jsonb + + , integer_array_ptr integer[] + , string_array_ptr text[] + , jsonb_map_ptr jsonb + + , auto_integer_array integer[] + , auto_string_array text[] + , auto_jsonb_map jsonb + , auto_jsonb_map_string jsonb + , auto_jsonb_map_integer jsonb + + , jsonb_object jsonb + , jsonb_array jsonb + + , custom_jsonb_object jsonb + , auto_custom_jsonb_object jsonb + + , custom_jsonb_object_ptr jsonb + , auto_custom_jsonb_object_ptr jsonb + + , custom_jsonb_object_array jsonb + , auto_custom_jsonb_object_array jsonb + , auto_custom_jsonb_object_map jsonb + + , string_value varchar(255) + , integer_value int + , varchar_value varchar(64) + , decimal_value decimal + + , integer_compat_value int + , uinteger_compat_value int + , string_compat_value text + + , integer_compat_value_jsonb_array jsonb + , string_compat_value_jsonb_array jsonb + , uinteger_compat_value_jsonb_array jsonb + + , string_value_ptr varchar(255) + , integer_value_ptr int + , varchar_value_ptr varchar(64) + , decimal_value_ptr decimal + + , uuid_value_string UUID + + )`, + + `DROP TABLE IF EXISTS issue_370`, + `CREATE TABLE issue_370 ( + id UUID PRIMARY KEY, + name VARCHAR(25) + )`, + + `CREATE EXTENSION IF NOT EXISTS "uuid-ossp"`, + + `DROP TABLE IF EXISTS issue_602_organizations`, + `CREATE TABLE issue_602_organizations ( + name character varying(256) NOT NULL, + created_at timestamp without time zone DEFAULT now() NOT NULL, + updated_at timestamp without time zone DEFAULT now() NOT NULL, + id uuid DEFAULT public.uuid_generate_v4() NOT NULL + )`, + + `ALTER TABLE ONLY issue_602_organizations ADD CONSTRAINT issue_602_organizations_pkey PRIMARY KEY (id)`, + + `DROP TABLE IF EXISTS issue_370_2`, + `CREATE TABLE issue_370_2 ( + id INTEGER[3] PRIMARY KEY, + name VARCHAR(25) + )`, + + `DROP TABLE IF EXISTS varchar_primary_key`, + `CREATE TABLE varchar_primary_key ( + address VARCHAR(42) PRIMARY KEY NOT NULL, + name VARCHAR(25) + )`, + + `DROP TABLE IF EXISTS "birthdays"`, + `CREATE TABLE "birthdays" ( + "id" serial primary key, + "name" CHARACTER VARYING(50), + "born" TIMESTAMP WITH TIME ZONE, + "born_ut" INT + )`, + + `DROP TABLE IF EXISTS "fibonacci"`, + `CREATE TABLE "fibonacci" ( + "id" serial primary key, + "input" NUMERIC, + "output" NUMERIC + )`, + + `DROP TABLE IF EXISTS "is_even"`, + `CREATE TABLE "is_even" ( + "input" NUMERIC, + "is_even" BOOL + )`, + + `DROP TABLE IF EXISTS "CaSe_TesT"`, + `CREATE TABLE "CaSe_TesT" ( + "id" SERIAL PRIMARY KEY, + "case_test" VARCHAR(60) + )`, + + `DROP TABLE IF EXISTS accounts`, + `CREATE TABLE accounts ( + id serial primary key, + name varchar(255), + disabled boolean, + created_at timestamp with time zone + )`, + + `DROP TABLE IF EXISTS users`, + `CREATE TABLE users ( + id serial primary key, + account_id integer, + username varchar(255) UNIQUE + )`, + + `DROP TABLE IF EXISTS logs`, + `CREATE TABLE logs ( + id serial primary key, + message VARCHAR + )`, + } + + driver := h.sess.Driver().(*sql.DB) + tx, err := driver.Begin() + if err != nil { + return err + } + for _, query := range batch { + if _, err := tx.Exec(query); err != nil { + _ = tx.Rollback() + return err + } + } + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/postgresql/postgresql.go b/adapter/postgresql/postgresql.go new file mode 100644 index 0000000..2e770a8 --- /dev/null +++ b/adapter/postgresql/postgresql.go @@ -0,0 +1,30 @@ +package postgresql + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// Adapter is the internal name of the adapter. +const Adapter = "postgresql" + +var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{}) + +// Open establishes a connection to the database server and returns a +// sqlbuilder.Session instance (which is compatible with mydb.Session). +func Open(connURL mydb.ConnectionURL) (mydb.Session, error) { + return registeredAdapter.OpenDSN(connURL) +} + +// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value. +func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + return registeredAdapter.NewTx(sqlTx) +} + +// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value. +func New(sqlDB *sql.DB) (mydb.Session, error) { + return registeredAdapter.New(sqlDB) +} diff --git a/adapter/postgresql/postgresql_test.go b/adapter/postgresql/postgresql_test.go new file mode 100644 index 0000000..b40f6ef --- /dev/null +++ b/adapter/postgresql/postgresql_test.go @@ -0,0 +1,1404 @@ +package postgresql + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type customJSONBObjectArray []customJSONB + +func (customJSONBObjectArray) ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer +} { + return &JSONB{in} +} + +type customJSONBObjectMap map[string]customJSONB + +func (c customJSONBObjectMap) Value() (driver.Value, error) { + return JSONBValue(c) +} + +func (c *customJSONBObjectMap) Scan(src interface{}) error { + return ScanJSONB(c, src) +} + +type customJSONB struct { + N string `json:"name"` + V float64 `json:"value"` + + *JSONBConverter +} + +type int64Compat int64 + +type uintCompat uint + +type stringCompat string + +type uint8Compat uint8 + +type uint8CompatArray []uint8Compat + +func (ua uint8CompatArray) Value() (driver.Value, error) { + v := make([]byte, len(ua)) + for i := range ua { + v[i] = byte(ua[i]) + } + return v, nil +} + +func (ua *uint8CompatArray) Scan(src interface{}) error { + decoded := Bytea{} + if err := decoded.Scan(src); err != nil { + return nil + } + if len(decoded) < 1 { + *ua = nil + return nil + } + *ua = make([]uint8Compat, len(decoded)) + for i := range decoded { + (*ua)[i] = uint8Compat(decoded[i]) + } + return nil +} + +type int64CompatArray []int64Compat + +func (i64a int64CompatArray) Value() (driver.Value, error) { + v := make(Int64Array, len(i64a)) + for i := range i64a { + v[i] = int64(i64a[i]) + } + return v.Value() +} + +func (i64a *int64CompatArray) Scan(src interface{}) error { + s := Int64Array{} + if err := s.Scan(src); err != nil { + return err + } + dst := make([]int64Compat, len(s)) + for i := range s { + dst[i] = int64Compat(s[i]) + } + if len(dst) < 1 { + return nil + } + *i64a = dst + return nil +} + +type uintCompatArray []uintCompat + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) Test_Issue469_BadConnection() { + sess := s.Session() + + // Ask the PostgreSQL server to disconnect sessions that remain inactive for more + // than 1 second. + _, err := sess.SQL().Exec(`SET SESSION idle_in_transaction_session_timeout=1000`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // A query should start a new connection, even if the server disconnected us. + _, err = sess.Collection("artist").Find().Count() + s.NoError(err) + + // This is a new session, ask the PostgreSQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION idle_in_transaction_session_timeout=1000`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // At this point the server should have disconnected us. Let's try to create + // a transaction anyway. + err = sess.Tx(func(sess mydb.Session) error { + var err error + + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + return nil + }) + s.NoError(err) + + // This is a new session, ask the PostgreSQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION idle_in_transaction_session_timeout=1000`) + s.NoError(err) + + err = sess.Tx(func(sess mydb.Session) error { + var err error + + // This query should succeed. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + panic(err.Error()) + } + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // This query should fail because the server disconnected us in the middle + // of a transaction. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + + return nil + }) + + s.Error(err, "Expecting an error (can't recover from this)") +} + +func testPostgreSQLTypes(t *testing.T, sess mydb.Session) { + type PGTypeInline struct { + IntegerArrayPtr *Int64Array `db:"integer_array_ptr,omitempty"` + StringArrayPtr *StringArray `db:"string_array_ptr,omitempty"` + JSONBMapPtr *JSONBMap `db:"jsonb_map_ptr,omitempty"` + } + + type PGTypeAutoInline struct { + AutoIntegerArray []int64 `db:"auto_integer_array"` + AutoStringArray []string `db:"auto_string_array"` + AutoJSONBMap map[string]interface{} `db:"auto_jsonb_map"` + AutoJSONBMapString map[string]interface{} `db:"auto_jsonb_map_string"` + AutoJSONBMapInteger map[string]interface{} `db:"auto_jsonb_map_integer"` + } + + type PGType struct { + ID int64 `db:"id,omitempty"` + + UInt8Value uint8Compat `db:"uint8_value"` + UInt8ValueArray uint8CompatArray `db:"uint8_value_array"` + + Int64Value int64Compat `db:"int64_value"` + Int64ValueArray *int64CompatArray `db:"int64_value_array"` + + IntegerArray Int64Array `db:"integer_array"` + StringArray StringArray `db:"string_array,stringarray"` + JSONBMap JSONBMap `db:"jsonb_map"` + + RawJSONBMap *json.RawMessage `db:"raw_jsonb_map,omitempty"` + RawJSONBText *json.RawMessage `db:"raw_jsonb_text,omitempty"` + + PGTypeInline `db:",inline"` + + PGTypeAutoInline `db:",inline"` + + JSONBObject JSONB `db:"jsonb_object"` + JSONBArray JSONBArray `db:"jsonb_array"` + + CustomJSONBObject customJSONB `db:"custom_jsonb_object"` + AutoCustomJSONBObject customJSONB `db:"auto_custom_jsonb_object"` + + CustomJSONBObjectPtr *customJSONB `db:"custom_jsonb_object_ptr,omitempty"` + AutoCustomJSONBObjectPtr *customJSONB `db:"auto_custom_jsonb_object_ptr,omitempty"` + + AutoCustomJSONBObjectArray *customJSONBObjectArray `db:"auto_custom_jsonb_object_array"` + AutoCustomJSONBObjectMap *customJSONBObjectMap `db:"auto_custom_jsonb_object_map"` + + StringValue string `db:"string_value"` + IntegerValue int64 `db:"integer_value"` + VarcharValue string `db:"varchar_value"` + DecimalValue float64 `db:"decimal_value"` + + Int64CompatValue int64Compat `db:"integer_compat_value"` + UIntCompatValue uintCompat `db:"uinteger_compat_value"` + StringCompatValue stringCompat `db:"string_compat_value"` + + Int64CompatValueJSONBArray JSONBArray `db:"integer_compat_value_jsonb_array"` + UIntCompatValueJSONBArray JSONBArray `db:"uinteger_compat_value_jsonb_array"` + StringCompatValueJSONBArray JSONBArray `db:"string_compat_value_jsonb_array"` + + StringValuePtr *string `db:"string_value_ptr,omitempty"` + IntegerValuePtr *int64 `db:"integer_value_ptr,omitempty"` + VarcharValuePtr *string `db:"varchar_value_ptr,omitempty"` + DecimalValuePtr *float64 `db:"decimal_value_ptr,omitempty"` + + UUIDValueString *string `db:"uuid_value_string,omitempty"` + } + + integerValue := int64(10) + stringValue := string("ten") + decimalValue := float64(10.0) + + uuidStringValue := "52356d08-6a16-4839-9224-75f0a547e13c" + + integerArrayValue := Int64Array{1, 2, 3, 4} + stringArrayValue := StringArray{"a", "b", "c"} + jsonbMapValue := JSONBMap{"Hello": "World"} + rawJSONBMap := json.RawMessage(`{"foo": "bar"}`) + rawJSONBText := json.RawMessage(`{"age": [{">": "1h"}]}`) + + testValue := "Hello world!" + + origPgTypeTests := []PGType{ + PGType{ + UUIDValueString: &uuidStringValue, + }, + PGType{ + UInt8Value: 7, + UInt8ValueArray: uint8CompatArray{1, 2, 3, 4, 5, 6}, + }, + PGType{ + Int64Value: -1, + Int64ValueArray: &int64CompatArray{1, 2, 3, -4, 5, 6}, + }, + PGType{ + UInt8Value: 1, + UInt8ValueArray: uint8CompatArray{1, 2, 3, 4, 5, 6}, + }, + PGType{ + Int64Value: 1, + Int64ValueArray: &int64CompatArray{7, 7, 7}, + }, + PGType{ + Int64Value: 1, + }, + PGType{ + Int64Value: 99, + Int64ValueArray: nil, + }, + PGType{ + Int64CompatValue: -5, + UIntCompatValue: 3, + StringCompatValue: "abc", + }, + PGType{ + Int64CompatValueJSONBArray: JSONBArray{1.0, -2.0, 3.0, -4.0}, + UIntCompatValueJSONBArray: JSONBArray{1.0, 2.0, 3.0, 4.0}, + StringCompatValueJSONBArray: JSONBArray{"a", "b", "", "c"}, + }, + PGType{ + Int64CompatValueJSONBArray: JSONBArray(nil), + UIntCompatValueJSONBArray: JSONBArray(nil), + StringCompatValueJSONBArray: JSONBArray(nil), + }, + PGType{ + IntegerValuePtr: &integerValue, + StringValuePtr: &stringValue, + DecimalValuePtr: &decimalValue, + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMapString: map[string]interface{}{"a": "x", "b": "67"}, + AutoJSONBMapInteger: map[string]interface{}{"a": 12.0, "b": 13.0}, + }, + }, + PGType{ + RawJSONBMap: &rawJSONBMap, + RawJSONBText: &rawJSONBText, + }, + PGType{ + IntegerValue: integerValue, + StringValue: stringValue, + DecimalValue: decimalValue, + }, + PGType{ + IntegerArray: []int64{1, 2, 3, 4}, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoIntegerArray: Int64Array{1, 2, 3, 4}, + AutoStringArray: nil, + }, + }, + PGType{ + AutoCustomJSONBObjectArray: &customJSONBObjectArray{ + customJSONB{ + N: "Hello", + }, + customJSONB{ + N: "World", + }, + }, + AutoCustomJSONBObjectMap: &customJSONBObjectMap{ + "a": customJSONB{ + N: "Hello", + }, + "b": customJSONB{ + N: "World", + }, + }, + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}{ + "Hello": "world", + "Roses": "red", + }, + }, + JSONBArray: JSONBArray{float64(1), float64(2), float64(3), float64(4)}, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoIntegerArray: nil, + }, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}{}, + }, + JSONBArray: JSONBArray{}, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}(nil), + }, + JSONBArray: JSONBArray(nil), + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoStringArray: []string{"aaa", "bbb", "ccc"}, + }, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoStringArray: nil, + }, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}{"hello": "world!"}, + }, + }, + PGType{ + IntegerArray: []int64{1, 2, 3, 4}, + StringArray: []string{"a", "boo", "bar"}, + }, + PGType{ + StringValue: stringValue, + DecimalValue: decimalValue, + }, + PGType{ + IntegerArray: []int64{}, + }, + PGType{ + StringArray: []string{}, + }, + PGType{ + IntegerArray: []int64{}, + StringArray: []string{}, + }, + PGType{}, + PGType{ + IntegerArray: []int64{1}, + StringArray: []string{"a"}, + }, + PGType{ + PGTypeInline: PGTypeInline{ + IntegerArrayPtr: &integerArrayValue, + StringArrayPtr: &stringArrayValue, + JSONBMapPtr: &jsonbMapValue, + }, + }, + PGType{ + IntegerArray: []int64{0, 0, 0, 0}, + StringValue: testValue, + CustomJSONBObject: customJSONB{ + N: "Hello", + }, + AutoCustomJSONBObject: customJSONB{ + N: "World", + }, + StringArray: []string{"", "", "", ``, `""`}, + }, + PGType{ + CustomJSONBObject: customJSONB{}, + AutoCustomJSONBObject: customJSONB{}, + }, + PGType{ + CustomJSONBObject: customJSONB{ + N: "Hello 1", + }, + AutoCustomJSONBObject: customJSONB{ + N: "World 2", + }, + }, + PGType{ + CustomJSONBObjectPtr: nil, + AutoCustomJSONBObjectPtr: nil, + }, + PGType{ + CustomJSONBObjectPtr: &customJSONB{}, + AutoCustomJSONBObjectPtr: &customJSONB{}, + }, + PGType{ + CustomJSONBObjectPtr: &customJSONB{ + N: "Hello 3", + }, + AutoCustomJSONBObjectPtr: &customJSONB{ + N: "World 4", + }, + }, + PGType{ + StringValue: testValue, + }, + PGType{ + IntegerValue: integerValue, + IntegerValuePtr: &integerValue, + CustomJSONBObject: customJSONB{ + V: 4.4, + }, + }, + PGType{ + StringArray: []string{"a", "boo", "bar"}, + }, + PGType{ + StringArray: []string{"a", "boo", "bar", `""`}, + CustomJSONBObject: customJSONB{}, + }, + PGType{ + IntegerArray: []int64{0}, + StringArray: []string{""}, + }, + PGType{ + CustomJSONBObject: customJSONB{ + N: "Peter", + V: 5.56, + }, + }, + } + + for i := 0; i < 100; i++ { + pgTypeTests := make([]PGType, len(origPgTypeTests)) + perm := rand.Perm(len(origPgTypeTests)) + for i, v := range perm { + pgTypeTests[v] = origPgTypeTests[i] + } + + for i := range pgTypeTests { + record, err := sess.Collection("pg_types").Insert(pgTypeTests[i]) + assert.NoError(t, err) + + var actual PGType + err = sess.Collection("pg_types").Find(record.ID()).One(&actual) + assert.NoError(t, err) + + expected := pgTypeTests[i] + expected.ID = record.ID().(int64) + assert.Equal(t, expected, actual) + } + + for i := range pgTypeTests { + row, err := sess.SQL().InsertInto("pg_types").Values(pgTypeTests[i]).Returning("id").QueryRow() + assert.NoError(t, err) + + var id int64 + err = row.Scan(&id) + assert.NoError(t, err) + + var actual PGType + err = sess.Collection("pg_types").Find(id).One(&actual) + assert.NoError(t, err) + + expected := pgTypeTests[i] + expected.ID = id + + assert.Equal(t, expected, actual) + + var actual2 PGType + err = sess.SQL().SelectFrom("pg_types").Where("id = ?", id).One(&actual2) + assert.NoError(t, err) + assert.Equal(t, expected, actual2) + } + + inserter := sess.SQL().InsertInto("pg_types") + for i := range pgTypeTests { + inserter = inserter.Values(pgTypeTests[i]) + } + _, err := inserter.Exec() + assert.NoError(t, err) + + err = sess.Collection("pg_types").Truncate() + assert.NoError(t, err) + + batch := sess.SQL().InsertInto("pg_types").Batch(50) + go func() { + defer batch.Done() + for i := range pgTypeTests { + batch.Values(pgTypeTests[i]) + } + }() + + err = batch.Wait() + assert.NoError(t, err) + + var values []PGType + err = sess.SQL().SelectFrom("pg_types").All(&values) + assert.NoError(t, err) + + for i := range values { + expected := pgTypeTests[i] + expected.ID = values[i].ID + + assert.Equal(t, expected, values[i]) + } + } +} + +func (s *AdapterTests) TestOptionTypes() { + sess := s.Session() + + optionTypes := sess.Collection("option_types") + err := optionTypes.Truncate() + s.NoError(err) + + // TODO: lets do some benchmarking on these auto-wrapped option types.. + + // TODO: add nullable jsonb field mapped to a []string + + // A struct with wrapped option types defined in the struct tags + // for postgres string array and jsonb types + type optionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags"` + Settings map[string]interface{} `db:"settings"` + } + + // Item 1 + item1 := optionType{ + Name: "Food", + Tags: []string{"toronto", "pizza"}, + Settings: map[string]interface{}{"a": 1, "b": 2}, + } + + record, err := optionTypes.Insert(item1) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1Chk optionType + err = optionTypes.Find(record).One(&item1Chk) + s.NoError(err) + + s.Equal(float64(1), item1Chk.Settings["a"]) + s.Equal("toronto", item1Chk.Tags[0]) + + // Item 1 B + item1b := &optionType{ + Name: "Golang", + Tags: []string{"love", "it"}, + Settings: map[string]interface{}{"go": 1, "lang": 2}, + } + + record, err = optionTypes.Insert(item1b) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1bChk optionType + err = optionTypes.Find(record).One(&item1bChk) + s.NoError(err) + + s.Equal(float64(1), item1bChk.Settings["go"]) + s.Equal("love", item1bChk.Tags[0]) + + // Item 1 C + item1c := &optionType{ + Name: "Sup", Tags: []string{}, Settings: map[string]interface{}{}, + } + + record, err = optionTypes.Insert(item1c) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1cChk optionType + err = optionTypes.Find(record).One(&item1cChk) + s.NoError(err) + + s.Zero(len(item1cChk.Tags)) + s.Zero(len(item1cChk.Settings)) + + // An option type to pointer jsonb field + type optionType2 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags StringArray `db:"tags"` + Settings *JSONBMap `db:"settings"` + } + + item2 := optionType2{ + Name: "JS", Tags: []string{"hi", "bye"}, Settings: nil, + } + + record, err = optionTypes.Insert(item2) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item2Chk optionType2 + res := optionTypes.Find(record) + err = res.One(&item2Chk) + s.NoError(err) + + s.Equal(record.ID().(int64), item2Chk.ID) + + s.Equal(item2Chk.Name, item2.Name) + + s.Equal(item2Chk.Tags[0], item2.Tags[0]) + s.Equal(len(item2Chk.Tags), len(item2.Tags)) + + // Update the value + m := JSONBMap{} + m["lang"] = "javascript" + m["num"] = 31337 + item2.Settings = &m + err = res.Update(item2) + s.NoError(err) + + err = res.One(&item2Chk) + s.NoError(err) + + s.Equal(float64(31337), (*item2Chk.Settings)["num"].(float64)) + + s.Equal("javascript", (*item2Chk.Settings)["lang"]) + + // An option type to pointer string array field + type optionType3 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags *StringArray `db:"tags"` + Settings JSONBMap `db:"settings"` + } + + item3 := optionType3{ + Name: "Julia", + Tags: nil, + Settings: JSONBMap{"girl": true, "lang": true}, + } + + record, err = optionTypes.Insert(item3) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item3Chk optionType2 + err = optionTypes.Find(record).One(&item3Chk) + s.NoError(err) +} + +type Settings struct { + Name string `json:"name"` + Num int64 `json:"num"` +} + +func (s *Settings) Scan(src interface{}) error { + return ScanJSONB(s, src) +} +func (s Settings) Value() (driver.Value, error) { + return JSONBValue(s) +} + +func (s *AdapterTests) TestOptionTypeJsonbStruct() { + sess := s.Session() + + optionTypes := sess.Collection("option_types") + + err := optionTypes.Truncate() + s.NoError(err) + + type OptionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags StringArray `db:"tags"` + Settings Settings `db:"settings"` + } + + item1 := &OptionType{ + Name: "Hi", + Tags: []string{"aah", "ok"}, + Settings: Settings{Name: "a", Num: 123}, + } + + record, err := optionTypes.Insert(item1) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1Chk OptionType + err = optionTypes.Find(record).One(&item1Chk) + s.NoError(err) + + s.Equal(2, len(item1Chk.Tags)) + s.Equal("aah", item1Chk.Tags[0]) + s.Equal("a", item1Chk.Settings.Name) + s.Equal(int64(123), item1Chk.Settings.Num) +} + +func (s *AdapterTests) TestSchemaCollection() { + sess := s.Session() + + col := sess.Collection("test_schema.test") + _, err := col.Insert(map[string]int{"id": 9}) + s.Equal(nil, err) + + var dump []map[string]int + err = col.Find().All(&dump) + s.Nil(err) + s.Equal(1, len(dump)) + s.Equal(9, dump[0]["id"]) +} + +func (s *AdapterTests) Test_Issue340_MaxOpenConns() { + sess := s.Session() + + sess.SetMaxOpenConns(5) + + var wg sync.WaitGroup + for i := 0; i < 30; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + _, err := sess.SQL().Exec(fmt.Sprintf(`SELECT pg_sleep(1.%d)`, i)) + if err != nil { + s.T().Errorf("%v", err) + } + }(i) + } + + wg.Wait() + + sess.SetMaxOpenConns(0) +} + +func (s *AdapterTests) Test_Issue370_InsertUUID() { + sess := s.Session() + + { + type itemT struct { + ID *uuid.UUID `db:"id"` + Name string `db:"name"` + } + + newUUID := uuid.New() + + item1 := itemT{ + ID: &newUUID, + Name: "Jonny", + } + + col := sess.Collection("issue_370") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(item1.ID).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"id": item1.ID}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } + + { + type itemT struct { + ID uuid.UUID `db:"id"` + Name string `db:"name"` + } + + item1 := itemT{ + ID: uuid.New(), + Name: "Jonny", + } + + col := sess.Collection("issue_370") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(item1.ID).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"id": item1.ID}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } + + { + type itemT struct { + ID Int64Array `db:"id"` + Name string `db:"name"` + } + + item1 := itemT{ + ID: Int64Array{1, 2, 3}, + Name: "Vojtech", + } + + col := sess.Collection("issue_370_2") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(item1.ID).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"id": item1.ID}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } +} + +type issue602Organization struct { + ID string `json:"id" db:"id,omitempty"` + Name string `json:"name" db:"name"` + CreatedAt time.Time `json:"created_at,omitempty" db:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty" db:"updated_at,omitempty"` +} + +type issue602OrganizationStore struct { + mydb.Store +} + +func (r *issue602Organization) BeforeUpdate(mydb.Session) error { + return nil +} + +func (r *issue602Organization) Store(sess mydb.Session) mydb.Store { + return issue602OrganizationStore{sess.Collection("issue_602_organizations")} +} + +var _ interface { + mydb.Record + mydb.BeforeUpdateHook +} = &issue602Organization{} + +func (s *AdapterTests) Test_Issue602_IncorrectBinaryFormat() { + settingsWithBinaryMode := ConnectionURL{ + Database: settings.Database, + User: settings.User, + Password: settings.Password, + Host: settings.Host, + Options: map[string]string{ + "timezone": testsuite.TimeZone, + //"binary_parameters": "yes", + }, + } + + sess, err := Open(settingsWithBinaryMode) + if err != nil { + s.T().Errorf("%v", err) + } + + { + item := issue602Organization{ + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + err = sess.Save(&item) + s.NoError(err) + } + + { + item := issue602Organization{ + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item) + s.NoError(err) + } + + { + newUUID := uuid.New() + + item := issue602Organization{ + ID: newUUID.String(), + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + id, err := col.Insert(item) + s.NoError(err) + s.NotZero(id) + } + + { + item := issue602Organization{ + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + err = sess.Save(&item) + s.NoError(err) + } +} + +func (s *AdapterTests) TestInsertVarcharPrimaryKey() { + sess := s.Session() + + { + type itemT struct { + Address string `db:"address"` + Name string `db:"name"` + } + + item1 := itemT{ + Address: "1234", + Name: "Jonny", + } + + col := sess.Collection("varchar_primary_key") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(mydb.Cond{"address": item1.Address}).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"address": item1.Address}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } +} + +func (s *AdapterTests) Test_Issue409_TxOptions() { + sess := s.Session() + + { + err := sess.TxContext(context.Background(), func(tx mydb.Session) error { + col := tx.Collection("publication") + + row := map[string]interface{}{ + "title": "foo", + "author_id": 1, + } + err := col.InsertReturning(&row) + s.Error(err) + + return err + }, &sql.TxOptions{ + ReadOnly: true, + }) + s.Error(err) + s.True(strings.Contains(err.Error(), "read-only transaction")) + } +} + +func (s *AdapterTests) TestEscapeQuestionMark() { + sess := s.Session() + + var val bool + + { + res, err := sess.SQL().QueryRow(`SELECT '{"mykey":["val1", "val2"]}'::jsonb->'mykey' ?? ?`, "val2") + s.NoError(err) + + err = res.Scan(&val) + s.NoError(err) + s.Equal(true, val) + } + + { + res, err := sess.SQL().QueryRow(`SELECT ?::jsonb->'mykey' ?? ?`, `{"mykey":["val1", "val2"]}`, `val2`) + s.NoError(err) + + err = res.Scan(&val) + s.NoError(err) + s.Equal(true, val) + } + + { + res, err := sess.SQL().QueryRow(`SELECT ?::jsonb->? ?? ?`, `{"mykey":["val1", "val2"]}`, `mykey`, `val2`) + s.NoError(err) + + err = res.Scan(&val) + s.NoError(err) + s.Equal(true, val) + } +} + +func (s *AdapterTests) Test_Issue391_TextMode() { + testPostgreSQLTypes(s.T(), s.Session()) +} + +func (s *AdapterTests) Test_Issue391_BinaryMode() { + settingsWithBinaryMode := ConnectionURL{ + Database: settings.Database, + User: settings.User, + Password: settings.Password, + Host: settings.Host, + Options: map[string]string{ + "timezone": testsuite.TimeZone, + //"binary_parameters": "yes", + }, + } + + sess, err := Open(settingsWithBinaryMode) + if err != nil { + s.T().Errorf("%v", err) + } + defer sess.Close() + + testPostgreSQLTypes(s.T(), sess) +} + +func (s *AdapterTests) TestStringAndInt64Array() { + sess := s.Session() + driver := sess.Driver().(*sql.DB) + + defer func() { + _, _ = driver.Exec(`DROP TABLE IF EXISTS array_types`) + }() + + if _, err := driver.Exec(` + CREATE TABLE array_types ( + id serial primary key, + integers bigint[] DEFAULT NULL, + strings varchar(64)[] + )`); err != nil { + s.NoError(err) + } + + arrayTypes := sess.Collection("array_types") + err := arrayTypes.Truncate() + s.NoError(err) + + type arrayType struct { + ID int64 `db:"id,pk"` + Integers Int64Array `db:"integers"` + Strings StringArray `db:"strings"` + } + + tt := []arrayType{ + // Test nil arrays. + arrayType{ + ID: 1, + Integers: nil, + Strings: nil, + }, + + // Test empty arrays. + arrayType{ + ID: 2, + Integers: []int64{}, + Strings: []string{}, + }, + + // Test non-empty arrays. + arrayType{ + ID: 3, + Integers: []int64{1, 2, 3}, + Strings: []string{"1", "2", "3"}, + }, + } + + for _, item := range tt { + record, err := arrayTypes.Insert(item) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var itemCheck arrayType + err = arrayTypes.Find(record).One(&itemCheck) + s.NoError(err) + s.Len(itemCheck.Integers, len(item.Integers)) + s.Len(itemCheck.Strings, len(item.Strings)) + + s.Equal(item, itemCheck) + } +} + +func (s *AdapterTests) Test_Issue210() { + list := []string{ + `DROP TABLE IF EXISTS testing123`, + `DROP TABLE IF EXISTS hello`, + `CREATE TABLE IF NOT EXISTS testing123 ( + ID INT PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL + ) + `, + `CREATE TABLE IF NOT EXISTS hello ( + ID INT PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL + )`, + } + + sess := s.Session() + + err := sess.Tx(func(tx mydb.Session) error { + for i := range list { + _, err := tx.SQL().Exec(list[i]) + s.NoError(err) + if err != nil { + return err + } + } + return nil + }) + s.NoError(err) + + _, err = sess.Collection("testing123").Find().Count() + s.NoError(err) + + _, err = sess.Collection("hello").Find().Count() + s.NoError(err) +} + +func (s *AdapterTests) TestPreparedStatements() { + sess := s.Session() + + var val int + + { + stmt, err := sess.SQL().Prepare(`SELECT 1`) + s.NoError(err) + s.NotNil(stmt) + + q, err := stmt.Query() + s.NoError(err) + s.NotNil(q) + s.True(q.Next()) + + err = q.Scan(&val) + s.NoError(err) + + err = q.Close() + s.NoError(err) + + s.Equal(1, val) + + err = stmt.Close() + s.NoError(err) + } + + { + err := sess.Tx(func(tx mydb.Session) error { + stmt, err := tx.SQL().Prepare(`SELECT 2`) + s.NoError(err) + s.NotNil(stmt) + + q, err := stmt.Query() + s.NoError(err) + s.NotNil(q) + s.True(q.Next()) + + err = q.Scan(&val) + s.NoError(err) + + err = q.Close() + s.NoError(err) + + s.Equal(2, val) + + err = stmt.Close() + s.NoError(err) + + return nil + }) + s.NoError(err) + } + + { + stmt, err := sess.SQL().Select(3).Prepare() + s.NoError(err) + s.NotNil(stmt) + + q, err := stmt.Query() + s.NoError(err) + s.NotNil(q) + s.True(q.Next()) + + err = q.Scan(&val) + s.NoError(err) + + err = q.Close() + s.NoError(err) + + s.Equal(3, val) + + err = stmt.Close() + s.NoError(err) + } +} + +func (s *AdapterTests) TestNonTrivialSubqueries() { + sess := s.Session() + + // Creating test data + artist := sess.Collection("artist") + + artistNames := []string{"Ozzie", "Flea", "Slash", "Chrono"} + for _, artistName := range artistNames { + _, err := artist.Insert(map[string]string{ + "name": artistName, + }) + s.NoError(err) + } + + { + q, err := sess.SQL().Query(`WITH test AS (?) ?`, + sess.SQL().Select("id AS foo").From("artist"), + sess.SQL().Select("foo").From("test").Where("foo > ?", 0), + ) + + s.NoError(err) + s.NotNil(q) + + s.True(q.Next()) + + var number int + s.NoError(q.Scan(&number)) + + s.Equal(1, number) + s.NoError(q.Close()) + } + + { + builder := sess.SQL() + row, err := builder.QueryRow(`WITH test AS (?) ?`, + builder.Select("id AS foo").From("artist"), + builder.Select("foo").From("test").Where("foo > ?", 0), + ) + + s.NoError(err) + s.NotNil(row) + + var number int + s.NoError(row.Scan(&number)) + + s.Equal(1, number) + } + + { + res, err := sess.SQL().Exec( + `UPDATE artist a1 SET id = ?`, + sess.SQL().Select(mydb.Raw("id + 5")). + From("artist a2"). + Where("a2.id = a1.id"), + ) + + s.NoError(err) + s.NotNil(res) + } + + { + builder := sess.SQL() + + q, err := builder.Query(mydb.Raw(`WITH test AS (?) ?`, + builder.Select("id AS foo").From("artist"), + builder.Select("foo").From("test").Where("foo > ?", 0).OrderBy("foo"), + )) + + s.NoError(err) + s.NotNil(q) + + s.True(q.Next()) + + var number int + s.NoError(q.Scan(&number)) + + s.Equal(6, number) + s.NoError(q.Close()) + } + + { + res, err := sess.SQL().Exec(mydb.Raw(`UPDATE artist a1 SET id = ?`, + sess.SQL().Select(mydb.Raw("id + 7")).From("artist a2").Where("a2.id = a1.id"), + )) + + s.NoError(err) + s.NotNil(res) + } +} + +func (s *AdapterTests) Test_Issue601_ErrorCarrying() { + var items []interface{} + var err error + + sess := s.Session() + + _, err = sess.SQL().Exec(`DROP TABLE IF EXISTS issue_601`) + s.NoError(err) + + err = sess.Collection("issue_601").Find().All(&items) + s.Error(err) + + _, err = sess.SQL().Exec(`CREATE TABLE issue_601 (id INTEGER PRIMARY KEY)`) + s.NoError(err) + + err = sess.Collection("issue_601").Find().All(&items) + s.NoError(err) +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/postgresql/record_test.go b/adapter/postgresql/record_test.go new file mode 100644 index 0000000..a993c28 --- /dev/null +++ b/adapter/postgresql/record_test.go @@ -0,0 +1,20 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type RecordTests struct { + testsuite.RecordTestSuite +} + +func (s *RecordTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestRecord(t *testing.T) { + suite.Run(t, &RecordTests{}) +} diff --git a/adapter/postgresql/sql_test.go b/adapter/postgresql/sql_test.go new file mode 100644 index 0000000..21eae9e --- /dev/null +++ b/adapter/postgresql/sql_test.go @@ -0,0 +1,20 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type SQLTests struct { + testsuite.SQLTestSuite +} + +func (s *SQLTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestSQL(t *testing.T) { + suite.Run(t, &SQLTests{}) +} diff --git a/adapter/postgresql/template.go b/adapter/postgresql/template.go new file mode 100644 index 0000000..aa6eb69 --- /dev/null +++ b/adapter/postgresql/template.go @@ -0,0 +1,189 @@ +package postgresql + +import ( + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + adapterJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + adapterOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if defined .Columns}}({{.Columns | compile}}){{end}} + VALUES + {{if defined .Values}} + {{.Values | compile}} + {{else}} + (default) + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + adapterTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} RESTART IDENTITY + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var template = &exql.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + JoinLayout: adapterJoinLayout, + OnLayout: adapterOnLayout, + UsingLayout: adapterUsingLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + ComparisonOperator: map[adapter.ComparisonOperator]string{ + adapter.ComparisonOperatorRegExp: "~", + adapter.ComparisonOperatorNotRegExp: "!~", + }, +} diff --git a/adapter/postgresql/template_test.go b/adapter/postgresql/template_test.go new file mode 100644 index 0000000..1ac1fcb --- /dev/null +++ b/adapter/postgresql/template_test.go @@ -0,0 +1,262 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/assert" +) + +func TestTemplateSelect(t *testing.T) { + + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `SELECT * FROM "artist"`, + b.SelectFrom("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.Select().From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 1 OFFSET 5`, + b.Select().From("artist").Limit(1).Offset(5).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 1 OFFSET 5`, + b.Select().From("artist").Offset(5).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" OFFSET 5`, + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" OFFSET 5`, + b.Select().From("artist").Offset(5).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist"`, + b.Select("id").From("artist").String(), + ) + + assert.Equal( + `SELECT "id", "name" FROM "artist"`, + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("name" = $1)`, + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name LIKE $1)`, + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`, + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1)`, + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`, + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`, + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`, + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" NATURAL JOIN "publication"`, + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`, + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + `SELECT * FROM "artist" CROSS JOIN "publication"`, + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" USING ("id")`, + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Raw("DATE()")).String(), + ) +} + +func TestTemplateInsert(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) +} + +func TestTemplateUpdate(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) +} + +func TestTemplateDelete(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1)`, + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + b.DeleteFrom("artist").Where("id > 5").String(), + ) +} diff --git a/adapter/sqlite/Makefile b/adapter/sqlite/Makefile new file mode 100644 index 0000000..13bb432 --- /dev/null +++ b/adapter/sqlite/Makefile @@ -0,0 +1,27 @@ +SHELL ?= bash +DB_NAME ?= sqlite3-test.db +TEST_FLAGS ?= + +export DB_NAME + +export TEST_FLAGS + +build: + go build && go install + +require-client: + @if [ -z "$$(which sqlite3)" ]; then \ + echo 'Missing "sqlite3" command. Please install SQLite3 and try again.' && \ + exit 1; \ + fi + +reset-db: require-client + rm -f $(DB_NAME) + +test: reset-db + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +test-extended: test diff --git a/adapter/sqlite/README.md b/adapter/sqlite/README.md new file mode 100644 index 0000000..5d73e1a --- /dev/null +++ b/adapter/sqlite/README.md @@ -0,0 +1,4 @@ +# SQLite adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/sqlite/](https://upper.io/v4/adapter/sqlite/). diff --git a/adapter/sqlite/collection.go b/adapter/sqlite/collection.go new file mode 100644 index 0000000..ca0c8c2 --- /dev/null +++ b/adapter/sqlite/collection.go @@ -0,0 +1,49 @@ +package sqlite + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +type collectionAdapter struct { +} + +func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) { + columnNames, columnValues, err := sqlbuilder.Map(item, nil) + if err != nil { + return nil, err + } + + pKey, err := col.PrimaryKeys() + if err != nil { + return nil, err + } + + q := col.SQL().InsertInto(col.Name()). + Columns(columnNames...). + Values(columnValues...) + + var res sql.Result + if res, err = q.Exec(); err != nil { + return nil, err + } + + if len(pKey) <= 1 { + return res.LastInsertId() + } + + keyMap := mydb.Cond{} + + for i := range columnNames { + for j := 0; j < len(pKey); j++ { + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] + } + } + } + + return keyMap, nil +} diff --git a/adapter/sqlite/connection.go b/adapter/sqlite/connection.go new file mode 100644 index 0000000..ef6bf9a --- /dev/null +++ b/adapter/sqlite/connection.go @@ -0,0 +1,89 @@ +package sqlite + +import ( + "fmt" + "net/url" + "path/filepath" + "runtime" + "strings" +) + +const connectionScheme = `file` + +// ConnectionURL implements a SQLite connection struct. +type ConnectionURL struct { + Database string + Options map[string]string +} + +func (c ConnectionURL) String() (s string) { + vv := url.Values{} + + if c.Database == "" { + return "" + } + + // Did the user provided a full database path? + if !strings.HasPrefix(c.Database, "/") { + c.Database, _ = filepath.Abs(c.Database) + if runtime.GOOS == "windows" { + // Closes https://github.com/upper/db/issues/60 + c.Database = "/" + strings.Replace(c.Database, `\`, `/`, -1) + } + } + + // Do we have any options? + if c.Options == nil { + c.Options = map[string]string{} + } + + if _, ok := c.Options["_busy_timeout"]; !ok { + c.Options["_busy_timeout"] = "10000" + } + + // Converting options into URL values. + for k, v := range c.Options { + vv.Set(k, v) + } + + // Building URL. + u := url.URL{ + Scheme: connectionScheme, + Path: c.Database, + RawQuery: vv.Encode(), + } + + return u.String() +} + +// ParseURL parses s into a ConnectionURL struct. +func ParseURL(s string) (conn ConnectionURL, err error) { + var u *url.URL + + if !strings.HasPrefix(s, connectionScheme+"://") { + return conn, fmt.Errorf(`Expecting file:// connection scheme.`) + } + + if u, err = url.Parse(s); err != nil { + return conn, err + } + + conn.Database = u.Host + u.Path + conn.Options = map[string]string{} + + var vv url.Values + + if vv, err = url.ParseQuery(u.RawQuery); err != nil { + return conn, err + } + + for k := range vv { + conn.Options[k] = vv.Get(k) + } + + if _, ok := conn.Options["cache"]; !ok { + conn.Options["cache"] = "shared" + } + + return conn, err +} diff --git a/adapter/sqlite/connection_test.go b/adapter/sqlite/connection_test.go new file mode 100644 index 0000000..1f99f3d --- /dev/null +++ b/adapter/sqlite/connection_test.go @@ -0,0 +1,88 @@ +package sqlite + +import ( + "path/filepath" + "testing" +) + +func TestConnectionURL(t *testing.T) { + + c := ConnectionURL{} + + // Default connection string is only the protocol. + if c.String() != "" { + t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String()) + } + + // Adding a database name. + c.Database = "myfilename" + + absoluteName, _ := filepath.Abs(c.Database) + + if c.String() != "file://"+absoluteName+"?_busy_timeout=10000" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Adding an option. + c.Options = map[string]string{ + "cache": "foobar", + "mode": "ro", + } + + if c.String() != "file://"+absoluteName+"?_busy_timeout=10000&cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting another database. + c.Database = "/another/database" + + if c.String() != `file:///another/database?_busy_timeout=10000&cache=foobar&mode=ro` { + t.Fatal(`Test failed, got:`, c.String()) + } + +} + +func TestParseConnectionURL(t *testing.T) { + var u ConnectionURL + var s string + var err error + + s = "file://mydatabase.db" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "mydatabase.db" { + t.Fatal("Failed to parse database.") + } + + if u.Options["cache"] != "shared" { + t.Fatal("If not defined, cache should be shared by default.") + } + + s = "file:///path/to/my/database.db?_busy_timeout=10000&mode=ro&cache=foobar" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "/path/to/my/database.db" { + t.Fatal("Failed to parse username.") + } + + if u.Options["cache"] != "foobar" { + t.Fatal("Expecting option.") + } + + if u.Options["mode"] != "ro" { + t.Fatal("Expecting option.") + } + + s = "http://example.org" + + if _, err = ParseURL(s); err == nil { + t.Fatal("Expecting error.") + } + +} diff --git a/adapter/sqlite/database.go b/adapter/sqlite/database.go new file mode 100644 index 0000000..dfbada3 --- /dev/null +++ b/adapter/sqlite/database.go @@ -0,0 +1,168 @@ +// Package sqlite wraps the github.com/lib/sqlite SQLite driver. See +// https://github.com/upper/db/adapter/sqlite for documentation, particularities and +// usage examples. +package sqlite + +import ( + "context" + "database/sql" + "fmt" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/compat" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + _ "github.com/mattn/go-sqlite3" // SQLite3 driver. +) + +// database is the actual implementation of Database +type database struct { +} + +func (*database) Template() *exql.Template { + return template +} + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + return sql.Open("sqlite3", dsn) +} + +func (*database) Collections(sess sqladapter.Session) (collections []string, err error) { + q := sess.SQL(). + Select("tbl_name"). + From("sqlite_master"). + Where("type = ?", "table") + + iter := q.Iterator() + defer iter.Close() + + for iter.Next() { + var tableName string + if err := iter.Scan(&tableName); err != nil { + return nil, err + } + collections = append(collections, tableName) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return collections, nil +} + +func (*database) StatementExec(sess sqladapter.Session, ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) { + if sess.Transaction() != nil { + return compat.ExecContext(sess.Driver().(*sql.Tx), ctx, query, args) + } + + sqlTx, err := compat.BeginTx(sess.Driver().(*sql.DB), ctx, nil) + if err != nil { + return nil, err + } + + if res, err = compat.ExecContext(sqlTx, ctx, query, args); err != nil { + _ = sqlTx.Rollback() + return nil, err + } + + if err = sqlTx.Commit(); err != nil { + return nil, err + } + + return res, err +} + +func (*database) NewCollection() sqladapter.CollectionAdapter { + return &collectionAdapter{} +} + +func (*database) LookupName(sess sqladapter.Session) (string, error) { + connURL := sess.ConnectionURL() + if connURL != nil { + connURL, err := ParseURL(connURL.String()) + if err != nil { + return "", err + } + return connURL.Database, nil + } + + // sess.ConnectionURL() is nil if using sqlite.New + rows, err := sess.SQL().Query(exql.RawSQL("PRAGMA database_list")) + if err != nil { + return "", err + } + dbInfo := struct { + Name string `db:"name"` + File string `db:"file"` + }{} + + if err := sess.SQL().NewIterator(rows).One(&dbInfo); err != nil { + return "", err + } + if dbInfo.File != "" { + return dbInfo.File, nil + } + // dbInfo.File is empty if in memory mode + return dbInfo.Name, nil +} + +func (*database) TableExists(sess sqladapter.Session, name string) error { + q := sess.SQL(). + Select("tbl_name"). + From("sqlite_master"). + Where("type = 'table' AND tbl_name = ?", name) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return err + } + return nil + } + if err := iter.Err(); err != nil { + return err + } + + return mydb.ErrCollectionDoesNotExist +} + +func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) { + pk := make([]string, 0, 1) + + stmt := exql.RawSQL(fmt.Sprintf("PRAGMA TABLE_INFO('%s')", tableName)) + + rows, err := sess.SQL().Query(stmt) + if err != nil { + return nil, err + } + + columns := []struct { + Name string `db:"name"` + PK int `db:"pk"` + }{} + + if err := sess.SQL().NewIterator(rows).All(&columns); err != nil { + return nil, err + } + + maxValue := -1 + + for _, column := range columns { + if column.PK > 0 && column.PK > maxValue { + maxValue = column.PK + } + } + + if maxValue > 0 { + for _, column := range columns { + if column.PK > 0 { + pk = append(pk, column.Name) + } + } + } + + return pk, nil +} diff --git a/adapter/sqlite/generic_test.go b/adapter/sqlite/generic_test.go new file mode 100644 index 0000000..df65544 --- /dev/null +++ b/adapter/sqlite/generic_test.go @@ -0,0 +1,20 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/sqlite/helper_test.go b/adapter/sqlite/helper_test.go new file mode 100644 index 0000000..8c699da --- /dev/null +++ b/adapter/sqlite/helper_test.go @@ -0,0 +1,170 @@ +package sqlite + +import ( + "database/sql" + "os" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), +} + +type Helper struct { + sess mydb.Session +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return "sqlite" +} + +func (h *Helper) TearDown() error { + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + batch := []string{ + `PRAGMA foreign_keys=OFF`, + + `BEGIN TRANSACTION`, + + `DROP TABLE IF EXISTS artist`, + `CREATE TABLE artist ( + id integer primary key, + name varchar(60) + )`, + + `DROP TABLE IF EXISTS publication`, + `CREATE TABLE publication ( + id integer primary key, + title varchar(80), + author_id integer + )`, + + `DROP TABLE IF EXISTS review`, + `CREATE TABLE review ( + id integer primary key, + publication_id integer, + name varchar(80), + comments text, + created datetime + )`, + + `DROP TABLE IF EXISTS data_types`, + `CREATE TABLE data_types ( + id integer primary key, + _uint integer, + _uintptr integer, + _uint8 integer, + _uint16 int, + _uint32 int, + _uint64 int, + _int integer, + _int8 integer, + _int16 integer, + _int32 integer, + _int64 integer, + _float32 real, + _float64 real, + _byte integer, + _rune integer, + _bool integer, + _string text, + _blob blob, + _date datetime, + _nildate datetime, + _ptrdate datetime, + _defaultdate datetime default current_timestamp, + _time text + )`, + + `DROP TABLE IF EXISTS stats_test`, + `CREATE TABLE stats_test ( + id integer primary key, + numeric integer, + value integer + )`, + + `DROP TABLE IF EXISTS composite_keys`, + `CREATE TABLE composite_keys ( + code VARCHAR(255) default '', + user_id VARCHAR(255) default '', + some_val VARCHAR(255) default '', + primary key (code, user_id) + )`, + + `DROP TABLE IF EXISTS "birthdays"`, + `CREATE TABLE "birthdays" ( + "id" INTEGER PRIMARY KEY, + "name" VARCHAR(50) DEFAULT NULL, + "born" DATETIME DEFAULT NULL, + "born_ut" INTEGER + )`, + + `DROP TABLE IF EXISTS "fibonacci"`, + `CREATE TABLE "fibonacci" ( + "id" INTEGER PRIMARY KEY, + "input" INTEGER, + "output" INTEGER + )`, + + `DROP TABLE IF EXISTS "is_even"`, + `CREATE TABLE "is_even" ( + "input" INTEGER, + "is_even" INTEGER + )`, + + `DROP TABLE IF EXISTS "CaSe_TesT"`, + `CREATE TABLE "CaSe_TesT" ( + "id" INTEGER PRIMARY KEY, + "case_test" VARCHAR + )`, + + `DROP TABLE IF EXISTS accounts`, + `CREATE TABLE accounts ( + id integer primary key, + name varchar, + disabled integer, + created_at datetime default current_timestamp + )`, + + `DROP TABLE IF EXISTS users`, + `CREATE TABLE users ( + id integer primary key, + account_id integer, + username varchar UNIQUE + )`, + + `DROP TABLE IF EXISTS logs`, + `CREATE TABLE logs ( + id integer primary key, + message VARCHAR + )`, + + `COMMIT`, + } + + for _, query := range batch { + driver := h.sess.Driver().(*sql.DB) + if _, err := driver.Exec(query); err != nil { + return err + } + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/sqlite/record_test.go b/adapter/sqlite/record_test.go new file mode 100644 index 0000000..7a193c3 --- /dev/null +++ b/adapter/sqlite/record_test.go @@ -0,0 +1,20 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type RecordTests struct { + testsuite.RecordTestSuite +} + +func (s *RecordTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestRecord(t *testing.T) { + suite.Run(t, &RecordTests{}) +} diff --git a/adapter/sqlite/sql_test.go b/adapter/sqlite/sql_test.go new file mode 100644 index 0000000..4725c88 --- /dev/null +++ b/adapter/sqlite/sql_test.go @@ -0,0 +1,20 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type SQLTests struct { + testsuite.SQLTestSuite +} + +func (s *SQLTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestSQL(t *testing.T) { + suite.Run(t, &SQLTests{}) +} diff --git a/adapter/sqlite/sqlite.go b/adapter/sqlite/sqlite.go new file mode 100644 index 0000000..519489d --- /dev/null +++ b/adapter/sqlite/sqlite.go @@ -0,0 +1,30 @@ +package sqlite + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// Adapter is the public name of the adapter. +const Adapter = `sqlite` + +var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{}) + +// Open establishes a connection to the database server and returns a +// mydb.Session instance (which is compatible with mydb.Session). +func Open(connURL mydb.ConnectionURL) (mydb.Session, error) { + return registeredAdapter.OpenDSN(connURL) +} + +// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value. +func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + return registeredAdapter.NewTx(sqlTx) +} + +// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value. +func New(sqlDB *sql.DB) (mydb.Session, error) { + return registeredAdapter.New(sqlDB) +} diff --git a/adapter/sqlite/sqlite_test.go b/adapter/sqlite/sqlite_test.go new file mode 100644 index 0000000..5e99ef5 --- /dev/null +++ b/adapter/sqlite/sqlite_test.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "path/filepath" + "testing" + + "database/sql" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) Test_Issue633_OpenSession() { + sess, err := Open(settings) + s.NoError(err) + defer sess.Close() + + absoluteName, _ := filepath.Abs(settings.Database) + s.Equal(absoluteName, sess.Name()) +} + +func (s *AdapterTests) Test_Issue633_NewAdapterWithFile() { + sqldb, err := sql.Open("sqlite3", settings.Database) + s.NoError(err) + + sess, err := New(sqldb) + s.NoError(err) + defer sess.Close() + + absoluteName, _ := filepath.Abs(settings.Database) + s.Equal(absoluteName, sess.Name()) +} + +func (s *AdapterTests) Test_Issue633_NewAdapterWithMemory() { + sqldb, err := sql.Open("sqlite3", ":memory:") + s.NoError(err) + + sess, err := New(sqldb) + s.NoError(err) + defer sess.Close() + + s.Equal("main", sess.Name()) +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/sqlite/template.go b/adapter/sqlite/template.go new file mode 100644 index 0000000..e046d82 --- /dev/null +++ b/adapter/sqlite/template.go @@ -0,0 +1,187 @@ +package sqlite + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + adapterJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + adapterOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + {{if not .Limit}} + LIMIT -1 + {{end}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if .Columns }}({{.Columns | compile}}){{end}} + {{if defined .Values}} + VALUES + {{.Values | compile}} + {{else}} + DEFAULT VALUES + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + adapterTruncateLayout = ` + DELETE FROM {{.Table | compile}} + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var template = &exql.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + JoinLayout: adapterJoinLayout, + OnLayout: adapterOnLayout, + UsingLayout: adapterUsingLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), +} diff --git a/adapter/sqlite/template_test.go b/adapter/sqlite/template_test.go new file mode 100644 index 0000000..f3a4793 --- /dev/null +++ b/adapter/sqlite/template_test.go @@ -0,0 +1,246 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/assert" +) + +func TestTemplateSelect(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `SELECT * FROM "artist"`, + b.SelectFrom("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.Select().From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT -1 OFFSET 5`, + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist"`, + b.Select("id").From("artist").String(), + ) + + assert.Equal( + `SELECT "id", "name" FROM "artist"`, + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("name" = $1)`, + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name LIKE $1)`, + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`, + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1)`, + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`, + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`, + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`, + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" NATURAL JOIN "publication"`, + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`, + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + `SELECT * FROM "artist" CROSS JOIN "publication"`, + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" USING ("id")`, + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Raw("DATE()")).String(), + ) +} + +func TestTemplateInsert(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) +} + +func TestTemplateUpdate(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) +} + +func TestTemplateDelete(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1)`, + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + b.DeleteFrom("artist").Where("id > 5").String(), + ) +} diff --git a/clauses.go b/clauses.go new file mode 100644 index 0000000..7e8d69d --- /dev/null +++ b/clauses.go @@ -0,0 +1,468 @@ +package mydb + +import ( + "context" + "fmt" +) + +// Selector represents a SELECT statement. +type Selector interface { + // Columns defines which columns to retrive. + // + // You should call From() after Columns() if you want to query data from an + // specific table. + // + // s.Columns("name", "last_name").From(...) + // + // It is also possible to use an alias for the column, this could be handy if + // you plan to use the alias later, use the "AS" keyword to denote an alias. + // + // s.Columns("name AS n") + // + // or the shortcut: + // + // s.Columns("name n") + // + // If you don't want the column to be escaped use the db.Raw + // function. + // + // s.Columns(db.Raw("MAX(id)")) + // + // The above statement is equivalent to: + // + // s.Columns(db.Func("MAX", "id")) + Columns(columns ...interface{}) Selector + + // From represents a FROM clause and is tipically used after Columns(). + // + // FROM defines from which table data is going to be retrieved + // + // s.Columns(...).From("people") + // + // It is also possible to use an alias for the table, this could be handy if + // you plan to use the alias later: + // + // s.Columns(...).From("people AS p").Where("p.name = ?", ...) + // + // Or with the shortcut: + // + // s.Columns(...).From("people p").Where("p.name = ?", ...) + From(tables ...interface{}) Selector + + // Distict represents a DISTINCT clause + // + // DISTINCT is used to ask the database to return only values that are + // different. + Distinct(columns ...interface{}) Selector + + // As defines an alias for a table. + As(string) Selector + + // Where specifies the conditions that columns must match in order to be + // retrieved. + // + // Where accepts raw strings and fmt.Stringer to define conditions and + // interface{} to specify parameters. Be careful not to embed any parameters + // within the SQL part as that could lead to security problems. You can use + // que question mark (?) as placeholder for parameters. + // + // s.Where("name = ?", "max") + // + // s.Where("name = ? AND last_name = ?", "Mary", "Doe") + // + // s.Where("last_name IS NULL") + // + // You can also use other types of parameters besides only strings, like: + // + // s.Where("online = ? AND last_logged <= ?", true, time.Now()) + // + // and Where() will transform them into strings before feeding them to the + // database. + // + // When an unknown type is provided, Where() will first try to match it with + // the Marshaler interface, then with fmt.Stringer and finally, if the + // argument does not satisfy any of those interfaces Where() will use + // fmt.Sprintf("%v", arg) to transform the type into a string. + // + // Subsequent calls to Where() will overwrite previously set conditions, if + // you want these new conditions to be appended use And() instead. + Where(conds ...interface{}) Selector + + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. + And(conds ...interface{}) Selector + + // GroupBy represents a GROUP BY statement. + // + // GROUP BY defines which columns should be used to aggregate and group + // results. + // + // s.GroupBy("country_id") + // + // GroupBy accepts more than one column: + // + // s.GroupBy("country_id", "city_id") + GroupBy(columns ...interface{}) Selector + + // Having(...interface{}) Selector + + // OrderBy represents a ORDER BY statement. + // + // ORDER BY is used to define which columns are going to be used to sort + // results. + // + // Use the column name to sort results in ascendent order. + // + // // "last_name" ASC + // s.OrderBy("last_name") + // + // Prefix the column name with the minus sign (-) to sort results in + // descendent order. + // + // // "last_name" DESC + // s.OrderBy("-last_name") + // + // If you would rather be very explicit, you can also use ASC and DESC. + // + // s.OrderBy("last_name ASC") + // + // s.OrderBy("last_name DESC", "name ASC") + OrderBy(columns ...interface{}) Selector + + // Join represents a JOIN statement. + // + // JOIN statements are used to define external tables that the user wants to + // include as part of the result. + // + // You can use the On() method after Join() to define the conditions of the + // join. + // + // s.Join("author").On("author.id = book.author_id") + // + // If you don't specify conditions for the join, a NATURAL JOIN will be used. + // + // On() accepts the same arguments as Where() + // + // You can also use Using() after Join(). + // + // s.Join("employee").Using("department_id") + Join(table ...interface{}) Selector + + // FullJoin is like Join() but with FULL JOIN. + FullJoin(...interface{}) Selector + + // CrossJoin is like Join() but with CROSS JOIN. + CrossJoin(...interface{}) Selector + + // RightJoin is like Join() but with RIGHT JOIN. + RightJoin(...interface{}) Selector + + // LeftJoin is like Join() but with LEFT JOIN. + LeftJoin(...interface{}) Selector + + // Using represents the USING clause. + // + // USING is used to specifiy columns to join results. + // + // s.LeftJoin(...).Using("country_id") + Using(...interface{}) Selector + + // On represents the ON clause. + // + // ON is used to define conditions on a join. + // + // s.Join(...).On("b.author_id = a.id") + On(...interface{}) Selector + + // Limit represents the LIMIT parameter. + // + // LIMIT defines the maximum number of rows to return from the table. A + // negative limit cancels any previous limit settings. + // + // s.Limit(42) + Limit(int) Selector + + // Offset represents the OFFSET parameter. + // + // OFFSET defines how many results are going to be skipped before starting to + // return results. A negative offset cancels any previous offset settings. + // + // s.Offset(56) + Offset(int) Selector + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Selector + + // Paginate returns a paginator that can display a paginated lists of items. + // Paginators ignore previous Offset and Limit settings. Page numbering + // starts at 1. + Paginate(uint) Paginator + + // Iterator provides methods to iterate over the results returned by the + // Selector. + Iterator() Iterator + + // IteratorContext provides methods to iterate over the results returned by + // the Selector. + IteratorContext(ctx context.Context) Iterator + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLGetter provides methods to compile and execute a query that returns + // results. + SQLGetter + + // ResultMapper provides methods to retrieve and map results. + ResultMapper + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Selector` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} +} + +// Inserter represents an INSERT statement. +type Inserter interface { + // Columns represents the COLUMNS clause. + // + // COLUMNS defines the columns that we are going to provide values for. + // + // i.Columns("name", "last_name").Values(...) + Columns(...string) Inserter + + // Values represents the VALUES clause. + // + // VALUES defines the values of the columns. + // + // i.Columns(...).Values("María", "Méndez") + // + // i.Values(map[string][string]{"name": "María"}) + Values(...interface{}) Inserter + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} + + // Returning represents a RETURNING clause. + // + // RETURNING specifies which columns should be returned after INSERT. + // + // RETURNING may not be supported by all SQL databases. + Returning(columns ...string) Inserter + + // Iterator provides methods to iterate over the results returned by the + // Inserter. This is only possible when using Returning(). + Iterator() Iterator + + // IteratorContext provides methods to iterate over the results returned by + // the Inserter. This is only possible when using Returning(). + IteratorContext(ctx context.Context) Iterator + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Inserter + + // Batch provies a BatchInserter that can be used to insert many elements at + // once by issuing several calls to Values(). It accepts a size parameter + // which defines the batch size. If size is < 1, the batch size is set to 1. + Batch(size int) BatchInserter + + // SQLExecer provides the Exec method. + SQLExecer + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLGetter provides methods to return query results from INSERT statements + // that support such feature (e.g.: queries with Returning). + SQLGetter + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Inserter` into a string. + fmt.Stringer +} + +// Deleter represents a DELETE statement. +type Deleter interface { + // Where represents the WHERE clause. + // + // See Selector.Where for documentation and usage examples. + Where(...interface{}) Deleter + + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. + And(conds ...interface{}) Deleter + + // Limit represents the LIMIT clause. + // + // See Selector.Limit for documentation and usage examples. + Limit(int) Deleter + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Deleter + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLExecer provides the Exec method. + SQLExecer + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Inserter` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} +} + +// Updater represents an UPDATE statement. +type Updater interface { + // Set represents the SET clause. + Set(...interface{}) Updater + + // Where represents the WHERE clause. + // + // See Selector.Where for documentation and usage examples. + Where(...interface{}) Updater + + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. + And(conds ...interface{}) Updater + + // Limit represents the LIMIT parameter. + // + // See Selector.Limit for documentation and usage examples. + Limit(int) Updater + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLExecer provides the Exec method. + SQLExecer + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Inserter` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Updater +} + +// Paginator provides tools for splitting the results of a query into chunks +// containing a fixed number of items. +type Paginator interface { + // Page sets the page number. + Page(uint) Paginator + + // Cursor defines the column that is going to be taken as basis for + // cursor-based pagination. + // + // Example: + // + // a = q.Paginate(10).Cursor("id") + // b = q.Paginate(12).Cursor("-id") + // + // You can set "" as cursorColumn to disable cursors. + Cursor(cursorColumn string) Paginator + + // NextPage returns the next page according to the cursor. It expects a + // cursorValue, which is the value the cursor column has on the last item of + // the current result set (lower bound). + // + // Example: + // + // p = q.NextPage(items[len(items)-1].ID) + NextPage(cursorValue interface{}) Paginator + + // PrevPage returns the previous page according to the cursor. It expects a + // cursorValue, which is the value the cursor column has on the fist item of + // the current result set (mydb bound). + // + // Example: + // + // p = q.PrevPage(items[0].ID) + PrevPage(cursorValue interface{}) Paginator + + // TotalPages returns the total number of pages in the query. + TotalPages() (uint, error) + + // TotalEntries returns the total number of entries in the query. + TotalEntries() (uint64, error) + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLGetter provides methods to compile and execute a query that returns + // results. + SQLGetter + + // Iterator provides methods to iterate over the results returned by the + // Selector. + Iterator() Iterator + + // IteratorContext provides methods to iterate over the results returned by + // the Selector. + IteratorContext(ctx context.Context) Iterator + + // ResultMapper provides methods to retrieve and map results. + ResultMapper + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Selector` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} +} + +// ResultMapper defined methods for a result mapper. +type ResultMapper interface { + // All dumps all the results into the given slice, All() expects a pointer to + // slice of maps or structs. + // + // The behaviour of One() extends to each one of the results. + All(destSlice interface{}) error + + // One maps the row that is in the current query cursor into the + // given interface, which can be a pointer to either a map or a + // struct. + // + // If dest is a pointer to map, each one of the columns will create a new map + // key and the values of the result will be set as values for the keys. + // + // Depending on the type of map key and value, the results columns and values + // may need to be transformed. + // + // If dest if a pointer to struct, each one of the fields will be tested for + // a `db` tag which defines the column mapping. The value of the result will + // be set as the value of the field. + One(dest interface{}) error +} + +// BatchInserter provides an interface to do massive insertions in batches. +type BatchInserter interface { + // Values pushes column values to be inserted as part of the batch. + Values(...interface{}) BatchInserter + + // NextResult dumps the next slice of results to dst, which can mean having + // the IDs of all inserted elements in the batch. + NextResult(dst interface{}) bool + + // Done signals that no more elements are going to be added. + Done() + + // Wait blocks until the whole batch is executed. + Wait() error + + // Err returns the last error that happened while executing the batch (or nil + // if no error happened). + Err() error +} diff --git a/collection.go b/collection.go new file mode 100644 index 0000000..6ac09ec --- /dev/null +++ b/collection.go @@ -0,0 +1,45 @@ +package mydb + +// Collection defines methods to work with database tables or collections. +type Collection interface { + + // Name returns the name of the collection. + Name() string + + // Session returns the Session that was used to create the collection + // reference. + Session() Session + + // Find defines a new result set. + Find(...interface{}) Result + + Count() (uint64, error) + + // Insert inserts a new item into the collection, the type of this item could + // be a map, a struct or pointer to either of them. If the call succeeds and + // if the collection has a primary key, Insert returns the ID of the newly + // added element as an `interface{}`. The underlying type of this ID depends + // on both the database adapter and the column storing the ID. The ID + // returned by Insert() could be passed directly to Find() to retrieve the + // newly added element. + Insert(interface{}) (InsertResult, error) + + // InsertReturning is like Insert() but it takes a pointer to map or struct + // and, if the operation succeeds, updates it with data from the newly + // inserted row. If the database does not support transactions this method + // returns db.ErrUnsupported. + InsertReturning(interface{}) error + + // UpdateReturning takes a pointer to a map or struct and tries to update the + // row the item is refering to. If the element is updated sucessfully, + // UpdateReturning will fetch the row and update the fields of the passed + // item. If the database does not support transactions this method returns + // db.ErrUnsupported + UpdateReturning(interface{}) error + + // Exists returns true if the collection exists, false otherwise. + Exists() (bool, error) + + // Truncate removes all elements on the collection. + Truncate() error +} diff --git a/comparison.go b/comparison.go new file mode 100644 index 0000000..d79113a --- /dev/null +++ b/comparison.go @@ -0,0 +1,158 @@ +package mydb + +import ( + "reflect" + "time" + + "git.hexq.cn/tiglog/mydb/internal/adapter" +) + +// Comparison represents a relationship between values. +type Comparison struct { + *adapter.Comparison +} + +// Gte is a comparison that means: is greater than or equal to value. +func Gte(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)} +} + +// Lte is a comparison that means: is less than or equal to value. +func Lte(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)} +} + +// Eq is a comparison that means: is equal to value. +func Eq(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, value)} +} + +// NotEq is a comparison that means: is not equal to value. +func NotEq(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, value)} +} + +// Gt is a comparison that means: is greater than value. +func Gt(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)} +} + +// Lt is a comparison that means: is less than value. +func Lt(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)} +} + +// In is a comparison that means: is any of the values. +func In(value ...interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))} +} + +// AnyOf is a comparison that means: is any of the values of the slice. +func AnyOf(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))} +} + +// NotIn is a comparison that means: is none of the values. +func NotIn(value ...interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))} +} + +// NotAnyOf is a comparison that means: is none of the values of the slice. +func NotAnyOf(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))} +} + +// After is a comparison that means: is after the (time.Time) value. +func After(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)} +} + +// Before is a comparison that means: is before the (time.Time) value. +func Before(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)} +} + +// OnOrAfter is a comparison that means: is on or after the (time.Time) value. +func OnOrAfter(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)} +} + +// OnOrBefore is a comparison that means: is on or before the (time.Time) value. +func OnOrBefore(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)} +} + +// Between is a comparison that means: is between lowerBound and upperBound. +func Between(lowerBound interface{}, upperBound interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{lowerBound, upperBound})} +} + +// NotBetween is a comparison that means: is not between lowerBound and upperBound. +func NotBetween(lowerBound interface{}, upperBound interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{lowerBound, upperBound})} +} + +// Is is a comparison that means: is equivalent to nil, true or false. +func Is(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, value)} +} + +// IsNot is a comparison that means: is not equivalent to nil, true nor false. +func IsNot(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, value)} +} + +// IsNull is a comparison that means: is equivalent to nil. +func IsNull() *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil)} +} + +// IsNotNull is a comparison that means: is not equivalent to nil. +func IsNotNull() *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil)} +} + +// Like is a comparison that checks whether the reference matches the wildcard +// value. +func Like(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, value)} +} + +// NotLike is a comparison that checks whether the reference does not match the +// wildcard value. +func NotLike(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, value)} +} + +// RegExp is a comparison that checks whether the reference matches the regular +// expression. +func RegExp(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, value)} +} + +// NotRegExp is a comparison that checks whether the reference does not match +// the regular expression. +func NotRegExp(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, value)} +} + +// Op returns a custom comparison operator. +func Op(customOperator string, value interface{}) *Comparison { + return &Comparison{adapter.NewCustomComparisonOperator(customOperator, value)} +} + +func toInterfaceArray(value interface{}) []interface{} { + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Ptr: + return toInterfaceArray(rv.Elem().Interface()) + case reflect.Slice: + elems := rv.Len() + args := make([]interface{}, elems) + for i := 0; i < elems; i++ { + args[i] = rv.Index(i).Interface() + } + return args + } + return []interface{}{value} +} diff --git a/comparison_test.go b/comparison_test.go new file mode 100644 index 0000000..03ea789 --- /dev/null +++ b/comparison_test.go @@ -0,0 +1,111 @@ +package mydb + +import ( + "testing" + "time" + + "git.hexq.cn/tiglog/mydb/internal/adapter" + "github.com/stretchr/testify/assert" +) + +func TestComparison(t *testing.T) { + testTimeVal := time.Now() + + testCases := []struct { + expects *adapter.Comparison + result *Comparison + }{ + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, 1), + Gte(1), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, 22), + Lte(22), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, 6), + Eq(6), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, 67), + NotEq(67), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, 4), + Gt(4), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, 47), + Lt(47), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, []interface{}{1, 22, 34}), + In(1, 22, 34), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, testTimeVal), + After(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, testTimeVal), + Before(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, testTimeVal), + OnOrAfter(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, testTimeVal), + OnOrBefore(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{11, 35}), + Between(11, 35), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{11, 35}), + NotBetween(11, 35), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, 178), + Is(178), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, 32), + IsNot(32), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil), + IsNull(), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil), + IsNotNull(), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, "%a%"), + Like("%a%"), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, "%z%"), + NotLike("%z%"), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, ".*"), + RegExp(".*"), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, ".*"), + NotRegExp(".*"), + }, + { + adapter.NewCustomComparisonOperator("~", 56), + Op("~", 56), + }, + } + + for i := range testCases { + assert.Equal(t, testCases[i].expects, testCases[i].result.Comparison) + } +} diff --git a/cond.go b/cond.go new file mode 100644 index 0000000..385c7e6 --- /dev/null +++ b/cond.go @@ -0,0 +1,109 @@ +package mydb + +import ( + "fmt" + "sort" + + "git.hexq.cn/tiglog/mydb/internal/adapter" +) + +// LogicalExpr represents an expression to be used in logical statements. +type LogicalExpr = adapter.LogicalExpr + +// LogicalOperator represents a logical operation. +type LogicalOperator = adapter.LogicalOperator + +// Cond is a map that defines conditions for a query. +// +// Each entry of the map represents a condition (a column-value relation bound +// by a comparison Operator). The comparison can be specified after the column +// name, if no comparison operator is provided the equality operator is used as +// default. +// +// Examples: +// +// // Age equals 18. +// db.Cond{"age": 18} +// +// // Age is greater than or equal to 18. +// db.Cond{"age >=": 18} +// +// // id is any of the values 1, 2 or 3. +// db.Cond{"id IN": []{1, 2, 3}} +// +// // Age is lower than 18 (MongoDB syntax) +// db.Cond{"age $lt": 18} +// +// // age > 32 and age < 35 +// db.Cond{"age >": 32, "age <": 35} +type Cond map[interface{}]interface{} + +// Empty returns false if there are no conditions. +func (c Cond) Empty() bool { + for range c { + return false + } + return true +} + +// Constraints returns each one of the Cond map entires as a constraint. +func (c Cond) Constraints() []adapter.Constraint { + z := make([]adapter.Constraint, 0, len(c)) + for _, k := range c.keys() { + z = append(z, adapter.NewConstraint(k, c[k])) + } + return z +} + +// Operator returns the equality operator. +func (c Cond) Operator() LogicalOperator { + return adapter.DefaultLogicalOperator +} + +func (c Cond) keys() []interface{} { + keys := make(condKeys, 0, len(c)) + for k := range c { + keys = append(keys, k) + } + if len(c) > 1 { + sort.Sort(keys) + } + return keys +} + +// Expressions returns all the expressions contained in the condition. +func (c Cond) Expressions() []LogicalExpr { + z := make([]LogicalExpr, 0, len(c)) + for _, k := range c.keys() { + z = append(z, Cond{k: c[k]}) + } + return z +} + +type condKeys []interface{} + +func (ck condKeys) Len() int { + return len(ck) +} + +func (ck condKeys) Less(i, j int) bool { + return fmt.Sprintf("%v", ck[i]) < fmt.Sprintf("%v", ck[j]) +} + +func (ck condKeys) Swap(i, j int) { + ck[i], ck[j] = ck[j], ck[i] +} + +func defaultJoin(in ...adapter.LogicalExpr) []adapter.LogicalExpr { + for i := range in { + cond, ok := in[i].(Cond) + if ok && !cond.Empty() { + in[i] = And(cond) + } + } + return in +} + +var ( + _ = LogicalExpr(Cond{}) +) diff --git a/cond_test.go b/cond_test.go new file mode 100644 index 0000000..9e4232a --- /dev/null +++ b/cond_test.go @@ -0,0 +1,69 @@ +package mydb + +import ( + "testing" +) + +func TestCond(t *testing.T) { + c := Cond{} + + if !c.Empty() { + t.Fatal("Cond is empty.") + } + + c = Cond{"id": 1} + if c.Empty() { + t.Fatal("Cond is not empty.") + } +} + +func TestCondAnd(t *testing.T) { + a := And() + + if !a.Empty() { + t.Fatal("Cond is empty") + } + + _ = a.And(Cond{"id": 1}) + + if !a.Empty() { + t.Fatal("Cond is still empty") + } + + a = a.And(Cond{"name": "Ana"}) + + if a.Empty() { + t.Fatal("Cond is not empty anymore") + } + + a = a.And().And() + + if a.Empty() { + t.Fatal("Cond is not empty anymore") + } +} + +func TestCondOr(t *testing.T) { + a := Or() + + if !a.Empty() { + t.Fatal("Cond is empty") + } + + _ = a.Or(Cond{"id": 1}) + + if !a.Empty() { + t.Fatal("Cond is empty") + } + + a = a.Or(Cond{"name": "Ana"}) + + if a.Empty() { + t.Fatal("Cond is not empty") + } + + a = a.Or().Or() + if a.Empty() { + t.Fatal("Cond is not empty") + } +} diff --git a/connection_url.go b/connection_url.go new file mode 100644 index 0000000..20ff9ce --- /dev/null +++ b/connection_url.go @@ -0,0 +1,8 @@ +package mydb + +// ConnectionURL represents a data source name (DSN). +type ConnectionURL interface { + // String returns the connection string that is going to be passed to the + // adapter. + String() string +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..92abf8f --- /dev/null +++ b/errors.go @@ -0,0 +1,42 @@ +package mydb + +import ( + "errors" +) + +// Error messages +var ( + ErrMissingAdapter = errors.New(`mydb: missing adapter`) + ErrAlreadyWithinTransaction = errors.New(`mydb: already within a transaction`) + ErrCollectionDoesNotExist = errors.New(`mydb: collection does not exist`) + ErrExpectingNonNilModel = errors.New(`mydb: expecting non nil model`) + ErrExpectingPointerToStruct = errors.New(`mydb: expecting pointer to struct`) + ErrGivingUpTryingToConnect = errors.New(`mydb: giving up trying to connect: too many clients`) + ErrInvalidCollection = errors.New(`mydb: invalid collection`) + ErrMissingCollectionName = errors.New(`mydb: missing collection name`) + ErrMissingConditions = errors.New(`mydb: missing selector conditions`) + ErrMissingConnURL = errors.New(`mydb: missing DSN`) + ErrMissingDatabaseName = errors.New(`mydb: missing database name`) + ErrNoMoreRows = errors.New(`mydb: no more rows in this result set`) + ErrNotConnected = errors.New(`mydb: not connected to a database`) + ErrNotImplemented = errors.New(`mydb: call not implemented`) + ErrQueryIsPending = errors.New(`mydb: can't execute this instruction while the result set is still open`) + ErrQueryLimitParam = errors.New(`mydb: a query can accept only one limit parameter`) + ErrQueryOffsetParam = errors.New(`mydb: a query can accept only one offset parameter`) + ErrQuerySortParam = errors.New(`mydb: a query can accept only one order-by parameter`) + ErrSockerOrHost = errors.New(`mydb: you may connect either to a UNIX socket or a TCP address, but not both`) + ErrTooManyClients = errors.New(`mydb: can't connect to database server: too many clients`) + ErrUndefined = errors.New(`mydb: value is undefined`) + ErrUnknownConditionType = errors.New(`mydb: arguments of type %T can't be used as constraints`) + ErrUnsupported = errors.New(`mydb: action is not supported by the DBMS`) + ErrUnsupportedDestination = errors.New(`mydb: unsupported destination type`) + ErrUnsupportedType = errors.New(`mydb: type does not support marshaling`) + ErrUnsupportedValue = errors.New(`mydb: value does not support unmarshaling`) + ErrNilRecord = errors.New(`mydb: invalid item (nil)`) + ErrRecordIDIsZero = errors.New(`mydb: item ID is not defined`) + ErrMissingPrimaryKeys = errors.New(`mydb: collection %q has no primary keys`) + ErrWarnSlowQuery = errors.New(`mydb: slow query`) + ErrTransactionAborted = errors.New(`mydb: transaction was aborted`) + ErrNotWithinTransaction = errors.New(`mydb: not within transaction`) + ErrNotSupportedByAdapter = errors.New(`mydb: not supported by adapter`) +) diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..d26e677 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,14 @@ +package mydb + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorWrap(t *testing.T) { + adapterFakeErr := fmt.Errorf("could not find item in %q: %w", "users", ErrCollectionDoesNotExist) + assert.True(t, errors.Is(adapterFakeErr, ErrCollectionDoesNotExist)) +} diff --git a/function.go b/function.go new file mode 100644 index 0000000..6f9daf5 --- /dev/null +++ b/function.go @@ -0,0 +1,25 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// FuncExpr represents functions. +type FuncExpr = adapter.FuncExpr + +// Func returns a database function expression. +// +// Examples: +// +// // MOD(29, 9) +// db.Func("MOD", 29, 9) +// +// // CONCAT("foo", "bar") +// db.Func("CONCAT", "foo", "bar") +// +// // NOW() +// db.Func("NOW") +// +// // RTRIM("Hello ") +// db.Func("RTRIM", "Hello ") +func Func(name string, args ...interface{}) *FuncExpr { + return adapter.NewFuncExpr(name, args) +} diff --git a/function_test.go b/function_test.go new file mode 100644 index 0000000..b1f2ef9 --- /dev/null +++ b/function_test.go @@ -0,0 +1,51 @@ +package mydb + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFunction(t *testing.T) { + { + fn := Func("MOD", 29, 9) + assert.Equal(t, "MOD", fn.Name()) + assert.Equal(t, []interface{}{29, 9}, fn.Arguments()) + } + + { + fn := Func("HELLO") + assert.Equal(t, "HELLO", fn.Name()) + assert.Equal(t, []interface{}(nil), fn.Arguments()) + } + + { + fn := Func("CONCAT", "a") + assert.Equal(t, "CONCAT", fn.Name()) + assert.Equal(t, []interface{}{"a"}, fn.Arguments()) + } + + { + fn := Func("CONCAT", "a", "b", "c") + assert.Equal(t, "CONCAT", fn.Name()) + assert.Equal(t, []interface{}{"a", "b", "c"}, fn.Arguments()) + } + + { + fn := Func("IN", []interface{}{"a", "b", "c"}) + assert.Equal(t, "IN", fn.Name()) + assert.Equal(t, []interface{}{[]interface{}{"a", "b", "c"}}, fn.Arguments()) + } + + { + fn := Func("IN", []interface{}{"a"}) + assert.Equal(t, "IN", fn.Name()) + assert.Equal(t, []interface{}{[]interface{}{"a"}}, fn.Arguments()) + } + + { + fn := Func("IN", []interface{}(nil)) + assert.Equal(t, "IN", fn.Name()) + assert.Equal(t, []interface{}{[]interface{}(nil)}, fn.Arguments()) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..58d9f56 --- /dev/null +++ b/go.mod @@ -0,0 +1,33 @@ +module git.hexq.cn/tiglog/mydb + +go 1.20 + +require ( + github.com/go-sql-driver/mysql v1.7.1 + github.com/google/uuid v1.1.1 + github.com/ipfs/go-detect-race v0.0.1 + github.com/jackc/pgtype v1.14.0 + github.com/jackc/pgx/v4 v4.18.1 + github.com/lib/pq v1.10.9 + github.com/mattn/go-sqlite3 v1.14.17 + github.com/segmentio/fasthash v1.0.3 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgconn v1.14.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5e38d2e --- /dev/null +++ b/go.sum @@ -0,0 +1,226 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ6cOk= +github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgconn v1.14.1 h1:smbxIaZA08n6YuxEX1sDyjV/qkbtUtkH20qLkR9MUR4= +github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= +github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0= +github.com/jackc/pgx/v4 v4.18.1/go.mod h1:FydWkUyadDmdNH/mHnGob881GawxeEm7TcMCzkb+qQE= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= +github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 h1:VpOs+IwYnYBaFnrNAeB8UUWtL3vEUnzSCL1nVjPhqrw= +gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/internal/adapter/comparison.go b/internal/adapter/comparison.go new file mode 100644 index 0000000..1f63a20 --- /dev/null +++ b/internal/adapter/comparison.go @@ -0,0 +1,60 @@ +package adapter + +// ComparisonOperator is the base type for comparison operators. +type ComparisonOperator uint8 + +// Comparison operators +const ( + ComparisonOperatorNone ComparisonOperator = iota + ComparisonOperatorCustom + + ComparisonOperatorEqual + ComparisonOperatorNotEqual + + ComparisonOperatorLessThan + ComparisonOperatorGreaterThan + + ComparisonOperatorLessThanOrEqualTo + ComparisonOperatorGreaterThanOrEqualTo + + ComparisonOperatorBetween + ComparisonOperatorNotBetween + + ComparisonOperatorIn + ComparisonOperatorNotIn + + ComparisonOperatorIs + ComparisonOperatorIsNot + + ComparisonOperatorLike + ComparisonOperatorNotLike + + ComparisonOperatorRegExp + ComparisonOperatorNotRegExp +) + +type Comparison struct { + t ComparisonOperator + op string + v interface{} +} + +func (c *Comparison) CustomOperator() string { + return c.op +} + +func (c *Comparison) Operator() ComparisonOperator { + return c.t +} + +func (c *Comparison) Value() interface{} { + return c.v +} + +func NewComparisonOperator(t ComparisonOperator, v interface{}) *Comparison { + return &Comparison{t: t, v: v} +} + +func NewCustomComparisonOperator(op string, v interface{}) *Comparison { + return &Comparison{t: ComparisonOperatorCustom, op: op, v: v} +} diff --git a/internal/adapter/constraint.go b/internal/adapter/constraint.go new file mode 100644 index 0000000..84a63de --- /dev/null +++ b/internal/adapter/constraint.go @@ -0,0 +1,51 @@ +package adapter + +// ConstraintValuer allows constraints to use specific values of their own. +type ConstraintValuer interface { + ConstraintValue() interface{} +} + +// Constraint interface represents a single condition, like "a = 1". where `a` +// is the key and `1` is the value. This is an exported interface but it's +// rarely used directly, you may want to use the `db.Cond{}` map instead. +type Constraint interface { + // Key is the leftmost part of the constraint and usually contains a column + // name. + Key() interface{} + + // Value if the rightmost part of the constraint and usually contains a + // column value. + Value() interface{} +} + +// Constraints interface represents an array of constraints, like "a = 1, b = +// 2, c = 3". +type Constraints interface { + // Constraints returns an array of constraints. + Constraints() []Constraint +} + +type constraint struct { + k interface{} + v interface{} +} + +func (c constraint) Key() interface{} { + return c.k +} + +func (c constraint) Value() interface{} { + if constraintValuer, ok := c.v.(ConstraintValuer); ok { + return constraintValuer.ConstraintValue() + } + return c.v +} + +// NewConstraint creates a constraint. +func NewConstraint(key interface{}, value interface{}) Constraint { + return &constraint{k: key, v: value} +} + +var ( + _ = Constraint(&constraint{}) +) diff --git a/internal/adapter/func.go b/internal/adapter/func.go new file mode 100644 index 0000000..e6353ba --- /dev/null +++ b/internal/adapter/func.go @@ -0,0 +1,18 @@ +package adapter + +type FuncExpr struct { + name string + args []interface{} +} + +func (f *FuncExpr) Arguments() []interface{} { + return f.args +} + +func (f *FuncExpr) Name() string { + return f.name +} + +func NewFuncExpr(name string, args []interface{}) *FuncExpr { + return &FuncExpr{name: name, args: args} +} diff --git a/internal/adapter/logical_expr.go b/internal/adapter/logical_expr.go new file mode 100644 index 0000000..fce009a --- /dev/null +++ b/internal/adapter/logical_expr.go @@ -0,0 +1,100 @@ +package adapter + +import "git.hexq.cn/tiglog/mydb/internal/immutable" + +// LogicalExpr represents a group formed by one or more sentences joined by +// an Operator like "AND" or "OR". +type LogicalExpr interface { + // Expressions returns child sentences. + Expressions() []LogicalExpr + + // Operator returns the Operator that joins all the sentences in the group. + Operator() LogicalOperator + + // Empty returns true if the compound has zero children, false otherwise. + Empty() bool +} + +// LogicalOperator represents the operation on a compound statement. +type LogicalOperator uint + +// LogicalExpr Operators. +const ( + LogicalOperatorNone LogicalOperator = iota + LogicalOperatorAnd + LogicalOperatorOr +) + +const DefaultLogicalOperator = LogicalOperatorAnd + +type LogicalExprGroup struct { + op LogicalOperator + + prev *LogicalExprGroup + fn func(*[]LogicalExpr) error +} + +func NewLogicalExprGroup(op LogicalOperator, conds ...LogicalExpr) *LogicalExprGroup { + group := &LogicalExprGroup{op: op} + if len(conds) == 0 { + return group + } + return group.Frame(func(in *[]LogicalExpr) error { + *in = append(*in, conds...) + return nil + }) +} + +// Expressions returns each one of the conditions as a compound. +func (g *LogicalExprGroup) Expressions() []LogicalExpr { + conds, err := immutable.FastForward(g) + if err == nil { + return *(conds.(*[]LogicalExpr)) + } + return nil +} + +// Operator is undefined for a logical group. +func (g *LogicalExprGroup) Operator() LogicalOperator { + if g.op == LogicalOperatorNone { + panic("operator is not defined") + } + return g.op +} + +// Empty returns true if this condition has no elements. False otherwise. +func (g *LogicalExprGroup) Empty() bool { + if g.fn != nil { + return false + } + if g.prev != nil { + return g.prev.Empty() + } + return true +} + +func (g *LogicalExprGroup) Frame(fn func(*[]LogicalExpr) error) *LogicalExprGroup { + return &LogicalExprGroup{prev: g, op: g.op, fn: fn} +} + +func (g *LogicalExprGroup) Prev() immutable.Immutable { + if g == nil { + return nil + } + return g.prev +} + +func (g *LogicalExprGroup) Fn(in interface{}) error { + if g.fn == nil { + return nil + } + return g.fn(in.(*[]LogicalExpr)) +} + +func (g *LogicalExprGroup) Base() interface{} { + return &[]LogicalExpr{} +} + +var ( + _ = immutable.Immutable(&LogicalExprGroup{}) +) diff --git a/internal/adapter/raw.go b/internal/adapter/raw.go new file mode 100644 index 0000000..fc688d1 --- /dev/null +++ b/internal/adapter/raw.go @@ -0,0 +1,49 @@ +package adapter + +// RawExpr interface represents values that can bypass SQL filters. This is an +// exported interface but it's rarely used directly, you may want to use the +// `db.Raw()` function instead. +type RawExpr struct { + value string + args *[]interface{} +} + +func (r *RawExpr) Arguments() []interface{} { + if r.args != nil { + return *r.args + } + return nil +} + +func (r RawExpr) Raw() string { + return r.value +} + +func (r RawExpr) String() string { + return r.Raw() +} + +// Expressions returns a logical expressio.n +func (r *RawExpr) Expressions() []LogicalExpr { + return []LogicalExpr{r} +} + +// Operator returns the default compound operator. +func (r RawExpr) Operator() LogicalOperator { + return LogicalOperatorNone +} + +// Empty return false if this struct has no value. +func (r *RawExpr) Empty() bool { + return r.value == "" +} + +func NewRawExpr(value string, args []interface{}) *RawExpr { + r := &RawExpr{value: value, args: nil} + if len(args) > 0 { + r.args = &args + } + return r +} + +var _ = LogicalExpr(&RawExpr{}) diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a72a621 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,113 @@ +package cache + +import ( + "container/list" + "errors" + "sync" +) + +const defaultCapacity = 128 + +// Cache holds a map of volatile key -> values. +type Cache struct { + keys *list.List + items map[uint64]*list.Element + mu sync.RWMutex + capacity int +} + +type cacheItem struct { + key uint64 + value interface{} +} + +// NewCacheWithCapacity initializes a new caching space with the given +// capacity. +func NewCacheWithCapacity(capacity int) (*Cache, error) { + if capacity < 1 { + return nil, errors.New("Capacity must be greater than zero.") + } + c := &Cache{ + capacity: capacity, + } + c.init() + return c, nil +} + +// NewCache initializes a new caching space with default settings. +func NewCache() *Cache { + c, err := NewCacheWithCapacity(defaultCapacity) + if err != nil { + panic(err.Error()) // Should never happen as we're not providing a negative defaultCapacity. + } + return c +} + +func (c *Cache) init() { + c.items = make(map[uint64]*list.Element) + c.keys = list.New() +} + +// Read attempts to retrieve a cached value as a string, if the value does not +// exists returns an empty string and false. +func (c *Cache) Read(h Hashable) (string, bool) { + if v, ok := c.ReadRaw(h); ok { + if s, ok := v.(string); ok { + return s, true + } + } + return "", false +} + +// ReadRaw attempts to retrieve a cached value as an interface{}, if the value +// does not exists returns nil and false. +func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + item, ok := c.items[h.Hash()] + if ok { + return item.Value.(*cacheItem).value, true + } + + return nil, false +} + +// Write stores a value in memory. If the value already exists its overwritten. +func (c *Cache) Write(h Hashable, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + key := h.Hash() + + if item, ok := c.items[key]; ok { + item.Value.(*cacheItem).value = value + c.keys.MoveToFront(item) + return + } + + c.items[key] = c.keys.PushFront(&cacheItem{key, value}) + + for c.keys.Len() > c.capacity { + item := c.keys.Remove(c.keys.Back()).(*cacheItem) + delete(c.items, item.key) + if p, ok := item.value.(HasOnEvict); ok { + p.OnEvict() + } + } +} + +// Clear generates a new memory space, leaving the old memory unreferenced, so +// it can be claimed by the garbage collector. +func (c *Cache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, item := range c.items { + if p, ok := item.Value.(*cacheItem).value.(HasOnEvict); ok { + p.OnEvict() + } + } + + c.init() +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..11aeaef --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,97 @@ +package cache + +import ( + "fmt" + "hash/fnv" + "testing" +) + +var c *Cache + +type cacheableT struct { + Name string +} + +func (ct *cacheableT) Hash() uint64 { + s := fnv.New64() + s.Sum([]byte(ct.Name)) + return s.Sum64() +} + +var ( + key = cacheableT{"foo"} + value = "bar" +) + +func TestNewCache(t *testing.T) { + c = NewCache() + if c == nil { + t.Fatal("Expecting a new cache object.") + } +} + +func TestCacheReadNonExistentValue(t *testing.T) { + if _, ok := c.Read(&key); ok { + t.Fatal("Expecting false.") + } +} + +func TestCacheWritingValue(t *testing.T) { + c.Write(&key, value) + c.Write(&key, value) +} + +func TestCacheReadExistentValue(t *testing.T) { + s, ok := c.Read(&key) + + if !ok { + t.Fatal("Expecting true.") + } + + if s != value { + t.Fatal("Expecting value.") + } +} + +func BenchmarkNewCache(b *testing.B) { + for i := 0; i < b.N; i++ { + NewCache() + } +} + +func BenchmarkNewCacheAndClear(b *testing.B) { + for i := 0; i < b.N; i++ { + c := NewCache() + c.Clear() + } +} + +func BenchmarkReadNonExistentValue(b *testing.B) { + z := NewCache() + for i := 0; i < b.N; i++ { + z.Read(&key) + } +} + +func BenchmarkWriteSameValue(b *testing.B) { + z := NewCache() + for i := 0; i < b.N; i++ { + z.Write(&key, value) + } +} + +func BenchmarkWriteNewValue(b *testing.B) { + z := NewCache() + for i := 0; i < b.N; i++ { + key := cacheableT{fmt.Sprintf("item-%d", i)} + z.Write(&key, value) + } +} + +func BenchmarkReadExistentValue(b *testing.B) { + z := NewCache() + z.Write(&key, value) + for i := 0; i < b.N; i++ { + z.Read(&key) + } +} diff --git a/internal/cache/hash.go b/internal/cache/hash.go new file mode 100644 index 0000000..4b866a9 --- /dev/null +++ b/internal/cache/hash.go @@ -0,0 +1,109 @@ +package cache + +import ( + "fmt" + + "github.com/segmentio/fasthash/fnv1a" +) + +const ( + hashTypeInt uint64 = 1 << iota + hashTypeSignedInt + hashTypeBool + hashTypeString + hashTypeHashable + hashTypeNil +) + +type hasher struct { + t uint64 + v interface{} +} + +func (h *hasher) Hash() uint64 { + return NewHash(h.t, h.v) +} + +func NewHashable(t uint64, v interface{}) Hashable { + return &hasher{t: t, v: v} +} + +func InitHash(t uint64) uint64 { + return fnv1a.AddUint64(fnv1a.Init64, t) +} + +func NewHash(t uint64, in ...interface{}) uint64 { + return AddToHash(InitHash(t), in...) +} + +func AddToHash(h uint64, in ...interface{}) uint64 { + for i := range in { + if in[i] == nil { + continue + } + h = addToHash(h, in[i]) + } + return h +} + +func addToHash(h uint64, in interface{}) uint64 { + switch v := in.(type) { + case uint64: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), v) + case uint32: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint16: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint8: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case int64: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int32: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int16: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int8: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case bool: + if v { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 1) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 2) + } + case string: + return fnv1a.AddString64(fnv1a.AddUint64(h, hashTypeString), v) + case Hashable: + if in == nil { + panic(fmt.Sprintf("could not hash nil element %T", in)) + } + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeHashable), v.Hash()) + case nil: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeNil), 0) + default: + panic(fmt.Sprintf("unsupported value type %T", in)) + } +} diff --git a/internal/cache/interface.go b/internal/cache/interface.go new file mode 100644 index 0000000..aee9ac7 --- /dev/null +++ b/internal/cache/interface.go @@ -0,0 +1,13 @@ +package cache + +// Hashable types must implement a method that returns a key. This key will be +// associated with a cached value. +type Hashable interface { + Hash() uint64 +} + +// HasOnEvict type is (optionally) implemented by cache objects to clean after +// themselves. +type HasOnEvict interface { + OnEvict() +} diff --git a/internal/immutable/immutable.go b/internal/immutable/immutable.go new file mode 100644 index 0000000..57d29ce --- /dev/null +++ b/internal/immutable/immutable.go @@ -0,0 +1,28 @@ +package immutable + +// Immutable represents an immutable chain that, if passed to FastForward, +// applies Fn() to every element of a chain, the first element of this chain is +// represented by Base(). +type Immutable interface { + // Prev is the previous element on a chain. + Prev() Immutable + // Fn a function that is able to modify the passed element. + Fn(interface{}) error + // Base is the first element on a chain, there's no previous element before + // the Base element. + Base() interface{} +} + +// FastForward applies all Fn methods in order on the given new Base. +func FastForward(curr Immutable) (interface{}, error) { + prev := curr.Prev() + if prev == nil { + return curr.Base(), nil + } + in, err := FastForward(prev) + if err != nil { + return nil, err + } + err = curr.Fn(in) + return in, err +} diff --git a/internal/reflectx/LICENSE b/internal/reflectx/LICENSE new file mode 100644 index 0000000..0d31edf --- /dev/null +++ b/internal/reflectx/LICENSE @@ -0,0 +1,23 @@ + Copyright (c) 2013, Jason Moiron + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following + conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + diff --git a/internal/reflectx/README.md b/internal/reflectx/README.md new file mode 100644 index 0000000..76f1b5d --- /dev/null +++ b/internal/reflectx/README.md @@ -0,0 +1,17 @@ +# reflectx + +The sqlx package has special reflect needs. In particular, it needs to: + +* be able to map a name to a field +* understand embedded structs +* understand mapping names to fields by a particular tag +* user specified name -> field mapping functions + +These behaviors mimic the behaviors by the standard library marshallers and also the +behavior of standard Go accessors. + +The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is +addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct +tags in the ways that are vital to most marshalers, and they are slow. + +This reflectx package extends reflect to achieve these goals. diff --git a/internal/reflectx/reflect.go b/internal/reflectx/reflect.go new file mode 100644 index 0000000..02df3da --- /dev/null +++ b/internal/reflectx/reflect.go @@ -0,0 +1,404 @@ +// Package reflectx implements extensions to the standard reflect lib suitable +// for implementing marshaling and unmarshaling packages. The main Mapper type +// allows for Go-compatible named attribute access, including accessing embedded +// struct attributes and the ability to use functions and struct tags to +// customize field names. +package reflectx + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "sync" +) + +// A FieldInfo is a collection of metadata about a struct field. +type FieldInfo struct { + Index []int + Path string + Field reflect.StructField + Zero reflect.Value + Name string + Options map[string]string + Embedded bool + Children []*FieldInfo + Parent *FieldInfo +} + +// A StructMap is an index of field metadata for a struct. +type StructMap struct { + Tree *FieldInfo + Index []*FieldInfo + Paths map[string]*FieldInfo + Names map[string]*FieldInfo +} + +// GetByPath returns a *FieldInfo for a given string path. +func (f StructMap) GetByPath(path string) *FieldInfo { + return f.Paths[path] +} + +// GetByTraversal returns a *FieldInfo for a given integer path. It is +// analogous to reflect.FieldByIndex. +func (f StructMap) GetByTraversal(index []int) *FieldInfo { + if len(index) == 0 { + return nil + } + + tree := f.Tree + for _, i := range index { + if i >= len(tree.Children) || tree.Children[i] == nil { + return nil + } + tree = tree.Children[i] + } + return tree +} + +// Mapper is a general purpose mapper of names to struct fields. A Mapper +// behaves like most marshallers, optionally obeying a field tag for name +// mapping and a function to provide a basic mapping of fields to names. +type Mapper struct { + cache map[reflect.Type]*StructMap + tagName string + tagMapFunc func(string) string + mapFunc func(string) string + mutex sync.Mutex +} + +// NewMapper returns a new mapper which optionally obeys the field tag given +// by tagName. If tagName is the empty string, it is ignored. +func NewMapper(tagName string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + } +} + +// NewMapperTagFunc returns a new mapper which contains a mapper for field names +// AND a mapper for tag values. This is useful for tags like json which can +// have values like "name,omitempty". +func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: mapFunc, + tagMapFunc: tagMapFunc, + } +} + +// NewMapperFunc returns a new mapper which optionally obeys a field tag and +// a struct field name mapper func given by f. Tags will take precedence, but +// for any other field, the mapped name will be f(field.Name) +func NewMapperFunc(tagName string, f func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: f, + } +} + +// TypeMap returns a mapping of field strings to int slices representing +// the traversal down the struct to reach the field. +func (m *Mapper) TypeMap(t reflect.Type) *StructMap { + m.mutex.Lock() + mapping, ok := m.cache[t] + if !ok { + mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) + m.cache[t] = mapping + } + m.mutex.Unlock() + return mapping +} + +// FieldMap returns the mapper's mapping of field names to reflect values. Panics +// if v's Kind is not Struct, or v is not Indirectable to a struct kind. +func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + r[tagName] = FieldByIndexes(v, fi.Index) + } + return r +} + +// ValidFieldMap returns the mapper's mapping of field names to reflect valid +// field values. Panics if v's Kind is not Struct, or v is not Indirectable to +// a struct kind. +func (m *Mapper) ValidFieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + v := ValidFieldByIndexes(v, fi.Index) + if v.IsValid() { + r[tagName] = v + } + } + return r +} + +// FieldByName returns a field by the its mapped name as a reflect.Value. +// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. +// Returns zero Value if the name is not found. +func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + fi, ok := tm.Names[name] + if !ok { + return v + } + return FieldByIndexes(v, fi.Index) +} + +// FieldsByName returns a slice of values corresponding to the slice of names +// for the value. Panics if v's Kind is not Struct or v is not Indirectable +// to a struct Kind. Returns zero Value for each name not found. +func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + vals := make([]reflect.Value, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + vals = append(vals, *new(reflect.Value)) + } else { + vals = append(vals, FieldByIndexes(v, fi.Index)) + } + } + return vals +} + +// TraversalsByName returns a slice of int slices which represent the struct +// traversals for each mapped name. Panics if t is not a struct or Indirectable +// to a struct. Returns empty int slice for each name not found. +func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { + t = Deref(t) + mustBe(t, reflect.Struct) + tm := m.TypeMap(t) + + r := make([][]int, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + r = append(r, []int{}) + } else { + r = append(r, fi.Index) + } + } + return r +} + +// FieldByIndexes returns a value for a particular struct traversal. +func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + // if this is a pointer, it's possible it is nil + if v.Kind() == reflect.Ptr && v.IsNil() { + alloc := reflect.New(Deref(v.Type())) + v.Set(alloc) + } + if v.Kind() == reflect.Map && v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + } + return v +} + +// ValidFieldByIndexes returns a value for a particular struct traversal. +func ValidFieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + + for _, i := range indexes { + v = reflect.Indirect(v) + if !v.IsValid() { + return reflect.Value{} + } + v = v.Field(i) + // if this is a pointer, it's possible it is nil + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Map) && v.IsNil() { + return reflect.Value{} + } + } + + return v +} + +// FieldByIndexesReadOnly returns a value for a particular struct traversal, +// but is not concerned with allocating nil pointers because the value is +// going to be used for reading and not setting. +func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + } + return v +} + +// Deref is Indirect for reflect.Types +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +// -- helpers & utilities -- + +type kinder interface { + Kind() reflect.Kind +} + +// mustBe checks a value against a kind, panicing with a reflect.ValueError +// if the kind isn't that which is required. +func mustBe(v kinder, expected reflect.Kind) { + k := v.Kind() + if k != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: k}) + } +} + +// methodName is returns the caller of the function calling methodName +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +type typeQueue struct { + t reflect.Type + fi *FieldInfo + pp string // Parent path +} + +// A copying append that creates a new slice each time. +func apnd(is []int, i int) []int { + x := make([]int, len(is)+1) + copy(x, is) + x[len(x)-1] = i + return x +} + +// getMapping returns a mapping for the t type, using the tagName, mapFunc and +// tagMapFunc to determine the canonical names of fields. +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap { + m := []*FieldInfo{} + + root := &FieldInfo{} + queue := []typeQueue{} + queue = append(queue, typeQueue{Deref(t), root, ""}) + + for len(queue) != 0 { + // pop the first item off of the queue + tq := queue[0] + queue = queue[1:] + nChildren := 0 + if tq.t.Kind() == reflect.Struct { + nChildren = tq.t.NumField() + } + tq.fi.Children = make([]*FieldInfo, nChildren) + + // iterate through all of its fields + for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + f := tq.t.Field(fieldPos) + + fi := FieldInfo{} + fi.Field = f + fi.Zero = reflect.New(f.Type).Elem() + fi.Options = map[string]string{} + + var tag, name string + if tagName != "" && strings.Contains(string(f.Tag), tagName+":") { + tag = f.Tag.Get(tagName) + name = tag + } else { + if mapFunc != nil { + name = mapFunc(f.Name) + } + } + + parts := strings.Split(name, ",") + if len(parts) > 1 { + name = parts[0] + for _, opt := range parts[1:] { + kv := strings.Split(opt, "=") + if len(kv) > 1 { + fi.Options[kv[0]] = kv[1] + } else { + fi.Options[kv[0]] = "" + } + } + } + + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + fi.Name = name + + if tq.pp == "" || (tq.pp == "" && tag == "") { + fi.Path = fi.Name + } else { + fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name) + } + + // if the name is "-", disabled via a tag, skip it + if name == "-" { + continue + } + + // skip unexported fields + if len(f.PkgPath) != 0 && !f.Anonymous { + continue + } + + // bfs search of anonymous embedded structs + if f.Anonymous { + pp := tq.pp + if tag != "" { + pp = fi.Path + } + + fi.Embedded = true + fi.Index = apnd(tq.fi.Index, fieldPos) + nChildren := 0 + ft := Deref(f.Type) + if ft.Kind() == reflect.Struct { + nChildren = ft.NumField() + } + fi.Children = make([]*FieldInfo, nChildren) + queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) + } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) + queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) + } + + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Parent = tq.fi + tq.fi.Children[fieldPos] = &fi + m = append(m, &fi) + } + } + + flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} + for _, fi := range flds.Index { + flds.Paths[fi.Path] = fi + if fi.Name != "" && !fi.Embedded { + flds.Names[fi.Path] = fi + } + } + + return flds +} diff --git a/internal/reflectx/reflect_test.go b/internal/reflectx/reflect_test.go new file mode 100644 index 0000000..8072244 --- /dev/null +++ b/internal/reflectx/reflect_test.go @@ -0,0 +1,587 @@ +package reflectx + +import ( + "reflect" + "strings" + "testing" +) + +func ival(v reflect.Value) int { + return v.Interface().(int) +} + +func TestBasic(t *testing.T) { + type Foo struct { + A int + B int + C int + } + + f := Foo{1, 2, 3} + fv := reflect.ValueOf(f) + m := NewMapperFunc("", func(s string) string { return s }) + + v := m.FieldByName(fv, "A") + if ival(v) != f.A { + t.Errorf("Expecting %d, got %d", ival(v), f.A) + } + v = m.FieldByName(fv, "B") + if ival(v) != f.B { + t.Errorf("Expecting %d, got %d", f.B, ival(v)) + } + v = m.FieldByName(fv, "C") + if ival(v) != f.C { + t.Errorf("Expecting %d, got %d", f.C, ival(v)) + } +} + +func TestBasicEmbedded(t *testing.T) { + type Foo struct { + A int + } + + type Bar struct { + Foo // `db:""` is implied for an embedded struct + B int + C int `db:"-"` + } + + type Baz struct { + A int + Bar `db:"Bar"` + } + + m := NewMapperFunc("db", func(s string) string { return s }) + + z := Baz{} + z.A = 1 + z.B = 2 + z.C = 4 + z.Bar.Foo.A = 3 + + zv := reflect.ValueOf(z) + fields := m.TypeMap(reflect.TypeOf(z)) + + if len(fields.Index) != 5 { + t.Errorf("Expecting 5 fields") + } + + // for _, fi := range fields.Index { + // log.Println(fi) + // } + + v := m.FieldByName(zv, "A") + if ival(v) != z.A { + t.Errorf("Expecting %d, got %d", z.A, ival(v)) + } + v = m.FieldByName(zv, "Bar.B") + if ival(v) != z.Bar.B { + t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) + } + v = m.FieldByName(zv, "Bar.A") + if ival(v) != z.Bar.Foo.A { + t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + } + v = m.FieldByName(zv, "Bar.C") + if _, ok := v.Interface().(int); ok { + t.Errorf("Expecting Bar.C to not exist") + } + + fi := fields.GetByPath("Bar.C") + if fi != nil { + t.Errorf("Bar.C should not exist") + } +} + +func TestEmbeddedSimple(t *testing.T) { + type UUID [16]byte + type MyID struct { + UUID + } + type Item struct { + ID MyID + } + z := Item{} + + m := NewMapper("db") + m.TypeMap(reflect.TypeOf(z)) +} + +func TestBasicEmbeddedWithTags(t *testing.T) { + type Foo struct { + A int `db:"a"` + } + + type Bar struct { + Foo // `db:""` is implied for an embedded struct + B int `db:"b"` + } + + type Baz struct { + A int `db:"a"` + Bar // `db:""` is implied for an embedded struct + } + + m := NewMapper("db") + + z := Baz{} + z.A = 1 + z.B = 2 + z.Bar.Foo.A = 3 + + zv := reflect.ValueOf(z) + fields := m.TypeMap(reflect.TypeOf(z)) + + if len(fields.Index) != 5 { + t.Errorf("Expecting 5 fields") + } + + // for _, fi := range fields.index { + // log.Println(fi) + // } + + v := m.FieldByName(zv, "a") + if ival(v) != z.Bar.Foo.A { // the dominant field + t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + } + v = m.FieldByName(zv, "b") + if ival(v) != z.B { + t.Errorf("Expecting %d, got %d", z.B, ival(v)) + } +} + +func TestFlatTags(t *testing.T) { + m := NewMapper("db") + + type Asset struct { + Title string `db:"title"` + } + type Post struct { + Author string `db:"author,required"` + Asset Asset `db:""` + } + // Post columns: (author title) + + post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}} + pv := reflect.ValueOf(post) + + v := m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } + v = m.FieldByName(pv, "title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } +} + +func TestNestedStruct(t *testing.T) { + m := NewMapper("db") + + type Details struct { + Active bool `db:"active"` + } + type Asset struct { + Title string `db:"title"` + Details Details `db:"details"` + } + type Post struct { + Author string `db:"author,required"` + Asset `db:"asset"` + } + // Post columns: (author asset.title asset.details.active) + + post := Post{ + Author: "Joe", + Asset: Asset{Title: "Hello", Details: Details{Active: true}}, + } + pv := reflect.ValueOf(post) + + v := m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } + v = m.FieldByName(pv, "title") + if _, ok := v.Interface().(string); ok { + t.Errorf("Expecting field to not exist") + } + v = m.FieldByName(pv, "asset.title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset.details.active") + if v.Interface().(bool) != post.Asset.Details.Active { + t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) + } +} + +func TestInlineStruct(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + + type Employee struct { + Name string + ID int + } + type Boss Employee + type person struct { + Employee `db:"employee"` + Boss `db:"boss"` + } + // employees columns: (employee.name employee.id boss.name boss.id) + + em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}} + ev := reflect.ValueOf(em) + + fields := m.TypeMap(reflect.TypeOf(em)) + if len(fields.Index) != 6 { + t.Errorf("Expecting 6 fields") + } + + v := m.FieldByName(ev, "employee.name") + if v.Interface().(string) != em.Employee.Name { + t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) + } + v = m.FieldByName(ev, "boss.id") + if ival(v) != em.Boss.ID { + t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) + } +} + +func TestFieldsEmbedded(t *testing.T) { + m := NewMapper("db") + + type Person struct { + Name string `db:"name"` + } + type Place struct { + Name string `db:"name"` + } + type Article struct { + Title string `db:"title"` + } + type PP struct { + Person `db:"person,required"` + Place `db:",someflag"` + Article `db:",required"` + } + // PP columns: (person.name name title) + + pp := PP{} + pp.Person.Name = "Peter" + pp.Place.Name = "Toronto" + pp.Article.Title = "Best city ever" + + fields := m.TypeMap(reflect.TypeOf(pp)) + // for i, f := range fields { + // log.Println(i, f) + // } + + ppv := reflect.ValueOf(pp) + + v := m.FieldByName(ppv, "person.name") + if v.Interface().(string) != pp.Person.Name { + t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) + } + + v = m.FieldByName(ppv, "name") + if v.Interface().(string) != pp.Place.Name { + t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) + } + + v = m.FieldByName(ppv, "title") + if v.Interface().(string) != pp.Article.Title { + t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) + } + + fi := fields.GetByPath("person") + if _, ok := fi.Options["required"]; !ok { + t.Errorf("Expecting required option to be set") + } + if !fi.Embedded { + t.Errorf("Expecting field to be embedded") + } + if len(fi.Index) != 1 || fi.Index[0] != 0 { + t.Errorf("Expecting index to be [0]") + } + + fi = fields.GetByPath("person.name") + if fi == nil { + t.Errorf("Expecting person.name to exist") + } + if fi.Path != "person.name" { + t.Errorf("Expecting %s, got %s", "person.name", fi.Path) + } + + fi = fields.GetByTraversal([]int{1, 0}) + if fi == nil { + t.Errorf("Expecting traveral to exist") + } + if fi.Path != "name" { + t.Errorf("Expecting %s, got %s", "name", fi.Path) + } + + fi = fields.GetByTraversal([]int{2}) + if fi == nil { + t.Errorf("Expecting traversal to exist") + } + if _, ok := fi.Options["required"]; !ok { + t.Errorf("Expecting required option to be set") + } + + trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) + if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { + t.Errorf("Expecting traversal: %v", trs) + } +} + +func TestPtrFields(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + type Asset struct { + Title string + } + type Post struct { + *Asset `db:"asset"` + Author string + } + + post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}} + pv := reflect.ValueOf(post) + + fields := m.TypeMap(reflect.TypeOf(post)) + if len(fields.Index) != 3 { + t.Errorf("Expecting 3 fields") + } + + v := m.FieldByName(pv, "asset.title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } +} + +func TestNamedPtrFields(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + + type User struct { + Name string + } + + type Asset struct { + Title string + + Owner *User `db:"owner"` + } + type Post struct { + Author string + + Asset1 *Asset `db:"asset1"` + Asset2 *Asset `db:"asset2"` + } + + post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil + pv := reflect.ValueOf(post) + + fields := m.TypeMap(reflect.TypeOf(post)) + if len(fields.Index) != 9 { + t.Errorf("Expecting 9 fields") + } + + v := m.FieldByName(pv, "asset1.title") + if v.Interface().(string) != post.Asset1.Title { + t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset1.owner.name") + if v.Interface().(string) != post.Asset1.Owner.Name { + t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset2.title") + if v.Interface().(string) != post.Asset2.Title { + t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset2.owner.name") + if v.Interface().(string) != post.Asset2.Owner.Name { + t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) + } + v = m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } +} + +func TestFieldMap(t *testing.T) { + type Foo struct { + A int + B int + C int + } + + f := Foo{1, 2, 3} + m := NewMapperFunc("db", strings.ToLower) + + fm := m.FieldMap(reflect.ValueOf(f)) + + if len(fm) != 3 { + t.Errorf("Expecting %d keys, got %d", 3, len(fm)) + } + if fm["a"].Interface().(int) != 1 { + t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) + } + if fm["b"].Interface().(int) != 2 { + t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) + } + if fm["c"].Interface().(int) != 3 { + t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) + } +} + +func TestTagNameMapping(t *testing.T) { + type Strategy struct { + StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"` + StrategyName string + } + + m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string { + if strings.Contains(value, ",") { + return strings.Split(value, ",")[0] + } + return value + }) + strategy := Strategy{"1", "Alpah"} + mapping := m.TypeMap(reflect.TypeOf(strategy)) + + for _, key := range []string{"strategy_id", "STRATEGYNAME"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } +} + +func TestMapping(t *testing.T) { + type Person struct { + ID int + Name string + WearsGlasses bool `db:"wears_glasses"` + } + + m := NewMapperFunc("db", strings.ToLower) + p := Person{1, "Jason", true} + mapping := m.TypeMap(reflect.TypeOf(p)) + + for _, key := range []string{"id", "name", "wears_glasses"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + type SportsPerson struct { + Weight int + Age int + Person + } + s := SportsPerson{Weight: 100, Age: 30, Person: p} + mapping = m.TypeMap(reflect.TypeOf(s)) + for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + type RugbyPlayer struct { + Position int + IsIntense bool `db:"is_intense"` + IsAllBlack bool `db:"-"` + SportsPerson + } + r := RugbyPlayer{12, true, false, s} + mapping = m.TypeMap(reflect.TypeOf(r)) + for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + if fi := mapping.GetByPath("isallblack"); fi != nil { + t.Errorf("Expecting to ignore `IsAllBlack` field") + } +} + +type E1 struct { + A int +} +type E2 struct { + E1 + B int +} +type E3 struct { + E2 + C int +} +type E4 struct { + E3 + D int +} + +func BenchmarkFieldNameL1(b *testing.B) { + e4 := E4{D: 1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.FieldByName("D") + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldNameL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.FieldByName("A") + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldPosL1(b *testing.B) { + e4 := E4{D: 1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.Field(1) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldPosL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.Field(0) + f = f.Field(0) + f = f.Field(0) + f = f.Field(0) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldByIndexL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + idx := []int{0, 0, 0, 0} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := FieldByIndexes(v, idx) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go new file mode 100644 index 0000000..01c492a --- /dev/null +++ b/internal/sqladapter/collection.go @@ -0,0 +1,369 @@ +package sqladapter + +import ( + "fmt" + "reflect" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// CollectionAdapter defines methods to be implemented by SQL adapters. +type CollectionAdapter interface { + // Insert prepares and executes an INSERT statament. When the item is + // succefully added, Insert returns a unique identifier of the newly added + // element (or nil if the unique identifier couldn't be determined). + Insert(Collection, interface{}) (interface{}, error) +} + +// Collection satisfies mydb.Collection. +type Collection interface { + // Insert inserts a new item into the collection. + Insert(interface{}) (mydb.InsertResult, error) + + // Name returns the name of the collection. + Name() string + + // Session returns the mydb.Session the collection belongs to. + Session() mydb.Session + + // Exists returns true if the collection exists, false otherwise. + Exists() (bool, error) + + // Find defined a new result set. + Find(conds ...interface{}) mydb.Result + + Count() (uint64, error) + + // Truncate removes all elements on the collection and resets the + // collection's IDs. + Truncate() error + + // InsertReturning inserts a new item into the collection and refreshes the + // item with actual data from the database. This is useful to get automatic + // values, such as timestamps, or IDs. + InsertReturning(item interface{}) error + + // UpdateReturning updates a record from the collection and refreshes the item + // with actual data from the database. This is useful to get automatic + // values, such as timestamps, or IDs. + UpdateReturning(item interface{}) error + + // PrimaryKeys returns the names of all primary keys in the table. + PrimaryKeys() ([]string, error) + + // SQLBuilder returns a mydb.SQL instance. + SQL() mydb.SQL +} + +type finder interface { + Find(Collection, *Result, ...interface{}) mydb.Result +} + +type condsFilter interface { + FilterConds(...interface{}) []interface{} +} + +// collection is the implementation of Collection. +type collection struct { + name string + adapter CollectionAdapter +} + +type collectionWithSession struct { + *collection + + session Session +} + +func newCollection(name string, adapter CollectionAdapter) *collection { + if adapter == nil { + panic("mydb: nil adapter") + } + return &collection{ + name: name, + adapter: adapter, + } +} + +func (c *collectionWithSession) SQL() mydb.SQL { + return c.session.SQL() +} + +func (c *collectionWithSession) Session() mydb.Session { + return c.session +} + +func (c *collectionWithSession) Name() string { + return c.name +} + +func (c *collectionWithSession) Count() (uint64, error) { + return c.Find().Count() +} + +func (c *collectionWithSession) Insert(item interface{}) (mydb.InsertResult, error) { + id, err := c.adapter.Insert(c, item) + if err != nil { + return nil, err + } + + return mydb.NewInsertResult(id), nil +} + +func (c *collectionWithSession) PrimaryKeys() ([]string, error) { + return c.session.PrimaryKeys(c.Name()) +} + +func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) { + pk, err := c.PrimaryKeys() + if err != nil { + return nil, err + } + if len(conds) == 1 && len(pk) == 1 { + if id := conds[0]; IsKeyValue(id) { + conds[0] = mydb.Cond{pk[0]: mydb.Eq(id)} + } + } + if tr, ok := c.adapter.(condsFilter); ok { + return tr.FilterConds(conds...), nil + } + return conds, nil +} + +func (c *collectionWithSession) Find(conds ...interface{}) mydb.Result { + filteredConds, err := c.filterConds(conds...) + if err != nil { + res := &Result{} + res.setErr(err) + return res + } + + res := NewResult( + c.session.SQL(), + c.Name(), + filteredConds, + ) + if f, ok := c.adapter.(finder); ok { + return f.Find(c, res, conds...) + } + return res +} + +func (c *collectionWithSession) Exists() (bool, error) { + if err := c.session.TableExists(c.Name()); err != nil { + return false, err + } + return true, nil +} + +func (c *collectionWithSession) InsertReturning(item interface{}) error { + if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { + return fmt.Errorf("Expecting a pointer but got %T", item) + } + + // Grab primary keys + pks, err := c.PrimaryKeys() + if err != nil { + return err + } + + if len(pks) == 0 { + if ok, err := c.Exists(); !ok { + return err + } + return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name()) + } + + var tx Session + isTransaction := c.session.IsTransaction() + if isTransaction { + tx = c.session + } else { + var err error + tx, err = c.session.NewTransaction(c.session.Context(), nil) + if err != nil { + return err + } + defer tx.Close() + } + + // Allocate a clone of item. + newItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface() + var newItemFieldMap map[string]reflect.Value + + itemValue := reflect.ValueOf(item) + + col := tx.Collection(c.Name()) + + // Insert item as is and grab the returning ID. + var newItemRes mydb.Result + id, err := col.Insert(item) + if err != nil { + goto cancel + } + if id == nil { + err = fmt.Errorf("InsertReturning: Could not get a valid ID after inserting. Does the %q table have a primary key?", c.Name()) + goto cancel + } + + if len(pks) > 1 { + newItemRes = col.Find(id) + } else { + // We have one primary key, build a explicit mydb.Cond with it to prevent + // string keys to be considered as raw conditions. + newItemRes = col.Find(mydb.Cond{pks[0]: id}) // We already checked that pks is not empty, so pks[0] is defined. + } + + // Fetch the row that was just interted into newItem + err = newItemRes.One(newItem) + if err != nil { + goto cancel + } + + switch reflect.ValueOf(newItem).Elem().Kind() { + case reflect.Struct: + // Get valid fields from newItem to overwrite those that are on item. + newItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(newItem)) + for fieldName := range newItemFieldMap { + sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(newItemFieldMap[fieldName]) + } + case reflect.Map: + newItemV := reflect.ValueOf(newItem).Elem() + itemV := reflect.ValueOf(item) + if itemV.Kind() == reflect.Ptr { + itemV = itemV.Elem() + } + for _, keyV := range newItemV.MapKeys() { + itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV)) + } + default: + err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem) + goto cancel + } + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + return tx.Commit() + } + + return err + +cancel: + // This goto label should only be used when we got an error within a + // transaction and we don't want to continue. + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + _ = tx.Rollback() + } + return err +} + +func (c *collectionWithSession) UpdateReturning(item interface{}) error { + if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { + return fmt.Errorf("Expecting a pointer but got %T", item) + } + + // Grab primary keys + pks, err := c.PrimaryKeys() + if err != nil { + return err + } + + if len(pks) == 0 { + if ok, err := c.Exists(); !ok { + return err + } + return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name()) + } + + var tx Session + isTransaction := c.session.IsTransaction() + + if isTransaction { + tx = c.session + } else { + // Not within a transaction, let's create one. + var err error + tx, err = c.session.NewTransaction(c.session.Context(), nil) + if err != nil { + return err + } + defer tx.Close() + } + + // Allocate a clone of item. + defaultItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface() + var defaultItemFieldMap map[string]reflect.Value + + itemValue := reflect.ValueOf(item) + + conds := mydb.Cond{} + for _, pk := range pks { + conds[pk] = mydb.Eq(sqlbuilder.Mapper.FieldByName(itemValue, pk).Interface()) + } + + col := tx.(Session).Collection(c.Name()) + + err = col.Find(conds).Update(item) + if err != nil { + goto cancel + } + + if err = col.Find(conds).One(defaultItem); err != nil { + goto cancel + } + + switch reflect.ValueOf(defaultItem).Elem().Kind() { + case reflect.Struct: + // Get valid fields from defaultItem to overwrite those that are on item. + defaultItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(defaultItem)) + for fieldName := range defaultItemFieldMap { + sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(defaultItemFieldMap[fieldName]) + } + case reflect.Map: + defaultItemV := reflect.ValueOf(defaultItem).Elem() + itemV := reflect.ValueOf(item) + if itemV.Kind() == reflect.Ptr { + itemV = itemV.Elem() + } + for _, keyV := range defaultItemV.MapKeys() { + itemV.SetMapIndex(keyV, defaultItemV.MapIndex(keyV)) + } + default: + panic("default") + } + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + return tx.Commit() + } + return err + +cancel: + // This goto label should only be used when we got an error within a + // transaction and we don't want to continue. + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + _ = tx.Rollback() + } + return err +} + +func (c *collectionWithSession) Truncate() error { + stmt := exql.Statement{ + Type: exql.Truncate, + Table: exql.TableWithName(c.Name()), + } + if _, err := c.session.SQL().Exec(&stmt); err != nil { + return err + } + return nil +} diff --git a/internal/sqladapter/compat/query.go b/internal/sqladapter/compat/query.go new file mode 100644 index 0000000..93cb8fc --- /dev/null +++ b/internal/sqladapter/compat/query.go @@ -0,0 +1,72 @@ +// +build !go1.8 + +package compat + +import ( + "context" + "database/sql" +) + +type PreparedExecer interface { + Exec(...interface{}) (sql.Result, error) +} + +func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) { + return p.Exec(args...) +} + +type Execer interface { + Exec(string, ...interface{}) (sql.Result, error) +} + +func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) { + return p.Exec(query, args...) +} + +type PreparedQueryer interface { + Query(...interface{}) (*sql.Rows, error) +} + +func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) { + return p.Query(args...) +} + +type Queryer interface { + Query(string, ...interface{}) (*sql.Rows, error) +} + +func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) { + return p.Query(query, args...) +} + +type PreparedRowQueryer interface { + QueryRow(...interface{}) *sql.Row +} + +func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row { + return p.QueryRow(args...) +} + +type RowQueryer interface { + QueryRow(string, ...interface{}) *sql.Row +} + +func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row { + return p.QueryRow(query, args...) +} + +type Preparer interface { + Prepare(string) (*sql.Stmt, error) +} + +func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) { + return p.Prepare(query) +} + +type TxStarter interface { + Begin() (*sql.Tx, error) +} + +func BeginTx(p TxStarter, ctx context.Context, opts interface{}) (*sql.Tx, error) { + return p.Begin() +} diff --git a/internal/sqladapter/compat/query_go18.go b/internal/sqladapter/compat/query_go18.go new file mode 100644 index 0000000..a3abbaf --- /dev/null +++ b/internal/sqladapter/compat/query_go18.go @@ -0,0 +1,72 @@ +// +build go1.8 + +package compat + +import ( + "context" + "database/sql" +) + +type PreparedExecer interface { + ExecContext(context.Context, ...interface{}) (sql.Result, error) +} + +func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) { + return p.ExecContext(ctx, args...) +} + +type Execer interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) +} + +func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) { + return p.ExecContext(ctx, query, args...) +} + +type PreparedQueryer interface { + QueryContext(context.Context, ...interface{}) (*sql.Rows, error) +} + +func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) { + return p.QueryContext(ctx, args...) +} + +type Queryer interface { + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) +} + +func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) { + return p.QueryContext(ctx, query, args...) +} + +type PreparedRowQueryer interface { + QueryRowContext(context.Context, ...interface{}) *sql.Row +} + +func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row { + return p.QueryRowContext(ctx, args...) +} + +type RowQueryer interface { + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row { + return p.QueryRowContext(ctx, query, args...) +} + +type Preparer interface { + PrepareContext(context.Context, string) (*sql.Stmt, error) +} + +func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) { + return p.PrepareContext(ctx, query) +} + +type TxStarter interface { + BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) +} + +func BeginTx(p TxStarter, ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return p.BeginTx(ctx, opts) +} diff --git a/internal/sqladapter/exql/column.go b/internal/sqladapter/exql/column.go new file mode 100644 index 0000000..ec10825 --- /dev/null +++ b/internal/sqladapter/exql/column.go @@ -0,0 +1,83 @@ +package exql + +import ( + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +type columnWithAlias struct { + Name string + Alias string +} + +// Column represents a SQL column. +type Column struct { + Name interface{} +} + +var _ = Fragment(&Column{}) + +// ColumnWithName creates and returns a Column with the given name. +func ColumnWithName(name string) *Column { + return &Column{Name: name} +} + +// Hash returns a unique identifier for the struct. +func (c *Column) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_Column, nil) + } + return cache.NewHash(FragmentType_Column, c.Name) +} + +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *Column) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(c); ok { + return z, nil + } + + var alias string + switch value := c.Name.(type) { + case string: + value = trimString(value) + + chunks := separateByAS(value) + if len(chunks) == 1 { + chunks = separateBySpace(value) + } + + name := chunks[0] + nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) + + for i := range nameChunks { + nameChunks[i] = trimString(nameChunks[i]) + if nameChunks[i] == "*" { + continue + } + nameChunks[i] = layout.MustCompile(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) + } + + compiled = strings.Join(nameChunks, layout.ColumnSeparator) + + if len(chunks) > 1 { + alias = trimString(chunks[1]) + alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias}) + } + case compilable: + compiled, err = value.Compile(layout) + if err != nil { + return "", err + } + default: + return "", fmt.Errorf(errExpectingHashableFmt, c.Name) + } + + if alias != "" { + compiled = layout.MustCompile(layout.ColumnAliasLayout, columnWithAlias{compiled, alias}) + } + + layout.Write(c, compiled) + return +} diff --git a/internal/sqladapter/exql/column_test.go b/internal/sqladapter/exql/column_test.go new file mode 100644 index 0000000..7706538 --- /dev/null +++ b/internal/sqladapter/exql/column_test.go @@ -0,0 +1,88 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestColumnString(t *testing.T) { + column := Column{Name: "role.name"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"role"."name"`, s) +} + +func TestColumnAs(t *testing.T) { + column := Column{Name: "role.name as foo"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"role"."name" AS "foo"`, s) +} + +func TestColumnImplicitAs(t *testing.T) { + column := Column{Name: "role.name foo"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"role"."name" AS "foo"`, s) +} + +func TestColumnRaw(t *testing.T) { + column := Column{Name: &Raw{Value: "role.name As foo"}} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `role.name As foo`, s) +} + +func BenchmarkColumnWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ColumnWithName("a") + } +} + +func BenchmarkColumnHash(b *testing.B) { + c := Column{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnCompile(b *testing.B) { + c := Column{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Column{Name: "name"} + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithDotCompile(b *testing.B) { + c := Column{Name: "role.name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name AS foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/column_value.go b/internal/sqladapter/exql/column_value.go new file mode 100644 index 0000000..4d1b303 --- /dev/null +++ b/internal/sqladapter/exql/column_value.go @@ -0,0 +1,112 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// ColumnValue represents a bundle between a column and a corresponding value. +type ColumnValue struct { + Column Fragment + Operator string + Value Fragment +} + +var _ = Fragment(&ColumnValue{}) + +type columnValueT struct { + Column string + Operator string + Value string +} + +// Hash returns a unique identifier for the struct. +func (c *ColumnValue) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_ColumnValue, nil) + } + return cache.NewHash(FragmentType_ColumnValue, c.Column, c.Operator, c.Value) +} + +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *ColumnValue) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(c); ok { + return z, nil + } + + column, err := c.Column.Compile(layout) + if err != nil { + return "", err + } + + data := columnValueT{ + Column: column, + Operator: c.Operator, + } + + if c.Value != nil { + data.Value, err = c.Value.Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.TrimSpace(layout.MustCompile(layout.ColumnValue, data)) + + layout.Write(c, compiled) + + return +} + +// ColumnValues represents an array of ColumnValue +type ColumnValues struct { + ColumnValues []Fragment +} + +var _ = Fragment(&ColumnValues{}) + +// JoinColumnValues returns an array of ColumnValue +func JoinColumnValues(values ...Fragment) *ColumnValues { + return &ColumnValues{ColumnValues: values} +} + +// Insert adds a column to the columns array. +func (c *ColumnValues) Insert(values ...Fragment) *ColumnValues { + c.ColumnValues = append(c.ColumnValues, values...) + return c +} + +// Hash returns a unique identifier for the struct. +func (c *ColumnValues) Hash() uint64 { + h := cache.InitHash(FragmentType_ColumnValues) + for i := range c.ColumnValues { + h = cache.AddToHash(h, c.ColumnValues[i]) + } + return h +} + +// Compile transforms the ColumnValues into its SQL representation. +func (c *ColumnValues) Compile(layout *Template) (compiled string, err error) { + + if z, ok := layout.Read(c); ok { + return z, nil + } + + l := len(c.ColumnValues) + + out := make([]string, l) + + for i := range c.ColumnValues { + out[i], err = c.ColumnValues[i].Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.TrimSpace(strings.Join(out, layout.IdentifierSeparator)) + + layout.Write(c, compiled) + + return +} diff --git a/internal/sqladapter/exql/column_value_test.go b/internal/sqladapter/exql/column_value_test.go new file mode 100644 index 0000000..33ecc36 --- /dev/null +++ b/internal/sqladapter/exql/column_value_test.go @@ -0,0 +1,115 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestColumnValue(t *testing.T) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + s, err := cv.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"id" = '1'`, s) + + cv = &ColumnValue{Column: ColumnWithName("date"), Operator: "=", Value: &Raw{Value: "NOW()"}} + s, err = cv.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"date" = NOW()`, s) +} + +func TestColumnValues(t *testing.T) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + + s, err := cvs.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()`, s) +} + +func BenchmarkNewColumnValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = &ColumnValue{Column: ColumnWithName("a"), Operator: "=", Value: NewValue(Raw{Value: "7"})} + } +} + +func BenchmarkColumnValueHash(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + b.ResetTimer() + for i := 0; i < b.N; i++ { + cv.Hash() + } +} + +func BenchmarkColumnValueCompile(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = cv.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + _, _ = cv.Compile(defaultTemplate) + } +} + +func BenchmarkJoinColumnValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + } +} + +func BenchmarkColumnValuesHash(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + cvs.Hash() + } +} + +func BenchmarkColumnValuesCompile(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = cvs.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + _, _ = cvs.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/columns.go b/internal/sqladapter/exql/columns.go new file mode 100644 index 0000000..d060b4e --- /dev/null +++ b/internal/sqladapter/exql/columns.go @@ -0,0 +1,83 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Columns represents an array of Column. +type Columns struct { + Columns []Fragment +} + +var _ = Fragment(&Columns{}) + +// Hash returns a unique identifier. +func (c *Columns) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_Columns, nil) + } + h := cache.InitHash(FragmentType_Columns) + for i := range c.Columns { + h = cache.AddToHash(h, c.Columns[i]) + } + return h +} + +// JoinColumns creates and returns an array of Column. +func JoinColumns(columns ...Fragment) *Columns { + return &Columns{Columns: columns} +} + +// OnConditions creates and retuens a new On. +func OnConditions(conditions ...Fragment) *On { + return &On{Conditions: conditions} +} + +// UsingColumns builds a Using from the given columns. +func UsingColumns(columns ...Fragment) *Using { + return &Using{Columns: columns} +} + +// Append +func (c *Columns) Append(a *Columns) *Columns { + c.Columns = append(c.Columns, a.Columns...) + return c +} + +// IsEmpty +func (c *Columns) IsEmpty() bool { + if c == nil || len(c.Columns) < 1 { + return true + } + return false +} + +// Compile transforms the Columns into an equivalent SQL representation. +func (c *Columns) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(c); ok { + return z, nil + } + + l := len(c.Columns) + + if l > 0 { + out := make([]string, l) + + for i := 0; i < l; i++ { + out[i], err = c.Columns[i].Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.Join(out, layout.IdentifierSeparator) + } else { + compiled = "*" + } + + layout.Write(c, compiled) + + return +} diff --git a/internal/sqladapter/exql/columns_test.go b/internal/sqladapter/exql/columns_test.go new file mode 100644 index 0000000..39cbd5c --- /dev/null +++ b/internal/sqladapter/exql/columns_test.go @@ -0,0 +1,72 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestColumns(t *testing.T) { + columns := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + + s, err := columns.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"id", "customer", "service_id", "role"."name", "role"."id"`, s) +} + +func BenchmarkJoinColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkColumnsHash(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnsCompile(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/database.go b/internal/sqladapter/exql/database.go new file mode 100644 index 0000000..36bf37c --- /dev/null +++ b/internal/sqladapter/exql/database.go @@ -0,0 +1,37 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Database represents a SQL database. +type Database struct { + Name string +} + +var _ = Fragment(&Database{}) + +// DatabaseWithName returns a Database with the given name. +func DatabaseWithName(name string) *Database { + return &Database{Name: name} +} + +// Hash returns a unique identifier for the struct. +func (d *Database) Hash() uint64 { + if d == nil { + return cache.NewHash(FragmentType_Database, nil) + } + return cache.NewHash(FragmentType_Database, d.Name) +} + +// Compile transforms the Database into an equivalent SQL representation. +func (d *Database) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(d); ok { + return c, nil + } + + compiled = layout.MustCompile(layout.IdentifierQuote, Raw{Value: d.Name}) + + layout.Write(d, compiled) + return +} diff --git a/internal/sqladapter/exql/database_test.go b/internal/sqladapter/exql/database_test.go new file mode 100644 index 0000000..aba55be --- /dev/null +++ b/internal/sqladapter/exql/database_test.go @@ -0,0 +1,45 @@ +package exql + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDatabaseCompile(t *testing.T) { + column := Database{Name: "name"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"name"`, s) +} + +func BenchmarkDatabaseHash(b *testing.B) { + c := Database{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkDatabaseCompile(b *testing.B) { + c := Database{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkDatabaseCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Database{Name: "name"} + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkDatabaseCompileNoCache2(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Database{Name: strconv.Itoa(i)} + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/default.go b/internal/sqladapter/exql/default.go new file mode 100644 index 0000000..12b4354 --- /dev/null +++ b/internal/sqladapter/exql/default.go @@ -0,0 +1,192 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +const ( + defaultColumnSeparator = `.` + defaultIdentifierSeparator = `, ` + defaultIdentifierQuote = `"{{.Value}}"` + defaultValueSeparator = `, ` + defaultValueQuote = `'{{.}}'` + defaultAndKeyword = `AND` + defaultOrKeyword = `OR` + defaultDescKeyword = `DESC` + defaultAscKeyword = `ASC` + defaultAssignmentOperator = `=` + defaultClauseGroup = `({{.}})` + defaultClauseOperator = ` {{.}} ` + defaultColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + defaultTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultSortByColumnLayout = `{{.Column}} {{.Order}}` + + defaultOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + defaultWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + defaultUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + defaultJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + defaultOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + defaultSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{.GroupBy | compile}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + defaultDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + defaultUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + defaultCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + + {{if .Limit}} + LIMIT {{.Limit | compile}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + + defaultInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if .Columns }}({{.Columns | compile}}){{end}} + VALUES + {{.Values | compile}} + {{if .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + defaultTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} + ` + + defaultDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + defaultDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + defaultGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var defaultTemplate = &Template{ + AndKeyword: defaultAndKeyword, + AscKeyword: defaultAscKeyword, + AssignmentOperator: defaultAssignmentOperator, + ClauseGroup: defaultClauseGroup, + ClauseOperator: defaultClauseOperator, + ColumnAliasLayout: defaultColumnAliasLayout, + ColumnSeparator: defaultColumnSeparator, + ColumnValue: defaultColumnValue, + CountLayout: defaultCountLayout, + DeleteLayout: defaultDeleteLayout, + DescKeyword: defaultDescKeyword, + DropDatabaseLayout: defaultDropDatabaseLayout, + DropTableLayout: defaultDropTableLayout, + GroupByLayout: defaultGroupByLayout, + IdentifierQuote: defaultIdentifierQuote, + IdentifierSeparator: defaultIdentifierSeparator, + InsertLayout: defaultInsertLayout, + JoinLayout: defaultJoinLayout, + OnLayout: defaultOnLayout, + OrKeyword: defaultOrKeyword, + OrderByLayout: defaultOrderByLayout, + SelectLayout: defaultSelectLayout, + SortByColumnLayout: defaultSortByColumnLayout, + TableAliasLayout: defaultTableAliasLayout, + TruncateLayout: defaultTruncateLayout, + UpdateLayout: defaultUpdateLayout, + UsingLayout: defaultUsingLayout, + ValueQuote: defaultValueQuote, + ValueSeparator: defaultValueSeparator, + WhereLayout: defaultWhereLayout, + + Cache: cache.NewCache(), +} diff --git a/internal/sqladapter/exql/errors.go b/internal/sqladapter/exql/errors.go new file mode 100644 index 0000000..b9c8b85 --- /dev/null +++ b/internal/sqladapter/exql/errors.go @@ -0,0 +1,5 @@ +package exql + +const ( + errExpectingHashableFmt = "expecting hashable value, got %T" +) diff --git a/internal/sqladapter/exql/group_by.go b/internal/sqladapter/exql/group_by.go new file mode 100644 index 0000000..6583d44 --- /dev/null +++ b/internal/sqladapter/exql/group_by.go @@ -0,0 +1,60 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// GroupBy represents a SQL's "group by" statement. +type GroupBy struct { + Columns Fragment +} + +var _ = Fragment(&GroupBy{}) + +type groupByT struct { + GroupColumns string +} + +// Hash returns a unique identifier. +func (g *GroupBy) Hash() uint64 { + if g == nil { + return cache.NewHash(FragmentType_GroupBy, nil) + } + return cache.NewHash(FragmentType_GroupBy, g.Columns) +} + +// GroupByColumns creates and returns a GroupBy with the given column. +func GroupByColumns(columns ...Fragment) *GroupBy { + return &GroupBy{Columns: JoinColumns(columns...)} +} + +func (g *GroupBy) IsEmpty() bool { + if g == nil || g.Columns == nil { + return true + } + return g.Columns.(hasIsEmpty).IsEmpty() +} + +// Compile transforms the GroupBy into an equivalent SQL representation. +func (g *GroupBy) Compile(layout *Template) (compiled string, err error) { + + if c, ok := layout.Read(g); ok { + return c, nil + } + + if g.Columns != nil { + columns, err := g.Columns.Compile(layout) + if err != nil { + return "", err + } + + data := groupByT{ + GroupColumns: columns, + } + compiled = layout.MustCompile(layout.GroupByLayout, data) + } + + layout.Write(g, compiled) + + return +} diff --git a/internal/sqladapter/exql/group_by_test.go b/internal/sqladapter/exql/group_by_test.go new file mode 100644 index 0000000..cdc1e6f --- /dev/null +++ b/internal/sqladapter/exql/group_by_test.go @@ -0,0 +1,71 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGroupBy(t *testing.T) { + columns := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + + s := mustTrim(columns.Compile(defaultTemplate)) + assert.Equal(t, `GROUP BY "id", "customer", "service_id", "role"."name", "role"."id"`, s) +} + +func BenchmarkGroupByColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = GroupByColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkGroupByHash(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkGroupByCompile(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkGroupByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/interfaces.go b/internal/sqladapter/exql/interfaces.go new file mode 100644 index 0000000..efa34ed --- /dev/null +++ b/internal/sqladapter/exql/interfaces.go @@ -0,0 +1,20 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Fragment is any interface that can be both cached and compiled. +type Fragment interface { + cache.Hashable + + compilable +} + +type compilable interface { + Compile(*Template) (string, error) +} + +type hasIsEmpty interface { + IsEmpty() bool +} diff --git a/internal/sqladapter/exql/join.go b/internal/sqladapter/exql/join.go new file mode 100644 index 0000000..eed8f5b --- /dev/null +++ b/internal/sqladapter/exql/join.go @@ -0,0 +1,195 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +type innerJoinT struct { + Type string + Table string + On string + Using string +} + +// Joins represents the union of different join conditions. +type Joins struct { + Conditions []Fragment +} + +var _ = Fragment(&Joins{}) + +// Hash returns a unique identifier for the struct. +func (j *Joins) Hash() uint64 { + if j == nil { + return cache.NewHash(FragmentType_Joins, nil) + } + h := cache.InitHash(FragmentType_Joins) + for i := range j.Conditions { + h = cache.AddToHash(h, j.Conditions[i]) + } + return h +} + +// Compile transforms the Where into an equivalent SQL representation. +func (j *Joins) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(j); ok { + return c, nil + } + + l := len(j.Conditions) + + chunks := make([]string, 0, l) + + if l > 0 { + for i := 0; i < l; i++ { + chunk, err := j.Conditions[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + } + + compiled = strings.Join(chunks, " ") + + layout.Write(j, compiled) + + return +} + +// JoinConditions creates a Joins object. +func JoinConditions(joins ...*Join) *Joins { + fragments := make([]Fragment, len(joins)) + for i := range fragments { + fragments[i] = joins[i] + } + return &Joins{Conditions: fragments} +} + +// Join represents a generic JOIN statement. +type Join struct { + Type string + Table Fragment + On Fragment + Using Fragment +} + +var _ = Fragment(&Join{}) + +// Hash returns a unique identifier for the struct. +func (j *Join) Hash() uint64 { + if j == nil { + return cache.NewHash(FragmentType_Join, nil) + } + return cache.NewHash(FragmentType_Join, j.Type, j.Table, j.On, j.Using) +} + +// Compile transforms the Join into its equivalent SQL representation. +func (j *Join) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(j); ok { + return c, nil + } + + if j.Table == nil { + return "", nil + } + + table, err := j.Table.Compile(layout) + if err != nil { + return "", err + } + + on, err := layout.doCompile(j.On) + if err != nil { + return "", err + } + + using, err := layout.doCompile(j.Using) + if err != nil { + return "", err + } + + data := innerJoinT{ + Type: j.Type, + Table: table, + On: on, + Using: using, + } + + compiled = layout.MustCompile(layout.JoinLayout, data) + layout.Write(j, compiled) + return +} + +// On represents JOIN conditions. +type On Where + +var _ = Fragment(&On{}) + +func (o *On) Hash() uint64 { + if o == nil { + return cache.NewHash(FragmentType_On, nil) + } + return cache.NewHash(FragmentType_On, (*Where)(o)) +} + +// Compile transforms the On into an equivalent SQL representation. +func (o *On) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(o); ok { + return c, nil + } + + grouped, err := groupCondition(layout, o.Conditions, layout.MustCompile(layout.ClauseOperator, layout.AndKeyword)) + if err != nil { + return "", err + } + + if grouped != "" { + compiled = layout.MustCompile(layout.OnLayout, conds{grouped}) + } + + layout.Write(o, compiled) + return +} + +// Using represents a USING function. +type Using Columns + +var _ = Fragment(&Using{}) + +type usingT struct { + Columns string +} + +func (u *Using) Hash() uint64 { + if u == nil { + return cache.NewHash(FragmentType_Using, nil) + } + return cache.NewHash(FragmentType_Using, (*Columns)(u)) +} + +// Compile transforms the Using into an equivalent SQL representation. +func (u *Using) Compile(layout *Template) (compiled string, err error) { + if u == nil { + return "", nil + } + + if c, ok := layout.Read(u); ok { + return c, nil + } + + if len(u.Columns) > 0 { + c := Columns(*u) + columns, err := c.Compile(layout) + if err != nil { + return "", err + } + data := usingT{Columns: columns} + compiled = layout.MustCompile(layout.UsingLayout, data) + } + + layout.Write(u, compiled) + return +} diff --git a/internal/sqladapter/exql/join_test.go b/internal/sqladapter/exql/join_test.go new file mode 100644 index 0000000..65ce6aa --- /dev/null +++ b/internal/sqladapter/exql/join_test.go @@ -0,0 +1,221 @@ +package exql + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOnAndRawOrAnd(t *testing.T) { + on := OnConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + s := mustTrim(on.Compile(defaultTemplate)) + assert.Equal(t, `ON (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, s) +} + +func TestUsing(t *testing.T) { + using := UsingColumns( + &Column{Name: "country"}, + &Column{Name: "state"}, + ) + + s := mustTrim(using.Compile(defaultTemplate)) + assert.Equal(t, `USING ("country", "state")`, s) +} + +func TestJoinOn(t *testing.T) { + join := JoinConditions( + &Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }, + ) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")`, s) +} + +func TestInnerJoinOn(t *testing.T) { + join := JoinConditions(&Join{ + Type: "INNER", + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(ColumnWithName("a.id")), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(ColumnWithName("a.code")), + }, + ), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `INNER JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")`, s) +} + +func TestLeftJoinUsing(t *testing.T) { + join := JoinConditions(&Join{ + Type: "LEFT", + Table: TableWithName("countries"), + Using: UsingColumns(ColumnWithName("name")), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `LEFT JOIN "countries" USING ("name")`, s) +} + +func TestNaturalJoinOn(t *testing.T) { + join := JoinConditions(&Join{ + Table: TableWithName("countries"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL JOIN "countries"`, s) +} + +func TestNaturalInnerJoinOn(t *testing.T) { + join := JoinConditions(&Join{ + Type: "INNER", + Table: TableWithName("countries"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL INNER JOIN "countries"`, s) +} + +func TestCrossJoin(t *testing.T) { + join := JoinConditions(&Join{ + Type: "CROSS", + Table: TableWithName("countries"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `CROSS JOIN "countries"`, s) +} + +func TestMultipleJoins(t *testing.T) { + join := JoinConditions(&Join{ + Type: "LEFT", + Table: TableWithName("countries"), + }, &Join{ + Table: TableWithName("cities"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL LEFT JOIN "countries" NATURAL JOIN "cities"`, s) +} + +func BenchmarkJoin(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinConditions(&Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + } +} + +func BenchmarkCompileJoin(b *testing.B) { + j := JoinConditions(&Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = j.Compile(defaultTemplate) + } +} + +func BenchmarkCompileJoinNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + j := JoinConditions(&Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + _, _ = j.Compile(defaultTemplate) + } +} + +func BenchmarkCompileJoinNoCache2(b *testing.B) { + for i := 0; i < b.N; i++ { + j := JoinConditions(&Join{ + Table: TableWithName(fmt.Sprintf("countries c%d", i)), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + _, _ = j.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/order_by.go b/internal/sqladapter/exql/order_by.go new file mode 100644 index 0000000..d41a48a --- /dev/null +++ b/internal/sqladapter/exql/order_by.go @@ -0,0 +1,175 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Order represents the order in which SQL results are sorted. +type Order uint8 + +// Possible values for Order +const ( + Order_Default Order = iota + + Order_Ascendent + Order_Descendent +) + +func (o Order) Hash() uint64 { + return cache.NewHash(FragmentType_Order, uint8(o)) +} + +// SortColumn represents the column-order relation in an ORDER BY clause. +type SortColumn struct { + Column Fragment + Order +} + +var _ = Fragment(&SortColumn{}) + +type sortColumnT struct { + Column string + Order string +} + +var _ = Fragment(&SortColumn{}) + +// SortColumns represents the columns in an ORDER BY clause. +type SortColumns struct { + Columns []Fragment +} + +var _ = Fragment(&SortColumns{}) + +// OrderBy represents an ORDER BY clause. +type OrderBy struct { + SortColumns Fragment +} + +var _ = Fragment(&OrderBy{}) + +type orderByT struct { + SortColumns string +} + +// JoinSortColumns creates and returns an array of column-order relations. +func JoinSortColumns(values ...Fragment) *SortColumns { + return &SortColumns{Columns: values} +} + +// JoinWithOrderBy creates an returns an OrderBy using the given SortColumns. +func JoinWithOrderBy(sc *SortColumns) *OrderBy { + return &OrderBy{SortColumns: sc} +} + +// Hash returns a unique identifier for the struct. +func (s *SortColumn) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_SortColumn, nil) + } + return cache.NewHash(FragmentType_SortColumn, s.Column, s.Order) +} + +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *SortColumn) Compile(layout *Template) (compiled string, err error) { + + if c, ok := layout.Read(s); ok { + return c, nil + } + + column, err := s.Column.Compile(layout) + if err != nil { + return "", err + } + + orderBy, err := s.Order.Compile(layout) + if err != nil { + return "", err + } + + data := sortColumnT{Column: column, Order: orderBy} + + compiled = layout.MustCompile(layout.SortByColumnLayout, data) + + layout.Write(s, compiled) + + return +} + +// Hash returns a unique identifier for the struct. +func (s *SortColumns) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_SortColumns, nil) + } + h := cache.InitHash(FragmentType_SortColumns) + for i := range s.Columns { + h = cache.AddToHash(h, s.Columns[i]) + } + return h +} + +// Compile transforms the SortColumns into an equivalent SQL representation. +func (s *SortColumns) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(s); ok { + return z, nil + } + + z := make([]string, len(s.Columns)) + + for i := range s.Columns { + z[i], err = s.Columns[i].Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.Join(z, layout.IdentifierSeparator) + + layout.Write(s, compiled) + + return +} + +// Hash returns a unique identifier for the struct. +func (s *OrderBy) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_OrderBy, nil) + } + return cache.NewHash(FragmentType_OrderBy, s.SortColumns) +} + +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *OrderBy) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(s); ok { + return z, nil + } + + if s.SortColumns != nil { + sortColumns, err := s.SortColumns.Compile(layout) + if err != nil { + return "", err + } + + data := orderByT{ + SortColumns: sortColumns, + } + compiled = layout.MustCompile(layout.OrderByLayout, data) + } + + layout.Write(s, compiled) + + return +} + +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s Order) Compile(layout *Template) (string, error) { + switch s { + case Order_Ascendent: + return layout.AscKeyword, nil + case Order_Descendent: + return layout.DescKeyword, nil + } + return "", nil +} diff --git a/internal/sqladapter/exql/order_by_test.go b/internal/sqladapter/exql/order_by_test.go new file mode 100644 index 0000000..f214f86 --- /dev/null +++ b/internal/sqladapter/exql/order_by_test.go @@ -0,0 +1,154 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOrderBy(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + + s := mustTrim(o.Compile(defaultTemplate)) + assert.Equal(t, `ORDER BY "foo"`, s) +} + +func TestOrderByRaw(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Raw{Value: "CASE WHEN id IN ? THEN 0 ELSE 1 END"}}, + ), + ) + + s := mustTrim(o.Compile(defaultTemplate)) + assert.Equal(t, `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END`, s) +} + +func TestOrderByDesc(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + ), + ) + + s := mustTrim(o.Compile(defaultTemplate)) + assert.Equal(t, `ORDER BY "foo" DESC`, s) +} + +func BenchmarkOrderBy(b *testing.B) { + for i := 0; i < b.N; i++ { + JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + } +} + +func BenchmarkOrderByHash(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Hash() + } +} + +func BenchmarkCompileOrderByCompile(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompile(b *testing.B) { + o := Order_Descendent + for i := 0; i < b.N; i++ { + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := Order_Descendent + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnHash(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnCompile(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := &SortColumn{Column: &Column{Name: "foo"}} + _, _ = s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsHash(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnsCompile(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + _, _ = s.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/raw.go b/internal/sqladapter/exql/raw.go new file mode 100644 index 0000000..808f6e6 --- /dev/null +++ b/internal/sqladapter/exql/raw.go @@ -0,0 +1,48 @@ +package exql + +import ( + "fmt" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +var ( + _ = fmt.Stringer(&Raw{}) +) + +// Raw represents a value that is meant to be used in a query without escaping. +type Raw struct { + Value string +} + +func NewRawValue(v interface{}) (*Raw, error) { + switch t := v.(type) { + case string: + return &Raw{Value: t}, nil + case int, uint, int64, uint64, int32, uint32, int16, uint16: + return &Raw{Value: fmt.Sprintf("%d", t)}, nil + case fmt.Stringer: + return &Raw{Value: t.String()}, nil + } + return nil, fmt.Errorf("unexpected type: %T", v) +} + +// Hash returns a unique identifier for the struct. +func (r *Raw) Hash() uint64 { + if r == nil { + return cache.NewHash(FragmentType_Raw, nil) + } + return cache.NewHash(FragmentType_Raw, r.Value) +} + +// Compile returns the raw value. +func (r *Raw) Compile(*Template) (string, error) { + return r.Value, nil +} + +// String returns the raw value. +func (r *Raw) String() string { + return r.Value +} + +var _ = Fragment(&Raw{}) diff --git a/internal/sqladapter/exql/raw_test.go b/internal/sqladapter/exql/raw_test.go new file mode 100644 index 0000000..66e38b1 --- /dev/null +++ b/internal/sqladapter/exql/raw_test.go @@ -0,0 +1,51 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRawString(t *testing.T) { + raw := &Raw{Value: "foo"} + s, err := raw.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `foo`, s) +} + +func TestRawCompile(t *testing.T) { + raw := &Raw{Value: "foo"} + s, err := raw.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `foo`, s) +} + +func BenchmarkRawCreate(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Raw{Value: "foo"} + } +} + +func BenchmarkRawString(b *testing.B) { + raw := &Raw{Value: "foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = raw.String() + } +} + +func BenchmarkRawCompile(b *testing.B) { + raw := &Raw{Value: "foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = raw.Compile(defaultTemplate) + } +} + +func BenchmarkRawHash(b *testing.B) { + raw := &Raw{Value: "foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + raw.Hash() + } +} diff --git a/internal/sqladapter/exql/returning.go b/internal/sqladapter/exql/returning.go new file mode 100644 index 0000000..9b882dc --- /dev/null +++ b/internal/sqladapter/exql/returning.go @@ -0,0 +1,41 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Returning represents a RETURNING clause. +type Returning struct { + *Columns +} + +// Hash returns a unique identifier for the struct. +func (r *Returning) Hash() uint64 { + if r == nil { + return cache.NewHash(FragmentType_Returning, nil) + } + return cache.NewHash(FragmentType_Returning, r.Columns) +} + +var _ = Fragment(&Returning{}) + +// ReturningColumns creates and returns an array of Column. +func ReturningColumns(columns ...Fragment) *Returning { + return &Returning{Columns: &Columns{Columns: columns}} +} + +// Compile transforms the clause into its equivalent SQL representation. +func (r *Returning) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(r); ok { + return z, nil + } + + compiled, err = r.Columns.Compile(layout) + if err != nil { + return "", err + } + + layout.Write(r, compiled) + + return +} diff --git a/internal/sqladapter/exql/statement.go b/internal/sqladapter/exql/statement.go new file mode 100644 index 0000000..4351682 --- /dev/null +++ b/internal/sqladapter/exql/statement.go @@ -0,0 +1,132 @@ +package exql + +import ( + "errors" + "reflect" + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +var errUnknownTemplateType = errors.New("Unknown template type") + +// represents different kinds of SQL statements. +type Statement struct { + Type + Table Fragment + Database Fragment + Columns Fragment + Values Fragment + Distinct bool + ColumnValues Fragment + OrderBy Fragment + GroupBy Fragment + Joins Fragment + Where Fragment + Returning Fragment + + Limit + Offset + + SQL string + + amendFn func(string) string +} + +func (layout *Template) doCompile(c Fragment) (string, error) { + if c != nil && !reflect.ValueOf(c).IsNil() { + return c.Compile(layout) + } + return "", nil +} + +// Hash returns a unique identifier for the struct. +func (s *Statement) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_Statement, nil) + } + return cache.NewHash( + FragmentType_Statement, + s.Type, + s.Table, + s.Database, + s.Columns, + s.Values, + s.Distinct, + s.ColumnValues, + s.OrderBy, + s.GroupBy, + s.Joins, + s.Where, + s.Returning, + s.Limit, + s.Offset, + s.SQL, + ) +} + +func (s *Statement) SetAmendment(amendFn func(string) string) { + s.amendFn = amendFn +} + +func (s *Statement) Amend(in string) string { + if s.amendFn == nil { + return in + } + return s.amendFn(in) +} + +func (s *Statement) template(layout *Template) (string, error) { + switch s.Type { + case Truncate: + return layout.TruncateLayout, nil + case DropTable: + return layout.DropTableLayout, nil + case DropDatabase: + return layout.DropDatabaseLayout, nil + case Count: + return layout.CountLayout, nil + case Select: + return layout.SelectLayout, nil + case Delete: + return layout.DeleteLayout, nil + case Update: + return layout.UpdateLayout, nil + case Insert: + return layout.InsertLayout, nil + default: + return "", errUnknownTemplateType + } +} + +// Compile transforms the Statement into an equivalent SQL query. +func (s *Statement) Compile(layout *Template) (compiled string, err error) { + if s.Type == SQL { + // No need to hit the cache. + return s.SQL, nil + } + + if z, ok := layout.Read(s); ok { + return s.Amend(z), nil + } + + tpl, err := s.template(layout) + if err != nil { + return "", err + } + + compiled = layout.MustCompile(tpl, s) + + compiled = strings.TrimSpace(compiled) + layout.Write(s, compiled) + + return s.Amend(compiled), nil +} + +// RawSQL represents a raw SQL statement. +func RawSQL(s string) *Statement { + return &Statement{ + Type: SQL, + SQL: s, + } +} diff --git a/internal/sqladapter/exql/statement_test.go b/internal/sqladapter/exql/statement_test.go new file mode 100644 index 0000000..28e726a --- /dev/null +++ b/internal/sqladapter/exql/statement_test.go @@ -0,0 +1,703 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTruncateTable(t *testing.T) { + stmt := Statement{ + Type: Truncate, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `TRUNCATE TABLE "table_name"`, s) +} + +func TestDropTable(t *testing.T) { + stmt := Statement{ + Type: DropTable, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DROP TABLE "table_name"`, s) +} + +func TestDropDatabase(t *testing.T) { + stmt := Statement{ + Type: DropDatabase, + Database: &Database{Name: "table_name"}, + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DROP DATABASE "table_name"`, s) +} + +func TestCount(t *testing.T) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "table_name"`, s) +} + +func TestCountRelation(t *testing.T) { + stmt := Statement{ + Type: Count, + Table: TableWithName("information_schema.tables"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "information_schema"."tables"`, s) +} + +func TestCountWhere(t *testing.T) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: &Raw{Value: "7"}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)`, s) +} + +func TestSelectStarFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table_name"`, s) +} + +func TestSelectStarFromAlias(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo"`, s) +} + +func TestSelectStarFromRawWhere(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)`, s) + } + + { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + &Raw{Value: "baz.id = exp.baz_id"}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)`, s) + } +} + +func TestSelectStarFromMany(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"`, s) +} + +func TestSelectTableStarFromMany(t *testing.T) { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo.name"}, + &Column{Name: "BAR.*"}, + &Column{Name: "baz.last_name"}, + ), + Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo"."name", "BAR".*, "baz"."last_name" FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"`, s) +} + +func TestSelectArtistNameFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist"), + Columns: JoinColumns( + &Column{Name: "artist.name"}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "artist"."name" FROM "artist"`, s) +} + +func TestSelectJoin(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist a"), + Columns: JoinColumns( + &Column{Name: "a.name"}, + ), + Joins: JoinConditions(&Join{ + Table: TableWithName("books b"), + On: OnConditions( + &ColumnValue{ + Column: ColumnWithName("b.author_id"), + Operator: `=`, + Value: NewValue(ColumnWithName("a.id")), + }, + ), + }), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" ON ("b"."author_id" = "a"."id")`, s) +} + +func TestSelectJoinUsing(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist a"), + Columns: JoinColumns( + &Column{Name: "a.name"}, + ), + Joins: JoinConditions(&Join{ + Table: TableWithName("books b"), + Using: UsingColumns( + ColumnWithName("artist_id"), + ColumnWithName("country"), + ), + }), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" USING ("artist_id", "country")`, s) +} + +func TestSelectUnfinishedJoin(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist a"), + Columns: JoinColumns( + &Column{Name: "a.name"}, + ), + Joins: JoinConditions(&Join{}), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a"`, s) +} + +func TestSelectNaturalJoin(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist"), + Joins: JoinConditions(&Join{ + Table: TableWithName("books"), + }), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "artist" NATURAL JOIN "books"`, s) +} + +func TestSelectRawFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName(`artist`), + Columns: JoinColumns( + &Column{Name: `artist.name`}, + &Column{Name: &Raw{Value: `CONCAT(artist.name, " ", artist.last_name)`}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"`, s) +} + +func TestSelectFieldsFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name"`, s) +} + +func TestSelectFieldsFromWithLimitOffset(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Offset: 17, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Offset: 17, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17`, s) + } +} + +func TestStatementGroupBy(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"`, s) + } +} + +func TestSelectFieldsFromWithOrderBy(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + &SortColumn{Column: &Column{Name: "bar"}, Order: Order_Ascendent}, + &SortColumn{Column: &Column{Name: "baz"}, Order: Order_Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: &Raw{Value: "FOO()"}}, Order: Order_Descendent}, + &SortColumn{Column: &Column{Name: &Raw{Value: "BAR()"}}, Order: Order_Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC`, s) + } +} + +func TestSelectFieldsFromWhere(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')`, s) + } +} + +func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + Limit: 10, + Offset: 23, + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23`, s) + } +} + +func TestDelete(t *testing.T) { + stmt := Statement{ + Type: Delete, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DELETE FROM "table_name" WHERE ("baz" = '99')`, s) +} + +func TestUpdate(t *testing.T) { + { + stmt := Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')`, s) + } + + { + stmt := Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + &ColumnValue{Column: &Column{Name: "bar"}, Operator: "=", Value: NewValue(&Raw{Value: "88"})}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')`, s) + } +} + +func TestInsert(t *testing.T) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)`, s) +} + +func TestInsertMultiple(t *testing.T) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValueGroups( + NewValueGroup( + NewValue("1"), + NewValue("2"), + NewValue(&Raw{Value: "3"}), + ), + NewValueGroup( + NewValue(&Raw{Value: "4"}), + NewValue(&Raw{Value: "5"}), + NewValue(&Raw{Value: "6"}), + ), + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3), (4, 5, 6)`, s) +} + +func TestInsertReturning(t *testing.T) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Returning: ReturningColumns( + ColumnWithName("id"), + ), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING "id"`, s) +} + +func TestRawSQLStatement(t *testing.T) { + stmt := RawSQL(`SELECT * FROM "foo" ORDER BY "bar"`) + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "foo" ORDER BY "bar"`, s) +} + +func BenchmarkStatementSimpleQuery(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, + ), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementSimpleQueryHash(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, + ), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = stmt.Hash() + } +} + +func BenchmarkStatementSimpleQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, + ), + } + _, _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQuery(b *testing.B) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + _, _ = stmt.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/table.go b/internal/sqladapter/exql/table.go new file mode 100644 index 0000000..b831b45 --- /dev/null +++ b/internal/sqladapter/exql/table.go @@ -0,0 +1,98 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +type tableT struct { + Name string + Alias string +} + +// Table struct represents a SQL table. +type Table struct { + Name interface{} +} + +var _ = Fragment(&Table{}) + +func quotedTableName(layout *Template, input string) string { + input = trimString(input) + + // chunks := reAliasSeparator.Split(input, 2) + chunks := separateByAS(input) + + if len(chunks) == 1 { + // chunks = reSpaceSeparator.Split(input, 2) + chunks = separateBySpace(input) + } + + name := chunks[0] + + nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) + + for i := range nameChunks { + // nameChunks[i] = strings.TrimSpace(nameChunks[i]) + nameChunks[i] = trimString(nameChunks[i]) + nameChunks[i] = layout.MustCompile(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) + } + + name = strings.Join(nameChunks, layout.ColumnSeparator) + + var alias string + + if len(chunks) > 1 { + // alias = strings.TrimSpace(chunks[1]) + alias = trimString(chunks[1]) + alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias}) + } + + return layout.MustCompile(layout.TableAliasLayout, tableT{name, alias}) +} + +// TableWithName creates an returns a Table with the given name. +func TableWithName(name string) *Table { + return &Table{Name: name} +} + +// Hash returns a string hash of the table value. +func (t *Table) Hash() uint64 { + if t == nil { + return cache.NewHash(FragmentType_Table, nil) + } + return cache.NewHash(FragmentType_Table, t.Name) +} + +// Compile transforms a table struct into a SQL chunk. +func (t *Table) Compile(layout *Template) (compiled string, err error) { + + if z, ok := layout.Read(t); ok { + return z, nil + } + + switch value := t.Name.(type) { + case string: + if t.Name == "" { + return + } + + // Splitting tables by a comma + parts := separateByComma(value) + + l := len(parts) + + for i := 0; i < l; i++ { + parts[i] = quotedTableName(layout, parts[i]) + } + + compiled = strings.Join(parts, layout.IdentifierSeparator) + case Raw: + compiled = value.String() + } + + layout.Write(t, compiled) + + return +} diff --git a/internal/sqladapter/exql/table_test.go b/internal/sqladapter/exql/table_test.go new file mode 100644 index 0000000..08bc825 --- /dev/null +++ b/internal/sqladapter/exql/table_test.go @@ -0,0 +1,82 @@ +package exql + +import ( + "github.com/stretchr/testify/assert" + + "testing" +) + +func TestTableSimple(t *testing.T) { + table := TableWithName("artist") + assert.Equal(t, `"artist"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableCompound(t *testing.T) { + table := TableWithName("artist.foo") + assert.Equal(t, `"artist"."foo"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableCompoundAlias(t *testing.T) { + table := TableWithName("artist.foo AS baz") + + assert.Equal(t, `"artist"."foo" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableImplicitAlias(t *testing.T) { + table := TableWithName("artist.foo baz") + + assert.Equal(t, `"artist"."foo" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableMultiple(t *testing.T) { + table := TableWithName("artist.foo, artist.bar, artist.baz") + + assert.Equal(t, `"artist"."foo", "artist"."bar", "artist"."baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableMultipleAlias(t *testing.T) { + table := TableWithName("artist.foo AS foo, artist.bar as bar, artist.baz As baz") + + assert.Equal(t, `"artist"."foo" AS "foo", "artist"."bar" AS "bar", "artist"."baz" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableMinimal(t *testing.T) { + table := TableWithName("a") + + assert.Equal(t, `"a"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableEmpty(t *testing.T) { + table := TableWithName("") + + assert.Equal(t, "", mustTrim(table.Compile(defaultTemplate))) +} + +func BenchmarkTableWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = TableWithName("foo") + } +} + +func BenchmarkTableHash(b *testing.B) { + t := TableWithName("name") + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.Hash() + } +} + +func BenchmarkTableCompile(b *testing.B) { + t := TableWithName("name") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = t.Compile(defaultTemplate) + } +} + +func BenchmarkTableCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + t := TableWithName("name") + _, _ = t.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/template.go b/internal/sqladapter/exql/template.go new file mode 100644 index 0000000..4a148a2 --- /dev/null +++ b/internal/sqladapter/exql/template.go @@ -0,0 +1,148 @@ +package exql + +import ( + "bytes" + "reflect" + "sync" + "text/template" + + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Type is the type of SQL query the statement represents. +type Type uint8 + +// Values for Type. +const ( + NoOp Type = iota + + Truncate + DropTable + DropDatabase + Count + Insert + Select + Update + Delete + + SQL +) + +func (t Type) Hash() uint64 { + return cache.NewHash(FragmentType_StatementType, uint8(t)) +} + +type ( + // Limit represents the SQL limit in a query. + Limit int64 + // Offset represents the SQL offset in a query. + Offset int64 +) + +func (t Limit) Hash() uint64 { + return cache.NewHash(FragmentType_Limit, uint64(t)) +} + +func (t Offset) Hash() uint64 { + return cache.NewHash(FragmentType_Offset, uint64(t)) +} + +// Template is an SQL template. +type Template struct { + AndKeyword string + AscKeyword string + AssignmentOperator string + ClauseGroup string + ClauseOperator string + ColumnAliasLayout string + ColumnSeparator string + ColumnValue string + CountLayout string + DeleteLayout string + DescKeyword string + DropDatabaseLayout string + DropTableLayout string + GroupByLayout string + IdentifierQuote string + IdentifierSeparator string + InsertLayout string + JoinLayout string + OnLayout string + OrKeyword string + OrderByLayout string + SelectLayout string + SortByColumnLayout string + TableAliasLayout string + TruncateLayout string + UpdateLayout string + UsingLayout string + ValueQuote string + ValueSeparator string + WhereLayout string + + ComparisonOperator map[adapter.ComparisonOperator]string + + templateMutex sync.RWMutex + templateMap map[string]*template.Template + + *cache.Cache +} + +func (layout *Template) MustCompile(templateText string, data interface{}) string { + var b bytes.Buffer + + v, ok := layout.getTemplate(templateText) + if !ok { + v = template. + Must(template.New(""). + Funcs(map[string]interface{}{ + "defined": func(in Fragment) bool { + if in == nil || reflect.ValueOf(in).IsNil() { + return false + } + if check, ok := in.(hasIsEmpty); ok { + if check.IsEmpty() { + return false + } + } + return true + }, + "compile": func(in Fragment) (string, error) { + s, err := layout.doCompile(in) + if err != nil { + return "", err + } + return s, nil + }, + }). + Parse(templateText)) + + layout.setTemplate(templateText, v) + } + + if err := v.Execute(&b, data); err != nil { + panic("There was an error compiling the following template:\n" + templateText + "\nError was: " + err.Error()) + } + + return b.String() +} + +func (t *Template) getTemplate(k string) (*template.Template, bool) { + t.templateMutex.RLock() + defer t.templateMutex.RUnlock() + + if t.templateMap == nil { + t.templateMap = make(map[string]*template.Template) + } + + v, ok := t.templateMap[k] + return v, ok +} + +func (t *Template) setTemplate(k string, v *template.Template) { + t.templateMutex.Lock() + defer t.templateMutex.Unlock() + + t.templateMap[k] = v +} diff --git a/internal/sqladapter/exql/types.go b/internal/sqladapter/exql/types.go new file mode 100644 index 0000000..d6ecca9 --- /dev/null +++ b/internal/sqladapter/exql/types.go @@ -0,0 +1,35 @@ +package exql + +const ( + FragmentType_None uint64 = iota + 713910251627 + + FragmentType_And + FragmentType_Column + FragmentType_ColumnValue + FragmentType_ColumnValues + FragmentType_Columns + FragmentType_Database + FragmentType_GroupBy + FragmentType_Join + FragmentType_Joins + FragmentType_Nil + FragmentType_Or + FragmentType_Limit + FragmentType_Offset + FragmentType_OrderBy + FragmentType_Order + FragmentType_Raw + FragmentType_Returning + FragmentType_SortBy + FragmentType_SortColumn + FragmentType_SortColumns + FragmentType_Statement + FragmentType_StatementType + FragmentType_Table + FragmentType_Value + FragmentType_On + FragmentType_Using + FragmentType_ValueGroups + FragmentType_Values + FragmentType_Where +) diff --git a/internal/sqladapter/exql/utilities.go b/internal/sqladapter/exql/utilities.go new file mode 100644 index 0000000..972ebb4 --- /dev/null +++ b/internal/sqladapter/exql/utilities.go @@ -0,0 +1,151 @@ +package exql + +import ( + "strings" +) + +// isBlankSymbol returns true if the given byte is either space, tab, carriage +// return or newline. +func isBlankSymbol(in byte) bool { + return in == ' ' || in == '\t' || in == '\r' || in == '\n' +} + +// trimString returns a slice of s with a leading and trailing blank symbols +// (as defined by isBlankSymbol) removed. +func trimString(s string) string { + + // This conversion is rather slow. + // return string(trimBytes([]byte(s))) + + start, end := 0, len(s)-1 + + if end < start { + return "" + } + + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return "" + } + } + + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] +} + +// trimBytes returns a slice of s with a leading and trailing blank symbols (as +// defined by isBlankSymbol) removed. +func trimBytes(s []byte) []byte { + + start, end := 0, len(s)-1 + + if end < start { + return []byte{} + } + + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return []byte{} + } + } + + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] +} + +/* +// Separates by a comma, ignoring spaces too. +// This was slower than strings.Split. +func separateByComma(in string) (out []string) { + + out = []string{} + + start, lim := 0, len(in)-1 + + for start < lim { + var end int + + for end = start; end <= lim; end++ { + // Is a comma? + if in[end] == ',' { + break + } + } + + out = append(out, trimString(in[start:end])) + + start = end + 1 + } + + return +} +*/ + +// Separates by a comma, ignoring spaces too. +func separateByComma(in string) (out []string) { + out = strings.Split(in, ",") + for i := range out { + out[i] = trimString(out[i]) + } + return +} + +// Separates by spaces, ignoring spaces too. +func separateBySpace(in string) (out []string) { + if len(in) == 0 { + return []string{""} + } + + pre := strings.Split(in, " ") + out = make([]string, 0, len(pre)) + + for i := range pre { + pre[i] = trimString(pre[i]) + if pre[i] != "" { + out = append(out, pre[i]) + } + } + + return +} + +func separateByAS(in string) (out []string) { + out = []string{} + + if len(in) < 6 { + // The minimum expression with the AS keyword is "x AS y", 6 chars. + return []string{in} + } + + start, lim := 0, len(in)-1 + + for start <= lim { + var end int + + for end = start; end <= lim; end++ { + if end > 3 && isBlankSymbol(in[end]) && isBlankSymbol(in[end-3]) { + if (in[end-1] == 's' || in[end-1] == 'S') && (in[end-2] == 'a' || in[end-2] == 'A') { + break + } + } + } + + if end < lim { + out = append(out, trimString(in[start:end-3])) + } else { + out = append(out, trimString(in[start:end])) + } + + start = end + 1 + } + + return +} diff --git a/internal/sqladapter/exql/utilities_test.go b/internal/sqladapter/exql/utilities_test.go new file mode 100644 index 0000000..9dcbde3 --- /dev/null +++ b/internal/sqladapter/exql/utilities_test.go @@ -0,0 +1,211 @@ +package exql + +import ( + "bytes" + "regexp" + "strings" + "testing" + "unicode" + + "github.com/stretchr/testify/assert" +) + +const ( + blankSymbol = ' ' + stringWithCommas = "Hello,,World!,Enjoy" + stringWithSpaces = " Hello World! Enjoy" + stringWithASKeyword = "table.Name AS myTableAlias" +) + +var ( + bytesWithLeadingBlanks = []byte(" Hello world! ") + stringWithLeadingBlanks = string(bytesWithLeadingBlanks) +) + +var ( + reInvisible = regexp.MustCompile(`[\t\n\r]`) + reSpace = regexp.MustCompile(`\s+`) +) + +func mustTrim(a string, err error) string { + if err != nil { + panic(err.Error()) + } + a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") + a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") + return a +} + +func TestUtilIsBlankSymbol(t *testing.T) { + assert.True(t, isBlankSymbol(' ')) + assert.True(t, isBlankSymbol('\n')) + assert.True(t, isBlankSymbol('\t')) + assert.True(t, isBlankSymbol('\r')) + assert.False(t, isBlankSymbol('x')) +} + +func TestUtilTrimBytes(t *testing.T) { + var trimmed []byte + + trimmed = trimBytes([]byte(" \t\nHello World! \n")) + assert.Equal(t, "Hello World!", string(trimmed)) + + trimmed = trimBytes([]byte("Nope")) + assert.Equal(t, "Nope", string(trimmed)) + + trimmed = trimBytes([]byte("")) + assert.Equal(t, "", string(trimmed)) + + trimmed = trimBytes([]byte(" ")) + assert.Equal(t, "", string(trimmed)) + + trimmed = trimBytes(nil) + assert.Equal(t, "", string(trimmed)) +} + +func TestUtilSeparateByComma(t *testing.T) { + chunks := separateByComma("Hello,,World!,Enjoy") + assert.Equal(t, 4, len(chunks)) + + assert.Equal(t, "Hello", chunks[0]) + assert.Equal(t, "", chunks[1]) + assert.Equal(t, "World!", chunks[2]) + assert.Equal(t, "Enjoy", chunks[3]) +} + +func TestUtilSeparateBySpace(t *testing.T) { + chunks := separateBySpace(" Hello World! Enjoy") + assert.Equal(t, 3, len(chunks)) + + assert.Equal(t, "Hello", chunks[0]) + assert.Equal(t, "World!", chunks[1]) + assert.Equal(t, "Enjoy", chunks[2]) +} + +func TestUtilSeparateByAS(t *testing.T) { + var chunks []string + + var tests = []string{ + `table.Name AS myTableAlias`, + `table.Name AS myTableAlias`, + "table.Name\tAS\r\nmyTableAlias", + } + + for _, test := range tests { + chunks = separateByAS(test) + assert.Len(t, chunks, 2) + + assert.Equal(t, "table.Name", chunks[0]) + assert.Equal(t, "myTableAlias", chunks[1]) + } + + // Single character. + chunks = separateByAS("a") + assert.Len(t, chunks, 1) + assert.Equal(t, "a", chunks[0]) + + // Empty name + chunks = separateByAS("") + assert.Len(t, chunks, 1) + assert.Equal(t, "", chunks[0]) + + // Single name + chunks = separateByAS(" A Single Table ") + assert.Len(t, chunks, 1) + assert.Equal(t, "A Single Table", chunks[0]) + + // Minimal expression. + chunks = separateByAS("a AS b") + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "b", chunks[1]) + + // Minimal expression with spaces. + chunks = separateByAS(" a AS b ") + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "b", chunks[1]) + + // Minimal expression + 1 with spaces. + chunks = separateByAS(" a AS bb ") + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "bb", chunks[1]) +} + +func BenchmarkUtilIsBlankSymbol(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = isBlankSymbol(blankSymbol) + } +} + +func BenchmarkUtilStdlibIsBlankSymbol(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = unicode.IsSpace(blankSymbol) + } +} + +func BenchmarkUtilTrimBytes(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = trimBytes(bytesWithLeadingBlanks) + } +} +func BenchmarkUtilStdlibBytesTrimSpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = bytes.TrimSpace(bytesWithLeadingBlanks) + } +} + +func BenchmarkUtilTrimString(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = trimString(stringWithLeadingBlanks) + } +} + +func BenchmarkUtilStdlibStringsTrimSpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = strings.TrimSpace(stringWithLeadingBlanks) + } +} + +func BenchmarkUtilSeparateByComma(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = separateByComma(stringWithCommas) + } +} + +func BenchmarkUtilRegExpSeparateByComma(b *testing.B) { + sep := regexp.MustCompile(`\s*?,\s*?`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sep.Split(stringWithCommas, -1) + } +} + +func BenchmarkUtilSeparateBySpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = separateBySpace(stringWithSpaces) + } +} + +func BenchmarkUtilRegExpSeparateBySpace(b *testing.B) { + sep := regexp.MustCompile(`\s+`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sep.Split(stringWithSpaces, -1) + } +} + +func BenchmarkUtilSeparateByAS(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = separateByAS(stringWithASKeyword) + } +} + +func BenchmarkUtilRegExpSeparateByAS(b *testing.B) { + sep := regexp.MustCompile(`(?i:\s+AS\s+)`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sep.Split(stringWithASKeyword, -1) + } +} diff --git a/internal/sqladapter/exql/value.go b/internal/sqladapter/exql/value.go new file mode 100644 index 0000000..6190235 --- /dev/null +++ b/internal/sqladapter/exql/value.go @@ -0,0 +1,166 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// ValueGroups represents an array of value groups. +type ValueGroups struct { + Values []*Values +} + +func (vg *ValueGroups) IsEmpty() bool { + if vg == nil || len(vg.Values) < 1 { + return true + } + for i := range vg.Values { + if !vg.Values[i].IsEmpty() { + return false + } + } + return true +} + +var _ = Fragment(&ValueGroups{}) + +// Values represents an array of Value. +type Values struct { + Values []Fragment +} + +func (vs *Values) IsEmpty() bool { + if vs == nil || len(vs.Values) < 1 { + return true + } + return false +} + +// NewValueGroup creates and returns an array of values. +func NewValueGroup(v ...Fragment) *Values { + return &Values{Values: v} +} + +var _ = Fragment(&Values{}) + +// Value represents an escaped SQL value. +type Value struct { + V interface{} +} + +var _ = Fragment(&Value{}) + +// NewValue creates and returns a Value. +func NewValue(v interface{}) *Value { + return &Value{V: v} +} + +// Hash returns a unique identifier for the struct. +func (v *Value) Hash() uint64 { + if v == nil { + return cache.NewHash(FragmentType_Value, nil) + } + return cache.NewHash(FragmentType_Value, v.V) +} + +// Compile transforms the Value into an equivalent SQL representation. +func (v *Value) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(v); ok { + return z, nil + } + + switch value := v.V.(type) { + case compilable: + compiled, err = value.Compile(layout) + if err != nil { + return "", err + } + default: + value, err := NewRawValue(v.V) + if err != nil { + return "", err + } + compiled = layout.MustCompile( + layout.ValueQuote, + value, + ) + } + + layout.Write(v, compiled) + return +} + +// Hash returns a unique identifier for the struct. +func (vs *Values) Hash() uint64 { + if vs == nil { + return cache.NewHash(FragmentType_Values, nil) + } + h := cache.InitHash(FragmentType_Values) + for i := range vs.Values { + h = cache.AddToHash(h, vs.Values[i]) + } + return h +} + +// Compile transforms the Values into an equivalent SQL representation. +func (vs *Values) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(vs); ok { + return c, nil + } + + l := len(vs.Values) + if l > 0 { + chunks := make([]string, 0, l) + for i := 0; i < l; i++ { + chunk, err := vs.Values[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + compiled = layout.MustCompile(layout.ClauseGroup, strings.Join(chunks, layout.ValueSeparator)) + } + layout.Write(vs, compiled) + return +} + +// Hash returns a unique identifier for the struct. +func (vg *ValueGroups) Hash() uint64 { + if vg == nil { + return cache.NewHash(FragmentType_ValueGroups, nil) + } + h := cache.InitHash(FragmentType_ValueGroups) + for i := range vg.Values { + h = cache.AddToHash(h, vg.Values[i]) + } + return h +} + +// Compile transforms the ValueGroups into an equivalent SQL representation. +func (vg *ValueGroups) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(vg); ok { + return c, nil + } + + l := len(vg.Values) + if l > 0 { + chunks := make([]string, 0, l) + for i := 0; i < l; i++ { + chunk, err := vg.Values[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + compiled = strings.Join(chunks, layout.ValueSeparator) + } + + layout.Write(vg, compiled) + return +} + +// JoinValueGroups creates a new *ValueGroups object. +func JoinValueGroups(values ...*Values) *ValueGroups { + return &ValueGroups{Values: values} +} diff --git a/internal/sqladapter/exql/value_test.go b/internal/sqladapter/exql/value_test.go new file mode 100644 index 0000000..a0269b6 --- /dev/null +++ b/internal/sqladapter/exql/value_test.go @@ -0,0 +1,130 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValue(t *testing.T) { + val := NewValue(1) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + + val = NewValue(&Raw{Value: "NOW()"}) + + s, err = val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `NOW()`, s) +} + +func TestSameRawValue(t *testing.T) { + { + val := NewValue(&Raw{Value: `"1"`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"1"`, s) + } + { + val := NewValue(&Raw{Value: `'1'`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + } + { + val := NewValue(&Raw{Value: `1`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `1`, s) + } + { + val := NewValue("1") + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + } + { + val := NewValue(1) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + } +} + +func TestValues(t *testing.T) { + val := NewValueGroup( + &Value{V: &Raw{Value: "1"}}, + &Value{V: &Raw{Value: "2"}}, + &Value{V: "3"}, + ) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + + assert.Equal(t, `(1, 2, '3')`, s) +} + +func BenchmarkValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewValue("a") + } +} + +func BenchmarkValueHash(b *testing.B) { + v := NewValue("a") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = v.Hash() + } +} + +func BenchmarkValueCompile(b *testing.B) { + v := NewValue("a") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + v := NewValue("a") + _, _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewValueGroup(NewValue("a"), NewValue("b")) + } +} + +func BenchmarkValuesHash(b *testing.B) { + vs := NewValueGroup(NewValue("a"), NewValue("b")) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = vs.Hash() + } +} + +func BenchmarkValuesCompile(b *testing.B) { + vs := NewValueGroup(NewValue("a"), NewValue("b")) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = vs.Compile(defaultTemplate) + } +} + +func BenchmarkValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + vs := NewValueGroup(NewValue("a"), NewValue("b")) + _, _ = vs.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/where.go b/internal/sqladapter/exql/where.go new file mode 100644 index 0000000..3cc8985 --- /dev/null +++ b/internal/sqladapter/exql/where.go @@ -0,0 +1,149 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Or represents an SQL OR operator. +type Or Where + +// And represents an SQL AND operator. +type And Where + +// Where represents an SQL WHERE clause. +type Where struct { + Conditions []Fragment +} + +var _ = Fragment(&Where{}) + +type conds struct { + Conds string +} + +// WhereConditions creates and retuens a new Where. +func WhereConditions(conditions ...Fragment) *Where { + return &Where{Conditions: conditions} +} + +// JoinWithOr creates and returns a new Or. +func JoinWithOr(conditions ...Fragment) *Or { + return &Or{Conditions: conditions} +} + +// JoinWithAnd creates and returns a new And. +func JoinWithAnd(conditions ...Fragment) *And { + return &And{Conditions: conditions} +} + +// Hash returns a unique identifier for the struct. +func (w *Where) Hash() uint64 { + if w == nil { + return cache.NewHash(FragmentType_Where, nil) + } + h := cache.InitHash(FragmentType_Where) + for i := range w.Conditions { + h = cache.AddToHash(h, w.Conditions[i]) + } + return h +} + +// Appends adds the conditions to the ones that already exist. +func (w *Where) Append(a *Where) *Where { + if a != nil { + w.Conditions = append(w.Conditions, a.Conditions...) + } + return w +} + +// Hash returns a unique identifier. +func (o *Or) Hash() uint64 { + if o == nil { + return cache.NewHash(FragmentType_Or, nil) + } + return cache.NewHash(FragmentType_Or, (*Where)(o)) +} + +// Hash returns a unique identifier. +func (a *And) Hash() uint64 { + if a == nil { + return cache.NewHash(FragmentType_And, nil) + } + return cache.NewHash(FragmentType_And, (*Where)(a)) +} + +// Compile transforms the Or into an equivalent SQL representation. +func (o *Or) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(o); ok { + return z, nil + } + + compiled, err = groupCondition(layout, o.Conditions, layout.MustCompile(layout.ClauseOperator, layout.OrKeyword)) + if err != nil { + return "", err + } + + layout.Write(o, compiled) + + return +} + +// Compile transforms the And into an equivalent SQL representation. +func (a *And) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(a); ok { + return c, nil + } + + compiled, err = groupCondition(layout, a.Conditions, layout.MustCompile(layout.ClauseOperator, layout.AndKeyword)) + if err != nil { + return "", err + } + + layout.Write(a, compiled) + + return +} + +// Compile transforms the Where into an equivalent SQL representation. +func (w *Where) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(w); ok { + return c, nil + } + + grouped, err := groupCondition(layout, w.Conditions, layout.MustCompile(layout.ClauseOperator, layout.AndKeyword)) + if err != nil { + return "", err + } + + if grouped != "" { + compiled = layout.MustCompile(layout.WhereLayout, conds{grouped}) + } + + layout.Write(w, compiled) + + return +} + +func groupCondition(layout *Template, terms []Fragment, joinKeyword string) (string, error) { + l := len(terms) + + chunks := make([]string, 0, l) + + if l > 0 { + for i := 0; i < l; i++ { + chunk, err := terms[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + } + + if len(chunks) > 0 { + return layout.MustCompile(layout.ClauseGroup, strings.Join(chunks, joinKeyword)), nil + } + + return "", nil +} diff --git a/internal/sqladapter/exql/where_test.go b/internal/sqladapter/exql/where_test.go new file mode 100644 index 0000000..16e301e --- /dev/null +++ b/internal/sqladapter/exql/where_test.go @@ -0,0 +1,127 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWhereAnd(t *testing.T) { + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + ) + + s, err := and.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `("id" > 8 AND "id" < 99 AND "name" = 'John')`, s) +} + +func TestWhereOr(t *testing.T) { + or := JoinWithOr( + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "99"})}, + ) + + s, err := or.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `("id" = 8 OR "id" = 99)`, s) +} + +func TestWhereAndOr(t *testing.T) { + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + ) + + s, err := and.Compile(defaultTemplate) + assert.NoError(t, err) + + assert.Equal(t, `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))`, s) +} + +func TestWhereAndRawOrAnd(t *testing.T) { + { + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(2)}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "77"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(1)}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + assert.Equal(t, + `WHERE (("id" > '2' AND "id" < 77 AND "id" < '1') AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, + mustTrim(where.Compile(defaultTemplate)), + ) + } + + { + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(1)}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + assert.Equal(t, + `WHERE (("id" > 8 AND "id" > 8 AND "id" < 99 AND "id" < '1') AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, + mustTrim(where.Compile(defaultTemplate)), + ) + } +} + +func BenchmarkWhere(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + } +} + +func BenchmarkCompileWhere(b *testing.B) { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = w.Compile(defaultTemplate) + } +} + +func BenchmarkCompileWhereNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + _, _ = w.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/hash.go b/internal/sqladapter/hash.go new file mode 100644 index 0000000..4d75491 --- /dev/null +++ b/internal/sqladapter/hash.go @@ -0,0 +1,8 @@ +package sqladapter + +const ( + hashTypeNone = iota + 345065139389 + + hashTypeCollection + hashTypePrimaryKeys +) diff --git a/internal/sqladapter/record.go b/internal/sqladapter/record.go new file mode 100644 index 0000000..35f568e --- /dev/null +++ b/internal/sqladapter/record.go @@ -0,0 +1,122 @@ +package sqladapter + +import ( + "reflect" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +func recordID(store mydb.Store, record mydb.Record) (mydb.Cond, error) { + if record == nil { + return nil, mydb.ErrNilRecord + } + + if hasConstraints, ok := record.(mydb.HasConstraints); ok { + return hasConstraints.Constraints(), nil + } + + id := mydb.Cond{} + + keys, fields, err := recordPrimaryKeyFieldValues(store, record) + if err != nil { + return nil, err + } + for i := range fields { + if fields[i] == reflect.Zero(reflect.TypeOf(fields[i])).Interface() { + return nil, mydb.ErrRecordIDIsZero + } + id[keys[i]] = fields[i] + } + if len(id) < 1 { + return nil, mydb.ErrRecordIDIsZero + } + + return id, nil +} + +func recordPrimaryKeyFieldValues(store mydb.Store, record mydb.Record) ([]string, []interface{}, error) { + sess := store.Session() + + pKeys, err := sess.(Session).PrimaryKeys(store.Name()) + if err != nil { + return nil, nil, err + } + + fields := sqlbuilder.Mapper.FieldsByName(reflect.ValueOf(record), pKeys) + + values := make([]interface{}, 0, len(fields)) + for i := range fields { + if fields[i].IsValid() { + values = append(values, fields[i].Interface()) + } + } + + return pKeys, values, nil +} + +func recordCreate(store mydb.Store, record mydb.Record) error { + sess := store.Session() + + if validator, ok := record.(mydb.Validator); ok { + if err := validator.Validate(); err != nil { + return err + } + } + + if hook, ok := record.(mydb.BeforeCreateHook); ok { + if err := hook.BeforeCreate(sess); err != nil { + return err + } + } + + if creator, ok := store.(mydb.StoreCreator); ok { + if err := creator.Create(record); err != nil { + return err + } + } else { + if err := store.InsertReturning(record); err != nil { + return err + } + } + + if hook, ok := record.(mydb.AfterCreateHook); ok { + if err := hook.AfterCreate(sess); err != nil { + return err + } + } + return nil +} + +func recordUpdate(store mydb.Store, record mydb.Record) error { + sess := store.Session() + + if validator, ok := record.(mydb.Validator); ok { + if err := validator.Validate(); err != nil { + return err + } + } + + if hook, ok := record.(mydb.BeforeUpdateHook); ok { + if err := hook.BeforeUpdate(sess); err != nil { + return err + } + } + + if updater, ok := store.(mydb.StoreUpdater); ok { + if err := updater.Update(record); err != nil { + return err + } + } else { + if err := record.Store(sess).UpdateReturning(record); err != nil { + return err + } + } + + if hook, ok := record.(mydb.AfterUpdateHook); ok { + if err := hook.AfterUpdate(sess); err != nil { + return err + } + } + return nil +} diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go new file mode 100644 index 0000000..1f0ead6 --- /dev/null +++ b/internal/sqladapter/result.go @@ -0,0 +1,498 @@ +package sqladapter + +import ( + "errors" + "sync" + "sync/atomic" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" +) + +type Result struct { + builder mydb.SQL + + err atomic.Value + + iter mydb.Iterator + iterMu sync.Mutex + + prev *Result + fn func(*result) error +} + +// result represents a delimited set of items bound by a condition. +type result struct { + table string + limit int + offset int + + pageSize uint + pageNumber uint + + cursorColumn string + nextPageCursorValue interface{} + prevPageCursorValue interface{} + + fields []interface{} + orderBy []interface{} + groupBy []interface{} + conds [][]interface{} +} + +func filter(conds []interface{}) []interface{} { + return conds +} + +// NewResult creates and Results a new Result set on the given table, this set +// is limited by the given exql.Where conditions. +func NewResult(builder mydb.SQL, table string, conds []interface{}) *Result { + r := &Result{ + builder: builder, + } + return r.from(table).where(conds) +} + +func (r *Result) frame(fn func(*result) error) *Result { + return &Result{err: r.err, prev: r, fn: fn} +} + +func (r *Result) SQL() mydb.SQL { + if r.prev == nil { + return r.builder + } + return r.prev.SQL() +} + +func (r *Result) from(table string) *Result { + return r.frame(func(res *result) error { + res.table = table + return nil + }) +} + +func (r *Result) where(conds []interface{}) *Result { + return r.frame(func(res *result) error { + res.conds = [][]interface{}{conds} + return nil + }) +} + +func (r *Result) setErr(err error) { + if err == nil { + return + } + r.err.Store(err) +} + +// Err returns the last error that has happened with the result set, +// nil otherwise +func (r *Result) Err() error { + if errV := r.err.Load(); errV != nil { + return errV.(error) + } + return nil +} + +// Where sets conditions for the result set. +func (r *Result) Where(conds ...interface{}) mydb.Result { + return r.where(conds) +} + +// And adds more conditions on top of the existing ones. +func (r *Result) And(conds ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.conds = append(res.conds, conds) + return nil + }) +} + +// Limit determines the maximum limit of Results to be returned. +func (r *Result) Limit(n int) mydb.Result { + return r.frame(func(res *result) error { + res.limit = n + return nil + }) +} + +func (r *Result) Paginate(pageSize uint) mydb.Result { + return r.frame(func(res *result) error { + res.pageSize = pageSize + return nil + }) +} + +func (r *Result) Page(pageNumber uint) mydb.Result { + return r.frame(func(res *result) error { + res.pageNumber = pageNumber + res.nextPageCursorValue = nil + res.prevPageCursorValue = nil + return nil + }) +} + +func (r *Result) Cursor(cursorColumn string) mydb.Result { + return r.frame(func(res *result) error { + res.cursorColumn = cursorColumn + return nil + }) +} + +func (r *Result) NextPage(cursorValue interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.nextPageCursorValue = cursorValue + res.prevPageCursorValue = nil + return nil + }) +} + +func (r *Result) PrevPage(cursorValue interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.nextPageCursorValue = nil + res.prevPageCursorValue = cursorValue + return nil + }) +} + +// Offset determines how many documents will be skipped before starting to grab +// Results. +func (r *Result) Offset(n int) mydb.Result { + return r.frame(func(res *result) error { + res.offset = n + return nil + }) +} + +// GroupBy is used to group Results that have the same value in the same column +// or columns. +func (r *Result) GroupBy(fields ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.groupBy = fields + return nil + }) +} + +// OrderBy determines sorting of Results according to the provided names. Fields +// may be prefixed by - (minus) which means descending order, ascending order +// would be used otherwise. +func (r *Result) OrderBy(fields ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.orderBy = fields + return nil + }) +} + +// Select determines which fields to return. +func (r *Result) Select(fields ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.fields = fields + return nil + }) +} + +// String satisfies fmt.Stringer +func (r *Result) String() string { + query, err := r.Paginator() + if err != nil { + panic(err.Error()) + } + return query.String() +} + +// All dumps all Results into a pointer to an slice of structs or maps. +func (r *Result) All(dst interface{}) error { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return err + } + err = query.Iterator().All(dst) + r.setErr(err) + return err +} + +// One fetches only one Result from the set. +func (r *Result) One(dst interface{}) error { + one := r.Limit(1).(*Result) + query, err := one.Paginator() + if err != nil { + r.setErr(err) + return err + } + + err = query.Iterator().One(dst) + r.setErr(err) + return err +} + +// Next fetches the next Result from the set. +func (r *Result) Next(dst interface{}) bool { + r.iterMu.Lock() + defer r.iterMu.Unlock() + + if r.iter == nil { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return false + } + r.iter = query.Iterator() + } + + if r.iter.Next(dst) { + return true + } + + if err := r.iter.Err(); !errors.Is(err, mydb.ErrNoMoreRows) { + r.setErr(err) + return false + } + + return false +} + +// Delete deletes all matching items from the collection. +func (r *Result) Delete() error { + query, err := r.buildDelete() + if err != nil { + r.setErr(err) + return err + } + + _, err = query.Exec() + r.setErr(err) + return err +} + +// Close closes the Result set. +func (r *Result) Close() error { + if r.iter != nil { + err := r.iter.Close() + r.setErr(err) + return err + } + return nil +} + +// Update updates matching items from the collection with values of the given +// map or struct. +func (r *Result) Update(values interface{}) error { + query, err := r.buildUpdate(values) + if err != nil { + r.setErr(err) + return err + } + + _, err = query.Exec() + r.setErr(err) + return err +} + +func (r *Result) TotalPages() (uint, error) { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return 0, err + } + + total, err := query.TotalPages() + if err != nil { + r.setErr(err) + return 0, err + } + + return total, nil +} + +func (r *Result) TotalEntries() (uint64, error) { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return 0, err + } + + total, err := query.TotalEntries() + if err != nil { + r.setErr(err) + return 0, err + } + + return total, nil +} + +// Exists returns true if at least one item on the collection exists. +func (r *Result) Exists() (bool, error) { + query, err := r.buildCount() + if err != nil { + r.setErr(err) + return false, err + } + + query = query.Limit(1) + + value := struct { + Exists uint64 `db:"_t"` + }{} + + if err := query.One(&value); err != nil { + if errors.Is(err, mydb.ErrNoMoreRows) { + return false, nil + } + r.setErr(err) + return false, err + } + + if value.Exists > 0 { + return true, nil + } + + return false, nil +} + +// Count counts the elements on the set. +func (r *Result) Count() (uint64, error) { + query, err := r.buildCount() + if err != nil { + r.setErr(err) + return 0, err + } + + counter := struct { + Count uint64 `db:"_t"` + }{} + if err := query.One(&counter); err != nil { + if errors.Is(err, mydb.ErrNoMoreRows) { + return 0, nil + } + r.setErr(err) + return 0, err + } + + return counter.Count, nil +} + +func (r *Result) Paginator() (mydb.Paginator, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + sel := r.SQL().Select(res.fields...). + From(res.table). + Limit(res.limit). + Offset(res.offset). + GroupBy(res.groupBy...). + OrderBy(res.orderBy...) + + for i := range res.conds { + sel = sel.And(filter(res.conds[i])...) + } + + pag := sel.Paginate(res.pageSize). + Page(res.pageNumber). + Cursor(res.cursorColumn) + + if res.nextPageCursorValue != nil { + pag = pag.NextPage(res.nextPageCursorValue) + } + + if res.prevPageCursorValue != nil { + pag = pag.PrevPage(res.prevPageCursorValue) + } + + return pag, nil +} + +func (r *Result) buildDelete() (mydb.Deleter, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + del := r.SQL().DeleteFrom(res.table). + Limit(res.limit) + + for i := range res.conds { + del = del.And(filter(res.conds[i])...) + } + + return del, nil +} + +func (r *Result) buildUpdate(values interface{}) (mydb.Updater, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + upd := r.SQL().Update(res.table). + Set(values). + Limit(res.limit) + + for i := range res.conds { + upd = upd.And(filter(res.conds[i])...) + } + + return upd, nil +} + +func (r *Result) buildCount() (mydb.Selector, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + sel := r.SQL().Select(mydb.Raw("count(1) AS _t")). + From(res.table). + GroupBy(res.groupBy...) + + for i := range res.conds { + sel = sel.And(filter(res.conds[i])...) + } + + return sel, nil +} + +func (r *Result) Prev() immutable.Immutable { + if r == nil { + return nil + } + return r.prev +} + +func (r *Result) Fn(in interface{}) error { + if r.fn == nil { + return nil + } + return r.fn(in.(*result)) +} + +func (r *Result) Base() interface{} { + return &result{} +} + +func (r *Result) fastForward() (*result, error) { + ff, err := immutable.FastForward(r) + if err != nil { + return nil, err + } + return ff.(*result), nil +} + +var _ = immutable.Immutable(&Result{}) diff --git a/internal/sqladapter/session.go b/internal/sqladapter/session.go new file mode 100644 index 0000000..e973e6b --- /dev/null +++ b/internal/sqladapter/session.go @@ -0,0 +1,1106 @@ +package sqladapter + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "math" + "reflect" + "strconv" + "sync" + "sync/atomic" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/compat" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +var ( + lastSessID uint64 + lastTxID uint64 +) + +var ( + slowQueryThreshold = time.Millisecond * 200 + retryTransactionWaitTime = time.Millisecond * 10 + retryTransactionMaxWaitTime = time.Second * 1 +) + +// hasCleanUp is implemented by structs that have a clean up routine that needs +// to be called before Close(). +type hasCleanUp interface { + CleanUp() error +} + +// statementExecer allows the adapter to have its own exec statement. +type statementExecer interface { + StatementExec(sess Session, ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// statementCompiler transforms an internal statement into a format +// database/sql can understand. +type statementCompiler interface { + CompileStatement(sess Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error) +} + +// sessValueConverter converts values before being passed to the underlying driver. +type sessValueConverter interface { + ConvertValue(in interface{}) interface{} +} + +// sessValueConverterContext converts values before being passed to the underlying driver. +type sessValueConverterContext interface { + ConvertValueContext(ctx context.Context, in interface{}) interface{} +} + +// valueConverter converts values before being passed to the underlying driver. +type valueConverter interface { + ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer + } +} + +// errorConverter converts an error value from the underlying driver into +// something different. +type errorConverter interface { + Err(errIn error) (errOut error) +} + +// AdapterSession defines methods to be implemented by SQL adapters. +type AdapterSession interface { + Template() *exql.Template + + NewCollection() CollectionAdapter + + // Open opens a new connection + OpenDSN(sess Session, dsn string) (*sql.DB, error) + + // Collections returns a list of non-system tables from the database. + Collections(sess Session) ([]string, error) + + // TableExists returns an error if the given table does not exist. + TableExists(sess Session, name string) error + + // LookupName returns the name of the database. + LookupName(sess Session) (string, error) + + // PrimaryKeys returns all primary keys on the table. + PrimaryKeys(sess Session, name string) ([]string, error) +} + +// Session satisfies mydb.Session. +type Session interface { + SQL() mydb.SQL + + // PrimaryKeys returns all primary keys on the table. + PrimaryKeys(tableName string) ([]string, error) + + // Collections returns a list of references to all collections in the + // database. + Collections() ([]mydb.Collection, error) + + // Name returns the name of the database. + Name() string + + // Close closes the database session + Close() error + + // Ping checks if the database server is reachable. + Ping() error + + // Reset clears all caches the session is using + Reset() + + // Collection returns a new collection. + Collection(string) mydb.Collection + + // ConnectionURL returns the ConnectionURL that was used to create the + // Session. + ConnectionURL() mydb.ConnectionURL + + // Open attempts to establish a connection to the database server. + Open() error + + // TableExists returns an error if the table doesn't exists. + TableExists(name string) error + + // Driver returns the underlying driver the session is using + Driver() interface{} + + Save(mydb.Record) error + + Get(mydb.Record, interface{}) error + + Delete(mydb.Record) error + + // WaitForConnection attempts to run the given connection function a fixed + // number of times before failing. + WaitForConnection(func() error) error + + // BindDB sets the *sql.DB the session will use. + BindDB(*sql.DB) error + + // Session returns the *sql.DB the session is using. + DB() *sql.DB + + // BindTx binds a transaction to the current session. + BindTx(context.Context, *sql.Tx) error + + // Returns the current transaction the session is using. + Transaction() *sql.Tx + + // NewClone clones the database using the given AdapterSession as base. + NewClone(AdapterSession, bool) (Session, error) + + // Context returns the default context the session is using. + Context() context.Context + + // SetContext sets the default context for the session. + SetContext(context.Context) + + NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) + + Tx(fn func(sess mydb.Session) error) error + + TxContext(ctx context.Context, fn func(sess mydb.Session) error, opts *sql.TxOptions) error + + WithContext(context.Context) mydb.Session + + IsTransaction() bool + + Commit() error + + Rollback() error + + mydb.Settings +} + +// NewTx wraps a *sql.Tx and returns a Tx. +func NewTx(adapter AdapterSession, tx *sql.Tx) (Session, error) { + sessTx := &sessionWithContext{ + session: &session{ + Settings: mydb.DefaultSettings, + + sqlTx: tx, + adapter: adapter, + cachedPKs: cache.NewCache(), + cachedCollections: cache.NewCache(), + cachedStatements: cache.NewCache(), + }, + ctx: context.Background(), + } + return sessTx, nil +} + +// NewSession creates a new Session. +func NewSession(connURL mydb.ConnectionURL, adapter AdapterSession) Session { + sess := &sessionWithContext{ + session: &session{ + Settings: mydb.DefaultSettings, + + connURL: connURL, + adapter: adapter, + cachedPKs: cache.NewCache(), + cachedCollections: cache.NewCache(), + cachedStatements: cache.NewCache(), + }, + ctx: context.Background(), + } + return sess +} + +type session struct { + mydb.Settings + + adapter AdapterSession + + connURL mydb.ConnectionURL + + builder mydb.SQL + + lookupNameOnce sync.Once + name string + + mu sync.Mutex // guards ctx, txOptions + txOptions *sql.TxOptions + + sqlDBMu sync.Mutex // guards sess, baseTx + + sqlDB *sql.DB + sqlTx *sql.Tx + + sessID uint64 + txID uint64 + + cacheMu sync.Mutex // guards cachedStatements and cachedCollections + cachedPKs *cache.Cache + cachedStatements *cache.Cache + cachedCollections *cache.Cache + + template *exql.Template +} + +type sessionWithContext struct { + *session + + ctx context.Context +} + +func (sess *sessionWithContext) WithContext(ctx context.Context) mydb.Session { + if ctx == nil { + panic("nil context") + } + newSess := &sessionWithContext{ + session: sess.session, + ctx: ctx, + } + return newSess +} + +func (sess *sessionWithContext) Tx(fn func(sess mydb.Session) error) error { + return TxContext(sess.Context(), sess, fn, nil) +} + +func (sess *sessionWithContext) TxContext(ctx context.Context, fn func(sess mydb.Session) error, opts *sql.TxOptions) error { + return TxContext(ctx, sess, fn, opts) +} + +func (sess *sessionWithContext) SQL() mydb.SQL { + return sqlbuilder.WithSession( + sess, + sess.adapter.Template(), + ) +} + +func (sess *sessionWithContext) Err(errIn error) (errOur error) { + if convertError, ok := sess.adapter.(errorConverter); ok { + return convertError.Err(errIn) + } + return errIn +} + +func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) { + h := cache.NewHashable(hashTypePrimaryKeys, tableName) + + cachedPK, ok := sess.cachedPKs.ReadRaw(h) + if ok { + return cachedPK.([]string), nil + } + + pk, err := sess.adapter.PrimaryKeys(sess, tableName) + if err != nil { + return nil, err + } + + sess.cachedPKs.Write(h, pk) + return pk, nil +} + +func (sess *sessionWithContext) TableExists(name string) error { + return sess.adapter.TableExists(sess, name) +} + +func (sess *sessionWithContext) NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) { + if ctx == nil { + ctx = context.Background() + } + clone, err := sess.NewClone(sess.adapter, false) + if err != nil { + return nil, err + } + + connFn := func() error { + sqlTx, err := compat.BeginTx(clone.DB(), clone.Context(), opts) + if err == nil { + return clone.BindTx(ctx, sqlTx) + } + return err + } + + if err := clone.WaitForConnection(connFn); err != nil { + return nil, err + } + + return clone, nil +} + +func (sess *sessionWithContext) Collections() ([]mydb.Collection, error) { + names, err := sess.adapter.Collections(sess) + if err != nil { + return nil, err + } + + collections := make([]mydb.Collection, 0, len(names)) + for i := range names { + collections = append(collections, sess.Collection(names[i])) + } + + return collections, nil +} + +func (sess *sessionWithContext) ConnectionURL() mydb.ConnectionURL { + return sess.connURL +} + +func (sess *sessionWithContext) Open() error { + var sqlDB *sql.DB + var err error + + connFn := func() error { + sqlDB, err = sess.adapter.OpenDSN(sess, sess.connURL.String()) + if err != nil { + return err + } + + sqlmydb.SetConnMaxLifetime(sess.ConnMaxLifetime()) + sqlmydb.SetConnMaxIdleTime(sess.ConnMaxIdleTime()) + sqlmydb.SetMaxIdleConns(sess.MaxIdleConns()) + sqlmydb.SetMaxOpenConns(sess.MaxOpenConns()) + return nil + } + + if err := sess.WaitForConnection(connFn); err != nil { + return err + } + + return sess.BindDB(sqlDB) +} + +func (sess *sessionWithContext) Get(record mydb.Record, id interface{}) error { + store := record.Store(sess) + if getter, ok := store.(mydb.StoreGetter); ok { + return getter.Get(record, id) + } + return store.Find(id).One(record) +} + +func (sess *sessionWithContext) Save(record mydb.Record) error { + if record == nil { + return mydb.ErrNilRecord + } + + if reflect.TypeOf(record).Kind() != reflect.Ptr { + return mydb.ErrExpectingPointerToStruct + } + + store := record.Store(sess) + if saver, ok := store.(mydb.StoreSaver); ok { + return saver.Save(record) + } + + id := mydb.Cond{} + keys, values, err := recordPrimaryKeyFieldValues(store, record) + if err != nil { + return err + } + for i := range values { + if values[i] != reflect.Zero(reflect.TypeOf(values[i])).Interface() { + id[keys[i]] = values[i] + } + } + + if len(id) > 0 && len(id) == len(values) { + // check if record exists before updating it + exists, _ := store.Find(id).Count() + if exists > 0 { + return recordUpdate(store, record) + } + } + + return recordCreate(store, record) +} + +func (sess *sessionWithContext) Delete(record mydb.Record) error { + if record == nil { + return mydb.ErrNilRecord + } + + if reflect.TypeOf(record).Kind() != reflect.Ptr { + return mydb.ErrExpectingPointerToStruct + } + + store := record.Store(sess) + + if hook, ok := record.(mydb.BeforeDeleteHook); ok { + if err := hook.BeforeDelete(sess); err != nil { + return err + } + } + + if deleter, ok := store.(mydb.StoreDeleter); ok { + if err := deleter.Delete(record); err != nil { + return err + } + } else { + conds, err := recordID(store, record) + if err != nil { + return err + } + if err := store.Find(conds).Delete(); err != nil { + return err + } + } + + if hook, ok := record.(mydb.AfterDeleteHook); ok { + if err := hook.AfterDelete(sess); err != nil { + return err + } + } + + return nil +} + +func (sess *sessionWithContext) DB() *sql.DB { + return sess.sqlDB +} + +func (sess *sessionWithContext) SetContext(ctx context.Context) { + sess.mu.Lock() + sess.ctx = ctx + sess.mu.Unlock() +} + +func (sess *sessionWithContext) Context() context.Context { + return sess.ctx +} + +func (sess *sessionWithContext) SetTxOptions(txOptions sql.TxOptions) { + sess.mu.Lock() + sess.txOptions = &txOptions + sess.mu.Unlock() +} + +func (sess *sessionWithContext) TxOptions() *sql.TxOptions { + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.txOptions == nil { + return nil + } + return sess.txOptions +} + +func (sess *sessionWithContext) BindTx(ctx context.Context, tx *sql.Tx) error { + sess.sqlDBMu.Lock() + defer sess.sqlDBMu.Unlock() + + sess.sqlTx = tx + sess.SetContext(ctx) + + sess.txID = newBaseTxID() + + return nil +} + +func (sess *sessionWithContext) Commit() error { + if sess.sqlTx != nil { + return sess.sqlTx.Commit() + } + return mydb.ErrNotWithinTransaction +} + +func (sess *sessionWithContext) Rollback() error { + if sess.sqlTx != nil { + return sess.sqlTx.Rollback() + } + return mydb.ErrNotWithinTransaction +} + +func (sess *sessionWithContext) IsTransaction() bool { + return sess.sqlTx != nil +} + +func (sess *sessionWithContext) Transaction() *sql.Tx { + return sess.sqlTx +} + +func (sess *sessionWithContext) Name() string { + sess.lookupNameOnce.Do(func() { + if sess.name == "" { + sess.name, _ = sess.adapter.LookupName(sess) + } + }) + + return sess.name +} + +func (sess *sessionWithContext) BindDB(sqlDB *sql.DB) error { + + sess.sqlDBMu.Lock() + sess.sqlDB = sqlDB + sess.sqlDBMu.Unlock() + + if err := sess.Ping(); err != nil { + return err + } + + sess.sessID = newSessionID() + name, err := sess.adapter.LookupName(sess) + if err != nil { + return err + } + sess.name = name + + return nil +} + +func (sess *sessionWithContext) Ping() error { + if sess.sqlDB != nil { + return sess.sqlmydb.Ping() + } + return mydb.ErrNotConnected +} + +func (sess *sessionWithContext) SetConnMaxLifetime(t time.Duration) { + sess.Settings.SetConnMaxLifetime(t) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetConnMaxLifetime(sess.Settings.ConnMaxLifetime()) + } +} + +func (sess *sessionWithContext) SetConnMaxIdleTime(t time.Duration) { + sess.Settings.SetConnMaxIdleTime(t) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetConnMaxIdleTime(sess.Settings.ConnMaxIdleTime()) + } +} + +func (sess *sessionWithContext) SetMaxIdleConns(n int) { + sess.Settings.SetMaxIdleConns(n) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetMaxIdleConns(sess.Settings.MaxIdleConns()) + } +} + +func (sess *sessionWithContext) SetMaxOpenConns(n int) { + sess.Settings.SetMaxOpenConns(n) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetMaxOpenConns(sess.Settings.MaxOpenConns()) + } +} + +// Reset removes all caches. +func (sess *sessionWithContext) Reset() { + sess.cacheMu.Lock() + defer sess.cacheMu.Unlock() + + sess.cachedPKs.Clear() + sess.cachedCollections.Clear() + sess.cachedStatements.Clear() + + if sess.template != nil { + sess.template.Cache.Clear() + } +} + +func (sess *sessionWithContext) NewClone(adapter AdapterSession, checkConn bool) (Session, error) { + + newSess := NewSession(sess.connURL, adapter).(*sessionWithContext) + + newSess.name = sess.name + newSess.sqlDB = sess.sqlDB + newSess.cachedPKs = sess.cachedPKs + + if checkConn { + if err := newSess.Ping(); err != nil { + // Retry once if ping fails. + return sess.NewClone(adapter, false) + } + } + + newSess.sessID = newSessionID() + + // New transaction should inherit parent settings + copySettings(sess, newSess) + + return newSess, nil +} + +func (sess *sessionWithContext) Close() error { + defer func() { + sess.sqlDBMu.Lock() + sess.sqlDB = nil + sess.sqlTx = nil + sess.sqlDBMu.Unlock() + }() + + if sess.sqlDB == nil { + return nil + } + + sess.cachedCollections.Clear() + sess.cachedStatements.Clear() // Closes prepared statements as well. + + if !sess.IsTransaction() { + if cleaner, ok := sess.adapter.(hasCleanUp); ok { + if err := cleaner.CleanUp(); err != nil { + return err + } + } + // Not within a transaction. + return sess.sqlmydb.Close() + } + + return nil +} + +func (sess *sessionWithContext) Collection(name string) mydb.Collection { + sess.cacheMu.Lock() + defer sess.cacheMu.Unlock() + + h := cache.NewHashable(hashTypeCollection, name) + + col, ok := sess.cachedCollections.ReadRaw(h) + if !ok { + col = newCollection(name, sess.adapter.NewCollection()) + sess.cachedCollections.Write(h, col) + } + + return &collectionWithSession{ + collection: col.(*collection), + session: sess, + } +} + +func queryLog(status *mydb.QueryStatus) { + diff := status.End.Sub(status.Start) + + slowQuery := false + if diff >= slowQueryThreshold { + status.Err = mydb.ErrWarnSlowQuery + slowQuery = true + } + + if status.Err != nil || slowQuery { + mydb.LC().Warn(status) + return + } + + mydb.LC().Debug(status) +} + +func (sess *sessionWithContext) StatementPrepare(ctx context.Context, stmt *exql.Statement) (sqlStmt *sql.Stmt, err error) { + var query string + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + }) + }(time.Now()) + + query, _, err = sess.compileStatement(stmt, nil) + if err != nil { + return nil, err + } + + tx := sess.Transaction() + if tx != nil { + sqlStmt, err = compat.PrepareContext(tx, ctx, query) + return + } + + sqlStmt, err = compat.PrepareContext(sess.sqlDB, ctx, query) + return +} + +func (sess *sessionWithContext) ConvertValue(value interface{}) interface{} { + if scannerValuer, ok := value.(sqlbuilder.ScannerValuer); ok { + return scannerValuer + } + + dv := reflect.Indirect(reflect.ValueOf(value)) + if dv.IsValid() { + if converter, ok := dv.Interface().(valueConverter); ok { + return converter.ConvertValue(dv.Interface()) + } + } + + if converter, ok := sess.adapter.(sessValueConverterContext); ok { + return converter.ConvertValueContext(sess.Context(), value) + } + + if converter, ok := sess.adapter.(sessValueConverter); ok { + return converter.ConvertValue(value) + } + + return value +} + +func (sess *sessionWithContext) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { + var query string + + defer func(start time.Time) { + status := mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + } + + if res != nil { + if rowsAffected, err := res.RowsAffected(); err == nil { + status.RowsAffected = &rowsAffected + } + + if lastInsertID, err := res.LastInsertId(); err == nil { + status.LastInsertID = &lastInsertID + } + } + + queryLog(&status) + }(time.Now()) + + if execer, ok := sess.adapter.(statementExecer); ok { + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + res, err = execer.StatementExec(sess, ctx, query, args...) + return + } + + tx := sess.Transaction() + if sess.Settings.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil { + return nil, err + } + defer p.Close() + + res, err = compat.PreparedExecContext(p, ctx, args) + return + } + + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + + if tx != nil { + res, err = compat.ExecContext(tx, ctx, query, args) + return + } + + res, err = compat.ExecContext(sess.sqlDB, ctx, query, args) + return +} + +// StatementQuery compiles and executes a statement that returns rows. +func (sess *sessionWithContext) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { + var query string + + defer func(start time.Time) { + status := mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + } + queryLog(&status) + }(time.Now()) + + tx := sess.Transaction() + + if sess.Settings.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil { + return nil, err + } + defer p.Close() + + rows, err = compat.PreparedQueryContext(p, ctx, args) + return + } + + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + if tx != nil { + rows, err = compat.QueryContext(tx, ctx, query, args) + return + } + + rows, err = compat.QueryContext(sess.sqlDB, ctx, query, args) + return +} + +// StatementQueryRow compiles and executes a statement that returns at most one +// row. +func (sess *sessionWithContext) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { + var query string + + defer func(start time.Time) { + status := mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + } + queryLog(&status) + }(time.Now()) + + tx := sess.Transaction() + + if sess.Settings.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil { + return nil, err + } + defer p.Close() + + row = compat.PreparedQueryRowContext(p, ctx, args) + return + } + + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + if tx != nil { + row = compat.QueryRowContext(tx, ctx, query, args) + return + } + + row = compat.QueryRowContext(sess.sqlDB, ctx, query, args) + return +} + +// Driver returns the underlying *sql.DB or *sql.Tx instance. +func (sess *sessionWithContext) Driver() interface{} { + if sess.sqlTx != nil { + return sess.sqlTx + } + return sess.sqlDB +} + +// compileStatement compiles the given statement into a string. +func (sess *sessionWithContext) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}, error) { + for i := range args { + args[i] = sess.ConvertValue(args[i]) + } + if statementCompiler, ok := sess.adapter.(statementCompiler); ok { + return statementCompiler.CompileStatement(sess, stmt, args) + } + + compiled, err := stmt.Compile(sess.adapter.Template()) + if err != nil { + return "", nil, err + } + query, args := sqlbuilder.Preprocess(compiled, args) + return query, args, nil +} + +// prepareStatement compiles a query and tries to use previously generated +// statement. +func (sess *sessionWithContext) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { + sess.sqlDBMu.Lock() + defer sess.sqlDBMu.Unlock() + + sqlDB, tx := sess.sqlDB, sess.Transaction() + if sqlDB == nil && tx == nil { + return nil, "", nil, mydb.ErrNotConnected + } + + pc, ok := sess.cachedStatements.ReadRaw(stmt) + if ok { + // The statement was cachesess. + ps, err := pc.(*Stmt).Open() + if err == nil { + _, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, "", nil, err + } + return ps, ps.query, args, nil + } + } + + query, args, err := sess.compileStatement(stmt, args) + if err != nil { + return nil, "", nil, err + } + sqlStmt, err := func(query *string) (*sql.Stmt, error) { + if tx != nil { + return compat.PrepareContext(tx, ctx, *query) + } + return compat.PrepareContext(sess.sqlDB, ctx, *query) + }(&query) + if err != nil { + return nil, "", nil, err + } + + p, err := NewStatement(sqlStmt, query).Open() + if err != nil { + return nil, query, args, err + } + sess.cachedStatements.Write(stmt, p) + return p, p.query, args, nil +} + +var waitForConnMu sync.Mutex + +// WaitForConnection tries to execute the given connectFn function, if +// connectFn returns an error, then WaitForConnection will keep trying until +// connectFn returns nil. Maximum waiting time is 5s after having acquired the +// lock. +func (sess *sessionWithContext) WaitForConnection(connectFn func() error) error { + // This lock ensures first-come, first-served and prevents opening too many + // file descriptors. + waitForConnMu.Lock() + defer waitForConnMu.Unlock() + + // Minimum waiting time. + waitTime := time.Millisecond * 10 + + // Waitig 5 seconds for a successful connection. + for timeStart := time.Now(); time.Since(timeStart) < time.Second*5; { + err := connectFn() + if err == nil { + return nil // Connected! + } + + // Only attempt to reconnect if the error is too many clients. + if sess.Err(err) == mydb.ErrTooManyClients { + // Sleep and try again if, and only if, the server replied with a "too + // many clients" error. + time.Sleep(waitTime) + if waitTime < time.Millisecond*500 { + // Wait a bit more next time. + waitTime = waitTime * 2 + } + continue + } + + // Return any other error immediately. + return err + } + + return mydb.ErrGivingUpTryingToConnect +} + +// ReplaceWithDollarSign turns a SQL statament with '?' placeholders into +// dollar placeholders, like $1, $2, ..., $n +func ReplaceWithDollarSign(buf []byte) []byte { + z := bytes.Count(buf, []byte{'?'}) + // the capacity is a quick estimation of the total memory required, this + // reduces reallocations + out := make([]byte, 0, len(buf)+z*3) + + var i, k = 0, 1 + for i < len(buf) { + if buf[i] == '?' { + out = append(out, buf[:i]...) + buf = buf[i+1:] + i = 0 + + if len(buf) > 0 && buf[0] == '?' { + out = append(out, '?') + buf = buf[1:] + continue + } + + out = append(out, '$') + out = append(out, []byte(strconv.Itoa(k))...) + k = k + 1 + continue + } + i = i + 1 + } + + out = append(out, buf[:len(buf)]...) + buf = nil + + return out +} + +func copySettings(from Session, into Session) { + into.SetPreparedStatementCache(from.PreparedStatementCacheEnabled()) + into.SetConnMaxLifetime(from.ConnMaxLifetime()) + into.SetConnMaxIdleTime(from.ConnMaxIdleTime()) + into.SetMaxIdleConns(from.MaxIdleConns()) + into.SetMaxOpenConns(from.MaxOpenConns()) +} + +func newSessionID() uint64 { + if atomic.LoadUint64(&lastSessID) == math.MaxUint64 { + atomic.StoreUint64(&lastSessID, 0) + return 0 + } + return atomic.AddUint64(&lastSessID, 1) +} + +func newBaseTxID() uint64 { + if atomic.LoadUint64(&lastTxID) == math.MaxUint64 { + atomic.StoreUint64(&lastTxID, 0) + return 0 + } + return atomic.AddUint64(&lastTxID, 1) +} + +// TxContext creates a transaction context and runs fn within it. +func TxContext(ctx context.Context, sess mydb.Session, fn func(tx mydb.Session) error, opts *sql.TxOptions) error { + txFn := func(sess mydb.Session) error { + tx, err := sess.(Session).NewTransaction(ctx, opts) + if err != nil { + return err + } + defer tx.Close() + + if err := fn(tx); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("%v: %w", rollbackErr, err) + } + return err + } + return tx.Commit() + } + + retryTime := retryTransactionWaitTime + + var txErr error + for i := 0; i < sess.MaxTransactionRetries(); i++ { + txErr = sess.(*sessionWithContext).Err(txFn(sess)) + if txErr == nil { + return nil + } + if errors.Is(txErr, mydb.ErrTransactionAborted) { + time.Sleep(retryTime) + + retryTime = retryTime * 2 + if retryTime > retryTransactionMaxWaitTime { + retryTime = retryTransactionMaxWaitTime + } + + continue + } + return txErr + } + + return fmt.Errorf("db: giving up trying to commit transaction: %w", txErr) +} + +var _ = mydb.Session(&sessionWithContext{}) diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go new file mode 100644 index 0000000..a567e2e --- /dev/null +++ b/internal/sqladapter/sqladapter.go @@ -0,0 +1,62 @@ +// Package sqladapter provides common logic for SQL adapters. +package sqladapter + +import ( + "database/sql" + "database/sql/driver" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// IsKeyValue reports whether v is a valid value for a primary key that can be +// used with Find(pKey). +func IsKeyValue(v interface{}) bool { + if v == nil { + return true + } + switch v.(type) { + case int64, int, uint, uint64, + []int64, []int, []uint, []uint64, + []byte, []string, + []interface{}, + driver.Valuer: + return true + } + return false +} + +type sqlAdapterWrapper struct { + adapter AdapterSession +} + +func (w *sqlAdapterWrapper) OpenDSN(dsn mydb.ConnectionURL) (mydb.Session, error) { + sess := NewSession(dsn, w.adapter) + if err := sess.Open(); err != nil { + return nil, err + } + return sess, nil +} + +func (w *sqlAdapterWrapper) NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + tx, err := NewTx(w.adapter, sqlTx) + if err != nil { + return nil, err + } + return tx, nil +} + +func (w *sqlAdapterWrapper) New(sqlDB *sql.DB) (mydb.Session, error) { + sess := NewSession(nil, w.adapter) + if err := sess.BindDB(sqlDB); err != nil { + return nil, err + } + return sess, nil +} + +// RegisterAdapter registers a new SQL adapter. +func RegisterAdapter(name string, adapter AdapterSession) sqlbuilder.Adapter { + z := &sqlAdapterWrapper{adapter} + mydb.RegisterAdapter(name, sqlbuilder.NewCompatAdapter(z)) + return z +} diff --git a/internal/sqladapter/sqladapter_test.go b/internal/sqladapter/sqladapter_test.go new file mode 100644 index 0000000..99240d0 --- /dev/null +++ b/internal/sqladapter/sqladapter_test.go @@ -0,0 +1,45 @@ +package sqladapter + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/assert" +) + +var ( + _ mydb.Collection = &collectionWithSession{} + _ Collection = &collectionWithSession{} +) + +func TestReplaceWithDollarSign(t *testing.T) { + tests := []struct { + in string + out string + }{ + { + `SELECT ?`, + `SELECT $1`, + }, + { + `SELECT ? FROM ? WHERE ?`, + `SELECT $1 FROM $2 WHERE $3`, + }, + { + `SELECT ?? FROM ? WHERE ??`, + `SELECT ? FROM $1 WHERE ?`, + }, + { + `SELECT ??? FROM ? WHERE ??`, + `SELECT ?$1 FROM $2 WHERE ?`, + }, + { + `SELECT ??? FROM ? WHERE ????`, + `SELECT ?$1 FROM $2 WHERE ??`, + }, + } + + for _, test := range tests { + assert.Equal(t, []byte(test.out), ReplaceWithDollarSign([]byte(test.in))) + } +} diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go new file mode 100644 index 0000000..0b18ebd --- /dev/null +++ b/internal/sqladapter/statement.go @@ -0,0 +1,85 @@ +package sqladapter + +import ( + "database/sql" + "errors" + "sync" + "sync/atomic" +) + +var ( + activeStatements int64 +) + +// Stmt represents a *sql.Stmt that is cached and provides the +// OnEvict method to allow it to clean after itself. +type Stmt struct { + *sql.Stmt + + query string + mu sync.Mutex + + count int64 + dead bool +} + +// NewStatement creates an returns an opened statement +func NewStatement(stmt *sql.Stmt, query string) *Stmt { + s := &Stmt{ + Stmt: stmt, + query: query, + } + atomic.AddInt64(&activeStatements, 1) + return s +} + +// Open marks the statement as in-use +func (c *Stmt) Open() (*Stmt, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.dead { + return nil, errors.New("statement is dead") + } + + c.count++ + return c, nil +} + +// Close closes the underlying statement if no other go-routine is using it. +func (c *Stmt) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + c.count-- + + return c.checkClose() +} + +func (c *Stmt) checkClose() error { + if c.dead && c.count == 0 { + // Statement is dead and we can close it for real. + err := c.Stmt.Close() + if err != nil { + return err + } + // Reduce active statements counter. + atomic.AddInt64(&activeStatements, -1) + } + return nil +} + +// OnEvict marks the statement as ready to be cleaned up. +func (c *Stmt) OnEvict() { + c.mu.Lock() + defer c.mu.Unlock() + + c.dead = true + c.checkClose() +} + +// NumActiveStatements returns the global number of prepared statements in use +// at any point. +func NumActiveStatements() int64 { + return atomic.LoadInt64(&activeStatements) +} diff --git a/internal/sqlbuilder/batch.go b/internal/sqlbuilder/batch.go new file mode 100644 index 0000000..c988e04 --- /dev/null +++ b/internal/sqlbuilder/batch.go @@ -0,0 +1,84 @@ +package sqlbuilder + +import git.hexq.cn/tiglog/mydb + +// BatchInserter provides a helper that can be used to do massive insertions in +// batches. +type BatchInserter struct { + inserter *inserter + size int + values chan []interface{} + err error +} + +func newBatchInserter(inserter *inserter, size int) *BatchInserter { + if size < 1 { + size = 1 + } + b := &BatchInserter{ + inserter: inserter, + size: size, + values: make(chan []interface{}, size), + } + return b +} + +// Values pushes column values to be inserted as part of the batch. +func (b *BatchInserter) Values(values ...interface{}) mydb.BatchInserter { + b.values <- values + return b +} + +func (b *BatchInserter) nextQuery() *inserter { + ins := &inserter{} + *ins = *b.inserter + i := 0 + for values := range b.values { + i++ + ins = ins.Values(values...).(*inserter) + if i == b.size { + break + } + } + if i == 0 { + return nil + } + return ins +} + +// NextResult is useful when using PostgreSQL and Returning(), it dumps the +// next slice of results to dst, which can mean having the IDs of all inserted +// elements in the batch. +func (b *BatchInserter) NextResult(dst interface{}) bool { + clone := b.nextQuery() + if clone == nil { + return false + } + b.err = clone.Iterator().All(dst) + return (b.err == nil) +} + +// Done means that no more elements are going to be added. +func (b *BatchInserter) Done() { + close(b.values) +} + +// Wait blocks until the whole batch is executed. +func (b *BatchInserter) Wait() error { + for { + q := b.nextQuery() + if q == nil { + break + } + if _, err := q.Exec(); err != nil { + b.err = err + break + } + } + return b.Err() +} + +// Err returns any error while executing the batch. +func (b *BatchInserter) Err() error { + return b.err +} diff --git a/internal/sqlbuilder/builder.go b/internal/sqlbuilder/builder.go new file mode 100644 index 0000000..a90ac74 --- /dev/null +++ b/internal/sqlbuilder/builder.go @@ -0,0 +1,611 @@ +// Package sqlbuilder provides tools for building custom SQL queries. +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/reflectx" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/compat" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +// MapOptions represents options for the mapper. +type MapOptions struct { + IncludeZeroed bool + IncludeNil bool +} + +var defaultMapOptions = MapOptions{ + IncludeZeroed: false, + IncludeNil: false, +} + +type hasPaginator interface { + Paginator() (mydb.Paginator, error) +} + +type isCompilable interface { + Compile() (string, error) + Arguments() []interface{} +} + +type hasIsZero interface { + IsZero() bool +} + +type iterator struct { + sess exprDB + cursor *sql.Rows // This is the main query cursor. It starts as a nil value. + err error +} + +type fieldValue struct { + fields []string + values []interface{} +} + +var ( + sqlPlaceholder = &exql.Raw{Value: `?`} +) + +var ( + errDeprecatedJSONBTag = errors.New(`Tag "jsonb" is deprecated. See "PostgreSQL: jsonb tag" at https://github.com/upper/db/releases/tag/v3.4.0`) +) + +type exprDB interface { + StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) + StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error) + StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) + StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) + + Context() context.Context +} + +type sqlBuilder struct { + sess exprDB + t *templateWithUtils +} + +// WithSession returns a query builder that is bound to the given database session. +func WithSession(sess interface{}, t *exql.Template) mydb.SQL { + if sqlDB, ok := sess.(*sql.DB); ok { + sess = sqlDB + } + return &sqlBuilder{ + sess: sess.(exprDB), // Let it panic, it will show the developer an informative error. + t: newTemplateWithUtils(t), + } +} + +// WithTemplate returns a builder that is based on the given template. +func WithTemplate(t *exql.Template) mydb.SQL { + return &sqlBuilder{ + t: newTemplateWithUtils(t), + } +} + +func (b *sqlBuilder) NewIteratorContext(ctx context.Context, rows *sql.Rows) mydb.Iterator { + return &iterator{b.sess, rows, nil} +} + +func (b *sqlBuilder) NewIterator(rows *sql.Rows) mydb.Iterator { + return b.NewIteratorContext(b.sess.Context(), rows) +} + +func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) mydb.Iterator { + return b.IteratorContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) IteratorContext(ctx context.Context, query interface{}, args ...interface{}) mydb.Iterator { + rows, err := b.QueryContext(ctx, query, args...) + return &iterator{b.sess, rows, err} +} + +func (b *sqlBuilder) Prepare(query interface{}) (*sql.Stmt, error) { + return b.PrepareContext(b.sess.Context(), query) +} + +func (b *sqlBuilder) PrepareContext(ctx context.Context, query interface{}) (*sql.Stmt, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementPrepare(ctx, q) + case string: + return b.sess.StatementPrepare(ctx, exql.RawSQL(q)) + case *adapter.RawExpr: + return b.PrepareContext(ctx, q.Raw()) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + return b.ExecContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementExec(ctx, q, args...) + case string: + return b.sess.StatementExec(ctx, exql.RawSQL(q), args...) + case *adapter.RawExpr: + return b.ExecContext(ctx, q.Raw(), q.Arguments()...) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, error) { + return b.QueryContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementQuery(ctx, q, args...) + case string: + return b.sess.StatementQuery(ctx, exql.RawSQL(q), args...) + case *adapter.RawExpr: + return b.QueryContext(ctx, q.Raw(), q.Arguments()...) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) { + return b.QueryRowContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementQueryRow(ctx, q, args...) + case string: + return b.sess.StatementQueryRow(ctx, exql.RawSQL(q), args...) + case *adapter.RawExpr: + return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) SelectFrom(table ...interface{}) mydb.Selector { + qs := &selector{ + builder: b, + } + return qs.From(table...) +} + +func (b *sqlBuilder) Select(columns ...interface{}) mydb.Selector { + qs := &selector{ + builder: b, + } + return qs.Columns(columns...) +} + +func (b *sqlBuilder) InsertInto(table string) mydb.Inserter { + qi := &inserter{ + builder: b, + } + return qi.Into(table) +} + +func (b *sqlBuilder) DeleteFrom(table string) mydb.Deleter { + qd := &deleter{ + builder: b, + } + return qd.setTable(table) +} + +func (b *sqlBuilder) Update(table string) mydb.Updater { + qu := &updater{ + builder: b, + } + return qu.setTable(table) +} + +// Map receives a pointer to map or struct and maps it to columns and values. +func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) { + var fv fieldValue + if options == nil { + options = &defaultMapOptions + } + + itemV := reflect.ValueOf(item) + if !itemV.IsValid() { + return nil, nil, nil + } + + itemT := itemV.Type() + + if itemT.Kind() == reflect.Ptr { + // Single dereference. Just in case the user passes a pointer to struct + // instead of a struct. + item = itemV.Elem().Interface() + itemV = reflect.ValueOf(item) + itemT = itemV.Type() + } + + switch itemT.Kind() { + case reflect.Struct: + fieldMap := Mapper.TypeMap(itemT).Names + nfields := len(fieldMap) + + fv.values = make([]interface{}, 0, nfields) + fv.fields = make([]string, 0, nfields) + + for _, fi := range fieldMap { + + // Check for deprecated JSONB tag + if _, hasJSONBTag := fi.Options["jsonb"]; hasJSONBTag { + return nil, nil, errDeprecatedJSONBTag + } + + // Field options + _, tagOmitEmpty := fi.Options["omitempty"] + + fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) + if fld.Kind() == reflect.Ptr && fld.IsNil() { + if tagOmitEmpty && !options.IncludeNil { + continue + } + fv.fields = append(fv.fields, fi.Name) + if tagOmitEmpty { + fv.values = append(fv.values, sqlDefault) + } else { + fv.values = append(fv.values, nil) + } + continue + } + + value := fld.Interface() + + isZero := false + if t, ok := fld.Interface().(hasIsZero); ok { + if t.IsZero() { + isZero = true + } + } else if fld.Kind() == reflect.Array || fld.Kind() == reflect.Slice { + if fld.Len() == 0 { + isZero = true + } + } else if reflect.DeepEqual(fi.Zero.Interface(), value) { + isZero = true + } + + if isZero && tagOmitEmpty && !options.IncludeZeroed { + continue + } + + fv.fields = append(fv.fields, fi.Name) + v, err := marshal(value) + if err != nil { + return nil, nil, err + } + if isZero && tagOmitEmpty { + v = sqlDefault + } + fv.values = append(fv.values, v) + } + + case reflect.Map: + nfields := itemV.Len() + fv.values = make([]interface{}, nfields) + fv.fields = make([]string, nfields) + mkeys := itemV.MapKeys() + + for i, keyV := range mkeys { + valv := itemV.MapIndex(keyV) + fv.fields[i] = fmt.Sprintf("%v", keyV.Interface()) + + v, err := marshal(valv.Interface()) + if err != nil { + return nil, nil, err + } + + fv.values[i] = v + } + default: + return nil, nil, ErrExpectingPointerToEitherMapOrStruct + } + + sort.Sort(&fv) + + return fv.fields, fv.values, nil +} + +func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, error) { + f := make([]exql.Fragment, len(columns)) + args := []interface{}{} + + for i := range columns { + switch v := columns[i].(type) { + case hasPaginator: + p, err := v.Paginator() + if err != nil { + return nil, nil, err + } + + q, a := Preprocess(p.String(), p.Arguments()) + + f[i] = &exql.Raw{Value: "(" + q + ")"} + args = append(args, a...) + case isCompilable: + c, err := v.Compile() + if err != nil { + return nil, nil, err + } + q, a := Preprocess(c, v.Arguments()) + if _, ok := v.(mydb.Selector); ok { + q = "(" + q + ")" + } + f[i] = &exql.Raw{Value: q} + args = append(args, a...) + case *adapter.FuncExpr: + fnName, fnArgs := v.Name(), v.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + f[i] = &exql.Raw{Value: fnName} + args = append(args, fnArgs...) + case *adapter.RawExpr: + q, a := Preprocess(v.Raw(), v.Arguments()) + f[i] = &exql.Raw{Value: q} + args = append(args, a...) + case exql.Fragment: + f[i] = v + case string: + f[i] = exql.ColumnWithName(v) + case fmt.Stringer: + f[i] = exql.ColumnWithName(v.String()) + default: + var err error + f[i], err = exql.NewRawValue(columns[i]) + if err != nil { + return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument: %w", v, err) + } + } + } + return f, args, nil +} + +func prepareQueryForDisplay(in string) string { + out := make([]byte, 0, len(in)) + + offset := 0 + whitespace := true + placeholders := 1 + + for i := 0; i < len(in); i++ { + if in[i] == ' ' || in[i] == '\r' || in[i] == '\n' || in[i] == '\t' { + if whitespace { + offset = i + } else { + whitespace = true + out = append(out, in[offset:i]...) + offset = i + } + continue + } + if whitespace { + whitespace = false + if len(out) > 0 { + out = append(out, ' ') + } + offset = i + } + if in[i] == '?' { + out = append(out, in[offset:i]...) + offset = i + 1 + + out = append(out, '$') + out = append(out, strconv.Itoa(placeholders)...) + placeholders++ + } + } + if !whitespace { + out = append(out, in[offset:len(in)]...) + } + return string(out) +} + +func (iter *iterator) NextScan(dst ...interface{}) error { + if ok := iter.Next(); ok { + return iter.Scan(dst...) + } + if err := iter.Err(); err != nil { + return err + } + return mydb.ErrNoMoreRows +} + +func (iter *iterator) ScanOne(dst ...interface{}) error { + defer iter.Close() + return iter.NextScan(dst...) +} + +func (iter *iterator) Scan(dst ...interface{}) error { + if err := iter.Err(); err != nil { + return err + } + return iter.cursor.Scan(dst...) +} + +func (iter *iterator) setErr(err error) error { + iter.err = err + return iter.err +} + +func (iter *iterator) One(dst interface{}) error { + if err := iter.Err(); err != nil { + return err + } + defer iter.Close() + return iter.setErr(iter.next(dst)) +} + +func (iter *iterator) All(dst interface{}) error { + if err := iter.Err(); err != nil { + return err + } + defer iter.Close() + + // Fetching all results within the cursor. + if err := fetchRows(iter, dst); err != nil { + return iter.setErr(err) + } + + return nil +} + +func (iter *iterator) Err() (err error) { + return iter.err +} + +func (iter *iterator) Next(dst ...interface{}) bool { + if err := iter.Err(); err != nil { + return false + } + + if err := iter.next(dst...); err != nil { + // ignore mydb.ErrNoMoreRows, just break. + if !errors.Is(err, mydb.ErrNoMoreRows) { + _ = iter.setErr(err) + } + return false + } + + return true +} + +func (iter *iterator) next(dst ...interface{}) error { + if iter.cursor == nil { + return iter.setErr(mydb.ErrNoMoreRows) + } + + switch len(dst) { + case 0: + if ok := iter.cursor.Next(); !ok { + defer iter.Close() + err := iter.cursor.Err() + if err == nil { + err = mydb.ErrNoMoreRows + } + return err + } + return nil + case 1: + if err := fetchRow(iter, dst[0]); err != nil { + defer iter.Close() + return err + } + return nil + } + + return errors.New("Next does not currently supports more than one parameters") +} + +func (iter *iterator) Close() (err error) { + if iter.cursor != nil { + err = iter.cursor.Close() + iter.cursor = nil + } + return err +} + +func marshal(v interface{}) (interface{}, error) { + if m, isMarshaler := v.(mydb.Marshaler); isMarshaler { + var err error + if v, err = m.MarshalDB(); err != nil { + return nil, err + } + } + return v, nil +} + +func (fv *fieldValue) Len() int { + return len(fv.fields) +} + +func (fv *fieldValue) Swap(i, j int) { + fv.fields[i], fv.fields[j] = fv.fields[j], fv.fields[i] + fv.values[i], fv.values[j] = fv.values[j], fv.values[i] +} + +func (fv *fieldValue) Less(i, j int) bool { + return fv.fields[i] < fv.fields[j] +} + +type exprProxy struct { + db *sql.DB + t *exql.Template +} + +func (p *exprProxy) Context() context.Context { + return context.Background() +} + +func (p *exprProxy) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.ExecContext(p.db, ctx, s, args) +} + +func (p *exprProxy) StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.PrepareContext(p.db, ctx, s) +} + +func (p *exprProxy) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.QueryContext(p.db, ctx, s, args) +} + +func (p *exprProxy) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.QueryRowContext(p.db, ctx, s, args), nil +} + +var ( + _ = mydb.SQL(&sqlBuilder{}) + _ = exprDB(&exprProxy{}) +) + +func joinArguments(args ...[]interface{}) []interface{} { + total := 0 + for i := range args { + total += len(args[i]) + } + if total == 0 { + return nil + } + + flatten := make([]interface{}, 0, total) + for i := range args { + flatten = append(flatten, args[i]...) + } + return flatten +} diff --git a/internal/sqlbuilder/builder_test.go b/internal/sqlbuilder/builder_test.go new file mode 100644 index 0000000..4bb843b --- /dev/null +++ b/internal/sqlbuilder/builder_test.go @@ -0,0 +1,1510 @@ +package sqlbuilder + +import ( + "fmt" + "regexp" + "strings" + "testing" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/assert" +) + +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) +) + +func TestSelect(t *testing.T) { + + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Func("DATE")).String(), + ) + + assert.Equal( + `SELECT DATE() FOR UPDATE`, + b.Select(mydb.Func("DATE")).Amend(func(query string) string { + return query + " FOR UPDATE" + }).String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.SelectFrom("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT "bcolor" FROM "artist"`, + b.Select().Distinct("bcolor").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT * FROM "artist"`, + b.Select().Distinct().From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON("col1"), "col2" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON("col1")`), "col2").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON("col1") AS col1, "col2" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON("col1") AS col1`)).Distinct("col2").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON("col1") AS col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON("col1") AS col1`)).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON(SELECT foo FROM bar) col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON(?) col1`, mydb.Raw(`SELECT foo FROM bar`))).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + + { + q0 := b.Select("foo").From("bar") + assert.Equal( + `SELECT DISTINCT ON (SELECT "foo" FROM "bar") col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON ? col1`, q0)).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + } + + assert.Equal( + `SELECT DISTINCT ON (SELECT foo FROM bar, SELECT baz from qux) col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON ? col1`, []interface{}{mydb.Raw(`SELECT foo FROM bar`), mydb.Raw(`SELECT baz from qux`)})).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + + { + q := b.Select(). + Distinct(mydb.Raw(`ON ? col1`, []*mydb.RawExpr{mydb.Raw(`SELECT foo FROM bar WHERE id = ?`, 1), mydb.Raw(`SELECT baz from qux WHERE id = 2`)})). + Columns("col2", "col3"). + Distinct("col4", "col5").From("artist"). + Where("id", 3) + assert.Equal( + `SELECT DISTINCT ON (SELECT foo FROM bar WHERE id = $1, SELECT baz from qux WHERE id = 2) col1, "col2", "col3", "col4", "col5" FROM "artist" WHERE ("id" = $2)`, + q.String(), + ) + + assert.Equal( + []interface{}{1, 3}, + q.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000, 2000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + sel.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000}, + sel.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN (NULL) THEN 0 ELSE 1 END") + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + rawCase.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN (?, ?) THEN 0 ELSE 1 END", 1000, 2000) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + rawCase.Arguments(), + ) + } + + { + sel := b.Select(mydb.Func("DISTINCT", "name")).From("artist") + assert.Equal( + `SELECT DISTINCT($1) FROM "artist"`, + sel.String(), + ) + assert.Equal( + []interface{}{"name"}, + sel.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = $1)`, + b.Select().From("artist").Where(mydb.Cond{1: 1}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = ANY($1))`, + b.Select().From("artist").Where(mydb.Cond{1: mydb.Func("ANY", "name")}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = ANY(column))`, + b.Select().From("artist").Where(mydb.Cond{1: mydb.Func("ANY", mydb.Raw("column"))}).String(), + ) + + { + q := b.Select().From("artist").Where(mydb.Cond{"id NOT IN": []int{0, -1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, + q.String(), + ) + assert.Equal( + []interface{}{0, -1}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Where(mydb.Cond{"id NOT IN": []int{-1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`, + q.String(), + ) + assert.Equal( + []interface{}{-1}, + q.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2))`, + b.Select().From("artist").Where(mydb.Cond{"id IN": []int{0, -1}}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (("id" = $1 OR "id" = $2 OR "id" = $3))`, + b.Select().From("artist").Where( + mydb.Or( + mydb.Cond{"id": 1}, + mydb.Cond{"id": 2}, + mydb.Cond{"id": 3}, + ), + ).String(), + ) + + { + q := b.Select().From("artist").Where( + mydb.Or( + mydb.And(mydb.Cond{"a": 1}, mydb.Cond{"b": 2}, mydb.Cond{"c": 3}), + mydb.And(mydb.Cond{"d": 1}, mydb.Cond{"e": 2}, mydb.Cond{"f": 3}), + ), + ) + assert.Equal( + `SELECT * FROM "artist" WHERE ((("a" = $1 AND "b" = $2 AND "c" = $3) OR ("d" = $4 AND "e" = $5 AND "f" = $6)))`, + q.String(), + ) + assert.Equal( + []interface{}{1, 2, 3, 1, 2, 3}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Where( + mydb.Or( + mydb.And(mydb.Cond{"a": 1, "b": 2, "c": 3}), + mydb.And(mydb.Cond{"f": 6, "e": 5, "d": 4}), + ), + ) + assert.Equal( + `SELECT * FROM "artist" WHERE ((("a" = $1 AND "b" = $2 AND "c" = $3) OR ("d" = $4 AND "e" = $5 AND "f" = $6)))`, + q.String(), + ) + assert.Equal( + []interface{}{1, 2, 3, 4, 5, 6}, + q.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" WHERE ((("id" = $1 OR "id" = $2 OR "id" IS NULL) OR ("name" = $3 OR "name" = $4)))`, + b.Select().From("artist").Where( + mydb.Or( + mydb.Or( + mydb.Cond{"id": 1}, + mydb.Cond{"id": 2}, + mydb.Cond{"id IS": nil}, + ), + mydb.Or( + mydb.Cond{"name": "John"}, + mydb.Cond{"name": "Peter"}, + ), + ), + ).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ((("id" = $1 OR "id" = $2 OR "id" = $3 OR "id" = $4) AND ("name" = $5 AND "last_name" = $6) AND "age" > $7))`, + b.Select().From("artist").Where( + mydb.And( + mydb.Or( + mydb.Cond{"id": 1}, + mydb.Cond{"id": 2}, + mydb.Cond{"id": 3}, + ).Or( + mydb.Cond{"id": 4}, + ), + mydb.Or(), + mydb.And( + mydb.Cond{"name": "John"}, + mydb.Cond{"last_name": "Smith"}, + ), + mydb.And(), + ).And( + mydb.Cond{"age >": "20"}, + ), + ).String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.Select().From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + { + sel := b.Select().From("artist").OrderBy(mydb.Raw("id = ?", 1), "name DESC") + assert.Equal( + `SELECT * FROM "artist" ORDER BY id = $1 , "name" DESC`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.Select().From("artist").OrderBy(mydb.Func("RAND")) + assert.Equal( + `SELECT * FROM "artist" ORDER BY RAND()`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" ORDER BY RAND()`, + b.Select().From("artist").OrderBy(mydb.Raw("RAND()")).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" OFFSET 5`, + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 1 OFFSET 5`, + b.Select().From("artist").Limit(1).Offset(5).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist"`, + b.Select("id").From("artist").String(), + ) + + assert.Equal( + `SELECT "id", "name" FROM "artist"`, + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("name" = $1)`, + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name LIKE $1)`, + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`, + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1)`, + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`, + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`, + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id IN ($1, $2, $3, $4) AND foo = $5 AND bar IN ($6, $7, $8))`, + b.SelectFrom("artist").Where("id IN ? AND foo = ? AND bar IN ?", []int{1, 9, 8, 7}, 28, []int{1, 2, 3}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`, + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" NATURAL JOIN "publication"`, + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`, + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3 AND a.sub_id = $4) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Where("a.sub_id = ?", 3).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3 AND a.id = $4) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).And("a.id = ?", 3).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + `SELECT * FROM "artist" CROSS JOIN "publication"`, + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" USING ("id")`, + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IS NULL)`, + b.SelectFrom("artist").Where(mydb.Cond{"id": nil}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN (NULL))`, + b.SelectFrom("artist").Where(mydb.Cond{"id": []int64{}}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1))`, + b.SelectFrom("artist").Where(mydb.Cond{"id": []int64{0}}).String(), + ) + + assert.Equal( + `SELECT COUNT(*) AS total FROM "user" AS "u" JOIN (SELECT DISTINCT user_id FROM user_profile) AS up ON (u.id = up.user_id)`, + b.Select(mydb.Raw(`COUNT(*) AS total`)).From("user u").Join(mydb.Raw("(SELECT DISTINCT user_id FROM user_profile) AS up")).On("u.id = up.user_id").String(), + ) + + { + q0 := b.Select("user_id").Distinct().From("user_profile") + + assert.Equal( + `SELECT COUNT(*) AS total FROM "user" AS "u" JOIN (SELECT DISTINCT "user_id" FROM "user_profile") AS up ON (u.id = up.user_id)`, + b.Select(mydb.Raw(`COUNT(*) AS total`)).From("user u").Join(mydb.Raw("? AS up", q0)).On("u.id = up.user_id").String(), + ) + } + + { + q0 := b.Select("user_id").Distinct().From("user_profile").Where("t", []int{1, 2, 4, 5}) + + assert.Equal( + []interface{}{1, 2, 4, 5}, + q0.Arguments(), + ) + + q1 := b.Select(mydb.Raw(`COUNT(*) AS total`)).From("user u").Join(mydb.Raw("? AS up", q0)).On("u.id = up.user_id AND foo = ?", 8) + + assert.Equal( + `SELECT COUNT(*) AS total FROM "user" AS "u" JOIN (SELECT DISTINCT "user_id" FROM "user_profile" WHERE ("t" IN ($1, $2, $3, $4))) AS up ON (u.id = up.user_id AND foo = $5)`, + q1.String(), + ) + + assert.Equal( + []interface{}{1, 2, 4, 5, 8}, + q1.Arguments(), + ) + } + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Raw("DATE()")).String(), + ) + + { + sel := b.Select(mydb.Raw("CONCAT(?, ?)", "foo", "bar")) + assert.Equal( + `SELECT CONCAT($1, $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{"foo", "bar"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": mydb.Raw("1")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{mydb.Raw("1"): 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{mydb.Raw("1"): mydb.Raw("1")}) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Raw("1 = 1")) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Cond{"baz": mydb.Raw("CONCAT(?, ?)", "foo", "bar")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = CONCAT($2, $3))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "foo", "bar"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Raw("? = ANY(col)", "name")) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND $2 = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "name"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Cond{"name": mydb.Raw("ANY(col)")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "name" = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Cond{mydb.Raw("CONCAT(?, ?)", "a", "b"): mydb.Raw("ANY(col)")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND CONCAT($2, $3) = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "a", "b"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).And(mydb.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").And(mydb.Cond{"bar": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).And(mydb.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).Where(mydb.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Raw("bar->'baz' = ?", true)) + assert.Equal( + `SELECT * FROM "foo" WHERE (bar->'baz' = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{true}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{}).And(mydb.Cond{}) + assert.Equal( + `SELECT * FROM "foo"`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar = 1").And(mydb.Or( + mydb.Raw("fieldA ILIKE ?", `%a%`), + mydb.Raw("fieldB ILIKE ?", `%b%`), + )) + assert.Equal( + `SELECT * FROM "foo" WHERE (bar = 1 AND (fieldA ILIKE $1 OR fieldB ILIKE $2))`, + sel.String(), + ) + assert.Equal( + []interface{}{`%a%`, `%b%`}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("group_id", 1).And("user_id", 2) + assert.Equal( + `SELECT * FROM "foo" WHERE ("group_id" = $1 AND "user_id" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{1, 2}, + sel.Arguments(), + ) + } + + { + s := `SUM(CASE WHEN foo in ? THEN 1 ELSE 0 END) AS _sum` + sel := b.Select("c1").Columns(mydb.Raw(s, []int{5, 4, 3, 2})).From("foo").Where("bar = ?", 1) + assert.Equal( + `SELECT "c1", SUM(CASE WHEN foo in ($1, $2, $3, $4) THEN 1 ELSE 0 END) AS _sum FROM "foo" WHERE (bar = $5)`, + sel.String(), + ) + assert.Equal( + []interface{}{5, 4, 3, 2, 1}, + sel.Arguments(), + ) + } + + { + s := `SUM(CASE WHEN foo in ? THEN 1 ELSE 0 END) AS _sum` + sel := b.Select("c1").Columns(mydb.Raw(s, []int{5, 4, 3, 2})).From("foo").Where("bar = ?", 1) + sel2 := b.SelectFrom(sel).As("subquery").Where(mydb.Cond{"foo": "bar"}).OrderBy("subquery.seq") + assert.Equal( + `SELECT * FROM (SELECT "c1", SUM(CASE WHEN foo in ($1, $2, $3, $4) THEN 1 ELSE 0 END) AS _sum FROM "foo" WHERE (bar = $5)) AS "subquery" WHERE ("foo" = $6) ORDER BY "subquery"."seq" ASC`, + sel2.String(), + ) + assert.Equal( + []interface{}{5, 4, 3, 2, 1, "bar"}, + sel2.Arguments(), + ) + } + + { + series := b.Select( + mydb.Raw("start + interval ? - interval '1s' AS end", "1 day"), + ).From( + b.Select( + mydb.Raw("generate_series(?::timestamp, ?::timestamp, ?::interval) AS start", 1, 2, 3), + ), + ).As("series") + + assert.Equal( + []interface{}{"1 day", 1, 2, 3}, + series.Arguments(), + ) + + distinct := b.Select().Distinct( + mydb.Raw(`ON(dt.email) SUBSTRING(email,(POSITION('@' in email) + 1),252) AS email_domain`), + "dt.event_type AS event_type", + mydb.Raw("count(dt.*) AS count"), + "intervals.start AS start", + "intervals.end AS start", + ).From("email_events AS dt"). + RightJoin("intervals").On("dt.ts BETWEEN intervals.stast AND intervals.END AND dt.hub_id = ? AND dt.object_id = ?", 67, 68). + GroupBy("email_domain", "event_type", "start"). + OrderBy("email", "start", "event_type") + + sq, args := Preprocess( + `WITH intervals AS ? ?`, + []interface{}{ + series, + distinct, + }, + ) + + assert.Equal( + stripWhitespace(` + WITH intervals AS (SELECT start + interval ? - interval '1s' AS end FROM (SELECT generate_series(?::timestamp, ?::timestamp, ?::interval) AS start) AS "series") + (SELECT DISTINCT ON(dt.email) SUBSTRING(email,(POSITION('@' in email) + 1),252) AS email_domain, "dt"."event_type" AS "event_type", count(dt.*) AS count, "intervals"."start" AS "start", "intervals"."end" AS "start" + FROM "email_events" AS "dt" + RIGHT JOIN "intervals" ON (dt.ts BETWEEN intervals.stast AND intervals.END AND dt.hub_id = ? AND dt.object_id = ?) + GROUP BY "email_domain", "event_type", "start" + ORDER BY "email" ASC, "start" ASC, "event_type" ASC)`), + stripWhitespace(sq), + ) + + assert.Equal( + []interface{}{"1 day", 1, 2, 3, 67, 68}, + args, + ) + } + + { + sq := b. + Select("user_id"). + From("user_access"). + Where(mydb.Cond{"hub_id": 3}) + + // Don't reassign + _ = sq.And(mydb.Cond{"role": []int{1, 2}}) + + assert.Equal( + `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1)`, + sq.String(), + ) + + assert.Equal( + []interface{}{3}, + sq.Arguments(), + ) + + // Reassign + sq = sq.And(mydb.Cond{"role": []int{1, 2}}) + + assert.Equal( + `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))`, + sq.String(), + ) + + assert.Equal( + []interface{}{3, 1, 2}, + sq.Arguments(), + ) + + cond := mydb.Or( + mydb.Raw("a.id IN ?", sq), + ) + + cond = cond.Or(mydb.Cond{"ml.mailing_list_id": []int{4, 5, 6}}) + + sel := b. + Select(mydb.Raw("DISTINCT ON(a.id) a.id"), mydb.Raw("COALESCE(NULLIF(ml.name,''), a.name) as name"), "a.email"). + From("mailing_list_recipients ml"). + FullJoin("accounts a").On("a.id = ml.user_id"). + Where(cond) + + search := "word" + sel = sel.And(mydb.Or( + mydb.Raw("COALESCE(NULLIF(ml.name,''), a.name) ILIKE ?", fmt.Sprintf("%%%s%%", search)), + mydb.Cond{"a.email ILIKE": fmt.Sprintf("%%%s%%", search)}, + )) + + assert.Equal( + `SELECT DISTINCT ON(a.id) a.id, COALESCE(NULLIF(ml.name,''), a.name) as name, "a"."email" FROM "mailing_list_recipients" AS "ml" FULL JOIN "accounts" AS "a" ON (a.id = ml.user_id) WHERE ((a.id IN (SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))) OR "ml"."mailing_list_id" IN ($4, $5, $6)) AND (COALESCE(NULLIF(ml.name,''), a.name) ILIKE $7 OR "a"."email" ILIKE $8))`, + sel.String(), + ) + + assert.Equal( + []interface{}{3, 1, 2, 4, 5, 6, `%word%`, `%word%`}, + sel.Arguments(), + ) + + { + sel := b.Select(mydb.Func("FOO", "A", "B", 1)).From("accounts").Where(mydb.Cond{"time": mydb.Func("FUNCTION", "20170103", "YYYYMMDD", 1, "E")}) + + assert.Equal( + `SELECT FOO($1, $2, $3) FROM "accounts" WHERE ("time" = FUNCTION($4, $5, $6, $7))`, + sel.String(), + ) + assert.Equal( + []interface{}{"A", "B", 1, "20170103", "YYYYMMDD", 1, "E"}, + sel.Arguments(), + ) + } + + { + + sel := b.Select(mydb.Func("FOO", "A", "B", 1)).From("accounts").Where(mydb.Cond{mydb.Func("FUNCTION", "20170103", "YYYYMMDD", 1, "E"): mydb.Func("FOO", 1)}) + + assert.Equal( + `SELECT FOO($1, $2, $3) FROM "accounts" WHERE (FUNCTION($4, $5, $6, $7) = FOO($8))`, + sel.String(), + ) + assert.Equal( + []interface{}{"A", "B", 1, "20170103", "YYYYMMDD", 1, "E", 1}, + sel.Arguments(), + ) + } + } +} + +func TestInsert(t *testing.T) { + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Amend(func(query string) string { + return query + ` RETURNING "id"` + }).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{13, "Alondra de la Parra"}). + Values(artistStruct{14, "Haruki Murakami"}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{0, ""}). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{0, "Alondra de la Parra"}). + Values(artistStruct{14, ""}). + Values(artistStruct{0, ""}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES (DEFAULT, DEFAULT), ($1, $2), (DEFAULT, $3), ($4, DEFAULT), (DEFAULT, DEFAULT)`, + q.String(), + ) + + assert.Equal( + []interface{}{12, "Chavela Vargas", "Alondra de la Parra", 14}, + q.Arguments(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1)`, + b.InsertInto("artist"). + Values(artistStruct{Name: "Chavela Vargas"}). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1)`, + b.InsertInto("artist"). + Values(artistStruct{ID: 1}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + { + q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}) + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1)`, + q.String(), + ) + assert.Equal( + []interface{}{"Chavela Vargas"}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}).Values(artistStruct{Name: "Alondra de la Parra"}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES (DEFAULT, $1), (DEFAULT, $2)`, + q.String(), + ) + assert.Equal( + []interface{}{"Chavela Vargas", "Alondra de la Parra"}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{ID: 1}) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1)`, + q.String(), + ) + + assert.Equal( + []interface{}{1}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{ID: 1}).Values(artistStruct{ID: 2}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, DEFAULT), ($2, DEFAULT)`, + q.String(), + ) + + assert.Equal( + []interface{}{1, 2}, + q.Arguments(), + ) + } + + } + + { + intRef := func(i int) *int { + if i == 0 { + return nil + } + return &i + } + + strRef := func(s string) *string { + if s == "" { + return nil + } + return &s + } + + type artistStruct struct { + ID *int `db:"id,omitempty"` + Name *string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{intRef(0), strRef("")}). + Values(artistStruct{intRef(12), strRef("Chavela Vargas")}). + Values(artistStruct{intRef(0), strRef("Alondra de la Parra")}). + Values(artistStruct{intRef(14), strRef("")}). + Values(artistStruct{intRef(0), strRef("")}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES (DEFAULT, DEFAULT), ($1, $2), (DEFAULT, $3), ($4, DEFAULT), (DEFAULT, DEFAULT)`, + q.String(), + ) + + assert.Equal( + []interface{}{intRef(12), strRef("Chavela Vargas"), strRef("Alondra de la Parra"), intRef(14)}, + q.Arguments(), + ) + } + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) + + assert.Equal( + `INSERT INTO "artist" VALUES (default)`, + b.InsertInto("artist").String(), + ) +} + +func TestUpdate(t *testing.T) { + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 RETURNING "name"`, + b.Update("artist").Set("name", "Artist").Amend(func(query string) string { + return query + ` RETURNING "name"` + }).String(), + ) + + { + idSlice := []int64{8, 7, 6} + q := b.Update("artist").Set(mydb.Cond{"some_column": 10}).Where(mydb.Cond{"id": 1}, mydb.Cond{"another_val": idSlice}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN ($3, $4, $5))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1, int64(8), int64(7), int64(6)}, + q.Arguments(), + ) + } + + { + idSlice := []int64{} + q := b.Update("artist").Set(mydb.Cond{"some_column": 10}).Where(mydb.Cond{"id": 1}, mydb.Cond{"another_val": idSlice}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1}, + q.Arguments(), + ) + } + + { + idSlice := []int64{} + q := b.Update("artist").Where(mydb.Cond{"id": 1}, mydb.Cond{"another_val": idSlice}).Set(mydb.Cond{"some_column": 10}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1}, + q.Arguments(), + ) + } + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Where(mydb.Cond{"id <": 5}).Set(struct { + Nombre string `db:"name"` + }{"Artist"}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) + + { + q := b.Update("posts").Set("column = ?", "foo") + + assert.Equal( + `UPDATE "posts" SET "column" = $1`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo"}, + q.Arguments(), + ) + } + + { + q := b.Update("posts").Set(mydb.Raw("column = ?", "foo")) + + assert.Equal( + `UPDATE "posts" SET column = $1`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo"}, + q.Arguments(), + ) + } + + { + q := b.Update("posts").Set("foo = bar") + + assert.Equal( + []interface{}(nil), + q.Arguments(), + ) + + assert.Equal( + `UPDATE "posts" SET "foo" = bar`, + q.String(), + ) + } + + { + q := b.Update("posts").Set( + mydb.Cond{"tags": mydb.Raw("array_remove(tags, ?)", "foo")}, + ).Where(mydb.Raw("hub_id = ? AND ? = ANY(tags) AND ? = ANY(tags)", 1, "bar", "baz")) + + assert.Equal( + `UPDATE "posts" SET "tags" = array_remove(tags, $1) WHERE (hub_id = $2 AND $3 = ANY(tags) AND $4 = ANY(tags))`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo", 1, "bar", "baz"}, + q.Arguments(), + ) + } +} + +func TestDelete(t *testing.T) { + bt := WithTemplate(&testTemplate) + assert := assert.New(t) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1)`, + bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1) RETURNING 1`, + bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Amend(func(query string) string { + return fmt.Sprintf("%s RETURNING 1", query) + }).String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + bt.DeleteFrom("artist").Where("id > 5").String(), + ) +} + +func TestPaginate(t *testing.T) { + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + // Limit, offset + assert.Equal( + `SELECT * FROM "artist" LIMIT 10`, + b.Select().From("artist").Paginate(10).Page(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 10 OFFSET 10`, + b.Select().From("artist").Paginate(10).Page(2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 5 OFFSET 110`, + b.Select().From("artist").Paginate(5).Page(23).String(), + ) + + // Cursor + assert.Equal( + `SELECT * FROM "artist" ORDER BY "id" ASC LIMIT 10`, + b.Select().From("artist").Paginate(10).Cursor("id").String(), + ) + + { + q := b.Select().From("artist").Paginate(10).Cursor("id").NextPage(3) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1) ORDER BY "id" ASC LIMIT 10`, + q.String(), + ) + assert.Equal( + []interface{}{3}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Paginate(10).Cursor("id").PrevPage(30) + assert.Equal( + `SELECT * FROM (SELECT * FROM "artist" WHERE ("id" < $1) ORDER BY "id" DESC LIMIT 10) AS p0 ORDER BY "id" ASC`, + q.String(), + ) + assert.Equal( + []interface{}{30}, + q.Arguments(), + ) + } + + // Cursor reversed + assert.Equal( + `SELECT * FROM "artist" ORDER BY "id" DESC LIMIT 10`, + b.Select().From("artist").Paginate(10).Cursor("-id").String(), + ) + + { + q := b.Select().From("artist").Paginate(10).Cursor("-id").NextPage(3) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" < $1) ORDER BY "id" DESC LIMIT 10`, + q.String(), + ) + assert.Equal( + []interface{}{3}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Paginate(10).Cursor("-id").PrevPage(30) + assert.Equal( + `SELECT * FROM (SELECT * FROM "artist" WHERE ("id" > $1) ORDER BY "id" ASC LIMIT 10) AS p0 ORDER BY "id" DESC`, + q.String(), + ) + assert.Equal( + []interface{}{30}, + q.Arguments(), + ) + } +} + +func BenchmarkDelete1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String() + } +} + +func BenchmarkDelete2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.DeleteFrom("artist").Where("id > 5").String() + } +} + +func BenchmarkInsert1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(10, "Ryuichi Sakamoto").Values(11, "Alondra de la Parra").Values(12, "Haruki Murakami").String() + } +} + +func BenchmarkInsert2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String() + } +} + +func BenchmarkInsert3(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String() + } +} + +func BenchmarkInsert4(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String() + } +} + +func BenchmarkInsert5(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String() + } +} + +func BenchmarkSelect1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Select().From("artist").OrderBy("name DESC").String() + } +} + +func BenchmarkSelect2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String() + } +} + +func BenchmarkSelect3(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String() + } +} + +func BenchmarkSelect4(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String() + } +} + +func BenchmarkSelect5(b *testing.B) { + t := WithTemplate(&testTemplate) + b.ResetTimer() + for n := 0; n < b.N; n++ { + _ = t.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String() + } +} + +func BenchmarkUpdate1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set("name", "Artist").String() + } +} + +func BenchmarkUpdate2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String() + } +} + +func BenchmarkUpdate3(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String() + } +} + +func BenchmarkUpdate4(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String() + } +} + +func BenchmarkUpdate5(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String() + } +} + +func stripWhitespace(in string) string { + q := reInvisibleChars.ReplaceAllString(in, ` `) + return strings.TrimSpace(q) +} diff --git a/internal/sqlbuilder/comparison.go b/internal/sqlbuilder/comparison.go new file mode 100644 index 0000000..915aa8a --- /dev/null +++ b/internal/sqlbuilder/comparison.go @@ -0,0 +1,122 @@ +package sqlbuilder + +import ( + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +var comparisonOperators = map[adapter.ComparisonOperator]string{ + adapter.ComparisonOperatorEqual: "=", + adapter.ComparisonOperatorNotEqual: "!=", + + adapter.ComparisonOperatorLessThan: "<", + adapter.ComparisonOperatorGreaterThan: ">", + + adapter.ComparisonOperatorLessThanOrEqualTo: "<=", + adapter.ComparisonOperatorGreaterThanOrEqualTo: ">=", + + adapter.ComparisonOperatorBetween: "BETWEEN", + adapter.ComparisonOperatorNotBetween: "NOT BETWEEN", + + adapter.ComparisonOperatorIn: "IN", + adapter.ComparisonOperatorNotIn: "NOT IN", + + adapter.ComparisonOperatorIs: "IS", + adapter.ComparisonOperatorIsNot: "IS NOT", + + adapter.ComparisonOperatorLike: "LIKE", + adapter.ComparisonOperatorNotLike: "NOT LIKE", + + adapter.ComparisonOperatorRegExp: "REGEXP", + adapter.ComparisonOperatorNotRegExp: "NOT REGEXP", +} + +type operatorWrapper struct { + tu *templateWithUtils + cv *exql.ColumnValue + + op *adapter.Comparison + v interface{} +} + +func (ow *operatorWrapper) cmp() *adapter.Comparison { + if ow.op != nil { + return ow.op + } + + if ow.cv.Operator != "" { + return mydb.Op(ow.cv.Operator, ow.v).Comparison + } + + if ow.v == nil { + return mydb.Is(nil).Comparison + } + + args, isSlice := toInterfaceArguments(ow.v) + if isSlice { + return mydb.In(args...).Comparison + } + + return mydb.Eq(ow.v).Comparison +} + +func (ow *operatorWrapper) preprocess() (string, []interface{}) { + placeholder := "?" + + column, err := ow.cv.Column.Compile(ow.tu.Template) + if err != nil { + panic(fmt.Sprintf("could not compile column: %v", err.Error())) + } + + c := ow.cmp() + + op := ow.tu.comparisonOperatorMapper(c.Operator()) + + var args []interface{} + + switch c.Operator() { + case adapter.ComparisonOperatorNone: + panic("no operator given") + case adapter.ComparisonOperatorCustom: + op = c.CustomOperator() + case adapter.ComparisonOperatorIn, adapter.ComparisonOperatorNotIn: + values := c.Value().([]interface{}) + if len(values) < 1 { + placeholder, args = "(NULL)", []interface{}{} + break + } + placeholder, args = "(?"+strings.Repeat(", ?", len(values)-1)+")", values + case adapter.ComparisonOperatorIs, adapter.ComparisonOperatorIsNot: + switch c.Value() { + case nil: + placeholder, args = "NULL", []interface{}{} + case false: + placeholder, args = "FALSE", []interface{}{} + case true: + placeholder, args = "TRUE", []interface{}{} + } + case adapter.ComparisonOperatorBetween, adapter.ComparisonOperatorNotBetween: + values := c.Value().([]interface{}) + placeholder, args = "? AND ?", []interface{}{values[0], values[1]} + case adapter.ComparisonOperatorEqual: + v := c.Value() + if b, ok := v.([]byte); ok { + v = string(b) + } + args = []interface{}{v} + } + + if args == nil { + args = []interface{}{c.Value()} + } + + if strings.Contains(op, ":column") { + return strings.Replace(op, ":column", column, -1), args + } + + return column + " " + op + " " + placeholder, args +} diff --git a/internal/sqlbuilder/convert.go b/internal/sqlbuilder/convert.go new file mode 100644 index 0000000..5036465 --- /dev/null +++ b/internal/sqlbuilder/convert.go @@ -0,0 +1,166 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql/driver" + "reflect" + + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +var ( + sqlDefault = &exql.Raw{Value: "DEFAULT"} +) + +func expandQuery(in []byte, inArgs []interface{}) ([]byte, []interface{}) { + out := make([]byte, 0, len(in)) + outArgs := make([]interface{}, 0, len(inArgs)) + + i := 0 + for i < len(in) && len(inArgs) > 0 { + if in[i] == '?' { + out = append(out, in[:i]...) + in = in[i+1:] + i = 0 + + replace, replaceArgs := expandArgument(inArgs[0]) + inArgs = inArgs[1:] + + if len(replace) > 0 { + replace, replaceArgs = expandQuery(replace, replaceArgs) + out = append(out, replace...) + } else { + out = append(out, '?') + } + + outArgs = append(outArgs, replaceArgs...) + continue + } + i = i + 1 + } + + if len(out) < 1 { + return in, inArgs + } + + out = append(out, in[:len(in)]...) + in = nil + + outArgs = append(outArgs, inArgs[:len(inArgs)]...) + inArgs = nil + + return out, outArgs +} + +func expandArgument(arg interface{}) ([]byte, []interface{}) { + values, isSlice := toInterfaceArguments(arg) + + if isSlice { + if len(values) == 0 { + return []byte("(NULL)"), nil + } + buf := bytes.Repeat([]byte(" ?,"), len(values)) + buf[0] = '(' + buf[len(buf)-1] = ')' + return buf, values + } + + if len(values) == 1 { + switch t := arg.(type) { + case *adapter.RawExpr: + return expandQuery([]byte(t.Raw()), t.Arguments()) + case hasPaginator: + p, err := t.Paginator() + if err == nil { + return append([]byte{'('}, append([]byte(p.String()), ')')...), p.Arguments() + } + panic(err.Error()) + case isCompilable: + s, err := t.Compile() + if err == nil { + return append([]byte{'('}, append([]byte(s), ')')...), t.Arguments() + } + panic(err.Error()) + } + } else if len(values) == 0 { + return []byte("NULL"), nil + } + + return nil, []interface{}{arg} +} + +// toInterfaceArguments converts the given value into an array of interfaces. +func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) { + if value == nil { + return nil, false + } + + switch t := value.(type) { + case driver.Valuer: + return []interface{}{t}, false + } + + v := reflect.ValueOf(value) + if v.Type().Kind() == reflect.Slice { + var i, total int + + // Byte slice gets transformed into a string. + if v.Type().Elem().Kind() == reflect.Uint8 { + return []interface{}{string(v.Bytes())}, false + } + + total = v.Len() + args = make([]interface{}, total) + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() + } + return args, true + } + + return []interface{}{value}, false +} + +// toColumnsValuesAndArguments maps the given columnNames and columnValues into +// expr's Columns and Values, it also extracts and returns query arguments. +func toColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*exql.Columns, *exql.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(exql.Columns) + + columns.Columns = make([]exql.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, exql.ColumnWithName(columnNames[i])) + } + + values := new(exql.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]exql.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *exql.Raw, exql.Raw: + values.Values = append(values.Values, sqlDefault) + case *exql.Value: + // Adding value. + values.Values = append(values.Values, v) + case exql.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} + +// Preprocess expands arguments that needs to be expanded and compiles a query +// into a single string. +func Preprocess(in string, args []interface{}) (string, []interface{}) { + b, args := expandQuery([]byte(in), args) + return string(b), args +} diff --git a/internal/sqlbuilder/custom_types.go b/internal/sqlbuilder/custom_types.go new file mode 100644 index 0000000..9b5e0cf --- /dev/null +++ b/internal/sqlbuilder/custom_types.go @@ -0,0 +1,11 @@ +package sqlbuilder + +import ( + "database/sql" + "database/sql/driver" +) + +type ScannerValuer interface { + sql.Scanner + driver.Valuer +} diff --git a/internal/sqlbuilder/delete.go b/internal/sqlbuilder/delete.go new file mode 100644 index 0000000..396be02 --- /dev/null +++ b/internal/sqlbuilder/delete.go @@ -0,0 +1,195 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type deleterQuery struct { + table string + limit int + + where *exql.Where + whereArgs []interface{} + + amendFn func(string) string +} + +func (dq *deleterQuery) and(b *sqlBuilder, terms ...interface{}) error { + where, whereArgs := b.t.toWhereWithArguments(terms) + + if dq.where == nil { + dq.where, dq.whereArgs = &exql.Where{}, []interface{}{} + } + dq.where.Append(&where) + dq.whereArgs = append(dq.whereArgs, whereArgs...) + + return nil +} + +func (dq *deleterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Delete, + Table: exql.TableWithName(dq.table), + } + + if dq.where != nil { + stmt.Where = dq.where + } + + if dq.limit != 0 { + stmt.Limit = exql.Limit(dq.limit) + } + + stmt.SetAmendment(dq.amendFn) + + return stmt +} + +type deleter struct { + builder *sqlBuilder + + fn func(*deleterQuery) error + prev *deleter +} + +var _ = immutable.Immutable(&deleter{}) + +func (del *deleter) SQL() *sqlBuilder { + if del.prev == nil { + return del.builder + } + return del.prev.SQL() +} + +func (del *deleter) template() *exql.Template { + return del.SQL().t.Template +} + +func (del *deleter) String() string { + s, err := del.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (del *deleter) setTable(table string) *deleter { + return del.frame(func(uq *deleterQuery) error { + uq.table = table + return nil + }) +} + +func (del *deleter) frame(fn func(*deleterQuery) error) *deleter { + return &deleter{prev: del, fn: fn} +} + +func (del *deleter) Where(terms ...interface{}) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + dq.where, dq.whereArgs = &exql.Where{}, []interface{}{} + return dq.and(del.SQL(), terms...) + }) +} + +func (del *deleter) And(terms ...interface{}) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + return dq.and(del.SQL(), terms...) + }) +} + +func (del *deleter) Limit(limit int) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + dq.limit = limit + return nil + }) +} + +func (del *deleter) Amend(fn func(string) string) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + dq.amendFn = fn + return nil + }) +} + +func (dq *deleterQuery) arguments() []interface{} { + return joinArguments(dq.whereArgs) +} + +func (del *deleter) Arguments() []interface{} { + dq, err := del.build() + if err != nil { + return nil + } + return dq.arguments() +} + +func (del *deleter) Prepare() (*sql.Stmt, error) { + return del.PrepareContext(del.SQL().sess.Context()) +} + +func (del *deleter) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + dq, err := del.build() + if err != nil { + return nil, err + } + return del.SQL().sess.StatementPrepare(ctx, dq.statement()) +} + +func (del *deleter) Exec() (sql.Result, error) { + return del.ExecContext(del.SQL().sess.Context()) +} + +func (del *deleter) ExecContext(ctx context.Context) (sql.Result, error) { + dq, err := del.build() + if err != nil { + return nil, err + } + return del.SQL().sess.StatementExec(ctx, dq.statement(), dq.arguments()...) +} + +func (del *deleter) statement() (*exql.Statement, error) { + iq, err := del.build() + if err != nil { + return nil, err + } + return iq.statement(), nil +} + +func (del *deleter) build() (*deleterQuery, error) { + dq, err := immutable.FastForward(del) + if err != nil { + return nil, err + } + return dq.(*deleterQuery), nil +} + +func (del *deleter) Compile() (string, error) { + s, err := del.statement() + if err != nil { + return "", err + } + return s.Compile(del.template()) +} + +func (del *deleter) Prev() immutable.Immutable { + if del == nil { + return nil + } + return del.prev +} + +func (del *deleter) Fn(in interface{}) error { + if del.fn == nil { + return nil + } + return del.fn(in.(*deleterQuery)) +} + +func (del *deleter) Base() interface{} { + return &deleterQuery{} +} diff --git a/internal/sqlbuilder/errors.go b/internal/sqlbuilder/errors.go new file mode 100644 index 0000000..5c5a723 --- /dev/null +++ b/internal/sqlbuilder/errors.go @@ -0,0 +1,14 @@ +package sqlbuilder + +import ( + "errors" +) + +// Common error messages. +var ( + ErrExpectingPointer = errors.New(`argument must be an address`) + ErrExpectingSlicePointer = errors.New(`argument must be a slice address`) + ErrExpectingSliceMapStruct = errors.New(`argument must be a slice address of maps or structs`) + ErrExpectingMapOrStruct = errors.New(`argument must be either a map or a struct`) + ErrExpectingPointerToEitherMapOrStruct = errors.New(`expecting a pointer to either a map or a struct`) +) diff --git a/internal/sqlbuilder/fetch.go b/internal/sqlbuilder/fetch.go new file mode 100644 index 0000000..6f686b4 --- /dev/null +++ b/internal/sqlbuilder/fetch.go @@ -0,0 +1,234 @@ +package sqlbuilder + +import ( + "reflect" + + "database/sql" + "database/sql/driver" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/reflectx" +) + +type sessValueConverter interface { + ConvertValue(interface{}) interface{} +} + +type valueConverter interface { + ConvertValue(in interface{}) (out interface { + sql.Scanner + driver.Valuer + }) +} + +var Mapper = reflectx.NewMapper("db") + +// fetchRow receives a *sql.Rows value and tries to map all the rows into a +// single struct given by the pointer `dst`. +func fetchRow(iter *iterator, dst interface{}) error { + var columns []string + var err error + + rows := iter.cursor + + dstv := reflect.ValueOf(dst) + + if dstv.IsNil() || dstv.Kind() != reflect.Ptr { + return ErrExpectingPointer + } + + itemV := dstv.Elem() + + if columns, err = rows.Columns(); err != nil { + return err + } + + reset(dst) + + next := rows.Next() + + if !next { + if err = rows.Err(); err != nil { + return err + } + return mydb.ErrNoMoreRows + } + + itemT := itemV.Type() + item, err := fetchResult(iter, itemT, columns) + if err != nil { + return err + } + + if itemT.Kind() == reflect.Ptr { + itemV.Set(item) + } else { + itemV.Set(reflect.Indirect(item)) + } + + return nil +} + +// fetchRows receives a *sql.Rows value and tries to map all the rows into a +// slice of structs given by the pointer `dst`. +func fetchRows(iter *iterator, dst interface{}) error { + var err error + rows := iter.cursor + defer rows.Close() + + // Destination. + dstv := reflect.ValueOf(dst) + + if dstv.IsNil() || dstv.Kind() != reflect.Ptr { + return ErrExpectingPointer + } + + if dstv.Elem().Kind() != reflect.Slice { + return ErrExpectingSlicePointer + } + + if dstv.Kind() != reflect.Ptr || dstv.Elem().Kind() != reflect.Slice || dstv.IsNil() { + return ErrExpectingSliceMapStruct + } + + var columns []string + if columns, err = rows.Columns(); err != nil { + return err + } + + slicev := dstv.Elem() + itemT := slicev.Type().Elem() + + reset(dst) + + for rows.Next() { + item, err := fetchResult(iter, itemT, columns) + if err != nil { + return err + } + if itemT.Kind() == reflect.Ptr { + slicev = reflect.Append(slicev, item) + } else { + slicev = reflect.Append(slicev, reflect.Indirect(item)) + } + } + + dstv.Elem().Set(slicev) + + return rows.Err() +} + +func fetchResult(iter *iterator, itemT reflect.Type, columns []string) (reflect.Value, error) { + + var item reflect.Value + var err error + rows := iter.cursor + + objT := itemT + + switch objT.Kind() { + case reflect.Map: + item = reflect.MakeMap(objT) + case reflect.Struct: + item = reflect.New(objT) + case reflect.Ptr: + objT = itemT.Elem() + if objT.Kind() != reflect.Struct { + return item, ErrExpectingMapOrStruct + } + item = reflect.New(objT) + default: + return item, ErrExpectingMapOrStruct + } + + switch objT.Kind() { + case reflect.Struct: + + values := make([]interface{}, len(columns)) + typeMap := Mapper.TypeMap(itemT) + fieldMap := typeMap.Names + + for i, k := range columns { + fi, ok := fieldMap[k] + if !ok { + values[i] = new(interface{}) + continue + } + + // Check for deprecated jsonb tag. + if _, hasJSONBTag := fi.Options["jsonb"]; hasJSONBTag { + return item, errDeprecatedJSONBTag + } + + f := reflectx.FieldByIndexes(item, fi.Index) + + // TODO: type switch + scanner + + if w, ok := f.Interface().(valueConverter); ok { + wrapper := w.ConvertValue(f.Addr().Interface()) + z := reflect.ValueOf(wrapper) + values[i] = z.Interface() + continue + } else { + values[i] = f.Addr().Interface() + } + + if unmarshaler, ok := values[i].(mydb.Unmarshaler); ok { + values[i] = scanner{unmarshaler} + continue + } + + if converter, ok := iter.sess.(sessValueConverter); ok { + values[i] = converter.ConvertValue(values[i]) + continue + } + } + + if err = rows.Scan(values...); err != nil { + return item, err + } + + case reflect.Map: + + columns, err := rows.Columns() + if err != nil { + return item, err + } + + values := make([]interface{}, len(columns)) + for i := range values { + if itemT.Elem().Kind() == reflect.Interface { + values[i] = new(interface{}) + } else { + values[i] = reflect.New(itemT.Elem()).Interface() + } + } + + if err = rows.Scan(values...); err != nil { + return item, err + } + + for i, column := range columns { + item.SetMapIndex(reflect.ValueOf(column), reflect.Indirect(reflect.ValueOf(values[i]))) + } + } + + return item, nil +} + +func reset(data interface{}) { + // Resetting element. + v := reflect.ValueOf(data).Elem() + t := v.Type() + + var z reflect.Value + + switch v.Kind() { + case reflect.Slice: + z = reflect.MakeSlice(t, 0, v.Cap()) + default: + z = reflect.Zero(t) + } + + v.Set(z) +} diff --git a/internal/sqlbuilder/insert.go b/internal/sqlbuilder/insert.go new file mode 100644 index 0000000..130264c --- /dev/null +++ b/internal/sqlbuilder/insert.go @@ -0,0 +1,285 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type inserterQuery struct { + table string + enqueuedValues [][]interface{} + returning []exql.Fragment + columns []exql.Fragment + values []*exql.Values + arguments []interface{} + amendFn func(string) string +} + +func (iq *inserterQuery) processValues() ([]*exql.Values, []interface{}, error) { + var values []*exql.Values + var arguments []interface{} + + var mapOptions *MapOptions + if len(iq.enqueuedValues) > 1 { + mapOptions = &MapOptions{IncludeZeroed: true, IncludeNil: true} + } + + for _, enqueuedValue := range iq.enqueuedValues { + if len(enqueuedValue) == 1 { + // If and only if we passed one argument to Values. + ff, vv, err := Map(enqueuedValue[0], mapOptions) + + if err == nil { + // If we didn't have any problem with mapping we can convert it into + // columns and values. + columns, vals, args, _ := toColumnsValuesAndArguments(ff, vv) + + values, arguments = append(values, vals), append(arguments, args...) + + if len(iq.columns) == 0 { + iq.columns = append(iq.columns, columns.Columns...) + } + continue + } + + // The only error we can expect without exiting is this argument not + // being a map or struct, in which case we can continue. + if !errors.Is(err, ErrExpectingPointerToEitherMapOrStruct) { + return nil, nil, err + } + } + + if len(iq.columns) == 0 || len(enqueuedValue) == len(iq.columns) { + arguments = append(arguments, enqueuedValue...) + + l := len(enqueuedValue) + placeholders := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + placeholders[i] = sqlPlaceholder + } + values = append(values, exql.NewValueGroup(placeholders...)) + } + } + + return values, arguments, nil +} + +func (iq *inserterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Insert, + Table: exql.TableWithName(iq.table), + } + + if len(iq.values) > 0 { + stmt.Values = exql.JoinValueGroups(iq.values...) + } + + if len(iq.columns) > 0 { + stmt.Columns = exql.JoinColumns(iq.columns...) + } + + if len(iq.returning) > 0 { + stmt.Returning = exql.ReturningColumns(iq.returning...) + } + + stmt.SetAmendment(iq.amendFn) + + return stmt +} + +type inserter struct { + builder *sqlBuilder + + fn func(*inserterQuery) error + prev *inserter +} + +var _ = immutable.Immutable(&inserter{}) + +func (ins *inserter) SQL() *sqlBuilder { + if ins.prev == nil { + return ins.builder + } + return ins.prev.SQL() +} + +func (ins *inserter) template() *exql.Template { + return ins.SQL().t.Template +} + +func (ins *inserter) String() string { + s, err := ins.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (ins *inserter) frame(fn func(*inserterQuery) error) *inserter { + return &inserter{prev: ins, fn: fn} +} + +func (ins *inserter) Batch(n int) mydb.BatchInserter { + return newBatchInserter(ins, n) +} + +func (ins *inserter) Amend(fn func(string) string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.amendFn = fn + return nil + }) +} + +func (ins *inserter) Arguments() []interface{} { + iq, err := ins.build() + if err != nil { + return nil + } + return iq.arguments +} + +func (ins *inserter) Returning(columns ...string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + columnsToFragments(&iq.returning, columns) + return nil + }) +} + +func (ins *inserter) Exec() (sql.Result, error) { + return ins.ExecContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) ExecContext(ctx context.Context) (sql.Result, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementExec(ctx, iq.statement(), iq.arguments...) +} + +func (ins *inserter) Prepare() (*sql.Stmt, error) { + return ins.PrepareContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementPrepare(ctx, iq.statement()) +} + +func (ins *inserter) Query() (*sql.Rows, error) { + return ins.QueryContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) QueryContext(ctx context.Context) (*sql.Rows, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementQuery(ctx, iq.statement(), iq.arguments...) +} + +func (ins *inserter) QueryRow() (*sql.Row, error) { + return ins.QueryRowContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) QueryRowContext(ctx context.Context) (*sql.Row, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementQueryRow(ctx, iq.statement(), iq.arguments...) +} + +func (ins *inserter) Iterator() mydb.Iterator { + return ins.IteratorContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) IteratorContext(ctx context.Context) mydb.Iterator { + rows, err := ins.QueryContext(ctx) + return &iterator{ins.SQL().sess, rows, err} +} + +func (ins *inserter) Into(table string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.table = table + return nil + }) +} + +func (ins *inserter) Columns(columns ...string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + columnsToFragments(&iq.columns, columns) + return nil + }) +} + +func (ins *inserter) Values(values ...interface{}) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.enqueuedValues = append(iq.enqueuedValues, values) + return nil + }) +} + +func (ins *inserter) statement() (*exql.Statement, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return iq.statement(), nil +} + +func (ins *inserter) build() (*inserterQuery, error) { + iq, err := immutable.FastForward(ins) + if err != nil { + return nil, err + } + ret := iq.(*inserterQuery) + ret.values, ret.arguments, err = ret.processValues() + if err != nil { + return nil, err + } + return ret, nil +} + +func (ins *inserter) Compile() (string, error) { + s, err := ins.statement() + if err != nil { + return "", err + } + return s.Compile(ins.template()) +} + +func (ins *inserter) Prev() immutable.Immutable { + if ins == nil { + return nil + } + return ins.prev +} + +func (ins *inserter) Fn(in interface{}) error { + if ins.fn == nil { + return nil + } + return ins.fn(in.(*inserterQuery)) +} + +func (ins *inserter) Base() interface{} { + return &inserterQuery{} +} + +func columnsToFragments(dst *[]exql.Fragment, columns []string) { + l := len(columns) + f := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + f[i] = exql.ColumnWithName(columns[i]) + } + *dst = append(*dst, f...) +} diff --git a/internal/sqlbuilder/paginate.go b/internal/sqlbuilder/paginate.go new file mode 100644 index 0000000..6d75d82 --- /dev/null +++ b/internal/sqlbuilder/paginate.go @@ -0,0 +1,340 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + "math" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" +) + +var ( + errMissingCursorColumn = errors.New("Missing cursor column") +) + +type paginatorQuery struct { + sel mydb.Selector + + cursorColumn string + cursorValue interface{} + cursorCond mydb.Cond + cursorReverseOrder bool + + pageSize uint + pageNumber uint +} + +func newPaginator(sel mydb.Selector, pageSize uint) mydb.Paginator { + pag := &paginator{} + return pag.frame(func(pq *paginatorQuery) error { + pq.pageSize = pageSize + pq.sel = sel + return nil + }).Page(1) +} + +func (pq *paginatorQuery) count() (uint64, error) { + var count uint64 + + row, err := pq.sel.(*selector).setColumns(mydb.Raw("count(1) AS _t")). + Limit(0). + Offset(0). + OrderBy(nil). + QueryRow() + if err != nil { + return 0, err + } + + err = row.Scan(&count) + if err != nil { + return 0, err + } + + return count, nil +} + +type paginator struct { + fn func(*paginatorQuery) error + prev *paginator +} + +var _ = immutable.Immutable(&paginator{}) + +func (pag *paginator) frame(fn func(*paginatorQuery) error) *paginator { + return &paginator{prev: pag, fn: fn} +} + +func (pag *paginator) Page(pageNumber uint) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + if pageNumber < 1 { + pageNumber = 1 + } + pq.pageNumber = pageNumber + return nil + }) +} + +func (pag *paginator) Cursor(column string) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + pq.cursorColumn = column + pq.cursorValue = nil + return nil + }) +} + +func (pag *paginator) NextPage(cursorValue interface{}) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + if pq.cursorValue != nil && pq.cursorColumn == "" { + return errMissingCursorColumn + } + pq.cursorValue = cursorValue + pq.cursorReverseOrder = false + if strings.HasPrefix(pq.cursorColumn, "-") { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn[1:]: mydb.Lt(cursorValue), + } + } else { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn: mydb.Gt(cursorValue), + } + } + return nil + }) +} + +func (pag *paginator) PrevPage(cursorValue interface{}) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + if pq.cursorValue != nil && pq.cursorColumn == "" { + return errMissingCursorColumn + } + pq.cursorValue = cursorValue + pq.cursorReverseOrder = true + if strings.HasPrefix(pq.cursorColumn, "-") { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn[1:]: mydb.Gt(cursorValue), + } + } else { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn: mydb.Lt(cursorValue), + } + } + return nil + }) +} + +func (pag *paginator) TotalPages() (uint, error) { + pq, err := pag.build() + if err != nil { + return 0, err + } + + count, err := pq.count() + if err != nil { + return 0, err + } + if count < 1 { + return 0, nil + } + + if pq.pageSize < 1 { + return 1, nil + } + + pages := uint(math.Ceil(float64(count) / float64(pq.pageSize))) + return pages, nil +} + +func (pag *paginator) All(dest interface{}) error { + pq, err := pag.buildWithCursor() + if err != nil { + return err + } + err = pq.sel.All(dest) + if err != nil { + return err + } + return nil +} + +func (pag *paginator) One(dest interface{}) error { + pq, err := pag.buildWithCursor() + if err != nil { + return err + } + return pq.sel.One(dest) +} + +func (pag *paginator) Iterator() mydb.Iterator { + pq, err := pag.buildWithCursor() + if err != nil { + sess := pq.sel.(*selector).SQL().sess + return &iterator{sess, nil, err} + } + return pq.sel.Iterator() +} + +func (pag *paginator) IteratorContext(ctx context.Context) mydb.Iterator { + pq, err := pag.buildWithCursor() + if err != nil { + sess := pq.sel.(*selector).SQL().sess + return &iterator{sess, nil, err} + } + return pq.sel.IteratorContext(ctx) +} + +func (pag *paginator) String() string { + pq, err := pag.buildWithCursor() + if err != nil { + panic(err.Error()) + } + return pq.sel.String() +} + +func (pag *paginator) Arguments() []interface{} { + pq, err := pag.buildWithCursor() + if err != nil { + return nil + } + return pq.sel.Arguments() +} + +func (pag *paginator) Compile() (string, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return "", err + } + return pq.sel.(*selector).Compile() +} + +func (pag *paginator) Query() (*sql.Rows, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.Query() +} + +func (pag *paginator) QueryContext(ctx context.Context) (*sql.Rows, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.QueryContext(ctx) +} + +func (pag *paginator) QueryRow() (*sql.Row, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.QueryRow() +} + +func (pag *paginator) QueryRowContext(ctx context.Context) (*sql.Row, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.QueryRowContext(ctx) +} + +func (pag *paginator) Prepare() (*sql.Stmt, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.Prepare() +} + +func (pag *paginator) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.PrepareContext(ctx) +} + +func (pag *paginator) TotalEntries() (uint64, error) { + pq, err := pag.build() + if err != nil { + return 0, err + } + return pq.count() +} + +func (pag *paginator) build() (*paginatorQuery, error) { + pq, err := immutable.FastForward(pag) + if err != nil { + return nil, err + } + return pq.(*paginatorQuery), nil +} + +func (pag *paginator) buildWithCursor() (*paginatorQuery, error) { + pq, err := immutable.FastForward(pag) + if err != nil { + return nil, err + } + + pqq := pq.(*paginatorQuery) + + if pqq.cursorReverseOrder { + orderBy := pqq.cursorColumn + + if orderBy == "" { + return nil, errMissingCursorColumn + } + + if strings.HasPrefix(orderBy, "-") { + orderBy = orderBy[1:] + } else { + orderBy = "-" + orderBy + } + + pqq.sel = pqq.sel.OrderBy(orderBy) + } + + if pqq.pageSize > 0 { + pqq.sel = pqq.sel.Limit(int(pqq.pageSize)) + if pqq.pageNumber > 1 { + pqq.sel = pqq.sel.Offset(int(pqq.pageSize * (pqq.pageNumber - 1))) + } + } + + if pqq.cursorCond != nil { + pqq.sel = pqq.sel.Where(pqq.cursorCond).Offset(0) + } + + if pqq.cursorColumn != "" { + if pqq.cursorReverseOrder { + pqq.sel = pqq.sel.(*selector).SQL(). + SelectFrom(mydb.Raw("? AS p0", pqq.sel)). + OrderBy(pqq.cursorColumn) + } else { + pqq.sel = pqq.sel.OrderBy(pqq.cursorColumn) + } + } + + return pqq, nil +} + +func (pag *paginator) Prev() immutable.Immutable { + if pag == nil { + return nil + } + return pag.prev +} + +func (pag *paginator) Fn(in interface{}) error { + if pag.fn == nil { + return nil + } + return pag.fn(in.(*paginatorQuery)) +} + +func (pag *paginator) Base() interface{} { + return &paginatorQuery{} +} diff --git a/internal/sqlbuilder/placeholder_test.go b/internal/sqlbuilder/placeholder_test.go new file mode 100644 index 0000000..bbea17c --- /dev/null +++ b/internal/sqlbuilder/placeholder_test.go @@ -0,0 +1,146 @@ +package sqlbuilder + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/assert" +) + +func TestPrepareForDisplay(t *testing.T) { + samples := []struct { + In string + Out string + }{ + { + In: "12345", + Out: "12345", + }, + { + In: "\r\n\t12345", + Out: "12345", + }, + { + In: "12345\r\n\t", + Out: "12345", + }, + { + In: "\r\n\t1\r2\n3\t4\r5\r\n\t", + Out: "1 2 3 4 5", + }, + { + In: "\r\n \t 1\r 2\n 3\t 4\r 5\r \n\t", + Out: "1 2 3 4 5", + }, + { + In: "\r\n \t 11\r 22\n 33\t 44 \r 55", + Out: "11 22 33 44 55", + }, + { + In: "11\r 22\n 33\t 44 \r 55", + Out: "11 22 33 44 55", + }, + { + In: "1 2 3 4 5", + Out: "1 2 3 4 5", + }, + { + In: "?", + Out: "$1", + }, + { + In: "? ?", + Out: "$1 $2", + }, + { + In: "? ? ?", + Out: "$1 $2 $3", + }, + { + In: " ? ? ? ", + Out: "$1 $2 $3", + }, + { + In: "???", + Out: "$1$2$3", + }, + } + for _, sample := range samples { + assert.Equal(t, sample.Out, prepareQueryForDisplay(sample.In)) + } +} + +func TestPlaceholderSimple(t *testing.T) { + { + ret, _ := Preprocess("?", []interface{}{1}) + assert.Equal(t, "?", ret) + } + { + ret, _ := Preprocess("?", nil) + assert.Equal(t, "?", ret) + } +} + +func TestPlaceholderMany(t *testing.T) { + { + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, 3}) + assert.Equal(t, "?, ?, ?", ret) + } +} + +func TestPlaceholderArray(t *testing.T) { + { + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) + assert.Equal(t, "?, ?, (?, ?, ?)", ret) + } + + { + ret, _ := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) + assert.Equal(t, "(?, ?, ?), ?, ?", ret) + } + + { + ret, _ := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) + assert.Equal(t, "?, (?, ?, ?), ?", ret) + } + + { + ret, _ := Preprocess("???", []interface{}{1, []interface{}{2, 3, 4}, 5}) + assert.Equal(t, "?(?, ?, ?)?", ret) + } + + { + ret, _ := Preprocess("??", []interface{}{[]interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}}) + assert.Equal(t, "(?, ?, ?)(NULL)", ret) + } +} + +func TestPlaceholderArguments(t *testing.T) { + { + _, args := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := Preprocess("?, ?", []interface{}{[]interface{}{1, 2, 3}, []interface{}{4, 5}}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } +} + +func TestPlaceholderReplace(t *testing.T) { + { + ret, args := Preprocess("?, ?, ?", []interface{}{1, mydb.Raw("foo"), 3}) + assert.Equal(t, "?, foo, ?", ret) + assert.Equal(t, []interface{}{1, 3}, args) + } +} diff --git a/internal/sqlbuilder/scanner.go b/internal/sqlbuilder/scanner.go new file mode 100644 index 0000000..228a8e0 --- /dev/null +++ b/internal/sqlbuilder/scanner.go @@ -0,0 +1,17 @@ +package sqlbuilder + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" +) + +type scanner struct { + v mydb.Unmarshaler +} + +func (u scanner) Scan(v interface{}) error { + return u.v.UnmarshalDB(v) +} + +var _ sql.Scanner = scanner{} diff --git a/internal/sqlbuilder/select.go b/internal/sqlbuilder/select.go new file mode 100644 index 0000000..cbf90b0 --- /dev/null +++ b/internal/sqlbuilder/select.go @@ -0,0 +1,524 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type selectorQuery struct { + table *exql.Columns + tableArgs []interface{} + + distinct bool + + where *exql.Where + whereArgs []interface{} + + groupBy *exql.GroupBy + groupByArgs []interface{} + + orderBy *exql.OrderBy + orderByArgs []interface{} + + limit exql.Limit + offset exql.Offset + + columns *exql.Columns + columnsArgs []interface{} + + joins []*exql.Join + joinsArgs []interface{} + + amendFn func(string) string +} + +func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error { + where, whereArgs := b.t.toWhereWithArguments(terms) + + if sq.where == nil { + sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} + } + sq.where.Append(&where) + sq.whereArgs = append(sq.whereArgs, whereArgs...) + + return nil +} + +func (sq *selectorQuery) arguments() []interface{} { + return joinArguments( + sq.columnsArgs, + sq.tableArgs, + sq.joinsArgs, + sq.whereArgs, + sq.groupByArgs, + sq.orderByArgs, + ) +} + +func (sq *selectorQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Select, + Table: sq.table, + Columns: sq.columns, + Distinct: sq.distinct, + Limit: sq.limit, + Offset: sq.offset, + Where: sq.where, + OrderBy: sq.orderBy, + GroupBy: sq.groupBy, + } + + if len(sq.joins) > 0 { + stmt.Joins = exql.JoinConditions(sq.joins...) + } + + stmt.SetAmendment(sq.amendFn) + + return stmt +} + +func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { + fragments, args, err := columnFragments(tables) + if err != nil { + return err + } + + if sq.joins == nil { + sq.joins = []*exql.Join{} + } + sq.joins = append(sq.joins, + &exql.Join{ + Type: t, + Table: exql.JoinColumns(fragments...), + }, + ) + + sq.joinsArgs = append(sq.joinsArgs, args...) + + return nil +} + +type selector struct { + builder *sqlBuilder + + fn func(*selectorQuery) error + prev *selector +} + +var _ = immutable.Immutable(&selector{}) + +func (sel *selector) SQL() *sqlBuilder { + if sel.prev == nil { + return sel.builder + } + return sel.prev.SQL() +} + +func (sel *selector) String() string { + s, err := sel.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (sel *selector) frame(fn func(*selectorQuery) error) *selector { + return &selector{prev: sel, fn: fn} +} + +func (sel *selector) clone() mydb.Selector { + return sel.frame(func(*selectorQuery) error { + return nil + }) +} + +func (sel *selector) From(tables ...interface{}) mydb.Selector { + return sel.frame( + func(sq *selectorQuery) error { + fragments, args, err := columnFragments(tables) + if err != nil { + return err + } + sq.table = exql.JoinColumns(fragments...) + sq.tableArgs = args + return nil + }, + ) +} + +func (sel *selector) setColumns(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.columns = nil + return sq.pushColumns(columns...) + }) +} + +func (sel *selector) Columns(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushColumns(columns...) + }) +} + +func (sq *selectorQuery) pushColumns(columns ...interface{}) error { + f, args, err := columnFragments(columns) + if err != nil { + return err + } + + c := exql.JoinColumns(f...) + + if sq.columns != nil { + sq.columns.Append(c) + } else { + sq.columns = c + } + + sq.columnsArgs = append(sq.columnsArgs, args...) + return nil +} + +func (sel *selector) Distinct(exps ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.distinct = true + return sq.pushColumns(exps...) + }) +} + +func (sel *selector) Where(terms ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if len(terms) == 1 && terms[0] == nil { + sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} + return nil + } + return sq.and(sel.SQL(), terms...) + }) +} + +func (sel *selector) And(terms ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.and(sel.SQL(), terms...) + }) +} + +func (sel *selector) Amend(fn func(string) string) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.amendFn = fn + return nil + }) +} + +func (sel *selector) Arguments() []interface{} { + sq, err := sel.build() + if err != nil { + return nil + } + return sq.arguments() +} + +func (sel *selector) GroupBy(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + fragments, args, err := columnFragments(columns) + if err != nil { + return err + } + + if fragments != nil { + sq.groupBy = exql.GroupByColumns(fragments...) + } + sq.groupByArgs = args + + return nil + }) +} + +func (sel *selector) OrderBy(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + + if len(columns) == 1 && columns[0] == nil { + sq.orderBy = nil + sq.orderByArgs = nil + return nil + } + + var sortColumns exql.SortColumns + + for i := range columns { + var sort *exql.SortColumn + + switch value := columns[i].(type) { + case *adapter.RawExpr: + query, args := Preprocess(value.Raw(), value.Arguments()) + sort = &exql.SortColumn{ + Column: &exql.Raw{Value: query}, + } + sq.orderByArgs = append(sq.orderByArgs, args...) + case *adapter.FuncExpr: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + sort = &exql.SortColumn{ + Column: &exql.Raw{Value: fnName}, + } + sq.orderByArgs = append(sq.orderByArgs, fnArgs...) + case string: + if strings.HasPrefix(value, "-") { + sort = &exql.SortColumn{ + Column: exql.ColumnWithName(value[1:]), + Order: exql.Order_Descendent, + } + } else { + chunks := strings.SplitN(value, " ", 2) + + order := exql.Order_Ascendent + if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { + order = exql.Order_Descendent + } + + sort = &exql.SortColumn{ + Column: exql.ColumnWithName(chunks[0]), + Order: order, + } + } + default: + return fmt.Errorf("Can't sort by type %T", value) + } + sortColumns.Columns = append(sortColumns.Columns, sort) + } + + sq.orderBy = &exql.OrderBy{ + SortColumns: &sortColumns, + } + return nil + }) +} + +func (sel *selector) Using(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + + joins := len(sq.joins) + if joins == 0 { + return errors.New(`cannot use Using() without a preceding Join() expression`) + } + + lastJoin := sq.joins[joins-1] + if lastJoin.On != nil { + return errors.New(`cannot use Using() and On() with the same Join() expression`) + } + + fragments, args, err := columnFragments(columns) + if err != nil { + return err + } + + sq.joinsArgs = append(sq.joinsArgs, args...) + lastJoin.Using = exql.UsingColumns(fragments...) + + return nil + }) +} + +func (sel *selector) FullJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("FULL", tables) + }) +} + +func (sel *selector) CrossJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("CROSS", tables) + }) +} + +func (sel *selector) RightJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("RIGHT", tables) + }) +} + +func (sel *selector) LeftJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("LEFT", tables) + }) +} + +func (sel *selector) Join(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("", tables) + }) +} + +func (sel *selector) On(terms ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + joins := len(sq.joins) + + if joins == 0 { + return errors.New(`cannot use On() without a preceding Join() expression`) + } + + lastJoin := sq.joins[joins-1] + if lastJoin.On != nil { + return errors.New(`cannot use Using() and On() with the same Join() expression`) + } + + w, a := sel.SQL().t.toWhereWithArguments(terms) + o := exql.On(w) + + lastJoin.On = &o + + sq.joinsArgs = append(sq.joinsArgs, a...) + + return nil + }) +} + +func (sel *selector) Limit(n int) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if n < 0 { + n = 0 + } + sq.limit = exql.Limit(n) + return nil + }) +} + +func (sel *selector) Offset(n int) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if n < 0 { + n = 0 + } + sq.offset = exql.Offset(n) + return nil + }) +} + +func (sel *selector) template() *exql.Template { + return sel.SQL().t.Template +} + +func (sel *selector) As(alias string) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if sq.table == nil { + return errors.New("Cannot use As() without a preceding From() expression") + } + last := len(sq.table.Columns) - 1 + if raw, ok := sq.table.Columns[last].(*exql.Raw); ok { + compiled, err := exql.ColumnWithName(alias).Compile(sel.template()) + if err != nil { + return err + } + sq.table.Columns[last] = &exql.Raw{Value: raw.Value + " AS " + compiled} + } + return nil + }) +} + +func (sel *selector) statement() *exql.Statement { + sq, _ := sel.build() + return sq.statement() +} + +func (sel *selector) QueryRow() (*sql.Row, error) { + return sel.QueryRowContext(sel.SQL().sess.Context()) +} + +func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + + return sel.SQL().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...) +} + +func (sel *selector) Prepare() (*sql.Stmt, error) { + return sel.PrepareContext(sel.SQL().sess.Context()) +} + +func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + return sel.SQL().sess.StatementPrepare(ctx, sq.statement()) +} + +func (sel *selector) Query() (*sql.Rows, error) { + return sel.QueryContext(sel.SQL().sess.Context()) +} + +func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + return sel.SQL().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) +} + +func (sel *selector) Iterator() mydb.Iterator { + return sel.IteratorContext(sel.SQL().sess.Context()) +} + +func (sel *selector) IteratorContext(ctx context.Context) mydb.Iterator { + sess := sel.SQL().sess + sq, err := sel.build() + if err != nil { + return &iterator{sess, nil, err} + } + + rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) + return &iterator{sess, rows, err} +} + +func (sel *selector) Paginate(pageSize uint) mydb.Paginator { + return newPaginator(sel.clone(), pageSize) +} + +func (sel *selector) All(destSlice interface{}) error { + return sel.Iterator().All(destSlice) +} + +func (sel *selector) One(dest interface{}) error { + return sel.Iterator().One(dest) +} + +func (sel *selector) build() (*selectorQuery, error) { + sq, err := immutable.FastForward(sel) + if err != nil { + return nil, err + } + return sq.(*selectorQuery), nil +} + +func (sel *selector) Compile() (string, error) { + return sel.statement().Compile(sel.template()) +} + +func (sel *selector) Prev() immutable.Immutable { + if sel == nil { + return nil + } + return sel.prev +} + +func (sel *selector) Fn(in interface{}) error { + if sel.fn == nil { + return nil + } + return sel.fn(in.(*selectorQuery)) +} + +func (sel *selector) Base() interface{} { + return &selectorQuery{} +} diff --git a/internal/sqlbuilder/sqlbuilder.go b/internal/sqlbuilder/sqlbuilder.go new file mode 100644 index 0000000..7afa620 --- /dev/null +++ b/internal/sqlbuilder/sqlbuilder.go @@ -0,0 +1,40 @@ +package sqlbuilder + +import ( + "database/sql" + "fmt" + + "git.hexq.cn/tiglog/mydb" +) + +// Engine represents a SQL database engine. +type Engine interface { + mydb.Session + + mydb.SQL +} + +func lookupAdapter(adapterName string) (Adapter, error) { + adapter := mydb.LookupAdapter(adapterName) + if sqlAdapter, ok := adapter.(Adapter); ok { + return sqlAdapter, nil + } + return nil, fmt.Errorf("%w %q", mydb.ErrMissingAdapter, adapterName) +} + +func BindTx(adapterName string, tx *sql.Tx) (Tx, error) { + adapter, err := lookupAdapter(adapterName) + if err != nil { + return nil, err + } + return adapter.NewTx(tx) +} + +// Bind creates a binding between an adapter and a *sql.Tx or a *sql.mydb. +func BindDB(adapterName string, sess *sql.DB) (mydb.Session, error) { + adapter, err := lookupAdapter(adapterName) + if err != nil { + return nil, err + } + return adapter.New(sess) +} diff --git a/internal/sqlbuilder/template.go b/internal/sqlbuilder/template.go new file mode 100644 index 0000000..c27466d --- /dev/null +++ b/internal/sqlbuilder/template.go @@ -0,0 +1,332 @@ +package sqlbuilder + +import ( + "database/sql/driver" + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type templateWithUtils struct { + *exql.Template +} + +func newTemplateWithUtils(template *exql.Template) *templateWithUtils { + return &templateWithUtils{template} +} + +func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { + switch t := in.(type) { + case *adapter.RawExpr: + return &exql.Raw{Value: t.Raw()}, t.Arguments() + case *adapter.FuncExpr: + fnName := t.Name() + fnArgs := []interface{}{} + args, _ := toInterfaceArguments(t.Arguments()) + fragments := []string{} + for i := range args { + frag, args := tu.PlaceholderValue(args[i]) + fragment, err := frag.Compile(tu.Template) + if err == nil { + fragments = append(fragments, fragment) + fnArgs = append(fnArgs, args...) + } + } + return &exql.Raw{Value: fnName + `(` + strings.Join(fragments, `, `) + `)`}, fnArgs + default: + return sqlPlaceholder, []interface{}{in} + } +} + +// toWhereWithArguments converts the given parameters into a exql.Where value. +func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + if len(t) > 0 { + if s, ok := t[0].(string); ok { + if strings.ContainsAny(s, "?") || len(t) == 1 { + s, args = Preprocess(s, t[1:]) + where.Conditions = []exql.Fragment{&exql.Raw{Value: s}} + } else { + var val interface{} + key := s + if len(t) > 2 { + val = t[1:] + } else { + val = t[1] + } + cv, v := tu.toColumnValues(adapter.NewConstraint(key, val)) + args = append(args, v...) + where.Conditions = append(where.Conditions, cv.ColumnValues...) + } + return + } + } + for i := range t { + w, v := tu.toWhereWithArguments(t[i]) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case *adapter.RawExpr: + r, v := Preprocess(t.Raw(), t.Arguments()) + where.Conditions = []exql.Fragment{&exql.Raw{Value: r}} + args = append(args, v...) + return + case adapter.Constraints: + for _, c := range t.Constraints() { + w, v := tu.toWhereWithArguments(c) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case adapter.LogicalExpr: + var cond exql.Where + + expressions := t.Expressions() + for i := range expressions { + w, v := tu.toWhereWithArguments(expressions[i]) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + cond.Conditions = append(cond.Conditions, w.Conditions...) + } + if len(cond.Conditions) < 1 { + return + } + + if len(cond.Conditions) <= 1 { + where.Conditions = append(where.Conditions, cond.Conditions...) + return where, args + } + + var frag exql.Fragment + switch t.Operator() { + case adapter.LogicalOperatorNone, adapter.LogicalOperatorAnd: + q := exql.And(cond) + frag = &q + case adapter.LogicalOperatorOr: + q := exql.Or(cond) + frag = &q + default: + panic(fmt.Sprintf("Unknown type %T", t)) + } + where.Conditions = append(where.Conditions, frag) + return + + case mydb.InsertResult: + return tu.toWhereWithArguments(t.ID()) + + case adapter.Constraint: + cv, v := tu.toColumnValues(t) + args = append(args, v...) + where.Conditions = append(where.Conditions, cv.ColumnValues...) + return where, args + } + + panic(fmt.Sprintf("Unknown condition type %T", term)) +} + +func (tu *templateWithUtils) comparisonOperatorMapper(t adapter.ComparisonOperator) string { + if t == adapter.ComparisonOperatorCustom { + return "" + } + if tu.ComparisonOperator != nil { + if op, ok := tu.ComparisonOperator[t]; ok { + return op + } + } + if op, ok := comparisonOperators[t]; ok { + return op + } + panic(fmt.Sprintf("unsupported comparison operator %v", t)) +} + +func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case adapter.Constraint: + columnValue := exql.ColumnValue{} + + // TODO: Key and Value are similar. Can we refactor this? Maybe think about + // Left/Right rather than Key/Value. + + switch key := t.Key().(type) { + case string: + chunks := strings.SplitN(strings.TrimSpace(key), " ", 2) + columnValue.Column = exql.ColumnWithName(chunks[0]) + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } + case *adapter.RawExpr: + columnValue.Column = &exql.Raw{Value: key.Raw()} + args = append(args, key.Arguments()...) + case *mydb.FuncExpr: + fnName, fnArgs := key.Name(), key.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + columnValue.Column = &exql.Raw{Value: fnName} + args = append(args, fnArgs...) + default: + columnValue.Column = &exql.Raw{Value: fmt.Sprintf("%v", key)} + } + + switch value := t.Value().(type) { + case *mydb.FuncExpr: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + // A function with no arguments. + fnName = fnName + "()" + } else { + // A function with one or more arguments. + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + columnValue.Value = &exql.Raw{Value: fnName} + args = append(args, fnArgs...) + case *mydb.RawExpr: + q, a := Preprocess(value.Raw(), value.Arguments()) + columnValue.Value = &exql.Raw{Value: q} + args = append(args, a...) + case driver.Valuer: + columnValue.Value = sqlPlaceholder + args = append(args, value) + case *mydb.Comparison: + wrapper := &operatorWrapper{ + tu: tu, + cv: &columnValue, + op: value.Comparison, + } + + q, a := wrapper.preprocess() + q, a = Preprocess(q, a) + + columnValue = exql.ColumnValue{ + Column: &exql.Raw{Value: q}, + } + if a != nil { + args = append(args, a...) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + default: + wrapper := &operatorWrapper{ + tu: tu, + cv: &columnValue, + v: value, + } + + q, a := wrapper.preprocess() + q, a = Preprocess(q, a) + + columnValue = exql.ColumnValue{ + Column: &exql.Raw{Value: q}, + } + if a != nil { + args = append(args, a...) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + } + + if columnValue.Operator == "" { + columnValue.Operator = tu.comparisonOperatorMapper(adapter.ComparisonOperatorEqual) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + + case *adapter.RawExpr: + columnValue := exql.ColumnValue{} + p, q := Preprocess(t.Raw(), t.Arguments()) + columnValue.Column = &exql.Raw{Value: p} + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + args = append(args, q...) + return cv, args + + case adapter.Constraints: + for _, constraint := range t.Constraints() { + p, q := tu.toColumnValues(constraint) + cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) + args = append(args, q...) + } + return cv, args + } + + panic(fmt.Sprintf("Unknown term type %T.", term)) +} + +func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + l := len(t) + for i := 0; i < l; i++ { + column, isString := t[i].(string) + + if !isString { + p, q := tu.setColumnValues(t[i]) + cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) + args = append(args, q...) + continue + } + + if !strings.ContainsAny(column, tu.AssignmentOperator) { + column = column + " " + tu.AssignmentOperator + " ?" + } + + chunks := strings.SplitN(column, tu.AssignmentOperator, 2) + + column = chunks[0] + format := strings.TrimSpace(chunks[1]) + + columnValue := exql.ColumnValue{ + Column: exql.ColumnWithName(column), + Operator: tu.AssignmentOperator, + Value: &exql.Raw{Value: format}, + } + + ps := strings.Count(format, "?") + if i+ps < l { + for j := 0; j < ps; j++ { + args = append(args, t[i+j+1]) + } + i = i + ps + } else { + panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format)) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + } + return cv, args + case *adapter.RawExpr: + columnValue := exql.ColumnValue{} + p, q := Preprocess(t.Raw(), t.Arguments()) + columnValue.Column = &exql.Raw{Value: p} + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + args = append(args, q...) + return cv, args + } + + panic(fmt.Sprintf("Unknown term type %T.", term)) +} diff --git a/internal/sqlbuilder/template_test.go b/internal/sqlbuilder/template_test.go new file mode 100644 index 0000000..f185afd --- /dev/null +++ b/internal/sqlbuilder/template_test.go @@ -0,0 +1,192 @@ +package sqlbuilder + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + defaultColumnSeparator = `.` + defaultIdentifierSeparator = `, ` + defaultIdentifierQuote = `"{{.Value}}"` + defaultValueSeparator = `, ` + defaultValueQuote = `'{{.}}'` + defaultAndKeyword = `AND` + defaultOrKeyword = `OR` + defaultDescKeyword = `DESC` + defaultAscKeyword = `ASC` + defaultAssignmentOperator = `=` + defaultClauseGroup = `({{.}})` + defaultClauseOperator = ` {{.}} ` + defaultColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + defaultTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultSortByColumnLayout = `{{.Column}} {{.Order}}` + + defaultOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + defaultWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + defaultUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + defaultJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + defaultOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + defaultSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + defaultDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + defaultUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + defaultCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + + defaultInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if defined .Columns }}({{.Columns | compile}}){{end}} + VALUES + {{if defined .Values}} + {{.Values | compile}} + {{else}} + (default) + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + defaultTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} + ` + + defaultDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + defaultDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + defaultGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var testTemplate = exql.Template{ + ColumnSeparator: defaultColumnSeparator, + IdentifierSeparator: defaultIdentifierSeparator, + IdentifierQuote: defaultIdentifierQuote, + ValueSeparator: defaultValueSeparator, + ValueQuote: defaultValueQuote, + AndKeyword: defaultAndKeyword, + OrKeyword: defaultOrKeyword, + DescKeyword: defaultDescKeyword, + AscKeyword: defaultAscKeyword, + AssignmentOperator: defaultAssignmentOperator, + ClauseGroup: defaultClauseGroup, + ClauseOperator: defaultClauseOperator, + ColumnValue: defaultColumnValue, + TableAliasLayout: defaultTableAliasLayout, + ColumnAliasLayout: defaultColumnAliasLayout, + SortByColumnLayout: defaultSortByColumnLayout, + WhereLayout: defaultWhereLayout, + OnLayout: defaultOnLayout, + UsingLayout: defaultUsingLayout, + JoinLayout: defaultJoinLayout, + OrderByLayout: defaultOrderByLayout, + InsertLayout: defaultInsertLayout, + SelectLayout: defaultSelectLayout, + UpdateLayout: defaultUpdateLayout, + DeleteLayout: defaultDeleteLayout, + TruncateLayout: defaultTruncateLayout, + DropDatabaseLayout: defaultDropDatabaseLayout, + DropTableLayout: defaultDropTableLayout, + CountLayout: defaultCountLayout, + GroupByLayout: defaultGroupByLayout, + Cache: cache.NewCache(), +} diff --git a/internal/sqlbuilder/update.go b/internal/sqlbuilder/update.go new file mode 100644 index 0000000..fb3ae30 --- /dev/null +++ b/internal/sqlbuilder/update.go @@ -0,0 +1,242 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type updaterQuery struct { + table string + + columnValues *exql.ColumnValues + columnValuesArgs []interface{} + + limit int + + where *exql.Where + whereArgs []interface{} + + amendFn func(string) string +} + +func (uq *updaterQuery) and(b *sqlBuilder, terms ...interface{}) error { + where, whereArgs := b.t.toWhereWithArguments(terms) + + if uq.where == nil { + uq.where, uq.whereArgs = &exql.Where{}, []interface{}{} + } + uq.where.Append(&where) + uq.whereArgs = append(uq.whereArgs, whereArgs...) + + return nil +} + +func (uq *updaterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Update, + Table: exql.TableWithName(uq.table), + ColumnValues: uq.columnValues, + } + + if uq.where != nil { + stmt.Where = uq.where + } + + if uq.limit != 0 { + stmt.Limit = exql.Limit(uq.limit) + } + + stmt.SetAmendment(uq.amendFn) + + return stmt +} + +func (uq *updaterQuery) arguments() []interface{} { + return joinArguments( + uq.columnValuesArgs, + uq.whereArgs, + ) +} + +type updater struct { + builder *sqlBuilder + + fn func(*updaterQuery) error + prev *updater +} + +var _ = immutable.Immutable(&updater{}) + +func (upd *updater) SQL() *sqlBuilder { + if upd.prev == nil { + return upd.builder + } + return upd.prev.SQL() +} + +func (upd *updater) template() *exql.Template { + return upd.SQL().t.Template +} + +func (upd *updater) String() string { + s, err := upd.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (upd *updater) setTable(table string) *updater { + return upd.frame(func(uq *updaterQuery) error { + uq.table = table + return nil + }) +} + +func (upd *updater) frame(fn func(*updaterQuery) error) *updater { + return &updater{prev: upd, fn: fn} +} + +func (upd *updater) Set(terms ...interface{}) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + if uq.columnValues == nil { + uq.columnValues = &exql.ColumnValues{} + } + + if len(terms) == 1 { + ff, vv, err := Map(terms[0], nil) + if err == nil && len(ff) > 0 { + cvs := make([]exql.Fragment, 0, len(ff)) + args := make([]interface{}, 0, len(vv)) + + for i := range ff { + cv := &exql.ColumnValue{ + Column: exql.ColumnWithName(ff[i]), + Operator: upd.SQL().t.AssignmentOperator, + } + + var localArgs []interface{} + cv.Value, localArgs = upd.SQL().t.PlaceholderValue(vv[i]) + + args = append(args, localArgs...) + cvs = append(cvs, cv) + } + + uq.columnValues.Insert(cvs...) + uq.columnValuesArgs = append(uq.columnValuesArgs, args...) + + return nil + } + } + + cv, arguments := upd.SQL().t.setColumnValues(terms) + uq.columnValues.Insert(cv.ColumnValues...) + uq.columnValuesArgs = append(uq.columnValuesArgs, arguments...) + return nil + }) +} + +func (upd *updater) Amend(fn func(string) string) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.amendFn = fn + return nil + }) +} + +func (upd *updater) Arguments() []interface{} { + uq, err := upd.build() + if err != nil { + return nil + } + return uq.arguments() +} + +func (upd *updater) Where(terms ...interface{}) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.where, uq.whereArgs = &exql.Where{}, []interface{}{} + return uq.and(upd.SQL(), terms...) + }) +} + +func (upd *updater) And(terms ...interface{}) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + return uq.and(upd.SQL(), terms...) + }) +} + +func (upd *updater) Prepare() (*sql.Stmt, error) { + return upd.PrepareContext(upd.SQL().sess.Context()) +} + +func (upd *updater) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + uq, err := upd.build() + if err != nil { + return nil, err + } + return upd.SQL().sess.StatementPrepare(ctx, uq.statement()) +} + +func (upd *updater) Exec() (sql.Result, error) { + return upd.ExecContext(upd.SQL().sess.Context()) +} + +func (upd *updater) ExecContext(ctx context.Context) (sql.Result, error) { + uq, err := upd.build() + if err != nil { + return nil, err + } + return upd.SQL().sess.StatementExec(ctx, uq.statement(), uq.arguments()...) +} + +func (upd *updater) Limit(limit int) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.limit = limit + return nil + }) +} + +func (upd *updater) statement() (*exql.Statement, error) { + iq, err := upd.build() + if err != nil { + return nil, err + } + return iq.statement(), nil +} + +func (upd *updater) build() (*updaterQuery, error) { + uq, err := immutable.FastForward(upd) + if err != nil { + return nil, err + } + return uq.(*updaterQuery), nil +} + +func (upd *updater) Compile() (string, error) { + s, err := upd.statement() + if err != nil { + return "", err + } + return s.Compile(upd.template()) +} + +func (upd *updater) Prev() immutable.Immutable { + if upd == nil { + return nil + } + return upd.prev +} + +func (upd *updater) Fn(in interface{}) error { + if upd.fn == nil { + return nil + } + return upd.fn(in.(*updaterQuery)) +} + +func (upd *updater) Base() interface{} { + return &updaterQuery{} +} diff --git a/internal/sqlbuilder/wrapper.go b/internal/sqlbuilder/wrapper.go new file mode 100644 index 0000000..6806f1e --- /dev/null +++ b/internal/sqlbuilder/wrapper.go @@ -0,0 +1,64 @@ +package sqlbuilder + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" +) + +// Tx represents a transaction on a SQL database. A transaction is like a +// regular Session except it has two extra methods: Commit and Rollback. +// +// A transaction needs to be committed (with Commit) to make changes permanent, +// changes can be discarded before committing by rolling back (with Rollback). +// After either committing or rolling back a transaction it can not longer be +// used and it's automatically closed. +type Tx interface { + // All mydb.Session methods are available on transaction sessions. They will + // run on the same transaction. + mydb.Session + + Commit() error + + Rollback() error +} + +// Adapter represents a SQL adapter. +type Adapter interface { + // New wraps an active *sql.DB session and returns a SQLBuilder database. The + // adapter needs to be imported to the blank namespace in order for it to be + // used here. + // + // This method is internally used by upper-db to create a builder backed by the + // given database. You may want to use your adapter's New function instead of + // this one. + New(*sql.DB) (mydb.Session, error) + + // NewTx wraps an active *sql.Tx transation and returns a SQLBuilder + // transaction. The adapter needs to be imported to the blank namespace in + // order for it to be used. + // + // This method is internally used by upper-db to create a builder backed by the + // given transaction. You may want to use your adapter's NewTx function + // instead of this one. + NewTx(*sql.Tx) (Tx, error) + + // Open opens a SQL database. + OpenDSN(mydb.ConnectionURL) (mydb.Session, error) +} + +type dbAdapter struct { + Adapter +} + +func (d *dbAdapter) Open(conn mydb.ConnectionURL) (mydb.Session, error) { + sess, err := d.Adapter.OpenDSN(conn) + if err != nil { + return nil, err + } + return sess.(mydb.Session), nil +} + +func NewCompatAdapter(adapter Adapter) mydb.Adapter { + return &dbAdapter{adapter} +} diff --git a/internal/testsuite/generic_suite.go b/internal/testsuite/generic_suite.go new file mode 100644 index 0000000..c6139f9 --- /dev/null +++ b/internal/testsuite/generic_suite.go @@ -0,0 +1,889 @@ +package testsuite + +import ( + "database/sql/driver" + "time" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/suite" + "gopkg.in/mgo.v2/bson" +) + +type birthday struct { + Name string `db:"name"` + Born time.Time `db:"born"` + BornUT *unixTimestamp `db:"born_ut,omitempty"` + OmitMe bool `json:"omit_me" db:"-" bson:"-"` +} + +type fibonacci struct { + Input uint64 `db:"input"` + Output uint64 `db:"output"` + // Test for BSON option. + OmitMe bool `json:"omit_me" db:"omit_me,bson,omitempty" bson:"omit_me,omitempty"` +} + +type oddEven struct { + // Test for JSON option. + Input int `json:"input" db:"input"` + // Test for JSON option. + // The "bson" tag is required by mgo. + IsEven bool `json:"is_even" db:"is_even,json" bson:"is_even"` + OmitMe bool `json:"omit_me" db:"-" bson:"-"` +} + +// Struct that relies on explicit mapping. +type mapE struct { + ID uint `db:"id,omitempty" bson:"-"` + MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` + CaseTest string `db:"case_test" bson:"case_test"` +} + +// Struct that will fallback to default mapping. +type mapN struct { + ID uint `db:"id,omitempty"` + MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` + Case_TEST string `db:"case_test"` +} + +// Struct for testing marshalling. +type unixTimestamp struct { + // Time is handled internally as time.Time but saved as an (integer) unix + // timestamp. + value time.Time +} + +func (u unixTimestamp) Value() (driver.Value, error) { + return u.value.UTC().Unix(), nil +} + +func (u *unixTimestamp) Scan(v interface{}) error { + var unixTime int64 + + switch t := v.(type) { + case int64: + unixTime = t + case nil: + return nil + default: + return mydb.ErrUnsupportedValue + } + + t := time.Unix(unixTime, 0).In(time.UTC) + *u = unixTimestamp{t} + + return nil +} + +func newUnixTimestamp(t time.Time) *unixTimestamp { + return &unixTimestamp{t.UTC()} +} + +func even(i int) bool { + return i%2 == 0 +} + +func fib(i uint64) uint64 { + if i == 0 { + return 0 + } else if i == 1 { + return 1 + } + return fib(i-1) + fib(i-2) +} + +type GenericTestSuite struct { + suite.Suite + + Helper +} + +func (s *GenericTestSuite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *GenericTestSuite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) +} + +func (s *GenericTestSuite) TestDatesAndUnicode() { + sess := s.Session() + + testTimeZone := time.Local + switch s.Adapter() { + case "mysql", "cockroachdb", "postgresql": + testTimeZone = defaultTimeLocation + case "sqlite", "ql", "mssql": + testTimeZone = time.UTC + } + + born := time.Date(1941, time.January, 5, 0, 0, 0, 0, testTimeZone) + + controlItem := birthday{ + Name: "Hayao Miyazaki", + Born: born, + BornUT: newUnixTimestamp(born), + } + + col := sess.Collection(`birthdays`) + + record, err := col.Insert(controlItem) + s.NoError(err) + s.NotZero(record.ID()) + + var res mydb.Result + switch s.Adapter() { + case "mongo": + res = col.Find(mydb.Cond{"_id": record.ID().(bson.ObjectId)}) + case "ql": + res = col.Find(mydb.Cond{"id()": record.ID()}) + default: + res = col.Find(mydb.Cond{"id": record.ID()}) + } + + var total uint64 + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(1), total) + + switch s.Adapter() { + case "mongo": + s.T().Skip() + } + + var testItem birthday + err = res.One(&testItem) + s.NoError(err) + + switch s.Adapter() { + case "sqlite", "ql", "mssql": + testItem.Born = testItem.Born.In(time.UTC) + } + s.Equal(controlItem.Born, testItem.Born) + + s.Equal(controlItem.BornUT, testItem.BornUT) + s.Equal(controlItem, testItem) + + var testItems []birthday + err = res.All(&testItems) + s.NoError(err) + s.NotZero(len(testItems)) + + for _, testItem = range testItems { + switch s.Adapter() { + case "sqlite", "ql", "mssql": + testItem.Born = testItem.Born.In(time.UTC) + } + s.Equal(controlItem, testItem) + } + + controlItem.Name = `宮崎駿` + err = res.Update(controlItem) + s.NoError(err) + + err = res.One(&testItem) + s.NoError(err) + + switch s.Adapter() { + case "sqlite", "ql", "mssql": + testItem.Born = testItem.Born.In(time.UTC) + } + + s.Equal(controlItem, testItem) + + err = res.Delete() + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Zero(total) + + err = res.Close() + s.NoError(err) +} + +func (s *GenericTestSuite) TestFibonacci() { + var err error + var res mydb.Result + var total uint64 + + sess := s.Session() + + col := sess.Collection("fibonacci") + + // Adding some items. + var i uint64 + for i = 0; i < 10; i++ { + item := fibonacci{Input: i, Output: fib(i)} + _, err = col.Insert(item) + s.NoError(err) + } + + // Testing sorting by function. + res = col.Find( + // 5, 6, 7, 3 + mydb.Or( + mydb.And( + mydb.Cond{"input": mydb.Gte(5)}, + mydb.Cond{"input": mydb.Lte(7)}, + ), + mydb.Cond{"input": mydb.Eq(3)}, + ), + ) + + // Testing sort by function. + switch s.Adapter() { + case "postgresql": + res = res.OrderBy(mydb.Raw("RANDOM()")) + case "sqlite": + res = res.OrderBy(mydb.Raw("RANDOM()")) + case "mysql": + res = res.OrderBy(mydb.Raw("RAND()")) + case "mssql": + res = res.OrderBy(mydb.Raw("NEWID()")) + } + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + // Find() with IN/$in + res = col.Find(mydb.Cond{"input IN": []int{3, 5, 6, 7}}).OrderBy("input") + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + res = res.Offset(1).Limit(2) + + var item fibonacci + for res.Next(&item) { + switch item.Input { + case 5: + case 6: + s.Equal(fib(item.Input), item.Output) + default: + s.T().Errorf(`Unexpected item: %v.`, item) + } + } + s.NoError(res.Err()) + + // Find() with range + res = col.Find( + // 5, 6, 7, 3 + mydb.Or( + mydb.And( + mydb.Cond{"input >=": 5}, + mydb.Cond{"input <=": 7}, + ), + mydb.Cond{"input": 3}, + ), + ).OrderBy("-input") + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + // Skipping. + res = res.Offset(1).Limit(2) + + var item2 fibonacci + for res.Next(&item2) { + switch item2.Input { + case 5: + case 6: + s.Equal(fib(item2.Input), item2.Output) + default: + s.T().Errorf(`Unexpected item: %v.`, item2) + } + } + err = res.Err() + s.NoError(err) + + err = res.Delete() + s.NoError(err) + + { + total, err := res.Count() + s.NoError(err) + s.Zero(total) + } + + // Find() with no arguments. + res = col.Find() + { + total, err := res.Count() + s.NoError(err) + s.Equal(uint64(6), total) + } + + // Skipping mongodb as the results of this are not defined there. + if s.Adapter() != `mongo` { + // Find() with empty mydb.Cond. + { + total, err := col.Find(mydb.Cond{}).Count() + s.NoError(err) + s.Equal(uint64(6), total) + } + + // Find() with empty expression + { + total, err := col.Find(mydb.Or(mydb.And(mydb.Cond{}, mydb.Cond{}), mydb.Or(mydb.Cond{}))).Count() + s.NoError(err) + s.Equal(uint64(6), total) + } + + // Find() with explicit IS NULL + { + total, err := col.Find(mydb.Cond{"input IS": nil}).Count() + s.NoError(err) + s.Equal(uint64(0), total) + } + + // Find() with implicit IS NULL + { + total, err := col.Find(mydb.Cond{"input": nil}).Count() + s.NoError(err) + s.Equal(uint64(0), total) + } + + // Find() with explicit = NULL + { + total, err := col.Find(mydb.Cond{"input =": nil}).Count() + s.NoError(err) + s.Equal(uint64(0), total) + } + + // Find() with implicit IN + { + total, err := col.Find(mydb.Cond{"input": []int{1, 2, 3, 4}}).Count() + s.NoError(err) + s.Equal(uint64(3), total) + } + + // Find() with implicit NOT IN + { + total, err := col.Find(mydb.Cond{"input NOT IN": []int{1, 2, 3, 4}}).Count() + s.NoError(err) + s.Equal(uint64(3), total) + } + } + + var items []fibonacci + err = res.All(&items) + s.NoError(err) + + for _, item := range items { + switch item.Input { + case 0: + case 1: + case 2: + case 4: + case 8: + case 9: + s.Equal(fib(item.Input), item.Output) + default: + s.T().Errorf(`Unexpected item: %v`, item) + } + } + + err = res.Close() + s.NoError(err) +} + +func (s *GenericTestSuite) TestOddEven() { + sess := s.Session() + + col := sess.Collection("is_even") + + // Adding some items. + var i int + for i = 1; i < 100; i++ { + item := oddEven{Input: i, IsEven: even(i)} + _, err := col.Insert(item) + s.NoError(err) + } + + // Retrieving items + res := col.Find(mydb.Cond{"is_even": true}) + + var item oddEven + for res.Next(&item) { + s.Zero(item.Input % 2) + } + + err := res.Err() + s.NoError(err) + + err = res.Delete() + s.NoError(err) + + // Testing named inputs (using tags). + res = col.Find() + + var item2 struct { + Value uint `db:"input" bson:"input"` // The "bson" tag is required by mgo. + } + for res.Next(&item2) { + s.NotZero(item2.Value % 2) + } + err = res.Err() + s.NoError(err) + + // Testing inline tag. + res = col.Find() + + var item3 struct { + OddEven oddEven `db:",inline" bson:",inline"` + } + for res.Next(&item3) { + s.NotZero(item3.OddEven.Input % 2) + s.NotZero(item3.OddEven.Input) + } + err = res.Err() + s.NoError(err) + + // Testing inline tag. + type OddEven oddEven + res = col.Find() + + var item31 struct { + OddEven `db:",inline" bson:",inline"` + } + for res.Next(&item31) { + s.NotZero(item31.Input % 2) + s.NotZero(item31.Input) + } + s.NoError(res.Err()) + + // Testing omision tag. + res = col.Find() + + var item4 struct { + Value uint `db:"-"` + } + for res.Next(&item4) { + s.Zero(item4.Value) + } + s.NoError(res.Err()) +} + +func (s *GenericTestSuite) TestExplicitAndDefaultMapping() { + var err error + var res mydb.Result + + var testE mapE + var testN mapN + + sess := s.Session() + + col := sess.Collection("CaSe_TesT") + + if err = col.Truncate(); err != nil { + if s.Adapter() != "mongo" { + s.NoError(err) + } + } + + // Testing explicit mapping. + testE = mapE{ + CaseTest: "Hello!", + } + + _, err = col.Insert(testE) + s.NoError(err) + + res = col.Find(mydb.Cond{"case_test": "Hello!"}) + if s.Adapter() == "ql" { + res = res.Select("id() as id", "case_test") + } + + err = res.One(&testE) + s.NoError(err) + + if s.Adapter() == "mongo" { + s.True(testE.MongoID.Valid()) + } else { + s.NotZero(testE.ID) + } + + // Testing default mapping. + testN = mapN{ + Case_TEST: "World!", + } + + _, err = col.Insert(testN) + s.NoError(err) + + if s.Adapter() == `mongo` { + res = col.Find(mydb.Cond{"case_test": "World!"}) + } else { + res = col.Find(mydb.Cond{"case_test": "World!"}) + } + + if s.Adapter() == `ql` { + res = res.Select(`id() as id`, `case_test`) + } + + err = res.One(&testN) + s.NoError(err) + + if s.Adapter() == `mongo` { + s.True(testN.MongoID.Valid()) + } else { + s.NotZero(testN.ID) + } +} + +func (s *GenericTestSuite) TestComparisonOperators() { + sess := s.Session() + + birthdays := sess.Collection("birthdays") + err := birthdays.Truncate() + if err != nil { + if s.Adapter() != "mongo" { + s.NoError(err) + } + } + + // Insert data for testing + birthdaysDataset := []birthday{ + { + Name: "Marie Smith", + Born: time.Date(1956, time.August, 5, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Peter", + Born: time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Eve Smith", + Born: time.Date(1911, time.February, 8, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Alex López", + Born: time.Date(2001, time.May, 5, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Rose Smith", + Born: time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Daria López", + Born: time.Date(1923, time.March, 23, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "", + Born: time.Date(1945, time.December, 1, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Colin", + Born: time.Date(2010, time.May, 6, 0, 0, 0, 0, defaultTimeLocation), + }, + } + for _, birthday := range birthdaysDataset { + _, err := birthdays.Insert(birthday) + s.NoError(err) + } + + // Test: equal + { + var item birthday + err := birthdays.Find(mydb.Cond{ + "name": mydb.Eq("Colin"), + }).One(&item) + s.NoError(err) + s.NotNil(item) + + s.Equal("Colin", item.Name) + } + + // Test: not equal + { + var item birthday + err := birthdays.Find(mydb.Cond{ + "name": mydb.NotEq("Colin"), + }).One(&item) + s.NoError(err) + s.NotNil(item) + + s.NotEqual("Colin", item.Name) + } + + // Test: greater than + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Gt(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(2, len(items)) + for _, item := range items { + s.True(item.Born.After(ref)) + } + } + + // Test: less than + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Lt(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(5, len(items)) + for _, item := range items { + s.True(item.Born.Before(ref)) + } + } + + // Test: greater than or equal to + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Gte(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(3, len(items)) + for _, item := range items { + s.True(item.Born.After(ref) || item.Born.Equal(ref)) + } + } + + // Test: less than or equal to + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Lte(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(6, len(items)) + for _, item := range items { + s.True(item.Born.Before(ref) || item.Born.Equal(ref)) + } + } + + // Test: between + { + var items []birthday + dateA := time.Date(1911, time.February, 8, 0, 0, 0, 0, defaultTimeLocation) + dateB := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Between(dateA, dateB), + }).All(&items) + s.NoError(err) + s.Equal(6, len(items)) + for _, item := range items { + s.True(item.Born.After(dateA) || item.Born.Equal(dateA)) + s.True(item.Born.Before(dateB) || item.Born.Equal(dateB)) + } + } + + // Test: not between + { + var items []birthday + dateA := time.Date(1911, time.February, 8, 0, 0, 0, 0, defaultTimeLocation) + dateB := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.NotBetween(dateA, dateB), + }).All(&items) + s.NoError(err) + s.Equal(2, len(items)) + for _, item := range items { + s.False(item.Born.Before(dateA) || item.Born.Equal(dateA)) + s.False(item.Born.Before(dateB) || item.Born.Equal(dateB)) + } + } + + // Test: in + { + var items []birthday + names := []interface{}{"Peter", "Eve Smith", "Daria López", "Alex López"} + err := birthdays.Find(mydb.Cond{ + "name": mydb.In(names...), + }).All(&items) + s.NoError(err) + s.Equal(4, len(items)) + for _, item := range items { + inArray := false + for _, name := range names { + if name == item.Name { + inArray = true + } + } + s.True(inArray) + } + } + + // Test: not in + { + var items []birthday + names := []interface{}{"Peter", "Eve Smith", "Daria López", "Alex López"} + err := birthdays.Find(mydb.Cond{ + "name": mydb.NotIn(names...), + }).All(&items) + s.NoError(err) + s.Equal(4, len(items)) + for _, item := range items { + inArray := false + for _, name := range names { + if name == item.Name { + inArray = true + } + } + s.False(inArray) + } + } + + // Test: not in + { + var items []birthday + names := []interface{}{"Peter", "Eve Smith", "Daria López", "Alex López"} + err := birthdays.Find(mydb.Cond{ + "name": mydb.NotIn(names...), + }).All(&items) + s.NoError(err) + s.Equal(4, len(items)) + for _, item := range items { + inArray := false + for _, name := range names { + if name == item.Name { + inArray = true + } + } + s.False(inArray) + } + } + + // Test: is and is not + { + var items []birthday + err := birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.Is(nil)}, + mydb.Cond{"name": mydb.IsNot(nil)}, + )).All(&items) + s.NoError(err) + s.Equal(0, len(items)) + } + + // Test: is nil + { + var items []birthday + err := birthdays.Find(mydb.And( + mydb.Cond{"born_ut": mydb.IsNull()}, + )).All(&items) + s.NoError(err) + s.Equal(8, len(items)) + } + + // Test: like and not like + { + var items []birthday + var q mydb.Result + + switch s.Adapter() { + case "ql", "mongo": + q = birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.Like(".*ari.*")}, + mydb.Cond{"name": mydb.NotLike(".*Smith")}, + )) + default: + q = birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.Like("%ari%")}, + mydb.Cond{"name": mydb.NotLike("%Smith")}, + )) + } + + err := q.All(&items) + s.NoError(err) + s.Equal(1, len(items)) + + s.Equal("Daria López", items[0].Name) + } + + if s.Adapter() != "sqlite" && s.Adapter() != "mssql" { + // Test: regexp + { + var items []birthday + err := birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.RegExp("^[D|C|M]")}, + )).OrderBy("name").All(&items) + s.NoError(err) + s.Equal(3, len(items)) + + s.Equal("Colin", items[0].Name) + s.Equal("Daria López", items[1].Name) + s.Equal("Marie Smith", items[2].Name) + } + + // Test: not regexp + { + var items []birthday + names := []string{"Daria López", "Colin", "Marie Smith"} + err := birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.NotRegExp("^[D|C|M]")}, + )).OrderBy("name").All(&items) + s.NoError(err) + s.Equal(5, len(items)) + + for _, item := range items { + for _, name := range names { + s.NotEqual(item.Name, name) + } + } + } + } + + // Test: after + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.After(ref), + }).All(&items) + s.NoError(err) + s.Equal(5, len(items)) + } + + // Test: on or after + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.OnOrAfter(ref), + }).All(&items) + s.NoError(err) + s.Equal(6, len(items)) + } + + // Test: before + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.Before(ref), + }).All(&items) + s.NoError(err) + s.Equal(2, len(items)) + } + + // Test: on or before + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.OnOrBefore(ref), + }).All(&items) + s.NoError(err) + s.Equal(3, len(items)) + } +} diff --git a/internal/testsuite/record_suite.go b/internal/testsuite/record_suite.go new file mode 100644 index 0000000..89cddc4 --- /dev/null +++ b/internal/testsuite/record_suite.go @@ -0,0 +1,428 @@ +package testsuite + +import ( + "context" + "database/sql" + "fmt" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/suite" +) + +type AccountsStore struct { + mydb.Collection +} + +type UsersStore struct { + mydb.Collection +} + +type LogsStore struct { + mydb.Collection +} + +func Accounts(sess mydb.Session) mydb.Store { + return &AccountsStore{sess.Collection("accounts")} +} + +func Users(sess mydb.Session) *UsersStore { + return &UsersStore{sess.Collection("users")} +} + +func Logs(sess mydb.Session) *LogsStore { + return &LogsStore{sess.Collection("logs")} +} + +type Log struct { + ID uint64 `db:"id,omitempty"` + Message string `db:"message"` +} + +func (*Log) Store(sess mydb.Session) mydb.Store { + return Logs(sess) +} + +var _ = mydb.Store(&LogsStore{}) + +type Account struct { + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` + Disabled bool `db:"disabled"` + CreatedAt *time.Time `db:"created_at,omitempty"` +} + +func (*Account) Store(sess mydb.Session) mydb.Store { + return Accounts(sess) +} + +func (account *Account) AfterCreate(sess mydb.Session) error { + message := fmt.Sprintf("Account %q was created.", account.Name) + return sess.Save(&Log{Message: message}) +} + +type User struct { + ID uint64 `db:"id,omitempty"` + AccountID uint64 `db:"account_id"` + Username string `db:"username"` +} + +func (user *User) AfterCreate(sess mydb.Session) error { + message := fmt.Sprintf("User %q was created.", user.Username) + return sess.Save(&Log{Message: message}) +} + +func (*User) Store(sess mydb.Session) mydb.Store { + return Users(sess) +} + +type RecordTestSuite struct { + suite.Suite + Helper +} + +func (s *RecordTestSuite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *RecordTestSuite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) + + sess := s.Helper.Session() + + cols, err := sess.Collections() + s.NoError(err) + + for i := range cols { + err = cols[i].Truncate() + s.NoError(err) + } +} + +func (s *RecordTestSuite) TestFindOne() { + var err error + sess := s.Session() + + user := User{Username: "jose"} + err = sess.Save(&user) + s.NoError(err) + + s.NotZero(user.ID) + userID := user.ID + + user = User{} + err = Users(sess).Find(userID).One(&user) + s.NoError(err) + s.Equal("jose", user.Username) + + user = User{} + err = sess.Get(&user, mydb.Cond{"username": "jose"}) + s.NoError(err) + s.Equal("jose", user.Username) + + user.Username = "Catalina" + err = sess.Save(&user) + s.NoError(err) + + user = User{} + err = sess.Get(&user, userID) + s.NoError(err) + s.Equal("Catalina", user.Username) + + err = sess.Delete(&user) + s.NoError(err) + + err = sess.Get(&user, userID) + s.Error(err) + + err = sess.Collection("users"). + Find(userID). + One(&user) + s.Error(err) +} + +func (s *RecordTestSuite) TestAccounts() { + sess := s.Session() + + user := User{Username: "peter"} + + err := sess.Save(&user) + s.NoError(err) + + user = User{Username: "peter"} + err = sess.Save(&user) + s.Error(err, "username should be unique") + + account1 := Account{Name: "skywalker"} + err = sess.Save(&account1) + s.NoError(err) + + account2 := Account{} + err = sess.Get(&account2, account1.ID) + + s.NoError(err) + s.Equal(account1.Name, account2.Name) + + var account3 Account + err = sess.Get(&account3, account1.ID) + + s.NoError(err) + s.Equal(account1.Name, account3.Name) + + var a Account + err = sess.Get(&a, account1.ID) + s.NoError(err) + s.NotNil(a) + + account1.Disabled = true + err = sess.Save(&account1) + s.NoError(err) + + count, err := Accounts(sess).Count() + s.NoError(err) + s.Equal(uint64(1), count) + + err = sess.Delete(&account1) + s.NoError(err) + + count, err = Accounts(sess).Find().Count() + s.NoError(err) + s.Zero(count) +} + +func (s *RecordTestSuite) TestDelete() { + sess := s.Session() + + account := Account{Name: "Pressly"} + err := sess.Save(&account) + s.NoError(err) + s.NotZero(account.ID) + + // Delete by query -- without callbacks + err = Accounts(sess). + Find(account.ID). + Delete() + s.NoError(err) + + count, err := Accounts(sess).Find(account.ID).Count() + s.Zero(count) + s.NoError(err) +} + +func (s *RecordTestSuite) TestSlices() { + sess := s.Session() + + err := sess.Save(&Account{Name: "Apple"}) + s.NoError(err) + + err = sess.Save(&Account{Name: "Google"}) + s.NoError(err) + + var accounts []*Account + err = Accounts(sess). + Find(mydb.Cond{}). + All(&accounts) + s.NoError(err) + s.Len(accounts, 2) +} + +func (s *RecordTestSuite) TestSelectOnlyIDs() { + sess := s.Session() + + err := sess.Save(&Account{Name: "Apple"}) + s.NoError(err) + + err = sess.Save(&Account{Name: "Google"}) + s.NoError(err) + + var ids []struct { + Id int64 `db:"id"` + } + + err = Accounts(sess). + Find(). + Select("id").All(&ids) + s.NoError(err) + s.Len(ids, 2) + s.NotEmpty(ids[0]) +} + +func (s *RecordTestSuite) TestTx() { + sess := s.Session() + + user := User{Username: "peter"} + err := sess.Save(&user) + s.NoError(err) + + // This transaction should fail because user is a UNIQUE value and we already + // have a "peter". + err = sess.Tx(func(tx mydb.Session) error { + return tx.Save(&User{Username: "peter"}) + }) + s.Error(err) + + // This transaction should fail because user is a UNIQUE value and we already + // have a "peter". + err = sess.Tx(func(tx mydb.Session) error { + return tx.Save(&User{Username: "peter"}) + }) + s.Error(err) + + // This transaction will have no errors, but we'll produce one in order for + // it to rollback at the last moment. + err = sess.Tx(func(tx mydb.Session) error { + if err := tx.Save(&User{Username: "Joe"}); err != nil { + return err + } + + if err := tx.Save(&User{Username: "Cool"}); err != nil { + return err + } + + return fmt.Errorf("Rolling back for no reason.") + }) + s.Error(err) + + // Attempt to add two new unique values, if the transaction above had not + // been rolled back this transaction will fail. + err = sess.Tx(func(tx mydb.Session) error { + if err := tx.Save(&User{Username: "Joe"}); err != nil { + return err + } + + if err := tx.Save(&User{Username: "Cool"}); err != nil { + return err + } + + return nil + }) + s.NoError(err) + + // If the transaction above was successful, this one will fail. + err = sess.Tx(func(tx mydb.Session) error { + if err := tx.Save(&User{Username: "Joe"}); err != nil { + return err + } + + if err := tx.Save(&User{Username: "Cool"}); err != nil { + return err + } + + return nil + }) + s.Error(err) +} + +func (s *RecordTestSuite) TestInheritedTx() { + sess := s.Session() + + sqlDB := sess.Driver().(*sql.DB) + + user := User{Username: "peter"} + err := sess.Save(&user) + s.NoError(err) + + // Create a transaction + sqlTx, err := sqlmydb.Begin() + s.NoError(err) + + // And pass that transaction to upper/db, this whole session is a transaction. + upperTx, err := sqlbuilder.BindTx(s.Adapter(), sqlTx) + s.NoError(err) + + // Should fail because user is a UNIQUE value and we already have a "peter". + err = upperTx.Save(&User{Username: "peter"}) + s.Error(err) + + // The transaction is controlled outside upper/mydb. + err = sqlTx.Rollback() + s.NoError(err) + + // The sqlTx is worthless now. + err = upperTx.Save(&User{Username: "peter-2"}) + s.Error(err) + + // But we can create a new one. + sqlTx, err = sqlmydb.Begin() + s.NoError(err) + s.NotNil(sqlTx) + + // And create another session. + upperTx, err = sqlbuilder.BindTx(s.Adapter(), sqlTx) + s.NoError(err) + + // Adding two new values. + err = upperTx.Save(&User{Username: "Joe-2"}) + s.NoError(err) + + err = upperTx.Save(&User{Username: "Cool-2"}) + s.NoError(err) + + // And a value that is going to be rolled back. + err = upperTx.Save(&Account{Name: "Rolled back"}) + s.NoError(err) + + // This session happens to be a transaction, let's rollback everything. + err = sqlTx.Rollback() + s.NoError(err) + + // Start again. + sqlTx, err = sqlmydb.Begin() + s.NoError(err) + + tx, err := sqlbuilder.BindTx(s.Adapter(), sqlTx) + s.NoError(err) + + // Attempt to add two unique values. + err = tx.Save(&User{Username: "Joe-2"}) + s.NoError(err) + + err = tx.Save(&User{Username: "Cool-2"}) + s.NoError(err) + + // And a value that is going to be commited. + err = tx.Save(&Account{Name: "Commited!"}) + s.NoError(err) + + // Yes, commit them. + err = sqlTx.Commit() + s.NoError(err) +} + +func (s *RecordTestSuite) TestUnknownCollection() { + var err error + sess := s.Session() + + err = sess.Save(nil) + s.Error(err) + + _, err = sess.Collection("users").Insert(&User{Username: "Foo"}) + s.NoError(err) +} + +func (s *RecordTestSuite) TestContextCanceled() { + var err error + + sess := s.Session() + + err = sess.Collection("users").Truncate() + s.NoError(err) + + { + ctx, cancelFn := context.WithTimeout(context.Background(), time.Minute) + canceledSess := sess.WithContext(ctx) + + cancelFn() + + user := User{Username: "foo"} + err = canceledSess.Save(&user) + s.Error(err) + + c, err := sess.Collection("users").Count() + s.NoError(err) + s.Equal(uint64(0), c) + } +} diff --git a/internal/testsuite/sql_suite.go b/internal/testsuite/sql_suite.go new file mode 100644 index 0000000..b19a1e6 --- /dev/null +++ b/internal/testsuite/sql_suite.go @@ -0,0 +1,1974 @@ +package testsuite + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math/rand" + "strconv" + "strings" + "sync" + "time" + + "git.hexq.cn/tiglog/mydb" + detectrace "github.com/ipfs/go-detect-race" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" +) + +type artistType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` +} + +type itemWithCompoundKey struct { + Code string `db:"code"` + UserID string `db:"user_id"` + SomeVal string `db:"some_val"` +} + +type customType struct { + Val []byte +} + +type artistWithCustomType struct { + Custom customType `db:"name"` +} + +func (f customType) String() string { + return fmt.Sprintf("foo: %s", string(f.Val)) +} + +func (f customType) MarshalDB() (interface{}, error) { + return f.String(), nil +} + +func (f *customType) UnmarshalDB(in interface{}) error { + switch t := in.(type) { + case []byte: + f.Val = t + case string: + f.Val = []byte(t) + } + return nil +} + +var ( + _ = mydb.Marshaler(&customType{}) + _ = mydb.Unmarshaler(&customType{}) +) + +type SQLTestSuite struct { + suite.Suite + + Helper +} + +func (s *SQLTestSuite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *SQLTestSuite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) + + sess := s.Session() + + // Creating test data + artist := sess.Collection("artist") + + artistNames := []string{"Ozzie", "Flea", "Slash", "Chrono"} + for _, artistName := range artistNames { + _, err := artist.Insert(map[string]string{ + "name": artistName, + }) + s.NoError(err) + } +} + +func (s *SQLTestSuite) TestPreparedStatementsCache() { + sess := s.Session() + + sess.SetPreparedStatementCache(true) + defer sess.SetPreparedStatementCache(false) + + var tMu sync.Mutex + tFatal := func(err error) { + tMu.Lock() + defer tMu.Unlock() + + s.T().Errorf("tmu: %v", err) + } + + // This limit was chosen because, by default, MySQL accepts 16k statements + // and dies. See https://github.com/upper/db/issues/287 + limit := 20000 + + if detectrace.WithRace() { + // When running this test under the Go race detector we quickly reach the limit + // of 8128 alive goroutines it can handle, so we set it to a safer number. + // + // Note that in order to fully stress this feature you'll have to run this + // test without the race detector. + limit = 100 + } + + var wg sync.WaitGroup + + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + // This query is different on each iteration and generates a new + // prepared statement everytime it's called. + res := sess.Collection("artist").Find().Select(mydb.Raw(fmt.Sprintf("count(%d) AS c", i))) + + var count map[string]uint64 + err := res.One(&count) + if err != nil { + tFatal(err) + } + }(i) + } + wg.Wait() + + // Concurrent Insert can open many connections on MySQL / PostgreSQL, this + // sets a limit on them. + sess.SetMaxOpenConns(90) + + switch s.Adapter() { + case "ql": + limit = 1000 + } + + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // The same prepared query on every iteration. + _, err := sess.Collection("artist").Insert(artistType{ + Name: fmt.Sprintf("artist-%d", i), + }) + if err != nil { + tFatal(err) + } + }(i) + } + wg.Wait() + + // Insert returning creates a transaction. + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // The same prepared query on every iteration. + artist := artistType{ + Name: fmt.Sprintf("artist-%d", i), + } + err := sess.Collection("artist").InsertReturning(&artist) + if err != nil { + tFatal(err) + } + }(i) + } + wg.Wait() + + // Removing the limit. + sess.SetMaxOpenConns(0) +} + +func (s *SQLTestSuite) TestTruncateAllCollections() { + sess := s.Session() + + collections, err := sess.Collections() + s.NoError(err) + s.True(len(collections) > 0) + + for _, col := range collections { + if ok, _ := col.Exists(); ok { + if err = col.Truncate(); err != nil { + s.NoError(err) + } + } + } +} + +func (s *SQLTestSuite) TestQueryLogger() { + logLevel := mydb.LC().Level() + + mydb.LC().SetLogger(logrus.New()) + mydb.LC().SetLevel(mydb.LogLevelDebug) + + defer func() { + mydb.LC().SetLogger(nil) + mydb.LC().SetLevel(logLevel) + }() + + sess := s.Session() + + _, err := sess.Collection("artist").Find().Count() + s.Equal(nil, err) + + _, err = sess.Collection("artist_x").Find().Count() + s.NotEqual(nil, err) +} + +func (s *SQLTestSuite) TestExpectCursorError() { + sess := s.Session() + + artist := sess.Collection("artist") + + res := artist.Find(-1) + c, err := res.Count() + s.Equal(uint64(0), c) + s.NoError(err) + + var item map[string]interface{} + err = res.One(&item) + s.Error(err) +} + +func (s *SQLTestSuite) TestInsertDefault() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + id, err := artist.Insert(&artistType{}) + s.NoError(err) + s.NotNil(id) + + err = artist.Truncate() + s.NoError(err) + + id, err = artist.Insert(nil) + s.NoError(err) + s.NotNil(id) +} + +func (s *SQLTestSuite) TestInsertReturning() { + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + itemMap := map[string]string{ + "name": "Ozzie", + } + s.Zero(itemMap["id"], "Must be zero before inserting") + err = artist.InsertReturning(&itemMap) + s.NoError(err) + s.NotZero(itemMap["id"], "Must not be zero after inserting") + + itemStruct := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Flea", + } + s.Zero(itemStruct.ID, "Must be zero before inserting") + err = artist.InsertReturning(&itemStruct) + s.NoError(err) + s.NotZero(itemStruct.ID, "Must not be zero after inserting") + + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") + + itemStruct2 := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Slash", + } + s.Zero(itemStruct2.ID, "Must be zero before inserting") + err = artist.InsertReturning(itemStruct2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemStruct2.ID, "Must still be zero because there was no insertion") + + itemMap2 := map[string]string{ + "name": "Janus", + } + s.Zero(itemMap2["id"], "Must be zero before inserting") + err = artist.InsertReturning(itemMap2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemMap2["id"], "Must still be zero because there was no insertion") + + // Counting elements, must be exactly 2 elements. + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") +} + +func (s *SQLTestSuite) TestInsertReturningWithinTransaction() { + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + err = sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + + itemMap := map[string]string{ + "name": "Ozzie", + } + s.Zero(itemMap["id"], "Must be zero before inserting") + err = artist.InsertReturning(&itemMap) + s.NoError(err) + s.NotZero(itemMap["id"], "Must not be zero after inserting") + + itemStruct := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Flea", + } + s.Zero(itemStruct.ID, "Must be zero before inserting") + err = artist.InsertReturning(&itemStruct) + s.NoError(err) + s.NotZero(itemStruct.ID, "Must not be zero after inserting") + + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") + + itemStruct2 := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Slash", + } + s.Zero(itemStruct2.ID, "Must be zero before inserting") + err = artist.InsertReturning(itemStruct2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemStruct2.ID, "Must still be zero because there was no insertion") + + itemMap2 := map[string]string{ + "name": "Janus", + } + s.Zero(itemMap2["id"], "Must be zero before inserting") + err = artist.InsertReturning(itemMap2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemMap2["id"], "Must still be zero because there was no insertion") + + // Counting elements, must be exactly 2 elements. + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") + + return fmt.Errorf("rolling back for no reason") + }) + s.Error(err) + + // Expecting no elements. + count, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(0), count, "Expecting 0 elements, everything was rolled back!") +} + +func (s *SQLTestSuite) TestInsertIntoArtistsTable() { + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + itemMap := map[string]string{ + "name": "Ozzie", + } + + record, err := artist.Insert(itemMap) + s.NoError(err) + s.NotNil(record) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + // Attempt to append a struct. + itemStruct := struct { + Name string `db:"name"` + }{ + "Flea", + } + + record, err = artist.Insert(itemStruct) + s.NoError(err) + s.NotNil(record) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + // Attempt to append a tagged struct. + itemStruct2 := struct { + ArtistName string `db:"name"` + }{ + "Slash", + } + + record, err = artist.Insert(&itemStruct2) + s.NoError(err) + s.NotNil(record) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + itemStruct3 := artistType{ + Name: "Janus", + } + record, err = artist.Insert(&itemStruct3) + s.NoError(err) + if s.Adapter() != "ql" { + s.NotZero(record) // QL always inserts an ID. + } + + // Counting elements, must be exactly 4 elements. + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(4), count) + + count, err = artist.Find(mydb.Cond{"name": mydb.Eq("Ozzie")}).Count() + s.NoError(err) + s.Equal(uint64(1), count) + + count, err = artist.Find("name", "Ozzie").And("name", "Flea").Count() + s.NoError(err) + s.Equal(uint64(0), count) + + count, err = artist.Find(mydb.Or(mydb.Cond{"name": "Ozzie"}, mydb.Cond{"name": "Flea"})).Count() + s.NoError(err) + s.Equal(uint64(2), count) + + count, err = artist.Find(mydb.And(mydb.Cond{"name": "Ozzie"}, mydb.Cond{"name": "Flea"})).Count() + s.NoError(err) + s.Equal(uint64(0), count) + + count, err = artist.Find(mydb.Cond{"name": "Ozzie"}).And(mydb.Cond{"name": "Flea"}).Count() + s.NoError(err) + s.Equal(uint64(0), count) +} + +func (s *SQLTestSuite) TestQueryNonExistentCollection() { + sess := s.Session() + + count, err := sess.Collection("doesnotexist").Find().Count() + s.Error(err) + s.Zero(count) +} + +func (s *SQLTestSuite) TestGetOneResult() { + sess := s.Session() + + artist := sess.Collection("artist") + + for i := 0; i < 5; i++ { + _, err := artist.Insert(map[string]string{ + "name": fmt.Sprintf("Artist %d", i), + }) + s.NoError(err) + } + + // Fetching one struct. + var someArtist artistType + err := artist.Find().Limit(1).One(&someArtist) + s.NoError(err) + + s.NotZero(someArtist.Name) + if s.Adapter() != "ql" { + s.NotZero(someArtist.ID) + } + + // Fetching a pointer to a pointer. + var someArtistObj *artistType + err = artist.Find().Limit(1).One(&someArtistObj) + s.NoError(err) + s.NotZero(someArtist.Name) + if s.Adapter() != "ql" { + s.NotZero(someArtist.ID) + } +} + +func (s *SQLTestSuite) TestGetWithOffset() { + sess := s.Session() + + artist := sess.Collection("artist") + + // Fetching one struct. + var artists []artistType + err := artist.Find().Offset(1).All(&artists) + s.NoError(err) + + s.Equal(3, len(artists)) +} + +func (s *SQLTestSuite) TestGetResultsOneByOne() { + sess := s.Session() + + artist := sess.Collection("artist") + + rowMap := map[string]interface{}{} + + res := artist.Find() + + err := res.Err() + s.NoError(err) + + for res.Next(&rowMap) { + s.NotZero(rowMap["id"]) + s.NotZero(rowMap["name"]) + } + err = res.Err() + s.NoError(err) + + err = res.Close() + s.NoError(err) + + // Dumping into a tagged struct. + rowStruct2 := struct { + Value1 int64 `db:"id"` + Value2 string `db:"name"` + }{} + + res = artist.Find() + + for res.Next(&rowStruct2) { + s.NotZero(rowStruct2.Value1) + s.NotZero(rowStruct2.Value2) + } + err = res.Err() + s.NoError(err) + + err = res.Close() + s.NoError(err) + + // Dumping into a slice of maps. + allRowsMap := []map[string]interface{}{} + + res = artist.Find() + + err = res.All(&allRowsMap) + s.NoError(err) + s.Equal(4, len(allRowsMap)) + + for _, singleRowMap := range allRowsMap { + if fmt.Sprintf("%d", singleRowMap["id"]) == "0" { + s.T().Errorf("Expecting a not null ID.") + } + } + + // Dumping into a slice of structs. + allRowsStruct := []struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + }{} + + res = artist.Find() + + if err = res.All(&allRowsStruct); err != nil { + s.T().Errorf("%v", err) + } + + s.Equal(4, len(allRowsStruct)) + + for _, singleRowStruct := range allRowsStruct { + s.NotZero(singleRowStruct.ID) + } + + // Dumping into a slice of tagged structs. + allRowsStruct2 := []struct { + Value1 int64 `db:"id"` + Value2 string `db:"name"` + }{} + + res = artist.Find() + + err = res.All(&allRowsStruct2) + s.NoError(err) + + s.Equal(4, len(allRowsStruct2)) + + for _, singleRowStruct := range allRowsStruct2 { + s.NotZero(singleRowStruct.Value1) + } +} + +func (s *SQLTestSuite) TestGetAllResults() { + sess := s.Session() + + artist := sess.Collection("artist") + + total, err := artist.Find().Count() + s.NoError(err) + s.NotZero(total) + + // Fetching all artists into struct + artists := []artistType{} + + res := artist.Find() + + err = res.All(&artists) + s.NoError(err) + s.Equal(len(artists), int(total)) + + s.NotZero(artists[0].Name) + s.NotZero(artists[0].ID) + + // Fetching all artists into struct pointers + artistObjs := []*artistType{} + res = artist.Find() + + err = res.All(&artistObjs) + s.NoError(err) + s.Equal(len(artistObjs), int(total)) + + s.NotZero(artistObjs[0].Name) + s.NotZero(artistObjs[0].ID) +} + +func (s *SQLTestSuite) TestInlineStructs() { + type reviewTypeDetails struct { + Name string `db:"name"` + Comments string `db:"comments"` + Created time.Time `db:"created"` + } + + type reviewType struct { + ID int64 `db:"id,omitempty"` + PublicationID int64 `db:"publication_id"` + Details reviewTypeDetails `db:",inline"` + } + + sess := s.Session() + + review := sess.Collection("review") + + err := review.Truncate() + s.NoError(err) + + rec := reviewType{ + PublicationID: 123, + Details: reviewTypeDetails{ + Name: "..name..", + Comments: "..comments..", + }, + } + + testTimeZone := time.UTC + switch s.Adapter() { + case "mysql": // MySQL uses a global time zone + testTimeZone = defaultTimeLocation + } + + createdAt := time.Date(2016, time.January, 1, 2, 3, 4, 0, testTimeZone) + rec.Details.Created = createdAt + + record, err := review.Insert(rec) + s.NoError(err) + s.NotZero(record.ID().(int64)) + + rec.ID = record.ID().(int64) + + var recChk reviewType + res := review.Find() + + err = res.One(&recChk) + s.NoError(err) + + s.Equal(rec, recChk) +} + +func (s *SQLTestSuite) TestUpdate() { + sess := s.Session() + + artist := sess.Collection("artist") + + _, err := artist.Insert(map[string]string{ + "name": "Ozzie", + }) + s.NoError(err) + + // Defining destination struct + value := struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + }{} + + // Getting the first artist. + cond := mydb.Cond{"id !=": mydb.NotEq(0)} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id() !=": 0} + } + res := artist.Find(cond).Limit(1) + + err = res.One(&value) + s.NoError(err) + + res = artist.Find(value.ID) + + // Updating set with a map + rowMap := map[string]interface{}{ + "name": strings.ToUpper(value.Name), + } + + err = res.Update(rowMap) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, rowMap["name"]) + + if s.Adapter() != "ql" { + + // Updating using raw + if err = res.Update(map[string]interface{}{"name": mydb.Raw("LOWER(name)")}); err != nil { + s.T().Errorf("%v", err) + } + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, strings.ToLower(rowMap["name"].(string))) + + // Updating using raw + if err = res.Update(struct { + Name *mydb.RawExpr `db:"name"` + }{mydb.Raw(`UPPER(name)`)}); err != nil { + s.T().Errorf("%v", err) + } + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, strings.ToUpper(rowMap["name"].(string))) + + // Updating using raw + if err = res.Update(struct { + Name *mydb.FuncExpr `db:"name"` + }{mydb.Func("LOWER", mydb.Raw("name"))}); err != nil { + s.T().Errorf("%v", err) + } + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, strings.ToLower(rowMap["name"].(string))) + } + + // Updating set with a struct + rowStruct := struct { + Name string `db:"name"` + }{strings.ToLower(value.Name)} + + err = res.Update(rowStruct) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying + s.Equal(value.Name, rowStruct.Name) + + // Updating set with a tagged struct + rowStruct2 := struct { + Value1 string `db:"name"` + }{"john"} + + err = res.Update(rowStruct2) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying + s.Equal(value.Name, rowStruct2.Value1) + + // Updating set with a tagged object + rowStruct3 := &struct { + Value1 string `db:"name"` + }{"anderson"} + + err = res.Update(rowStruct3) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying + s.Equal(value.Name, rowStruct3.Value1) +} + +func (s *SQLTestSuite) TestFunction() { + sess := s.Session() + + rowStruct := struct { + ID int64 + Name string + }{} + + artist := sess.Collection("artist") + + cond := mydb.Cond{"id NOT IN": []int{0, -1}} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id() NOT IN": []int{0, -1}} + } + res := artist.Find(cond) + + err := res.One(&rowStruct) + s.NoError(err) + + total, err := res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + // Testing conditions + cond = mydb.Cond{"id NOT IN": []interface{}{0, -1}} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id() NOT IN": []interface{}{0, -1}} + } + res = artist.Find(cond) + + err = res.One(&rowStruct) + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + res = artist.Find().Select("name") + + var rowMap map[string]interface{} + err = res.One(&rowMap) + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + res = artist.Find().Select("name") + + err = res.One(&rowMap) + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) +} + +func (s *SQLTestSuite) TestNullableFields() { + sess := s.Session() + + type testType struct { + ID int64 `db:"id,omitempty"` + NullStringTest sql.NullString `db:"_string"` + NullInt64Test sql.NullInt64 `db:"_int64"` + NullFloat64Test sql.NullFloat64 `db:"_float64"` + NullBoolTest sql.NullBool `db:"_bool"` + } + + col := sess.Collection(`data_types`) + + err := col.Truncate() + s.NoError(err) + + // Testing insertion of invalid nulls. + test := testType{ + NullStringTest: sql.NullString{String: "", Valid: false}, + NullInt64Test: sql.NullInt64{Int64: 0, Valid: false}, + NullFloat64Test: sql.NullFloat64{Float64: 0.0, Valid: false}, + NullBoolTest: sql.NullBool{Bool: false, Valid: false}, + } + + id, err := col.Insert(testType{}) + s.NoError(err) + + // Testing fetching of invalid nulls. + err = col.Find(id).One(&test) + s.NoError(err) + + s.False(test.NullInt64Test.Valid) + s.False(test.NullFloat64Test.Valid) + s.False(test.NullBoolTest.Valid) + + // Testing insertion of valid nulls. + test = testType{ + NullStringTest: sql.NullString{String: "", Valid: true}, + NullInt64Test: sql.NullInt64{Int64: 0, Valid: true}, + NullFloat64Test: sql.NullFloat64{Float64: 0.0, Valid: true}, + NullBoolTest: sql.NullBool{Bool: false, Valid: true}, + } + + id, err = col.Insert(test) + s.NoError(err) + + // Testing fetching of valid nulls. + err = col.Find(id).One(&test) + s.NoError(err) + + s.True(test.NullInt64Test.Valid) + s.True(test.NullBoolTest.Valid) + s.True(test.NullStringTest.Valid) +} + +func (s *SQLTestSuite) TestGroup() { + sess := s.Session() + + type statsType struct { + Numeric int `db:"numeric"` + Value int `db:"value"` + } + + stats := sess.Collection("stats_test") + + err := stats.Truncate() + s.NoError(err) + + // Adding row append. + for i := 0; i < 100; i++ { + numeric, value := rand.Intn(5), rand.Intn(100) + _, err := stats.Insert(statsType{numeric, value}) + s.NoError(err) + } + + // Testing GROUP BY + res := stats.Find().Select( + "numeric", + mydb.Raw("count(1) AS counter"), + mydb.Raw("sum(value) AS total"), + ).GroupBy("numeric") + + var results []map[string]interface{} + + err = res.All(&results) + s.NoError(err) + + s.Equal(5, len(results)) +} + +func (s *SQLTestSuite) TestInsertAndDelete() { + sess := s.Session() + + artist := sess.Collection("artist") + res := artist.Find() + + total, err := res.Count() + s.NoError(err) + s.Greater(total, uint64(0)) + + err = res.Delete() + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(0), total) +} + +func (s *SQLTestSuite) TestCompositeKeys() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + sess := s.Session() + + compositeKeys := sess.Collection("composite_keys") + + { + n := rand.Intn(100000) + + item := itemWithCompoundKey{ + "ABCDEF", + strconv.Itoa(n), + "Some value", + } + + id, err := compositeKeys.Insert(&item) + s.NoError(err) + s.NotZero(id) + + var item2 itemWithCompoundKey + s.NotEqual(item2.SomeVal, item.SomeVal) + + // Finding by ID + err = compositeKeys.Find(id).One(&item2) + s.NoError(err) + + s.Equal(item2.SomeVal, item.SomeVal) + } + + { + n := rand.Intn(100000) + + item := itemWithCompoundKey{ + "ABCDEF", + strconv.Itoa(n), + "Some value", + } + + err := compositeKeys.InsertReturning(&item) + s.NoError(err) + } +} + +// Attempts to test database transactions. +func (s *SQLTestSuite) TestTransactionsAndRollback() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + sess := s.Session() + + err := sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + err := artist.Truncate() + s.NoError(err) + + _, err = artist.Insert(artistType{1, "First"}) + s.NoError(err) + + return nil + }) + s.NoError(err) + + err = sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + + _, err = artist.Insert(artistType{2, "Second"}) + s.NoError(err) + + // Won't fail. + _, err = artist.Insert(artistType{3, "Third"}) + s.NoError(err) + + // Will fail. + _, err = artist.Insert(artistType{1, "Duplicated"}) + s.Error(err) + + return err + }) + s.Error(err) + + // Let's verify we still have one element. + artist := sess.Collection("artist") + + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(1), count) + + err = sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + + // Won't fail. + _, err = artist.Insert(artistType{2, "Second"}) + s.NoError(err) + + // Won't fail. + _, err = artist.Insert(artistType{3, "Third"}) + s.NoError(err) + + return fmt.Errorf("rollback for no reason") + }) + s.Error(err) + + // Let's verify we still have one element. + artist = sess.Collection("artist") + + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(1), count) + + // Attempt to add some rows. + err = sess.Tx(func(tx mydb.Session) error { + artist = tx.Collection("artist") + + // Won't fail. + _, err = artist.Insert(artistType{2, "Second"}) + s.NoError(err) + + // Won't fail. + _, err = artist.Insert(artistType{3, "Third"}) + s.NoError(err) + + return nil + }) + s.NoError(err) + + // Let's verify we have 3 rows. + artist = sess.Collection("artist") + + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(3), count) +} + +func (s *SQLTestSuite) TestDataTypes() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + type testValuesStruct struct { + Uint uint `db:"_uint"` + Uint8 uint8 `db:"_uint8"` + Uint16 uint16 `db:"_uint16"` + Uint32 uint32 `db:"_uint32"` + Uint64 uint64 `db:"_uint64"` + + Int int `db:"_int"` + Int8 int8 `db:"_int8"` + Int16 int16 `db:"_int16"` + Int32 int32 `db:"_int32"` + Int64 int64 `db:"_int64"` + + Float32 float32 `db:"_float32"` + Float64 float64 `db:"_float64"` + + Bool bool `db:"_bool"` + String string `db:"_string"` + Blob []byte `db:"_blob"` + + Date time.Time `db:"_date"` + DateN *time.Time `db:"_nildate"` + DateP *time.Time `db:"_ptrdate"` + DateD *time.Time `db:"_defaultdate,omitempty"` + Time int64 `db:"_time"` + } + + sess := s.Session() + + // Getting a pointer to the "data_types" collection. + dataTypes := sess.Collection("data_types") + + // Removing all data. + err := dataTypes.Truncate() + s.NoError(err) + + testTimeZone := time.Local + switch s.Adapter() { + case "mysql", "postgresql": // MySQL uses a global time zone + testTimeZone = defaultTimeLocation + } + + ts := time.Date(2011, 7, 28, 1, 2, 3, 0, testTimeZone) + tnz := ts.In(time.UTC) + + switch s.Adapter() { + case "mysql": + // MySQL uses a global timezone + tnz = ts.In(defaultTimeLocation) + } + + testValues := testValuesStruct{ + 1, 1, 1, 1, 1, + -1, -1, -1, -1, -1, + + 1.337, 1.337, + + true, + "Hello world!", + []byte("Hello world!"), + + ts, // Date + nil, // DateN + &tnz, // DateP + nil, // DateD + int64(time.Second * time.Duration(7331)), + } + id, err := dataTypes.Insert(testValues) + s.NoError(err) + s.NotNil(id) + + // Defining our set. + cond := mydb.Cond{"id": id} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id()": id} + } + res := dataTypes.Find(cond) + + count, err := res.Count() + s.NoError(err) + s.NotZero(count) + + // Trying to dump the subject into an empty structure of the same type. + var item testValuesStruct + + err = res.One(&item) + s.NoError(err) + + s.NotNil(item.DateD) + s.NotNil(item.Date) + + // Copy the default date (this value is set by the database) + testValues.DateD = item.DateD + item.Date = item.Date.In(testTimeZone) + + s.Equal(testValues.Date, item.Date) + s.Equal(testValues.DateN, item.DateN) + s.Equal(testValues.DateP, item.DateP) + s.Equal(testValues.DateD, item.DateD) + + // The original value and the test subject must match. + s.Equal(testValues, item) +} + +func (s *SQLTestSuite) TestUpdateWithNullColumn() { + sess := s.Session() + + artist := sess.Collection("artist") + err := artist.Truncate() + s.NoError(err) + + type Artist struct { + ID int64 `db:"id,omitempty"` + Name *string `db:"name"` + } + + name := "José" + id, err := artist.Insert(Artist{0, &name}) + s.NoError(err) + + var item Artist + err = artist.Find(id).One(&item) + s.NoError(err) + + s.NotEqual(nil, item.Name) + s.Equal(name, *item.Name) + + err = artist.Find(id).Update(Artist{Name: nil}) + s.NoError(err) + + var item2 Artist + err = artist.Find(id).One(&item2) + s.NoError(err) + + s.Equal((*string)(nil), item2.Name) +} + +func (s *SQLTestSuite) TestBatchInsert() { + sess := s.Session() + + for batchSize := 0; batchSize < 17; batchSize++ { + err := sess.Collection("artist").Truncate() + s.NoError(err) + + q := sess.SQL().InsertInto("artist").Columns("name") + + switch s.Adapter() { + case "postgresql", "cockroachdb": + q = q.Amend(func(query string) string { + return query + ` ON CONFLICT DO NOTHING` + }) + } + + batch := q.Batch(batchSize) + + totalItems := int(rand.Int31n(21)) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + + c, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(totalItems), c) + + for i := 0; i < totalItems; i++ { + c, err := sess.Collection("artist").Find(mydb.Cond{"name": fmt.Sprintf("artist-%d", i)}).Count() + s.NoError(err) + s.Equal(uint64(1), c) + } + } +} + +func (s *SQLTestSuite) TestBatchInsertNoColumns() { + sess := s.Session() + + for batchSize := 0; batchSize < 17; batchSize++ { + err := sess.Collection("artist").Truncate() + s.NoError(err) + + batch := sess.SQL().InsertInto("artist").Batch(batchSize) + + totalItems := int(rand.Int31n(21)) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + + c, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(totalItems), c) + + for i := 0; i < totalItems; i++ { + c, err := sess.Collection("artist").Find(mydb.Cond{"name": fmt.Sprintf("artist-%d", i)}).Count() + s.NoError(err) + s.Equal(uint64(1), c) + } + } +} + +func (s *SQLTestSuite) TestBatchInsertReturningKeys() { + switch s.Adapter() { + case "postgresql", "cockroachdb": + // pass + default: + s.T().Skip("Currently not supported.") + return + } + + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + batchSize, totalItems := 7, 12 + + batch := sess.SQL().InsertInto("artist").Columns("name").Returning("id").Batch(batchSize) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + }() + + var keyMap []struct { + ID int `db:"id"` + } + for batch.NextResult(&keyMap) { + // Each insertion must produce new keys. + s.True(len(keyMap) > 0) + s.True(len(keyMap) <= batchSize) + + // Find the elements we've just inserted + keys := make([]int, 0, len(keyMap)) + for i := range keyMap { + keys = append(keys, keyMap[i].ID) + } + + // Make sure count matches. + c, err := sess.Collection("artist").Find(mydb.Cond{"id": keys}).Count() + s.NoError(err) + s.Equal(uint64(len(keyMap)), c) + } + s.NoError(batch.Err()) + + // Count all new elements + c, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(totalItems), c) +} + +func (s *SQLTestSuite) TestPaginator() { + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + batch := sess.SQL().InsertInto("artist").Batch(100) + + go func() { + defer batch.Done() + for i := 0; i < 999; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + + q := sess.SQL().SelectFrom("artist") + if s.Adapter() == "ql" { + q = sess.SQL().SelectFrom(sess.SQL().Select("id() AS id", "name").From("artist")) + } + + const pageSize = 13 + cursorColumn := "id" + + paginator := q.Paginate(pageSize) + + var zerothPage []artistType + err = paginator.Page(0).All(&zerothPage) + s.NoError(err) + s.Equal(pageSize, len(zerothPage)) + + var firstPage []artistType + err = paginator.Page(1).All(&firstPage) + s.NoError(err) + s.Equal(pageSize, len(firstPage)) + + s.Equal(zerothPage, firstPage) + + var secondPage []artistType + err = paginator.Page(2).All(&secondPage) + s.NoError(err) + s.Equal(pageSize, len(secondPage)) + + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(77), totalPages) + + totalEntries, err := paginator.TotalEntries() + s.NoError(err) + s.NotZero(totalEntries) + s.Equal(uint64(999), totalEntries) + + var lastPage []artistType + err = paginator.Page(totalPages).All(&lastPage) + s.NoError(err) + s.Equal(11, len(lastPage)) + + var beyondLastPage []artistType + err = paginator.Page(totalPages + 1).All(&beyondLastPage) + s.NoError(err) + s.Equal(0, len(beyondLastPage)) + + var hundredthPage []artistType + err = paginator.Page(100).All(&hundredthPage) + s.NoError(err) + s.Equal(0, len(hundredthPage)) + + for i := uint(0); i < totalPages; i++ { + current := paginator.Page(i + 1) + + var items []artistType + err := current.All(&items) + if err != nil { + s.T().Errorf("%v", err) + } + s.NoError(err) + if len(items) < 1 { + s.Equal(totalPages+1, i) + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(pageSize*int(i)+j)), items[j].Name) + } + } + + paginator = paginator.Cursor(cursorColumn) + { + current := paginator.Page(1) + for i := 0; ; i++ { + var items []artistType + err := current.All(&items) + if err != nil { + s.T().Errorf("%v", err) + } + if len(items) < 1 { + s.Equal(int(totalPages), i) + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(pageSize*int(i)+j)), items[j].Name) + } + current = current.NextPage(items[len(items)-1].ID) + } + } + + { + current := paginator.Page(totalPages) + for i := totalPages; ; i-- { + var items []artistType + + err := current.All(&items) + s.NoError(err) + + if len(items) < 1 { + s.Equal(uint(0), i) + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", pageSize*int(i-1)+j), items[j].Name) + } + + current = current.PrevPage(items[0].ID) + } + } + + if s.Adapter() == "ql" { + s.T().Skip("Unsupported, see https://github.com/cznic/ql/issues/182") + return + } + + { + result := sess.Collection("artist").Find() + + fifteenResults := 15 + resultPaginator := result.Paginate(uint(fifteenResults)) + + count, err := resultPaginator.TotalPages() + s.Equal(uint(67), count) + s.NoError(err) + + var items []artistType + fifthPage := 5 + err = resultPaginator.Page(uint(fifthPage)).All(&items) + s.NoError(err) + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int(fifteenResults)*(fifthPage-1)+j), items[j].Name) + } + + resultPaginator = resultPaginator.Cursor(cursorColumn).Page(1) + for i := 0; ; i++ { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", fifteenResults*i+j), items[j].Name) + } + resultPaginator = resultPaginator.NextPage(items[len(items)-1].ID) + } + + resultPaginator = resultPaginator.Cursor(cursorColumn).Page(count) + for i := count; ; i-- { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + s.Equal(uint(0), i) + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", fifteenResults*(int(i)-1)+j), items[j].Name) + } + resultPaginator = resultPaginator.PrevPage(items[0].ID) + } + } + + { + // Testing page size 0. + paginator := q.Paginate(0) + + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.Equal(uint(1), totalPages) + + totalEntries, err := paginator.TotalEntries() + s.NoError(err) + s.Equal(uint64(999), totalEntries) + + var allItems []artistType + err = paginator.Page(0).All(&allItems) + s.NoError(err) + s.Equal(totalEntries, uint64(len(allItems))) + + } +} + +func (s *SQLTestSuite) TestPaginator_Issue607() { + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + // Add first batch + { + batch := sess.SQL().InsertInto("artist").Batch(50) + + go func() { + defer batch.Done() + for i := 0; i < 49; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-1.%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + } + + artists := []*artistType{} + paginator := sess.SQL().Select("name").From("artist").Paginate(10) + + err = paginator.Page(1).All(&artists) + s.NoError(err) + + { + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(5), totalPages) + } + + // Add second batch + { + batch := sess.SQL().InsertInto("artist").Batch(50) + + go func() { + defer batch.Done() + for i := 0; i < 49; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-2.%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + } + + { + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(10), totalPages, "expect number of pages to change") + } + + artists = []*artistType{} + + cond := mydb.Cond{"name": mydb.Like("artist-1.%")} + if s.Adapter() == "ql" { + cond = mydb.Cond{"name": mydb.Like("artist-1.")} + } + + paginator = sess.SQL().Select("name").From("artist").Where(cond).Paginate(10) + + err = paginator.Page(1).All(&artists) + s.NoError(err) + + { + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(5), totalPages, "expect same 5 pages from the first batch") + } + +} + +func (s *SQLTestSuite) TestSession() { + sess := s.Session() + + var all []map[string]interface{} + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + _, err = sess.SQL().InsertInto("artist").Values(struct { + Name string `db:"name"` + }{"Rinko Kikuchi"}).Exec() + s.NoError(err) + + // Using explicit iterator. + iter := sess.SQL().SelectFrom("artist").Iterator() + err = iter.All(&all) + + s.NoError(err) + s.NotZero(all) + + // Using explicit iterator to fetch one item. + var item map[string]interface{} + iter = sess.SQL().SelectFrom("artist").Iterator() + err = iter.One(&item) + + s.NoError(err) + s.NotZero(item) + + // Using explicit iterator and NextScan. + iter = sess.SQL().SelectFrom("artist").Iterator() + var id int + var name string + + if s.Adapter() == "ql" { + err = iter.NextScan(&name) + id = 1 + } else { + err = iter.NextScan(&id, &name) + } + + s.NoError(err) + s.NotZero(id) + s.NotEmpty(name) + s.NoError(iter.Close()) + + err = iter.NextScan(&id, &name) + s.Error(err) + + // Using explicit iterator and ScanOne. + iter = sess.SQL().SelectFrom("artist").Iterator() + id, name = 0, "" + if s.Adapter() == "ql" { + err = iter.ScanOne(&name) + id = 1 + } else { + err = iter.ScanOne(&id, &name) + } + + s.NoError(err) + s.NotZero(id) + s.NotEmpty(name) + + err = iter.ScanOne(&id, &name) + s.Error(err) + + // Using explicit iterator and Next. + iter = sess.SQL().SelectFrom("artist").Iterator() + + var artist map[string]interface{} + for iter.Next(&artist) { + if s.Adapter() != "ql" { + s.NotZero(artist["id"]) + } + s.NotEmpty(artist["name"]) + } + // We should not have any error after finishing successfully exiting a Next() loop. + s.Empty(iter.Err()) + + for i := 0; i < 5; i++ { + // But we'll get errors if we attempt to continue using Next(). + s.False(iter.Next(&artist)) + s.Error(iter.Err()) + } + + // Using implicit iterator. + q := sess.SQL().SelectFrom("artist") + err = q.All(&all) + + s.NoError(err) + s.NotZero(all) + + err = sess.Tx(func(tx mydb.Session) error { + q := tx.SQL().SelectFrom("artist") + s.NotZero(iter) + + err = q.All(&all) + s.NoError(err) + s.NotZero(all) + + return nil + }) + + s.NoError(err) +} + +func (s *SQLTestSuite) TestExhaustConnectionPool() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + return + } + + sess := s.Session() + errRolledBack := errors.New("rolled back") + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + s.T().Logf("Tx %d: Pending", i) + + wg.Add(1) + go func(wg *sync.WaitGroup, i int) { + defer wg.Done() + + // Requesting a new transaction session. + start := time.Now() + s.T().Logf("Tx: %d: NewTx", i) + + expectError := false + if i%2 == 1 { + expectError = true + } + + err := sess.Tx(func(tx mydb.Session) error { + s.T().Logf("Tx %d: OK (time to connect: %v)", i, time.Since(start)) + // Let's suppose that we do a bunch of complex stuff and that the + // transaction lasts 3 seconds. + time.Sleep(time.Second * 3) + + if expectError { + if _, err := tx.SQL().DeleteFrom("artist").Exec(); err != nil { + return err + } + return errRolledBack + } + + var account map[string]interface{} + if err := tx.Collection("artist").Find().One(&account); err != nil { + return err + } + return nil + }) + if expectError { + s.Error(err) + s.True(errors.Is(err, errRolledBack)) + } else { + s.NoError(err) + } + }(&wg, i) + } + + wg.Wait() +} + +func (s *SQLTestSuite) TestCustomType() { + // See https://github.com/upper/db/issues/332 + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + id, err := artist.Insert(artistWithCustomType{ + Custom: customType{Val: []byte("some name")}, + }) + s.NoError(err) + s.NotNil(id) + + var bar artistWithCustomType + err = artist.Find(id).One(&bar) + s.NoError(err) + + s.Equal("foo: some name", string(bar.Custom.Val)) +} + +func (s *SQLTestSuite) Test_Issue565() { + s.Session().Collection("birthdays").Insert(&birthday{ + Name: "Lucy", + Born: time.Now(), + }) + + parentCtx := context.WithValue(s.Session().Context(), "carry", 1) + s.NotZero(parentCtx.Value("carry")) + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Nanosecond) + defer cancel() + + sess := s.Session() + + sess = sess.WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.Error(err) + s.Zero(result.Name) + + s.NotZero(ctx.Value("carry")) + } + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Second*10) + cancel() // cancel before passing + + sess := s.Session().WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.Error(err) + s.Zero(result.Name) + + s.NotZero(ctx.Value("carry")) + } + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Second) + defer cancel() + + sess := s.Session().WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.NoError(err) + s.NotZero(result.Name) + + s.NotZero(ctx.Value("carry")) + } +} + +func (s *SQLTestSuite) TestSelectFromSubquery() { + sess := s.Session() + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.SQL().SelectFrom("artist").Where(mydb.Cond{ + "name": mydb.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.Collection("artist").Find(mydb.Cond{ + "name": mydb.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + +} diff --git a/internal/testsuite/suite.go b/internal/testsuite/suite.go new file mode 100644 index 0000000..52ae956 --- /dev/null +++ b/internal/testsuite/suite.go @@ -0,0 +1,37 @@ +package testsuite + +import ( + "time" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/suite" +) + +const TimeZone = "Canada/Eastern" + +var defaultTimeLocation, _ = time.LoadLocation(TimeZone) + +type Helper interface { + Session() mydb.Session + + Adapter() string + + TearUp() error + TearDown() error +} + +type Suite struct { + suite.Suite + + Helper +} + +func (s *Suite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *Suite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) +} diff --git a/intersection.go b/intersection.go new file mode 100644 index 0000000..dda67a0 --- /dev/null +++ b/intersection.go @@ -0,0 +1,50 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// AndExpr represents an expression joined by a logical conjuction (AND). +type AndExpr struct { + *adapter.LogicalExprGroup +} + +// And adds more expressions to the group. +func (a *AndExpr) And(andConds ...LogicalExpr) *AndExpr { + var fn func(*[]LogicalExpr) error + if len(andConds) > 0 { + fn = func(in *[]LogicalExpr) error { + *in = append(*in, andConds...) + return nil + } + } + return &AndExpr{a.LogicalExprGroup.Frame(fn)} +} + +// Empty returns false if the expressions has zero conditions. +func (a *AndExpr) Empty() bool { + return a.LogicalExprGroup.Empty() +} + +// And joins conditions under logical conjunction. Conditions can be +// represented by `db.Cond{}`, `db.Or()` or `db.And()`. +// +// Examples: +// +// // name = "Peter" AND last_name = "Parker" +// db.And( +// db.Cond{"name": "Peter"}, +// db.Cond{"last_name": "Parker "}, +// ) +// +// // (name = "Peter" OR name = "Mickey") AND last_name = "Mouse" +// db.And( +// db.Or( +// db.Cond{"name": "Peter"}, +// db.Cond{"name": "Mickey"}, +// ), +// db.Cond{"last_name": "Mouse"}, +// ) +func And(conds ...LogicalExpr) *AndExpr { + return &AndExpr{adapter.NewLogicalExprGroup(adapter.LogicalOperatorAnd, conds...)} +} + +var _ = adapter.LogicalExpr(&AndExpr{}) diff --git a/iterator.go b/iterator.go new file mode 100644 index 0000000..c833bba --- /dev/null +++ b/iterator.go @@ -0,0 +1,26 @@ +package mydb + +// Iterator provides methods for iterating over query results. +type Iterator interface { + // ResultMapper provides methods to retrieve and map results. + ResultMapper + + // Scan dumps the current result into the given pointer variable pointers. + Scan(dest ...interface{}) error + + // NextScan advances the iterator and performs Scan. + NextScan(dest ...interface{}) error + + // ScanOne advances the iterator, performs Scan and closes the iterator. + ScanOne(dest ...interface{}) error + + // Next dumps the current element into the given destination, which could be + // a pointer to either a map or a struct. + Next(dest ...interface{}) bool + + // Err returns the last error produced by the cursor. + Err() error + + // Close closes the iterator and frees up the cursor. + Close() error +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..cdb9f2e --- /dev/null +++ b/logger.go @@ -0,0 +1,349 @@ +package mydb + +import ( + "context" + "fmt" + "log" + "os" + "regexp" + "runtime" + "strings" + "time" +) + +const ( + fmtLogSessID = `Session ID: %05d` + fmtLogTxID = `Transaction ID: %05d` + fmtLogQuery = `Query: %s` + fmtLogArgs = `Arguments: %#v` + fmtLogRowsAffected = `Rows affected: %d` + fmtLogLastInsertID = `Last insert ID: %d` + fmtLogError = `Error: %v` + fmtLogStack = `Stack: %v` + fmtLogTimeTaken = `Time taken: %0.5fs` + fmtLogContext = `Context: %v` +) + +const ( + maxFrames = 30 + skipFrames = 3 +) + +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) +) + +// LogLevel represents a verbosity level for logs +type LogLevel int8 + +// Log levels +const ( + LogLevelTrace LogLevel = -1 + + LogLevelDebug LogLevel = iota + LogLevelInfo + LogLevelWarn + LogLevelError + LogLevelFatal + LogLevelPanic +) + +var logLevels = map[LogLevel]string{ + LogLevelTrace: "TRACE", + LogLevelDebug: "DEBUG", + LogLevelInfo: "INFO", + LogLevelWarn: "WARNING", + LogLevelError: "ERROR", + LogLevelFatal: "FATAL", + LogLevelPanic: "PANIC", +} + +func (ll LogLevel) String() string { + return logLevels[ll] +} + +const ( + defaultLogLevel LogLevel = LogLevelWarn +) + +var defaultLogger Logger = log.New(os.Stdout, "", log.LstdFlags) + +// Logger represents a logging interface that is compatible with the standard +// "log" and with many other logging libraries. +type Logger interface { + Fatal(v ...interface{}) + Fatalf(format string, v ...interface{}) + + Print(v ...interface{}) + Printf(format string, v ...interface{}) + + Panic(v ...interface{}) + Panicf(format string, v ...interface{}) +} + +// LoggingCollector provides different methods for collecting and classifying +// log messages. +type LoggingCollector interface { + Enabled(LogLevel) bool + + Level() LogLevel + + SetLogger(Logger) + SetLevel(LogLevel) + + Trace(v ...interface{}) + Tracef(format string, v ...interface{}) + + Debug(v ...interface{}) + Debugf(format string, v ...interface{}) + + Info(v ...interface{}) + Infof(format string, v ...interface{}) + + Warn(v ...interface{}) + Warnf(format string, v ...interface{}) + + Error(v ...interface{}) + Errorf(format string, v ...interface{}) + + Fatal(v ...interface{}) + Fatalf(format string, v ...interface{}) + + Panic(v ...interface{}) + Panicf(format string, v ...interface{}) +} + +type loggingCollector struct { + level LogLevel + logger Logger +} + +func (c *loggingCollector) Enabled(level LogLevel) bool { + return level >= c.level +} + +func (c *loggingCollector) SetLevel(level LogLevel) { + c.level = level +} + +func (c *loggingCollector) Level() LogLevel { + return c.level +} + +func (c *loggingCollector) Logger() Logger { + if c.logger == nil { + return defaultLogger + } + return c.logger +} + +func (c *loggingCollector) SetLogger(logger Logger) { + c.logger = logger +} + +func (c *loggingCollector) logf(level LogLevel, f string, v ...interface{}) { + if level >= LogLevelPanic { + c.Logger().Panicf(f, v...) + } + if level >= LogLevelFatal { + c.Logger().Fatalf(f, v...) + } + if c.Enabled(level) { + c.Logger().Printf(f, v...) + } +} + +func (c *loggingCollector) log(level LogLevel, v ...interface{}) { + if level >= LogLevelPanic { + c.Logger().Panic(v...) + } + if level >= LogLevelFatal { + c.Logger().Fatal(v...) + } + if c.Enabled(level) { + c.Logger().Print(v...) + } +} + +func (c *loggingCollector) Debugf(format string, v ...interface{}) { + c.logf(LogLevelDebug, format, v...) +} +func (c *loggingCollector) Debug(v ...interface{}) { + c.log(LogLevelDebug, v...) +} + +func (c *loggingCollector) Tracef(format string, v ...interface{}) { + c.logf(LogLevelTrace, format, v...) +} +func (c *loggingCollector) Trace(v ...interface{}) { + c.log(LogLevelDebug, v...) +} + +func (c *loggingCollector) Infof(format string, v ...interface{}) { + c.logf(LogLevelInfo, format, v...) +} +func (c *loggingCollector) Info(v ...interface{}) { + c.log(LogLevelInfo, v...) +} + +func (c *loggingCollector) Warnf(format string, v ...interface{}) { + c.logf(LogLevelWarn, format, v...) +} +func (c *loggingCollector) Warn(v ...interface{}) { + c.log(LogLevelWarn, v...) +} + +func (c *loggingCollector) Errorf(format string, v ...interface{}) { + c.logf(LogLevelError, format, v...) +} +func (c *loggingCollector) Error(v ...interface{}) { + c.log(LogLevelError, v...) +} + +func (c *loggingCollector) Fatalf(format string, v ...interface{}) { + c.logf(LogLevelFatal, format, v...) +} +func (c *loggingCollector) Fatal(v ...interface{}) { + c.log(LogLevelFatal, v...) +} + +func (c *loggingCollector) Panicf(format string, v ...interface{}) { + c.logf(LogLevelPanic, format, v...) +} +func (c *loggingCollector) Panic(v ...interface{}) { + c.log(LogLevelPanic, v...) +} + +var defaultLoggingCollector LoggingCollector = &loggingCollector{ + level: defaultLogLevel, + logger: defaultLogger, +} + +// QueryStatus represents the status of a query after being executed. +type QueryStatus struct { + SessID uint64 + TxID uint64 + + RowsAffected *int64 + LastInsertID *int64 + + RawQuery string + Args []interface{} + + Err error + + Start time.Time + End time.Time + + Context context.Context +} + +func (q *QueryStatus) Query() string { + query := reInvisibleChars.ReplaceAllString(q.RawQuery, " ") + query = strings.TrimSpace(query) + return query +} + +func (q *QueryStatus) Stack() []string { + frames := collectFrames() + lines := make([]string, 0, len(frames)) + + for _, frame := range frames { + lines = append(lines, fmt.Sprintf("%s@%s:%d", frame.Function, frame.File, frame.Line)) + } + return lines +} + +// String returns a formatted log message. +func (q *QueryStatus) String() string { + lines := make([]string, 0, 8) + + if q.SessID > 0 { + lines = append(lines, fmt.Sprintf(fmtLogSessID, q.SessID)) + } + + if q.TxID > 0 { + lines = append(lines, fmt.Sprintf(fmtLogTxID, q.TxID)) + } + + if query := q.RawQuery; query != "" { + lines = append(lines, fmt.Sprintf(fmtLogQuery, q.Query())) + } + + if len(q.Args) > 0 { + lines = append(lines, fmt.Sprintf(fmtLogArgs, q.Args)) + } + + if stack := q.Stack(); len(stack) > 0 { + lines = append(lines, fmt.Sprintf(fmtLogStack, "\n\t"+strings.Join(stack, "\n\t"))) + } + + if q.RowsAffected != nil { + lines = append(lines, fmt.Sprintf(fmtLogRowsAffected, *q.RowsAffected)) + } + if q.LastInsertID != nil { + lines = append(lines, fmt.Sprintf(fmtLogLastInsertID, *q.LastInsertID)) + } + + if q.Err != nil { + lines = append(lines, fmt.Sprintf(fmtLogError, q.Err)) + } + + lines = append(lines, fmt.Sprintf(fmtLogTimeTaken, float64(q.End.UnixNano()-q.Start.UnixNano())/float64(1e9))) + + if q.Context != nil { + lines = append(lines, fmt.Sprintf(fmtLogContext, q.Context)) + } + + return "\t" + strings.Replace(strings.Join(lines, "\n"), "\n", "\n\t", -1) + "\n\n" +} + +// LC returns the logging collector. +func LC() LoggingCollector { + return defaultLoggingCollector +} + +func init() { + if logLevel := strings.ToUpper(os.Getenv("UPPER_DB_LOG")); logLevel != "" { + for ll := range logLevels { + if ll.String() == logLevel { + LC().SetLevel(ll) + break + } + } + } +} + +func collectFrames() []runtime.Frame { + pc := make([]uintptr, maxFrames) + n := runtime.Callers(skipFrames, pc) + if n == 0 { + return nil + } + + pc = pc[:n] + frames := runtime.CallersFrames(pc) + + collectedFrames := make([]runtime.Frame, 0, maxFrames) + discardedFrames := make([]runtime.Frame, 0, maxFrames) + for { + frame, more := frames.Next() + + // collect all frames except those from upper/db and runtime stack + if (strings.Contains(frame.Function, "upper/db") || strings.Contains(frame.Function, "/go/src/")) && !strings.Contains(frame.Function, "test") { + discardedFrames = append(discardedFrames, frame) + } else { + collectedFrames = append(collectedFrames, frame) + } + + if !more { + break + } + } + + if len(collectedFrames) < 1 { + return discardedFrames + } + + return collectedFrames +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..5295d97 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,11 @@ +package mydb + +import ( + "errors" + "testing" +) + +func TestLogger(t *testing.T) { + err := errors.New("fake error") + LC().Error(err) +} diff --git a/marshal.go b/marshal.go new file mode 100644 index 0000000..b9dd64a --- /dev/null +++ b/marshal.go @@ -0,0 +1,16 @@ +package mydb + +// Marshaler is the interface implemented by struct fields that can transform +// themselves into values to be stored in a database. +type Marshaler interface { + // MarshalDB returns the internal database representation of the Go value. + MarshalDB() (interface{}, error) +} + +// Unmarshaler is the interface implemented by struct fields that can transform +// themselves from database values into Go values. +type Unmarshaler interface { + // UnmarshalDB receives an internal database representation of a value and + // transforms it into a Go value. + UnmarshalDB(interface{}) error +} diff --git a/mydb.go b/mydb.go new file mode 100644 index 0000000..2942fe8 --- /dev/null +++ b/mydb.go @@ -0,0 +1,50 @@ +// Package db (or tiglog/mydb) provides an agnostic data access layer to work with +// different databases. +// +// Install tiglog/mydb: +// +// go get git.hexq.cn/tiglog/mydb +// +// Usage +// +// package main +// +// import ( +// "log" +// +// "git.hexq.cn/tiglog/mydb/adapter/postgresql" // Imports the postgresql adapter. +// ) +// +// var settings = postgresql.ConnectionURL{ +// Database: `booktown`, +// Host: `demo.upper.io`, +// User: `demouser`, +// Password: `demop4ss`, +// } +// +// // Book represents a book. +// type Book struct { +// ID uint `db:"id"` +// Title string `db:"title"` +// AuthorID uint `db:"author_id"` +// SubjectID uint `db:"subject_id"` +// } +// +// func main() { +// sess, err := postgresql.Open(settings) +// if err != nil { +// log.Fatal(err) +// } +// defer sess.Close() +// +// var books []Book +// if err := sess.Collection("books").Find().OrderBy("title").All(&books); err != nil { +// log.Fatal(err) +// } +// +// log.Println("Books:") +// for _, book := range books { +// log.Printf("%q (ID: %d)\n", book.Title, book.ID) +// } +// } +package mydb diff --git a/raw.go b/raw.go new file mode 100644 index 0000000..44164fd --- /dev/null +++ b/raw.go @@ -0,0 +1,17 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// RawExpr represents a raw (non-filtered) expression. +type RawExpr = adapter.RawExpr + +// Raw marks chunks of data as protected, so they pass directly to the query +// without any filtering. Use with care. +// +// Example: +// +// // SOUNDEX('Hello') +// Raw("SOUNDEX('Hello')") +func Raw(value string, args ...interface{}) *RawExpr { + return adapter.NewRawExpr(value, args) +} diff --git a/readme.adoc b/readme.adoc new file mode 100644 index 0000000..0c91d29 --- /dev/null +++ b/readme.adoc @@ -0,0 +1,18 @@ += 说明 +:author: tiglog +:experimental: +:toc: left +:toclevels: 3 +:toc-title: 目录 +:sectnums: +:icons: font +:!webfonts: +:autofit-option: +:source-highlighter: rouge +:rouge-style: github +:source-linenums-option: +:revdate: 2023-09-18 +:imagesdir: ./img + + +基于 `upper/db` 改造。 diff --git a/record.go b/record.go new file mode 100644 index 0000000..20f3ead --- /dev/null +++ b/record.go @@ -0,0 +1,62 @@ +package mydb + +// Record is the equivalence between concrete database schemas and Go values. +type Record interface { + Store(sess Session) Store +} + +// HasConstraints is an interface for records that defines a Constraints method +// that returns the record's own constraints. +type HasConstraints interface { + Constraints() Cond +} + +// Validator is an interface for records that defines an (optional) Validate +// method that is called before persisting a record (creating or updating). If +// Validate returns an error the current operation is cancelled and rolled +// back. +type Validator interface { + Validate() error +} + +// BeforeCreateHook is an interface for records that defines an BeforeCreate +// method that is called before creating a record. If BeforeCreate returns an +// error the create process is cancelled and rolled back. +type BeforeCreateHook interface { + BeforeCreate(Session) error +} + +// AfterCreateHook is an interface for records that defines an AfterCreate +// method that is called after creating a record. If AfterCreate returns an +// error the create process is cancelled and rolled back. +type AfterCreateHook interface { + AfterCreate(Session) error +} + +// BeforeUpdateHook is an interface for records that defines a BeforeUpdate +// method that is called before updating a record. If BeforeUpdate returns an +// error the update process is cancelled and rolled back. +type BeforeUpdateHook interface { + BeforeUpdate(Session) error +} + +// AfterUpdateHook is an interface for records that defines an AfterUpdate +// method that is called after updating a record. If AfterUpdate returns an +// error the update process is cancelled and rolled back. +type AfterUpdateHook interface { + AfterUpdate(Session) error +} + +// BeforeDeleteHook is an interface for records that defines a BeforeDelete +// method that is called before removing a record. If BeforeDelete returns an +// error the delete process is cancelled and rolled back. +type BeforeDeleteHook interface { + BeforeDelete(Session) error +} + +// AfterDeleteHook is an interface for records that defines a AfterDelete +// method that is called after removing a record. If AfterDelete returns an +// error the delete process is cancelled and rolled back. +type AfterDeleteHook interface { + AfterDelete(Session) error +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..6b962a1 --- /dev/null +++ b/result.go @@ -0,0 +1,193 @@ +package mydb + +import ( + "database/sql/driver" +) + +// Result is an interface that defines methods for result sets. +type Result interface { + + // String returns the SQL statement to be used in the query. + String() string + + // Limit defines the maximum number of results for this set. It only has + // effect on `One()`, `All()` and `Next()`. A negative limit cancels any + // previous limit settings. + Limit(int) Result + + // Offset ignores the first n results. It only has effect on `One()`, `All()` + // and `Next()`. A negative offset cancels any previous offset settings. + Offset(int) Result + + // OrderBy receives one or more field names that define the order in which + // elements will be returned in a query, field names may be prefixed with a + // minus sign (-) indicating descending order, ascending order will be used + // otherwise. + OrderBy(...interface{}) Result + + // Select defines specific columns to be fetched on every column in the + // result set. + Select(...interface{}) Result + + // And adds more filtering conditions on top of the existing constraints. + // + // res := col.Find(...).And(...) + And(...interface{}) Result + + // GroupBy is used to group results that have the same value in the same column + // or columns. + GroupBy(...interface{}) Result + + // Delete deletes all items within the result set. `Offset()` and `Limit()` + // are not honoured by `Delete()`. + Delete() error + + // Update modifies all items within the result set. `Offset()` and `Limit()` + // are not honoured by `Update()`. + Update(interface{}) error + + // Count returns the number of items that match the set conditions. + // `Offset()` and `Limit()` are not honoured by `Count()` + Count() (uint64, error) + + // Exists returns true if at least one item on the collection exists. False + // otherwise. + Exists() (bool, error) + + // Next fetches the next result within the result set and dumps it into the + // given pointer to struct or pointer to map. You must call + // `Close()` after finishing using `Next()`. + Next(ptrToStruct interface{}) bool + + // Err returns the last error that has happened with the result set, nil + // otherwise. + Err() error + + // One fetches the first result within the result set and dumps it into the + // given pointer to struct or pointer to map. The result set is automatically + // closed after picking the element, so there is no need to call Close() + // after using One(). + One(ptrToStruct interface{}) error + + // All fetches all results within the result set and dumps them into the + // given pointer to slice of maps or structs. The result set is + // automatically closed, so there is no need to call Close() after + // using All(). + All(sliceOfStructs interface{}) error + + // Paginate splits the results of the query into pages containing pageSize + // items. When using pagination previous settings for `Limit()` and + // `Offset()` are ignored. Page numbering starts at 1. + // + // Use `Page()` to define the specific page to get results from. + // + // Example: + // + // r = q.Paginate(12) + // + // You can provide constraints an order settings when using pagination: + // + // Example: + // + // res := q.Where(conds).OrderBy("-id").Paginate(12) + // err := res.Page(4).All(&items) + Paginate(pageSize uint) Result + + // Page makes the result set return results only from the page identified by + // pageNumber. Page numbering starts from 1. + // + // Example: + // + // r = q.Paginate(12).Page(4) + Page(pageNumber uint) Result + + // Cursor defines the column that is going to be taken as basis for + // cursor-based pagination. + // + // Example: + // + // a = q.Paginate(10).Cursor("id") + // b = q.Paginate(12).Cursor("-id") + // + // You can set "" as cursorColumn to disable cursors. + Cursor(cursorColumn string) Result + + // NextPage returns the next results page according to the cursor. It expects + // a cursorValue, which is the value the cursor column had on the last item + // of the current result set (lower bound). + // + // Example: + // + // cursor = q.Paginate(12).Cursor("id") + // res = cursor.NextPage(items[len(items)-1].ID) + // + // Note that `NextPage()` requires a cursor, any column with an absolute + // order (given two values one always precedes the other) can be a cursor. + // + // You can define the pagination order and add constraints to your result: + // + // cursor = q.Where(...).OrderBy("id").Paginate(10).Cursor("id") + // res = cursor.NextPage(lowerBound) + NextPage(cursorValue interface{}) Result + + // PrevPage returns the previous results page according to the cursor. It + // expects a cursorValue, which is the value the cursor column had on the + // fist item of the current result set. + // + // Example: + // + // current = current.PrevPage(items[0].ID) + // + // Note that PrevPage requires a cursor, any column with an absolute order + // (given two values one always precedes the other) can be a cursor. + // + // You can define the pagination order and add constraints to your result: + // + // cursor = q.Where(...).OrderBy("id").Paginate(10).Cursor("id") + // res = cursor.PrevPage(upperBound) + PrevPage(cursorValue interface{}) Result + + // TotalPages returns the total number of pages the result set could produce. + // If no pagination parameters have been set this value equals 1. + TotalPages() (uint, error) + + // TotalEntries returns the total number of matching items in the result set. + TotalEntries() (uint64, error) + + // Close closes the result set and frees all locked resources. + Close() error +} + +// InsertResult provides infomation about an insert operation. +type InsertResult interface { + // ID returns the ID of the newly inserted record. + ID() ID +} + +type insertResult struct { + id interface{} +} + +func (r *insertResult) ID() ID { + return r.id +} + +// ConstraintValue satisfies adapter.ConstraintValuer +func (r *insertResult) ConstraintValue() interface{} { + return r.id +} + +// Value satisfies driver.Valuer +func (r *insertResult) Value() (driver.Value, error) { + return r.id, nil +} + +// NewInsertResult creates an InsertResult +func NewInsertResult(id interface{}) InsertResult { + return &insertResult{id: id} +} + +// ID represents a record ID +type ID interface{} + +var _ = driver.Valuer(&insertResult{}) diff --git a/session.go b/session.go new file mode 100644 index 0000000..a49192a --- /dev/null +++ b/session.go @@ -0,0 +1,78 @@ +package mydb + +import ( + "context" + "database/sql" +) + +// Session is an interface that defines methods for database adapters. +type Session interface { + // ConnectionURL returns the DSN that was used to set up the adapter. + ConnectionURL() ConnectionURL + + // Name returns the name of the database. + Name() string + + // Ping returns an error if the DBMS could not be reached. + Ping() error + + // Collection receives a table name and returns a collection reference. The + // information retrieved from a collection is cached. + Collection(name string) Collection + + // Collections returns a collection reference of all non system tables on the + // database. + Collections() ([]Collection, error) + + // Save creates or updates a record. + Save(record Record) error + + // Get retrieves a record that matches the given condition. + Get(record Record, cond interface{}) error + + // Delete deletes a record. + Delete(record Record) error + + // Reset resets all the caching mechanisms the adapter is using. + Reset() + + // Close terminates the currently active connection to the DBMS and clears + // all caches. + Close() error + + // Driver returns the underlying driver of the adapter as an interface. + // + // In order to actually use the driver, the `interface{}` value needs to be + // casted into the appropriate type. + // + // Example: + // internalSQLDriver := sess.Driver().(*sql.DB) + Driver() interface{} + + // SQL returns a special interface for SQL databases. + SQL() SQL + + // Tx creates a transaction block on the default database context and passes + // it to the function fn. If fn returns no error the transaction is commited, + // else the transaction is rolled back. After being commited or rolled back + // the transaction is closed automatically. + Tx(fn func(sess Session) error) error + + // TxContext creates a transaction block on the given context and passes it to + // the function fn. If fn returns no error the transaction is commited, else + // the transaction is rolled back. After being commited or rolled back the + // transaction is closed automatically. + TxContext(ctx context.Context, fn func(sess Session) error, opts *sql.TxOptions) error + + // Context returns the context used as default for queries on this session + // and for new transactions. If no context has been set, a default + // context.Background() is returned. + Context() context.Context + + // WithContext returns the same session on a different default context. The + // session is identical to the original one in all ways except for the + // context. + WithContext(ctx context.Context) Session + + Settings +} diff --git a/settings.go b/settings.go new file mode 100644 index 0000000..3bb2172 --- /dev/null +++ b/settings.go @@ -0,0 +1,179 @@ +package mydb + +import ( + "sync" + "sync/atomic" + "time" +) + +// Settings defines methods to get or set configuration values. +type Settings interface { + // SetPreparedStatementCache enables or disables the prepared statement + // cache. + SetPreparedStatementCache(bool) + + // PreparedStatementCacheEnabled returns true if the prepared statement cache + // is enabled, false otherwise. + PreparedStatementCacheEnabled() bool + + // SetConnMaxLifetime sets the default maximum amount of time a connection + // may be reused. + SetConnMaxLifetime(time.Duration) + + // ConnMaxLifetime returns the default maximum amount of time a connection + // may be reused. + ConnMaxLifetime() time.Duration + + // SetConnMaxIdleTime sets the default maximum amount of time a connection + // may remain idle. + SetConnMaxIdleTime(time.Duration) + + // ConnMaxIdleTime returns the default maximum amount of time a connection + // may remain idle. + ConnMaxIdleTime() time.Duration + + // SetMaxIdleConns sets the default maximum number of connections in the idle + // connection pool. + SetMaxIdleConns(int) + + // MaxIdleConns returns the default maximum number of connections in the idle + // connection pool. + MaxIdleConns() int + + // SetMaxOpenConns sets the default maximum number of open connections to the + // database. + SetMaxOpenConns(int) + + // MaxOpenConns returns the default maximum number of open connections to the + // database. + MaxOpenConns() int + + // SetMaxTransactionRetries sets the number of times a transaction can + // be retried. + SetMaxTransactionRetries(int) + + // MaxTransactionRetries returns the maximum number of times a + // transaction can be retried. + MaxTransactionRetries() int +} + +type settings struct { + sync.RWMutex + + preparedStatementCacheEnabled uint32 + + connMaxLifetime time.Duration + connMaxIdleTime time.Duration + maxOpenConns int + maxIdleConns int + + maxTransactionRetries int +} + +func (c *settings) binaryOption(opt *uint32) bool { + return atomic.LoadUint32(opt) == 1 +} + +func (c *settings) setBinaryOption(opt *uint32, value bool) { + if value { + atomic.StoreUint32(opt, 1) + return + } + atomic.StoreUint32(opt, 0) +} + +func (c *settings) SetPreparedStatementCache(value bool) { + c.setBinaryOption(&c.preparedStatementCacheEnabled, value) +} + +func (c *settings) PreparedStatementCacheEnabled() bool { + return c.binaryOption(&c.preparedStatementCacheEnabled) +} + +func (c *settings) SetConnMaxLifetime(t time.Duration) { + c.Lock() + c.connMaxLifetime = t + c.Unlock() +} + +func (c *settings) ConnMaxLifetime() time.Duration { + c.RLock() + defer c.RUnlock() + return c.connMaxLifetime +} + +func (c *settings) SetConnMaxIdleTime(t time.Duration) { + c.Lock() + c.connMaxIdleTime = t + c.Unlock() +} + +func (c *settings) ConnMaxIdleTime() time.Duration { + c.RLock() + defer c.RUnlock() + return c.connMaxIdleTime +} + +func (c *settings) SetMaxIdleConns(n int) { + c.Lock() + c.maxIdleConns = n + c.Unlock() +} + +func (c *settings) MaxIdleConns() int { + c.RLock() + defer c.RUnlock() + return c.maxIdleConns +} + +func (c *settings) SetMaxTransactionRetries(n int) { + c.Lock() + c.maxTransactionRetries = n + c.Unlock() +} + +func (c *settings) MaxTransactionRetries() int { + c.RLock() + defer c.RUnlock() + if c.maxTransactionRetries < 1 { + return 1 + } + return c.maxTransactionRetries +} + +func (c *settings) SetMaxOpenConns(n int) { + c.Lock() + c.maxOpenConns = n + c.Unlock() +} + +func (c *settings) MaxOpenConns() int { + c.RLock() + defer c.RUnlock() + return c.maxOpenConns +} + +// NewSettings returns a new settings value prefilled with the current default +// settings. +func NewSettings() Settings { + def := DefaultSettings.(*settings) + return &settings{ + preparedStatementCacheEnabled: def.preparedStatementCacheEnabled, + connMaxLifetime: def.connMaxLifetime, + connMaxIdleTime: def.connMaxIdleTime, + maxIdleConns: def.maxIdleConns, + maxOpenConns: def.maxOpenConns, + maxTransactionRetries: def.maxTransactionRetries, + } +} + +// DefaultSettings provides default global configuration settings for database +// sessions. +var DefaultSettings Settings = &settings{ + preparedStatementCacheEnabled: 0, + connMaxLifetime: time.Duration(0), + connMaxIdleTime: time.Duration(0), + maxIdleConns: 10, + maxOpenConns: 0, + maxTransactionRetries: 1, +} diff --git a/sql.go b/sql.go new file mode 100644 index 0000000..f6934a1 --- /dev/null +++ b/sql.go @@ -0,0 +1,190 @@ +package mydb + +import ( + "context" + "database/sql" +) + +// SQL defines methods that can be used to build a SQL query with chainable +// method calls. +// +// Queries are immutable, so every call to any method will return a new +// pointer, if you want to build a query using variables you need to reassign +// them, like this: +// +// a = builder.Select("name").From("foo") // "a" is created +// +// a.Where(...) // No effect, the value returned from Where is ignored. +// +// a = a.Where(...) // "a" is reassigned and points to a different address. +type SQL interface { + + // Select initializes and returns a Selector, it accepts column names as + // parameters. + // + // The returned Selector does not initially point to any table, a call to + // From() is required after Select() to complete a valid query. + // + // Example: + // + // q := sqlbuilder.Select("first_name", "last_name").From("people").Where(...) + Select(columns ...interface{}) Selector + + // SelectFrom creates a Selector that selects all columns (like SELECT *) + // from the given table. + // + // Example: + // + // q := sqlbuilder.SelectFrom("people").Where(...) + SelectFrom(table ...interface{}) Selector + + // InsertInto prepares and returns an Inserter targeted at the given table. + // + // Example: + // + // q := sqlbuilder.InsertInto("books").Columns(...).Values(...) + InsertInto(table string) Inserter + + // DeleteFrom prepares a Deleter targeted at the given table. + // + // Example: + // + // q := sqlbuilder.DeleteFrom("tasks").Where(...) + DeleteFrom(table string) Deleter + + // Update prepares and returns an Updater targeted at the given table. + // + // Example: + // + // q := sqlbuilder.Update("profile").Set(...).Where(...) + Update(table string) Updater + + // Exec executes a SQL query that does not return any rows, like sql.Exec. + // Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.Exec(`INSERT INTO books (title) VALUES("La Ciudad y los Perros")`) + Exec(query interface{}, args ...interface{}) (sql.Result, error) + + // ExecContext executes a SQL query that does not return any rows, like sql.ExecContext. + // Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.ExecContext(ctx, `INSERT INTO books (title) VALUES(?)`, "La Ciudad y los Perros") + ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) + + // Prepare creates a prepared statement for later queries or executions. The + // caller must call the statement's Close method when the statement is no + // longer needed. + Prepare(query interface{}) (*sql.Stmt, error) + + // Prepare creates a prepared statement on the guiven context for later + // queries or executions. The caller must call the statement's Close method + // when the statement is no longer needed. + PrepareContext(ctx context.Context, query interface{}) (*sql.Stmt, error) + + // Query executes a SQL query that returns rows, like sql.Query. Queries can + // be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.Query(`SELECT * FROM people WHERE name = "Mateo"`) + Query(query interface{}, args ...interface{}) (*sql.Rows, error) + + // QueryContext executes a SQL query that returns rows, like + // sql.QueryContext. Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.QueryContext(ctx, `SELECT * FROM people WHERE name = ?`, "Mateo") + QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) + + // QueryRow executes a SQL query that returns one row, like sql.QueryRow. + // Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.QueryRow(`SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`) + QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) + + // QueryRowContext executes a SQL query that returns one row, like + // sql.QueryRowContext. Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.QueryRowContext(ctx, `SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`) + QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) + + // Iterator executes a SQL query that returns rows and creates an Iterator + // with it. + // + // Example: + // + // sqlbuilder.Iterator(`SELECT * FROM people WHERE name LIKE "M%"`) + Iterator(query interface{}, args ...interface{}) Iterator + + // IteratorContext executes a SQL query that returns rows and creates an Iterator + // with it. + // + // Example: + // + // sqlbuilder.IteratorContext(ctx, `SELECT * FROM people WHERE name LIKE "M%"`) + IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator + + // NewIterator converts a *sql.Rows value into an Iterator. + NewIterator(rows *sql.Rows) Iterator + + // NewIteratorContext converts a *sql.Rows value into an Iterator. + NewIteratorContext(ctx context.Context, rows *sql.Rows) Iterator +} + +// SQLExecer provides methods for executing statements that do not return +// results. +type SQLExecer interface { + // Exec executes a statement and returns sql.Result. + Exec() (sql.Result, error) + + // ExecContext executes a statement and returns sql.Result. + ExecContext(context.Context) (sql.Result, error) +} + +// SQLPreparer provides the Prepare and PrepareContext methods for creating +// prepared statements. +type SQLPreparer interface { + // Prepare creates a prepared statement. + Prepare() (*sql.Stmt, error) + + // PrepareContext creates a prepared statement. + PrepareContext(context.Context) (*sql.Stmt, error) +} + +// SQLGetter provides methods for executing statements that return results. +type SQLGetter interface { + // Query returns *sql.Rows. + Query() (*sql.Rows, error) + + // QueryContext returns *sql.Rows. + QueryContext(context.Context) (*sql.Rows, error) + + // QueryRow returns only one row. + QueryRow() (*sql.Row, error) + + // QueryRowContext returns only one row. + QueryRowContext(ctx context.Context) (*sql.Row, error) +} + +// SQLEngine represents a SQL engine that can execute SQL queries. This is +// compatible with *sql.DB. +type SQLEngine interface { + Exec(string, ...interface{}) (sql.Result, error) + Prepare(string) (*sql.Stmt, error) + Query(string, ...interface{}) (*sql.Rows, error) + QueryRow(string, ...interface{}) *sql.Row + + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} diff --git a/store.go b/store.go new file mode 100644 index 0000000..4225df2 --- /dev/null +++ b/store.go @@ -0,0 +1,36 @@ +package mydb + +// Store represents a data store. +type Store interface { + Collection +} + +// StoreSaver is an interface for data stores that defines a Save method that +// has the task of persisting a record. +type StoreSaver interface { + Save(record Record) error +} + +// StoreCreator is an interface for data stores that defines a Create method +// that has the task of creating a new record. +type StoreCreator interface { + Create(record Record) error +} + +// StoreDeleter is an interface for data stores that defines a Delete method +// that has the task of removing a record. +type StoreDeleter interface { + Delete(record Record) error +} + +// StoreUpdater is an interface for data stores that defines a Update method +// that has the task of updating a record. +type StoreUpdater interface { + Update(record Record) error +} + +// StoreGetter is an interface for data stores that defines a Get method that +// has the task of retrieving a record. +type StoreGetter interface { + Get(record Record, id interface{}) error +} diff --git a/union.go b/union.go new file mode 100644 index 0000000..d97b3a9 --- /dev/null +++ b/union.go @@ -0,0 +1,41 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// OrExpr represents a logical expression joined by logical disjunction (OR). +type OrExpr struct { + *adapter.LogicalExprGroup +} + +// Or adds more expressions to the group. +func (o *OrExpr) Or(orConds ...LogicalExpr) *OrExpr { + var fn func(*[]LogicalExpr) error + if len(orConds) > 0 { + fn = func(in *[]LogicalExpr) error { + *in = append(*in, orConds...) + return nil + } + } + return &OrExpr{o.LogicalExprGroup.Frame(fn)} +} + +// Empty returns false if the expressions has zero conditions. +func (o *OrExpr) Empty() bool { + return o.LogicalExprGroup.Empty() +} + +// Or joins conditions under logical disjunction. Conditions can be represented +// by `db.Cond{}`, `db.Or()` or `db.And()`. +// +// Example: +// +// // year = 2012 OR year = 1987 +// db.Or( +// db.Cond{"year": 2012}, +// db.Cond{"year": 1987}, +// ) +func Or(conds ...LogicalExpr) *OrExpr { + return &OrExpr{adapter.NewLogicalExprGroup(adapter.LogicalOperatorOr, defaultJoin(conds...)...)} +} + +var _ = adapter.LogicalExpr(&OrExpr{})