From 4ac92c26b052d2a21a05e9d83096a9ce1d3afbaf Mon Sep 17 00:00:00 2001 From: tiglog Date: Mon, 18 Sep 2023 15:15:42 +0800 Subject: [PATCH] first commit --- .gitignore | 4 + Makefile | 39 + adapter.go | 54 + adapter/mongo/Makefile | 43 + adapter/mongo/README.md | 4 + adapter/mongo/collection.go | 346 +++ adapter/mongo/connection.go | 98 + adapter/mongo/connection_test.go | 114 + adapter/mongo/database.go | 245 ++ adapter/mongo/docker-compose.yml | 13 + adapter/mongo/generic_test.go | 20 + adapter/mongo/helper_test.go | 77 + adapter/mongo/mongo_test.go | 754 +++++++ adapter/mongo/result.go | 565 +++++ adapter/mysql/Makefile | 43 + adapter/mysql/README.md | 5 + adapter/mysql/collection.go | 56 + adapter/mysql/connection.go | 244 ++ adapter/mysql/connection_test.go | 117 + adapter/mysql/custom_types.go | 151 ++ adapter/mysql/database.go | 168 ++ adapter/mysql/docker-compose.yml | 14 + adapter/mysql/generic_test.go | 20 + adapter/mysql/helper_test.go | 276 +++ adapter/mysql/mysql.go | 30 + adapter/mysql/mysql_test.go | 379 ++++ adapter/mysql/record_test.go | 20 + adapter/mysql/sql_test.go | 20 + adapter/mysql/template.go | 198 ++ adapter/mysql/template_test.go | 269 +++ adapter/postgresql/Makefile | 44 + adapter/postgresql/README.md | 5 + adapter/postgresql/collection.go | 50 + adapter/postgresql/connection.go | 289 +++ adapter/postgresql/connection_pgx.go | 73 + adapter/postgresql/connection_pgx_test.go | 108 + adapter/postgresql/connection_pq.go | 70 + adapter/postgresql/connection_pq_test.go | 108 + adapter/postgresql/custom_types.go | 126 ++ adapter/postgresql/custom_types_pgx.go | 286 +++ adapter/postgresql/custom_types_pq.go | 249 +++ adapter/postgresql/custom_types_test.go | 105 + adapter/postgresql/database.go | 180 ++ adapter/postgresql/database_pgx.go | 26 + adapter/postgresql/database_pq.go | 26 + adapter/postgresql/docker-compose.yml | 13 + adapter/postgresql/generic_test.go | 20 + adapter/postgresql/helper_test.go | 321 +++ adapter/postgresql/postgresql.go | 30 + adapter/postgresql/postgresql_test.go | 1404 ++++++++++++ adapter/postgresql/record_test.go | 20 + adapter/postgresql/sql_test.go | 20 + adapter/postgresql/template.go | 189 ++ adapter/postgresql/template_test.go | 262 +++ adapter/sqlite/Makefile | 27 + adapter/sqlite/README.md | 4 + adapter/sqlite/collection.go | 49 + adapter/sqlite/connection.go | 89 + adapter/sqlite/connection_test.go | 88 + adapter/sqlite/database.go | 168 ++ adapter/sqlite/generic_test.go | 20 + adapter/sqlite/helper_test.go | 170 ++ adapter/sqlite/record_test.go | 20 + adapter/sqlite/sql_test.go | 20 + adapter/sqlite/sqlite.go | 30 + adapter/sqlite/sqlite_test.go | 55 + adapter/sqlite/template.go | 187 ++ adapter/sqlite/template_test.go | 246 ++ clauses.go | 468 ++++ collection.go | 45 + comparison.go | 158 ++ comparison_test.go | 111 + cond.go | 109 + cond_test.go | 69 + connection_url.go | 8 + errors.go | 42 + errors_test.go | 14 + function.go | 25 + function_test.go | 51 + go.mod | 33 + go.sum | 226 ++ internal/adapter/comparison.go | 60 + internal/adapter/constraint.go | 51 + internal/adapter/func.go | 18 + internal/adapter/logical_expr.go | 100 + internal/adapter/raw.go | 49 + internal/cache/cache.go | 113 + internal/cache/cache_test.go | 97 + internal/cache/hash.go | 109 + internal/cache/interface.go | 13 + internal/immutable/immutable.go | 28 + internal/reflectx/LICENSE | 23 + internal/reflectx/README.md | 17 + internal/reflectx/reflect.go | 404 ++++ internal/reflectx/reflect_test.go | 587 +++++ internal/sqladapter/collection.go | 369 +++ internal/sqladapter/compat/query.go | 72 + internal/sqladapter/compat/query_go18.go | 72 + internal/sqladapter/exql/column.go | 83 + internal/sqladapter/exql/column_test.go | 88 + internal/sqladapter/exql/column_value.go | 112 + internal/sqladapter/exql/column_value_test.go | 115 + internal/sqladapter/exql/columns.go | 83 + internal/sqladapter/exql/columns_test.go | 72 + internal/sqladapter/exql/database.go | 37 + internal/sqladapter/exql/database_test.go | 45 + internal/sqladapter/exql/default.go | 192 ++ internal/sqladapter/exql/errors.go | 5 + internal/sqladapter/exql/group_by.go | 60 + internal/sqladapter/exql/group_by_test.go | 71 + internal/sqladapter/exql/interfaces.go | 20 + internal/sqladapter/exql/join.go | 195 ++ internal/sqladapter/exql/join_test.go | 221 ++ internal/sqladapter/exql/order_by.go | 175 ++ internal/sqladapter/exql/order_by_test.go | 154 ++ internal/sqladapter/exql/raw.go | 48 + internal/sqladapter/exql/raw_test.go | 51 + internal/sqladapter/exql/returning.go | 41 + internal/sqladapter/exql/statement.go | 132 ++ internal/sqladapter/exql/statement_test.go | 703 ++++++ internal/sqladapter/exql/table.go | 98 + internal/sqladapter/exql/table_test.go | 82 + internal/sqladapter/exql/template.go | 148 ++ internal/sqladapter/exql/types.go | 35 + internal/sqladapter/exql/utilities.go | 151 ++ internal/sqladapter/exql/utilities_test.go | 211 ++ internal/sqladapter/exql/value.go | 166 ++ internal/sqladapter/exql/value_test.go | 130 ++ internal/sqladapter/exql/where.go | 149 ++ internal/sqladapter/exql/where_test.go | 127 ++ internal/sqladapter/hash.go | 8 + internal/sqladapter/record.go | 122 + internal/sqladapter/result.go | 498 +++++ internal/sqladapter/session.go | 1106 +++++++++ internal/sqladapter/sqladapter.go | 62 + internal/sqladapter/sqladapter_test.go | 45 + internal/sqladapter/statement.go | 85 + internal/sqlbuilder/batch.go | 84 + internal/sqlbuilder/builder.go | 611 +++++ internal/sqlbuilder/builder_test.go | 1510 +++++++++++++ internal/sqlbuilder/comparison.go | 122 + internal/sqlbuilder/convert.go | 166 ++ internal/sqlbuilder/custom_types.go | 11 + internal/sqlbuilder/delete.go | 195 ++ internal/sqlbuilder/errors.go | 14 + internal/sqlbuilder/fetch.go | 234 ++ internal/sqlbuilder/insert.go | 285 +++ internal/sqlbuilder/paginate.go | 340 +++ internal/sqlbuilder/placeholder_test.go | 146 ++ internal/sqlbuilder/scanner.go | 17 + internal/sqlbuilder/select.go | 524 +++++ internal/sqlbuilder/sqlbuilder.go | 40 + internal/sqlbuilder/template.go | 332 +++ internal/sqlbuilder/template_test.go | 192 ++ internal/sqlbuilder/update.go | 242 ++ internal/sqlbuilder/wrapper.go | 64 + internal/testsuite/generic_suite.go | 889 ++++++++ internal/testsuite/record_suite.go | 428 ++++ internal/testsuite/sql_suite.go | 1974 +++++++++++++++++ internal/testsuite/suite.go | 37 + intersection.go | 50 + iterator.go | 26 + logger.go | 349 +++ logger_test.go | 11 + marshal.go | 16 + mydb.go | 50 + raw.go | 17 + readme.adoc | 18 + record.go | 62 + result.go | 193 ++ session.go | 78 + settings.go | 179 ++ sql.go | 190 ++ store.go | 36 + union.go | 41 + 175 files changed, 28823 insertions(+) create mode 100644 .gitignore create mode 100644 Makefile create mode 100644 adapter.go create mode 100644 adapter/mongo/Makefile create mode 100644 adapter/mongo/README.md create mode 100644 adapter/mongo/collection.go create mode 100644 adapter/mongo/connection.go create mode 100644 adapter/mongo/connection_test.go create mode 100644 adapter/mongo/database.go create mode 100644 adapter/mongo/docker-compose.yml create mode 100644 adapter/mongo/generic_test.go create mode 100644 adapter/mongo/helper_test.go create mode 100644 adapter/mongo/mongo_test.go create mode 100644 adapter/mongo/result.go create mode 100644 adapter/mysql/Makefile create mode 100644 adapter/mysql/README.md create mode 100644 adapter/mysql/collection.go create mode 100644 adapter/mysql/connection.go create mode 100644 adapter/mysql/connection_test.go create mode 100644 adapter/mysql/custom_types.go create mode 100644 adapter/mysql/database.go create mode 100644 adapter/mysql/docker-compose.yml create mode 100644 adapter/mysql/generic_test.go create mode 100644 adapter/mysql/helper_test.go create mode 100644 adapter/mysql/mysql.go create mode 100644 adapter/mysql/mysql_test.go create mode 100644 adapter/mysql/record_test.go create mode 100644 adapter/mysql/sql_test.go create mode 100644 adapter/mysql/template.go create mode 100644 adapter/mysql/template_test.go create mode 100644 adapter/postgresql/Makefile create mode 100644 adapter/postgresql/README.md create mode 100644 adapter/postgresql/collection.go create mode 100644 adapter/postgresql/connection.go create mode 100644 adapter/postgresql/connection_pgx.go create mode 100644 adapter/postgresql/connection_pgx_test.go create mode 100644 adapter/postgresql/connection_pq.go create mode 100644 adapter/postgresql/connection_pq_test.go create mode 100644 adapter/postgresql/custom_types.go create mode 100644 adapter/postgresql/custom_types_pgx.go create mode 100644 adapter/postgresql/custom_types_pq.go create mode 100644 adapter/postgresql/custom_types_test.go create mode 100644 adapter/postgresql/database.go create mode 100644 adapter/postgresql/database_pgx.go create mode 100644 adapter/postgresql/database_pq.go create mode 100644 adapter/postgresql/docker-compose.yml create mode 100644 adapter/postgresql/generic_test.go create mode 100644 adapter/postgresql/helper_test.go create mode 100644 adapter/postgresql/postgresql.go create mode 100644 adapter/postgresql/postgresql_test.go create mode 100644 adapter/postgresql/record_test.go create mode 100644 adapter/postgresql/sql_test.go create mode 100644 adapter/postgresql/template.go create mode 100644 adapter/postgresql/template_test.go create mode 100644 adapter/sqlite/Makefile create mode 100644 adapter/sqlite/README.md create mode 100644 adapter/sqlite/collection.go create mode 100644 adapter/sqlite/connection.go create mode 100644 adapter/sqlite/connection_test.go create mode 100644 adapter/sqlite/database.go create mode 100644 adapter/sqlite/generic_test.go create mode 100644 adapter/sqlite/helper_test.go create mode 100644 adapter/sqlite/record_test.go create mode 100644 adapter/sqlite/sql_test.go create mode 100644 adapter/sqlite/sqlite.go create mode 100644 adapter/sqlite/sqlite_test.go create mode 100644 adapter/sqlite/template.go create mode 100644 adapter/sqlite/template_test.go create mode 100644 clauses.go create mode 100644 collection.go create mode 100644 comparison.go create mode 100644 comparison_test.go create mode 100644 cond.go create mode 100644 cond_test.go create mode 100644 connection_url.go create mode 100644 errors.go create mode 100644 errors_test.go create mode 100644 function.go create mode 100644 function_test.go create mode 100644 go.mod create mode 100644 go.sum create mode 100644 internal/adapter/comparison.go create mode 100644 internal/adapter/constraint.go create mode 100644 internal/adapter/func.go create mode 100644 internal/adapter/logical_expr.go create mode 100644 internal/adapter/raw.go create mode 100644 internal/cache/cache.go create mode 100644 internal/cache/cache_test.go create mode 100644 internal/cache/hash.go create mode 100644 internal/cache/interface.go create mode 100644 internal/immutable/immutable.go create mode 100644 internal/reflectx/LICENSE create mode 100644 internal/reflectx/README.md create mode 100644 internal/reflectx/reflect.go create mode 100644 internal/reflectx/reflect_test.go create mode 100644 internal/sqladapter/collection.go create mode 100644 internal/sqladapter/compat/query.go create mode 100644 internal/sqladapter/compat/query_go18.go create mode 100644 internal/sqladapter/exql/column.go create mode 100644 internal/sqladapter/exql/column_test.go create mode 100644 internal/sqladapter/exql/column_value.go create mode 100644 internal/sqladapter/exql/column_value_test.go create mode 100644 internal/sqladapter/exql/columns.go create mode 100644 internal/sqladapter/exql/columns_test.go create mode 100644 internal/sqladapter/exql/database.go create mode 100644 internal/sqladapter/exql/database_test.go create mode 100644 internal/sqladapter/exql/default.go create mode 100644 internal/sqladapter/exql/errors.go create mode 100644 internal/sqladapter/exql/group_by.go create mode 100644 internal/sqladapter/exql/group_by_test.go create mode 100644 internal/sqladapter/exql/interfaces.go create mode 100644 internal/sqladapter/exql/join.go create mode 100644 internal/sqladapter/exql/join_test.go create mode 100644 internal/sqladapter/exql/order_by.go create mode 100644 internal/sqladapter/exql/order_by_test.go create mode 100644 internal/sqladapter/exql/raw.go create mode 100644 internal/sqladapter/exql/raw_test.go create mode 100644 internal/sqladapter/exql/returning.go create mode 100644 internal/sqladapter/exql/statement.go create mode 100644 internal/sqladapter/exql/statement_test.go create mode 100644 internal/sqladapter/exql/table.go create mode 100644 internal/sqladapter/exql/table_test.go create mode 100644 internal/sqladapter/exql/template.go create mode 100644 internal/sqladapter/exql/types.go create mode 100644 internal/sqladapter/exql/utilities.go create mode 100644 internal/sqladapter/exql/utilities_test.go create mode 100644 internal/sqladapter/exql/value.go create mode 100644 internal/sqladapter/exql/value_test.go create mode 100644 internal/sqladapter/exql/where.go create mode 100644 internal/sqladapter/exql/where_test.go create mode 100644 internal/sqladapter/hash.go create mode 100644 internal/sqladapter/record.go create mode 100644 internal/sqladapter/result.go create mode 100644 internal/sqladapter/session.go create mode 100644 internal/sqladapter/sqladapter.go create mode 100644 internal/sqladapter/sqladapter_test.go create mode 100644 internal/sqladapter/statement.go create mode 100644 internal/sqlbuilder/batch.go create mode 100644 internal/sqlbuilder/builder.go create mode 100644 internal/sqlbuilder/builder_test.go create mode 100644 internal/sqlbuilder/comparison.go create mode 100644 internal/sqlbuilder/convert.go create mode 100644 internal/sqlbuilder/custom_types.go create mode 100644 internal/sqlbuilder/delete.go create mode 100644 internal/sqlbuilder/errors.go create mode 100644 internal/sqlbuilder/fetch.go create mode 100644 internal/sqlbuilder/insert.go create mode 100644 internal/sqlbuilder/paginate.go create mode 100644 internal/sqlbuilder/placeholder_test.go create mode 100644 internal/sqlbuilder/scanner.go create mode 100644 internal/sqlbuilder/select.go create mode 100644 internal/sqlbuilder/sqlbuilder.go create mode 100644 internal/sqlbuilder/template.go create mode 100644 internal/sqlbuilder/template_test.go create mode 100644 internal/sqlbuilder/update.go create mode 100644 internal/sqlbuilder/wrapper.go create mode 100644 internal/testsuite/generic_suite.go create mode 100644 internal/testsuite/record_suite.go create mode 100644 internal/testsuite/sql_suite.go create mode 100644 internal/testsuite/suite.go create mode 100644 intersection.go create mode 100644 iterator.go create mode 100644 logger.go create mode 100644 logger_test.go create mode 100644 marshal.go create mode 100644 mydb.go create mode 100644 raw.go create mode 100644 readme.adoc create mode 100644 record.go create mode 100644 result.go create mode 100644 session.go create mode 100644 settings.go create mode 100644 sql.go create mode 100644 store.go create mode 100644 union.go diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..2946070 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +*.sw? +*.db +*.tmp +generated_*.go diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..c5a605a --- /dev/null +++ b/Makefile @@ -0,0 +1,39 @@ +SHELL ?= /bin/bash + +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs=2 -v -u + +TEST_FLAGS ?= + +UPPER_DB_LOG ?= WARN + +export TEST_FLAGS +export PARALLEL_FLAGS +export UPPER_DB_LOG + +test: go-test-internal test-adapters + +benchmark: go-benchmark-internal + +go-benchmark-%: + go test -v -benchtime=500ms -bench=. ./$*/... + +go-test-%: + go test -v ./$*/... + +test-adapters: \ + test-adapter-postgresql \ + # test-adapter-mysql \ + # test-adapter-sqlite \ + # test-adapter-mongo + +test-adapter-%: + ($(MAKE) -C adapter/$* test-extended || exit 1) + +test-generic: + export TEST_FLAGS="-run TestGeneric"; \ + $(MAKE) test-adapters + +goimports: + for FILE in $$(find -name "*.go" | grep -v vendor); do \ + goimports -w $$FILE; \ + done diff --git a/adapter.go b/adapter.go new file mode 100644 index 0000000..1ef65b2 --- /dev/null +++ b/adapter.go @@ -0,0 +1,54 @@ +package mydb + +import ( + "fmt" + "sync" +) + +var ( + adapterMap = make(map[string]Adapter) + adapterMapMu sync.RWMutex +) + +// Adapter interface defines an adapter +type Adapter interface { + Open(ConnectionURL) (Session, error) +} + +type missingAdapter struct { + name string +} + +func (ma *missingAdapter) Open(ConnectionURL) (Session, error) { + return nil, fmt.Errorf("mydb: Missing adapter %q, did you forget to import it?", ma.name) +} + +// RegisterAdapter registers a generic database adapter. +func RegisterAdapter(name string, adapter Adapter) { + adapterMapMu.Lock() + defer adapterMapMu.Unlock() + + if name == "" { + panic(`Missing adapter name`) + } + if _, ok := adapterMap[name]; ok { + panic(`db.RegisterAdapter() called twice for adapter: ` + name) + } + adapterMap[name] = adapter +} + +// LookupAdapter returns a previously registered adapter by name. +func LookupAdapter(name string) Adapter { + adapterMapMu.RLock() + defer adapterMapMu.RUnlock() + + if adapter, ok := adapterMap[name]; ok { + return adapter + } + return &missingAdapter{name: name} +} + +// Open attempts to stablish a connection with a database. +func Open(adapterName string, settings ConnectionURL) (Session, error) { + return LookupAdapter(adapterName).Open(settings) +} diff --git a/adapter/mongo/Makefile b/adapter/mongo/Makefile new file mode 100644 index 0000000..7660d30 --- /dev/null +++ b/adapter/mongo/Makefile @@ -0,0 +1,43 @@ +SHELL ?= bash + +MONGO_VERSION ?= 4 +MONGO_SUPPORTED ?= $(MONGO_VERSION) 3 +PROJECT ?= upper_mongo_$(MONGO_VERSION) + +DB_HOST ?= 127.0.0.1 +DB_PORT ?= 27017 + +DB_NAME ?= admin +DB_USERNAME ?= upperio_user +DB_PASSWORD ?= upperio//s3cr37 + +TEST_FLAGS ?= +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1 + +export MONGO_VERSION + +export DB_HOST +export DB_NAME +export DB_PASSWORD +export DB_PORT +export DB_USERNAME + +export TEST_FLAGS + +test: + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +server-up: server-down + docker-compose -p $(PROJECT) up -d && \ + sleep 10 + +server-down: + docker-compose -p $(PROJECT) down + +test-extended: + parallel $(PARALLEL_FLAGS) \ + "MONGO_VERSION={} DB_PORT=\$$((27017+{#})) $(MAKE) server-up test server-down" ::: \ + $(MONGO_SUPPORTED) diff --git a/adapter/mongo/README.md b/adapter/mongo/README.md new file mode 100644 index 0000000..cb56a27 --- /dev/null +++ b/adapter/mongo/README.md @@ -0,0 +1,4 @@ +# MongoDB adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/mongo/](https://upper.io/v4/adapter/mongo/). diff --git a/adapter/mongo/collection.go b/adapter/mongo/collection.go new file mode 100644 index 0000000..0c27c06 --- /dev/null +++ b/adapter/mongo/collection.go @@ -0,0 +1,346 @@ +package mongo + +import ( + "fmt" + "strings" + "sync" + + "reflect" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + mgo "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +// Collection represents a mongodb collection. +type Collection struct { + parent *Source + collection *mgo.Collection +} + +var ( + // idCache should be a struct if we're going to cache more than just + // _id field here + idCache = make(map[reflect.Type]string) + idCacheMutex sync.RWMutex +) + +// Find creates a result set with the given conditions. +func (col *Collection) Find(terms ...interface{}) mydb.Result { + fields := []string{"*"} + + conditions := col.compileQuery(terms...) + + res := &result{} + res = res.frame(func(r *resultQuery) error { + r.c = col + r.conditions = conditions + r.fields = fields + return nil + }) + + return res +} + +var comparisonOperators = map[adapter.ComparisonOperator]string{ + adapter.ComparisonOperatorEqual: "$eq", + adapter.ComparisonOperatorNotEqual: "$ne", + + adapter.ComparisonOperatorLessThan: "$lt", + adapter.ComparisonOperatorGreaterThan: "$gt", + + adapter.ComparisonOperatorLessThanOrEqualTo: "$lte", + adapter.ComparisonOperatorGreaterThanOrEqualTo: "$gte", + + adapter.ComparisonOperatorIn: "$in", + adapter.ComparisonOperatorNotIn: "$nin", +} + +func compare(field string, cmp *adapter.Comparison) (string, interface{}) { + op := cmp.Operator() + value := cmp.Value() + + switch op { + case adapter.ComparisonOperatorEqual: + return field, value + case adapter.ComparisonOperatorBetween: + values := value.([]interface{}) + return field, bson.M{ + "$gte": values[0], + "$lte": values[1], + } + case adapter.ComparisonOperatorNotBetween: + values := value.([]interface{}) + return "$or", []bson.M{ + {field: bson.M{"$gt": values[1]}}, + {field: bson.M{"$lt": values[0]}}, + } + case adapter.ComparisonOperatorIs: + if value == nil { + return field, bson.M{"$exists": false} + } + return field, bson.M{"$eq": value} + case adapter.ComparisonOperatorIsNot: + if value == nil { + return field, bson.M{"$exists": true} + } + return field, bson.M{"$ne": value} + case adapter.ComparisonOperatorRegExp, adapter.ComparisonOperatorLike: + return field, bson.RegEx{Pattern: value.(string), Options: ""} + case adapter.ComparisonOperatorNotRegExp, adapter.ComparisonOperatorNotLike: + return field, bson.M{"$not": bson.RegEx{Pattern: value.(string), Options: ""}} + } + + if cmpOp, ok := comparisonOperators[op]; ok { + return field, bson.M{ + cmpOp: value, + } + } + + panic(fmt.Sprintf("Unsupported operator %v", op)) +} + +// compileStatement transforms conditions into something *mgo.Session can +// understand. +func compileStatement(cond mydb.Cond) bson.M { + conds := bson.M{} + + // Walking over conditions + for fieldI, value := range cond { + field := strings.TrimSpace(fmt.Sprintf("%v", fieldI)) + + if cmp, ok := value.(*mydb.Comparison); ok { + k, v := compare(field, cmp.Comparison) + conds[k] = v + continue + } + + var op string + chunks := strings.SplitN(field, ` `, 2) + + if len(chunks) > 1 { + switch chunks[1] { + case `IN`: + op = `$in` + case `NOT IN`: + op = `$nin` + case `>`: + op = `$gt` + case `<`: + op = `$lt` + case `<=`: + op = `$lte` + case `>=`: + op = `$gte` + default: + op = chunks[1] + } + } + field = chunks[0] + + if op == "" { + conds[field] = value + } else { + conds[field] = bson.M{op: value} + } + } + + return conds +} + +// compileConditions compiles terms into something *mgo.Session can +// understand. +func (col *Collection) compileConditions(term interface{}) interface{} { + + switch t := term.(type) { + case []interface{}: + values := []interface{}{} + for i := range t { + value := col.compileConditions(t[i]) + if value != nil { + values = append(values, value) + } + } + if len(values) > 0 { + return values + } + case mydb.Cond: + return compileStatement(t) + case adapter.LogicalExpr: + values := []interface{}{} + + for _, s := range t.Expressions() { + values = append(values, col.compileConditions(s)) + } + + var op string + switch t.Operator() { + case adapter.LogicalOperatorOr: + op = `$or` + default: + op = `$and` + } + + return bson.M{op: values} + } + return nil +} + +// compileQuery compiles terms into something that *mgo.Session can +// understand. +func (col *Collection) compileQuery(terms ...interface{}) interface{} { + compiled := col.compileConditions(terms) + if compiled == nil { + return nil + } + + conditions := compiled.([]interface{}) + if len(conditions) == 1 { + return conditions[0] + } + // this should be correct. + // query = map[string]interface{}{"$and": conditions} + + // attempt to workaround https://jira.mongomydb.org/browse/SERVER-4572 + mapped := map[string]interface{}{} + for _, v := range conditions { + for kk := range v.(map[string]interface{}) { + mapped[kk] = v.(map[string]interface{})[kk] + } + } + + return mapped +} + +// Name returns the name of the table or tables that form the collection. +func (col *Collection) Name() string { + return col.collection.Name +} + +// Truncate deletes all rows from the table. +func (col *Collection) Truncate() error { + err := col.collection.DropCollection() + + if err != nil { + return err + } + + return nil +} + +func (col *Collection) Session() mydb.Session { + return col.parent +} + +func (col *Collection) Count() (uint64, error) { + return col.Find().Count() +} + +func (col *Collection) InsertReturning(item interface{}) error { + return mydb.ErrUnsupported +} + +func (col *Collection) UpdateReturning(item interface{}) error { + return mydb.ErrUnsupported +} + +// Insert inserts a record (map or struct) into the collection. +func (col *Collection) Insert(item interface{}) (mydb.InsertResult, error) { + var err error + + id := getID(item) + + if col.parent.versionAtLeast(2, 6, 0, 0) { + // this breaks MongoDb older than 2.6 + if _, err = col.collection.Upsert(bson.M{"_id": id}, item); err != nil { + return nil, err + } + } else { + // Allocating a new ID. + if err = col.collection.Insert(bson.M{"_id": id}); err != nil { + return nil, err + } + + // Now append data the user wants to append. + if err = col.collection.Update(bson.M{"_id": id}, item); err != nil { + // Cleanup allocated ID + if err := col.collection.Remove(bson.M{"_id": id}); err != nil { + return nil, err + } + return nil, err + } + } + + return mydb.NewInsertResult(id), nil +} + +// Exists returns true if the collection exists. +func (col *Collection) Exists() (bool, error) { + query := col.parent.database.C(`system.namespaces`).Find(map[string]string{`name`: fmt.Sprintf(`%s.%s`, col.parent.database.Name, col.collection.Name)}) + count, err := query.Count() + return count > 0, err +} + +// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid +func getID(item interface{}) interface{} { + v := reflect.ValueOf(item) // convert interface to Value + v = reflect.Indirect(v) // convert pointers + + switch v.Kind() { + case reflect.Map: + if inItem, ok := item.(map[string]interface{}); ok { + if id, ok := inItem["_id"]; ok { + bsonID, ok := id.(bson.ObjectId) + if ok { + return bsonID + } + } + } + case reflect.Struct: + t := v.Type() + + idCacheMutex.RLock() + fieldName, found := idCache[t] + idCacheMutex.RUnlock() + + if !found { + for n := 0; n < t.NumField(); n++ { + field := t.Field(n) + if field.PkgPath != "" { + continue // Private field + } + + tag := field.Tag.Get("bson") + if tag == "" { + tag = field.Tag.Get("db") + } + + if tag == "" { + continue + } + + parts := strings.Split(tag, ",") + + if parts[0] == "_id" { + fieldName = field.Name + idCacheMutex.RLock() + idCache[t] = fieldName + idCacheMutex.RUnlock() + break + } + } + } + if fieldName != "" { + if bsonID, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok { + if bsonID.Valid() { + return bsonID + } + } else { + return v.FieldByName(fieldName).Interface() + } + } + } + + return bson.NewObjectId() +} diff --git a/adapter/mongo/connection.go b/adapter/mongo/connection.go new file mode 100644 index 0000000..48377c4 --- /dev/null +++ b/adapter/mongo/connection.go @@ -0,0 +1,98 @@ +package mongo + +import ( + "fmt" + "net/url" + "strings" +) + +const connectionScheme = `mongodb` + +// ConnectionURL implements a MongoDB connection struct. +type ConnectionURL struct { + User string + Password string + Host string + Database string + Options map[string]string +} + +func (c ConnectionURL) String() (s string) { + + if c.Database == "" { + return "" + } + + vv := url.Values{} + + // Do we have any options? + if c.Options == nil { + c.Options = map[string]string{} + } + + // Converting options into URL values. + for k, v := range c.Options { + vv.Set(k, v) + } + + // Has user? + var userInfo *url.Userinfo + + if c.User != "" { + if c.Password == "" { + userInfo = url.User(c.User) + } else { + userInfo = url.UserPassword(c.User, c.Password) + } + } + + // Building URL. + u := url.URL{ + Scheme: connectionScheme, + Path: c.Database, + Host: c.Host, + User: userInfo, + RawQuery: vv.Encode(), + } + + return u.String() +} + +// ParseURL parses s into a ConnectionURL struct. +func ParseURL(s string) (conn ConnectionURL, err error) { + var u *url.URL + + if !strings.HasPrefix(s, connectionScheme+"://") { + return conn, fmt.Errorf(`Expecting mongodb:// connection scheme.`) + } + + if u, err = url.Parse(s); err != nil { + return conn, err + } + + conn.Host = u.Host + + // Deleting / from start of the string. + conn.Database = strings.Trim(u.Path, "/") + + // Adding user / password. + if u.User != nil { + conn.User = u.User.Username() + conn.Password, _ = u.User.Password() + } + + // Adding options. + conn.Options = map[string]string{} + + var vv url.Values + + if vv, err = url.ParseQuery(u.RawQuery); err != nil { + return conn, err + } + + for k := range vv { + conn.Options[k] = vv.Get(k) + } + + return conn, err +} diff --git a/adapter/mongo/connection_test.go b/adapter/mongo/connection_test.go new file mode 100644 index 0000000..e60b323 --- /dev/null +++ b/adapter/mongo/connection_test.go @@ -0,0 +1,114 @@ +package mongo + +import ( + "testing" +) + +func TestConnectionURL(t *testing.T) { + + c := ConnectionURL{} + + // Default connection string is only the protocol. + if c.String() != "" { + t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String()) + } + + // Adding a database name. + c.Database = "myfilename" + + if c.String() != "mongodb://myfilename" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Adding an option. + c.Options = map[string]string{ + "cache": "foobar", + "mode": "ro", + } + + // Adding username and password + c.User = "user" + c.Password = "pass" + + // Setting host. + c.Host = "localhost" + + if c.String() != "mongodb://user:pass@localhost/myfilename?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting host and port. + c.Host = "localhost:27017" + + if c.String() != "mongodb://user:pass@localhost:27017/myfilename?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting cluster. + c.Host = "localhost,1.2.3.4,example.org:1234" + + if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/myfilename?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting another database. + c.Database = "another_database" + + if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + +} + +func TestParseConnectionURL(t *testing.T) { + var u ConnectionURL + var s string + var err error + + s = "mongodb:///mydatabase" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "mydatabase" { + t.Fatal("Failed to parse database.") + } + + s = "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "another_database" { + t.Fatal("Failed to get database.") + } + + if u.Options["cache"] != "foobar" { + t.Fatal("Expecting option.") + } + + if u.Options["mode"] != "ro" { + t.Fatal("Expecting option.") + } + + if u.User != "user" { + t.Fatal("Expecting user.") + } + + if u.Password != "pass" { + t.Fatal("Expecting password.") + } + + if u.Host != "localhost,1.2.3.4,example.org:1234" { + t.Fatal("Expecting host.") + } + + s = "http://example.org" + + if _, err = ParseURL(s); err == nil { + t.Fatal("Expecting error.") + } + +} diff --git a/adapter/mongo/database.go b/adapter/mongo/database.go new file mode 100644 index 0000000..e2cdb50 --- /dev/null +++ b/adapter/mongo/database.go @@ -0,0 +1,245 @@ +// Package mongo wraps the gopkg.in/mgo.v2 MongoDB driver. See +// https://github.com/upper/db/adapter/mongo for documentation, particularities and usage +// examples. +package mongo + +import ( + "context" + "database/sql" + "strings" + "sync" + "time" + + "git.hexq.cn/tiglog/mydb" + mgo "gopkg.in/mgo.v2" +) + +// Adapter holds the name of the mongodb adapter. +const Adapter = `mongo` + +var connTimeout = time.Second * 5 + +// Source represents a MongoDB database. +type Source struct { + mydb.Settings + + ctx context.Context + + name string + connURL mydb.ConnectionURL + session *mgo.Session + database *mgo.Database + version []int + collections map[string]*Collection + collectionsMu sync.Mutex +} + +type mongoAdapter struct { +} + +func (mongoAdapter) Open(dsn mydb.ConnectionURL) (mydb.Session, error) { + return Open(dsn) +} + +func init() { + mydb.RegisterAdapter(Adapter, mydb.Adapter(&mongoAdapter{})) +} + +// Open stablishes a new connection to a SQL server. +func Open(settings mydb.ConnectionURL) (mydb.Session, error) { + d := &Source{Settings: mydb.NewSettings(), ctx: context.Background()} + if err := d.Open(settings); err != nil { + return nil, err + } + return d, nil +} + +func (s *Source) TxContext(context.Context, func(tx mydb.Session) error, *sql.TxOptions) error { + return mydb.ErrNotSupportedByAdapter +} + +func (s *Source) Tx(func(mydb.Session) error) error { + return mydb.ErrNotSupportedByAdapter +} + +func (s *Source) SQL() mydb.SQL { + // Not supported + panic("sql builder is not supported by mongodb") +} + +func (s *Source) ConnectionURL() mydb.ConnectionURL { + return s.connURL +} + +// SetConnMaxLifetime is not supported. +func (s *Source) SetConnMaxLifetime(time.Duration) { + s.Settings.SetConnMaxLifetime(time.Duration(0)) +} + +// SetMaxIdleConns is not supported. +func (s *Source) SetMaxIdleConns(int) { + s.Settings.SetMaxIdleConns(0) +} + +// SetMaxOpenConns is not supported. +func (s *Source) SetMaxOpenConns(int) { + s.Settings.SetMaxOpenConns(0) +} + +// Name returns the name of the database. +func (s *Source) Name() string { + return s.name +} + +// Open attempts to connect to the database. +func (s *Source) Open(connURL mydb.ConnectionURL) error { + s.connURL = connURL + return s.open() +} + +// Clone returns a cloned mydb.Session session. +func (s *Source) Clone() (mydb.Session, error) { + newSession := s.session.Copy() + clone := &Source{ + Settings: mydb.NewSettings(), + + name: s.name, + connURL: s.connURL, + session: newSession, + database: newSession.DB(s.database.Name), + version: s.version, + collections: map[string]*Collection{}, + } + return clone, nil +} + +// Ping checks whether a connection to the database is still alive by pinging +// it, establishing a connection if necessary. +func (s *Source) Ping() error { + return s.session.Ping() +} + +func (s *Source) Reset() { + s.collectionsMu.Lock() + defer s.collectionsMu.Unlock() + s.collections = make(map[string]*Collection) +} + +// Driver returns the underlying *mgo.Session instance. +func (s *Source) Driver() interface{} { + return s.session +} + +func (s *Source) open() error { + var err error + + if s.session, err = mgo.DialWithTimeout(s.connURL.String(), connTimeout); err != nil { + return err + } + + s.collections = map[string]*Collection{} + s.database = s.session.DB("") + + return nil +} + +// Close terminates the current database session. +func (s *Source) Close() error { + if s.session != nil { + s.session.Close() + } + return nil +} + +// Collections returns a list of non-system tables from the database. +func (s *Source) Collections() (cols []mydb.Collection, err error) { + var rawcols []string + var col string + + if rawcols, err = s.database.CollectionNames(); err != nil { + return nil, err + } + + cols = make([]mydb.Collection, 0, len(rawcols)) + + for _, col = range rawcols { + if !strings.HasPrefix(col, "system.") { + cols = append(cols, s.Collection(col)) + } + } + + return cols, nil +} + +func (s *Source) Delete(mydb.Record) error { + return mydb.ErrNotImplemented +} + +func (s *Source) Get(mydb.Record, interface{}) error { + return mydb.ErrNotImplemented +} + +func (s *Source) Save(mydb.Record) error { + return mydb.ErrNotImplemented +} + +func (s *Source) Context() context.Context { + return s.ctx +} + +func (s *Source) WithContext(ctx context.Context) mydb.Session { + return &Source{ + ctx: ctx, + Settings: s.Settings, + name: s.name, + connURL: s.connURL, + session: s.session, + database: s.database, + version: s.version, + } +} + +// Collection returns a collection by name. +func (s *Source) Collection(name string) mydb.Collection { + s.collectionsMu.Lock() + defer s.collectionsMu.Unlock() + + var col *Collection + var ok bool + + if col, ok = s.collections[name]; !ok { + col = &Collection{ + parent: s, + collection: s.database.C(name), + } + s.collections[name] = col + } + + return col +} + +func (s *Source) versionAtLeast(version ...int) bool { + // only fetch this once - it makes a db call + if len(s.version) == 0 { + buildInfo, err := s.database.Session.BuildInfo() + if err != nil { + return false + } + s.version = buildInfo.VersionArray + } + + // Check major version first + if s.version[0] > version[0] { + return true + } + + for i := range version { + if i == len(s.version) { + return false + } + if s.version[i] < version[i] { + return false + } + } + return true +} diff --git a/adapter/mongo/docker-compose.yml b/adapter/mongo/docker-compose.yml new file mode 100644 index 0000000..af27525 --- /dev/null +++ b/adapter/mongo/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3' + +services: + + server: + image: mongo:${MONGO_VERSION:-3} + environment: + MONGO_INITDB_ROOT_USERNAME: ${DB_USERNAME:-upperio_user} + MONGO_INITDB_ROOT_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37} + MONGO_INITDB_DATABASE: ${DB_NAME:-upperio} + ports: + - '${BIND_HOST:-127.0.0.1}:${DB_PORT:-27017}:27017' + diff --git a/adapter/mongo/generic_test.go b/adapter/mongo/generic_test.go new file mode 100644 index 0000000..13840af --- /dev/null +++ b/adapter/mongo/generic_test.go @@ -0,0 +1,20 @@ +package mongo + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/mongo/helper_test.go b/adapter/mongo/helper_test.go new file mode 100644 index 0000000..6db7323 --- /dev/null +++ b/adapter/mongo/helper_test.go @@ -0,0 +1,77 @@ +package mongo + +import ( + "fmt" + "os" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + mgo "gopkg.in/mgo.v2" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), + User: os.Getenv("DB_USERNAME"), + Password: os.Getenv("DB_PASSWORD"), + Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"), +} + +type Helper struct { + sess mydb.Session +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return "mongo" +} + +func (h *Helper) TearDown() error { + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + mgod, ok := h.sess.Driver().(*mgo.Session) + if !ok { + panic("expecting mgo.Session") + } + + var col *mgo.Collection + col = mgod.DB(settings.Database).C("birthdays") + _ = col.DropCollection() + + col = mgod.DB(settings.Database).C("fibonacci") + _ = col.DropCollection() + + col = mgod.DB(settings.Database).C("is_even") + _ = col.DropCollection() + + col = mgod.DB(settings.Database).C("CaSe_TesT") + _ = col.DropCollection() + + // Getting a pointer to the "artist" collection. + artist := h.sess.Collection("artist") + + _ = artist.Truncate() + for i := 0; i < 999; i++ { + _, err = artist.Insert(artistType{ + Name: fmt.Sprintf("artist-%d", i), + }) + if err != nil { + return err + } + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/mongo/mongo_test.go b/adapter/mongo/mongo_test.go new file mode 100644 index 0000000..8e0f42e --- /dev/null +++ b/adapter/mongo/mongo_test.go @@ -0,0 +1,754 @@ +// Tests for the mongodb adapter. +package mongo + +import ( + "fmt" + "log" + "math/rand" + "strings" + "testing" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" + "gopkg.in/mgo.v2/bson" +) + +type artistType struct { + ID bson.ObjectId `bson:"_id,omitempty"` + Name string `bson:"name"` +} + +// Structure for testing conversions and datatypes. +type testValuesStruct struct { + Uint uint `bson:"_uint"` + Uint8 uint8 `bson:"_uint8"` + Uint16 uint16 `bson:"_uint16"` + Uint32 uint32 `bson:"_uint32"` + Uint64 uint64 `bson:"_uint64"` + + Int int `bson:"_int"` + Int8 int8 `bson:"_int8"` + Int16 int16 `bson:"_int16"` + Int32 int32 `bson:"_int32"` + Int64 int64 `bson:"_int64"` + + Float32 float32 `bson:"_float32"` + Float64 float64 `bson:"_float64"` + + Bool bool `bson:"_bool"` + String string `bson:"_string"` + + Date time.Time `bson:"_date"` + DateN *time.Time `bson:"_nildate"` + DateP *time.Time `bson:"_ptrdate"` + Time time.Duration `bson:"_time"` +} + +var testValues testValuesStruct + +func init() { + t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local) + + testValues = testValuesStruct{ + 1, 1, 1, 1, 1, + -1, -1, -1, -1, -1, + 1.337, 1.337, + true, + "Hello world!", + t, + nil, + &t, + time.Second * time.Duration(7331), + } +} + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) TestOpenWithWrongData() { + var err error + var rightSettings, wrongSettings ConnectionURL + + // Attempt to open with safe settings. + rightSettings = ConnectionURL{ + Database: settings.Database, + Host: settings.Host, + User: settings.User, + Password: settings.Password, + } + + // Attempt to open an empty database. + _, err = Open(rightSettings) + s.NoError(err) + + // Attempt to open with wrong password. + wrongSettings = ConnectionURL{ + Database: settings.Database, + Host: settings.Host, + User: settings.User, + Password: "fail", + } + + _, err = Open(wrongSettings) + s.Error(err) + + // Attempt to open with wrong database. + wrongSettings = ConnectionURL{ + Database: "fail", + Host: settings.Host, + User: settings.User, + Password: settings.Password, + } + + _, err = Open(wrongSettings) + s.Error(err) + + // Attempt to open with wrong username. + wrongSettings = ConnectionURL{ + Database: settings.Database, + Host: settings.Host, + User: "fail", + Password: settings.Password, + } + + _, err = Open(wrongSettings) + s.Error(err) +} + +func (s *AdapterTests) TestTruncate() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a list of all collections in this database. + collections, err := sess.Collections() + s.NoError(err) + + for _, col := range collections { + // The collection may ot may not exists. + if ok, _ := col.Exists(); ok { + // Truncating the structure, if exists. + err = col.Truncate() + s.NoError(err) + } + } +} + +func (s *AdapterTests) TestInsert() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + _ = artist.Truncate() + + // Inserting a map. + record, err := artist.Insert(map[string]string{ + "name": "Ozzie", + }) + s.NoError(err) + + id := record.ID() + s.NotZero(record.ID()) + + _, ok := id.(bson.ObjectId) + s.True(ok) + + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a struct. + record, err = artist.Insert(struct { + Name string + }{ + "Flea", + }) + s.NoError(err) + + id = record.ID() + s.NotZero(id) + + _, ok = id.(bson.ObjectId) + s.True(ok) + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a struct (using tags to specify the field name). + record, err = artist.Insert(struct { + ArtistName string `bson:"name"` + }{ + "Slash", + }) + s.NoError(err) + + id = record.ID() + s.NotNil(id) + + _, ok = id.(bson.ObjectId) + + s.True(ok) + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a pointer to a struct + record, err = artist.Insert(&struct { + ArtistName string `bson:"name"` + }{ + "Metallica", + }) + s.NoError(err) + + id = record.ID() + s.NotNil(id) + + _, ok = id.(bson.ObjectId) + s.True(ok) + s.True(id.(bson.ObjectId).Valid()) + + // Inserting a pointer to a map + record, err = artist.Insert(&map[string]string{ + "name": "Freddie", + }) + s.NoError(err) + s.NotZero(id) + + _, ok = id.(bson.ObjectId) + s.True(ok) + + id = record.ID() + s.NotNil(id) + + s.True(id.(bson.ObjectId).Valid()) + + // Counting elements, must be exactly 6 elements. + total, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(5), total) +} + +func (s *AdapterTests) TestGetNonExistentRow_Issue426() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + defer sess.Close() + + artist := sess.Collection("artist") + + var one artistType + err = artist.Find(mydb.Cond{"name": "nothing"}).One(&one) + + s.NotZero(err) + s.Equal(mydb.ErrNoMoreRows, err) + + var all []artistType + err = artist.Find(mydb.Cond{"name": "nothing"}).All(&all) + + s.Zero(err, "All should not return mgo.ErrNotFound") + s.Equal(0, len(all)) +} + +func (s *AdapterTests) TestResultCount() { + var err error + var res mydb.Result + + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + defer sess.Close() + + // We should close the database when it's no longer in use. + artist := sess.Collection("artist") + + res = artist.Find() + + // Counting all the matching rows. + total, err := res.Count() + s.NoError(err) + s.NotZero(total) +} + +func (s *AdapterTests) TestGroup() { + var stats mydb.Collection + + sess, err := Open(settings) + s.NoError(err) + + type statsT struct { + Numeric int `db:"numeric" bson:"numeric"` + Value int `db:"value" bson:"value"` + } + + defer sess.Close() + + stats = sess.Collection("statsTest") + + // Truncating table. + _ = stats.Truncate() + + // Adding row append. + for i := 0; i < 1000; i++ { + numeric, value := rand.Intn(10), rand.Intn(100) + _, err = stats.Insert(statsT{numeric, value}) + s.NoError(err) + } + + // mydb.statsTest.group({key: {numeric: true}, initial: {sum: 0}, reduce: function(doc, prev) { prev.sum += 1}}); + + // Testing GROUP BY + res := stats.Find().GroupBy(bson.M{ + "key": bson.M{"numeric": true}, + "initial": bson.M{"sum": 0}, + "reduce": `function(doc, prev) { prev.sum += 1}`, + }) + + var results []map[string]interface{} + + err = res.All(&results) + s.Equal(mydb.ErrUnsupported, err) +} + +func (s *AdapterTests) TestResultNonExistentCount() { + sess, err := Open(settings) + s.NoError(err) + + defer sess.Close() + + total, err := sess.Collection("notartist").Find().Count() + s.NoError(err) + s.Zero(total) +} + +func (s *AdapterTests) TestResultFetch() { + + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + artist := sess.Collection("artist") + + // Testing map + res := artist.Find() + + rowM := map[string]interface{}{} + + for res.Next(&rowM) { + s.NotZero(rowM["_id"]) + + _, ok := rowM["_id"].(bson.ObjectId) + s.True(ok) + + s.True(rowM["_id"].(bson.ObjectId).Valid()) + + name, ok := rowM["name"].(string) + s.True(ok) + s.NotZero(name) + } + + err = res.Close() + s.NoError(err) + + // Testing struct + rowS := struct { + ID bson.ObjectId `bson:"_id"` + Name string `bson:"name"` + }{} + + res = artist.Find() + + for res.Next(&rowS) { + s.True(rowS.ID.Valid()) + s.NotZero(rowS.Name) + } + + err = res.Close() + s.NoError(err) + + // Testing tagged struct + rowT := struct { + Value1 bson.ObjectId `bson:"_id"` + Value2 string `bson:"name"` + }{} + + res = artist.Find() + + for res.Next(&rowT) { + s.True(rowT.Value1.Valid()) + s.NotZero(rowT.Value2) + } + + err = res.Close() + s.NoError(err) + + // Testing Result.All() with a slice of maps. + res = artist.Find() + + allRowsM := []map[string]interface{}{} + err = res.All(&allRowsM) + s.NoError(err) + + for _, singleRowM := range allRowsM { + s.NotZero(singleRowM["_id"]) + } + + // Testing Result.All() with a slice of structs. + res = artist.Find() + + allRowsS := []struct { + ID bson.ObjectId `bson:"_id"` + Name string + }{} + err = res.All(&allRowsS) + s.NoError(err) + + for _, singleRowS := range allRowsS { + s.True(singleRowS.ID.Valid()) + } + + // Testing Result.All() with a slice of tagged structs. + res = artist.Find() + + allRowsT := []struct { + Value1 bson.ObjectId `bson:"_id"` + Value2 string `bson:"name"` + }{} + err = res.All(&allRowsT) + s.NoError(err) + + for _, singleRowT := range allRowsT { + s.True(singleRowT.Value1.Valid()) + } +} + +func (s *AdapterTests) TestUpdate() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + // Value + value := struct { + ID bson.ObjectId `bson:"_id"` + Name string + }{} + + // Getting the first artist. + res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1) + + err = res.One(&value) + s.NoError(err) + + // Updating with a map + rowM := map[string]interface{}{ + "name": strings.ToUpper(value.Name), + } + + err = res.Update(rowM) + s.NoError(err) + + err = res.One(&value) + s.NoError(err) + + s.Equal(value.Name, rowM["name"]) + + // Updating with a struct + rowS := struct { + Name string + }{strings.ToLower(value.Name)} + + err = res.Update(rowS) + s.NoError(err) + + err = res.One(&value) + s.NoError(err) + + s.Equal(value.Name, rowS.Name) + + // Updating with a tagged struct + rowT := struct { + Value1 string `bson:"name"` + }{strings.Replace(value.Name, "z", "Z", -1)} + + err = res.Update(rowT) + s.NoError(err) + + err = res.One(&value) + s.NoError(err) + + s.Equal(value.Name, rowT.Value1) +} + +func (s *AdapterTests) TestOperators() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + rowS := struct { + ID uint64 + Name string + }{} + + res := artist.Find(mydb.Cond{"_id": mydb.NotIn(0, -1)}) + + err = res.One(&rowS) + s.NoError(err) + + err = res.Close() + s.NoError(err) +} + +func (s *AdapterTests) TestDelete() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + // Getting the first artist. + res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1) + + var first struct { + ID bson.ObjectId `bson:"_id"` + } + + err = res.One(&first) + s.NoError(err) + + res = artist.Find(mydb.Cond{"_id": mydb.Eq(first.ID)}) + + // Trying to remove the row. + err = res.Delete() + s.NoError(err) +} + +func (s *AdapterTests) TestDataTypes() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "data_types" collection. + dataTypes := sess.Collection("data_types") + + // Inserting our test subject. + record, err := dataTypes.Insert(testValues) + s.NoError(err) + + id := record.ID() + s.NotZero(id) + + // Trying to get the same subject we added. + res := dataTypes.Find(mydb.Cond{"_id": mydb.Eq(id)}) + + exists, err := res.Count() + s.NoError(err) + s.NotZero(exists) + + // Trying to dump the subject into an empty structure of the same type. + var item testValuesStruct + err = res.One(&item) + s.NoError(err) + + // The original value and the test subject must match. + s.Equal(testValues, item) +} + +func (s *AdapterTests) TestPaginator() { + // Opening database. + sess, err := Open(settings) + s.NoError(err) + + // We should close the database when it's no longer in use. + defer sess.Close() + + // Getting a pointer to the "artist" collection. + artist := sess.Collection("artist") + + err = artist.Truncate() + s.NoError(err) + + for i := 0; i < 999; i++ { + _, err = artist.Insert(artistType{ + Name: fmt.Sprintf("artist-%d", i), + }) + s.NoError(err) + } + + q := sess.Collection("artist").Find().Paginate(15) + paginator := q.Paginate(13) + + var zerothPage []artistType + err = paginator.Page(0).All(&zerothPage) + s.NoError(err) + s.Equal(13, len(zerothPage)) + + var secondPage []artistType + err = paginator.Page(2).All(&secondPage) + s.NoError(err) + s.Equal(13, len(secondPage)) + + tp, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(tp) + s.Equal(uint(77), tp) + + ti, err := paginator.TotalEntries() + s.NoError(err) + s.NotZero(ti) + s.Equal(uint64(999), ti) + + var seventySixthPage []artistType + err = paginator.Page(76).All(&seventySixthPage) + s.NoError(err) + s.Equal(11, len(seventySixthPage)) + + var seventySeventhPage []artistType + err = paginator.Page(77).All(&seventySeventhPage) + s.NoError(err) + s.Equal(0, len(seventySeventhPage)) + + var hundredthPage []artistType + err = paginator.Page(100).All(&hundredthPage) + s.NoError(err) + s.Equal(0, len(hundredthPage)) + + for i := uint(0); i < tp; i++ { + current := paginator.Page(i) + + var items []artistType + err := current.All(&items) + s.NoError(err) + if len(items) < 1 { + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name) + } + } + + paginator = paginator.Cursor("_id") + { + current := paginator.Page(0) + for i := 0; ; i++ { + var items []artistType + err := current.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name) + } + current = current.NextPage(items[len(items)-1].ID) + } + } + + { + log.Printf("Page 76") + current := paginator.Page(76) + for i := 76; ; i-- { + var items []artistType + + err := current.All(&items) + s.NoError(err) + + if len(items) < 1 { + s.Equal(0, len(items)) + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 13*int(i)+j), items[j].Name) + } + + current = current.PrevPage(items[0].ID) + } + } + + { + resultPaginator := sess.Collection("artist").Find().Paginate(15) + + count, err := resultPaginator.TotalPages() + s.Equal(uint(67), count) + s.NoError(err) + + var items []artistType + err = resultPaginator.Page(5).All(&items) + s.NoError(err) + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 15*5+j), items[j].Name) + } + + resultPaginator = resultPaginator.Cursor("_id").Page(0) + for i := 0; ; i++ { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name) + } + resultPaginator = resultPaginator.NextPage(items[len(items)-1].ID) + } + + resultPaginator = resultPaginator.Cursor("_id").Page(66) + for i := 66; ; i-- { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name) + } + resultPaginator = resultPaginator.PrevPage(items[0].ID) + } + } +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/mongo/result.go b/adapter/mongo/result.go new file mode 100644 index 0000000..f19d9ef --- /dev/null +++ b/adapter/mongo/result.go @@ -0,0 +1,565 @@ +package mongo + +import ( + "errors" + "fmt" + "math" + "strings" + "sync" + "time" + + "encoding/json" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + mgo "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +type resultQuery struct { + c *Collection + + fields []string + limit int + offset int + sort []string + conditions interface{} + groupBy []interface{} + + pageSize uint + pageNumber uint + cursorColumn string + cursorValue interface{} + cursorCond mydb.Cond + cursorReverseOrder bool +} + +type result struct { + iter *mgo.Iter + err error + errMu sync.Mutex + + fn func(*resultQuery) error + prev *result +} + +var _ = immutable.Immutable(&result{}) + +func (res *result) frame(fn func(*resultQuery) error) *result { + return &result{prev: res, fn: fn} +} + +func (r *resultQuery) and(terms ...interface{}) error { + if r.conditions == nil { + return r.where(terms...) + } + + r.conditions = map[string]interface{}{ + "$and": []interface{}{ + r.conditions, + r.c.compileQuery(terms...), + }, + } + return nil +} + +func (r *resultQuery) where(terms ...interface{}) error { + r.conditions = r.c.compileQuery(terms...) + return nil +} + +func (res *result) And(terms ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + return r.and(terms...) + }) +} + +func (res *result) Where(terms ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + return r.where(terms...) + }) +} + +func (res *result) Paginate(pageSize uint) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.pageSize = pageSize + return nil + }) +} + +func (res *result) Page(pageNumber uint) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.pageNumber = pageNumber + return nil + }) +} + +func (res *result) Cursor(cursorColumn string) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.cursorColumn = cursorColumn + return nil + }) +} + +func (res *result) NextPage(cursorValue interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.cursorValue = cursorValue + r.cursorReverseOrder = false + r.cursorCond = mydb.Cond{ + r.cursorColumn: bson.M{"$gt": cursorValue}, + } + return nil + }) +} + +func (res *result) PrevPage(cursorValue interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.cursorValue = cursorValue + r.cursorReverseOrder = true + r.cursorCond = mydb.Cond{ + r.cursorColumn: bson.M{"$lt": cursorValue}, + } + return nil + }) +} + +func (res *result) TotalEntries() (uint64, error) { + return res.Count() +} + +func (res *result) TotalPages() (uint, error) { + count, err := res.Count() + if err != nil { + return 0, err + } + + rq, err := res.build() + if err != nil { + return 0, err + } + + if rq.pageSize < 1 { + return 1, nil + } + + total := uint(math.Ceil(float64(count) / float64(rq.pageSize))) + return total, nil +} + +// Limit determines the maximum limit of results to be returned. +func (res *result) Limit(n int) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.limit = n + return nil + }) +} + +// Offset determines how many documents will be skipped before starting to grab +// results. +func (res *result) Offset(n int) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.offset = n + return nil + }) +} + +// OrderBy determines sorting of results according to the provided names. Fields +// may be prefixed by - (minus) which means descending order, ascending order +// would be used otherwise. +func (res *result) OrderBy(fields ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + ss := make([]string, len(fields)) + for i, field := range fields { + ss[i] = fmt.Sprintf(`%v`, field) + } + r.sort = ss + return nil + }) +} + +// String satisfies fmt.Stringer +func (res *result) String() string { + return "" +} + +// Select marks the specific fields the user wants to retrieve. +func (res *result) Select(fields ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + fieldslen := len(fields) + r.fields = make([]string, 0, fieldslen) + for i := 0; i < fieldslen; i++ { + r.fields = append(r.fields, fmt.Sprintf(`%v`, fields[i])) + } + return nil + }) +} + +// All dumps all results into a pointer to an slice of structs or maps. +func (res *result) All(dst interface{}) error { + rq, err := res.build() + if err != nil { + return err + } + + q, err := rq.query() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.All"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + err = q.All(dst) + if errors.Is(err, mgo.ErrNotFound) { + return mydb.ErrNoMoreRows + } + return err +} + +// GroupBy is used to group results that have the same value in the same column +// or columns. +func (res *result) GroupBy(fields ...interface{}) mydb.Result { + return res.frame(func(r *resultQuery) error { + r.groupBy = fields + return nil + }) +} + +// One fetches only one result from the resultset. +func (res *result) One(dst interface{}) error { + rq, err := res.build() + if err != nil { + return err + } + + q, err := rq.query() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.One"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + err = q.One(dst) + if errors.Is(err, mgo.ErrNotFound) { + return mydb.ErrNoMoreRows + } + return err +} + +func (res *result) Err() error { + res.errMu.Lock() + defer res.errMu.Unlock() + + return res.err +} + +func (res *result) setErr(err error) { + res.errMu.Lock() + defer res.errMu.Unlock() + res.err = err +} + +func (res *result) Next(dst interface{}) bool { + if res.iter == nil { + rq, err := res.build() + if err != nil { + return false + } + + q, err := rq.query() + if err != nil { + return false + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.Next"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + res.iter = q.Iter() + } + + if !res.iter.Next(dst) { + res.setErr(res.iter.Err()) + return false + } + + return true +} + +// Delete remove the matching items from the collection. +func (res *result) Delete() error { + rq, err := res.build() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Remove"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + _, err = rq.c.collection.RemoveAll(rq.conditions) + if err != nil { + return err + } + + return nil +} + +// Close closes the result set. +func (r *result) Close() error { + var err error + if r.iter != nil { + err = r.iter.Close() + r.iter = nil + } + return err +} + +// Update modified matching items from the collection with values of the given +// map or struct. +func (res *result) Update(src interface{}) (err error) { + updateSet := map[string]interface{}{"$set": src} + + rq, err := res.build() + if err != nil { + return err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Update"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + _, err = rq.c.collection.UpdateAll(rq.conditions, updateSet) + if err != nil { + return err + } + return nil +} + +func (res *result) build() (*resultQuery, error) { + rqi, err := immutable.FastForward(res) + if err != nil { + return nil, err + } + + rq := rqi.(*resultQuery) + if !rq.cursorCond.Empty() { + if err := rq.and(rq.cursorCond); err != nil { + return nil, err + } + } + + if rq.cursorColumn != "" { + if rq.cursorReverseOrder { + rq.sort = append(rq.sort, "-"+rq.cursorColumn) + } else { + rq.sort = append(rq.sort, rq.cursorColumn) + } + } + return rq, nil +} + +// query executes a mgo query. +func (r *resultQuery) query() (*mgo.Query, error) { + if len(r.groupBy) > 0 { + return nil, mydb.ErrUnsupported + } + + q := r.c.collection.Find(r.conditions) + + if r.pageSize > 0 { + r.offset = int(r.pageSize * r.pageNumber) + r.limit = int(r.pageSize) + } + + if r.offset > 0 { + q.Skip(r.offset) + } + + if r.limit > 0 { + q.Limit(r.limit) + } + + if len(r.sort) > 0 { + q.Sort(r.sort...) + } + + selectedFields := bson.M{} + if len(r.fields) > 0 { + for _, field := range r.fields { + if field == `*` { + break + } + selectedFields[field] = true + } + } + + if r.cursorReverseOrder { + ids := make([]bson.ObjectId, 0, r.limit) + + iter := q.Select(bson.M{"_id": true}).Iter() + defer iter.Close() + + var item map[string]bson.ObjectId + for iter.Next(&item) { + ids = append(ids, item["_id"]) + } + + r.conditions = bson.M{"_id": bson.M{"$in": ids}} + + q = r.c.collection.Find(r.conditions) + } + + if len(selectedFields) > 0 { + q.Select(selectedFields) + } + + return q, nil +} + +func (res *result) Exists() (bool, error) { + total, err := res.Count() + if err != nil { + return false, err + } + if total > 0 { + return true, nil + } + return false, nil +} + +// Count counts matching elements. +func (res *result) Count() (total uint64, err error) { + rq, err := res.build() + if err != nil { + return 0, err + } + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + RawQuery: rq.debugQuery("Find.Count"), + Err: err, + Start: start, + End: time.Now(), + }) + }(time.Now()) + + q := rq.c.collection.Find(rq.conditions) + + var c int + c, err = q.Count() + + return uint64(c), err +} + +func (res *result) Prev() immutable.Immutable { + if res == nil { + return nil + } + return res.prev +} + +func (res *result) Fn(in interface{}) error { + if res.fn == nil { + return nil + } + return res.fn(in.(*resultQuery)) +} + +func (res *result) Base() interface{} { + return &resultQuery{} +} + +func (r *resultQuery) debugQuery(action string) string { + query := fmt.Sprintf("mydb.%s.%s", r.c.collection.Name, action) + + if r.conditions != nil { + query = fmt.Sprintf("%s.conds(%v)", query, r.conditions) + } + if r.limit > 0 { + query = fmt.Sprintf("%s.limit(%d)", query, r.limit) + } + if r.offset > 0 { + query = fmt.Sprintf("%s.offset(%d)", query, r.offset) + } + if len(r.fields) > 0 { + selectedFields := bson.M{} + for _, field := range r.fields { + if field == `*` { + break + } + selectedFields[field] = true + } + if len(selectedFields) > 0 { + query = fmt.Sprintf("%s.select(%v)", query, selectedFields) + } + } + if len(r.groupBy) > 0 { + escaped := make([]string, len(r.groupBy)) + for i := range r.groupBy { + escaped[i] = string(mustJSON(r.groupBy[i])) + } + query = fmt.Sprintf("%s.groupBy(%v)", query, strings.Join(escaped, ", ")) + } + if len(r.sort) > 0 { + escaped := make([]string, len(r.sort)) + for i := range r.sort { + escaped[i] = string(mustJSON(r.sort[i])) + } + query = fmt.Sprintf("%s.sort(%s)", query, strings.Join(escaped, ", ")) + } + return query +} + +func mustJSON(in interface{}) (out []byte) { + out, err := json.Marshal(in) + if err != nil { + panic(err) + } + return out +} + +func queryLog(status *mydb.QueryStatus) { + diff := status.End.Sub(status.Start) + + slowQuery := false + if diff >= time.Millisecond*100 { + status.Err = mydb.ErrWarnSlowQuery + slowQuery = true + } + + if status.Err != nil || slowQuery { + mydb.LC().Warn(status) + return + } + + mydb.LC().Debug(status) +} diff --git a/adapter/mysql/Makefile b/adapter/mysql/Makefile new file mode 100644 index 0000000..f3a8759 --- /dev/null +++ b/adapter/mysql/Makefile @@ -0,0 +1,43 @@ +SHELL ?= bash + +MYSQL_VERSION ?= 8.1 +MYSQL_SUPPORTED ?= $(MYSQL_VERSION) 5.7 +PROJECT ?= upper_mysql_$(MYSQL_VERSION) + +DB_HOST ?= 127.0.0.1 +DB_PORT ?= 3306 + +DB_NAME ?= upperio +DB_USERNAME ?= upperio_user +DB_PASSWORD ?= upperio//s3cr37 + +TEST_FLAGS ?= +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1 + +export MYSQL_VERSION + +export DB_HOST +export DB_NAME +export DB_PASSWORD +export DB_PORT +export DB_USERNAME + +export TEST_FLAGS + +test: + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +server-up: server-down + docker-compose -p $(PROJECT) up -d && \ + sleep 15 + +server-down: + docker-compose -p $(PROJECT) down + +test-extended: + parallel $(PARALLEL_FLAGS) \ + "MYSQL_VERSION={} DB_PORT=\$$((3306+{#})) $(MAKE) server-up test server-down" ::: \ + $(MYSQL_SUPPORTED) diff --git a/adapter/mysql/README.md b/adapter/mysql/README.md new file mode 100644 index 0000000..f427fee --- /dev/null +++ b/adapter/mysql/README.md @@ -0,0 +1,5 @@ +# MySQL adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/mysql/](https://upper.io/v4/adapter/mysql/). + diff --git a/adapter/mysql/collection.go b/adapter/mysql/collection.go new file mode 100644 index 0000000..57a5275 --- /dev/null +++ b/adapter/mysql/collection.go @@ -0,0 +1,56 @@ +package mysql + +import ( + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +type collectionAdapter struct { +} + +func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) { + columnNames, columnValues, err := sqlbuilder.Map(item, nil) + if err != nil { + return nil, err + } + + pKey, err := col.PrimaryKeys() + if err != nil { + return nil, err + } + + q := col.SQL().InsertInto(col.Name()). + Columns(columnNames...). + Values(columnValues...) + + res, err := q.Exec() + if err != nil { + return nil, err + } + + lastID, err := res.LastInsertId() + if err == nil && len(pKey) <= 1 { + return lastID, nil + } + + keyMap := mydb.Cond{} + for i := range columnNames { + for j := 0; j < len(pKey); j++ { + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] + } + } + } + + // There was an auto column among primary keys, let's search for it. + if lastID > 0 { + for j := 0; j < len(pKey); j++ { + if keyMap[pKey[j]] == nil { + keyMap[pKey[j]] = lastID + } + } + } + + return keyMap, nil +} diff --git a/adapter/mysql/connection.go b/adapter/mysql/connection.go new file mode 100644 index 0000000..b9e6237 --- /dev/null +++ b/adapter/mysql/connection.go @@ -0,0 +1,244 @@ +package mysql + +import ( + "errors" + "fmt" + "net" + "net/url" + "strings" +) + +// From https://github.com/go-sql-driver/mysql/blob/master/utils.go +var ( + errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?") + errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)") + errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name") +) + +// From https://github.com/go-sql-driver/mysql/blob/master/utils.go +type config struct { + user string + passwd string + net string + addr string + dbname string + params map[string]string +} + +// ConnectionURL implements a MySQL connection struct. +type ConnectionURL struct { + User string + Password string + Database string + Host string + Socket string + Options map[string]string +} + +func (c ConnectionURL) String() (s string) { + + if c.Database == "" { + return "" + } + + // Adding username. + if c.User != "" { + s = s + c.User + // Adding password. + if c.Password != "" { + s = s + ":" + c.Password + } + s = s + "@" + } + + // Adding protocol and address + if c.Socket != "" { + s = s + fmt.Sprintf("unix(%s)", c.Socket) + } else if c.Host != "" { + host, port, err := net.SplitHostPort(c.Host) + if err != nil { + host = c.Host + port = "3306" + } + s = s + fmt.Sprintf("tcp(%s:%s)", host, port) + } + + // Adding database + s = s + "/" + c.Database + + // Do we have any options? + if c.Options == nil { + c.Options = map[string]string{} + } + + // Default options. + if _, ok := c.Options["charset"]; !ok { + c.Options["charset"] = "utf8" + } + + if _, ok := c.Options["parseTime"]; !ok { + c.Options["parseTime"] = "true" + } + + // Converting options into URL values. + vv := url.Values{} + + for k, v := range c.Options { + vv.Set(k, v) + } + + // Inserting options. + if p := vv.Encode(); p != "" { + s = s + "?" + p + } + + return s +} + +// ParseURL parses s into a ConnectionURL struct. +func ParseURL(s string) (conn ConnectionURL, err error) { + var cfg *config + + if cfg, err = parseDSN(s); err != nil { + return + } + + conn.User = cfg.user + conn.Password = cfg.passwd + + if cfg.net == "unix" { + conn.Socket = cfg.addr + } else if cfg.net == "tcp" { + conn.Host = cfg.addr + } + + conn.Database = cfg.dbname + + conn.Options = map[string]string{} + + for k, v := range cfg.params { + conn.Options[k] = v + } + + return +} + +// from https://github.com/go-sql-driver/mysql/blob/master/utils.go +// parseDSN parses the DSN string to a config +func parseDSN(dsn string) (cfg *config, err error) { + // New config with some default values + cfg = &config{} + + // TODO: use strings.IndexByte when we can depend on Go 1.2 + + // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] + // Find the last '/' (since the password or the net addr might contain a '/') + foundSlash := false + for i := len(dsn) - 1; i >= 0; i-- { + if dsn[i] == '/' { + foundSlash = true + var j, k int + + // left part is empty if i <= 0 + if i > 0 { + // [username[:password]@][protocol[(address)]] + // Find the last '@' in dsn[:i] + for j = i; j >= 0; j-- { + if dsn[j] == '@' { + // username[:password] + // Find the first ':' in dsn[:j] + for k = 0; k < j; k++ { + if dsn[k] == ':' { + cfg.passwd = dsn[k+1 : j] + break + } + } + cfg.user = dsn[:k] + + break + } + } + + // [protocol[(address)]] + // Find the first '(' in dsn[j+1:i] + for k = j + 1; k < i; k++ { + if dsn[k] == '(' { + // dsn[i-1] must be == ')' if an address is specified + if dsn[i-1] != ')' { + if strings.ContainsRune(dsn[k+1:i], ')') { + return nil, errInvalidDSNUnescaped + } + return nil, errInvalidDSNAddr + } + cfg.addr = dsn[k+1 : i-1] + break + } + } + cfg.net = dsn[j+1 : k] + } + + // dbname[?param1=value1&...¶mN=valueN] + // Find the first '?' in dsn[i+1:] + for j = i + 1; j < len(dsn); j++ { + if dsn[j] == '?' { + if err = parseDSNParams(cfg, dsn[j+1:]); err != nil { + return + } + break + } + } + cfg.dbname = dsn[i+1 : j] + + break + } + } + + if !foundSlash && len(dsn) > 0 { + return nil, errInvalidDSNNoSlash + } + + // Set default network if empty + if cfg.net == "" { + cfg.net = "tcp" + } + + // Set default address if empty + if cfg.addr == "" { + switch cfg.net { + case "tcp": + cfg.addr = "127.0.0.1:3306" + case "unix": + cfg.addr = "/tmp/mysql.sock" + default: + return nil, errors.New("Default addr for network '" + cfg.net + "' unknown") + } + + } + + return +} + +// From https://github.com/go-sql-driver/mysql/blob/master/utils.go +// parseDSNParams parses the DSN "query string" +// Values must be url.QueryEscape'ed +func parseDSNParams(cfg *config, params string) (err error) { + for _, v := range strings.Split(params, "&") { + param := strings.SplitN(v, "=", 2) + if len(param) != 2 { + continue + } + + value := param[1] + + // lazy init + if cfg.params == nil { + cfg.params = make(map[string]string) + } + + if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil { + return + } + } + + return +} diff --git a/adapter/mysql/connection_test.go b/adapter/mysql/connection_test.go new file mode 100644 index 0000000..172856f --- /dev/null +++ b/adapter/mysql/connection_test.go @@ -0,0 +1,117 @@ +package mysql + +import ( + "testing" +) + +func TestConnectionURL(t *testing.T) { + + c := ConnectionURL{} + + // Zero value equals to an empty string. + if c.String() != "" { + t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String()) + } + + // Adding a database name. + c.Database = "mydbname" + + if c.String() != "/mydbname?charset=utf8&parseTime=true" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Adding an option. + c.Options = map[string]string{ + "charset": "utf8mb4,utf8", + "sys_var": "esc@ped", + } + + if c.String() != "/mydbname?charset=utf8mb4%2Cutf8&parseTime=true&sys_var=esc%40ped" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting default options + c.Options = nil + + // Setting user and password. + c.User = "user" + c.Password = "pass" + + if c.String() != `user:pass@/mydbname?charset=utf8&parseTime=true` { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting host. + c.Host = "1.2.3.4:3306" + + if c.String() != `user:pass@tcp(1.2.3.4:3306)/mydbname?charset=utf8&parseTime=true` { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting socket. + c.Socket = "/path/to/socket" + + if c.String() != `user:pass@unix(/path/to/socket)/mydbname?charset=utf8&parseTime=true` { + t.Fatal(`Test failed, got:`, c.String()) + } + +} + +func TestParseConnectionURL(t *testing.T) { + var u ConnectionURL + var s string + var err error + + s = "user:pass@unix(/path/to/socket)/mydbname?charset=utf8" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.User != "user" { + t.Fatal("Expecting username.") + } + + if u.Password != "pass" { + t.Fatal("Expecting password.") + } + + if u.Socket != "/path/to/socket" { + t.Fatal("Expecting socket.") + } + + if u.Database != "mydbname" { + t.Fatal("Expecting database.") + } + + if u.Options["charset"] != "utf8" { + t.Fatal("Expecting charset.") + } + + s = "user:pass@tcp(1.2.3.4:5678)/mydbname?charset=utf8" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.User != "user" { + t.Fatal("Expecting username.") + } + + if u.Password != "pass" { + t.Fatal("Expecting password.") + } + + if u.Host != "1.2.3.4:5678" { + t.Fatal("Expecting host.") + } + + if u.Database != "mydbname" { + t.Fatal("Expecting database.") + } + + if u.Options["charset"] != "utf8" { + t.Fatal("Expecting charset.") + } + +} diff --git a/adapter/mysql/custom_types.go b/adapter/mysql/custom_types.go new file mode 100644 index 0000000..034f354 --- /dev/null +++ b/adapter/mysql/custom_types.go @@ -0,0 +1,151 @@ +package mysql + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "errors" + "reflect" + + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// JSON represents a MySQL's JSON value: +// https://www.mysql.org/docs/9.6/static/datatype-json.html. JSON +// satisfies sqlbuilder.ScannerValuer. +type JSON struct { + V interface{} +} + +// MarshalJSON encodes the wrapper value as JSON. +func (j JSON) MarshalJSON() ([]byte, error) { + return json.Marshal(j.V) +} + +// UnmarshalJSON decodes the given JSON into the wrapped value. +func (j *JSON) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + j.V = v + return nil +} + +// Scan satisfies the sql.Scanner interface. +func (j *JSON) Scan(src interface{}) error { + if j.V == nil { + return nil + } + if src == nil { + dv := reflect.Indirect(reflect.ValueOf(j.V)) + dv.Set(reflect.Zero(dv.Type())) + return nil + } + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + if err := json.Unmarshal(b, j.V); err != nil { + return err + } + return nil +} + +// Value satisfies the driver.Valuer interface. +func (j JSON) Value() (driver.Value, error) { + if j.V == nil { + return nil, nil + } + if v, ok := j.V.(json.RawMessage); ok { + return string(v), nil + } + b, err := json.Marshal(j.V) + if err != nil { + return nil, err + } + return string(b), nil +} + +// JSONMap represents a map of interfaces with string keys +// (`map[string]interface{}`) that is compatible with MySQL's JSON type. +// JSONMap satisfies sqlbuilder.ScannerValuer. +type JSONMap map[string]interface{} + +// Value satisfies the driver.Valuer interface. +func (m JSONMap) Value() (driver.Value, error) { + return JSONValue(m) +} + +// Scan satisfies the sql.Scanner interface. +func (m *JSONMap) Scan(src interface{}) error { + *m = map[string]interface{}(nil) + return ScanJSON(m, src) +} + +// JSONArray represents an array of any type (`[]interface{}`) that is +// compatible with MySQL's JSON type. JSONArray satisfies +// sqlbuilder.ScannerValuer. +type JSONArray []interface{} + +// Value satisfies the driver.Valuer interface. +func (a JSONArray) Value() (driver.Value, error) { + return JSONValue(a) +} + +// Scan satisfies the sql.Scanner interface. +func (a *JSONArray) Scan(src interface{}) error { + return ScanJSON(a, src) +} + +// JSONValue takes an interface and provides a driver.Value that can be +// stored as a JSON column. +func JSONValue(i interface{}) (driver.Value, error) { + v := JSON{i} + return v.Value() +} + +// ScanJSON decodes a JSON byte stream into the passed dst value. +func ScanJSON(dst interface{}, src interface{}) error { + v := JSON{dst} + return v.Scan(src) +} + +// EncodeJSON is deprecated and going to be removed. Use ScanJSON instead. +func EncodeJSON(i interface{}) (driver.Value, error) { + return JSONValue(i) +} + +// DecodeJSON is deprecated and going to be removed. Use JSONValue instead. +func DecodeJSON(dst interface{}, src interface{}) error { + return ScanJSON(dst, src) +} + +// JSONConverter provides a helper method WrapValue that satisfies +// sqlbuilder.ValueWrapper, can be used to encode Go structs into JSON +// MySQL types and vice versa. +// +// Example: +// +// type MyCustomStruct struct { +// ID int64 `db:"id" json:"id"` +// Name string `db:"name" json:"name"` +// ... +// mysql.JSONConverter +// } +type JSONConverter struct{} + +func (*JSONConverter) ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer +} { + return &JSON{in} +} + +// Type checks. +var ( + _ sqlbuilder.ScannerValuer = &JSONMap{} + _ sqlbuilder.ScannerValuer = &JSONArray{} + _ sqlbuilder.ScannerValuer = &JSON{} +) diff --git a/adapter/mysql/database.go b/adapter/mysql/database.go new file mode 100644 index 0000000..e601eb7 --- /dev/null +++ b/adapter/mysql/database.go @@ -0,0 +1,168 @@ +// Package mysql wraps the github.com/go-sql-driver/mysql MySQL driver. See +// https://github.com/upper/db/adapter/mysql for documentation, particularities and usage +// examples. +package mysql + +import ( + "reflect" + "strings" + + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + _ "github.com/go-sql-driver/mysql" // MySQL driver. +) + +// database is the actual implementation of Database +type database struct { +} + +func (*database) Template() *exql.Template { + return template +} + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + return sql.Open("mysql", dsn) +} + +func (*database) Collections(sess sqladapter.Session) (collections []string, err error) { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ?", sess.Name()) + + iter := q.Iterator() + defer iter.Close() + + for iter.Next() { + var tableName string + if err := iter.Scan(&tableName); err != nil { + return nil, err + } + collections = append(collections, tableName) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return collections, nil +} + +func (d *database) ConvertValue(in interface{}) interface{} { + switch v := in.(type) { + case *map[string]interface{}: + return (*JSONMap)(v) + + case map[string]interface{}: + return (*JSONMap)(&v) + } + + dv := reflect.ValueOf(in) + if dv.IsValid() { + if dv.Type().Kind() == reflect.Ptr { + dv = dv.Elem() + } + + switch dv.Kind() { + case reflect.Map: + if reflect.TypeOf(in).Kind() == reflect.Ptr { + w := reflect.ValueOf(in) + z := reflect.New(w.Elem().Type()) + w.Elem().Set(z.Elem()) + } + return &JSON{in} + case reflect.Slice: + return &JSON{in} + } + } + + return in +} + +func (*database) Err(err error) error { + if err != nil { + // This error is not exported so we have to check it by its string value. + s := err.Error() + if strings.Contains(s, `many connections`) { + return mydb.ErrTooManyClients + } + } + return err +} + +func (*database) NewCollection() sqladapter.CollectionAdapter { + return &collectionAdapter{} +} + +func (*database) LookupName(sess sqladapter.Session) (string, error) { + q := sess.SQL(). + Select(mydb.Raw("DATABASE() AS name")) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return "", err + } + return name, nil + } + + return "", iter.Err() +} + +func (*database) TableExists(sess sqladapter.Session, name string) error { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ? AND table_name = ?", sess.Name(), name) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return err + } + return nil + } + if err := iter.Err(); err != nil { + return err + } + + return mydb.ErrCollectionDoesNotExist +} + +func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) { + q := sess.SQL(). + Select("k.column_name"). + From("information_schema.key_column_usage AS k"). + Where(` + k.constraint_name = 'PRIMARY' + AND k.table_schema = ? + AND k.table_name = ? + `, sess.Name(), tableName). + OrderBy("k.ordinal_position") + + iter := q.Iterator() + defer iter.Close() + + pk := []string{} + + for iter.Next() { + var k string + if err := iter.Scan(&k); err != nil { + return nil, err + } + pk = append(pk, k) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return pk, nil +} diff --git a/adapter/mysql/docker-compose.yml b/adapter/mysql/docker-compose.yml new file mode 100644 index 0000000..18ab349 --- /dev/null +++ b/adapter/mysql/docker-compose.yml @@ -0,0 +1,14 @@ +version: '3' + +services: + + server: + image: mysql:${MYSQL_VERSION:-5} + environment: + MYSQL_USER: ${DB_USERNAME:-upperio_user} + MYSQL_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37} + MYSQL_ALLOW_EMPTY_PASSWORD: 1 + MYSQL_DATABASE: ${DB_NAME:-upperio} + ports: + - '${DB_HOST:-127.0.0.1}:${DB_PORT:-3306}:3306' + diff --git a/adapter/mysql/generic_test.go b/adapter/mysql/generic_test.go new file mode 100644 index 0000000..db7e430 --- /dev/null +++ b/adapter/mysql/generic_test.go @@ -0,0 +1,20 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/mysql/helper_test.go b/adapter/mysql/helper_test.go new file mode 100644 index 0000000..a14b897 --- /dev/null +++ b/adapter/mysql/helper_test.go @@ -0,0 +1,276 @@ +package mysql + +import ( + "database/sql" + "fmt" + "os" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/testsuite" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), + User: os.Getenv("DB_USERNAME"), + Password: os.Getenv("DB_PASSWORD"), + Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"), + Options: map[string]string{ + // See https://github.com/go-sql-driver/mysql/issues/9 + "parseTime": "true", + // Might require you to use mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root -p mysql + "time_zone": fmt.Sprintf(`'%s'`, testsuite.TimeZone), + "loc": testsuite.TimeZone, + }, +} + +type Helper struct { + sess mydb.Session +} + +func cleanUp(sess mydb.Session) error { + if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 { + return fmt.Errorf("Expecting active statements to be at most 128, got %d", activeStatements) + } + + sess.Reset() + + if activeStatements := sqladapter.NumActiveStatements(); activeStatements != 0 { + return fmt.Errorf("Expecting active statements to be 0, got %d", activeStatements) + } + + var err error + var stats map[string]int + for i := 0; i < 10; i++ { + stats, err = getStats(sess) + if err != nil { + return err + } + + if stats["Prepared_stmt_count"] != 0 { + time.Sleep(time.Millisecond * 200) // Sometimes it takes a bit to clean prepared statements + err = fmt.Errorf(`Expecting "Prepared_stmt_count" to be 0, got %d`, stats["Prepared_stmt_count"]) + continue + } + break + } + + return err +} + +func getStats(sess mydb.Session) (map[string]int, error) { + stats := make(map[string]int) + + res, err := sess.Driver().(*sql.DB).Query(`SHOW GLOBAL STATUS LIKE '%stmt%'`) + if err != nil { + return nil, err + } + var result struct { + VariableName string `db:"Variable_name"` + Value int `db:"Value"` + } + + iter := sess.SQL().NewIterator(res) + for iter.Next(&result) { + stats[result.VariableName] = result.Value + } + + return stats, nil +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return "mysql" +} + +func (h *Helper) TearDown() error { + if err := cleanUp(h.sess); err != nil { + return err + } + + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + batch := []string{ + `DROP TABLE IF EXISTS artist`, + `CREATE TABLE artist ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + name VARCHAR(60) + )`, + + `DROP TABLE IF EXISTS publication`, + `CREATE TABLE publication ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + title VARCHAR(80), + author_id BIGINT(20) + )`, + + `DROP TABLE IF EXISTS review`, + `CREATE TABLE review ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + publication_id BIGINT(20), + name VARCHAR(80), + comments TEXT, + created DATETIME NOT NULL + )`, + + `DROP TABLE IF EXISTS data_types`, + `CREATE TABLE data_types ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + _uint INT(10) UNSIGNED DEFAULT 0, + _uint8 INT(10) UNSIGNED DEFAULT 0, + _uint16 INT(10) UNSIGNED DEFAULT 0, + _uint32 INT(10) UNSIGNED DEFAULT 0, + _uint64 INT(10) UNSIGNED DEFAULT 0, + _int INT(10) DEFAULT 0, + _int8 INT(10) DEFAULT 0, + _int16 INT(10) DEFAULT 0, + _int32 INT(10) DEFAULT 0, + _int64 INT(10) DEFAULT 0, + _float32 DECIMAL(10,6), + _float64 DECIMAL(10,6), + _bool TINYINT(1), + _string text, + _blob blob, + _date TIMESTAMP NULL, + _nildate DATETIME NULL, + _ptrdate DATETIME NULL, + _defaultdate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP, + _time BIGINT UNSIGNED NOT NULL DEFAULT 0 + )`, + + `DROP TABLE IF EXISTS stats_test`, + `CREATE TABLE stats_test ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + ` + "`numeric`" + ` INT(10), + ` + "`value`" + ` INT(10) + )`, + + `DROP TABLE IF EXISTS composite_keys`, + `CREATE TABLE composite_keys ( + code VARCHAR(255) default '', + user_id VARCHAR(255) default '', + some_val VARCHAR(255) default '', + primary key (code, user_id) + )`, + + `DROP TABLE IF EXISTS admin`, + `CREATE TABLE admin ( + ID int(11) NOT NULL AUTO_INCREMENT, + Accounts varchar(255) DEFAULT '', + LoginPassWord varchar(255) DEFAULT '', + Date TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6), + PRIMARY KEY (ID,Date) + ) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8`, + + `DROP TABLE IF EXISTS my_types`, + + `CREATE TABLE my_types (id int(11) NOT NULL AUTO_INCREMENT, PRIMARY KEY(id) + , json_map JSON + , json_map_ptr JSON + + , auto_json_map JSON + , auto_json_map_string JSON + , auto_json_map_integer JSON + + , json_object JSON + , json_array JSON + + , custom_json_object JSON + , auto_custom_json_object JSON + + , custom_json_object_ptr JSON + , auto_custom_json_object_ptr JSON + + , custom_json_object_array JSON + , auto_custom_json_object_array JSON + , auto_custom_json_object_map JSON + + , integer_compat_value_json_array JSON + , string_compat_value_json_array JSON + , uinteger_compat_value_json_array JSON + + )`, + + `DROP TABLE IF EXISTS ` + "`" + `birthdays` + "`" + ``, + `CREATE TABLE ` + "`" + `birthdays` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + name VARCHAR(50), + born DATE, + born_ut BIGINT(20) SIGNED + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS ` + "`" + `fibonacci` + "`" + ``, + `CREATE TABLE ` + "`" + `fibonacci` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + input BIGINT(20) UNSIGNED NOT NULL, + output BIGINT(20) UNSIGNED NOT NULL + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS ` + "`" + `is_even` + "`" + ``, + `CREATE TABLE ` + "`" + `is_even` + "`" + ` ( + input BIGINT(20) UNSIGNED NOT NULL, + is_even TINYINT(1) + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS ` + "`" + `CaSe_TesT` + "`" + ``, + `CREATE TABLE ` + "`" + `CaSe_TesT` + "`" + ` ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id), + case_test VARCHAR(60) + ) CHARSET=utf8`, + + `DROP TABLE IF EXISTS accounts`, + + `CREATE TABLE accounts ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + name varchar(255), + disabled BOOL, + created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP + )`, + + `DROP TABLE IF EXISTS users`, + + `CREATE TABLE users ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + account_id BIGINT(20), + username varchar(255) UNIQUE + )`, + + `DROP TABLE IF EXISTS logs`, + + `CREATE TABLE logs ( + id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, + PRIMARY KEY(id), + message VARCHAR(255) + )`, + } + + for _, query := range batch { + driver := h.sess.Driver().(*sql.DB) + if _, err := driver.Exec(query); err != nil { + return err + } + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/mysql/mysql.go b/adapter/mysql/mysql.go new file mode 100644 index 0000000..00a3876 --- /dev/null +++ b/adapter/mysql/mysql.go @@ -0,0 +1,30 @@ +package mysql + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// Adapter is the public name of the adapter. +const Adapter = `mysql` + +var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{}) + +// Open establishes a connection to the database server and returns a +// mydb.Session instance (which is compatible with mydb.Session). +func Open(connURL mydb.ConnectionURL) (mydb.Session, error) { + return registeredAdapter.OpenDSN(connURL) +} + +// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value. +func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + return registeredAdapter.NewTx(sqlTx) +} + +// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value. +func New(sqlDB *sql.DB) (mydb.Session, error) { + return registeredAdapter.New(sqlDB) +} diff --git a/adapter/mysql/mysql_test.go b/adapter/mysql/mysql_test.go new file mode 100644 index 0000000..7e0d4ff --- /dev/null +++ b/adapter/mysql/mysql_test.go @@ -0,0 +1,379 @@ +package mysql + +import ( + "database/sql" + "database/sql/driver" + "fmt" + "math/rand" + "strconv" + "testing" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type int64Compat int64 + +type uintCompat uint + +type stringCompat string + +type uintCompatArray []uintCompat + +func (u *int64Compat) Scan(src interface{}) error { + if src != nil { + switch v := src.(type) { + case int64: + *u = int64Compat((src).(int64)) + case []byte: + i, err := strconv.ParseInt(string(v), 10, 64) + if err != nil { + return err + } + *u = int64Compat(i) + default: + panic(fmt.Sprintf("expected type %T", src)) + } + } + return nil +} + +type customJSON struct { + N string `json:"name"` + V float64 `json:"value"` +} + +func (c customJSON) Value() (driver.Value, error) { + return JSONValue(c) +} + +func (c *customJSON) Scan(src interface{}) error { + return ScanJSON(c, src) +} + +type autoCustomJSON struct { + N string `json:"name"` + V float64 `json:"value"` + + *JSONConverter +} + +var ( + _ = driver.Valuer(&customJSON{}) + _ = sql.Scanner(&customJSON{}) +) + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) TestInsertReturningCompositeKey_Issue383() { + sess := s.Session() + + type Admin struct { + ID int `db:"ID,omitempty"` + Accounts string `db:"Accounts"` + LoginPassWord string `db:"LoginPassWord"` + Date time.Time `db:"Date"` + } + + dateNow := time.Now() + + a := Admin{ + Accounts: "admin", + LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E", + Date: dateNow, + } + + adminCollection := sess.Collection("admin") + err := adminCollection.InsertReturning(&a) + s.NoError(err) + + s.NotZero(a.ID) + s.NotZero(a.Date) + s.Equal("admin", a.Accounts) + s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord) + + b := Admin{ + Accounts: "admin2", + LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E", + Date: dateNow, + } + + err = adminCollection.InsertReturning(&b) + s.NoError(err) + + s.NotZero(b.ID) + s.NotZero(b.Date) + s.Equal("admin2", b.Accounts) + s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord) +} + +func (s *AdapterTests) TestIssue469_BadConnection() { + var err error + sess := s.Session() + + // Ask the MySQL server to disconnect sessions that remain inactive for more + // than 1 second. + _, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // A query should start a new connection, even if the server disconnected us. + _, err = sess.Collection("artist").Find().Count() + s.NoError(err) + + // This is a new session, ask the MySQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // At this point the server should have disconnected us. Let's try to create + // a transaction anyway. + err = sess.Tx(func(sess mydb.Session) error { + var err error + + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + return nil + }) + s.NoError(err) + + // This is a new session, ask the MySQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`) + s.NoError(err) + + err = sess.Tx(func(sess mydb.Session) error { + var err error + + // This query should succeed. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + panic(err.Error()) + } + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // This query should fail because the server disconnected us in the middle + // of a transaction. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + + return nil + }) + + s.Error(err, "Expecting an error (can't recover from this)") +} + +func (s *AdapterTests) TestMySQLTypes() { + sess := s.Session() + + type MyType struct { + ID int64 `db:"id,omitempty"` + + JSONMap JSONMap `db:"json_map"` + + JSONObject JSONMap `db:"json_object"` + JSONArray JSONArray `db:"json_array"` + + CustomJSONObject customJSON `db:"custom_json_object"` + AutoCustomJSONObject autoCustomJSON `db:"auto_custom_json_object"` + + CustomJSONObjectPtr *customJSON `db:"custom_json_object_ptr,omitempty"` + AutoCustomJSONObjectPtr *autoCustomJSON `db:"auto_custom_json_object_ptr,omitempty"` + + AutoCustomJSONObjectArray []autoCustomJSON `db:"auto_custom_json_object_array"` + AutoCustomJSONObjectMap map[string]autoCustomJSON `db:"auto_custom_json_object_map"` + + Int64CompatValueJSONArray []int64Compat `db:"integer_compat_value_json_array"` + UIntCompatValueJSONArray uintCompatArray `db:"uinteger_compat_value_json_array"` + StringCompatValueJSONArray []stringCompat `db:"string_compat_value_json_array"` + } + + origMyTypeTests := []MyType{ + MyType{ + Int64CompatValueJSONArray: []int64Compat{1, -2, 3, -4}, + UIntCompatValueJSONArray: []uintCompat{1, 2, 3, 4}, + StringCompatValueJSONArray: []stringCompat{"a", "b", "", "c"}, + }, + MyType{ + Int64CompatValueJSONArray: []int64Compat(nil), + UIntCompatValueJSONArray: []uintCompat(nil), + StringCompatValueJSONArray: []stringCompat(nil), + }, + MyType{ + AutoCustomJSONObjectArray: []autoCustomJSON{ + autoCustomJSON{ + N: "Hello", + }, + autoCustomJSON{ + N: "World", + }, + }, + AutoCustomJSONObjectMap: map[string]autoCustomJSON{ + "a": autoCustomJSON{ + N: "Hello", + }, + "b": autoCustomJSON{ + N: "World", + }, + }, + JSONArray: JSONArray{float64(1), float64(2), float64(3), float64(4)}, + }, + MyType{ + JSONArray: JSONArray{}, + }, + MyType{ + JSONArray: JSONArray(nil), + }, + MyType{}, + MyType{ + CustomJSONObject: customJSON{ + N: "Hello", + }, + AutoCustomJSONObject: autoCustomJSON{ + N: "World", + }, + }, + MyType{ + CustomJSONObject: customJSON{}, + AutoCustomJSONObject: autoCustomJSON{}, + }, + MyType{ + CustomJSONObject: customJSON{ + N: "Hello 1", + }, + AutoCustomJSONObject: autoCustomJSON{ + N: "World 2", + }, + }, + MyType{ + CustomJSONObjectPtr: nil, + AutoCustomJSONObjectPtr: nil, + }, + MyType{ + CustomJSONObjectPtr: &customJSON{}, + AutoCustomJSONObjectPtr: &autoCustomJSON{}, + }, + MyType{ + CustomJSONObjectPtr: &customJSON{ + N: "Hello 3", + }, + AutoCustomJSONObjectPtr: &autoCustomJSON{ + N: "World 4", + }, + }, + MyType{}, + MyType{ + CustomJSONObject: customJSON{ + V: 4.4, + }, + }, + MyType{ + CustomJSONObject: customJSON{}, + }, + MyType{ + CustomJSONObject: customJSON{ + N: "Peter", + V: 5.56, + }, + }, + } + + for i := 0; i < 100; i++ { + + myTypeTests := make([]MyType, len(origMyTypeTests)) + perm := rand.Perm(len(origMyTypeTests)) + for i, v := range perm { + myTypeTests[v] = origMyTypeTests[i] + } + + for i := range myTypeTests { + result, err := sess.Collection("my_types").Insert(myTypeTests[i]) + s.NoError(err) + + var actual MyType + err = sess.Collection("my_types").Find(result).One(&actual) + s.NoError(err) + + expected := myTypeTests[i] + expected.ID = result.ID().(int64) + s.Equal(expected, actual) + } + + for i := range myTypeTests { + res, err := sess.SQL().InsertInto("my_types").Values(myTypeTests[i]).Exec() + s.NoError(err) + + id, err := res.LastInsertId() + s.NoError(err) + s.NotEqual(0, id) + + var actual MyType + err = sess.Collection("my_types").Find(id).One(&actual) + s.NoError(err) + + expected := myTypeTests[i] + expected.ID = id + + s.Equal(expected, actual) + + var actual2 MyType + err = sess.SQL().SelectFrom("my_types").Where("id = ?", id).One(&actual2) + s.NoError(err) + s.Equal(expected, actual2) + } + + inserter := sess.SQL().InsertInto("my_types") + for i := range myTypeTests { + inserter = inserter.Values(myTypeTests[i]) + } + _, err := inserter.Exec() + s.NoError(err) + + err = sess.Collection("my_types").Truncate() + s.NoError(err) + + batch := sess.SQL().InsertInto("my_types").Batch(50) + go func() { + defer batch.Done() + for i := range myTypeTests { + batch.Values(myTypeTests[i]) + } + }() + + err = batch.Wait() + s.NoError(err) + + var values []MyType + err = sess.SQL().SelectFrom("my_types").All(&values) + s.NoError(err) + + for i := range values { + expected := myTypeTests[i] + expected.ID = values[i].ID + s.Equal(expected, values[i]) + } + } +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/mysql/record_test.go b/adapter/mysql/record_test.go new file mode 100644 index 0000000..7eb597c --- /dev/null +++ b/adapter/mysql/record_test.go @@ -0,0 +1,20 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type RecordTests struct { + testsuite.RecordTestSuite +} + +func (s *RecordTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestRecord(t *testing.T) { + suite.Run(t, &RecordTests{}) +} diff --git a/adapter/mysql/sql_test.go b/adapter/mysql/sql_test.go new file mode 100644 index 0000000..5605532 --- /dev/null +++ b/adapter/mysql/sql_test.go @@ -0,0 +1,20 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type SQLTests struct { + testsuite.SQLTestSuite +} + +func (s *SQLTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestSQL(t *testing.T) { + suite.Run(t, &SQLTests{}) +} diff --git a/adapter/mysql/template.go b/adapter/mysql/template.go new file mode 100644 index 0000000..4bd33a6 --- /dev/null +++ b/adapter/mysql/template.go @@ -0,0 +1,198 @@ +package mysql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = "`{{.Value}}`" + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + adapterJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + adapterOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + ` + + // The argument for LIMIT when only OFFSET is specified is a pretty odd magic + // number; this comes directly from MySQL's manual, see: + // https://dev.mysql.com/doc/refman/5.7/en/select.html + // + // "To retrieve all rows from a certain offset up to the end of the result + // set, you can use some large number for the second parameter. This + // statement retrieves all rows from the 96th row to the last: + // SELECT * FROM tbl LIMIT 95,18446744073709551615; " + // + // ¯\_(ツ)_/¯ + ` + {{if .Offset}} + {{if not .Limit}} + LIMIT 18446744073709551615 + {{end}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if defined .Columns}}({{.Columns | compile}}){{end}} + VALUES + {{if defined .Values}} + {{.Values | compile}} + {{else}} + () + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + adapterTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var template = &exql.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + JoinLayout: adapterJoinLayout, + OnLayout: adapterOnLayout, + UsingLayout: adapterUsingLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), +} diff --git a/adapter/mysql/template_test.go b/adapter/mysql/template_test.go new file mode 100644 index 0000000..4bd5ac1 --- /dev/null +++ b/adapter/mysql/template_test.go @@ -0,0 +1,269 @@ +package mysql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/assert" +) + +func TestTemplateSelect(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "SELECT * FROM `artist`", + b.SelectFrom("artist").String(), + ) + + assert.Equal( + "SELECT * FROM `artist`", + b.Select().From("artist").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` DESC", + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` DESC", + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` ASC", + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` ORDER BY `name` ASC", + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` LIMIT 18446744073709551615 OFFSET 5", + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + "SELECT `id` FROM `artist`", + b.Select("id").From("artist").String(), + ) + + assert.Equal( + "SELECT `id`, `name` FROM `artist`", + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`name` = $1)", + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (name LIKE $1)", + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + "SELECT `id` FROM `artist` WHERE (name LIKE $1 OR name LIKE $2)", + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` > $1)", + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (id <= 2 AND name != $1)", + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2, $3, $4))", + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (name IS NOT NULL)", + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a`, `publication` AS `p` WHERE (p.author_id = a.id) LIMIT 1", + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + "SELECT `id` FROM `artist` NATURAL JOIN `publication`", + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) LIMIT 1", + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (`a`.`id` = $1) LIMIT 1", + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1", + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1", + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` AS `a` LEFT JOIN `publication` AS `p1` ON (p1.id = a.id) RIGHT JOIN `publication` AS `p2` ON (p2.id = a.id)", + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + "SELECT * FROM `artist` CROSS JOIN `publication`", + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + "SELECT * FROM `artist` JOIN `publication` USING (`id`)", + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + "SELECT DATE()", + b.Select(mydb.Raw("DATE()")).String(), + ) + + // Issue #408 + { + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id IN": []int{1, 2}}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` = $1 AND `name` LIKE $2)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": []byte{1, 2}}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.In(1, 2)}).String(), + ) + + assert.Equal( + "SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)", + b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.AnyOf([]int{1, 2})}).String(), + ) + } +} + +func TestTemplateInsert(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "INSERT INTO `artist` VALUES ($1, $2), ($3, $4), ($5, $6)", + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)", + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2) RETURNING `id`", + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)", + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)", + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + assert.Equal( + "INSERT INTO `artist` (`name`, `id`) VALUES ($1, $2)", + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) +} + +func TestTemplateUpdate(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "UPDATE `artist` SET `name` = $1", + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)", + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)", + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)", + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1, `last_name` = $2 WHERE (`id` < $3)", + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + "UPDATE `artist` SET `name` = $1 || ' ' || $2 || id, `id` = id + $3 WHERE (id > $4)", + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) +} + +func TestTemplateDelete(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + "DELETE FROM `artist` WHERE (name = $1)", + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(), + ) + + assert.Equal( + "DELETE FROM `artist` WHERE (id > 5)", + b.DeleteFrom("artist").Where("id > 5").String(), + ) +} diff --git a/adapter/postgresql/Makefile b/adapter/postgresql/Makefile new file mode 100644 index 0000000..c9a0ff2 --- /dev/null +++ b/adapter/postgresql/Makefile @@ -0,0 +1,44 @@ +SHELL ?= bash + +POSTGRES_VERSION ?= 15-alpine +POSTGRES_SUPPORTED ?= $(POSTGRES_VERSION) 14-alpine 13-alpine 12-alpine + +PROJECT ?= upper_postgres_$(POSTGRES_VERSION) + +DB_HOST ?= 127.0.0.1 +DB_PORT ?= 5432 + +DB_NAME ?= upperio +DB_USERNAME ?= upperio_user +DB_PASSWORD ?= upperio//s3cr37 + +TEST_FLAGS ?= +PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1 + +export POSTGRES_VERSION + +export DB_HOST +export DB_NAME +export DB_PASSWORD +export DB_PORT +export DB_USERNAME + +export TEST_FLAGS + +test: + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +server-up: server-down + docker-compose -p $(PROJECT) up -d && \ + sleep 10 + +server-down: + docker-compose -p $(PROJECT) down + +test-extended: + parallel $(PARALLEL_FLAGS) \ + "POSTGRES_VERSION={} DB_PORT=\$$((5432+{#})) $(MAKE) server-up test server-down" ::: \ + $(POSTGRES_SUPPORTED) diff --git a/adapter/postgresql/README.md b/adapter/postgresql/README.md new file mode 100644 index 0000000..7e72601 --- /dev/null +++ b/adapter/postgresql/README.md @@ -0,0 +1,5 @@ +# PostgreSQL adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/postgresql/](https://upper.io/v4/adapter/postgresql/). + diff --git a/adapter/postgresql/collection.go b/adapter/postgresql/collection.go new file mode 100644 index 0000000..df41dd1 --- /dev/null +++ b/adapter/postgresql/collection.go @@ -0,0 +1,50 @@ +package postgresql + +import ( + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" +) + +type collectionAdapter struct { +} + +func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) { + pKey, err := col.PrimaryKeys() + if err != nil { + return nil, err + } + + q := col.SQL().InsertInto(col.Name()).Values(item) + + if len(pKey) == 0 { + // There is no primary key. + res, err := q.Exec() + if err != nil { + return nil, err + } + + // Attempt to use LastInsertId() (probably won't work, but the Exec() + // succeeded, so we can safely ignore the error from LastInsertId()). + lastID, err := res.LastInsertId() + if err != nil { + return nil, nil + } + return lastID, nil + } + + // Asking the database to return the primary key after insertion. + q = q.Returning(pKey...) + + var keyMap mydb.Cond + if err := q.Iterator().One(&keyMap); err != nil { + return nil, err + } + + // The IDSetter interface does not match, look for another interface match. + if len(keyMap) == 1 { + return keyMap[pKey[0]], nil + } + + // This was a compound key and no interface matched it, let's return a map. + return keyMap, nil +} diff --git a/adapter/postgresql/connection.go b/adapter/postgresql/connection.go new file mode 100644 index 0000000..2c2e4a7 --- /dev/null +++ b/adapter/postgresql/connection.go @@ -0,0 +1,289 @@ +package postgresql + +import ( + "fmt" + "net" + "net/url" + "sort" + "strings" + "time" + "unicode" +) + +// scanner implements a tokenizer for libpq-style option strings. +type scanner struct { + s []rune + i int +} + +// Next returns the next rune. It returns 0, false if the end of the text has +// been reached. +func (s *scanner) Next() (rune, bool) { + if s.i >= len(s.s) { + return 0, false + } + r := s.s[s.i] + s.i++ + return r, true +} + +// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the +// end of the text has been reached. +func (s *scanner) SkipSpaces() (rune, bool) { + r, ok := s.Next() + for unicode.IsSpace(r) && ok { + r, ok = s.Next() + } + return r, ok +} + +type values map[string]string + +func (vs values) Set(k, v string) { + vs[k] = v +} + +func (vs values) Get(k string) (v string) { + return vs[k] +} + +func (vs values) Isset(k string) bool { + _, ok := vs[k] + return ok +} + +// ConnectionURL represents a parsed PostgreSQL connection URL. +// +// You can use a ConnectionURL struct as an argument for Open: +// +// var settings = postgresql.ConnectionURL{ +// Host: "localhost", // PostgreSQL server IP or name. +// Database: "peanuts", // Database name. +// User: "cbrown", // Optional user name. +// Password: "snoopy", // Optional user password. +// } +// +// sess, err = postgresql.Open(settings) +// +// If you already have a valid DSN, you can use ParseURL to convert it into +// a ConnectionURL before passing it to Open. +type ConnectionURL struct { + User string + Password string + Host string + Socket string + Database string + Options map[string]string + + timezone *time.Location +} + +var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + +// ParseURL parses the given DSN into a ConnectionURL struct. +// A typical PostgreSQL connection URL looks like: +// +// postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full +func ParseURL(s string) (u *ConnectionURL, err error) { + o := make(values) + + if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") { + s, err = parseURL(s) + if err != nil { + return u, err + } + } + + if err := parseOpts(s, o); err != nil { + return u, err + } + u = &ConnectionURL{} + + u.User = o.Get("user") + u.Password = o.Get("password") + + h := o.Get("host") + p := o.Get("port") + + if strings.HasPrefix(h, "/") { + u.Socket = h + } else { + if p == "" { + u.Host = h + } else { + u.Host = fmt.Sprintf("%s:%s", h, p) + } + } + + u.Database = o.Get("dbname") + + u.Options = make(map[string]string) + + for k := range o { + switch k { + case "user", "password", "host", "port", "dbname": + // Skip + default: + u.Options[k] = o[k] + } + } + + if timezone, ok := u.Options["timezone"]; ok { + u.timezone, _ = time.LoadLocation(timezone) + } + + return u, err +} + +// parseOpts parses the options from name and adds them to the values. +// +// The parsing code is based on conninfo_parse from libpq's fe-connect.c +func parseOpts(name string, o values) error { + s := newScanner(name) + + for { + var ( + keyRunes, valRunes []rune + r rune + ok bool + ) + + if r, ok = s.SkipSpaces(); !ok { + break + } + + // Scan the key + for !unicode.IsSpace(r) && r != '=' { + keyRunes = append(keyRunes, r) + if r, ok = s.Next(); !ok { + break + } + } + + // Skip any whitespace if we're not at the = yet + if r != '=' { + r, ok = s.SkipSpaces() + } + + // The current character should be = + if r != '=' || !ok { + return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes)) + } + + // Skip any whitespace after the = + if r, ok = s.SkipSpaces(); !ok { + // If we reach the end here, the last value is just an empty string as per libpq. + o.Set(string(keyRunes), "") + break + } + + if r != '\'' { + for !unicode.IsSpace(r) { + if r == '\\' { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`missing character after backslash`) + } + } + valRunes = append(valRunes, r) + + if r, ok = s.Next(); !ok { + break + } + } + } else { + quote: + for { + if r, ok = s.Next(); !ok { + return fmt.Errorf(`unterminated quoted string literal in connection string`) + } + switch r { + case '\'': + break quote + case '\\': + r, _ = s.Next() + fallthrough + default: + valRunes = append(valRunes, r) + } + } + } + + o.Set(string(keyRunes), string(valRunes)) + } + + return nil +} + +// newScanner returns a new scanner initialized with the option string s. +func newScanner(s string) *scanner { + return &scanner{[]rune(s), 0} +} + +// ParseURL no longer needs to be used by clients of this library since supplying a URL as a +// connection string to sql.Open() is now supported: +// +// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full") +// +// It remains exported here for backwards-compatibility. +// +// ParseURL converts a url to a connection string for driver.Open. +// Example: +// +// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full" +// +// converts to: +// +// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full" +// +// A minimal example: +// +// "postgres://" +// +// # This will be blank, causing driver.Open to use all of the defaults +// +// NOTE: vendored/copied from github.com/lib/pq +func parseURL(uri string) (string, error) { + u, err := url.Parse(uri) + if err != nil { + return "", err + } + + if u.Scheme != "postgres" && u.Scheme != "postgresql" { + return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme) + } + + var kvs []string + escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`) + accrue := func(k, v string) { + if v != "" { + kvs = append(kvs, k+"="+escaper.Replace(v)) + } + } + + if u.User != nil { + v := u.User.Username() + accrue("user", v) + + v, _ = u.User.Password() + accrue("password", v) + } + + if host, port, err := net.SplitHostPort(u.Host); err != nil { + accrue("host", u.Host) + } else { + accrue("host", host) + accrue("port", port) + } + + if u.Path != "" { + accrue("dbname", u.Path[1:]) + } + + q := u.Query() + for k := range q { + accrue(k, q.Get(k)) + } + + sort.Strings(kvs) // Makes testing easier (not a performance concern) + return strings.Join(kvs, " "), nil +} diff --git a/adapter/postgresql/connection_pgx.go b/adapter/postgresql/connection_pgx.go new file mode 100644 index 0000000..8cecdb7 --- /dev/null +++ b/adapter/postgresql/connection_pgx.go @@ -0,0 +1,73 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "net" + "sort" + "strings" +) + +// String reassembles the parsed PostgreSQL connection URL into a valid DSN. +func (c ConnectionURL) String() (s string) { + u := []string{} + + // TODO: This surely needs some sort of escaping. + if c.User != "" { + u = append(u, "user="+escaper.Replace(c.User)) + } + + if c.Password != "" { + u = append(u, "password="+escaper.Replace(c.Password)) + } + + if c.Host != "" { + host, port, err := net.SplitHostPort(c.Host) + if err == nil { + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port = "5432" + } + u = append(u, "host="+escaper.Replace(host)) + u = append(u, "port="+escaper.Replace(port)) + } else { + u = append(u, "host="+escaper.Replace(c.Host)) + } + } + + if c.Socket != "" { + u = append(u, "host="+escaper.Replace(c.Socket)) + } + + if c.Database != "" { + u = append(u, "dbname="+escaper.Replace(c.Database)) + } + + // Is there actually any connection data? + if len(u) == 0 { + return "" + } + + if c.Options == nil { + c.Options = map[string]string{} + } + + // If not present, SSL mode is assumed "prefer". + if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" { + c.Options["sslmode"] = "prefer" + } + + // Disabled by default + c.Options["statement_cache_capacity"] = "0" + + for k, v := range c.Options { + u = append(u, escaper.Replace(k)+"="+escaper.Replace(v)) + } + + sort.Strings(u) + + return strings.Join(u, " ") +} diff --git a/adapter/postgresql/connection_pgx_test.go b/adapter/postgresql/connection_pgx_test.go new file mode 100644 index 0000000..380c9db --- /dev/null +++ b/adapter/postgresql/connection_pgx_test.go @@ -0,0 +1,108 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectionURL(t *testing.T) { + c := ConnectionURL{} + + // Default connection string is empty. + assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty") + + // Adding a host with port. + c.Host = "localhost:1234" + assert.Equal(t, "host=localhost port=1234 sslmode=prefer statement_cache_capacity=0", c.String()) + + // Adding a host. + c.Host = "localhost" + assert.Equal(t, "host=localhost sslmode=prefer statement_cache_capacity=0", c.String()) + + // Adding a username. + c.User = "Anakin" + assert.Equal(t, `host=localhost sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding a password with special characters. + c.Password = "Some Sort of ' Password" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding a port. + c.Host = "localhost:1234" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding a database. + c.Database = "MyDatabase" + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String()) + + // Adding options. + c.Options = map[string]string{ + "sslmode": "verify-full", + } + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full statement_cache_capacity=0 user=Anakin`, c.String()) +} + +func TestParseConnectionURL(t *testing.T) { + + { + s := "postgres://anakin:skywalker@localhost/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "localhost", u.Host) + assert.Equal(t, "jedis", u.Database) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + // case with port + s := "postgres://anakin:skywalker@localhost:1234/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost:1234", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, 2, len(u.Options), "Expecting exactly two options") + assert.Equal(t, "verify-full", u.Options["sslmode"]) + assert.Equal(t, "UTC", u.Options["timezone"]) + } +} diff --git a/adapter/postgresql/connection_pq.go b/adapter/postgresql/connection_pq.go new file mode 100644 index 0000000..fc3bddd --- /dev/null +++ b/adapter/postgresql/connection_pq.go @@ -0,0 +1,70 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "net" + "sort" + "strings" +) + +// String reassembles the parsed PostgreSQL connection URL into a valid DSN. +func (c ConnectionURL) String() (s string) { + u := []string{} + + // TODO: This surely needs some sort of escaping. + if c.User != "" { + u = append(u, "user="+escaper.Replace(c.User)) + } + + if c.Password != "" { + u = append(u, "password="+escaper.Replace(c.Password)) + } + + if c.Host != "" { + host, port, err := net.SplitHostPort(c.Host) + if err == nil { + if host == "" { + host = "127.0.0.1" + } + if port == "" { + port = "5432" + } + u = append(u, "host="+escaper.Replace(host)) + u = append(u, "port="+escaper.Replace(port)) + } else { + u = append(u, "host="+escaper.Replace(c.Host)) + } + } + + if c.Socket != "" { + u = append(u, "host="+escaper.Replace(c.Socket)) + } + + if c.Database != "" { + u = append(u, "dbname="+escaper.Replace(c.Database)) + } + + // Is there actually any connection data? + if len(u) == 0 { + return "" + } + + if c.Options == nil { + c.Options = map[string]string{} + } + + // If not present, SSL mode is assumed "prefer". + if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" { + c.Options["sslmode"] = "prefer" + } + + for k, v := range c.Options { + u = append(u, escaper.Replace(k)+"="+escaper.Replace(v)) + } + + sort.Strings(u) + + return strings.Join(u, " ") +} diff --git a/adapter/postgresql/connection_pq_test.go b/adapter/postgresql/connection_pq_test.go new file mode 100644 index 0000000..7646a7b --- /dev/null +++ b/adapter/postgresql/connection_pq_test.go @@ -0,0 +1,108 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestConnectionURL(t *testing.T) { + c := ConnectionURL{} + + // Default connection string is empty. + assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty") + + // Adding a host with port. + c.Host = "localhost:1234" + assert.Equal(t, "host=localhost port=1234 sslmode=prefer", c.String()) + + // Adding a host. + c.Host = "localhost" + assert.Equal(t, "host=localhost sslmode=prefer", c.String()) + + // Adding a username. + c.User = "Anakin" + assert.Equal(t, `host=localhost sslmode=prefer user=Anakin`, c.String()) + + // Adding a password with special characters. + c.Password = "Some Sort of ' Password" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer user=Anakin`, c.String()) + + // Adding a port. + c.Host = "localhost:1234" + assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String()) + + // Adding a database. + c.Database = "MyDatabase" + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String()) + + // Adding options. + c.Options = map[string]string{ + "sslmode": "verify-full", + } + assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full user=Anakin`, c.String()) +} + +func TestParseConnectionURL(t *testing.T) { + + { + s := "postgres://anakin:skywalker@localhost/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "localhost", u.Host) + assert.Equal(t, "jedis", u.Database) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + // case with port + s := "postgres://anakin:skywalker@localhost:1234/jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost:1234", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "anakin", u.User) + assert.Equal(t, "skywalker", u.Password) + assert.Equal(t, "jedis", u.Database) + assert.Equal(t, "localhost", u.Host) + assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.") + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, "verify-full", u.Options["sslmode"]) + } + + { + s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC" + u, err := ParseURL(s) + assert.NoError(t, err) + assert.Equal(t, 2, len(u.Options), "Expecting exactly two options") + assert.Equal(t, "verify-full", u.Options["sslmode"]) + assert.Equal(t, "UTC", u.Options["timezone"]) + } +} diff --git a/adapter/postgresql/custom_types.go b/adapter/postgresql/custom_types.go new file mode 100644 index 0000000..7b11668 --- /dev/null +++ b/adapter/postgresql/custom_types.go @@ -0,0 +1,126 @@ +package postgresql + +import ( + "context" + "database/sql" + "database/sql/driver" + "time" + + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// JSONBMap represents a map of interfaces with string keys +// (`map[string]interface{}`) that is compatible with PostgreSQL's JSONB type. +// JSONBMap satisfies sqlbuilder.ScannerValuer. +type JSONBMap map[string]interface{} + +// Value satisfies the driver.Valuer interface. +func (m JSONBMap) Value() (driver.Value, error) { + return JSONBValue(m) +} + +// Scan satisfies the sql.Scanner interface. +func (m *JSONBMap) Scan(src interface{}) error { + *m = map[string]interface{}(nil) + return ScanJSONB(m, src) +} + +// JSONBArray represents an array of any type (`[]interface{}`) that is +// compatible with PostgreSQL's JSONB type. JSONBArray satisfies +// sqlbuilder.ScannerValuer. +type JSONBArray []interface{} + +// Value satisfies the driver.Valuer interface. +func (a JSONBArray) Value() (driver.Value, error) { + return JSONBValue(a) +} + +// Scan satisfies the sql.Scanner interface. +func (a *JSONBArray) Scan(src interface{}) error { + return ScanJSONB(a, src) +} + +// JSONBValue takes an interface and provides a driver.Value that can be +// stored as a JSONB column. +func JSONBValue(i interface{}) (driver.Value, error) { + v := JSONB{i} + return v.Value() +} + +// ScanJSONB decodes a JSON byte stream into the passed dst value. +func ScanJSONB(dst interface{}, src interface{}) error { + v := JSONB{dst} + return v.Scan(src) +} + +type JSONBConverter struct { +} + +func (*JSONBConverter) ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer +} { + return &JSONB{in} +} + +type timeWrapper struct { + v **time.Time + loc *time.Location +} + +func (t timeWrapper) Value() (driver.Value, error) { + if *t.v != nil { + return **t.v, nil + } + return nil, nil +} + +func (t *timeWrapper) Scan(src interface{}) error { + if src == nil { + nilTime := (*time.Time)(nil) + if t.v == nil { + t.v = &nilTime + } else { + *(t.v) = nilTime + } + return nil + } + tz := src.(time.Time) + if t.loc != nil && (tz.Location() == time.Local) { + tz = tz.In(t.loc) + } + if tz.Location().String() == "" { + tz = tz.In(time.UTC) + } + if *(t.v) == nil { + *(t.v) = &tz + } else { + **t.v = tz + } + return nil +} + +func (d *database) ConvertValueContext(ctx context.Context, in interface{}) interface{} { + tz, _ := ctx.Value("timezone").(*time.Location) + + switch v := in.(type) { + case *time.Time: + return &timeWrapper{&v, tz} + case **time.Time: + return &timeWrapper{v, tz} + } + + return d.ConvertValue(in) +} + +// Type checks. +var ( + _ sqlbuilder.ScannerValuer = &StringArray{} + _ sqlbuilder.ScannerValuer = &Int64Array{} + _ sqlbuilder.ScannerValuer = &Float64Array{} + _ sqlbuilder.ScannerValuer = &Float32Array{} + _ sqlbuilder.ScannerValuer = &BoolArray{} + _ sqlbuilder.ScannerValuer = &JSONBMap{} + _ sqlbuilder.ScannerValuer = &JSONBArray{} + _ sqlbuilder.ScannerValuer = &JSONB{} +) diff --git a/adapter/postgresql/custom_types_pgx.go b/adapter/postgresql/custom_types_pgx.go new file mode 100644 index 0000000..7f5324c --- /dev/null +++ b/adapter/postgresql/custom_types_pgx.go @@ -0,0 +1,286 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "database/sql/driver" + + "github.com/jackc/pgtype" +) + +// JSONB represents a PostgreSQL's JSONB value: +// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB +// satisfies sqlbuilder.ScannerValuer. +type JSONB struct { + Data interface{} +} + +// MarshalJSON encodes the wrapper value as JSON. +func (j JSONB) MarshalJSON() ([]byte, error) { + t := &pgtype.JSONB{} + if err := t.Set(j.Data); err != nil { + return nil, err + } + return t.MarshalJSON() +} + +// UnmarshalJSON decodes the given JSON into the wrapped value. +func (j *JSONB) UnmarshalJSON(b []byte) error { + t := &pgtype.JSONB{} + if err := t.UnmarshalJSON(b); err != nil { + return err + } + if j.Data == nil { + j.Data = t.Get() + return nil + } + if err := t.AssignTo(&j.Data); err != nil { + return err + } + return nil +} + +// Scan satisfies the sql.Scanner interface. +func (j *JSONB) Scan(src interface{}) error { + t := &pgtype.JSONB{} + if err := t.Scan(src); err != nil { + return err + } + if j.Data == nil { + j.Data = t.Get() + return nil + } + if err := t.AssignTo(j.Data); err != nil { + return err + } + return nil +} + +// Value satisfies the driver.Valuer interface. +func (j JSONB) Value() (driver.Value, error) { + t := &pgtype.JSONB{} + if err := t.Set(j.Data); err != nil { + return nil, err + } + return t.Value() +} + +// StringArray represents a one-dimensional array of strings (`[]string{}`) +// that is compatible with PostgreSQL's text array (`text[]`). StringArray +// satisfies sqlbuilder.ScannerValuer. +type StringArray []string + +// Value satisfies the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + t := pgtype.TextArray{} + if err := t.Set(a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (sa *StringArray) Scan(src interface{}) error { + d := []string{} + t := pgtype.TextArray{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *sa = StringArray(d) + return nil +} + +type Bytea []byte + +func (b Bytea) Value() (driver.Value, error) { + t := pgtype.Bytea{Bytes: b} + if err := t.Set(b); err != nil { + return nil, err + } + return t.Value() +} + +func (b *Bytea) Scan(src interface{}) error { + d := []byte{} + t := pgtype.Bytea{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *b = Bytea(d) + return nil +} + +// ByteaArray represents a one-dimensional array of strings (`[]string{}`) +// that is compatible with PostgreSQL's text array (`text[]`). ByteaArray +// satisfies sqlbuilder.ScannerValuer. +type ByteaArray [][]byte + +// Value satisfies the driver.Valuer interface. +func (a ByteaArray) Value() (driver.Value, error) { + t := pgtype.ByteaArray{} + if err := t.Set(a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (ba *ByteaArray) Scan(src interface{}) error { + d := [][]byte{} + t := pgtype.ByteaArray{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *ba = ByteaArray(d) + return nil +} + +// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that +// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array +// satisfies sqlbuilder.ScannerValuer. +type Int64Array []int64 + +// Value satisfies the driver.Valuer interface. +func (i64a Int64Array) Value() (driver.Value, error) { + t := pgtype.Int8Array{} + if err := t.Set(i64a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (i64a *Int64Array) Scan(src interface{}) error { + d := []int64{} + t := pgtype.Int8Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *i64a = Int64Array(d) + return nil +} + +// Int32Array represents a one-dimensional array of int32s (`[]int32{}`) that +// is compatible with PostgreSQL's integer array (`integer[]`). Int32Array +// satisfies sqlbuilder.ScannerValuer. +type Int32Array []int32 + +// Value satisfies the driver.Valuer interface. +func (i32a Int32Array) Value() (driver.Value, error) { + t := pgtype.Int4Array{} + if err := t.Set(i32a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (i32a *Int32Array) Scan(src interface{}) error { + d := []int32{} + t := pgtype.Int4Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *i32a = Int32Array(d) + return nil +} + +// Float64Array represents a one-dimensional array of float64s (`[]float64{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer. +type Float64Array []float64 + +// Value satisfies the driver.Valuer interface. +func (f64a Float64Array) Value() (driver.Value, error) { + t := pgtype.Float8Array{} + if err := t.Set(f64a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f64a *Float64Array) Scan(src interface{}) error { + d := []float64{} + t := pgtype.Float8Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *f64a = Float64Array(d) + return nil +} + +// Float32Array represents a one-dimensional array of float32s (`[]float32{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer. +type Float32Array []float32 + +// Value satisfies the driver.Valuer interface. +func (f32a Float32Array) Value() (driver.Value, error) { + t := pgtype.Float8Array{} + if err := t.Set(f32a); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f32a *Float32Array) Scan(src interface{}) error { + d := []float32{} + t := pgtype.Float8Array{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *f32a = Float32Array(d) + return nil +} + +// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that +// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray +// satisfies sqlbuilder.ScannerValuer. +type BoolArray []bool + +// Value satisfies the driver.Valuer interface. +func (ba BoolArray) Value() (driver.Value, error) { + t := pgtype.BoolArray{} + if err := t.Set(ba); err != nil { + return nil, err + } + return t.Value() +} + +// Scan satisfies the sql.Scanner interface. +func (ba *BoolArray) Scan(src interface{}) error { + d := []bool{} + t := pgtype.BoolArray{} + if err := t.Scan(src); err != nil { + return err + } + if err := t.AssignTo(&d); err != nil { + return err + } + *ba = BoolArray(d) + return nil +} diff --git a/adapter/postgresql/custom_types_pq.go b/adapter/postgresql/custom_types_pq.go new file mode 100644 index 0000000..f93ca4c --- /dev/null +++ b/adapter/postgresql/custom_types_pq.go @@ -0,0 +1,249 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "bytes" + "database/sql/driver" + "encoding/hex" + "encoding/json" + "errors" + "fmt" + "reflect" + "strconv" + "time" + + "github.com/lib/pq" +) + +// JSONB represents a PostgreSQL's JSONB value: +// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB +// satisfies sqlbuilder.ScannerValuer. +type JSONB struct { + Data interface{} +} + +// MarshalJSON encodes the wrapper value as JSON. +func (j JSONB) MarshalJSON() ([]byte, error) { + return json.Marshal(j.Data) +} + +// UnmarshalJSON decodes the given JSON into the wrapped value. +func (j *JSONB) UnmarshalJSON(b []byte) error { + var v interface{} + if err := json.Unmarshal(b, &v); err != nil { + return err + } + j.Data = v + return nil +} + +// Scan satisfies the sql.Scanner interface. +func (j *JSONB) Scan(src interface{}) error { + if j.Data == nil { + return nil + } + if src == nil { + dv := reflect.Indirect(reflect.ValueOf(j.Data)) + dv.Set(reflect.Zero(dv.Type())) + return nil + } + + b, ok := src.([]byte) + if !ok { + return errors.New("Scan source was not []bytes") + } + + if err := json.Unmarshal(b, j.Data); err != nil { + return err + } + return nil +} + +// Value satisfies the driver.Valuer interface. +func (j JSONB) Value() (driver.Value, error) { + // See https://github.com/lib/pq/issues/528#issuecomment-257197239 on why are + // we returning string instead of []byte. + if j.Data == nil { + return nil, nil + } + if v, ok := j.Data.(json.RawMessage); ok { + return string(v), nil + } + b, err := json.Marshal(j.Data) + if err != nil { + return nil, err + } + return string(b), nil +} + +// StringArray represents a one-dimensional array of strings (`[]string{}`) +// that is compatible with PostgreSQL's text array (`text[]`). StringArray +// satisfies sqlbuilder.ScannerValuer. +type StringArray pq.StringArray + +// Value satisfies the driver.Valuer interface. +func (a StringArray) Value() (driver.Value, error) { + return pq.StringArray(a).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (a *StringArray) Scan(src interface{}) error { + s := pq.StringArray(*a) + if err := s.Scan(src); err != nil { + return err + } + *a = StringArray(s) + return nil +} + +// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that +// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array +// satisfies sqlbuilder.ScannerValuer. +type Int64Array pq.Int64Array + +// Value satisfies the driver.Valuer interface. +func (i Int64Array) Value() (driver.Value, error) { + return pq.Int64Array(i).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (i *Int64Array) Scan(src interface{}) error { + s := pq.Int64Array(*i) + if err := s.Scan(src); err != nil { + return err + } + *i = Int64Array(s) + return nil +} + +// Float64Array represents a one-dimensional array of float64s (`[]float64{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer. +type Float64Array pq.Float64Array + +// Value satisfies the driver.Valuer interface. +func (f Float64Array) Value() (driver.Value, error) { + return pq.Float64Array(f).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f *Float64Array) Scan(src interface{}) error { + s := pq.Float64Array(*f) + if err := s.Scan(src); err != nil { + return err + } + *f = Float64Array(s) + return nil +} + +// Float32Array represents a one-dimensional array of float32s (`[]float32{}`) +// that is compatible with PostgreSQL's double precision array (`double +// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer. +type Float32Array pq.Float32Array + +// Value satisfies the driver.Valuer interface. +func (f Float32Array) Value() (driver.Value, error) { + return pq.Float32Array(f).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (f *Float32Array) Scan(src interface{}) error { + s := pq.Float32Array(*f) + if err := s.Scan(src); err != nil { + return err + } + *f = Float32Array(s) + return nil +} + +// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that +// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray +// satisfies sqlbuilder.ScannerValuer. +type BoolArray pq.BoolArray + +// Value satisfies the driver.Valuer interface. +func (b BoolArray) Value() (driver.Value, error) { + return pq.BoolArray(b).Value() +} + +// Scan satisfies the sql.Scanner interface. +func (b *BoolArray) Scan(src interface{}) error { + s := pq.BoolArray(*b) + if err := s.Scan(src); err != nil { + return err + } + *b = BoolArray(s) + return nil +} + +type Bytea []byte + +// Scan satisfies the sql.Scanner interface. +func (b *Bytea) Scan(src interface{}) error { + decoded, err := parseBytea(src.([]byte)) + if err != nil { + return err + } + if len(decoded) < 1 { + *b = nil + return nil + } + (*b) = make(Bytea, len(decoded)) + for i := range decoded { + (*b)[i] = decoded[i] + } + return nil +} + +type Time time.Time + +// Parse a bytea value received from the server. Both "hex" and the legacy +// "escape" format are supported. +func parseBytea(s []byte) (result []byte, err error) { + if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) { + // bytea_output = hex + s = s[2:] // trim off leading "\\x" + result = make([]byte, hex.DecodedLen(len(s))) + _, err := hex.Decode(result, s) + if err != nil { + return nil, err + } + } else { + // bytea_output = escape + for len(s) > 0 { + if s[0] == '\\' { + // escaped '\\' + if len(s) >= 2 && s[1] == '\\' { + result = append(result, '\\') + s = s[2:] + continue + } + + // '\\' followed by an octal number + if len(s) < 4 { + return nil, fmt.Errorf("invalid bytea sequence %v", s) + } + r, err := strconv.ParseInt(string(s[1:4]), 8, 9) + if err != nil { + return nil, fmt.Errorf("could not parse bytea value: %s", err.Error()) + } + result = append(result, byte(r)) + s = s[4:] + } else { + // We hit an unescaped, raw byte. Try to read in as many as + // possible in one go. + i := bytes.IndexByte(s, '\\') + if i == -1 { + result = append(result, s...) + break + } + result = append(result, s[:i]...) + s = s[i:] + } + } + } + + return result, nil +} diff --git a/adapter/postgresql/custom_types_test.go b/adapter/postgresql/custom_types_test.go new file mode 100644 index 0000000..8624fda --- /dev/null +++ b/adapter/postgresql/custom_types_test.go @@ -0,0 +1,105 @@ +package postgresql + +import ( + "encoding/json" + "testing" + + "github.com/stretchr/testify/assert" +) + +type testStruct struct { + X int `json:"x"` + Z string `json:"z"` + V interface{} `json:"v"` +} + +func TestScanJSONB(t *testing.T) { + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": 1}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.Z) + assert.Equal(t, float64(1), a.V) + assert.Equal(t, 5, a.X) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": null}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.Z) + assert.Equal(t, nil, a.V) + assert.Equal(t, 5, a.X) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello"}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.Z) + assert.Equal(t, nil, a.V) + assert.Equal(t, 5, a.X) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"v": "Hello"}`)) + assert.NoError(t, err) + assert.Equal(t, "Hello", a.V) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{"v": true}`)) + assert.NoError(t, err) + assert.Equal(t, true, a.V) + } + { + a := testStruct{} + err := ScanJSONB(&a, []byte(`{}`)) + assert.NoError(t, err) + assert.Equal(t, nil, a.V) + } + { + var a []byte + err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`)) + assert.NoError(t, err) + assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a)) + } + { + var a json.RawMessage + err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`)) + assert.NoError(t, err) + assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a)) + } + { + var a json.RawMessage + err := ScanJSONB(&a, []byte("{\"age\":[{\"\u003e\":\"1h\"}]}")) + assert.NoError(t, err) + assert.Equal(t, `{"age":[{">":"1h"}]}`, string(a)) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Nil(t, a[0].V) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{"v": true}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Equal(t, true, a[0].V) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{"v": null}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Nil(t, a[0].V) + } + { + a := []*testStruct{} + err := json.Unmarshal([]byte(`[{"v": 12.34}]`), &a) + assert.NoError(t, err) + assert.Equal(t, 1, len(a)) + assert.Equal(t, 12.34, a[0].V) + } +} diff --git a/adapter/postgresql/database.go b/adapter/postgresql/database.go new file mode 100644 index 0000000..f718e13 --- /dev/null +++ b/adapter/postgresql/database.go @@ -0,0 +1,180 @@ +// Package postgresql provides an adapter for PostgreSQL. +// See https://github.com/upper/db/adapter/postgresql for documentation, +// particularities and usage examples. +package postgresql + +import ( + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +type database struct { +} + +func (*database) Template() *exql.Template { + return template +} + +func (*database) Collections(sess sqladapter.Session) (collections []string, err error) { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_schema = ?", "public") + + iter := q.Iterator() + defer iter.Close() + + for iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return nil, err + } + collections = append(collections, name) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return collections, nil +} + +func (*database) ConvertValue(in interface{}) interface{} { + switch v := in.(type) { + case *[]int64: + return (*Int64Array)(v) + case *[]string: + return (*StringArray)(v) + case *[]float64: + return (*Float64Array)(v) + case *[]bool: + return (*BoolArray)(v) + case *map[string]interface{}: + return (*JSONBMap)(v) + + case []int64: + return (*Int64Array)(&v) + case []string: + return (*StringArray)(&v) + case []float64: + return (*Float64Array)(&v) + case []bool: + return (*BoolArray)(&v) + case map[string]interface{}: + return (*JSONBMap)(&v) + + } + return in +} + +func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error) { + compiled, err := stmt.Compile(template) + if err != nil { + return "", nil, err + } + + query, args := sqlbuilder.Preprocess(compiled, args) + query = string(sqladapter.ReplaceWithDollarSign([]byte(query))) + return query, args, nil +} + +func (*database) Err(err error) error { + if err != nil { + s := err.Error() + // These errors are not exported so we have to check them by they string value. + if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) || strings.Contains(s, `too many open`) { + return mydb.ErrTooManyClients + } + } + return err +} + +func (*database) NewCollection() sqladapter.CollectionAdapter { + return &collectionAdapter{} +} + +func (*database) LookupName(sess sqladapter.Session) (string, error) { + q := sess.SQL(). + Select(mydb.Raw("CURRENT_DATABASE() AS name")) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return "", err + } + return name, nil + } + + return "", iter.Err() +} + +func (*database) TableExists(sess sqladapter.Session, name string) error { + q := sess.SQL(). + Select("table_name"). + From("information_schema.tables"). + Where("table_catalog = ? AND table_name = ?", sess.Name(), name) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return err + } + return nil + } + if err := iter.Err(); err != nil { + return err + } + + return mydb.ErrCollectionDoesNotExist +} + +func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) { + q := sess.SQL(). + Select("pg_attribute.attname AS pkey"). + From("pg_index", "pg_class", "pg_attribute"). + Where(` + pg_class.oid = '` + quotedTableName(tableName) + `'::regclass + AND indrelid = pg_class.oid + AND pg_attribute.attrelid = pg_class.oid + AND pg_attribute.attnum = ANY(pg_index.indkey) + AND indisprimary + `).OrderBy("pkey") + + iter := q.Iterator() + defer iter.Close() + + pk := []string{} + + for iter.Next() { + var k string + if err := iter.Scan(&k); err != nil { + return nil, err + } + pk = append(pk, k) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return pk, nil +} + +// quotedTableName returns a valid regclass name for both regular tables and +// for schemas. +func quotedTableName(s string) string { + chunks := strings.Split(s, ".") + for i := range chunks { + chunks[i] = fmt.Sprintf("%q", chunks[i]) + } + return strings.Join(chunks, ".") +} diff --git a/adapter/postgresql/database_pgx.go b/adapter/postgresql/database_pgx.go new file mode 100644 index 0000000..a40aee3 --- /dev/null +++ b/adapter/postgresql/database_pgx.go @@ -0,0 +1,26 @@ +//go:build !pq +// +build !pq + +package postgresql + +import ( + "context" + "database/sql" + "time" + + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + _ "github.com/jackc/pgx/v4/stdlib" +) + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + connURL, err := ParseURL(dsn) + if err != nil { + return nil, err + } + if tz := connURL.Options["timezone"]; tz != "" { + loc, _ := time.LoadLocation(tz) + ctx := context.WithValue(sess.Context(), "timezone", loc) + sess.SetContext(ctx) + } + return sql.Open("pgx", dsn) +} diff --git a/adapter/postgresql/database_pq.go b/adapter/postgresql/database_pq.go new file mode 100644 index 0000000..da3b904 --- /dev/null +++ b/adapter/postgresql/database_pq.go @@ -0,0 +1,26 @@ +//go:build pq +// +build pq + +package postgresql + +import ( + "context" + "database/sql" + "time" + + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + _ "github.com/lib/pq" +) + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + connURL, err := ParseURL(dsn) + if err != nil { + return nil, err + } + if tz := connURL.Options["timezone"]; tz != "" { + loc, _ := time.LoadLocation(tz) + ctx := context.WithValue(sess.Context(), "timezone", loc) + sess.SetContext(ctx) + } + return sql.Open("postgres", dsn) +} diff --git a/adapter/postgresql/docker-compose.yml b/adapter/postgresql/docker-compose.yml new file mode 100644 index 0000000..4f4884a --- /dev/null +++ b/adapter/postgresql/docker-compose.yml @@ -0,0 +1,13 @@ +version: '3' + +services: + + server: + image: postgres:${POSTGRES_VERSION:-11} + environment: + POSTGRES_USER: ${DB_USERNAME:-upperio_user} + POSTGRES_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37} + POSTGRES_DB: ${DB_NAME:-upperio} + ports: + - '${DB_HOST:-127.0.0.1}:${DB_PORT:-5432}:5432' + diff --git a/adapter/postgresql/generic_test.go b/adapter/postgresql/generic_test.go new file mode 100644 index 0000000..feec428 --- /dev/null +++ b/adapter/postgresql/generic_test.go @@ -0,0 +1,20 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/postgresql/helper_test.go b/adapter/postgresql/helper_test.go new file mode 100644 index 0000000..acb490e --- /dev/null +++ b/adapter/postgresql/helper_test.go @@ -0,0 +1,321 @@ +package postgresql + +import ( + "database/sql" + "fmt" + "os" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/testsuite" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), + User: os.Getenv("DB_USERNAME"), + Password: os.Getenv("DB_PASSWORD"), + Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"), + Options: map[string]string{ + "timezone": testsuite.TimeZone, + }, +} + +const preparedStatementsKey = "pg_prepared_statements_count" + +type Helper struct { + sess mydb.Session +} + +func cleanUp(sess mydb.Session) error { + if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 { + return fmt.Errorf("Expecting active statements to be less than 128, got %d", activeStatements) + } + + sess.Reset() + + stats, err := getStats(sess) + if err != nil { + return err + } + + if stats[preparedStatementsKey] != 0 { + return fmt.Errorf(`Expecting %q to be 0, got %d`, preparedStatementsKey, stats[preparedStatementsKey]) + } + + return nil +} + +func getStats(sess mydb.Session) (map[string]int, error) { + stats := make(map[string]int) + + row := sess.Driver().(*sql.DB).QueryRow(`SELECT count(1) AS value FROM pg_prepared_statements`) + + var value int + err := row.Scan(&value) + if err != nil { + return nil, err + } + + stats[preparedStatementsKey] = value + + return stats, nil +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return Adapter +} + +func (h *Helper) TearDown() error { + if err := cleanUp(h.sess); err != nil { + return err + } + + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + batch := []string{ + `DROP TABLE IF EXISTS artist`, + `CREATE TABLE artist ( + id serial primary key, + name varchar(60) + )`, + + `DROP TABLE IF EXISTS publication`, + `CREATE TABLE publication ( + id serial primary key, + title varchar(80), + author_id integer + )`, + + `DROP TABLE IF EXISTS review`, + `CREATE TABLE review ( + id serial primary key, + publication_id integer, + name varchar(80), + comments text, + created timestamp without time zone + )`, + + `DROP TABLE IF EXISTS data_types`, + `CREATE TABLE data_types ( + id serial primary key, + _uint integer, + _uint8 integer, + _uint16 integer, + _uint32 integer, + _uint64 integer, + _int integer, + _int8 integer, + _int16 integer, + _int32 integer, + _int64 integer, + _float32 numeric(10,6), + _float64 numeric(10,6), + _bool boolean, + _string text, + _blob bytea, + _date timestamp with time zone, + _nildate timestamp without time zone null, + _ptrdate timestamp without time zone, + _defaultdate timestamp without time zone DEFAULT now(), + _time bigint + )`, + + `DROP TABLE IF EXISTS stats_test`, + `CREATE TABLE stats_test ( + id serial primary key, + numeric integer, + value integer + )`, + + `DROP TABLE IF EXISTS composite_keys`, + `CREATE TABLE composite_keys ( + code varchar(255) default '', + user_id varchar(255) default '', + some_val varchar(255) default '', + primary key (code, user_id) + )`, + + `DROP TABLE IF EXISTS option_types`, + `CREATE TABLE option_types ( + id serial primary key, + name varchar(255) default '', + tags varchar(64)[], + settings jsonb + )`, + + `DROP TABLE IF EXISTS test_schema.test`, + `DROP SCHEMA IF EXISTS test_schema`, + + `CREATE SCHEMA test_schema`, + `CREATE TABLE test_schema.test (id integer)`, + + `DROP TABLE IF EXISTS pg_types`, + `CREATE TABLE pg_types (id serial primary key + , uint8_value smallint + , uint8_value_array bytea + + , int64_value smallint + , int64_value_array smallint[] + + , integer_array integer[] + , string_array text[] + , jsonb_map jsonb + , raw_jsonb_map jsonb + , raw_jsonb_text jsonb + + , integer_array_ptr integer[] + , string_array_ptr text[] + , jsonb_map_ptr jsonb + + , auto_integer_array integer[] + , auto_string_array text[] + , auto_jsonb_map jsonb + , auto_jsonb_map_string jsonb + , auto_jsonb_map_integer jsonb + + , jsonb_object jsonb + , jsonb_array jsonb + + , custom_jsonb_object jsonb + , auto_custom_jsonb_object jsonb + + , custom_jsonb_object_ptr jsonb + , auto_custom_jsonb_object_ptr jsonb + + , custom_jsonb_object_array jsonb + , auto_custom_jsonb_object_array jsonb + , auto_custom_jsonb_object_map jsonb + + , string_value varchar(255) + , integer_value int + , varchar_value varchar(64) + , decimal_value decimal + + , integer_compat_value int + , uinteger_compat_value int + , string_compat_value text + + , integer_compat_value_jsonb_array jsonb + , string_compat_value_jsonb_array jsonb + , uinteger_compat_value_jsonb_array jsonb + + , string_value_ptr varchar(255) + , integer_value_ptr int + , varchar_value_ptr varchar(64) + , decimal_value_ptr decimal + + , uuid_value_string UUID + + )`, + + `DROP TABLE IF EXISTS issue_370`, + `CREATE TABLE issue_370 ( + id UUID PRIMARY KEY, + name VARCHAR(25) + )`, + + `CREATE EXTENSION IF NOT EXISTS "uuid-ossp"`, + + `DROP TABLE IF EXISTS issue_602_organizations`, + `CREATE TABLE issue_602_organizations ( + name character varying(256) NOT NULL, + created_at timestamp without time zone DEFAULT now() NOT NULL, + updated_at timestamp without time zone DEFAULT now() NOT NULL, + id uuid DEFAULT public.uuid_generate_v4() NOT NULL + )`, + + `ALTER TABLE ONLY issue_602_organizations ADD CONSTRAINT issue_602_organizations_pkey PRIMARY KEY (id)`, + + `DROP TABLE IF EXISTS issue_370_2`, + `CREATE TABLE issue_370_2 ( + id INTEGER[3] PRIMARY KEY, + name VARCHAR(25) + )`, + + `DROP TABLE IF EXISTS varchar_primary_key`, + `CREATE TABLE varchar_primary_key ( + address VARCHAR(42) PRIMARY KEY NOT NULL, + name VARCHAR(25) + )`, + + `DROP TABLE IF EXISTS "birthdays"`, + `CREATE TABLE "birthdays" ( + "id" serial primary key, + "name" CHARACTER VARYING(50), + "born" TIMESTAMP WITH TIME ZONE, + "born_ut" INT + )`, + + `DROP TABLE IF EXISTS "fibonacci"`, + `CREATE TABLE "fibonacci" ( + "id" serial primary key, + "input" NUMERIC, + "output" NUMERIC + )`, + + `DROP TABLE IF EXISTS "is_even"`, + `CREATE TABLE "is_even" ( + "input" NUMERIC, + "is_even" BOOL + )`, + + `DROP TABLE IF EXISTS "CaSe_TesT"`, + `CREATE TABLE "CaSe_TesT" ( + "id" SERIAL PRIMARY KEY, + "case_test" VARCHAR(60) + )`, + + `DROP TABLE IF EXISTS accounts`, + `CREATE TABLE accounts ( + id serial primary key, + name varchar(255), + disabled boolean, + created_at timestamp with time zone + )`, + + `DROP TABLE IF EXISTS users`, + `CREATE TABLE users ( + id serial primary key, + account_id integer, + username varchar(255) UNIQUE + )`, + + `DROP TABLE IF EXISTS logs`, + `CREATE TABLE logs ( + id serial primary key, + message VARCHAR + )`, + } + + driver := h.sess.Driver().(*sql.DB) + tx, err := driver.Begin() + if err != nil { + return err + } + for _, query := range batch { + if _, err := tx.Exec(query); err != nil { + _ = tx.Rollback() + return err + } + } + if err := tx.Commit(); err != nil { + return err + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/postgresql/postgresql.go b/adapter/postgresql/postgresql.go new file mode 100644 index 0000000..2e770a8 --- /dev/null +++ b/adapter/postgresql/postgresql.go @@ -0,0 +1,30 @@ +package postgresql + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// Adapter is the internal name of the adapter. +const Adapter = "postgresql" + +var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{}) + +// Open establishes a connection to the database server and returns a +// sqlbuilder.Session instance (which is compatible with mydb.Session). +func Open(connURL mydb.ConnectionURL) (mydb.Session, error) { + return registeredAdapter.OpenDSN(connURL) +} + +// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value. +func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + return registeredAdapter.NewTx(sqlTx) +} + +// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value. +func New(sqlDB *sql.DB) (mydb.Session, error) { + return registeredAdapter.New(sqlDB) +} diff --git a/adapter/postgresql/postgresql_test.go b/adapter/postgresql/postgresql_test.go new file mode 100644 index 0000000..b40f6ef --- /dev/null +++ b/adapter/postgresql/postgresql_test.go @@ -0,0 +1,1404 @@ +package postgresql + +import ( + "context" + "database/sql" + "database/sql/driver" + "encoding/json" + "fmt" + "math/rand" + "strings" + "sync" + "testing" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" +) + +type customJSONBObjectArray []customJSONB + +func (customJSONBObjectArray) ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer +} { + return &JSONB{in} +} + +type customJSONBObjectMap map[string]customJSONB + +func (c customJSONBObjectMap) Value() (driver.Value, error) { + return JSONBValue(c) +} + +func (c *customJSONBObjectMap) Scan(src interface{}) error { + return ScanJSONB(c, src) +} + +type customJSONB struct { + N string `json:"name"` + V float64 `json:"value"` + + *JSONBConverter +} + +type int64Compat int64 + +type uintCompat uint + +type stringCompat string + +type uint8Compat uint8 + +type uint8CompatArray []uint8Compat + +func (ua uint8CompatArray) Value() (driver.Value, error) { + v := make([]byte, len(ua)) + for i := range ua { + v[i] = byte(ua[i]) + } + return v, nil +} + +func (ua *uint8CompatArray) Scan(src interface{}) error { + decoded := Bytea{} + if err := decoded.Scan(src); err != nil { + return nil + } + if len(decoded) < 1 { + *ua = nil + return nil + } + *ua = make([]uint8Compat, len(decoded)) + for i := range decoded { + (*ua)[i] = uint8Compat(decoded[i]) + } + return nil +} + +type int64CompatArray []int64Compat + +func (i64a int64CompatArray) Value() (driver.Value, error) { + v := make(Int64Array, len(i64a)) + for i := range i64a { + v[i] = int64(i64a[i]) + } + return v.Value() +} + +func (i64a *int64CompatArray) Scan(src interface{}) error { + s := Int64Array{} + if err := s.Scan(src); err != nil { + return err + } + dst := make([]int64Compat, len(s)) + for i := range s { + dst[i] = int64Compat(s[i]) + } + if len(dst) < 1 { + return nil + } + *i64a = dst + return nil +} + +type uintCompatArray []uintCompat + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) Test_Issue469_BadConnection() { + sess := s.Session() + + // Ask the PostgreSQL server to disconnect sessions that remain inactive for more + // than 1 second. + _, err := sess.SQL().Exec(`SET SESSION idle_in_transaction_session_timeout=1000`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // A query should start a new connection, even if the server disconnected us. + _, err = sess.Collection("artist").Find().Count() + s.NoError(err) + + // This is a new session, ask the PostgreSQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION idle_in_transaction_session_timeout=1000`) + s.NoError(err) + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // At this point the server should have disconnected us. Let's try to create + // a transaction anyway. + err = sess.Tx(func(sess mydb.Session) error { + var err error + + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + return nil + }) + s.NoError(err) + + // This is a new session, ask the PostgreSQL server to disconnect sessions that + // remain inactive for more than 1 second. + _, err = sess.SQL().Exec(`SET SESSION idle_in_transaction_session_timeout=1000`) + s.NoError(err) + + err = sess.Tx(func(sess mydb.Session) error { + var err error + + // This query should succeed. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + panic(err.Error()) + } + + // Remain inactive for 2 seconds. + time.Sleep(time.Second * 2) + + // This query should fail because the server disconnected us in the middle + // of a transaction. + _, err = sess.Collection("artist").Find().Count() + if err != nil { + return err + } + + return nil + }) + + s.Error(err, "Expecting an error (can't recover from this)") +} + +func testPostgreSQLTypes(t *testing.T, sess mydb.Session) { + type PGTypeInline struct { + IntegerArrayPtr *Int64Array `db:"integer_array_ptr,omitempty"` + StringArrayPtr *StringArray `db:"string_array_ptr,omitempty"` + JSONBMapPtr *JSONBMap `db:"jsonb_map_ptr,omitempty"` + } + + type PGTypeAutoInline struct { + AutoIntegerArray []int64 `db:"auto_integer_array"` + AutoStringArray []string `db:"auto_string_array"` + AutoJSONBMap map[string]interface{} `db:"auto_jsonb_map"` + AutoJSONBMapString map[string]interface{} `db:"auto_jsonb_map_string"` + AutoJSONBMapInteger map[string]interface{} `db:"auto_jsonb_map_integer"` + } + + type PGType struct { + ID int64 `db:"id,omitempty"` + + UInt8Value uint8Compat `db:"uint8_value"` + UInt8ValueArray uint8CompatArray `db:"uint8_value_array"` + + Int64Value int64Compat `db:"int64_value"` + Int64ValueArray *int64CompatArray `db:"int64_value_array"` + + IntegerArray Int64Array `db:"integer_array"` + StringArray StringArray `db:"string_array,stringarray"` + JSONBMap JSONBMap `db:"jsonb_map"` + + RawJSONBMap *json.RawMessage `db:"raw_jsonb_map,omitempty"` + RawJSONBText *json.RawMessage `db:"raw_jsonb_text,omitempty"` + + PGTypeInline `db:",inline"` + + PGTypeAutoInline `db:",inline"` + + JSONBObject JSONB `db:"jsonb_object"` + JSONBArray JSONBArray `db:"jsonb_array"` + + CustomJSONBObject customJSONB `db:"custom_jsonb_object"` + AutoCustomJSONBObject customJSONB `db:"auto_custom_jsonb_object"` + + CustomJSONBObjectPtr *customJSONB `db:"custom_jsonb_object_ptr,omitempty"` + AutoCustomJSONBObjectPtr *customJSONB `db:"auto_custom_jsonb_object_ptr,omitempty"` + + AutoCustomJSONBObjectArray *customJSONBObjectArray `db:"auto_custom_jsonb_object_array"` + AutoCustomJSONBObjectMap *customJSONBObjectMap `db:"auto_custom_jsonb_object_map"` + + StringValue string `db:"string_value"` + IntegerValue int64 `db:"integer_value"` + VarcharValue string `db:"varchar_value"` + DecimalValue float64 `db:"decimal_value"` + + Int64CompatValue int64Compat `db:"integer_compat_value"` + UIntCompatValue uintCompat `db:"uinteger_compat_value"` + StringCompatValue stringCompat `db:"string_compat_value"` + + Int64CompatValueJSONBArray JSONBArray `db:"integer_compat_value_jsonb_array"` + UIntCompatValueJSONBArray JSONBArray `db:"uinteger_compat_value_jsonb_array"` + StringCompatValueJSONBArray JSONBArray `db:"string_compat_value_jsonb_array"` + + StringValuePtr *string `db:"string_value_ptr,omitempty"` + IntegerValuePtr *int64 `db:"integer_value_ptr,omitempty"` + VarcharValuePtr *string `db:"varchar_value_ptr,omitempty"` + DecimalValuePtr *float64 `db:"decimal_value_ptr,omitempty"` + + UUIDValueString *string `db:"uuid_value_string,omitempty"` + } + + integerValue := int64(10) + stringValue := string("ten") + decimalValue := float64(10.0) + + uuidStringValue := "52356d08-6a16-4839-9224-75f0a547e13c" + + integerArrayValue := Int64Array{1, 2, 3, 4} + stringArrayValue := StringArray{"a", "b", "c"} + jsonbMapValue := JSONBMap{"Hello": "World"} + rawJSONBMap := json.RawMessage(`{"foo": "bar"}`) + rawJSONBText := json.RawMessage(`{"age": [{">": "1h"}]}`) + + testValue := "Hello world!" + + origPgTypeTests := []PGType{ + PGType{ + UUIDValueString: &uuidStringValue, + }, + PGType{ + UInt8Value: 7, + UInt8ValueArray: uint8CompatArray{1, 2, 3, 4, 5, 6}, + }, + PGType{ + Int64Value: -1, + Int64ValueArray: &int64CompatArray{1, 2, 3, -4, 5, 6}, + }, + PGType{ + UInt8Value: 1, + UInt8ValueArray: uint8CompatArray{1, 2, 3, 4, 5, 6}, + }, + PGType{ + Int64Value: 1, + Int64ValueArray: &int64CompatArray{7, 7, 7}, + }, + PGType{ + Int64Value: 1, + }, + PGType{ + Int64Value: 99, + Int64ValueArray: nil, + }, + PGType{ + Int64CompatValue: -5, + UIntCompatValue: 3, + StringCompatValue: "abc", + }, + PGType{ + Int64CompatValueJSONBArray: JSONBArray{1.0, -2.0, 3.0, -4.0}, + UIntCompatValueJSONBArray: JSONBArray{1.0, 2.0, 3.0, 4.0}, + StringCompatValueJSONBArray: JSONBArray{"a", "b", "", "c"}, + }, + PGType{ + Int64CompatValueJSONBArray: JSONBArray(nil), + UIntCompatValueJSONBArray: JSONBArray(nil), + StringCompatValueJSONBArray: JSONBArray(nil), + }, + PGType{ + IntegerValuePtr: &integerValue, + StringValuePtr: &stringValue, + DecimalValuePtr: &decimalValue, + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMapString: map[string]interface{}{"a": "x", "b": "67"}, + AutoJSONBMapInteger: map[string]interface{}{"a": 12.0, "b": 13.0}, + }, + }, + PGType{ + RawJSONBMap: &rawJSONBMap, + RawJSONBText: &rawJSONBText, + }, + PGType{ + IntegerValue: integerValue, + StringValue: stringValue, + DecimalValue: decimalValue, + }, + PGType{ + IntegerArray: []int64{1, 2, 3, 4}, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoIntegerArray: Int64Array{1, 2, 3, 4}, + AutoStringArray: nil, + }, + }, + PGType{ + AutoCustomJSONBObjectArray: &customJSONBObjectArray{ + customJSONB{ + N: "Hello", + }, + customJSONB{ + N: "World", + }, + }, + AutoCustomJSONBObjectMap: &customJSONBObjectMap{ + "a": customJSONB{ + N: "Hello", + }, + "b": customJSONB{ + N: "World", + }, + }, + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}{ + "Hello": "world", + "Roses": "red", + }, + }, + JSONBArray: JSONBArray{float64(1), float64(2), float64(3), float64(4)}, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoIntegerArray: nil, + }, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}{}, + }, + JSONBArray: JSONBArray{}, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}(nil), + }, + JSONBArray: JSONBArray(nil), + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoStringArray: []string{"aaa", "bbb", "ccc"}, + }, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoStringArray: nil, + }, + }, + PGType{ + PGTypeAutoInline: PGTypeAutoInline{ + AutoJSONBMap: map[string]interface{}{"hello": "world!"}, + }, + }, + PGType{ + IntegerArray: []int64{1, 2, 3, 4}, + StringArray: []string{"a", "boo", "bar"}, + }, + PGType{ + StringValue: stringValue, + DecimalValue: decimalValue, + }, + PGType{ + IntegerArray: []int64{}, + }, + PGType{ + StringArray: []string{}, + }, + PGType{ + IntegerArray: []int64{}, + StringArray: []string{}, + }, + PGType{}, + PGType{ + IntegerArray: []int64{1}, + StringArray: []string{"a"}, + }, + PGType{ + PGTypeInline: PGTypeInline{ + IntegerArrayPtr: &integerArrayValue, + StringArrayPtr: &stringArrayValue, + JSONBMapPtr: &jsonbMapValue, + }, + }, + PGType{ + IntegerArray: []int64{0, 0, 0, 0}, + StringValue: testValue, + CustomJSONBObject: customJSONB{ + N: "Hello", + }, + AutoCustomJSONBObject: customJSONB{ + N: "World", + }, + StringArray: []string{"", "", "", ``, `""`}, + }, + PGType{ + CustomJSONBObject: customJSONB{}, + AutoCustomJSONBObject: customJSONB{}, + }, + PGType{ + CustomJSONBObject: customJSONB{ + N: "Hello 1", + }, + AutoCustomJSONBObject: customJSONB{ + N: "World 2", + }, + }, + PGType{ + CustomJSONBObjectPtr: nil, + AutoCustomJSONBObjectPtr: nil, + }, + PGType{ + CustomJSONBObjectPtr: &customJSONB{}, + AutoCustomJSONBObjectPtr: &customJSONB{}, + }, + PGType{ + CustomJSONBObjectPtr: &customJSONB{ + N: "Hello 3", + }, + AutoCustomJSONBObjectPtr: &customJSONB{ + N: "World 4", + }, + }, + PGType{ + StringValue: testValue, + }, + PGType{ + IntegerValue: integerValue, + IntegerValuePtr: &integerValue, + CustomJSONBObject: customJSONB{ + V: 4.4, + }, + }, + PGType{ + StringArray: []string{"a", "boo", "bar"}, + }, + PGType{ + StringArray: []string{"a", "boo", "bar", `""`}, + CustomJSONBObject: customJSONB{}, + }, + PGType{ + IntegerArray: []int64{0}, + StringArray: []string{""}, + }, + PGType{ + CustomJSONBObject: customJSONB{ + N: "Peter", + V: 5.56, + }, + }, + } + + for i := 0; i < 100; i++ { + pgTypeTests := make([]PGType, len(origPgTypeTests)) + perm := rand.Perm(len(origPgTypeTests)) + for i, v := range perm { + pgTypeTests[v] = origPgTypeTests[i] + } + + for i := range pgTypeTests { + record, err := sess.Collection("pg_types").Insert(pgTypeTests[i]) + assert.NoError(t, err) + + var actual PGType + err = sess.Collection("pg_types").Find(record.ID()).One(&actual) + assert.NoError(t, err) + + expected := pgTypeTests[i] + expected.ID = record.ID().(int64) + assert.Equal(t, expected, actual) + } + + for i := range pgTypeTests { + row, err := sess.SQL().InsertInto("pg_types").Values(pgTypeTests[i]).Returning("id").QueryRow() + assert.NoError(t, err) + + var id int64 + err = row.Scan(&id) + assert.NoError(t, err) + + var actual PGType + err = sess.Collection("pg_types").Find(id).One(&actual) + assert.NoError(t, err) + + expected := pgTypeTests[i] + expected.ID = id + + assert.Equal(t, expected, actual) + + var actual2 PGType + err = sess.SQL().SelectFrom("pg_types").Where("id = ?", id).One(&actual2) + assert.NoError(t, err) + assert.Equal(t, expected, actual2) + } + + inserter := sess.SQL().InsertInto("pg_types") + for i := range pgTypeTests { + inserter = inserter.Values(pgTypeTests[i]) + } + _, err := inserter.Exec() + assert.NoError(t, err) + + err = sess.Collection("pg_types").Truncate() + assert.NoError(t, err) + + batch := sess.SQL().InsertInto("pg_types").Batch(50) + go func() { + defer batch.Done() + for i := range pgTypeTests { + batch.Values(pgTypeTests[i]) + } + }() + + err = batch.Wait() + assert.NoError(t, err) + + var values []PGType + err = sess.SQL().SelectFrom("pg_types").All(&values) + assert.NoError(t, err) + + for i := range values { + expected := pgTypeTests[i] + expected.ID = values[i].ID + + assert.Equal(t, expected, values[i]) + } + } +} + +func (s *AdapterTests) TestOptionTypes() { + sess := s.Session() + + optionTypes := sess.Collection("option_types") + err := optionTypes.Truncate() + s.NoError(err) + + // TODO: lets do some benchmarking on these auto-wrapped option types.. + + // TODO: add nullable jsonb field mapped to a []string + + // A struct with wrapped option types defined in the struct tags + // for postgres string array and jsonb types + type optionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags []string `db:"tags"` + Settings map[string]interface{} `db:"settings"` + } + + // Item 1 + item1 := optionType{ + Name: "Food", + Tags: []string{"toronto", "pizza"}, + Settings: map[string]interface{}{"a": 1, "b": 2}, + } + + record, err := optionTypes.Insert(item1) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1Chk optionType + err = optionTypes.Find(record).One(&item1Chk) + s.NoError(err) + + s.Equal(float64(1), item1Chk.Settings["a"]) + s.Equal("toronto", item1Chk.Tags[0]) + + // Item 1 B + item1b := &optionType{ + Name: "Golang", + Tags: []string{"love", "it"}, + Settings: map[string]interface{}{"go": 1, "lang": 2}, + } + + record, err = optionTypes.Insert(item1b) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1bChk optionType + err = optionTypes.Find(record).One(&item1bChk) + s.NoError(err) + + s.Equal(float64(1), item1bChk.Settings["go"]) + s.Equal("love", item1bChk.Tags[0]) + + // Item 1 C + item1c := &optionType{ + Name: "Sup", Tags: []string{}, Settings: map[string]interface{}{}, + } + + record, err = optionTypes.Insert(item1c) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1cChk optionType + err = optionTypes.Find(record).One(&item1cChk) + s.NoError(err) + + s.Zero(len(item1cChk.Tags)) + s.Zero(len(item1cChk.Settings)) + + // An option type to pointer jsonb field + type optionType2 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags StringArray `db:"tags"` + Settings *JSONBMap `db:"settings"` + } + + item2 := optionType2{ + Name: "JS", Tags: []string{"hi", "bye"}, Settings: nil, + } + + record, err = optionTypes.Insert(item2) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item2Chk optionType2 + res := optionTypes.Find(record) + err = res.One(&item2Chk) + s.NoError(err) + + s.Equal(record.ID().(int64), item2Chk.ID) + + s.Equal(item2Chk.Name, item2.Name) + + s.Equal(item2Chk.Tags[0], item2.Tags[0]) + s.Equal(len(item2Chk.Tags), len(item2.Tags)) + + // Update the value + m := JSONBMap{} + m["lang"] = "javascript" + m["num"] = 31337 + item2.Settings = &m + err = res.Update(item2) + s.NoError(err) + + err = res.One(&item2Chk) + s.NoError(err) + + s.Equal(float64(31337), (*item2Chk.Settings)["num"].(float64)) + + s.Equal("javascript", (*item2Chk.Settings)["lang"]) + + // An option type to pointer string array field + type optionType3 struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags *StringArray `db:"tags"` + Settings JSONBMap `db:"settings"` + } + + item3 := optionType3{ + Name: "Julia", + Tags: nil, + Settings: JSONBMap{"girl": true, "lang": true}, + } + + record, err = optionTypes.Insert(item3) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item3Chk optionType2 + err = optionTypes.Find(record).One(&item3Chk) + s.NoError(err) +} + +type Settings struct { + Name string `json:"name"` + Num int64 `json:"num"` +} + +func (s *Settings) Scan(src interface{}) error { + return ScanJSONB(s, src) +} +func (s Settings) Value() (driver.Value, error) { + return JSONBValue(s) +} + +func (s *AdapterTests) TestOptionTypeJsonbStruct() { + sess := s.Session() + + optionTypes := sess.Collection("option_types") + + err := optionTypes.Truncate() + s.NoError(err) + + type OptionType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + Tags StringArray `db:"tags"` + Settings Settings `db:"settings"` + } + + item1 := &OptionType{ + Name: "Hi", + Tags: []string{"aah", "ok"}, + Settings: Settings{Name: "a", Num: 123}, + } + + record, err := optionTypes.Insert(item1) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var item1Chk OptionType + err = optionTypes.Find(record).One(&item1Chk) + s.NoError(err) + + s.Equal(2, len(item1Chk.Tags)) + s.Equal("aah", item1Chk.Tags[0]) + s.Equal("a", item1Chk.Settings.Name) + s.Equal(int64(123), item1Chk.Settings.Num) +} + +func (s *AdapterTests) TestSchemaCollection() { + sess := s.Session() + + col := sess.Collection("test_schema.test") + _, err := col.Insert(map[string]int{"id": 9}) + s.Equal(nil, err) + + var dump []map[string]int + err = col.Find().All(&dump) + s.Nil(err) + s.Equal(1, len(dump)) + s.Equal(9, dump[0]["id"]) +} + +func (s *AdapterTests) Test_Issue340_MaxOpenConns() { + sess := s.Session() + + sess.SetMaxOpenConns(5) + + var wg sync.WaitGroup + for i := 0; i < 30; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + _, err := sess.SQL().Exec(fmt.Sprintf(`SELECT pg_sleep(1.%d)`, i)) + if err != nil { + s.T().Errorf("%v", err) + } + }(i) + } + + wg.Wait() + + sess.SetMaxOpenConns(0) +} + +func (s *AdapterTests) Test_Issue370_InsertUUID() { + sess := s.Session() + + { + type itemT struct { + ID *uuid.UUID `db:"id"` + Name string `db:"name"` + } + + newUUID := uuid.New() + + item1 := itemT{ + ID: &newUUID, + Name: "Jonny", + } + + col := sess.Collection("issue_370") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(item1.ID).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"id": item1.ID}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } + + { + type itemT struct { + ID uuid.UUID `db:"id"` + Name string `db:"name"` + } + + item1 := itemT{ + ID: uuid.New(), + Name: "Jonny", + } + + col := sess.Collection("issue_370") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(item1.ID).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"id": item1.ID}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } + + { + type itemT struct { + ID Int64Array `db:"id"` + Name string `db:"name"` + } + + item1 := itemT{ + ID: Int64Array{1, 2, 3}, + Name: "Vojtech", + } + + col := sess.Collection("issue_370_2") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(item1.ID).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"id": item1.ID}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } +} + +type issue602Organization struct { + ID string `json:"id" db:"id,omitempty"` + Name string `json:"name" db:"name"` + CreatedAt time.Time `json:"created_at,omitempty" db:"created_at,omitempty"` + UpdatedAt time.Time `json:"updated_at,omitempty" db:"updated_at,omitempty"` +} + +type issue602OrganizationStore struct { + mydb.Store +} + +func (r *issue602Organization) BeforeUpdate(mydb.Session) error { + return nil +} + +func (r *issue602Organization) Store(sess mydb.Session) mydb.Store { + return issue602OrganizationStore{sess.Collection("issue_602_organizations")} +} + +var _ interface { + mydb.Record + mydb.BeforeUpdateHook +} = &issue602Organization{} + +func (s *AdapterTests) Test_Issue602_IncorrectBinaryFormat() { + settingsWithBinaryMode := ConnectionURL{ + Database: settings.Database, + User: settings.User, + Password: settings.Password, + Host: settings.Host, + Options: map[string]string{ + "timezone": testsuite.TimeZone, + //"binary_parameters": "yes", + }, + } + + sess, err := Open(settingsWithBinaryMode) + if err != nil { + s.T().Errorf("%v", err) + } + + { + item := issue602Organization{ + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + err = sess.Save(&item) + s.NoError(err) + } + + { + item := issue602Organization{ + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item) + s.NoError(err) + } + + { + newUUID := uuid.New() + + item := issue602Organization{ + ID: newUUID.String(), + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + id, err := col.Insert(item) + s.NoError(err) + s.NotZero(id) + } + + { + item := issue602Organization{ + Name: "Jonny", + } + + col := sess.Collection("issue_602_organizations") + err := col.Truncate() + s.NoError(err) + + err = sess.Save(&item) + s.NoError(err) + } +} + +func (s *AdapterTests) TestInsertVarcharPrimaryKey() { + sess := s.Session() + + { + type itemT struct { + Address string `db:"address"` + Name string `db:"name"` + } + + item1 := itemT{ + Address: "1234", + Name: "Jonny", + } + + col := sess.Collection("varchar_primary_key") + err := col.Truncate() + s.NoError(err) + + err = col.InsertReturning(&item1) + s.NoError(err) + + var item2 itemT + err = col.Find(mydb.Cond{"address": item1.Address}).One(&item2) + s.NoError(err) + s.Equal(item1.Name, item2.Name) + + var item3 itemT + err = col.Find(mydb.Cond{"address": item1.Address}).One(&item3) + s.NoError(err) + s.Equal(item1.Name, item3.Name) + } +} + +func (s *AdapterTests) Test_Issue409_TxOptions() { + sess := s.Session() + + { + err := sess.TxContext(context.Background(), func(tx mydb.Session) error { + col := tx.Collection("publication") + + row := map[string]interface{}{ + "title": "foo", + "author_id": 1, + } + err := col.InsertReturning(&row) + s.Error(err) + + return err + }, &sql.TxOptions{ + ReadOnly: true, + }) + s.Error(err) + s.True(strings.Contains(err.Error(), "read-only transaction")) + } +} + +func (s *AdapterTests) TestEscapeQuestionMark() { + sess := s.Session() + + var val bool + + { + res, err := sess.SQL().QueryRow(`SELECT '{"mykey":["val1", "val2"]}'::jsonb->'mykey' ?? ?`, "val2") + s.NoError(err) + + err = res.Scan(&val) + s.NoError(err) + s.Equal(true, val) + } + + { + res, err := sess.SQL().QueryRow(`SELECT ?::jsonb->'mykey' ?? ?`, `{"mykey":["val1", "val2"]}`, `val2`) + s.NoError(err) + + err = res.Scan(&val) + s.NoError(err) + s.Equal(true, val) + } + + { + res, err := sess.SQL().QueryRow(`SELECT ?::jsonb->? ?? ?`, `{"mykey":["val1", "val2"]}`, `mykey`, `val2`) + s.NoError(err) + + err = res.Scan(&val) + s.NoError(err) + s.Equal(true, val) + } +} + +func (s *AdapterTests) Test_Issue391_TextMode() { + testPostgreSQLTypes(s.T(), s.Session()) +} + +func (s *AdapterTests) Test_Issue391_BinaryMode() { + settingsWithBinaryMode := ConnectionURL{ + Database: settings.Database, + User: settings.User, + Password: settings.Password, + Host: settings.Host, + Options: map[string]string{ + "timezone": testsuite.TimeZone, + //"binary_parameters": "yes", + }, + } + + sess, err := Open(settingsWithBinaryMode) + if err != nil { + s.T().Errorf("%v", err) + } + defer sess.Close() + + testPostgreSQLTypes(s.T(), sess) +} + +func (s *AdapterTests) TestStringAndInt64Array() { + sess := s.Session() + driver := sess.Driver().(*sql.DB) + + defer func() { + _, _ = driver.Exec(`DROP TABLE IF EXISTS array_types`) + }() + + if _, err := driver.Exec(` + CREATE TABLE array_types ( + id serial primary key, + integers bigint[] DEFAULT NULL, + strings varchar(64)[] + )`); err != nil { + s.NoError(err) + } + + arrayTypes := sess.Collection("array_types") + err := arrayTypes.Truncate() + s.NoError(err) + + type arrayType struct { + ID int64 `db:"id,pk"` + Integers Int64Array `db:"integers"` + Strings StringArray `db:"strings"` + } + + tt := []arrayType{ + // Test nil arrays. + arrayType{ + ID: 1, + Integers: nil, + Strings: nil, + }, + + // Test empty arrays. + arrayType{ + ID: 2, + Integers: []int64{}, + Strings: []string{}, + }, + + // Test non-empty arrays. + arrayType{ + ID: 3, + Integers: []int64{1, 2, 3}, + Strings: []string{"1", "2", "3"}, + }, + } + + for _, item := range tt { + record, err := arrayTypes.Insert(item) + s.NoError(err) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + var itemCheck arrayType + err = arrayTypes.Find(record).One(&itemCheck) + s.NoError(err) + s.Len(itemCheck.Integers, len(item.Integers)) + s.Len(itemCheck.Strings, len(item.Strings)) + + s.Equal(item, itemCheck) + } +} + +func (s *AdapterTests) Test_Issue210() { + list := []string{ + `DROP TABLE IF EXISTS testing123`, + `DROP TABLE IF EXISTS hello`, + `CREATE TABLE IF NOT EXISTS testing123 ( + ID INT PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL + ) + `, + `CREATE TABLE IF NOT EXISTS hello ( + ID INT PRIMARY KEY NOT NULL, + NAME TEXT NOT NULL + )`, + } + + sess := s.Session() + + err := sess.Tx(func(tx mydb.Session) error { + for i := range list { + _, err := tx.SQL().Exec(list[i]) + s.NoError(err) + if err != nil { + return err + } + } + return nil + }) + s.NoError(err) + + _, err = sess.Collection("testing123").Find().Count() + s.NoError(err) + + _, err = sess.Collection("hello").Find().Count() + s.NoError(err) +} + +func (s *AdapterTests) TestPreparedStatements() { + sess := s.Session() + + var val int + + { + stmt, err := sess.SQL().Prepare(`SELECT 1`) + s.NoError(err) + s.NotNil(stmt) + + q, err := stmt.Query() + s.NoError(err) + s.NotNil(q) + s.True(q.Next()) + + err = q.Scan(&val) + s.NoError(err) + + err = q.Close() + s.NoError(err) + + s.Equal(1, val) + + err = stmt.Close() + s.NoError(err) + } + + { + err := sess.Tx(func(tx mydb.Session) error { + stmt, err := tx.SQL().Prepare(`SELECT 2`) + s.NoError(err) + s.NotNil(stmt) + + q, err := stmt.Query() + s.NoError(err) + s.NotNil(q) + s.True(q.Next()) + + err = q.Scan(&val) + s.NoError(err) + + err = q.Close() + s.NoError(err) + + s.Equal(2, val) + + err = stmt.Close() + s.NoError(err) + + return nil + }) + s.NoError(err) + } + + { + stmt, err := sess.SQL().Select(3).Prepare() + s.NoError(err) + s.NotNil(stmt) + + q, err := stmt.Query() + s.NoError(err) + s.NotNil(q) + s.True(q.Next()) + + err = q.Scan(&val) + s.NoError(err) + + err = q.Close() + s.NoError(err) + + s.Equal(3, val) + + err = stmt.Close() + s.NoError(err) + } +} + +func (s *AdapterTests) TestNonTrivialSubqueries() { + sess := s.Session() + + // Creating test data + artist := sess.Collection("artist") + + artistNames := []string{"Ozzie", "Flea", "Slash", "Chrono"} + for _, artistName := range artistNames { + _, err := artist.Insert(map[string]string{ + "name": artistName, + }) + s.NoError(err) + } + + { + q, err := sess.SQL().Query(`WITH test AS (?) ?`, + sess.SQL().Select("id AS foo").From("artist"), + sess.SQL().Select("foo").From("test").Where("foo > ?", 0), + ) + + s.NoError(err) + s.NotNil(q) + + s.True(q.Next()) + + var number int + s.NoError(q.Scan(&number)) + + s.Equal(1, number) + s.NoError(q.Close()) + } + + { + builder := sess.SQL() + row, err := builder.QueryRow(`WITH test AS (?) ?`, + builder.Select("id AS foo").From("artist"), + builder.Select("foo").From("test").Where("foo > ?", 0), + ) + + s.NoError(err) + s.NotNil(row) + + var number int + s.NoError(row.Scan(&number)) + + s.Equal(1, number) + } + + { + res, err := sess.SQL().Exec( + `UPDATE artist a1 SET id = ?`, + sess.SQL().Select(mydb.Raw("id + 5")). + From("artist a2"). + Where("a2.id = a1.id"), + ) + + s.NoError(err) + s.NotNil(res) + } + + { + builder := sess.SQL() + + q, err := builder.Query(mydb.Raw(`WITH test AS (?) ?`, + builder.Select("id AS foo").From("artist"), + builder.Select("foo").From("test").Where("foo > ?", 0).OrderBy("foo"), + )) + + s.NoError(err) + s.NotNil(q) + + s.True(q.Next()) + + var number int + s.NoError(q.Scan(&number)) + + s.Equal(6, number) + s.NoError(q.Close()) + } + + { + res, err := sess.SQL().Exec(mydb.Raw(`UPDATE artist a1 SET id = ?`, + sess.SQL().Select(mydb.Raw("id + 7")).From("artist a2").Where("a2.id = a1.id"), + )) + + s.NoError(err) + s.NotNil(res) + } +} + +func (s *AdapterTests) Test_Issue601_ErrorCarrying() { + var items []interface{} + var err error + + sess := s.Session() + + _, err = sess.SQL().Exec(`DROP TABLE IF EXISTS issue_601`) + s.NoError(err) + + err = sess.Collection("issue_601").Find().All(&items) + s.Error(err) + + _, err = sess.SQL().Exec(`CREATE TABLE issue_601 (id INTEGER PRIMARY KEY)`) + s.NoError(err) + + err = sess.Collection("issue_601").Find().All(&items) + s.NoError(err) +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/postgresql/record_test.go b/adapter/postgresql/record_test.go new file mode 100644 index 0000000..a993c28 --- /dev/null +++ b/adapter/postgresql/record_test.go @@ -0,0 +1,20 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type RecordTests struct { + testsuite.RecordTestSuite +} + +func (s *RecordTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestRecord(t *testing.T) { + suite.Run(t, &RecordTests{}) +} diff --git a/adapter/postgresql/sql_test.go b/adapter/postgresql/sql_test.go new file mode 100644 index 0000000..21eae9e --- /dev/null +++ b/adapter/postgresql/sql_test.go @@ -0,0 +1,20 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type SQLTests struct { + testsuite.SQLTestSuite +} + +func (s *SQLTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestSQL(t *testing.T) { + suite.Run(t, &SQLTests{}) +} diff --git a/adapter/postgresql/template.go b/adapter/postgresql/template.go new file mode 100644 index 0000000..aa6eb69 --- /dev/null +++ b/adapter/postgresql/template.go @@ -0,0 +1,189 @@ +package postgresql + +import ( + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + adapterJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + adapterOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if defined .Columns}}({{.Columns | compile}}){{end}} + VALUES + {{if defined .Values}} + {{.Values | compile}} + {{else}} + (default) + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + adapterTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} RESTART IDENTITY + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var template = &exql.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + JoinLayout: adapterJoinLayout, + OnLayout: adapterOnLayout, + UsingLayout: adapterUsingLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), + ComparisonOperator: map[adapter.ComparisonOperator]string{ + adapter.ComparisonOperatorRegExp: "~", + adapter.ComparisonOperatorNotRegExp: "!~", + }, +} diff --git a/adapter/postgresql/template_test.go b/adapter/postgresql/template_test.go new file mode 100644 index 0000000..1ac1fcb --- /dev/null +++ b/adapter/postgresql/template_test.go @@ -0,0 +1,262 @@ +package postgresql + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/assert" +) + +func TestTemplateSelect(t *testing.T) { + + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `SELECT * FROM "artist"`, + b.SelectFrom("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.Select().From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 1 OFFSET 5`, + b.Select().From("artist").Limit(1).Offset(5).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 1 OFFSET 5`, + b.Select().From("artist").Offset(5).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" OFFSET 5`, + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" OFFSET 5`, + b.Select().From("artist").Offset(5).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist"`, + b.Select("id").From("artist").String(), + ) + + assert.Equal( + `SELECT "id", "name" FROM "artist"`, + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("name" = $1)`, + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name LIKE $1)`, + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`, + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1)`, + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`, + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`, + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`, + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" NATURAL JOIN "publication"`, + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`, + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + `SELECT * FROM "artist" CROSS JOIN "publication"`, + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" USING ("id")`, + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Raw("DATE()")).String(), + ) +} + +func TestTemplateInsert(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) +} + +func TestTemplateUpdate(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) +} + +func TestTemplateDelete(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1)`, + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + b.DeleteFrom("artist").Where("id > 5").String(), + ) +} diff --git a/adapter/sqlite/Makefile b/adapter/sqlite/Makefile new file mode 100644 index 0000000..13bb432 --- /dev/null +++ b/adapter/sqlite/Makefile @@ -0,0 +1,27 @@ +SHELL ?= bash +DB_NAME ?= sqlite3-test.db +TEST_FLAGS ?= + +export DB_NAME + +export TEST_FLAGS + +build: + go build && go install + +require-client: + @if [ -z "$$(which sqlite3)" ]; then \ + echo 'Missing "sqlite3" command. Please install SQLite3 and try again.' && \ + exit 1; \ + fi + +reset-db: require-client + rm -f $(DB_NAME) + +test: reset-db + go test -v -failfast -race -timeout 20m $(TEST_FLAGS) + +test-no-race: + go test -v -failfast $(TEST_FLAGS) + +test-extended: test diff --git a/adapter/sqlite/README.md b/adapter/sqlite/README.md new file mode 100644 index 0000000..5d73e1a --- /dev/null +++ b/adapter/sqlite/README.md @@ -0,0 +1,4 @@ +# SQLite adapter for upper/db + +Please read the full docs, acknowledgements and examples at +[https://upper.io/v4/adapter/sqlite/](https://upper.io/v4/adapter/sqlite/). diff --git a/adapter/sqlite/collection.go b/adapter/sqlite/collection.go new file mode 100644 index 0000000..ca0c8c2 --- /dev/null +++ b/adapter/sqlite/collection.go @@ -0,0 +1,49 @@ +package sqlite + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +type collectionAdapter struct { +} + +func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) { + columnNames, columnValues, err := sqlbuilder.Map(item, nil) + if err != nil { + return nil, err + } + + pKey, err := col.PrimaryKeys() + if err != nil { + return nil, err + } + + q := col.SQL().InsertInto(col.Name()). + Columns(columnNames...). + Values(columnValues...) + + var res sql.Result + if res, err = q.Exec(); err != nil { + return nil, err + } + + if len(pKey) <= 1 { + return res.LastInsertId() + } + + keyMap := mydb.Cond{} + + for i := range columnNames { + for j := 0; j < len(pKey); j++ { + if pKey[j] == columnNames[i] { + keyMap[pKey[j]] = columnValues[i] + } + } + } + + return keyMap, nil +} diff --git a/adapter/sqlite/connection.go b/adapter/sqlite/connection.go new file mode 100644 index 0000000..ef6bf9a --- /dev/null +++ b/adapter/sqlite/connection.go @@ -0,0 +1,89 @@ +package sqlite + +import ( + "fmt" + "net/url" + "path/filepath" + "runtime" + "strings" +) + +const connectionScheme = `file` + +// ConnectionURL implements a SQLite connection struct. +type ConnectionURL struct { + Database string + Options map[string]string +} + +func (c ConnectionURL) String() (s string) { + vv := url.Values{} + + if c.Database == "" { + return "" + } + + // Did the user provided a full database path? + if !strings.HasPrefix(c.Database, "/") { + c.Database, _ = filepath.Abs(c.Database) + if runtime.GOOS == "windows" { + // Closes https://github.com/upper/db/issues/60 + c.Database = "/" + strings.Replace(c.Database, `\`, `/`, -1) + } + } + + // Do we have any options? + if c.Options == nil { + c.Options = map[string]string{} + } + + if _, ok := c.Options["_busy_timeout"]; !ok { + c.Options["_busy_timeout"] = "10000" + } + + // Converting options into URL values. + for k, v := range c.Options { + vv.Set(k, v) + } + + // Building URL. + u := url.URL{ + Scheme: connectionScheme, + Path: c.Database, + RawQuery: vv.Encode(), + } + + return u.String() +} + +// ParseURL parses s into a ConnectionURL struct. +func ParseURL(s string) (conn ConnectionURL, err error) { + var u *url.URL + + if !strings.HasPrefix(s, connectionScheme+"://") { + return conn, fmt.Errorf(`Expecting file:// connection scheme.`) + } + + if u, err = url.Parse(s); err != nil { + return conn, err + } + + conn.Database = u.Host + u.Path + conn.Options = map[string]string{} + + var vv url.Values + + if vv, err = url.ParseQuery(u.RawQuery); err != nil { + return conn, err + } + + for k := range vv { + conn.Options[k] = vv.Get(k) + } + + if _, ok := conn.Options["cache"]; !ok { + conn.Options["cache"] = "shared" + } + + return conn, err +} diff --git a/adapter/sqlite/connection_test.go b/adapter/sqlite/connection_test.go new file mode 100644 index 0000000..1f99f3d --- /dev/null +++ b/adapter/sqlite/connection_test.go @@ -0,0 +1,88 @@ +package sqlite + +import ( + "path/filepath" + "testing" +) + +func TestConnectionURL(t *testing.T) { + + c := ConnectionURL{} + + // Default connection string is only the protocol. + if c.String() != "" { + t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String()) + } + + // Adding a database name. + c.Database = "myfilename" + + absoluteName, _ := filepath.Abs(c.Database) + + if c.String() != "file://"+absoluteName+"?_busy_timeout=10000" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Adding an option. + c.Options = map[string]string{ + "cache": "foobar", + "mode": "ro", + } + + if c.String() != "file://"+absoluteName+"?_busy_timeout=10000&cache=foobar&mode=ro" { + t.Fatal(`Test failed, got:`, c.String()) + } + + // Setting another database. + c.Database = "/another/database" + + if c.String() != `file:///another/database?_busy_timeout=10000&cache=foobar&mode=ro` { + t.Fatal(`Test failed, got:`, c.String()) + } + +} + +func TestParseConnectionURL(t *testing.T) { + var u ConnectionURL + var s string + var err error + + s = "file://mydatabase.db" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "mydatabase.db" { + t.Fatal("Failed to parse database.") + } + + if u.Options["cache"] != "shared" { + t.Fatal("If not defined, cache should be shared by default.") + } + + s = "file:///path/to/my/database.db?_busy_timeout=10000&mode=ro&cache=foobar" + + if u, err = ParseURL(s); err != nil { + t.Fatal(err) + } + + if u.Database != "/path/to/my/database.db" { + t.Fatal("Failed to parse username.") + } + + if u.Options["cache"] != "foobar" { + t.Fatal("Expecting option.") + } + + if u.Options["mode"] != "ro" { + t.Fatal("Expecting option.") + } + + s = "http://example.org" + + if _, err = ParseURL(s); err == nil { + t.Fatal("Expecting error.") + } + +} diff --git a/adapter/sqlite/database.go b/adapter/sqlite/database.go new file mode 100644 index 0000000..dfbada3 --- /dev/null +++ b/adapter/sqlite/database.go @@ -0,0 +1,168 @@ +// Package sqlite wraps the github.com/lib/sqlite SQLite driver. See +// https://github.com/upper/db/adapter/sqlite for documentation, particularities and +// usage examples. +package sqlite + +import ( + "context" + "database/sql" + "fmt" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/compat" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + _ "github.com/mattn/go-sqlite3" // SQLite3 driver. +) + +// database is the actual implementation of Database +type database struct { +} + +func (*database) Template() *exql.Template { + return template +} + +func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) { + return sql.Open("sqlite3", dsn) +} + +func (*database) Collections(sess sqladapter.Session) (collections []string, err error) { + q := sess.SQL(). + Select("tbl_name"). + From("sqlite_master"). + Where("type = ?", "table") + + iter := q.Iterator() + defer iter.Close() + + for iter.Next() { + var tableName string + if err := iter.Scan(&tableName); err != nil { + return nil, err + } + collections = append(collections, tableName) + } + if err := iter.Err(); err != nil { + return nil, err + } + + return collections, nil +} + +func (*database) StatementExec(sess sqladapter.Session, ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) { + if sess.Transaction() != nil { + return compat.ExecContext(sess.Driver().(*sql.Tx), ctx, query, args) + } + + sqlTx, err := compat.BeginTx(sess.Driver().(*sql.DB), ctx, nil) + if err != nil { + return nil, err + } + + if res, err = compat.ExecContext(sqlTx, ctx, query, args); err != nil { + _ = sqlTx.Rollback() + return nil, err + } + + if err = sqlTx.Commit(); err != nil { + return nil, err + } + + return res, err +} + +func (*database) NewCollection() sqladapter.CollectionAdapter { + return &collectionAdapter{} +} + +func (*database) LookupName(sess sqladapter.Session) (string, error) { + connURL := sess.ConnectionURL() + if connURL != nil { + connURL, err := ParseURL(connURL.String()) + if err != nil { + return "", err + } + return connURL.Database, nil + } + + // sess.ConnectionURL() is nil if using sqlite.New + rows, err := sess.SQL().Query(exql.RawSQL("PRAGMA database_list")) + if err != nil { + return "", err + } + dbInfo := struct { + Name string `db:"name"` + File string `db:"file"` + }{} + + if err := sess.SQL().NewIterator(rows).One(&dbInfo); err != nil { + return "", err + } + if dbInfo.File != "" { + return dbInfo.File, nil + } + // dbInfo.File is empty if in memory mode + return dbInfo.Name, nil +} + +func (*database) TableExists(sess sqladapter.Session, name string) error { + q := sess.SQL(). + Select("tbl_name"). + From("sqlite_master"). + Where("type = 'table' AND tbl_name = ?", name) + + iter := q.Iterator() + defer iter.Close() + + if iter.Next() { + var name string + if err := iter.Scan(&name); err != nil { + return err + } + return nil + } + if err := iter.Err(); err != nil { + return err + } + + return mydb.ErrCollectionDoesNotExist +} + +func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) { + pk := make([]string, 0, 1) + + stmt := exql.RawSQL(fmt.Sprintf("PRAGMA TABLE_INFO('%s')", tableName)) + + rows, err := sess.SQL().Query(stmt) + if err != nil { + return nil, err + } + + columns := []struct { + Name string `db:"name"` + PK int `db:"pk"` + }{} + + if err := sess.SQL().NewIterator(rows).All(&columns); err != nil { + return nil, err + } + + maxValue := -1 + + for _, column := range columns { + if column.PK > 0 && column.PK > maxValue { + maxValue = column.PK + } + } + + if maxValue > 0 { + for _, column := range columns { + if column.PK > 0 { + pk = append(pk, column.Name) + } + } + } + + return pk, nil +} diff --git a/adapter/sqlite/generic_test.go b/adapter/sqlite/generic_test.go new file mode 100644 index 0000000..df65544 --- /dev/null +++ b/adapter/sqlite/generic_test.go @@ -0,0 +1,20 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type GenericTests struct { + testsuite.GenericTestSuite +} + +func (s *GenericTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestGeneric(t *testing.T) { + suite.Run(t, &GenericTests{}) +} diff --git a/adapter/sqlite/helper_test.go b/adapter/sqlite/helper_test.go new file mode 100644 index 0000000..8c699da --- /dev/null +++ b/adapter/sqlite/helper_test.go @@ -0,0 +1,170 @@ +package sqlite + +import ( + "database/sql" + "os" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/testsuite" +) + +var settings = ConnectionURL{ + Database: os.Getenv("DB_NAME"), +} + +type Helper struct { + sess mydb.Session +} + +func (h *Helper) Session() mydb.Session { + return h.sess +} + +func (h *Helper) Adapter() string { + return "sqlite" +} + +func (h *Helper) TearDown() error { + return h.sess.Close() +} + +func (h *Helper) TearUp() error { + var err error + + h.sess, err = Open(settings) + if err != nil { + return err + } + + batch := []string{ + `PRAGMA foreign_keys=OFF`, + + `BEGIN TRANSACTION`, + + `DROP TABLE IF EXISTS artist`, + `CREATE TABLE artist ( + id integer primary key, + name varchar(60) + )`, + + `DROP TABLE IF EXISTS publication`, + `CREATE TABLE publication ( + id integer primary key, + title varchar(80), + author_id integer + )`, + + `DROP TABLE IF EXISTS review`, + `CREATE TABLE review ( + id integer primary key, + publication_id integer, + name varchar(80), + comments text, + created datetime + )`, + + `DROP TABLE IF EXISTS data_types`, + `CREATE TABLE data_types ( + id integer primary key, + _uint integer, + _uintptr integer, + _uint8 integer, + _uint16 int, + _uint32 int, + _uint64 int, + _int integer, + _int8 integer, + _int16 integer, + _int32 integer, + _int64 integer, + _float32 real, + _float64 real, + _byte integer, + _rune integer, + _bool integer, + _string text, + _blob blob, + _date datetime, + _nildate datetime, + _ptrdate datetime, + _defaultdate datetime default current_timestamp, + _time text + )`, + + `DROP TABLE IF EXISTS stats_test`, + `CREATE TABLE stats_test ( + id integer primary key, + numeric integer, + value integer + )`, + + `DROP TABLE IF EXISTS composite_keys`, + `CREATE TABLE composite_keys ( + code VARCHAR(255) default '', + user_id VARCHAR(255) default '', + some_val VARCHAR(255) default '', + primary key (code, user_id) + )`, + + `DROP TABLE IF EXISTS "birthdays"`, + `CREATE TABLE "birthdays" ( + "id" INTEGER PRIMARY KEY, + "name" VARCHAR(50) DEFAULT NULL, + "born" DATETIME DEFAULT NULL, + "born_ut" INTEGER + )`, + + `DROP TABLE IF EXISTS "fibonacci"`, + `CREATE TABLE "fibonacci" ( + "id" INTEGER PRIMARY KEY, + "input" INTEGER, + "output" INTEGER + )`, + + `DROP TABLE IF EXISTS "is_even"`, + `CREATE TABLE "is_even" ( + "input" INTEGER, + "is_even" INTEGER + )`, + + `DROP TABLE IF EXISTS "CaSe_TesT"`, + `CREATE TABLE "CaSe_TesT" ( + "id" INTEGER PRIMARY KEY, + "case_test" VARCHAR + )`, + + `DROP TABLE IF EXISTS accounts`, + `CREATE TABLE accounts ( + id integer primary key, + name varchar, + disabled integer, + created_at datetime default current_timestamp + )`, + + `DROP TABLE IF EXISTS users`, + `CREATE TABLE users ( + id integer primary key, + account_id integer, + username varchar UNIQUE + )`, + + `DROP TABLE IF EXISTS logs`, + `CREATE TABLE logs ( + id integer primary key, + message VARCHAR + )`, + + `COMMIT`, + } + + for _, query := range batch { + driver := h.sess.Driver().(*sql.DB) + if _, err := driver.Exec(query); err != nil { + return err + } + } + + return nil +} + +var _ testsuite.Helper = &Helper{} diff --git a/adapter/sqlite/record_test.go b/adapter/sqlite/record_test.go new file mode 100644 index 0000000..7a193c3 --- /dev/null +++ b/adapter/sqlite/record_test.go @@ -0,0 +1,20 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type RecordTests struct { + testsuite.RecordTestSuite +} + +func (s *RecordTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestRecord(t *testing.T) { + suite.Run(t, &RecordTests{}) +} diff --git a/adapter/sqlite/sql_test.go b/adapter/sqlite/sql_test.go new file mode 100644 index 0000000..4725c88 --- /dev/null +++ b/adapter/sqlite/sql_test.go @@ -0,0 +1,20 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type SQLTests struct { + testsuite.SQLTestSuite +} + +func (s *SQLTests) SetupSuite() { + s.Helper = &Helper{} +} + +func TestSQL(t *testing.T) { + suite.Run(t, &SQLTests{}) +} diff --git a/adapter/sqlite/sqlite.go b/adapter/sqlite/sqlite.go new file mode 100644 index 0000000..519489d --- /dev/null +++ b/adapter/sqlite/sqlite.go @@ -0,0 +1,30 @@ +package sqlite + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// Adapter is the public name of the adapter. +const Adapter = `sqlite` + +var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{}) + +// Open establishes a connection to the database server and returns a +// mydb.Session instance (which is compatible with mydb.Session). +func Open(connURL mydb.ConnectionURL) (mydb.Session, error) { + return registeredAdapter.OpenDSN(connURL) +} + +// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value. +func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + return registeredAdapter.NewTx(sqlTx) +} + +// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value. +func New(sqlDB *sql.DB) (mydb.Session, error) { + return registeredAdapter.New(sqlDB) +} diff --git a/adapter/sqlite/sqlite_test.go b/adapter/sqlite/sqlite_test.go new file mode 100644 index 0000000..5e99ef5 --- /dev/null +++ b/adapter/sqlite/sqlite_test.go @@ -0,0 +1,55 @@ +package sqlite + +import ( + "path/filepath" + "testing" + + "database/sql" + + "git.hexq.cn/tiglog/mydb/internal/testsuite" + "github.com/stretchr/testify/suite" +) + +type AdapterTests struct { + testsuite.Suite +} + +func (s *AdapterTests) SetupSuite() { + s.Helper = &Helper{} +} + +func (s *AdapterTests) Test_Issue633_OpenSession() { + sess, err := Open(settings) + s.NoError(err) + defer sess.Close() + + absoluteName, _ := filepath.Abs(settings.Database) + s.Equal(absoluteName, sess.Name()) +} + +func (s *AdapterTests) Test_Issue633_NewAdapterWithFile() { + sqldb, err := sql.Open("sqlite3", settings.Database) + s.NoError(err) + + sess, err := New(sqldb) + s.NoError(err) + defer sess.Close() + + absoluteName, _ := filepath.Abs(settings.Database) + s.Equal(absoluteName, sess.Name()) +} + +func (s *AdapterTests) Test_Issue633_NewAdapterWithMemory() { + sqldb, err := sql.Open("sqlite3", ":memory:") + s.NoError(err) + + sess, err := New(sqldb) + s.NoError(err) + defer sess.Close() + + s.Equal("main", sess.Name()) +} + +func TestAdapter(t *testing.T) { + suite.Run(t, &AdapterTests{}) +} diff --git a/adapter/sqlite/template.go b/adapter/sqlite/template.go new file mode 100644 index 0000000..e046d82 --- /dev/null +++ b/adapter/sqlite/template.go @@ -0,0 +1,187 @@ +package sqlite + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + adapterColumnSeparator = `.` + adapterIdentifierSeparator = `, ` + adapterIdentifierQuote = `"{{.Value}}"` + adapterValueSeparator = `, ` + adapterValueQuote = `'{{.}}'` + adapterAndKeyword = `AND` + adapterOrKeyword = `OR` + adapterDescKeyword = `DESC` + adapterAscKeyword = `ASC` + adapterAssignmentOperator = `=` + adapterClauseGroup = `({{.}})` + adapterClauseOperator = ` {{.}} ` + adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + adapterSortByColumnLayout = `{{.Column}} {{.Order}}` + + adapterOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + adapterWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + adapterUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + adapterJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + adapterOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + adapterSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + {{if not .Limit}} + LIMIT -1 + {{end}} + OFFSET {{.Offset}} + {{end}} + ` + adapterDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + adapterUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + adapterSelectCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + ` + + adapterInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if .Columns }}({{.Columns | compile}}){{end}} + {{if defined .Values}} + VALUES + {{.Values | compile}} + {{else}} + DEFAULT VALUES + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + adapterTruncateLayout = ` + DELETE FROM {{.Table | compile}} + ` + + adapterDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + adapterDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + adapterGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var template = &exql.Template{ + ColumnSeparator: adapterColumnSeparator, + IdentifierSeparator: adapterIdentifierSeparator, + IdentifierQuote: adapterIdentifierQuote, + ValueSeparator: adapterValueSeparator, + ValueQuote: adapterValueQuote, + AndKeyword: adapterAndKeyword, + OrKeyword: adapterOrKeyword, + DescKeyword: adapterDescKeyword, + AscKeyword: adapterAscKeyword, + AssignmentOperator: adapterAssignmentOperator, + ClauseGroup: adapterClauseGroup, + ClauseOperator: adapterClauseOperator, + ColumnValue: adapterColumnValue, + TableAliasLayout: adapterTableAliasLayout, + ColumnAliasLayout: adapterColumnAliasLayout, + SortByColumnLayout: adapterSortByColumnLayout, + WhereLayout: adapterWhereLayout, + JoinLayout: adapterJoinLayout, + OnLayout: adapterOnLayout, + UsingLayout: adapterUsingLayout, + OrderByLayout: adapterOrderByLayout, + InsertLayout: adapterInsertLayout, + SelectLayout: adapterSelectLayout, + UpdateLayout: adapterUpdateLayout, + DeleteLayout: adapterDeleteLayout, + TruncateLayout: adapterTruncateLayout, + DropDatabaseLayout: adapterDropDatabaseLayout, + DropTableLayout: adapterDropTableLayout, + CountLayout: adapterSelectCountLayout, + GroupByLayout: adapterGroupByLayout, + Cache: cache.NewCache(), +} diff --git a/adapter/sqlite/template_test.go b/adapter/sqlite/template_test.go new file mode 100644 index 0000000..f3a4793 --- /dev/null +++ b/adapter/sqlite/template_test.go @@ -0,0 +1,246 @@ +package sqlite + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/assert" +) + +func TestTemplateSelect(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `SELECT * FROM "artist"`, + b.SelectFrom("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.Select().From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT -1 OFFSET 5`, + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist"`, + b.Select("id").From("artist").String(), + ) + + assert.Equal( + `SELECT "id", "name" FROM "artist"`, + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("name" = $1)`, + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name LIKE $1)`, + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`, + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1)`, + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`, + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`, + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`, + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" NATURAL JOIN "publication"`, + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`, + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + `SELECT * FROM "artist" CROSS JOIN "publication"`, + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" USING ("id")`, + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Raw("DATE()")).String(), + ) +} + +func TestTemplateInsert(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) +} + +func TestTemplateUpdate(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) +} + +func TestTemplateDelete(t *testing.T) { + b := sqlbuilder.WithTemplate(template) + assert := assert.New(t) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1)`, + b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + b.DeleteFrom("artist").Where("id > 5").String(), + ) +} diff --git a/clauses.go b/clauses.go new file mode 100644 index 0000000..7e8d69d --- /dev/null +++ b/clauses.go @@ -0,0 +1,468 @@ +package mydb + +import ( + "context" + "fmt" +) + +// Selector represents a SELECT statement. +type Selector interface { + // Columns defines which columns to retrive. + // + // You should call From() after Columns() if you want to query data from an + // specific table. + // + // s.Columns("name", "last_name").From(...) + // + // It is also possible to use an alias for the column, this could be handy if + // you plan to use the alias later, use the "AS" keyword to denote an alias. + // + // s.Columns("name AS n") + // + // or the shortcut: + // + // s.Columns("name n") + // + // If you don't want the column to be escaped use the db.Raw + // function. + // + // s.Columns(db.Raw("MAX(id)")) + // + // The above statement is equivalent to: + // + // s.Columns(db.Func("MAX", "id")) + Columns(columns ...interface{}) Selector + + // From represents a FROM clause and is tipically used after Columns(). + // + // FROM defines from which table data is going to be retrieved + // + // s.Columns(...).From("people") + // + // It is also possible to use an alias for the table, this could be handy if + // you plan to use the alias later: + // + // s.Columns(...).From("people AS p").Where("p.name = ?", ...) + // + // Or with the shortcut: + // + // s.Columns(...).From("people p").Where("p.name = ?", ...) + From(tables ...interface{}) Selector + + // Distict represents a DISTINCT clause + // + // DISTINCT is used to ask the database to return only values that are + // different. + Distinct(columns ...interface{}) Selector + + // As defines an alias for a table. + As(string) Selector + + // Where specifies the conditions that columns must match in order to be + // retrieved. + // + // Where accepts raw strings and fmt.Stringer to define conditions and + // interface{} to specify parameters. Be careful not to embed any parameters + // within the SQL part as that could lead to security problems. You can use + // que question mark (?) as placeholder for parameters. + // + // s.Where("name = ?", "max") + // + // s.Where("name = ? AND last_name = ?", "Mary", "Doe") + // + // s.Where("last_name IS NULL") + // + // You can also use other types of parameters besides only strings, like: + // + // s.Where("online = ? AND last_logged <= ?", true, time.Now()) + // + // and Where() will transform them into strings before feeding them to the + // database. + // + // When an unknown type is provided, Where() will first try to match it with + // the Marshaler interface, then with fmt.Stringer and finally, if the + // argument does not satisfy any of those interfaces Where() will use + // fmt.Sprintf("%v", arg) to transform the type into a string. + // + // Subsequent calls to Where() will overwrite previously set conditions, if + // you want these new conditions to be appended use And() instead. + Where(conds ...interface{}) Selector + + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. + And(conds ...interface{}) Selector + + // GroupBy represents a GROUP BY statement. + // + // GROUP BY defines which columns should be used to aggregate and group + // results. + // + // s.GroupBy("country_id") + // + // GroupBy accepts more than one column: + // + // s.GroupBy("country_id", "city_id") + GroupBy(columns ...interface{}) Selector + + // Having(...interface{}) Selector + + // OrderBy represents a ORDER BY statement. + // + // ORDER BY is used to define which columns are going to be used to sort + // results. + // + // Use the column name to sort results in ascendent order. + // + // // "last_name" ASC + // s.OrderBy("last_name") + // + // Prefix the column name with the minus sign (-) to sort results in + // descendent order. + // + // // "last_name" DESC + // s.OrderBy("-last_name") + // + // If you would rather be very explicit, you can also use ASC and DESC. + // + // s.OrderBy("last_name ASC") + // + // s.OrderBy("last_name DESC", "name ASC") + OrderBy(columns ...interface{}) Selector + + // Join represents a JOIN statement. + // + // JOIN statements are used to define external tables that the user wants to + // include as part of the result. + // + // You can use the On() method after Join() to define the conditions of the + // join. + // + // s.Join("author").On("author.id = book.author_id") + // + // If you don't specify conditions for the join, a NATURAL JOIN will be used. + // + // On() accepts the same arguments as Where() + // + // You can also use Using() after Join(). + // + // s.Join("employee").Using("department_id") + Join(table ...interface{}) Selector + + // FullJoin is like Join() but with FULL JOIN. + FullJoin(...interface{}) Selector + + // CrossJoin is like Join() but with CROSS JOIN. + CrossJoin(...interface{}) Selector + + // RightJoin is like Join() but with RIGHT JOIN. + RightJoin(...interface{}) Selector + + // LeftJoin is like Join() but with LEFT JOIN. + LeftJoin(...interface{}) Selector + + // Using represents the USING clause. + // + // USING is used to specifiy columns to join results. + // + // s.LeftJoin(...).Using("country_id") + Using(...interface{}) Selector + + // On represents the ON clause. + // + // ON is used to define conditions on a join. + // + // s.Join(...).On("b.author_id = a.id") + On(...interface{}) Selector + + // Limit represents the LIMIT parameter. + // + // LIMIT defines the maximum number of rows to return from the table. A + // negative limit cancels any previous limit settings. + // + // s.Limit(42) + Limit(int) Selector + + // Offset represents the OFFSET parameter. + // + // OFFSET defines how many results are going to be skipped before starting to + // return results. A negative offset cancels any previous offset settings. + // + // s.Offset(56) + Offset(int) Selector + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Selector + + // Paginate returns a paginator that can display a paginated lists of items. + // Paginators ignore previous Offset and Limit settings. Page numbering + // starts at 1. + Paginate(uint) Paginator + + // Iterator provides methods to iterate over the results returned by the + // Selector. + Iterator() Iterator + + // IteratorContext provides methods to iterate over the results returned by + // the Selector. + IteratorContext(ctx context.Context) Iterator + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLGetter provides methods to compile and execute a query that returns + // results. + SQLGetter + + // ResultMapper provides methods to retrieve and map results. + ResultMapper + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Selector` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} +} + +// Inserter represents an INSERT statement. +type Inserter interface { + // Columns represents the COLUMNS clause. + // + // COLUMNS defines the columns that we are going to provide values for. + // + // i.Columns("name", "last_name").Values(...) + Columns(...string) Inserter + + // Values represents the VALUES clause. + // + // VALUES defines the values of the columns. + // + // i.Columns(...).Values("María", "Méndez") + // + // i.Values(map[string][string]{"name": "María"}) + Values(...interface{}) Inserter + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} + + // Returning represents a RETURNING clause. + // + // RETURNING specifies which columns should be returned after INSERT. + // + // RETURNING may not be supported by all SQL databases. + Returning(columns ...string) Inserter + + // Iterator provides methods to iterate over the results returned by the + // Inserter. This is only possible when using Returning(). + Iterator() Iterator + + // IteratorContext provides methods to iterate over the results returned by + // the Inserter. This is only possible when using Returning(). + IteratorContext(ctx context.Context) Iterator + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Inserter + + // Batch provies a BatchInserter that can be used to insert many elements at + // once by issuing several calls to Values(). It accepts a size parameter + // which defines the batch size. If size is < 1, the batch size is set to 1. + Batch(size int) BatchInserter + + // SQLExecer provides the Exec method. + SQLExecer + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLGetter provides methods to return query results from INSERT statements + // that support such feature (e.g.: queries with Returning). + SQLGetter + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Inserter` into a string. + fmt.Stringer +} + +// Deleter represents a DELETE statement. +type Deleter interface { + // Where represents the WHERE clause. + // + // See Selector.Where for documentation and usage examples. + Where(...interface{}) Deleter + + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. + And(conds ...interface{}) Deleter + + // Limit represents the LIMIT clause. + // + // See Selector.Limit for documentation and usage examples. + Limit(int) Deleter + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Deleter + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLExecer provides the Exec method. + SQLExecer + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Inserter` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} +} + +// Updater represents an UPDATE statement. +type Updater interface { + // Set represents the SET clause. + Set(...interface{}) Updater + + // Where represents the WHERE clause. + // + // See Selector.Where for documentation and usage examples. + Where(...interface{}) Updater + + // And appends more constraints to the WHERE clause without overwriting + // conditions that have been already set. + And(conds ...interface{}) Updater + + // Limit represents the LIMIT parameter. + // + // See Selector.Limit for documentation and usage examples. + Limit(int) Updater + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLExecer provides the Exec method. + SQLExecer + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Inserter` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} + + // Amend lets you alter the query's text just before sending it to the + // database server. + Amend(func(queryIn string) (queryOut string)) Updater +} + +// Paginator provides tools for splitting the results of a query into chunks +// containing a fixed number of items. +type Paginator interface { + // Page sets the page number. + Page(uint) Paginator + + // Cursor defines the column that is going to be taken as basis for + // cursor-based pagination. + // + // Example: + // + // a = q.Paginate(10).Cursor("id") + // b = q.Paginate(12).Cursor("-id") + // + // You can set "" as cursorColumn to disable cursors. + Cursor(cursorColumn string) Paginator + + // NextPage returns the next page according to the cursor. It expects a + // cursorValue, which is the value the cursor column has on the last item of + // the current result set (lower bound). + // + // Example: + // + // p = q.NextPage(items[len(items)-1].ID) + NextPage(cursorValue interface{}) Paginator + + // PrevPage returns the previous page according to the cursor. It expects a + // cursorValue, which is the value the cursor column has on the fist item of + // the current result set (mydb bound). + // + // Example: + // + // p = q.PrevPage(items[0].ID) + PrevPage(cursorValue interface{}) Paginator + + // TotalPages returns the total number of pages in the query. + TotalPages() (uint, error) + + // TotalEntries returns the total number of entries in the query. + TotalEntries() (uint64, error) + + // SQLPreparer provides methods for creating prepared statements. + SQLPreparer + + // SQLGetter provides methods to compile and execute a query that returns + // results. + SQLGetter + + // Iterator provides methods to iterate over the results returned by the + // Selector. + Iterator() Iterator + + // IteratorContext provides methods to iterate over the results returned by + // the Selector. + IteratorContext(ctx context.Context) Iterator + + // ResultMapper provides methods to retrieve and map results. + ResultMapper + + // fmt.Stringer provides `String() string`, you can use `String()` to compile + // the `Selector` into a string. + fmt.Stringer + + // Arguments returns the arguments that are prepared for this query. + Arguments() []interface{} +} + +// ResultMapper defined methods for a result mapper. +type ResultMapper interface { + // All dumps all the results into the given slice, All() expects a pointer to + // slice of maps or structs. + // + // The behaviour of One() extends to each one of the results. + All(destSlice interface{}) error + + // One maps the row that is in the current query cursor into the + // given interface, which can be a pointer to either a map or a + // struct. + // + // If dest is a pointer to map, each one of the columns will create a new map + // key and the values of the result will be set as values for the keys. + // + // Depending on the type of map key and value, the results columns and values + // may need to be transformed. + // + // If dest if a pointer to struct, each one of the fields will be tested for + // a `db` tag which defines the column mapping. The value of the result will + // be set as the value of the field. + One(dest interface{}) error +} + +// BatchInserter provides an interface to do massive insertions in batches. +type BatchInserter interface { + // Values pushes column values to be inserted as part of the batch. + Values(...interface{}) BatchInserter + + // NextResult dumps the next slice of results to dst, which can mean having + // the IDs of all inserted elements in the batch. + NextResult(dst interface{}) bool + + // Done signals that no more elements are going to be added. + Done() + + // Wait blocks until the whole batch is executed. + Wait() error + + // Err returns the last error that happened while executing the batch (or nil + // if no error happened). + Err() error +} diff --git a/collection.go b/collection.go new file mode 100644 index 0000000..6ac09ec --- /dev/null +++ b/collection.go @@ -0,0 +1,45 @@ +package mydb + +// Collection defines methods to work with database tables or collections. +type Collection interface { + + // Name returns the name of the collection. + Name() string + + // Session returns the Session that was used to create the collection + // reference. + Session() Session + + // Find defines a new result set. + Find(...interface{}) Result + + Count() (uint64, error) + + // Insert inserts a new item into the collection, the type of this item could + // be a map, a struct or pointer to either of them. If the call succeeds and + // if the collection has a primary key, Insert returns the ID of the newly + // added element as an `interface{}`. The underlying type of this ID depends + // on both the database adapter and the column storing the ID. The ID + // returned by Insert() could be passed directly to Find() to retrieve the + // newly added element. + Insert(interface{}) (InsertResult, error) + + // InsertReturning is like Insert() but it takes a pointer to map or struct + // and, if the operation succeeds, updates it with data from the newly + // inserted row. If the database does not support transactions this method + // returns db.ErrUnsupported. + InsertReturning(interface{}) error + + // UpdateReturning takes a pointer to a map or struct and tries to update the + // row the item is refering to. If the element is updated sucessfully, + // UpdateReturning will fetch the row and update the fields of the passed + // item. If the database does not support transactions this method returns + // db.ErrUnsupported + UpdateReturning(interface{}) error + + // Exists returns true if the collection exists, false otherwise. + Exists() (bool, error) + + // Truncate removes all elements on the collection. + Truncate() error +} diff --git a/comparison.go b/comparison.go new file mode 100644 index 0000000..d79113a --- /dev/null +++ b/comparison.go @@ -0,0 +1,158 @@ +package mydb + +import ( + "reflect" + "time" + + "git.hexq.cn/tiglog/mydb/internal/adapter" +) + +// Comparison represents a relationship between values. +type Comparison struct { + *adapter.Comparison +} + +// Gte is a comparison that means: is greater than or equal to value. +func Gte(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)} +} + +// Lte is a comparison that means: is less than or equal to value. +func Lte(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)} +} + +// Eq is a comparison that means: is equal to value. +func Eq(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, value)} +} + +// NotEq is a comparison that means: is not equal to value. +func NotEq(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, value)} +} + +// Gt is a comparison that means: is greater than value. +func Gt(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)} +} + +// Lt is a comparison that means: is less than value. +func Lt(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)} +} + +// In is a comparison that means: is any of the values. +func In(value ...interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))} +} + +// AnyOf is a comparison that means: is any of the values of the slice. +func AnyOf(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))} +} + +// NotIn is a comparison that means: is none of the values. +func NotIn(value ...interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))} +} + +// NotAnyOf is a comparison that means: is none of the values of the slice. +func NotAnyOf(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))} +} + +// After is a comparison that means: is after the (time.Time) value. +func After(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)} +} + +// Before is a comparison that means: is before the (time.Time) value. +func Before(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)} +} + +// OnOrAfter is a comparison that means: is on or after the (time.Time) value. +func OnOrAfter(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)} +} + +// OnOrBefore is a comparison that means: is on or before the (time.Time) value. +func OnOrBefore(value time.Time) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)} +} + +// Between is a comparison that means: is between lowerBound and upperBound. +func Between(lowerBound interface{}, upperBound interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{lowerBound, upperBound})} +} + +// NotBetween is a comparison that means: is not between lowerBound and upperBound. +func NotBetween(lowerBound interface{}, upperBound interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{lowerBound, upperBound})} +} + +// Is is a comparison that means: is equivalent to nil, true or false. +func Is(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, value)} +} + +// IsNot is a comparison that means: is not equivalent to nil, true nor false. +func IsNot(value interface{}) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, value)} +} + +// IsNull is a comparison that means: is equivalent to nil. +func IsNull() *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil)} +} + +// IsNotNull is a comparison that means: is not equivalent to nil. +func IsNotNull() *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil)} +} + +// Like is a comparison that checks whether the reference matches the wildcard +// value. +func Like(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, value)} +} + +// NotLike is a comparison that checks whether the reference does not match the +// wildcard value. +func NotLike(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, value)} +} + +// RegExp is a comparison that checks whether the reference matches the regular +// expression. +func RegExp(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, value)} +} + +// NotRegExp is a comparison that checks whether the reference does not match +// the regular expression. +func NotRegExp(value string) *Comparison { + return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, value)} +} + +// Op returns a custom comparison operator. +func Op(customOperator string, value interface{}) *Comparison { + return &Comparison{adapter.NewCustomComparisonOperator(customOperator, value)} +} + +func toInterfaceArray(value interface{}) []interface{} { + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Ptr: + return toInterfaceArray(rv.Elem().Interface()) + case reflect.Slice: + elems := rv.Len() + args := make([]interface{}, elems) + for i := 0; i < elems; i++ { + args[i] = rv.Index(i).Interface() + } + return args + } + return []interface{}{value} +} diff --git a/comparison_test.go b/comparison_test.go new file mode 100644 index 0000000..03ea789 --- /dev/null +++ b/comparison_test.go @@ -0,0 +1,111 @@ +package mydb + +import ( + "testing" + "time" + + "git.hexq.cn/tiglog/mydb/internal/adapter" + "github.com/stretchr/testify/assert" +) + +func TestComparison(t *testing.T) { + testTimeVal := time.Now() + + testCases := []struct { + expects *adapter.Comparison + result *Comparison + }{ + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, 1), + Gte(1), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, 22), + Lte(22), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, 6), + Eq(6), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, 67), + NotEq(67), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, 4), + Gt(4), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, 47), + Lt(47), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, []interface{}{1, 22, 34}), + In(1, 22, 34), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, testTimeVal), + After(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, testTimeVal), + Before(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, testTimeVal), + OnOrAfter(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, testTimeVal), + OnOrBefore(testTimeVal), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{11, 35}), + Between(11, 35), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{11, 35}), + NotBetween(11, 35), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, 178), + Is(178), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, 32), + IsNot(32), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil), + IsNull(), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil), + IsNotNull(), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, "%a%"), + Like("%a%"), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, "%z%"), + NotLike("%z%"), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, ".*"), + RegExp(".*"), + }, + { + adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, ".*"), + NotRegExp(".*"), + }, + { + adapter.NewCustomComparisonOperator("~", 56), + Op("~", 56), + }, + } + + for i := range testCases { + assert.Equal(t, testCases[i].expects, testCases[i].result.Comparison) + } +} diff --git a/cond.go b/cond.go new file mode 100644 index 0000000..385c7e6 --- /dev/null +++ b/cond.go @@ -0,0 +1,109 @@ +package mydb + +import ( + "fmt" + "sort" + + "git.hexq.cn/tiglog/mydb/internal/adapter" +) + +// LogicalExpr represents an expression to be used in logical statements. +type LogicalExpr = adapter.LogicalExpr + +// LogicalOperator represents a logical operation. +type LogicalOperator = adapter.LogicalOperator + +// Cond is a map that defines conditions for a query. +// +// Each entry of the map represents a condition (a column-value relation bound +// by a comparison Operator). The comparison can be specified after the column +// name, if no comparison operator is provided the equality operator is used as +// default. +// +// Examples: +// +// // Age equals 18. +// db.Cond{"age": 18} +// +// // Age is greater than or equal to 18. +// db.Cond{"age >=": 18} +// +// // id is any of the values 1, 2 or 3. +// db.Cond{"id IN": []{1, 2, 3}} +// +// // Age is lower than 18 (MongoDB syntax) +// db.Cond{"age $lt": 18} +// +// // age > 32 and age < 35 +// db.Cond{"age >": 32, "age <": 35} +type Cond map[interface{}]interface{} + +// Empty returns false if there are no conditions. +func (c Cond) Empty() bool { + for range c { + return false + } + return true +} + +// Constraints returns each one of the Cond map entires as a constraint. +func (c Cond) Constraints() []adapter.Constraint { + z := make([]adapter.Constraint, 0, len(c)) + for _, k := range c.keys() { + z = append(z, adapter.NewConstraint(k, c[k])) + } + return z +} + +// Operator returns the equality operator. +func (c Cond) Operator() LogicalOperator { + return adapter.DefaultLogicalOperator +} + +func (c Cond) keys() []interface{} { + keys := make(condKeys, 0, len(c)) + for k := range c { + keys = append(keys, k) + } + if len(c) > 1 { + sort.Sort(keys) + } + return keys +} + +// Expressions returns all the expressions contained in the condition. +func (c Cond) Expressions() []LogicalExpr { + z := make([]LogicalExpr, 0, len(c)) + for _, k := range c.keys() { + z = append(z, Cond{k: c[k]}) + } + return z +} + +type condKeys []interface{} + +func (ck condKeys) Len() int { + return len(ck) +} + +func (ck condKeys) Less(i, j int) bool { + return fmt.Sprintf("%v", ck[i]) < fmt.Sprintf("%v", ck[j]) +} + +func (ck condKeys) Swap(i, j int) { + ck[i], ck[j] = ck[j], ck[i] +} + +func defaultJoin(in ...adapter.LogicalExpr) []adapter.LogicalExpr { + for i := range in { + cond, ok := in[i].(Cond) + if ok && !cond.Empty() { + in[i] = And(cond) + } + } + return in +} + +var ( + _ = LogicalExpr(Cond{}) +) diff --git a/cond_test.go b/cond_test.go new file mode 100644 index 0000000..9e4232a --- /dev/null +++ b/cond_test.go @@ -0,0 +1,69 @@ +package mydb + +import ( + "testing" +) + +func TestCond(t *testing.T) { + c := Cond{} + + if !c.Empty() { + t.Fatal("Cond is empty.") + } + + c = Cond{"id": 1} + if c.Empty() { + t.Fatal("Cond is not empty.") + } +} + +func TestCondAnd(t *testing.T) { + a := And() + + if !a.Empty() { + t.Fatal("Cond is empty") + } + + _ = a.And(Cond{"id": 1}) + + if !a.Empty() { + t.Fatal("Cond is still empty") + } + + a = a.And(Cond{"name": "Ana"}) + + if a.Empty() { + t.Fatal("Cond is not empty anymore") + } + + a = a.And().And() + + if a.Empty() { + t.Fatal("Cond is not empty anymore") + } +} + +func TestCondOr(t *testing.T) { + a := Or() + + if !a.Empty() { + t.Fatal("Cond is empty") + } + + _ = a.Or(Cond{"id": 1}) + + if !a.Empty() { + t.Fatal("Cond is empty") + } + + a = a.Or(Cond{"name": "Ana"}) + + if a.Empty() { + t.Fatal("Cond is not empty") + } + + a = a.Or().Or() + if a.Empty() { + t.Fatal("Cond is not empty") + } +} diff --git a/connection_url.go b/connection_url.go new file mode 100644 index 0000000..20ff9ce --- /dev/null +++ b/connection_url.go @@ -0,0 +1,8 @@ +package mydb + +// ConnectionURL represents a data source name (DSN). +type ConnectionURL interface { + // String returns the connection string that is going to be passed to the + // adapter. + String() string +} diff --git a/errors.go b/errors.go new file mode 100644 index 0000000..92abf8f --- /dev/null +++ b/errors.go @@ -0,0 +1,42 @@ +package mydb + +import ( + "errors" +) + +// Error messages +var ( + ErrMissingAdapter = errors.New(`mydb: missing adapter`) + ErrAlreadyWithinTransaction = errors.New(`mydb: already within a transaction`) + ErrCollectionDoesNotExist = errors.New(`mydb: collection does not exist`) + ErrExpectingNonNilModel = errors.New(`mydb: expecting non nil model`) + ErrExpectingPointerToStruct = errors.New(`mydb: expecting pointer to struct`) + ErrGivingUpTryingToConnect = errors.New(`mydb: giving up trying to connect: too many clients`) + ErrInvalidCollection = errors.New(`mydb: invalid collection`) + ErrMissingCollectionName = errors.New(`mydb: missing collection name`) + ErrMissingConditions = errors.New(`mydb: missing selector conditions`) + ErrMissingConnURL = errors.New(`mydb: missing DSN`) + ErrMissingDatabaseName = errors.New(`mydb: missing database name`) + ErrNoMoreRows = errors.New(`mydb: no more rows in this result set`) + ErrNotConnected = errors.New(`mydb: not connected to a database`) + ErrNotImplemented = errors.New(`mydb: call not implemented`) + ErrQueryIsPending = errors.New(`mydb: can't execute this instruction while the result set is still open`) + ErrQueryLimitParam = errors.New(`mydb: a query can accept only one limit parameter`) + ErrQueryOffsetParam = errors.New(`mydb: a query can accept only one offset parameter`) + ErrQuerySortParam = errors.New(`mydb: a query can accept only one order-by parameter`) + ErrSockerOrHost = errors.New(`mydb: you may connect either to a UNIX socket or a TCP address, but not both`) + ErrTooManyClients = errors.New(`mydb: can't connect to database server: too many clients`) + ErrUndefined = errors.New(`mydb: value is undefined`) + ErrUnknownConditionType = errors.New(`mydb: arguments of type %T can't be used as constraints`) + ErrUnsupported = errors.New(`mydb: action is not supported by the DBMS`) + ErrUnsupportedDestination = errors.New(`mydb: unsupported destination type`) + ErrUnsupportedType = errors.New(`mydb: type does not support marshaling`) + ErrUnsupportedValue = errors.New(`mydb: value does not support unmarshaling`) + ErrNilRecord = errors.New(`mydb: invalid item (nil)`) + ErrRecordIDIsZero = errors.New(`mydb: item ID is not defined`) + ErrMissingPrimaryKeys = errors.New(`mydb: collection %q has no primary keys`) + ErrWarnSlowQuery = errors.New(`mydb: slow query`) + ErrTransactionAborted = errors.New(`mydb: transaction was aborted`) + ErrNotWithinTransaction = errors.New(`mydb: not within transaction`) + ErrNotSupportedByAdapter = errors.New(`mydb: not supported by adapter`) +) diff --git a/errors_test.go b/errors_test.go new file mode 100644 index 0000000..d26e677 --- /dev/null +++ b/errors_test.go @@ -0,0 +1,14 @@ +package mydb + +import ( + "errors" + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestErrorWrap(t *testing.T) { + adapterFakeErr := fmt.Errorf("could not find item in %q: %w", "users", ErrCollectionDoesNotExist) + assert.True(t, errors.Is(adapterFakeErr, ErrCollectionDoesNotExist)) +} diff --git a/function.go b/function.go new file mode 100644 index 0000000..6f9daf5 --- /dev/null +++ b/function.go @@ -0,0 +1,25 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// FuncExpr represents functions. +type FuncExpr = adapter.FuncExpr + +// Func returns a database function expression. +// +// Examples: +// +// // MOD(29, 9) +// db.Func("MOD", 29, 9) +// +// // CONCAT("foo", "bar") +// db.Func("CONCAT", "foo", "bar") +// +// // NOW() +// db.Func("NOW") +// +// // RTRIM("Hello ") +// db.Func("RTRIM", "Hello ") +func Func(name string, args ...interface{}) *FuncExpr { + return adapter.NewFuncExpr(name, args) +} diff --git a/function_test.go b/function_test.go new file mode 100644 index 0000000..b1f2ef9 --- /dev/null +++ b/function_test.go @@ -0,0 +1,51 @@ +package mydb + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestFunction(t *testing.T) { + { + fn := Func("MOD", 29, 9) + assert.Equal(t, "MOD", fn.Name()) + assert.Equal(t, []interface{}{29, 9}, fn.Arguments()) + } + + { + fn := Func("HELLO") + assert.Equal(t, "HELLO", fn.Name()) + assert.Equal(t, []interface{}(nil), fn.Arguments()) + } + + { + fn := Func("CONCAT", "a") + assert.Equal(t, "CONCAT", fn.Name()) + assert.Equal(t, []interface{}{"a"}, fn.Arguments()) + } + + { + fn := Func("CONCAT", "a", "b", "c") + assert.Equal(t, "CONCAT", fn.Name()) + assert.Equal(t, []interface{}{"a", "b", "c"}, fn.Arguments()) + } + + { + fn := Func("IN", []interface{}{"a", "b", "c"}) + assert.Equal(t, "IN", fn.Name()) + assert.Equal(t, []interface{}{[]interface{}{"a", "b", "c"}}, fn.Arguments()) + } + + { + fn := Func("IN", []interface{}{"a"}) + assert.Equal(t, "IN", fn.Name()) + assert.Equal(t, []interface{}{[]interface{}{"a"}}, fn.Arguments()) + } + + { + fn := Func("IN", []interface{}(nil)) + assert.Equal(t, "IN", fn.Name()) + assert.Equal(t, []interface{}{[]interface{}(nil)}, fn.Arguments()) + } +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..58d9f56 --- /dev/null +++ b/go.mod @@ -0,0 +1,33 @@ +module git.hexq.cn/tiglog/mydb + +go 1.20 + +require ( + github.com/go-sql-driver/mysql v1.7.1 + github.com/google/uuid v1.1.1 + github.com/ipfs/go-detect-race v0.0.1 + github.com/jackc/pgtype v1.14.0 + github.com/jackc/pgx/v4 v4.18.1 + github.com/lib/pq v1.10.9 + github.com/mattn/go-sqlite3 v1.14.17 + github.com/segmentio/fasthash v1.0.3 + github.com/sirupsen/logrus v1.9.3 + github.com/stretchr/testify v1.8.4 + gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 +) + +require ( + github.com/davecgh/go-spew v1.1.1 // indirect + github.com/jackc/chunkreader/v2 v2.0.1 // indirect + github.com/jackc/pgconn v1.14.1 // indirect + github.com/jackc/pgio v1.0.0 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgproto3/v2 v2.3.2 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + golang.org/x/crypto v0.12.0 // indirect + golang.org/x/sys v0.12.0 // indirect + golang.org/x/text v0.13.0 // indirect + gopkg.in/yaml.v2 v2.4.0 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..5e38d2e --- /dev/null +++ b/go.sum @@ -0,0 +1,226 @@ +github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= +github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc= +github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs= +github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I= +github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ= +github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4= +github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY= +github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY= +github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw= +github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= +github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= +github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ6cOk= +github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps= +github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo= +github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8= +github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk= +github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA= +github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE= +github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s= +github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o= +github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY= +github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI= +github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgconn v1.14.1 h1:smbxIaZA08n6YuxEX1sDyjV/qkbtUtkH20qLkR9MUR4= +github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E= +github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE= +github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8= +github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE= +github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc= +github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA= +github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg= +github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM= +github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0= +github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA= +github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg= +github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc= +github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw= +github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM= +github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw= +github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4= +github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y= +github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM= +github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc= +github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs= +github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0= +github.com/jackc/pgx/v4 v4.18.1/go.mod h1:FydWkUyadDmdNH/mHnGob881GawxeEm7TcMCzkb+qQE= +github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk= +github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ= +github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc= +github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU= +github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM= +github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= +github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= +github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= +github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ= +github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU= +github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc= +github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0= +github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM= +github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY= +github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q= +github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= +github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= +github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q= +go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE= +go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ= +go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0= +go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4= +go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= +go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA= +go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q= +go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE= +golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= +golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I= +golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58= +golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk= +golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc= +golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o= +golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k= +golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= +golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s= +gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 h1:VpOs+IwYnYBaFnrNAeB8UUWtL3vEUnzSCL1nVjPhqrw= +gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= diff --git a/internal/adapter/comparison.go b/internal/adapter/comparison.go new file mode 100644 index 0000000..1f63a20 --- /dev/null +++ b/internal/adapter/comparison.go @@ -0,0 +1,60 @@ +package adapter + +// ComparisonOperator is the base type for comparison operators. +type ComparisonOperator uint8 + +// Comparison operators +const ( + ComparisonOperatorNone ComparisonOperator = iota + ComparisonOperatorCustom + + ComparisonOperatorEqual + ComparisonOperatorNotEqual + + ComparisonOperatorLessThan + ComparisonOperatorGreaterThan + + ComparisonOperatorLessThanOrEqualTo + ComparisonOperatorGreaterThanOrEqualTo + + ComparisonOperatorBetween + ComparisonOperatorNotBetween + + ComparisonOperatorIn + ComparisonOperatorNotIn + + ComparisonOperatorIs + ComparisonOperatorIsNot + + ComparisonOperatorLike + ComparisonOperatorNotLike + + ComparisonOperatorRegExp + ComparisonOperatorNotRegExp +) + +type Comparison struct { + t ComparisonOperator + op string + v interface{} +} + +func (c *Comparison) CustomOperator() string { + return c.op +} + +func (c *Comparison) Operator() ComparisonOperator { + return c.t +} + +func (c *Comparison) Value() interface{} { + return c.v +} + +func NewComparisonOperator(t ComparisonOperator, v interface{}) *Comparison { + return &Comparison{t: t, v: v} +} + +func NewCustomComparisonOperator(op string, v interface{}) *Comparison { + return &Comparison{t: ComparisonOperatorCustom, op: op, v: v} +} diff --git a/internal/adapter/constraint.go b/internal/adapter/constraint.go new file mode 100644 index 0000000..84a63de --- /dev/null +++ b/internal/adapter/constraint.go @@ -0,0 +1,51 @@ +package adapter + +// ConstraintValuer allows constraints to use specific values of their own. +type ConstraintValuer interface { + ConstraintValue() interface{} +} + +// Constraint interface represents a single condition, like "a = 1". where `a` +// is the key and `1` is the value. This is an exported interface but it's +// rarely used directly, you may want to use the `db.Cond{}` map instead. +type Constraint interface { + // Key is the leftmost part of the constraint and usually contains a column + // name. + Key() interface{} + + // Value if the rightmost part of the constraint and usually contains a + // column value. + Value() interface{} +} + +// Constraints interface represents an array of constraints, like "a = 1, b = +// 2, c = 3". +type Constraints interface { + // Constraints returns an array of constraints. + Constraints() []Constraint +} + +type constraint struct { + k interface{} + v interface{} +} + +func (c constraint) Key() interface{} { + return c.k +} + +func (c constraint) Value() interface{} { + if constraintValuer, ok := c.v.(ConstraintValuer); ok { + return constraintValuer.ConstraintValue() + } + return c.v +} + +// NewConstraint creates a constraint. +func NewConstraint(key interface{}, value interface{}) Constraint { + return &constraint{k: key, v: value} +} + +var ( + _ = Constraint(&constraint{}) +) diff --git a/internal/adapter/func.go b/internal/adapter/func.go new file mode 100644 index 0000000..e6353ba --- /dev/null +++ b/internal/adapter/func.go @@ -0,0 +1,18 @@ +package adapter + +type FuncExpr struct { + name string + args []interface{} +} + +func (f *FuncExpr) Arguments() []interface{} { + return f.args +} + +func (f *FuncExpr) Name() string { + return f.name +} + +func NewFuncExpr(name string, args []interface{}) *FuncExpr { + return &FuncExpr{name: name, args: args} +} diff --git a/internal/adapter/logical_expr.go b/internal/adapter/logical_expr.go new file mode 100644 index 0000000..fce009a --- /dev/null +++ b/internal/adapter/logical_expr.go @@ -0,0 +1,100 @@ +package adapter + +import "git.hexq.cn/tiglog/mydb/internal/immutable" + +// LogicalExpr represents a group formed by one or more sentences joined by +// an Operator like "AND" or "OR". +type LogicalExpr interface { + // Expressions returns child sentences. + Expressions() []LogicalExpr + + // Operator returns the Operator that joins all the sentences in the group. + Operator() LogicalOperator + + // Empty returns true if the compound has zero children, false otherwise. + Empty() bool +} + +// LogicalOperator represents the operation on a compound statement. +type LogicalOperator uint + +// LogicalExpr Operators. +const ( + LogicalOperatorNone LogicalOperator = iota + LogicalOperatorAnd + LogicalOperatorOr +) + +const DefaultLogicalOperator = LogicalOperatorAnd + +type LogicalExprGroup struct { + op LogicalOperator + + prev *LogicalExprGroup + fn func(*[]LogicalExpr) error +} + +func NewLogicalExprGroup(op LogicalOperator, conds ...LogicalExpr) *LogicalExprGroup { + group := &LogicalExprGroup{op: op} + if len(conds) == 0 { + return group + } + return group.Frame(func(in *[]LogicalExpr) error { + *in = append(*in, conds...) + return nil + }) +} + +// Expressions returns each one of the conditions as a compound. +func (g *LogicalExprGroup) Expressions() []LogicalExpr { + conds, err := immutable.FastForward(g) + if err == nil { + return *(conds.(*[]LogicalExpr)) + } + return nil +} + +// Operator is undefined for a logical group. +func (g *LogicalExprGroup) Operator() LogicalOperator { + if g.op == LogicalOperatorNone { + panic("operator is not defined") + } + return g.op +} + +// Empty returns true if this condition has no elements. False otherwise. +func (g *LogicalExprGroup) Empty() bool { + if g.fn != nil { + return false + } + if g.prev != nil { + return g.prev.Empty() + } + return true +} + +func (g *LogicalExprGroup) Frame(fn func(*[]LogicalExpr) error) *LogicalExprGroup { + return &LogicalExprGroup{prev: g, op: g.op, fn: fn} +} + +func (g *LogicalExprGroup) Prev() immutable.Immutable { + if g == nil { + return nil + } + return g.prev +} + +func (g *LogicalExprGroup) Fn(in interface{}) error { + if g.fn == nil { + return nil + } + return g.fn(in.(*[]LogicalExpr)) +} + +func (g *LogicalExprGroup) Base() interface{} { + return &[]LogicalExpr{} +} + +var ( + _ = immutable.Immutable(&LogicalExprGroup{}) +) diff --git a/internal/adapter/raw.go b/internal/adapter/raw.go new file mode 100644 index 0000000..fc688d1 --- /dev/null +++ b/internal/adapter/raw.go @@ -0,0 +1,49 @@ +package adapter + +// RawExpr interface represents values that can bypass SQL filters. This is an +// exported interface but it's rarely used directly, you may want to use the +// `db.Raw()` function instead. +type RawExpr struct { + value string + args *[]interface{} +} + +func (r *RawExpr) Arguments() []interface{} { + if r.args != nil { + return *r.args + } + return nil +} + +func (r RawExpr) Raw() string { + return r.value +} + +func (r RawExpr) String() string { + return r.Raw() +} + +// Expressions returns a logical expressio.n +func (r *RawExpr) Expressions() []LogicalExpr { + return []LogicalExpr{r} +} + +// Operator returns the default compound operator. +func (r RawExpr) Operator() LogicalOperator { + return LogicalOperatorNone +} + +// Empty return false if this struct has no value. +func (r *RawExpr) Empty() bool { + return r.value == "" +} + +func NewRawExpr(value string, args []interface{}) *RawExpr { + r := &RawExpr{value: value, args: nil} + if len(args) > 0 { + r.args = &args + } + return r +} + +var _ = LogicalExpr(&RawExpr{}) diff --git a/internal/cache/cache.go b/internal/cache/cache.go new file mode 100644 index 0000000..a72a621 --- /dev/null +++ b/internal/cache/cache.go @@ -0,0 +1,113 @@ +package cache + +import ( + "container/list" + "errors" + "sync" +) + +const defaultCapacity = 128 + +// Cache holds a map of volatile key -> values. +type Cache struct { + keys *list.List + items map[uint64]*list.Element + mu sync.RWMutex + capacity int +} + +type cacheItem struct { + key uint64 + value interface{} +} + +// NewCacheWithCapacity initializes a new caching space with the given +// capacity. +func NewCacheWithCapacity(capacity int) (*Cache, error) { + if capacity < 1 { + return nil, errors.New("Capacity must be greater than zero.") + } + c := &Cache{ + capacity: capacity, + } + c.init() + return c, nil +} + +// NewCache initializes a new caching space with default settings. +func NewCache() *Cache { + c, err := NewCacheWithCapacity(defaultCapacity) + if err != nil { + panic(err.Error()) // Should never happen as we're not providing a negative defaultCapacity. + } + return c +} + +func (c *Cache) init() { + c.items = make(map[uint64]*list.Element) + c.keys = list.New() +} + +// Read attempts to retrieve a cached value as a string, if the value does not +// exists returns an empty string and false. +func (c *Cache) Read(h Hashable) (string, bool) { + if v, ok := c.ReadRaw(h); ok { + if s, ok := v.(string); ok { + return s, true + } + } + return "", false +} + +// ReadRaw attempts to retrieve a cached value as an interface{}, if the value +// does not exists returns nil and false. +func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) { + c.mu.RLock() + defer c.mu.RUnlock() + + item, ok := c.items[h.Hash()] + if ok { + return item.Value.(*cacheItem).value, true + } + + return nil, false +} + +// Write stores a value in memory. If the value already exists its overwritten. +func (c *Cache) Write(h Hashable, value interface{}) { + c.mu.Lock() + defer c.mu.Unlock() + + key := h.Hash() + + if item, ok := c.items[key]; ok { + item.Value.(*cacheItem).value = value + c.keys.MoveToFront(item) + return + } + + c.items[key] = c.keys.PushFront(&cacheItem{key, value}) + + for c.keys.Len() > c.capacity { + item := c.keys.Remove(c.keys.Back()).(*cacheItem) + delete(c.items, item.key) + if p, ok := item.value.(HasOnEvict); ok { + p.OnEvict() + } + } +} + +// Clear generates a new memory space, leaving the old memory unreferenced, so +// it can be claimed by the garbage collector. +func (c *Cache) Clear() { + c.mu.Lock() + defer c.mu.Unlock() + + for _, item := range c.items { + if p, ok := item.Value.(*cacheItem).value.(HasOnEvict); ok { + p.OnEvict() + } + } + + c.init() +} diff --git a/internal/cache/cache_test.go b/internal/cache/cache_test.go new file mode 100644 index 0000000..11aeaef --- /dev/null +++ b/internal/cache/cache_test.go @@ -0,0 +1,97 @@ +package cache + +import ( + "fmt" + "hash/fnv" + "testing" +) + +var c *Cache + +type cacheableT struct { + Name string +} + +func (ct *cacheableT) Hash() uint64 { + s := fnv.New64() + s.Sum([]byte(ct.Name)) + return s.Sum64() +} + +var ( + key = cacheableT{"foo"} + value = "bar" +) + +func TestNewCache(t *testing.T) { + c = NewCache() + if c == nil { + t.Fatal("Expecting a new cache object.") + } +} + +func TestCacheReadNonExistentValue(t *testing.T) { + if _, ok := c.Read(&key); ok { + t.Fatal("Expecting false.") + } +} + +func TestCacheWritingValue(t *testing.T) { + c.Write(&key, value) + c.Write(&key, value) +} + +func TestCacheReadExistentValue(t *testing.T) { + s, ok := c.Read(&key) + + if !ok { + t.Fatal("Expecting true.") + } + + if s != value { + t.Fatal("Expecting value.") + } +} + +func BenchmarkNewCache(b *testing.B) { + for i := 0; i < b.N; i++ { + NewCache() + } +} + +func BenchmarkNewCacheAndClear(b *testing.B) { + for i := 0; i < b.N; i++ { + c := NewCache() + c.Clear() + } +} + +func BenchmarkReadNonExistentValue(b *testing.B) { + z := NewCache() + for i := 0; i < b.N; i++ { + z.Read(&key) + } +} + +func BenchmarkWriteSameValue(b *testing.B) { + z := NewCache() + for i := 0; i < b.N; i++ { + z.Write(&key, value) + } +} + +func BenchmarkWriteNewValue(b *testing.B) { + z := NewCache() + for i := 0; i < b.N; i++ { + key := cacheableT{fmt.Sprintf("item-%d", i)} + z.Write(&key, value) + } +} + +func BenchmarkReadExistentValue(b *testing.B) { + z := NewCache() + z.Write(&key, value) + for i := 0; i < b.N; i++ { + z.Read(&key) + } +} diff --git a/internal/cache/hash.go b/internal/cache/hash.go new file mode 100644 index 0000000..4b866a9 --- /dev/null +++ b/internal/cache/hash.go @@ -0,0 +1,109 @@ +package cache + +import ( + "fmt" + + "github.com/segmentio/fasthash/fnv1a" +) + +const ( + hashTypeInt uint64 = 1 << iota + hashTypeSignedInt + hashTypeBool + hashTypeString + hashTypeHashable + hashTypeNil +) + +type hasher struct { + t uint64 + v interface{} +} + +func (h *hasher) Hash() uint64 { + return NewHash(h.t, h.v) +} + +func NewHashable(t uint64, v interface{}) Hashable { + return &hasher{t: t, v: v} +} + +func InitHash(t uint64) uint64 { + return fnv1a.AddUint64(fnv1a.Init64, t) +} + +func NewHash(t uint64, in ...interface{}) uint64 { + return AddToHash(InitHash(t), in...) +} + +func AddToHash(h uint64, in ...interface{}) uint64 { + for i := range in { + if in[i] == nil { + continue + } + h = addToHash(h, in[i]) + } + return h +} + +func addToHash(h uint64, in interface{}) uint64 { + switch v := in.(type) { + case uint64: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), v) + case uint32: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint16: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint8: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case uint: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + case int64: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int32: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int16: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int8: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case int: + if v < 0 { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v)) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v)) + } + case bool: + if v { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 1) + } else { + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 2) + } + case string: + return fnv1a.AddString64(fnv1a.AddUint64(h, hashTypeString), v) + case Hashable: + if in == nil { + panic(fmt.Sprintf("could not hash nil element %T", in)) + } + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeHashable), v.Hash()) + case nil: + return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeNil), 0) + default: + panic(fmt.Sprintf("unsupported value type %T", in)) + } +} diff --git a/internal/cache/interface.go b/internal/cache/interface.go new file mode 100644 index 0000000..aee9ac7 --- /dev/null +++ b/internal/cache/interface.go @@ -0,0 +1,13 @@ +package cache + +// Hashable types must implement a method that returns a key. This key will be +// associated with a cached value. +type Hashable interface { + Hash() uint64 +} + +// HasOnEvict type is (optionally) implemented by cache objects to clean after +// themselves. +type HasOnEvict interface { + OnEvict() +} diff --git a/internal/immutable/immutable.go b/internal/immutable/immutable.go new file mode 100644 index 0000000..57d29ce --- /dev/null +++ b/internal/immutable/immutable.go @@ -0,0 +1,28 @@ +package immutable + +// Immutable represents an immutable chain that, if passed to FastForward, +// applies Fn() to every element of a chain, the first element of this chain is +// represented by Base(). +type Immutable interface { + // Prev is the previous element on a chain. + Prev() Immutable + // Fn a function that is able to modify the passed element. + Fn(interface{}) error + // Base is the first element on a chain, there's no previous element before + // the Base element. + Base() interface{} +} + +// FastForward applies all Fn methods in order on the given new Base. +func FastForward(curr Immutable) (interface{}, error) { + prev := curr.Prev() + if prev == nil { + return curr.Base(), nil + } + in, err := FastForward(prev) + if err != nil { + return nil, err + } + err = curr.Fn(in) + return in, err +} diff --git a/internal/reflectx/LICENSE b/internal/reflectx/LICENSE new file mode 100644 index 0000000..0d31edf --- /dev/null +++ b/internal/reflectx/LICENSE @@ -0,0 +1,23 @@ + Copyright (c) 2013, Jason Moiron + + Permission is hereby granted, free of charge, to any person + obtaining a copy of this software and associated documentation + files (the "Software"), to deal in the Software without + restriction, including without limitation the rights to use, + copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the + Software is furnished to do so, subject to the following + conditions: + + The above copyright notice and this permission notice shall be + included in all copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + OTHER DEALINGS IN THE SOFTWARE. + diff --git a/internal/reflectx/README.md b/internal/reflectx/README.md new file mode 100644 index 0000000..76f1b5d --- /dev/null +++ b/internal/reflectx/README.md @@ -0,0 +1,17 @@ +# reflectx + +The sqlx package has special reflect needs. In particular, it needs to: + +* be able to map a name to a field +* understand embedded structs +* understand mapping names to fields by a particular tag +* user specified name -> field mapping functions + +These behaviors mimic the behaviors by the standard library marshallers and also the +behavior of standard Go accessors. + +The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is +addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct +tags in the ways that are vital to most marshalers, and they are slow. + +This reflectx package extends reflect to achieve these goals. diff --git a/internal/reflectx/reflect.go b/internal/reflectx/reflect.go new file mode 100644 index 0000000..02df3da --- /dev/null +++ b/internal/reflectx/reflect.go @@ -0,0 +1,404 @@ +// Package reflectx implements extensions to the standard reflect lib suitable +// for implementing marshaling and unmarshaling packages. The main Mapper type +// allows for Go-compatible named attribute access, including accessing embedded +// struct attributes and the ability to use functions and struct tags to +// customize field names. +package reflectx + +import ( + "fmt" + "reflect" + "runtime" + "strings" + "sync" +) + +// A FieldInfo is a collection of metadata about a struct field. +type FieldInfo struct { + Index []int + Path string + Field reflect.StructField + Zero reflect.Value + Name string + Options map[string]string + Embedded bool + Children []*FieldInfo + Parent *FieldInfo +} + +// A StructMap is an index of field metadata for a struct. +type StructMap struct { + Tree *FieldInfo + Index []*FieldInfo + Paths map[string]*FieldInfo + Names map[string]*FieldInfo +} + +// GetByPath returns a *FieldInfo for a given string path. +func (f StructMap) GetByPath(path string) *FieldInfo { + return f.Paths[path] +} + +// GetByTraversal returns a *FieldInfo for a given integer path. It is +// analogous to reflect.FieldByIndex. +func (f StructMap) GetByTraversal(index []int) *FieldInfo { + if len(index) == 0 { + return nil + } + + tree := f.Tree + for _, i := range index { + if i >= len(tree.Children) || tree.Children[i] == nil { + return nil + } + tree = tree.Children[i] + } + return tree +} + +// Mapper is a general purpose mapper of names to struct fields. A Mapper +// behaves like most marshallers, optionally obeying a field tag for name +// mapping and a function to provide a basic mapping of fields to names. +type Mapper struct { + cache map[reflect.Type]*StructMap + tagName string + tagMapFunc func(string) string + mapFunc func(string) string + mutex sync.Mutex +} + +// NewMapper returns a new mapper which optionally obeys the field tag given +// by tagName. If tagName is the empty string, it is ignored. +func NewMapper(tagName string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + } +} + +// NewMapperTagFunc returns a new mapper which contains a mapper for field names +// AND a mapper for tag values. This is useful for tags like json which can +// have values like "name,omitempty". +func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: mapFunc, + tagMapFunc: tagMapFunc, + } +} + +// NewMapperFunc returns a new mapper which optionally obeys a field tag and +// a struct field name mapper func given by f. Tags will take precedence, but +// for any other field, the mapped name will be f(field.Name) +func NewMapperFunc(tagName string, f func(string) string) *Mapper { + return &Mapper{ + cache: make(map[reflect.Type]*StructMap), + tagName: tagName, + mapFunc: f, + } +} + +// TypeMap returns a mapping of field strings to int slices representing +// the traversal down the struct to reach the field. +func (m *Mapper) TypeMap(t reflect.Type) *StructMap { + m.mutex.Lock() + mapping, ok := m.cache[t] + if !ok { + mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc) + m.cache[t] = mapping + } + m.mutex.Unlock() + return mapping +} + +// FieldMap returns the mapper's mapping of field names to reflect values. Panics +// if v's Kind is not Struct, or v is not Indirectable to a struct kind. +func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + r[tagName] = FieldByIndexes(v, fi.Index) + } + return r +} + +// ValidFieldMap returns the mapper's mapping of field names to reflect valid +// field values. Panics if v's Kind is not Struct, or v is not Indirectable to +// a struct kind. +func (m *Mapper) ValidFieldMap(v reflect.Value) map[string]reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + r := map[string]reflect.Value{} + tm := m.TypeMap(v.Type()) + for tagName, fi := range tm.Names { + v := ValidFieldByIndexes(v, fi.Index) + if v.IsValid() { + r[tagName] = v + } + } + return r +} + +// FieldByName returns a field by the its mapped name as a reflect.Value. +// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind. +// Returns zero Value if the name is not found. +func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + fi, ok := tm.Names[name] + if !ok { + return v + } + return FieldByIndexes(v, fi.Index) +} + +// FieldsByName returns a slice of values corresponding to the slice of names +// for the value. Panics if v's Kind is not Struct or v is not Indirectable +// to a struct Kind. Returns zero Value for each name not found. +func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value { + v = reflect.Indirect(v) + mustBe(v, reflect.Struct) + + tm := m.TypeMap(v.Type()) + vals := make([]reflect.Value, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + vals = append(vals, *new(reflect.Value)) + } else { + vals = append(vals, FieldByIndexes(v, fi.Index)) + } + } + return vals +} + +// TraversalsByName returns a slice of int slices which represent the struct +// traversals for each mapped name. Panics if t is not a struct or Indirectable +// to a struct. Returns empty int slice for each name not found. +func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int { + t = Deref(t) + mustBe(t, reflect.Struct) + tm := m.TypeMap(t) + + r := make([][]int, 0, len(names)) + for _, name := range names { + fi, ok := tm.Names[name] + if !ok { + r = append(r, []int{}) + } else { + r = append(r, fi.Index) + } + } + return r +} + +// FieldByIndexes returns a value for a particular struct traversal. +func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + // if this is a pointer, it's possible it is nil + if v.Kind() == reflect.Ptr && v.IsNil() { + alloc := reflect.New(Deref(v.Type())) + v.Set(alloc) + } + if v.Kind() == reflect.Map && v.IsNil() { + v.Set(reflect.MakeMap(v.Type())) + } + } + return v +} + +// ValidFieldByIndexes returns a value for a particular struct traversal. +func ValidFieldByIndexes(v reflect.Value, indexes []int) reflect.Value { + + for _, i := range indexes { + v = reflect.Indirect(v) + if !v.IsValid() { + return reflect.Value{} + } + v = v.Field(i) + // if this is a pointer, it's possible it is nil + if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Map) && v.IsNil() { + return reflect.Value{} + } + } + + return v +} + +// FieldByIndexesReadOnly returns a value for a particular struct traversal, +// but is not concerned with allocating nil pointers because the value is +// going to be used for reading and not setting. +func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value { + for _, i := range indexes { + v = reflect.Indirect(v).Field(i) + } + return v +} + +// Deref is Indirect for reflect.Types +func Deref(t reflect.Type) reflect.Type { + if t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +// -- helpers & utilities -- + +type kinder interface { + Kind() reflect.Kind +} + +// mustBe checks a value against a kind, panicing with a reflect.ValueError +// if the kind isn't that which is required. +func mustBe(v kinder, expected reflect.Kind) { + k := v.Kind() + if k != expected { + panic(&reflect.ValueError{Method: methodName(), Kind: k}) + } +} + +// methodName is returns the caller of the function calling methodName +func methodName() string { + pc, _, _, _ := runtime.Caller(2) + f := runtime.FuncForPC(pc) + if f == nil { + return "unknown method" + } + return f.Name() +} + +type typeQueue struct { + t reflect.Type + fi *FieldInfo + pp string // Parent path +} + +// A copying append that creates a new slice each time. +func apnd(is []int, i int) []int { + x := make([]int, len(is)+1) + copy(x, is) + x[len(x)-1] = i + return x +} + +// getMapping returns a mapping for the t type, using the tagName, mapFunc and +// tagMapFunc to determine the canonical names of fields. +func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap { + m := []*FieldInfo{} + + root := &FieldInfo{} + queue := []typeQueue{} + queue = append(queue, typeQueue{Deref(t), root, ""}) + + for len(queue) != 0 { + // pop the first item off of the queue + tq := queue[0] + queue = queue[1:] + nChildren := 0 + if tq.t.Kind() == reflect.Struct { + nChildren = tq.t.NumField() + } + tq.fi.Children = make([]*FieldInfo, nChildren) + + // iterate through all of its fields + for fieldPos := 0; fieldPos < nChildren; fieldPos++ { + f := tq.t.Field(fieldPos) + + fi := FieldInfo{} + fi.Field = f + fi.Zero = reflect.New(f.Type).Elem() + fi.Options = map[string]string{} + + var tag, name string + if tagName != "" && strings.Contains(string(f.Tag), tagName+":") { + tag = f.Tag.Get(tagName) + name = tag + } else { + if mapFunc != nil { + name = mapFunc(f.Name) + } + } + + parts := strings.Split(name, ",") + if len(parts) > 1 { + name = parts[0] + for _, opt := range parts[1:] { + kv := strings.Split(opt, "=") + if len(kv) > 1 { + fi.Options[kv[0]] = kv[1] + } else { + fi.Options[kv[0]] = "" + } + } + } + + if tagMapFunc != nil { + tag = tagMapFunc(tag) + } + + fi.Name = name + + if tq.pp == "" || (tq.pp == "" && tag == "") { + fi.Path = fi.Name + } else { + fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name) + } + + // if the name is "-", disabled via a tag, skip it + if name == "-" { + continue + } + + // skip unexported fields + if len(f.PkgPath) != 0 && !f.Anonymous { + continue + } + + // bfs search of anonymous embedded structs + if f.Anonymous { + pp := tq.pp + if tag != "" { + pp = fi.Path + } + + fi.Embedded = true + fi.Index = apnd(tq.fi.Index, fieldPos) + nChildren := 0 + ft := Deref(f.Type) + if ft.Kind() == reflect.Struct { + nChildren = ft.NumField() + } + fi.Children = make([]*FieldInfo, nChildren) + queue = append(queue, typeQueue{Deref(f.Type), &fi, pp}) + } else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) { + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Children = make([]*FieldInfo, Deref(f.Type).NumField()) + queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path}) + } + + fi.Index = apnd(tq.fi.Index, fieldPos) + fi.Parent = tq.fi + tq.fi.Children[fieldPos] = &fi + m = append(m, &fi) + } + } + + flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}} + for _, fi := range flds.Index { + flds.Paths[fi.Path] = fi + if fi.Name != "" && !fi.Embedded { + flds.Names[fi.Path] = fi + } + } + + return flds +} diff --git a/internal/reflectx/reflect_test.go b/internal/reflectx/reflect_test.go new file mode 100644 index 0000000..8072244 --- /dev/null +++ b/internal/reflectx/reflect_test.go @@ -0,0 +1,587 @@ +package reflectx + +import ( + "reflect" + "strings" + "testing" +) + +func ival(v reflect.Value) int { + return v.Interface().(int) +} + +func TestBasic(t *testing.T) { + type Foo struct { + A int + B int + C int + } + + f := Foo{1, 2, 3} + fv := reflect.ValueOf(f) + m := NewMapperFunc("", func(s string) string { return s }) + + v := m.FieldByName(fv, "A") + if ival(v) != f.A { + t.Errorf("Expecting %d, got %d", ival(v), f.A) + } + v = m.FieldByName(fv, "B") + if ival(v) != f.B { + t.Errorf("Expecting %d, got %d", f.B, ival(v)) + } + v = m.FieldByName(fv, "C") + if ival(v) != f.C { + t.Errorf("Expecting %d, got %d", f.C, ival(v)) + } +} + +func TestBasicEmbedded(t *testing.T) { + type Foo struct { + A int + } + + type Bar struct { + Foo // `db:""` is implied for an embedded struct + B int + C int `db:"-"` + } + + type Baz struct { + A int + Bar `db:"Bar"` + } + + m := NewMapperFunc("db", func(s string) string { return s }) + + z := Baz{} + z.A = 1 + z.B = 2 + z.C = 4 + z.Bar.Foo.A = 3 + + zv := reflect.ValueOf(z) + fields := m.TypeMap(reflect.TypeOf(z)) + + if len(fields.Index) != 5 { + t.Errorf("Expecting 5 fields") + } + + // for _, fi := range fields.Index { + // log.Println(fi) + // } + + v := m.FieldByName(zv, "A") + if ival(v) != z.A { + t.Errorf("Expecting %d, got %d", z.A, ival(v)) + } + v = m.FieldByName(zv, "Bar.B") + if ival(v) != z.Bar.B { + t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v)) + } + v = m.FieldByName(zv, "Bar.A") + if ival(v) != z.Bar.Foo.A { + t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + } + v = m.FieldByName(zv, "Bar.C") + if _, ok := v.Interface().(int); ok { + t.Errorf("Expecting Bar.C to not exist") + } + + fi := fields.GetByPath("Bar.C") + if fi != nil { + t.Errorf("Bar.C should not exist") + } +} + +func TestEmbeddedSimple(t *testing.T) { + type UUID [16]byte + type MyID struct { + UUID + } + type Item struct { + ID MyID + } + z := Item{} + + m := NewMapper("db") + m.TypeMap(reflect.TypeOf(z)) +} + +func TestBasicEmbeddedWithTags(t *testing.T) { + type Foo struct { + A int `db:"a"` + } + + type Bar struct { + Foo // `db:""` is implied for an embedded struct + B int `db:"b"` + } + + type Baz struct { + A int `db:"a"` + Bar // `db:""` is implied for an embedded struct + } + + m := NewMapper("db") + + z := Baz{} + z.A = 1 + z.B = 2 + z.Bar.Foo.A = 3 + + zv := reflect.ValueOf(z) + fields := m.TypeMap(reflect.TypeOf(z)) + + if len(fields.Index) != 5 { + t.Errorf("Expecting 5 fields") + } + + // for _, fi := range fields.index { + // log.Println(fi) + // } + + v := m.FieldByName(zv, "a") + if ival(v) != z.Bar.Foo.A { // the dominant field + t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v)) + } + v = m.FieldByName(zv, "b") + if ival(v) != z.B { + t.Errorf("Expecting %d, got %d", z.B, ival(v)) + } +} + +func TestFlatTags(t *testing.T) { + m := NewMapper("db") + + type Asset struct { + Title string `db:"title"` + } + type Post struct { + Author string `db:"author,required"` + Asset Asset `db:""` + } + // Post columns: (author title) + + post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}} + pv := reflect.ValueOf(post) + + v := m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } + v = m.FieldByName(pv, "title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } +} + +func TestNestedStruct(t *testing.T) { + m := NewMapper("db") + + type Details struct { + Active bool `db:"active"` + } + type Asset struct { + Title string `db:"title"` + Details Details `db:"details"` + } + type Post struct { + Author string `db:"author,required"` + Asset `db:"asset"` + } + // Post columns: (author asset.title asset.details.active) + + post := Post{ + Author: "Joe", + Asset: Asset{Title: "Hello", Details: Details{Active: true}}, + } + pv := reflect.ValueOf(post) + + v := m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } + v = m.FieldByName(pv, "title") + if _, ok := v.Interface().(string); ok { + t.Errorf("Expecting field to not exist") + } + v = m.FieldByName(pv, "asset.title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset.details.active") + if v.Interface().(bool) != post.Asset.Details.Active { + t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool)) + } +} + +func TestInlineStruct(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + + type Employee struct { + Name string + ID int + } + type Boss Employee + type person struct { + Employee `db:"employee"` + Boss `db:"boss"` + } + // employees columns: (employee.name employee.id boss.name boss.id) + + em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}} + ev := reflect.ValueOf(em) + + fields := m.TypeMap(reflect.TypeOf(em)) + if len(fields.Index) != 6 { + t.Errorf("Expecting 6 fields") + } + + v := m.FieldByName(ev, "employee.name") + if v.Interface().(string) != em.Employee.Name { + t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string)) + } + v = m.FieldByName(ev, "boss.id") + if ival(v) != em.Boss.ID { + t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v)) + } +} + +func TestFieldsEmbedded(t *testing.T) { + m := NewMapper("db") + + type Person struct { + Name string `db:"name"` + } + type Place struct { + Name string `db:"name"` + } + type Article struct { + Title string `db:"title"` + } + type PP struct { + Person `db:"person,required"` + Place `db:",someflag"` + Article `db:",required"` + } + // PP columns: (person.name name title) + + pp := PP{} + pp.Person.Name = "Peter" + pp.Place.Name = "Toronto" + pp.Article.Title = "Best city ever" + + fields := m.TypeMap(reflect.TypeOf(pp)) + // for i, f := range fields { + // log.Println(i, f) + // } + + ppv := reflect.ValueOf(pp) + + v := m.FieldByName(ppv, "person.name") + if v.Interface().(string) != pp.Person.Name { + t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string)) + } + + v = m.FieldByName(ppv, "name") + if v.Interface().(string) != pp.Place.Name { + t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string)) + } + + v = m.FieldByName(ppv, "title") + if v.Interface().(string) != pp.Article.Title { + t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string)) + } + + fi := fields.GetByPath("person") + if _, ok := fi.Options["required"]; !ok { + t.Errorf("Expecting required option to be set") + } + if !fi.Embedded { + t.Errorf("Expecting field to be embedded") + } + if len(fi.Index) != 1 || fi.Index[0] != 0 { + t.Errorf("Expecting index to be [0]") + } + + fi = fields.GetByPath("person.name") + if fi == nil { + t.Errorf("Expecting person.name to exist") + } + if fi.Path != "person.name" { + t.Errorf("Expecting %s, got %s", "person.name", fi.Path) + } + + fi = fields.GetByTraversal([]int{1, 0}) + if fi == nil { + t.Errorf("Expecting traveral to exist") + } + if fi.Path != "name" { + t.Errorf("Expecting %s, got %s", "name", fi.Path) + } + + fi = fields.GetByTraversal([]int{2}) + if fi == nil { + t.Errorf("Expecting traversal to exist") + } + if _, ok := fi.Options["required"]; !ok { + t.Errorf("Expecting required option to be set") + } + + trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"}) + if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) { + t.Errorf("Expecting traversal: %v", trs) + } +} + +func TestPtrFields(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + type Asset struct { + Title string + } + type Post struct { + *Asset `db:"asset"` + Author string + } + + post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}} + pv := reflect.ValueOf(post) + + fields := m.TypeMap(reflect.TypeOf(post)) + if len(fields.Index) != 3 { + t.Errorf("Expecting 3 fields") + } + + v := m.FieldByName(pv, "asset.title") + if v.Interface().(string) != post.Asset.Title { + t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } +} + +func TestNamedPtrFields(t *testing.T) { + m := NewMapperTagFunc("db", strings.ToLower, nil) + + type User struct { + Name string + } + + type Asset struct { + Title string + + Owner *User `db:"owner"` + } + type Post struct { + Author string + + Asset1 *Asset `db:"asset1"` + Asset2 *Asset `db:"asset2"` + } + + post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil + pv := reflect.ValueOf(post) + + fields := m.TypeMap(reflect.TypeOf(post)) + if len(fields.Index) != 9 { + t.Errorf("Expecting 9 fields") + } + + v := m.FieldByName(pv, "asset1.title") + if v.Interface().(string) != post.Asset1.Title { + t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset1.owner.name") + if v.Interface().(string) != post.Asset1.Owner.Name { + t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset2.title") + if v.Interface().(string) != post.Asset2.Title { + t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string)) + } + v = m.FieldByName(pv, "asset2.owner.name") + if v.Interface().(string) != post.Asset2.Owner.Name { + t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string)) + } + v = m.FieldByName(pv, "author") + if v.Interface().(string) != post.Author { + t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string)) + } +} + +func TestFieldMap(t *testing.T) { + type Foo struct { + A int + B int + C int + } + + f := Foo{1, 2, 3} + m := NewMapperFunc("db", strings.ToLower) + + fm := m.FieldMap(reflect.ValueOf(f)) + + if len(fm) != 3 { + t.Errorf("Expecting %d keys, got %d", 3, len(fm)) + } + if fm["a"].Interface().(int) != 1 { + t.Errorf("Expecting %d, got %d", 1, ival(fm["a"])) + } + if fm["b"].Interface().(int) != 2 { + t.Errorf("Expecting %d, got %d", 2, ival(fm["b"])) + } + if fm["c"].Interface().(int) != 3 { + t.Errorf("Expecting %d, got %d", 3, ival(fm["c"])) + } +} + +func TestTagNameMapping(t *testing.T) { + type Strategy struct { + StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"` + StrategyName string + } + + m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string { + if strings.Contains(value, ",") { + return strings.Split(value, ",")[0] + } + return value + }) + strategy := Strategy{"1", "Alpah"} + mapping := m.TypeMap(reflect.TypeOf(strategy)) + + for _, key := range []string{"strategy_id", "STRATEGYNAME"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } +} + +func TestMapping(t *testing.T) { + type Person struct { + ID int + Name string + WearsGlasses bool `db:"wears_glasses"` + } + + m := NewMapperFunc("db", strings.ToLower) + p := Person{1, "Jason", true} + mapping := m.TypeMap(reflect.TypeOf(p)) + + for _, key := range []string{"id", "name", "wears_glasses"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + type SportsPerson struct { + Weight int + Age int + Person + } + s := SportsPerson{Weight: 100, Age: 30, Person: p} + mapping = m.TypeMap(reflect.TypeOf(s)) + for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + type RugbyPlayer struct { + Position int + IsIntense bool `db:"is_intense"` + IsAllBlack bool `db:"-"` + SportsPerson + } + r := RugbyPlayer{12, true, false, s} + mapping = m.TypeMap(reflect.TypeOf(r)) + for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} { + if fi := mapping.GetByPath(key); fi == nil { + t.Errorf("Expecting to find key %s in mapping but did not.", key) + } + } + + if fi := mapping.GetByPath("isallblack"); fi != nil { + t.Errorf("Expecting to ignore `IsAllBlack` field") + } +} + +type E1 struct { + A int +} +type E2 struct { + E1 + B int +} +type E3 struct { + E2 + C int +} +type E4 struct { + E3 + D int +} + +func BenchmarkFieldNameL1(b *testing.B) { + e4 := E4{D: 1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.FieldByName("D") + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldNameL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.FieldByName("A") + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldPosL1(b *testing.B) { + e4 := E4{D: 1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.Field(1) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldPosL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := v.Field(0) + f = f.Field(0) + f = f.Field(0) + f = f.Field(0) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} + +func BenchmarkFieldByIndexL4(b *testing.B) { + e4 := E4{} + e4.A = 1 + idx := []int{0, 0, 0, 0} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(e4) + f := FieldByIndexes(v, idx) + if f.Interface().(int) != 1 { + b.Fatal("Wrong value.") + } + } +} diff --git a/internal/sqladapter/collection.go b/internal/sqladapter/collection.go new file mode 100644 index 0000000..01c492a --- /dev/null +++ b/internal/sqladapter/collection.go @@ -0,0 +1,369 @@ +package sqladapter + +import ( + "fmt" + "reflect" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// CollectionAdapter defines methods to be implemented by SQL adapters. +type CollectionAdapter interface { + // Insert prepares and executes an INSERT statament. When the item is + // succefully added, Insert returns a unique identifier of the newly added + // element (or nil if the unique identifier couldn't be determined). + Insert(Collection, interface{}) (interface{}, error) +} + +// Collection satisfies mydb.Collection. +type Collection interface { + // Insert inserts a new item into the collection. + Insert(interface{}) (mydb.InsertResult, error) + + // Name returns the name of the collection. + Name() string + + // Session returns the mydb.Session the collection belongs to. + Session() mydb.Session + + // Exists returns true if the collection exists, false otherwise. + Exists() (bool, error) + + // Find defined a new result set. + Find(conds ...interface{}) mydb.Result + + Count() (uint64, error) + + // Truncate removes all elements on the collection and resets the + // collection's IDs. + Truncate() error + + // InsertReturning inserts a new item into the collection and refreshes the + // item with actual data from the database. This is useful to get automatic + // values, such as timestamps, or IDs. + InsertReturning(item interface{}) error + + // UpdateReturning updates a record from the collection and refreshes the item + // with actual data from the database. This is useful to get automatic + // values, such as timestamps, or IDs. + UpdateReturning(item interface{}) error + + // PrimaryKeys returns the names of all primary keys in the table. + PrimaryKeys() ([]string, error) + + // SQLBuilder returns a mydb.SQL instance. + SQL() mydb.SQL +} + +type finder interface { + Find(Collection, *Result, ...interface{}) mydb.Result +} + +type condsFilter interface { + FilterConds(...interface{}) []interface{} +} + +// collection is the implementation of Collection. +type collection struct { + name string + adapter CollectionAdapter +} + +type collectionWithSession struct { + *collection + + session Session +} + +func newCollection(name string, adapter CollectionAdapter) *collection { + if adapter == nil { + panic("mydb: nil adapter") + } + return &collection{ + name: name, + adapter: adapter, + } +} + +func (c *collectionWithSession) SQL() mydb.SQL { + return c.session.SQL() +} + +func (c *collectionWithSession) Session() mydb.Session { + return c.session +} + +func (c *collectionWithSession) Name() string { + return c.name +} + +func (c *collectionWithSession) Count() (uint64, error) { + return c.Find().Count() +} + +func (c *collectionWithSession) Insert(item interface{}) (mydb.InsertResult, error) { + id, err := c.adapter.Insert(c, item) + if err != nil { + return nil, err + } + + return mydb.NewInsertResult(id), nil +} + +func (c *collectionWithSession) PrimaryKeys() ([]string, error) { + return c.session.PrimaryKeys(c.Name()) +} + +func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) { + pk, err := c.PrimaryKeys() + if err != nil { + return nil, err + } + if len(conds) == 1 && len(pk) == 1 { + if id := conds[0]; IsKeyValue(id) { + conds[0] = mydb.Cond{pk[0]: mydb.Eq(id)} + } + } + if tr, ok := c.adapter.(condsFilter); ok { + return tr.FilterConds(conds...), nil + } + return conds, nil +} + +func (c *collectionWithSession) Find(conds ...interface{}) mydb.Result { + filteredConds, err := c.filterConds(conds...) + if err != nil { + res := &Result{} + res.setErr(err) + return res + } + + res := NewResult( + c.session.SQL(), + c.Name(), + filteredConds, + ) + if f, ok := c.adapter.(finder); ok { + return f.Find(c, res, conds...) + } + return res +} + +func (c *collectionWithSession) Exists() (bool, error) { + if err := c.session.TableExists(c.Name()); err != nil { + return false, err + } + return true, nil +} + +func (c *collectionWithSession) InsertReturning(item interface{}) error { + if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { + return fmt.Errorf("Expecting a pointer but got %T", item) + } + + // Grab primary keys + pks, err := c.PrimaryKeys() + if err != nil { + return err + } + + if len(pks) == 0 { + if ok, err := c.Exists(); !ok { + return err + } + return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name()) + } + + var tx Session + isTransaction := c.session.IsTransaction() + if isTransaction { + tx = c.session + } else { + var err error + tx, err = c.session.NewTransaction(c.session.Context(), nil) + if err != nil { + return err + } + defer tx.Close() + } + + // Allocate a clone of item. + newItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface() + var newItemFieldMap map[string]reflect.Value + + itemValue := reflect.ValueOf(item) + + col := tx.Collection(c.Name()) + + // Insert item as is and grab the returning ID. + var newItemRes mydb.Result + id, err := col.Insert(item) + if err != nil { + goto cancel + } + if id == nil { + err = fmt.Errorf("InsertReturning: Could not get a valid ID after inserting. Does the %q table have a primary key?", c.Name()) + goto cancel + } + + if len(pks) > 1 { + newItemRes = col.Find(id) + } else { + // We have one primary key, build a explicit mydb.Cond with it to prevent + // string keys to be considered as raw conditions. + newItemRes = col.Find(mydb.Cond{pks[0]: id}) // We already checked that pks is not empty, so pks[0] is defined. + } + + // Fetch the row that was just interted into newItem + err = newItemRes.One(newItem) + if err != nil { + goto cancel + } + + switch reflect.ValueOf(newItem).Elem().Kind() { + case reflect.Struct: + // Get valid fields from newItem to overwrite those that are on item. + newItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(newItem)) + for fieldName := range newItemFieldMap { + sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(newItemFieldMap[fieldName]) + } + case reflect.Map: + newItemV := reflect.ValueOf(newItem).Elem() + itemV := reflect.ValueOf(item) + if itemV.Kind() == reflect.Ptr { + itemV = itemV.Elem() + } + for _, keyV := range newItemV.MapKeys() { + itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV)) + } + default: + err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem) + goto cancel + } + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + return tx.Commit() + } + + return err + +cancel: + // This goto label should only be used when we got an error within a + // transaction and we don't want to continue. + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + _ = tx.Rollback() + } + return err +} + +func (c *collectionWithSession) UpdateReturning(item interface{}) error { + if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { + return fmt.Errorf("Expecting a pointer but got %T", item) + } + + // Grab primary keys + pks, err := c.PrimaryKeys() + if err != nil { + return err + } + + if len(pks) == 0 { + if ok, err := c.Exists(); !ok { + return err + } + return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name()) + } + + var tx Session + isTransaction := c.session.IsTransaction() + + if isTransaction { + tx = c.session + } else { + // Not within a transaction, let's create one. + var err error + tx, err = c.session.NewTransaction(c.session.Context(), nil) + if err != nil { + return err + } + defer tx.Close() + } + + // Allocate a clone of item. + defaultItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface() + var defaultItemFieldMap map[string]reflect.Value + + itemValue := reflect.ValueOf(item) + + conds := mydb.Cond{} + for _, pk := range pks { + conds[pk] = mydb.Eq(sqlbuilder.Mapper.FieldByName(itemValue, pk).Interface()) + } + + col := tx.(Session).Collection(c.Name()) + + err = col.Find(conds).Update(item) + if err != nil { + goto cancel + } + + if err = col.Find(conds).One(defaultItem); err != nil { + goto cancel + } + + switch reflect.ValueOf(defaultItem).Elem().Kind() { + case reflect.Struct: + // Get valid fields from defaultItem to overwrite those that are on item. + defaultItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(defaultItem)) + for fieldName := range defaultItemFieldMap { + sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(defaultItemFieldMap[fieldName]) + } + case reflect.Map: + defaultItemV := reflect.ValueOf(defaultItem).Elem() + itemV := reflect.ValueOf(item) + if itemV.Kind() == reflect.Ptr { + itemV = itemV.Elem() + } + for _, keyV := range defaultItemV.MapKeys() { + itemV.SetMapIndex(keyV, defaultItemV.MapIndex(keyV)) + } + default: + panic("default") + } + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + return tx.Commit() + } + return err + +cancel: + // This goto label should only be used when we got an error within a + // transaction and we don't want to continue. + + if !isTransaction { + // This is only executed if t.Session() was **not** a transaction and if + // sess was created with sess.NewTransaction(). + _ = tx.Rollback() + } + return err +} + +func (c *collectionWithSession) Truncate() error { + stmt := exql.Statement{ + Type: exql.Truncate, + Table: exql.TableWithName(c.Name()), + } + if _, err := c.session.SQL().Exec(&stmt); err != nil { + return err + } + return nil +} diff --git a/internal/sqladapter/compat/query.go b/internal/sqladapter/compat/query.go new file mode 100644 index 0000000..93cb8fc --- /dev/null +++ b/internal/sqladapter/compat/query.go @@ -0,0 +1,72 @@ +// +build !go1.8 + +package compat + +import ( + "context" + "database/sql" +) + +type PreparedExecer interface { + Exec(...interface{}) (sql.Result, error) +} + +func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) { + return p.Exec(args...) +} + +type Execer interface { + Exec(string, ...interface{}) (sql.Result, error) +} + +func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) { + return p.Exec(query, args...) +} + +type PreparedQueryer interface { + Query(...interface{}) (*sql.Rows, error) +} + +func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) { + return p.Query(args...) +} + +type Queryer interface { + Query(string, ...interface{}) (*sql.Rows, error) +} + +func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) { + return p.Query(query, args...) +} + +type PreparedRowQueryer interface { + QueryRow(...interface{}) *sql.Row +} + +func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row { + return p.QueryRow(args...) +} + +type RowQueryer interface { + QueryRow(string, ...interface{}) *sql.Row +} + +func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row { + return p.QueryRow(query, args...) +} + +type Preparer interface { + Prepare(string) (*sql.Stmt, error) +} + +func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) { + return p.Prepare(query) +} + +type TxStarter interface { + Begin() (*sql.Tx, error) +} + +func BeginTx(p TxStarter, ctx context.Context, opts interface{}) (*sql.Tx, error) { + return p.Begin() +} diff --git a/internal/sqladapter/compat/query_go18.go b/internal/sqladapter/compat/query_go18.go new file mode 100644 index 0000000..a3abbaf --- /dev/null +++ b/internal/sqladapter/compat/query_go18.go @@ -0,0 +1,72 @@ +// +build go1.8 + +package compat + +import ( + "context" + "database/sql" +) + +type PreparedExecer interface { + ExecContext(context.Context, ...interface{}) (sql.Result, error) +} + +func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) { + return p.ExecContext(ctx, args...) +} + +type Execer interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) +} + +func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) { + return p.ExecContext(ctx, query, args...) +} + +type PreparedQueryer interface { + QueryContext(context.Context, ...interface{}) (*sql.Rows, error) +} + +func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) { + return p.QueryContext(ctx, args...) +} + +type Queryer interface { + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) +} + +func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) { + return p.QueryContext(ctx, query, args...) +} + +type PreparedRowQueryer interface { + QueryRowContext(context.Context, ...interface{}) *sql.Row +} + +func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row { + return p.QueryRowContext(ctx, args...) +} + +type RowQueryer interface { + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row { + return p.QueryRowContext(ctx, query, args...) +} + +type Preparer interface { + PrepareContext(context.Context, string) (*sql.Stmt, error) +} + +func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) { + return p.PrepareContext(ctx, query) +} + +type TxStarter interface { + BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error) +} + +func BeginTx(p TxStarter, ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) { + return p.BeginTx(ctx, opts) +} diff --git a/internal/sqladapter/exql/column.go b/internal/sqladapter/exql/column.go new file mode 100644 index 0000000..ec10825 --- /dev/null +++ b/internal/sqladapter/exql/column.go @@ -0,0 +1,83 @@ +package exql + +import ( + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +type columnWithAlias struct { + Name string + Alias string +} + +// Column represents a SQL column. +type Column struct { + Name interface{} +} + +var _ = Fragment(&Column{}) + +// ColumnWithName creates and returns a Column with the given name. +func ColumnWithName(name string) *Column { + return &Column{Name: name} +} + +// Hash returns a unique identifier for the struct. +func (c *Column) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_Column, nil) + } + return cache.NewHash(FragmentType_Column, c.Name) +} + +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *Column) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(c); ok { + return z, nil + } + + var alias string + switch value := c.Name.(type) { + case string: + value = trimString(value) + + chunks := separateByAS(value) + if len(chunks) == 1 { + chunks = separateBySpace(value) + } + + name := chunks[0] + nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) + + for i := range nameChunks { + nameChunks[i] = trimString(nameChunks[i]) + if nameChunks[i] == "*" { + continue + } + nameChunks[i] = layout.MustCompile(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) + } + + compiled = strings.Join(nameChunks, layout.ColumnSeparator) + + if len(chunks) > 1 { + alias = trimString(chunks[1]) + alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias}) + } + case compilable: + compiled, err = value.Compile(layout) + if err != nil { + return "", err + } + default: + return "", fmt.Errorf(errExpectingHashableFmt, c.Name) + } + + if alias != "" { + compiled = layout.MustCompile(layout.ColumnAliasLayout, columnWithAlias{compiled, alias}) + } + + layout.Write(c, compiled) + return +} diff --git a/internal/sqladapter/exql/column_test.go b/internal/sqladapter/exql/column_test.go new file mode 100644 index 0000000..7706538 --- /dev/null +++ b/internal/sqladapter/exql/column_test.go @@ -0,0 +1,88 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestColumnString(t *testing.T) { + column := Column{Name: "role.name"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"role"."name"`, s) +} + +func TestColumnAs(t *testing.T) { + column := Column{Name: "role.name as foo"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"role"."name" AS "foo"`, s) +} + +func TestColumnImplicitAs(t *testing.T) { + column := Column{Name: "role.name foo"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"role"."name" AS "foo"`, s) +} + +func TestColumnRaw(t *testing.T) { + column := Column{Name: &Raw{Value: "role.name As foo"}} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `role.name As foo`, s) +} + +func BenchmarkColumnWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = ColumnWithName("a") + } +} + +func BenchmarkColumnHash(b *testing.B) { + c := Column{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnCompile(b *testing.B) { + c := Column{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Column{Name: "name"} + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithDotCompile(b *testing.B) { + c := Column{Name: "role.name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnWithAsKeywordCompile(b *testing.B) { + c := Column{Name: "role.name AS foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/column_value.go b/internal/sqladapter/exql/column_value.go new file mode 100644 index 0000000..4d1b303 --- /dev/null +++ b/internal/sqladapter/exql/column_value.go @@ -0,0 +1,112 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// ColumnValue represents a bundle between a column and a corresponding value. +type ColumnValue struct { + Column Fragment + Operator string + Value Fragment +} + +var _ = Fragment(&ColumnValue{}) + +type columnValueT struct { + Column string + Operator string + Value string +} + +// Hash returns a unique identifier for the struct. +func (c *ColumnValue) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_ColumnValue, nil) + } + return cache.NewHash(FragmentType_ColumnValue, c.Column, c.Operator, c.Value) +} + +// Compile transforms the ColumnValue into an equivalent SQL representation. +func (c *ColumnValue) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(c); ok { + return z, nil + } + + column, err := c.Column.Compile(layout) + if err != nil { + return "", err + } + + data := columnValueT{ + Column: column, + Operator: c.Operator, + } + + if c.Value != nil { + data.Value, err = c.Value.Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.TrimSpace(layout.MustCompile(layout.ColumnValue, data)) + + layout.Write(c, compiled) + + return +} + +// ColumnValues represents an array of ColumnValue +type ColumnValues struct { + ColumnValues []Fragment +} + +var _ = Fragment(&ColumnValues{}) + +// JoinColumnValues returns an array of ColumnValue +func JoinColumnValues(values ...Fragment) *ColumnValues { + return &ColumnValues{ColumnValues: values} +} + +// Insert adds a column to the columns array. +func (c *ColumnValues) Insert(values ...Fragment) *ColumnValues { + c.ColumnValues = append(c.ColumnValues, values...) + return c +} + +// Hash returns a unique identifier for the struct. +func (c *ColumnValues) Hash() uint64 { + h := cache.InitHash(FragmentType_ColumnValues) + for i := range c.ColumnValues { + h = cache.AddToHash(h, c.ColumnValues[i]) + } + return h +} + +// Compile transforms the ColumnValues into its SQL representation. +func (c *ColumnValues) Compile(layout *Template) (compiled string, err error) { + + if z, ok := layout.Read(c); ok { + return z, nil + } + + l := len(c.ColumnValues) + + out := make([]string, l) + + for i := range c.ColumnValues { + out[i], err = c.ColumnValues[i].Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.TrimSpace(strings.Join(out, layout.IdentifierSeparator)) + + layout.Write(c, compiled) + + return +} diff --git a/internal/sqladapter/exql/column_value_test.go b/internal/sqladapter/exql/column_value_test.go new file mode 100644 index 0000000..33ecc36 --- /dev/null +++ b/internal/sqladapter/exql/column_value_test.go @@ -0,0 +1,115 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestColumnValue(t *testing.T) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + s, err := cv.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"id" = '1'`, s) + + cv = &ColumnValue{Column: ColumnWithName("date"), Operator: "=", Value: &Raw{Value: "NOW()"}} + s, err = cv.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"date" = NOW()`, s) +} + +func TestColumnValues(t *testing.T) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + + s, err := cvs.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"id" > '8', "other"."id" < 100, "name" = 'Haruki Murakami', "created" >= NOW(), "modified" <= NOW()`, s) +} + +func BenchmarkNewColumnValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = &ColumnValue{Column: ColumnWithName("a"), Operator: "=", Value: NewValue(Raw{Value: "7"})} + } +} + +func BenchmarkColumnValueHash(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + b.ResetTimer() + for i := 0; i < b.N; i++ { + cv.Hash() + } +} + +func BenchmarkColumnValueCompile(b *testing.B) { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = cv.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cv := &ColumnValue{Column: ColumnWithName("id"), Operator: "=", Value: NewValue(1)} + _, _ = cv.Compile(defaultTemplate) + } +} + +func BenchmarkJoinColumnValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(Raw{Value: "NOW()"})}, + ) + } +} + +func BenchmarkColumnValuesHash(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + cvs.Hash() + } +} + +func BenchmarkColumnValuesCompile(b *testing.B) { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = cvs.Compile(defaultTemplate) + } +} + +func BenchmarkColumnValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + cvs := JoinColumnValues( + &ColumnValue{Column: ColumnWithName("id"), Operator: ">", Value: NewValue(8)}, + &ColumnValue{Column: ColumnWithName("other.id"), Operator: "<", Value: NewValue(&Raw{Value: "100"})}, + &ColumnValue{Column: ColumnWithName("name"), Operator: "=", Value: NewValue("Haruki Murakami")}, + &ColumnValue{Column: ColumnWithName("created"), Operator: ">=", Value: NewValue(&Raw{Value: "NOW()"})}, + &ColumnValue{Column: ColumnWithName("modified"), Operator: "<=", Value: NewValue(&Raw{Value: "NOW()"})}, + ) + _, _ = cvs.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/columns.go b/internal/sqladapter/exql/columns.go new file mode 100644 index 0000000..d060b4e --- /dev/null +++ b/internal/sqladapter/exql/columns.go @@ -0,0 +1,83 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Columns represents an array of Column. +type Columns struct { + Columns []Fragment +} + +var _ = Fragment(&Columns{}) + +// Hash returns a unique identifier. +func (c *Columns) Hash() uint64 { + if c == nil { + return cache.NewHash(FragmentType_Columns, nil) + } + h := cache.InitHash(FragmentType_Columns) + for i := range c.Columns { + h = cache.AddToHash(h, c.Columns[i]) + } + return h +} + +// JoinColumns creates and returns an array of Column. +func JoinColumns(columns ...Fragment) *Columns { + return &Columns{Columns: columns} +} + +// OnConditions creates and retuens a new On. +func OnConditions(conditions ...Fragment) *On { + return &On{Conditions: conditions} +} + +// UsingColumns builds a Using from the given columns. +func UsingColumns(columns ...Fragment) *Using { + return &Using{Columns: columns} +} + +// Append +func (c *Columns) Append(a *Columns) *Columns { + c.Columns = append(c.Columns, a.Columns...) + return c +} + +// IsEmpty +func (c *Columns) IsEmpty() bool { + if c == nil || len(c.Columns) < 1 { + return true + } + return false +} + +// Compile transforms the Columns into an equivalent SQL representation. +func (c *Columns) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(c); ok { + return z, nil + } + + l := len(c.Columns) + + if l > 0 { + out := make([]string, l) + + for i := 0; i < l; i++ { + out[i], err = c.Columns[i].Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.Join(out, layout.IdentifierSeparator) + } else { + compiled = "*" + } + + layout.Write(c, compiled) + + return +} diff --git a/internal/sqladapter/exql/columns_test.go b/internal/sqladapter/exql/columns_test.go new file mode 100644 index 0000000..39cbd5c --- /dev/null +++ b/internal/sqladapter/exql/columns_test.go @@ -0,0 +1,72 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestColumns(t *testing.T) { + columns := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + + s, err := columns.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"id", "customer", "service_id", "role"."name", "role"."id"`, s) +} + +func BenchmarkJoinColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkColumnsHash(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkColumnsCompile(b *testing.B) { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := JoinColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/database.go b/internal/sqladapter/exql/database.go new file mode 100644 index 0000000..36bf37c --- /dev/null +++ b/internal/sqladapter/exql/database.go @@ -0,0 +1,37 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Database represents a SQL database. +type Database struct { + Name string +} + +var _ = Fragment(&Database{}) + +// DatabaseWithName returns a Database with the given name. +func DatabaseWithName(name string) *Database { + return &Database{Name: name} +} + +// Hash returns a unique identifier for the struct. +func (d *Database) Hash() uint64 { + if d == nil { + return cache.NewHash(FragmentType_Database, nil) + } + return cache.NewHash(FragmentType_Database, d.Name) +} + +// Compile transforms the Database into an equivalent SQL representation. +func (d *Database) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(d); ok { + return c, nil + } + + compiled = layout.MustCompile(layout.IdentifierQuote, Raw{Value: d.Name}) + + layout.Write(d, compiled) + return +} diff --git a/internal/sqladapter/exql/database_test.go b/internal/sqladapter/exql/database_test.go new file mode 100644 index 0000000..aba55be --- /dev/null +++ b/internal/sqladapter/exql/database_test.go @@ -0,0 +1,45 @@ +package exql + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestDatabaseCompile(t *testing.T) { + column := Database{Name: "name"} + s, err := column.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"name"`, s) +} + +func BenchmarkDatabaseHash(b *testing.B) { + c := Database{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkDatabaseCompile(b *testing.B) { + c := Database{Name: "name"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkDatabaseCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Database{Name: "name"} + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkDatabaseCompileNoCache2(b *testing.B) { + for i := 0; i < b.N; i++ { + c := Database{Name: strconv.Itoa(i)} + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/default.go b/internal/sqladapter/exql/default.go new file mode 100644 index 0000000..12b4354 --- /dev/null +++ b/internal/sqladapter/exql/default.go @@ -0,0 +1,192 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +const ( + defaultColumnSeparator = `.` + defaultIdentifierSeparator = `, ` + defaultIdentifierQuote = `"{{.Value}}"` + defaultValueSeparator = `, ` + defaultValueQuote = `'{{.}}'` + defaultAndKeyword = `AND` + defaultOrKeyword = `OR` + defaultDescKeyword = `DESC` + defaultAscKeyword = `ASC` + defaultAssignmentOperator = `=` + defaultClauseGroup = `({{.}})` + defaultClauseOperator = ` {{.}} ` + defaultColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + defaultTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultSortByColumnLayout = `{{.Column}} {{.Order}}` + + defaultOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + defaultWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + defaultUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + defaultJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + defaultOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + defaultSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{.GroupBy | compile}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + defaultDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + defaultUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + defaultCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + + {{if .Limit}} + LIMIT {{.Limit | compile}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + + defaultInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if .Columns }}({{.Columns | compile}}){{end}} + VALUES + {{.Values | compile}} + {{if .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + defaultTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} + ` + + defaultDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + defaultDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + defaultGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var defaultTemplate = &Template{ + AndKeyword: defaultAndKeyword, + AscKeyword: defaultAscKeyword, + AssignmentOperator: defaultAssignmentOperator, + ClauseGroup: defaultClauseGroup, + ClauseOperator: defaultClauseOperator, + ColumnAliasLayout: defaultColumnAliasLayout, + ColumnSeparator: defaultColumnSeparator, + ColumnValue: defaultColumnValue, + CountLayout: defaultCountLayout, + DeleteLayout: defaultDeleteLayout, + DescKeyword: defaultDescKeyword, + DropDatabaseLayout: defaultDropDatabaseLayout, + DropTableLayout: defaultDropTableLayout, + GroupByLayout: defaultGroupByLayout, + IdentifierQuote: defaultIdentifierQuote, + IdentifierSeparator: defaultIdentifierSeparator, + InsertLayout: defaultInsertLayout, + JoinLayout: defaultJoinLayout, + OnLayout: defaultOnLayout, + OrKeyword: defaultOrKeyword, + OrderByLayout: defaultOrderByLayout, + SelectLayout: defaultSelectLayout, + SortByColumnLayout: defaultSortByColumnLayout, + TableAliasLayout: defaultTableAliasLayout, + TruncateLayout: defaultTruncateLayout, + UpdateLayout: defaultUpdateLayout, + UsingLayout: defaultUsingLayout, + ValueQuote: defaultValueQuote, + ValueSeparator: defaultValueSeparator, + WhereLayout: defaultWhereLayout, + + Cache: cache.NewCache(), +} diff --git a/internal/sqladapter/exql/errors.go b/internal/sqladapter/exql/errors.go new file mode 100644 index 0000000..b9c8b85 --- /dev/null +++ b/internal/sqladapter/exql/errors.go @@ -0,0 +1,5 @@ +package exql + +const ( + errExpectingHashableFmt = "expecting hashable value, got %T" +) diff --git a/internal/sqladapter/exql/group_by.go b/internal/sqladapter/exql/group_by.go new file mode 100644 index 0000000..6583d44 --- /dev/null +++ b/internal/sqladapter/exql/group_by.go @@ -0,0 +1,60 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// GroupBy represents a SQL's "group by" statement. +type GroupBy struct { + Columns Fragment +} + +var _ = Fragment(&GroupBy{}) + +type groupByT struct { + GroupColumns string +} + +// Hash returns a unique identifier. +func (g *GroupBy) Hash() uint64 { + if g == nil { + return cache.NewHash(FragmentType_GroupBy, nil) + } + return cache.NewHash(FragmentType_GroupBy, g.Columns) +} + +// GroupByColumns creates and returns a GroupBy with the given column. +func GroupByColumns(columns ...Fragment) *GroupBy { + return &GroupBy{Columns: JoinColumns(columns...)} +} + +func (g *GroupBy) IsEmpty() bool { + if g == nil || g.Columns == nil { + return true + } + return g.Columns.(hasIsEmpty).IsEmpty() +} + +// Compile transforms the GroupBy into an equivalent SQL representation. +func (g *GroupBy) Compile(layout *Template) (compiled string, err error) { + + if c, ok := layout.Read(g); ok { + return c, nil + } + + if g.Columns != nil { + columns, err := g.Columns.Compile(layout) + if err != nil { + return "", err + } + + data := groupByT{ + GroupColumns: columns, + } + compiled = layout.MustCompile(layout.GroupByLayout, data) + } + + layout.Write(g, compiled) + + return +} diff --git a/internal/sqladapter/exql/group_by_test.go b/internal/sqladapter/exql/group_by_test.go new file mode 100644 index 0000000..cdc1e6f --- /dev/null +++ b/internal/sqladapter/exql/group_by_test.go @@ -0,0 +1,71 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestGroupBy(t *testing.T) { + columns := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + + s := mustTrim(columns.Compile(defaultTemplate)) + assert.Equal(t, `GROUP BY "id", "customer", "service_id", "role"."name", "role"."id"`, s) +} + +func BenchmarkGroupByColumns(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = GroupByColumns( + &Column{Name: "a"}, + &Column{Name: "b"}, + &Column{Name: "c"}, + ) + } +} + +func BenchmarkGroupByHash(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.Hash() + } +} + +func BenchmarkGroupByCompile(b *testing.B) { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = c.Compile(defaultTemplate) + } +} + +func BenchmarkGroupByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + c := GroupByColumns( + &Column{Name: "id"}, + &Column{Name: "customer"}, + &Column{Name: "service_id"}, + &Column{Name: "role.name"}, + &Column{Name: "role.id"}, + ) + _, _ = c.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/interfaces.go b/internal/sqladapter/exql/interfaces.go new file mode 100644 index 0000000..efa34ed --- /dev/null +++ b/internal/sqladapter/exql/interfaces.go @@ -0,0 +1,20 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Fragment is any interface that can be both cached and compiled. +type Fragment interface { + cache.Hashable + + compilable +} + +type compilable interface { + Compile(*Template) (string, error) +} + +type hasIsEmpty interface { + IsEmpty() bool +} diff --git a/internal/sqladapter/exql/join.go b/internal/sqladapter/exql/join.go new file mode 100644 index 0000000..eed8f5b --- /dev/null +++ b/internal/sqladapter/exql/join.go @@ -0,0 +1,195 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +type innerJoinT struct { + Type string + Table string + On string + Using string +} + +// Joins represents the union of different join conditions. +type Joins struct { + Conditions []Fragment +} + +var _ = Fragment(&Joins{}) + +// Hash returns a unique identifier for the struct. +func (j *Joins) Hash() uint64 { + if j == nil { + return cache.NewHash(FragmentType_Joins, nil) + } + h := cache.InitHash(FragmentType_Joins) + for i := range j.Conditions { + h = cache.AddToHash(h, j.Conditions[i]) + } + return h +} + +// Compile transforms the Where into an equivalent SQL representation. +func (j *Joins) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(j); ok { + return c, nil + } + + l := len(j.Conditions) + + chunks := make([]string, 0, l) + + if l > 0 { + for i := 0; i < l; i++ { + chunk, err := j.Conditions[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + } + + compiled = strings.Join(chunks, " ") + + layout.Write(j, compiled) + + return +} + +// JoinConditions creates a Joins object. +func JoinConditions(joins ...*Join) *Joins { + fragments := make([]Fragment, len(joins)) + for i := range fragments { + fragments[i] = joins[i] + } + return &Joins{Conditions: fragments} +} + +// Join represents a generic JOIN statement. +type Join struct { + Type string + Table Fragment + On Fragment + Using Fragment +} + +var _ = Fragment(&Join{}) + +// Hash returns a unique identifier for the struct. +func (j *Join) Hash() uint64 { + if j == nil { + return cache.NewHash(FragmentType_Join, nil) + } + return cache.NewHash(FragmentType_Join, j.Type, j.Table, j.On, j.Using) +} + +// Compile transforms the Join into its equivalent SQL representation. +func (j *Join) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(j); ok { + return c, nil + } + + if j.Table == nil { + return "", nil + } + + table, err := j.Table.Compile(layout) + if err != nil { + return "", err + } + + on, err := layout.doCompile(j.On) + if err != nil { + return "", err + } + + using, err := layout.doCompile(j.Using) + if err != nil { + return "", err + } + + data := innerJoinT{ + Type: j.Type, + Table: table, + On: on, + Using: using, + } + + compiled = layout.MustCompile(layout.JoinLayout, data) + layout.Write(j, compiled) + return +} + +// On represents JOIN conditions. +type On Where + +var _ = Fragment(&On{}) + +func (o *On) Hash() uint64 { + if o == nil { + return cache.NewHash(FragmentType_On, nil) + } + return cache.NewHash(FragmentType_On, (*Where)(o)) +} + +// Compile transforms the On into an equivalent SQL representation. +func (o *On) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(o); ok { + return c, nil + } + + grouped, err := groupCondition(layout, o.Conditions, layout.MustCompile(layout.ClauseOperator, layout.AndKeyword)) + if err != nil { + return "", err + } + + if grouped != "" { + compiled = layout.MustCompile(layout.OnLayout, conds{grouped}) + } + + layout.Write(o, compiled) + return +} + +// Using represents a USING function. +type Using Columns + +var _ = Fragment(&Using{}) + +type usingT struct { + Columns string +} + +func (u *Using) Hash() uint64 { + if u == nil { + return cache.NewHash(FragmentType_Using, nil) + } + return cache.NewHash(FragmentType_Using, (*Columns)(u)) +} + +// Compile transforms the Using into an equivalent SQL representation. +func (u *Using) Compile(layout *Template) (compiled string, err error) { + if u == nil { + return "", nil + } + + if c, ok := layout.Read(u); ok { + return c, nil + } + + if len(u.Columns) > 0 { + c := Columns(*u) + columns, err := c.Compile(layout) + if err != nil { + return "", err + } + data := usingT{Columns: columns} + compiled = layout.MustCompile(layout.UsingLayout, data) + } + + layout.Write(u, compiled) + return +} diff --git a/internal/sqladapter/exql/join_test.go b/internal/sqladapter/exql/join_test.go new file mode 100644 index 0000000..65ce6aa --- /dev/null +++ b/internal/sqladapter/exql/join_test.go @@ -0,0 +1,221 @@ +package exql + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOnAndRawOrAnd(t *testing.T) { + on := OnConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + s := mustTrim(on.Compile(defaultTemplate)) + assert.Equal(t, `ON (("id" > 8 AND "id" < 99) AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, s) +} + +func TestUsing(t *testing.T) { + using := UsingColumns( + &Column{Name: "country"}, + &Column{Name: "state"}, + ) + + s := mustTrim(using.Compile(defaultTemplate)) + assert.Equal(t, `USING ("country", "state")`, s) +} + +func TestJoinOn(t *testing.T) { + join := JoinConditions( + &Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }, + ) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")`, s) +} + +func TestInnerJoinOn(t *testing.T) { + join := JoinConditions(&Join{ + Type: "INNER", + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(ColumnWithName("a.id")), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(ColumnWithName("a.code")), + }, + ), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `INNER JOIN "countries" AS "c" ON ("p"."country_id" = "a"."id" AND "p"."country_code" = "a"."code")`, s) +} + +func TestLeftJoinUsing(t *testing.T) { + join := JoinConditions(&Join{ + Type: "LEFT", + Table: TableWithName("countries"), + Using: UsingColumns(ColumnWithName("name")), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `LEFT JOIN "countries" USING ("name")`, s) +} + +func TestNaturalJoinOn(t *testing.T) { + join := JoinConditions(&Join{ + Table: TableWithName("countries"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL JOIN "countries"`, s) +} + +func TestNaturalInnerJoinOn(t *testing.T) { + join := JoinConditions(&Join{ + Type: "INNER", + Table: TableWithName("countries"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL INNER JOIN "countries"`, s) +} + +func TestCrossJoin(t *testing.T) { + join := JoinConditions(&Join{ + Type: "CROSS", + Table: TableWithName("countries"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `CROSS JOIN "countries"`, s) +} + +func TestMultipleJoins(t *testing.T) { + join := JoinConditions(&Join{ + Type: "LEFT", + Table: TableWithName("countries"), + }, &Join{ + Table: TableWithName("cities"), + }) + + s := mustTrim(join.Compile(defaultTemplate)) + assert.Equal(t, `NATURAL LEFT JOIN "countries" NATURAL JOIN "cities"`, s) +} + +func BenchmarkJoin(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = JoinConditions(&Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + } +} + +func BenchmarkCompileJoin(b *testing.B) { + j := JoinConditions(&Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = j.Compile(defaultTemplate) + } +} + +func BenchmarkCompileJoinNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + j := JoinConditions(&Join{ + Table: TableWithName("countries c"), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + _, _ = j.Compile(defaultTemplate) + } +} + +func BenchmarkCompileJoinNoCache2(b *testing.B) { + for i := 0; i < b.N; i++ { + j := JoinConditions(&Join{ + Table: TableWithName(fmt.Sprintf("countries c%d", i)), + On: OnConditions( + &ColumnValue{ + Column: &Column{Name: "p.country_id"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.id"}), + }, + &ColumnValue{ + Column: &Column{Name: "p.country_code"}, + Operator: "=", + Value: NewValue(&Column{Name: "a.code"}), + }, + ), + }) + _, _ = j.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/order_by.go b/internal/sqladapter/exql/order_by.go new file mode 100644 index 0000000..d41a48a --- /dev/null +++ b/internal/sqladapter/exql/order_by.go @@ -0,0 +1,175 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Order represents the order in which SQL results are sorted. +type Order uint8 + +// Possible values for Order +const ( + Order_Default Order = iota + + Order_Ascendent + Order_Descendent +) + +func (o Order) Hash() uint64 { + return cache.NewHash(FragmentType_Order, uint8(o)) +} + +// SortColumn represents the column-order relation in an ORDER BY clause. +type SortColumn struct { + Column Fragment + Order +} + +var _ = Fragment(&SortColumn{}) + +type sortColumnT struct { + Column string + Order string +} + +var _ = Fragment(&SortColumn{}) + +// SortColumns represents the columns in an ORDER BY clause. +type SortColumns struct { + Columns []Fragment +} + +var _ = Fragment(&SortColumns{}) + +// OrderBy represents an ORDER BY clause. +type OrderBy struct { + SortColumns Fragment +} + +var _ = Fragment(&OrderBy{}) + +type orderByT struct { + SortColumns string +} + +// JoinSortColumns creates and returns an array of column-order relations. +func JoinSortColumns(values ...Fragment) *SortColumns { + return &SortColumns{Columns: values} +} + +// JoinWithOrderBy creates an returns an OrderBy using the given SortColumns. +func JoinWithOrderBy(sc *SortColumns) *OrderBy { + return &OrderBy{SortColumns: sc} +} + +// Hash returns a unique identifier for the struct. +func (s *SortColumn) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_SortColumn, nil) + } + return cache.NewHash(FragmentType_SortColumn, s.Column, s.Order) +} + +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *SortColumn) Compile(layout *Template) (compiled string, err error) { + + if c, ok := layout.Read(s); ok { + return c, nil + } + + column, err := s.Column.Compile(layout) + if err != nil { + return "", err + } + + orderBy, err := s.Order.Compile(layout) + if err != nil { + return "", err + } + + data := sortColumnT{Column: column, Order: orderBy} + + compiled = layout.MustCompile(layout.SortByColumnLayout, data) + + layout.Write(s, compiled) + + return +} + +// Hash returns a unique identifier for the struct. +func (s *SortColumns) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_SortColumns, nil) + } + h := cache.InitHash(FragmentType_SortColumns) + for i := range s.Columns { + h = cache.AddToHash(h, s.Columns[i]) + } + return h +} + +// Compile transforms the SortColumns into an equivalent SQL representation. +func (s *SortColumns) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(s); ok { + return z, nil + } + + z := make([]string, len(s.Columns)) + + for i := range s.Columns { + z[i], err = s.Columns[i].Compile(layout) + if err != nil { + return "", err + } + } + + compiled = strings.Join(z, layout.IdentifierSeparator) + + layout.Write(s, compiled) + + return +} + +// Hash returns a unique identifier for the struct. +func (s *OrderBy) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_OrderBy, nil) + } + return cache.NewHash(FragmentType_OrderBy, s.SortColumns) +} + +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s *OrderBy) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(s); ok { + return z, nil + } + + if s.SortColumns != nil { + sortColumns, err := s.SortColumns.Compile(layout) + if err != nil { + return "", err + } + + data := orderByT{ + SortColumns: sortColumns, + } + compiled = layout.MustCompile(layout.OrderByLayout, data) + } + + layout.Write(s, compiled) + + return +} + +// Compile transforms the SortColumn into an equivalent SQL representation. +func (s Order) Compile(layout *Template) (string, error) { + switch s { + case Order_Ascendent: + return layout.AscKeyword, nil + case Order_Descendent: + return layout.DescKeyword, nil + } + return "", nil +} diff --git a/internal/sqladapter/exql/order_by_test.go b/internal/sqladapter/exql/order_by_test.go new file mode 100644 index 0000000..f214f86 --- /dev/null +++ b/internal/sqladapter/exql/order_by_test.go @@ -0,0 +1,154 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestOrderBy(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + + s := mustTrim(o.Compile(defaultTemplate)) + assert.Equal(t, `ORDER BY "foo"`, s) +} + +func TestOrderByRaw(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Raw{Value: "CASE WHEN id IN ? THEN 0 ELSE 1 END"}}, + ), + ) + + s := mustTrim(o.Compile(defaultTemplate)) + assert.Equal(t, `ORDER BY CASE WHEN id IN ? THEN 0 ELSE 1 END`, s) +} + +func TestOrderByDesc(t *testing.T) { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + ), + ) + + s := mustTrim(o.Compile(defaultTemplate)) + assert.Equal(t, `ORDER BY "foo" DESC`, s) +} + +func BenchmarkOrderBy(b *testing.B) { + for i := 0; i < b.N; i++ { + JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + } +} + +func BenchmarkOrderByHash(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + o.Hash() + } +} + +func BenchmarkCompileOrderByCompile(b *testing.B) { + o := OrderBy{ + SortColumns: JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + } + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderByCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ) + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompile(b *testing.B) { + o := Order_Descendent + for i := 0; i < b.N; i++ { + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkCompileOrderCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + o := Order_Descendent + _, _ = o.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnHash(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnCompile(b *testing.B) { + s := &SortColumn{Column: &Column{Name: "foo"}} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := &SortColumn{Column: &Column{Name: "foo"}} + _, _ = s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsHash(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + s.Hash() + } +} + +func BenchmarkSortColumnsCompile(b *testing.B) { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = s.Compile(defaultTemplate) + } +} + +func BenchmarkSortColumnsCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + s := JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + &SortColumn{Column: &Column{Name: "bar"}}, + ) + _, _ = s.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/raw.go b/internal/sqladapter/exql/raw.go new file mode 100644 index 0000000..808f6e6 --- /dev/null +++ b/internal/sqladapter/exql/raw.go @@ -0,0 +1,48 @@ +package exql + +import ( + "fmt" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +var ( + _ = fmt.Stringer(&Raw{}) +) + +// Raw represents a value that is meant to be used in a query without escaping. +type Raw struct { + Value string +} + +func NewRawValue(v interface{}) (*Raw, error) { + switch t := v.(type) { + case string: + return &Raw{Value: t}, nil + case int, uint, int64, uint64, int32, uint32, int16, uint16: + return &Raw{Value: fmt.Sprintf("%d", t)}, nil + case fmt.Stringer: + return &Raw{Value: t.String()}, nil + } + return nil, fmt.Errorf("unexpected type: %T", v) +} + +// Hash returns a unique identifier for the struct. +func (r *Raw) Hash() uint64 { + if r == nil { + return cache.NewHash(FragmentType_Raw, nil) + } + return cache.NewHash(FragmentType_Raw, r.Value) +} + +// Compile returns the raw value. +func (r *Raw) Compile(*Template) (string, error) { + return r.Value, nil +} + +// String returns the raw value. +func (r *Raw) String() string { + return r.Value +} + +var _ = Fragment(&Raw{}) diff --git a/internal/sqladapter/exql/raw_test.go b/internal/sqladapter/exql/raw_test.go new file mode 100644 index 0000000..66e38b1 --- /dev/null +++ b/internal/sqladapter/exql/raw_test.go @@ -0,0 +1,51 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestRawString(t *testing.T) { + raw := &Raw{Value: "foo"} + s, err := raw.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `foo`, s) +} + +func TestRawCompile(t *testing.T) { + raw := &Raw{Value: "foo"} + s, err := raw.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `foo`, s) +} + +func BenchmarkRawCreate(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = Raw{Value: "foo"} + } +} + +func BenchmarkRawString(b *testing.B) { + raw := &Raw{Value: "foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = raw.String() + } +} + +func BenchmarkRawCompile(b *testing.B) { + raw := &Raw{Value: "foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = raw.Compile(defaultTemplate) + } +} + +func BenchmarkRawHash(b *testing.B) { + raw := &Raw{Value: "foo"} + b.ResetTimer() + for i := 0; i < b.N; i++ { + raw.Hash() + } +} diff --git a/internal/sqladapter/exql/returning.go b/internal/sqladapter/exql/returning.go new file mode 100644 index 0000000..9b882dc --- /dev/null +++ b/internal/sqladapter/exql/returning.go @@ -0,0 +1,41 @@ +package exql + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Returning represents a RETURNING clause. +type Returning struct { + *Columns +} + +// Hash returns a unique identifier for the struct. +func (r *Returning) Hash() uint64 { + if r == nil { + return cache.NewHash(FragmentType_Returning, nil) + } + return cache.NewHash(FragmentType_Returning, r.Columns) +} + +var _ = Fragment(&Returning{}) + +// ReturningColumns creates and returns an array of Column. +func ReturningColumns(columns ...Fragment) *Returning { + return &Returning{Columns: &Columns{Columns: columns}} +} + +// Compile transforms the clause into its equivalent SQL representation. +func (r *Returning) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(r); ok { + return z, nil + } + + compiled, err = r.Columns.Compile(layout) + if err != nil { + return "", err + } + + layout.Write(r, compiled) + + return +} diff --git a/internal/sqladapter/exql/statement.go b/internal/sqladapter/exql/statement.go new file mode 100644 index 0000000..4351682 --- /dev/null +++ b/internal/sqladapter/exql/statement.go @@ -0,0 +1,132 @@ +package exql + +import ( + "errors" + "reflect" + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +var errUnknownTemplateType = errors.New("Unknown template type") + +// represents different kinds of SQL statements. +type Statement struct { + Type + Table Fragment + Database Fragment + Columns Fragment + Values Fragment + Distinct bool + ColumnValues Fragment + OrderBy Fragment + GroupBy Fragment + Joins Fragment + Where Fragment + Returning Fragment + + Limit + Offset + + SQL string + + amendFn func(string) string +} + +func (layout *Template) doCompile(c Fragment) (string, error) { + if c != nil && !reflect.ValueOf(c).IsNil() { + return c.Compile(layout) + } + return "", nil +} + +// Hash returns a unique identifier for the struct. +func (s *Statement) Hash() uint64 { + if s == nil { + return cache.NewHash(FragmentType_Statement, nil) + } + return cache.NewHash( + FragmentType_Statement, + s.Type, + s.Table, + s.Database, + s.Columns, + s.Values, + s.Distinct, + s.ColumnValues, + s.OrderBy, + s.GroupBy, + s.Joins, + s.Where, + s.Returning, + s.Limit, + s.Offset, + s.SQL, + ) +} + +func (s *Statement) SetAmendment(amendFn func(string) string) { + s.amendFn = amendFn +} + +func (s *Statement) Amend(in string) string { + if s.amendFn == nil { + return in + } + return s.amendFn(in) +} + +func (s *Statement) template(layout *Template) (string, error) { + switch s.Type { + case Truncate: + return layout.TruncateLayout, nil + case DropTable: + return layout.DropTableLayout, nil + case DropDatabase: + return layout.DropDatabaseLayout, nil + case Count: + return layout.CountLayout, nil + case Select: + return layout.SelectLayout, nil + case Delete: + return layout.DeleteLayout, nil + case Update: + return layout.UpdateLayout, nil + case Insert: + return layout.InsertLayout, nil + default: + return "", errUnknownTemplateType + } +} + +// Compile transforms the Statement into an equivalent SQL query. +func (s *Statement) Compile(layout *Template) (compiled string, err error) { + if s.Type == SQL { + // No need to hit the cache. + return s.SQL, nil + } + + if z, ok := layout.Read(s); ok { + return s.Amend(z), nil + } + + tpl, err := s.template(layout) + if err != nil { + return "", err + } + + compiled = layout.MustCompile(tpl, s) + + compiled = strings.TrimSpace(compiled) + layout.Write(s, compiled) + + return s.Amend(compiled), nil +} + +// RawSQL represents a raw SQL statement. +func RawSQL(s string) *Statement { + return &Statement{ + Type: SQL, + SQL: s, + } +} diff --git a/internal/sqladapter/exql/statement_test.go b/internal/sqladapter/exql/statement_test.go new file mode 100644 index 0000000..28e726a --- /dev/null +++ b/internal/sqladapter/exql/statement_test.go @@ -0,0 +1,703 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestTruncateTable(t *testing.T) { + stmt := Statement{ + Type: Truncate, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `TRUNCATE TABLE "table_name"`, s) +} + +func TestDropTable(t *testing.T) { + stmt := Statement{ + Type: DropTable, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DROP TABLE "table_name"`, s) +} + +func TestDropDatabase(t *testing.T) { + stmt := Statement{ + Type: DropDatabase, + Database: &Database{Name: "table_name"}, + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DROP DATABASE "table_name"`, s) +} + +func TestCount(t *testing.T) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "table_name"`, s) +} + +func TestCountRelation(t *testing.T) { + stmt := Statement{ + Type: Count, + Table: TableWithName("information_schema.tables"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "information_schema"."tables"`, s) +} + +func TestCountWhere(t *testing.T) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: &Raw{Value: "7"}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT COUNT(1) AS _t FROM "table_name" WHERE ("a" = 7)`, s) +} + +func TestSelectStarFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table_name"`, s) +} + +func TestSelectStarFromAlias(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo"`, s) +} + +func TestSelectStarFromRawWhere(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id)`, s) + } + + { + stmt := Statement{ + Type: Select, + Table: TableWithName("table.name AS foo"), + Where: WhereConditions( + &Raw{Value: "foo.id = bar.foo_id"}, + &Raw{Value: "baz.id = exp.baz_id"}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "table"."name" AS "foo" WHERE (foo.id = bar.foo_id AND baz.id = exp.baz_id)`, s) + } +} + +func TestSelectStarFromMany(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"`, s) +} + +func TestSelectTableStarFromMany(t *testing.T) { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo.name"}, + &Column{Name: "BAR.*"}, + &Column{Name: "baz.last_name"}, + ), + Table: TableWithName("first.table AS foo, second.table as BAR, third.table aS baz"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo"."name", "BAR".*, "baz"."last_name" FROM "first"."table" AS "foo", "second"."table" AS "BAR", "third"."table" AS "baz"`, s) +} + +func TestSelectArtistNameFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist"), + Columns: JoinColumns( + &Column{Name: "artist.name"}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "artist"."name" FROM "artist"`, s) +} + +func TestSelectJoin(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist a"), + Columns: JoinColumns( + &Column{Name: "a.name"}, + ), + Joins: JoinConditions(&Join{ + Table: TableWithName("books b"), + On: OnConditions( + &ColumnValue{ + Column: ColumnWithName("b.author_id"), + Operator: `=`, + Value: NewValue(ColumnWithName("a.id")), + }, + ), + }), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" ON ("b"."author_id" = "a"."id")`, s) +} + +func TestSelectJoinUsing(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist a"), + Columns: JoinColumns( + &Column{Name: "a.name"}, + ), + Joins: JoinConditions(&Join{ + Table: TableWithName("books b"), + Using: UsingColumns( + ColumnWithName("artist_id"), + ColumnWithName("country"), + ), + }), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a" JOIN "books" AS "b" USING ("artist_id", "country")`, s) +} + +func TestSelectUnfinishedJoin(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist a"), + Columns: JoinColumns( + &Column{Name: "a.name"}, + ), + Joins: JoinConditions(&Join{}), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "a"."name" FROM "artist" AS "a"`, s) +} + +func TestSelectNaturalJoin(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName("artist"), + Joins: JoinConditions(&Join{ + Table: TableWithName("books"), + }), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "artist" NATURAL JOIN "books"`, s) +} + +func TestSelectRawFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Table: TableWithName(`artist`), + Columns: JoinColumns( + &Column{Name: `artist.name`}, + &Column{Name: &Raw{Value: `CONCAT(artist.name, " ", artist.last_name)`}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "artist"."name", CONCAT(artist.name, " ", artist.last_name) FROM "artist"`, s) +} + +func TestSelectFieldsFrom(t *testing.T) { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name"`, s) +} + +func TestSelectFieldsFromWithLimitOffset(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Offset: 17, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" OFFSET 17`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Limit: 42, + Offset: 17, + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" LIMIT 42 OFFSET 17`, s) + } +} + +func TestStatementGroupBy(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo"`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + GroupBy: GroupByColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" GROUP BY "foo", "bar"`, s) + } +} + +func TestSelectFieldsFromWithOrderBy(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo"`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" ASC`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: "foo"}, Order: Order_Descendent}, + &SortColumn{Column: &Column{Name: "bar"}, Order: Order_Ascendent}, + &SortColumn{Column: &Column{Name: "baz"}, Order: Order_Descendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY "foo" DESC, "bar" ASC, "baz" DESC`, s) + } + + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + OrderBy: JoinWithOrderBy( + JoinSortColumns( + &SortColumn{Column: &Column{Name: &Raw{Value: "FOO()"}}, Order: Order_Descendent}, + &SortColumn{Column: &Column{Name: &Raw{Value: "BAR()"}}, Order: Order_Ascendent}, + ), + ), + Table: TableWithName("table_name"), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" ORDER BY FOO() DESC, BAR() ASC`, s) + } +} + +func TestSelectFieldsFromWhere(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99')`, s) + } +} + +func TestSelectFieldsFromWhereLimitOffset(t *testing.T) { + { + stmt := Statement{ + Type: Select, + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + Limit: 10, + Offset: 23, + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT "foo", "bar", "baz" FROM "table_name" WHERE ("baz" = '99') LIMIT 10 OFFSET 23`, s) + } +} + +func TestDelete(t *testing.T) { + stmt := Statement{ + Type: Delete, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `DELETE FROM "table_name" WHERE ("baz" = '99')`, s) +} + +func TestUpdate(t *testing.T) { + { + stmt := Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `UPDATE "table_name" SET "foo" = '76' WHERE ("baz" = '99')`, s) + } + + { + stmt := Statement{ + Type: Update, + Table: TableWithName("table_name"), + ColumnValues: JoinColumnValues( + &ColumnValue{Column: &Column{Name: "foo"}, Operator: "=", Value: NewValue(76)}, + &ColumnValue{Column: &Column{Name: "bar"}, Operator: "=", Value: NewValue(&Raw{Value: "88"})}, + ), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `UPDATE "table_name" SET "foo" = '76', "bar" = 88 WHERE ("baz" = '99')`, s) + } +} + +func TestInsert(t *testing.T) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3)`, s) +} + +func TestInsertMultiple(t *testing.T) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: JoinValueGroups( + NewValueGroup( + NewValue("1"), + NewValue("2"), + NewValue(&Raw{Value: "3"}), + ), + NewValueGroup( + NewValue(&Raw{Value: "4"}), + NewValue(&Raw{Value: "5"}), + NewValue(&Raw{Value: "6"}), + ), + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3), (4, 5, 6)`, s) +} + +func TestInsertReturning(t *testing.T) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Returning: ReturningColumns( + ColumnWithName("id"), + ), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `INSERT INTO "table_name" ("foo", "bar", "baz") VALUES ('1', '2', 3) RETURNING "id"`, s) +} + +func TestRawSQLStatement(t *testing.T) { + stmt := RawSQL(`SELECT * FROM "foo" ORDER BY "bar"`) + + s := mustTrim(stmt.Compile(defaultTemplate)) + assert.Equal(t, `SELECT * FROM "foo" ORDER BY "bar"`, s) +} + +func BenchmarkStatementSimpleQuery(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, + ), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementSimpleQueryHash(b *testing.B) { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, + ), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = stmt.Hash() + } +} + +func BenchmarkStatementSimpleQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Count, + Table: TableWithName("table_name"), + Where: WhereConditions( + &ColumnValue{Column: &Column{Name: "a"}, Operator: "=", Value: NewValue(&Raw{Value: "7"})}, + ), + } + _, _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQuery(b *testing.B) { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = stmt.Compile(defaultTemplate) + } +} + +func BenchmarkStatementComplexQueryNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + stmt := Statement{ + Type: Insert, + Table: TableWithName("table_name"), + Columns: JoinColumns( + &Column{Name: "foo"}, + &Column{Name: "bar"}, + &Column{Name: "baz"}, + ), + Values: NewValueGroup( + &Value{V: "1"}, + &Value{V: 2}, + &Value{V: &Raw{Value: "3"}}, + ), + } + _, _ = stmt.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/table.go b/internal/sqladapter/exql/table.go new file mode 100644 index 0000000..b831b45 --- /dev/null +++ b/internal/sqladapter/exql/table.go @@ -0,0 +1,98 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +type tableT struct { + Name string + Alias string +} + +// Table struct represents a SQL table. +type Table struct { + Name interface{} +} + +var _ = Fragment(&Table{}) + +func quotedTableName(layout *Template, input string) string { + input = trimString(input) + + // chunks := reAliasSeparator.Split(input, 2) + chunks := separateByAS(input) + + if len(chunks) == 1 { + // chunks = reSpaceSeparator.Split(input, 2) + chunks = separateBySpace(input) + } + + name := chunks[0] + + nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2) + + for i := range nameChunks { + // nameChunks[i] = strings.TrimSpace(nameChunks[i]) + nameChunks[i] = trimString(nameChunks[i]) + nameChunks[i] = layout.MustCompile(layout.IdentifierQuote, Raw{Value: nameChunks[i]}) + } + + name = strings.Join(nameChunks, layout.ColumnSeparator) + + var alias string + + if len(chunks) > 1 { + // alias = strings.TrimSpace(chunks[1]) + alias = trimString(chunks[1]) + alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias}) + } + + return layout.MustCompile(layout.TableAliasLayout, tableT{name, alias}) +} + +// TableWithName creates an returns a Table with the given name. +func TableWithName(name string) *Table { + return &Table{Name: name} +} + +// Hash returns a string hash of the table value. +func (t *Table) Hash() uint64 { + if t == nil { + return cache.NewHash(FragmentType_Table, nil) + } + return cache.NewHash(FragmentType_Table, t.Name) +} + +// Compile transforms a table struct into a SQL chunk. +func (t *Table) Compile(layout *Template) (compiled string, err error) { + + if z, ok := layout.Read(t); ok { + return z, nil + } + + switch value := t.Name.(type) { + case string: + if t.Name == "" { + return + } + + // Splitting tables by a comma + parts := separateByComma(value) + + l := len(parts) + + for i := 0; i < l; i++ { + parts[i] = quotedTableName(layout, parts[i]) + } + + compiled = strings.Join(parts, layout.IdentifierSeparator) + case Raw: + compiled = value.String() + } + + layout.Write(t, compiled) + + return +} diff --git a/internal/sqladapter/exql/table_test.go b/internal/sqladapter/exql/table_test.go new file mode 100644 index 0000000..08bc825 --- /dev/null +++ b/internal/sqladapter/exql/table_test.go @@ -0,0 +1,82 @@ +package exql + +import ( + "github.com/stretchr/testify/assert" + + "testing" +) + +func TestTableSimple(t *testing.T) { + table := TableWithName("artist") + assert.Equal(t, `"artist"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableCompound(t *testing.T) { + table := TableWithName("artist.foo") + assert.Equal(t, `"artist"."foo"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableCompoundAlias(t *testing.T) { + table := TableWithName("artist.foo AS baz") + + assert.Equal(t, `"artist"."foo" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableImplicitAlias(t *testing.T) { + table := TableWithName("artist.foo baz") + + assert.Equal(t, `"artist"."foo" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableMultiple(t *testing.T) { + table := TableWithName("artist.foo, artist.bar, artist.baz") + + assert.Equal(t, `"artist"."foo", "artist"."bar", "artist"."baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableMultipleAlias(t *testing.T) { + table := TableWithName("artist.foo AS foo, artist.bar as bar, artist.baz As baz") + + assert.Equal(t, `"artist"."foo" AS "foo", "artist"."bar" AS "bar", "artist"."baz" AS "baz"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableMinimal(t *testing.T) { + table := TableWithName("a") + + assert.Equal(t, `"a"`, mustTrim(table.Compile(defaultTemplate))) +} + +func TestTableEmpty(t *testing.T) { + table := TableWithName("") + + assert.Equal(t, "", mustTrim(table.Compile(defaultTemplate))) +} + +func BenchmarkTableWithName(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = TableWithName("foo") + } +} + +func BenchmarkTableHash(b *testing.B) { + t := TableWithName("name") + b.ResetTimer() + for i := 0; i < b.N; i++ { + t.Hash() + } +} + +func BenchmarkTableCompile(b *testing.B) { + t := TableWithName("name") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = t.Compile(defaultTemplate) + } +} + +func BenchmarkTableCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + t := TableWithName("name") + _, _ = t.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/template.go b/internal/sqladapter/exql/template.go new file mode 100644 index 0000000..4a148a2 --- /dev/null +++ b/internal/sqladapter/exql/template.go @@ -0,0 +1,148 @@ +package exql + +import ( + "bytes" + "reflect" + "sync" + "text/template" + + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Type is the type of SQL query the statement represents. +type Type uint8 + +// Values for Type. +const ( + NoOp Type = iota + + Truncate + DropTable + DropDatabase + Count + Insert + Select + Update + Delete + + SQL +) + +func (t Type) Hash() uint64 { + return cache.NewHash(FragmentType_StatementType, uint8(t)) +} + +type ( + // Limit represents the SQL limit in a query. + Limit int64 + // Offset represents the SQL offset in a query. + Offset int64 +) + +func (t Limit) Hash() uint64 { + return cache.NewHash(FragmentType_Limit, uint64(t)) +} + +func (t Offset) Hash() uint64 { + return cache.NewHash(FragmentType_Offset, uint64(t)) +} + +// Template is an SQL template. +type Template struct { + AndKeyword string + AscKeyword string + AssignmentOperator string + ClauseGroup string + ClauseOperator string + ColumnAliasLayout string + ColumnSeparator string + ColumnValue string + CountLayout string + DeleteLayout string + DescKeyword string + DropDatabaseLayout string + DropTableLayout string + GroupByLayout string + IdentifierQuote string + IdentifierSeparator string + InsertLayout string + JoinLayout string + OnLayout string + OrKeyword string + OrderByLayout string + SelectLayout string + SortByColumnLayout string + TableAliasLayout string + TruncateLayout string + UpdateLayout string + UsingLayout string + ValueQuote string + ValueSeparator string + WhereLayout string + + ComparisonOperator map[adapter.ComparisonOperator]string + + templateMutex sync.RWMutex + templateMap map[string]*template.Template + + *cache.Cache +} + +func (layout *Template) MustCompile(templateText string, data interface{}) string { + var b bytes.Buffer + + v, ok := layout.getTemplate(templateText) + if !ok { + v = template. + Must(template.New(""). + Funcs(map[string]interface{}{ + "defined": func(in Fragment) bool { + if in == nil || reflect.ValueOf(in).IsNil() { + return false + } + if check, ok := in.(hasIsEmpty); ok { + if check.IsEmpty() { + return false + } + } + return true + }, + "compile": func(in Fragment) (string, error) { + s, err := layout.doCompile(in) + if err != nil { + return "", err + } + return s, nil + }, + }). + Parse(templateText)) + + layout.setTemplate(templateText, v) + } + + if err := v.Execute(&b, data); err != nil { + panic("There was an error compiling the following template:\n" + templateText + "\nError was: " + err.Error()) + } + + return b.String() +} + +func (t *Template) getTemplate(k string) (*template.Template, bool) { + t.templateMutex.RLock() + defer t.templateMutex.RUnlock() + + if t.templateMap == nil { + t.templateMap = make(map[string]*template.Template) + } + + v, ok := t.templateMap[k] + return v, ok +} + +func (t *Template) setTemplate(k string, v *template.Template) { + t.templateMutex.Lock() + defer t.templateMutex.Unlock() + + t.templateMap[k] = v +} diff --git a/internal/sqladapter/exql/types.go b/internal/sqladapter/exql/types.go new file mode 100644 index 0000000..d6ecca9 --- /dev/null +++ b/internal/sqladapter/exql/types.go @@ -0,0 +1,35 @@ +package exql + +const ( + FragmentType_None uint64 = iota + 713910251627 + + FragmentType_And + FragmentType_Column + FragmentType_ColumnValue + FragmentType_ColumnValues + FragmentType_Columns + FragmentType_Database + FragmentType_GroupBy + FragmentType_Join + FragmentType_Joins + FragmentType_Nil + FragmentType_Or + FragmentType_Limit + FragmentType_Offset + FragmentType_OrderBy + FragmentType_Order + FragmentType_Raw + FragmentType_Returning + FragmentType_SortBy + FragmentType_SortColumn + FragmentType_SortColumns + FragmentType_Statement + FragmentType_StatementType + FragmentType_Table + FragmentType_Value + FragmentType_On + FragmentType_Using + FragmentType_ValueGroups + FragmentType_Values + FragmentType_Where +) diff --git a/internal/sqladapter/exql/utilities.go b/internal/sqladapter/exql/utilities.go new file mode 100644 index 0000000..972ebb4 --- /dev/null +++ b/internal/sqladapter/exql/utilities.go @@ -0,0 +1,151 @@ +package exql + +import ( + "strings" +) + +// isBlankSymbol returns true if the given byte is either space, tab, carriage +// return or newline. +func isBlankSymbol(in byte) bool { + return in == ' ' || in == '\t' || in == '\r' || in == '\n' +} + +// trimString returns a slice of s with a leading and trailing blank symbols +// (as defined by isBlankSymbol) removed. +func trimString(s string) string { + + // This conversion is rather slow. + // return string(trimBytes([]byte(s))) + + start, end := 0, len(s)-1 + + if end < start { + return "" + } + + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return "" + } + } + + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] +} + +// trimBytes returns a slice of s with a leading and trailing blank symbols (as +// defined by isBlankSymbol) removed. +func trimBytes(s []byte) []byte { + + start, end := 0, len(s)-1 + + if end < start { + return []byte{} + } + + for isBlankSymbol(s[start]) { + start++ + if start >= end { + return []byte{} + } + } + + for isBlankSymbol(s[end]) { + end-- + } + + return s[start : end+1] +} + +/* +// Separates by a comma, ignoring spaces too. +// This was slower than strings.Split. +func separateByComma(in string) (out []string) { + + out = []string{} + + start, lim := 0, len(in)-1 + + for start < lim { + var end int + + for end = start; end <= lim; end++ { + // Is a comma? + if in[end] == ',' { + break + } + } + + out = append(out, trimString(in[start:end])) + + start = end + 1 + } + + return +} +*/ + +// Separates by a comma, ignoring spaces too. +func separateByComma(in string) (out []string) { + out = strings.Split(in, ",") + for i := range out { + out[i] = trimString(out[i]) + } + return +} + +// Separates by spaces, ignoring spaces too. +func separateBySpace(in string) (out []string) { + if len(in) == 0 { + return []string{""} + } + + pre := strings.Split(in, " ") + out = make([]string, 0, len(pre)) + + for i := range pre { + pre[i] = trimString(pre[i]) + if pre[i] != "" { + out = append(out, pre[i]) + } + } + + return +} + +func separateByAS(in string) (out []string) { + out = []string{} + + if len(in) < 6 { + // The minimum expression with the AS keyword is "x AS y", 6 chars. + return []string{in} + } + + start, lim := 0, len(in)-1 + + for start <= lim { + var end int + + for end = start; end <= lim; end++ { + if end > 3 && isBlankSymbol(in[end]) && isBlankSymbol(in[end-3]) { + if (in[end-1] == 's' || in[end-1] == 'S') && (in[end-2] == 'a' || in[end-2] == 'A') { + break + } + } + } + + if end < lim { + out = append(out, trimString(in[start:end-3])) + } else { + out = append(out, trimString(in[start:end])) + } + + start = end + 1 + } + + return +} diff --git a/internal/sqladapter/exql/utilities_test.go b/internal/sqladapter/exql/utilities_test.go new file mode 100644 index 0000000..9dcbde3 --- /dev/null +++ b/internal/sqladapter/exql/utilities_test.go @@ -0,0 +1,211 @@ +package exql + +import ( + "bytes" + "regexp" + "strings" + "testing" + "unicode" + + "github.com/stretchr/testify/assert" +) + +const ( + blankSymbol = ' ' + stringWithCommas = "Hello,,World!,Enjoy" + stringWithSpaces = " Hello World! Enjoy" + stringWithASKeyword = "table.Name AS myTableAlias" +) + +var ( + bytesWithLeadingBlanks = []byte(" Hello world! ") + stringWithLeadingBlanks = string(bytesWithLeadingBlanks) +) + +var ( + reInvisible = regexp.MustCompile(`[\t\n\r]`) + reSpace = regexp.MustCompile(`\s+`) +) + +func mustTrim(a string, err error) string { + if err != nil { + panic(err.Error()) + } + a = reInvisible.ReplaceAllString(strings.TrimSpace(a), " ") + a = reSpace.ReplaceAllString(strings.TrimSpace(a), " ") + return a +} + +func TestUtilIsBlankSymbol(t *testing.T) { + assert.True(t, isBlankSymbol(' ')) + assert.True(t, isBlankSymbol('\n')) + assert.True(t, isBlankSymbol('\t')) + assert.True(t, isBlankSymbol('\r')) + assert.False(t, isBlankSymbol('x')) +} + +func TestUtilTrimBytes(t *testing.T) { + var trimmed []byte + + trimmed = trimBytes([]byte(" \t\nHello World! \n")) + assert.Equal(t, "Hello World!", string(trimmed)) + + trimmed = trimBytes([]byte("Nope")) + assert.Equal(t, "Nope", string(trimmed)) + + trimmed = trimBytes([]byte("")) + assert.Equal(t, "", string(trimmed)) + + trimmed = trimBytes([]byte(" ")) + assert.Equal(t, "", string(trimmed)) + + trimmed = trimBytes(nil) + assert.Equal(t, "", string(trimmed)) +} + +func TestUtilSeparateByComma(t *testing.T) { + chunks := separateByComma("Hello,,World!,Enjoy") + assert.Equal(t, 4, len(chunks)) + + assert.Equal(t, "Hello", chunks[0]) + assert.Equal(t, "", chunks[1]) + assert.Equal(t, "World!", chunks[2]) + assert.Equal(t, "Enjoy", chunks[3]) +} + +func TestUtilSeparateBySpace(t *testing.T) { + chunks := separateBySpace(" Hello World! Enjoy") + assert.Equal(t, 3, len(chunks)) + + assert.Equal(t, "Hello", chunks[0]) + assert.Equal(t, "World!", chunks[1]) + assert.Equal(t, "Enjoy", chunks[2]) +} + +func TestUtilSeparateByAS(t *testing.T) { + var chunks []string + + var tests = []string{ + `table.Name AS myTableAlias`, + `table.Name AS myTableAlias`, + "table.Name\tAS\r\nmyTableAlias", + } + + for _, test := range tests { + chunks = separateByAS(test) + assert.Len(t, chunks, 2) + + assert.Equal(t, "table.Name", chunks[0]) + assert.Equal(t, "myTableAlias", chunks[1]) + } + + // Single character. + chunks = separateByAS("a") + assert.Len(t, chunks, 1) + assert.Equal(t, "a", chunks[0]) + + // Empty name + chunks = separateByAS("") + assert.Len(t, chunks, 1) + assert.Equal(t, "", chunks[0]) + + // Single name + chunks = separateByAS(" A Single Table ") + assert.Len(t, chunks, 1) + assert.Equal(t, "A Single Table", chunks[0]) + + // Minimal expression. + chunks = separateByAS("a AS b") + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "b", chunks[1]) + + // Minimal expression with spaces. + chunks = separateByAS(" a AS b ") + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "b", chunks[1]) + + // Minimal expression + 1 with spaces. + chunks = separateByAS(" a AS bb ") + assert.Len(t, chunks, 2) + assert.Equal(t, "a", chunks[0]) + assert.Equal(t, "bb", chunks[1]) +} + +func BenchmarkUtilIsBlankSymbol(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = isBlankSymbol(blankSymbol) + } +} + +func BenchmarkUtilStdlibIsBlankSymbol(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = unicode.IsSpace(blankSymbol) + } +} + +func BenchmarkUtilTrimBytes(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = trimBytes(bytesWithLeadingBlanks) + } +} +func BenchmarkUtilStdlibBytesTrimSpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = bytes.TrimSpace(bytesWithLeadingBlanks) + } +} + +func BenchmarkUtilTrimString(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = trimString(stringWithLeadingBlanks) + } +} + +func BenchmarkUtilStdlibStringsTrimSpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = strings.TrimSpace(stringWithLeadingBlanks) + } +} + +func BenchmarkUtilSeparateByComma(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = separateByComma(stringWithCommas) + } +} + +func BenchmarkUtilRegExpSeparateByComma(b *testing.B) { + sep := regexp.MustCompile(`\s*?,\s*?`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sep.Split(stringWithCommas, -1) + } +} + +func BenchmarkUtilSeparateBySpace(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = separateBySpace(stringWithSpaces) + } +} + +func BenchmarkUtilRegExpSeparateBySpace(b *testing.B) { + sep := regexp.MustCompile(`\s+`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sep.Split(stringWithSpaces, -1) + } +} + +func BenchmarkUtilSeparateByAS(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = separateByAS(stringWithASKeyword) + } +} + +func BenchmarkUtilRegExpSeparateByAS(b *testing.B) { + sep := regexp.MustCompile(`(?i:\s+AS\s+)`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = sep.Split(stringWithASKeyword, -1) + } +} diff --git a/internal/sqladapter/exql/value.go b/internal/sqladapter/exql/value.go new file mode 100644 index 0000000..6190235 --- /dev/null +++ b/internal/sqladapter/exql/value.go @@ -0,0 +1,166 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// ValueGroups represents an array of value groups. +type ValueGroups struct { + Values []*Values +} + +func (vg *ValueGroups) IsEmpty() bool { + if vg == nil || len(vg.Values) < 1 { + return true + } + for i := range vg.Values { + if !vg.Values[i].IsEmpty() { + return false + } + } + return true +} + +var _ = Fragment(&ValueGroups{}) + +// Values represents an array of Value. +type Values struct { + Values []Fragment +} + +func (vs *Values) IsEmpty() bool { + if vs == nil || len(vs.Values) < 1 { + return true + } + return false +} + +// NewValueGroup creates and returns an array of values. +func NewValueGroup(v ...Fragment) *Values { + return &Values{Values: v} +} + +var _ = Fragment(&Values{}) + +// Value represents an escaped SQL value. +type Value struct { + V interface{} +} + +var _ = Fragment(&Value{}) + +// NewValue creates and returns a Value. +func NewValue(v interface{}) *Value { + return &Value{V: v} +} + +// Hash returns a unique identifier for the struct. +func (v *Value) Hash() uint64 { + if v == nil { + return cache.NewHash(FragmentType_Value, nil) + } + return cache.NewHash(FragmentType_Value, v.V) +} + +// Compile transforms the Value into an equivalent SQL representation. +func (v *Value) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(v); ok { + return z, nil + } + + switch value := v.V.(type) { + case compilable: + compiled, err = value.Compile(layout) + if err != nil { + return "", err + } + default: + value, err := NewRawValue(v.V) + if err != nil { + return "", err + } + compiled = layout.MustCompile( + layout.ValueQuote, + value, + ) + } + + layout.Write(v, compiled) + return +} + +// Hash returns a unique identifier for the struct. +func (vs *Values) Hash() uint64 { + if vs == nil { + return cache.NewHash(FragmentType_Values, nil) + } + h := cache.InitHash(FragmentType_Values) + for i := range vs.Values { + h = cache.AddToHash(h, vs.Values[i]) + } + return h +} + +// Compile transforms the Values into an equivalent SQL representation. +func (vs *Values) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(vs); ok { + return c, nil + } + + l := len(vs.Values) + if l > 0 { + chunks := make([]string, 0, l) + for i := 0; i < l; i++ { + chunk, err := vs.Values[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + compiled = layout.MustCompile(layout.ClauseGroup, strings.Join(chunks, layout.ValueSeparator)) + } + layout.Write(vs, compiled) + return +} + +// Hash returns a unique identifier for the struct. +func (vg *ValueGroups) Hash() uint64 { + if vg == nil { + return cache.NewHash(FragmentType_ValueGroups, nil) + } + h := cache.InitHash(FragmentType_ValueGroups) + for i := range vg.Values { + h = cache.AddToHash(h, vg.Values[i]) + } + return h +} + +// Compile transforms the ValueGroups into an equivalent SQL representation. +func (vg *ValueGroups) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(vg); ok { + return c, nil + } + + l := len(vg.Values) + if l > 0 { + chunks := make([]string, 0, l) + for i := 0; i < l; i++ { + chunk, err := vg.Values[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + compiled = strings.Join(chunks, layout.ValueSeparator) + } + + layout.Write(vg, compiled) + return +} + +// JoinValueGroups creates a new *ValueGroups object. +func JoinValueGroups(values ...*Values) *ValueGroups { + return &ValueGroups{Values: values} +} diff --git a/internal/sqladapter/exql/value_test.go b/internal/sqladapter/exql/value_test.go new file mode 100644 index 0000000..a0269b6 --- /dev/null +++ b/internal/sqladapter/exql/value_test.go @@ -0,0 +1,130 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestValue(t *testing.T) { + val := NewValue(1) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + + val = NewValue(&Raw{Value: "NOW()"}) + + s, err = val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `NOW()`, s) +} + +func TestSameRawValue(t *testing.T) { + { + val := NewValue(&Raw{Value: `"1"`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `"1"`, s) + } + { + val := NewValue(&Raw{Value: `'1'`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + } + { + val := NewValue(&Raw{Value: `1`}) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `1`, s) + } + { + val := NewValue("1") + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + } + { + val := NewValue(1) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `'1'`, s) + } +} + +func TestValues(t *testing.T) { + val := NewValueGroup( + &Value{V: &Raw{Value: "1"}}, + &Value{V: &Raw{Value: "2"}}, + &Value{V: "3"}, + ) + + s, err := val.Compile(defaultTemplate) + assert.NoError(t, err) + + assert.Equal(t, `(1, 2, '3')`, s) +} + +func BenchmarkValue(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewValue("a") + } +} + +func BenchmarkValueHash(b *testing.B) { + v := NewValue("a") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = v.Hash() + } +} + +func BenchmarkValueCompile(b *testing.B) { + v := NewValue("a") + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValueCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + v := NewValue("a") + _, _ = v.Compile(defaultTemplate) + } +} + +func BenchmarkValues(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = NewValueGroup(NewValue("a"), NewValue("b")) + } +} + +func BenchmarkValuesHash(b *testing.B) { + vs := NewValueGroup(NewValue("a"), NewValue("b")) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = vs.Hash() + } +} + +func BenchmarkValuesCompile(b *testing.B) { + vs := NewValueGroup(NewValue("a"), NewValue("b")) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = vs.Compile(defaultTemplate) + } +} + +func BenchmarkValuesCompileNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + vs := NewValueGroup(NewValue("a"), NewValue("b")) + _, _ = vs.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/exql/where.go b/internal/sqladapter/exql/where.go new file mode 100644 index 0000000..3cc8985 --- /dev/null +++ b/internal/sqladapter/exql/where.go @@ -0,0 +1,149 @@ +package exql + +import ( + "strings" + + "git.hexq.cn/tiglog/mydb/internal/cache" +) + +// Or represents an SQL OR operator. +type Or Where + +// And represents an SQL AND operator. +type And Where + +// Where represents an SQL WHERE clause. +type Where struct { + Conditions []Fragment +} + +var _ = Fragment(&Where{}) + +type conds struct { + Conds string +} + +// WhereConditions creates and retuens a new Where. +func WhereConditions(conditions ...Fragment) *Where { + return &Where{Conditions: conditions} +} + +// JoinWithOr creates and returns a new Or. +func JoinWithOr(conditions ...Fragment) *Or { + return &Or{Conditions: conditions} +} + +// JoinWithAnd creates and returns a new And. +func JoinWithAnd(conditions ...Fragment) *And { + return &And{Conditions: conditions} +} + +// Hash returns a unique identifier for the struct. +func (w *Where) Hash() uint64 { + if w == nil { + return cache.NewHash(FragmentType_Where, nil) + } + h := cache.InitHash(FragmentType_Where) + for i := range w.Conditions { + h = cache.AddToHash(h, w.Conditions[i]) + } + return h +} + +// Appends adds the conditions to the ones that already exist. +func (w *Where) Append(a *Where) *Where { + if a != nil { + w.Conditions = append(w.Conditions, a.Conditions...) + } + return w +} + +// Hash returns a unique identifier. +func (o *Or) Hash() uint64 { + if o == nil { + return cache.NewHash(FragmentType_Or, nil) + } + return cache.NewHash(FragmentType_Or, (*Where)(o)) +} + +// Hash returns a unique identifier. +func (a *And) Hash() uint64 { + if a == nil { + return cache.NewHash(FragmentType_And, nil) + } + return cache.NewHash(FragmentType_And, (*Where)(a)) +} + +// Compile transforms the Or into an equivalent SQL representation. +func (o *Or) Compile(layout *Template) (compiled string, err error) { + if z, ok := layout.Read(o); ok { + return z, nil + } + + compiled, err = groupCondition(layout, o.Conditions, layout.MustCompile(layout.ClauseOperator, layout.OrKeyword)) + if err != nil { + return "", err + } + + layout.Write(o, compiled) + + return +} + +// Compile transforms the And into an equivalent SQL representation. +func (a *And) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(a); ok { + return c, nil + } + + compiled, err = groupCondition(layout, a.Conditions, layout.MustCompile(layout.ClauseOperator, layout.AndKeyword)) + if err != nil { + return "", err + } + + layout.Write(a, compiled) + + return +} + +// Compile transforms the Where into an equivalent SQL representation. +func (w *Where) Compile(layout *Template) (compiled string, err error) { + if c, ok := layout.Read(w); ok { + return c, nil + } + + grouped, err := groupCondition(layout, w.Conditions, layout.MustCompile(layout.ClauseOperator, layout.AndKeyword)) + if err != nil { + return "", err + } + + if grouped != "" { + compiled = layout.MustCompile(layout.WhereLayout, conds{grouped}) + } + + layout.Write(w, compiled) + + return +} + +func groupCondition(layout *Template, terms []Fragment, joinKeyword string) (string, error) { + l := len(terms) + + chunks := make([]string, 0, l) + + if l > 0 { + for i := 0; i < l; i++ { + chunk, err := terms[i].Compile(layout) + if err != nil { + return "", err + } + chunks = append(chunks, chunk) + } + } + + if len(chunks) > 0 { + return layout.MustCompile(layout.ClauseGroup, strings.Join(chunks, joinKeyword)), nil + } + + return "", nil +} diff --git a/internal/sqladapter/exql/where_test.go b/internal/sqladapter/exql/where_test.go new file mode 100644 index 0000000..16e301e --- /dev/null +++ b/internal/sqladapter/exql/where_test.go @@ -0,0 +1,127 @@ +package exql + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestWhereAnd(t *testing.T) { + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + ) + + s, err := and.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `("id" > 8 AND "id" < 99 AND "name" = 'John')`, s) +} + +func TestWhereOr(t *testing.T) { + or := JoinWithOr( + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "=", Value: NewValue(&Raw{Value: "99"})}, + ) + + s, err := or.Compile(defaultTemplate) + assert.NoError(t, err) + assert.Equal(t, `("id" = 8 OR "id" = 99)`, s) +} + +func TestWhereAndOr(t *testing.T) { + and := JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + ) + + s, err := and.Compile(defaultTemplate) + assert.NoError(t, err) + + assert.Equal(t, `("id" > 8 AND "id" < 99 AND "name" = 'John' AND ("last_name" = 'Smith' OR "last_name" = 'Reyes'))`, s) +} + +func TestWhereAndRawOrAnd(t *testing.T) { + { + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(2)}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "77"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(1)}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + assert.Equal(t, + `WHERE (("id" > '2' AND "id" < 77 AND "id" < '1') AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, + mustTrim(where.Compile(defaultTemplate)), + ) + } + + { + where := WhereConditions( + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: ">", Value: NewValue(&Raw{Value: "8"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(&Raw{Value: "99"})}, + &ColumnValue{Column: &Column{Name: "id"}, Operator: "<", Value: NewValue(1)}, + ), + &ColumnValue{Column: &Column{Name: "name"}, Operator: "=", Value: NewValue("John")}, + &Raw{Value: "city_id = 728"}, + JoinWithOr( + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Smith")}, + &ColumnValue{Column: &Column{Name: "last_name"}, Operator: "=", Value: NewValue("Reyes")}, + ), + JoinWithAnd( + &ColumnValue{Column: &Column{Name: "age"}, Operator: ">", Value: NewValue(&Raw{Value: "18"})}, + &ColumnValue{Column: &Column{Name: "age"}, Operator: "<", Value: NewValue(&Raw{Value: "41"})}, + ), + ) + + assert.Equal(t, + `WHERE (("id" > 8 AND "id" > 8 AND "id" < 99 AND "id" < '1') AND "name" = 'John' AND city_id = 728 AND ("last_name" = 'Smith' OR "last_name" = 'Reyes') AND ("age" > 18 AND "age" < 41))`, + mustTrim(where.Compile(defaultTemplate)), + ) + } +} + +func BenchmarkWhere(b *testing.B) { + for i := 0; i < b.N; i++ { + _ = WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + } +} + +func BenchmarkCompileWhere(b *testing.B) { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _, _ = w.Compile(defaultTemplate) + } +} + +func BenchmarkCompileWhereNoCache(b *testing.B) { + for i := 0; i < b.N; i++ { + w := WhereConditions( + &ColumnValue{Column: &Column{Name: "baz"}, Operator: "=", Value: NewValue(99)}, + ) + _, _ = w.Compile(defaultTemplate) + } +} diff --git a/internal/sqladapter/hash.go b/internal/sqladapter/hash.go new file mode 100644 index 0000000..4d75491 --- /dev/null +++ b/internal/sqladapter/hash.go @@ -0,0 +1,8 @@ +package sqladapter + +const ( + hashTypeNone = iota + 345065139389 + + hashTypeCollection + hashTypePrimaryKeys +) diff --git a/internal/sqladapter/record.go b/internal/sqladapter/record.go new file mode 100644 index 0000000..35f568e --- /dev/null +++ b/internal/sqladapter/record.go @@ -0,0 +1,122 @@ +package sqladapter + +import ( + "reflect" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +func recordID(store mydb.Store, record mydb.Record) (mydb.Cond, error) { + if record == nil { + return nil, mydb.ErrNilRecord + } + + if hasConstraints, ok := record.(mydb.HasConstraints); ok { + return hasConstraints.Constraints(), nil + } + + id := mydb.Cond{} + + keys, fields, err := recordPrimaryKeyFieldValues(store, record) + if err != nil { + return nil, err + } + for i := range fields { + if fields[i] == reflect.Zero(reflect.TypeOf(fields[i])).Interface() { + return nil, mydb.ErrRecordIDIsZero + } + id[keys[i]] = fields[i] + } + if len(id) < 1 { + return nil, mydb.ErrRecordIDIsZero + } + + return id, nil +} + +func recordPrimaryKeyFieldValues(store mydb.Store, record mydb.Record) ([]string, []interface{}, error) { + sess := store.Session() + + pKeys, err := sess.(Session).PrimaryKeys(store.Name()) + if err != nil { + return nil, nil, err + } + + fields := sqlbuilder.Mapper.FieldsByName(reflect.ValueOf(record), pKeys) + + values := make([]interface{}, 0, len(fields)) + for i := range fields { + if fields[i].IsValid() { + values = append(values, fields[i].Interface()) + } + } + + return pKeys, values, nil +} + +func recordCreate(store mydb.Store, record mydb.Record) error { + sess := store.Session() + + if validator, ok := record.(mydb.Validator); ok { + if err := validator.Validate(); err != nil { + return err + } + } + + if hook, ok := record.(mydb.BeforeCreateHook); ok { + if err := hook.BeforeCreate(sess); err != nil { + return err + } + } + + if creator, ok := store.(mydb.StoreCreator); ok { + if err := creator.Create(record); err != nil { + return err + } + } else { + if err := store.InsertReturning(record); err != nil { + return err + } + } + + if hook, ok := record.(mydb.AfterCreateHook); ok { + if err := hook.AfterCreate(sess); err != nil { + return err + } + } + return nil +} + +func recordUpdate(store mydb.Store, record mydb.Record) error { + sess := store.Session() + + if validator, ok := record.(mydb.Validator); ok { + if err := validator.Validate(); err != nil { + return err + } + } + + if hook, ok := record.(mydb.BeforeUpdateHook); ok { + if err := hook.BeforeUpdate(sess); err != nil { + return err + } + } + + if updater, ok := store.(mydb.StoreUpdater); ok { + if err := updater.Update(record); err != nil { + return err + } + } else { + if err := record.Store(sess).UpdateReturning(record); err != nil { + return err + } + } + + if hook, ok := record.(mydb.AfterUpdateHook); ok { + if err := hook.AfterUpdate(sess); err != nil { + return err + } + } + return nil +} diff --git a/internal/sqladapter/result.go b/internal/sqladapter/result.go new file mode 100644 index 0000000..1f0ead6 --- /dev/null +++ b/internal/sqladapter/result.go @@ -0,0 +1,498 @@ +package sqladapter + +import ( + "errors" + "sync" + "sync/atomic" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" +) + +type Result struct { + builder mydb.SQL + + err atomic.Value + + iter mydb.Iterator + iterMu sync.Mutex + + prev *Result + fn func(*result) error +} + +// result represents a delimited set of items bound by a condition. +type result struct { + table string + limit int + offset int + + pageSize uint + pageNumber uint + + cursorColumn string + nextPageCursorValue interface{} + prevPageCursorValue interface{} + + fields []interface{} + orderBy []interface{} + groupBy []interface{} + conds [][]interface{} +} + +func filter(conds []interface{}) []interface{} { + return conds +} + +// NewResult creates and Results a new Result set on the given table, this set +// is limited by the given exql.Where conditions. +func NewResult(builder mydb.SQL, table string, conds []interface{}) *Result { + r := &Result{ + builder: builder, + } + return r.from(table).where(conds) +} + +func (r *Result) frame(fn func(*result) error) *Result { + return &Result{err: r.err, prev: r, fn: fn} +} + +func (r *Result) SQL() mydb.SQL { + if r.prev == nil { + return r.builder + } + return r.prev.SQL() +} + +func (r *Result) from(table string) *Result { + return r.frame(func(res *result) error { + res.table = table + return nil + }) +} + +func (r *Result) where(conds []interface{}) *Result { + return r.frame(func(res *result) error { + res.conds = [][]interface{}{conds} + return nil + }) +} + +func (r *Result) setErr(err error) { + if err == nil { + return + } + r.err.Store(err) +} + +// Err returns the last error that has happened with the result set, +// nil otherwise +func (r *Result) Err() error { + if errV := r.err.Load(); errV != nil { + return errV.(error) + } + return nil +} + +// Where sets conditions for the result set. +func (r *Result) Where(conds ...interface{}) mydb.Result { + return r.where(conds) +} + +// And adds more conditions on top of the existing ones. +func (r *Result) And(conds ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.conds = append(res.conds, conds) + return nil + }) +} + +// Limit determines the maximum limit of Results to be returned. +func (r *Result) Limit(n int) mydb.Result { + return r.frame(func(res *result) error { + res.limit = n + return nil + }) +} + +func (r *Result) Paginate(pageSize uint) mydb.Result { + return r.frame(func(res *result) error { + res.pageSize = pageSize + return nil + }) +} + +func (r *Result) Page(pageNumber uint) mydb.Result { + return r.frame(func(res *result) error { + res.pageNumber = pageNumber + res.nextPageCursorValue = nil + res.prevPageCursorValue = nil + return nil + }) +} + +func (r *Result) Cursor(cursorColumn string) mydb.Result { + return r.frame(func(res *result) error { + res.cursorColumn = cursorColumn + return nil + }) +} + +func (r *Result) NextPage(cursorValue interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.nextPageCursorValue = cursorValue + res.prevPageCursorValue = nil + return nil + }) +} + +func (r *Result) PrevPage(cursorValue interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.nextPageCursorValue = nil + res.prevPageCursorValue = cursorValue + return nil + }) +} + +// Offset determines how many documents will be skipped before starting to grab +// Results. +func (r *Result) Offset(n int) mydb.Result { + return r.frame(func(res *result) error { + res.offset = n + return nil + }) +} + +// GroupBy is used to group Results that have the same value in the same column +// or columns. +func (r *Result) GroupBy(fields ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.groupBy = fields + return nil + }) +} + +// OrderBy determines sorting of Results according to the provided names. Fields +// may be prefixed by - (minus) which means descending order, ascending order +// would be used otherwise. +func (r *Result) OrderBy(fields ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.orderBy = fields + return nil + }) +} + +// Select determines which fields to return. +func (r *Result) Select(fields ...interface{}) mydb.Result { + return r.frame(func(res *result) error { + res.fields = fields + return nil + }) +} + +// String satisfies fmt.Stringer +func (r *Result) String() string { + query, err := r.Paginator() + if err != nil { + panic(err.Error()) + } + return query.String() +} + +// All dumps all Results into a pointer to an slice of structs or maps. +func (r *Result) All(dst interface{}) error { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return err + } + err = query.Iterator().All(dst) + r.setErr(err) + return err +} + +// One fetches only one Result from the set. +func (r *Result) One(dst interface{}) error { + one := r.Limit(1).(*Result) + query, err := one.Paginator() + if err != nil { + r.setErr(err) + return err + } + + err = query.Iterator().One(dst) + r.setErr(err) + return err +} + +// Next fetches the next Result from the set. +func (r *Result) Next(dst interface{}) bool { + r.iterMu.Lock() + defer r.iterMu.Unlock() + + if r.iter == nil { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return false + } + r.iter = query.Iterator() + } + + if r.iter.Next(dst) { + return true + } + + if err := r.iter.Err(); !errors.Is(err, mydb.ErrNoMoreRows) { + r.setErr(err) + return false + } + + return false +} + +// Delete deletes all matching items from the collection. +func (r *Result) Delete() error { + query, err := r.buildDelete() + if err != nil { + r.setErr(err) + return err + } + + _, err = query.Exec() + r.setErr(err) + return err +} + +// Close closes the Result set. +func (r *Result) Close() error { + if r.iter != nil { + err := r.iter.Close() + r.setErr(err) + return err + } + return nil +} + +// Update updates matching items from the collection with values of the given +// map or struct. +func (r *Result) Update(values interface{}) error { + query, err := r.buildUpdate(values) + if err != nil { + r.setErr(err) + return err + } + + _, err = query.Exec() + r.setErr(err) + return err +} + +func (r *Result) TotalPages() (uint, error) { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return 0, err + } + + total, err := query.TotalPages() + if err != nil { + r.setErr(err) + return 0, err + } + + return total, nil +} + +func (r *Result) TotalEntries() (uint64, error) { + query, err := r.Paginator() + if err != nil { + r.setErr(err) + return 0, err + } + + total, err := query.TotalEntries() + if err != nil { + r.setErr(err) + return 0, err + } + + return total, nil +} + +// Exists returns true if at least one item on the collection exists. +func (r *Result) Exists() (bool, error) { + query, err := r.buildCount() + if err != nil { + r.setErr(err) + return false, err + } + + query = query.Limit(1) + + value := struct { + Exists uint64 `db:"_t"` + }{} + + if err := query.One(&value); err != nil { + if errors.Is(err, mydb.ErrNoMoreRows) { + return false, nil + } + r.setErr(err) + return false, err + } + + if value.Exists > 0 { + return true, nil + } + + return false, nil +} + +// Count counts the elements on the set. +func (r *Result) Count() (uint64, error) { + query, err := r.buildCount() + if err != nil { + r.setErr(err) + return 0, err + } + + counter := struct { + Count uint64 `db:"_t"` + }{} + if err := query.One(&counter); err != nil { + if errors.Is(err, mydb.ErrNoMoreRows) { + return 0, nil + } + r.setErr(err) + return 0, err + } + + return counter.Count, nil +} + +func (r *Result) Paginator() (mydb.Paginator, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + sel := r.SQL().Select(res.fields...). + From(res.table). + Limit(res.limit). + Offset(res.offset). + GroupBy(res.groupBy...). + OrderBy(res.orderBy...) + + for i := range res.conds { + sel = sel.And(filter(res.conds[i])...) + } + + pag := sel.Paginate(res.pageSize). + Page(res.pageNumber). + Cursor(res.cursorColumn) + + if res.nextPageCursorValue != nil { + pag = pag.NextPage(res.nextPageCursorValue) + } + + if res.prevPageCursorValue != nil { + pag = pag.PrevPage(res.prevPageCursorValue) + } + + return pag, nil +} + +func (r *Result) buildDelete() (mydb.Deleter, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + del := r.SQL().DeleteFrom(res.table). + Limit(res.limit) + + for i := range res.conds { + del = del.And(filter(res.conds[i])...) + } + + return del, nil +} + +func (r *Result) buildUpdate(values interface{}) (mydb.Updater, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + upd := r.SQL().Update(res.table). + Set(values). + Limit(res.limit) + + for i := range res.conds { + upd = upd.And(filter(res.conds[i])...) + } + + return upd, nil +} + +func (r *Result) buildCount() (mydb.Selector, error) { + if err := r.Err(); err != nil { + return nil, err + } + + res, err := r.fastForward() + if err != nil { + return nil, err + } + + sel := r.SQL().Select(mydb.Raw("count(1) AS _t")). + From(res.table). + GroupBy(res.groupBy...) + + for i := range res.conds { + sel = sel.And(filter(res.conds[i])...) + } + + return sel, nil +} + +func (r *Result) Prev() immutable.Immutable { + if r == nil { + return nil + } + return r.prev +} + +func (r *Result) Fn(in interface{}) error { + if r.fn == nil { + return nil + } + return r.fn(in.(*result)) +} + +func (r *Result) Base() interface{} { + return &result{} +} + +func (r *Result) fastForward() (*result, error) { + ff, err := immutable.FastForward(r) + if err != nil { + return nil, err + } + return ff.(*result), nil +} + +var _ = immutable.Immutable(&Result{}) diff --git a/internal/sqladapter/session.go b/internal/sqladapter/session.go new file mode 100644 index 0000000..e973e6b --- /dev/null +++ b/internal/sqladapter/session.go @@ -0,0 +1,1106 @@ +package sqladapter + +import ( + "bytes" + "context" + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "math" + "reflect" + "strconv" + "sync" + "sync/atomic" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/compat" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +var ( + lastSessID uint64 + lastTxID uint64 +) + +var ( + slowQueryThreshold = time.Millisecond * 200 + retryTransactionWaitTime = time.Millisecond * 10 + retryTransactionMaxWaitTime = time.Second * 1 +) + +// hasCleanUp is implemented by structs that have a clean up routine that needs +// to be called before Close(). +type hasCleanUp interface { + CleanUp() error +} + +// statementExecer allows the adapter to have its own exec statement. +type statementExecer interface { + StatementExec(sess Session, ctx context.Context, query string, args ...interface{}) (sql.Result, error) +} + +// statementCompiler transforms an internal statement into a format +// database/sql can understand. +type statementCompiler interface { + CompileStatement(sess Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error) +} + +// sessValueConverter converts values before being passed to the underlying driver. +type sessValueConverter interface { + ConvertValue(in interface{}) interface{} +} + +// sessValueConverterContext converts values before being passed to the underlying driver. +type sessValueConverterContext interface { + ConvertValueContext(ctx context.Context, in interface{}) interface{} +} + +// valueConverter converts values before being passed to the underlying driver. +type valueConverter interface { + ConvertValue(in interface{}) interface { + sql.Scanner + driver.Valuer + } +} + +// errorConverter converts an error value from the underlying driver into +// something different. +type errorConverter interface { + Err(errIn error) (errOut error) +} + +// AdapterSession defines methods to be implemented by SQL adapters. +type AdapterSession interface { + Template() *exql.Template + + NewCollection() CollectionAdapter + + // Open opens a new connection + OpenDSN(sess Session, dsn string) (*sql.DB, error) + + // Collections returns a list of non-system tables from the database. + Collections(sess Session) ([]string, error) + + // TableExists returns an error if the given table does not exist. + TableExists(sess Session, name string) error + + // LookupName returns the name of the database. + LookupName(sess Session) (string, error) + + // PrimaryKeys returns all primary keys on the table. + PrimaryKeys(sess Session, name string) ([]string, error) +} + +// Session satisfies mydb.Session. +type Session interface { + SQL() mydb.SQL + + // PrimaryKeys returns all primary keys on the table. + PrimaryKeys(tableName string) ([]string, error) + + // Collections returns a list of references to all collections in the + // database. + Collections() ([]mydb.Collection, error) + + // Name returns the name of the database. + Name() string + + // Close closes the database session + Close() error + + // Ping checks if the database server is reachable. + Ping() error + + // Reset clears all caches the session is using + Reset() + + // Collection returns a new collection. + Collection(string) mydb.Collection + + // ConnectionURL returns the ConnectionURL that was used to create the + // Session. + ConnectionURL() mydb.ConnectionURL + + // Open attempts to establish a connection to the database server. + Open() error + + // TableExists returns an error if the table doesn't exists. + TableExists(name string) error + + // Driver returns the underlying driver the session is using + Driver() interface{} + + Save(mydb.Record) error + + Get(mydb.Record, interface{}) error + + Delete(mydb.Record) error + + // WaitForConnection attempts to run the given connection function a fixed + // number of times before failing. + WaitForConnection(func() error) error + + // BindDB sets the *sql.DB the session will use. + BindDB(*sql.DB) error + + // Session returns the *sql.DB the session is using. + DB() *sql.DB + + // BindTx binds a transaction to the current session. + BindTx(context.Context, *sql.Tx) error + + // Returns the current transaction the session is using. + Transaction() *sql.Tx + + // NewClone clones the database using the given AdapterSession as base. + NewClone(AdapterSession, bool) (Session, error) + + // Context returns the default context the session is using. + Context() context.Context + + // SetContext sets the default context for the session. + SetContext(context.Context) + + NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) + + Tx(fn func(sess mydb.Session) error) error + + TxContext(ctx context.Context, fn func(sess mydb.Session) error, opts *sql.TxOptions) error + + WithContext(context.Context) mydb.Session + + IsTransaction() bool + + Commit() error + + Rollback() error + + mydb.Settings +} + +// NewTx wraps a *sql.Tx and returns a Tx. +func NewTx(adapter AdapterSession, tx *sql.Tx) (Session, error) { + sessTx := &sessionWithContext{ + session: &session{ + Settings: mydb.DefaultSettings, + + sqlTx: tx, + adapter: adapter, + cachedPKs: cache.NewCache(), + cachedCollections: cache.NewCache(), + cachedStatements: cache.NewCache(), + }, + ctx: context.Background(), + } + return sessTx, nil +} + +// NewSession creates a new Session. +func NewSession(connURL mydb.ConnectionURL, adapter AdapterSession) Session { + sess := &sessionWithContext{ + session: &session{ + Settings: mydb.DefaultSettings, + + connURL: connURL, + adapter: adapter, + cachedPKs: cache.NewCache(), + cachedCollections: cache.NewCache(), + cachedStatements: cache.NewCache(), + }, + ctx: context.Background(), + } + return sess +} + +type session struct { + mydb.Settings + + adapter AdapterSession + + connURL mydb.ConnectionURL + + builder mydb.SQL + + lookupNameOnce sync.Once + name string + + mu sync.Mutex // guards ctx, txOptions + txOptions *sql.TxOptions + + sqlDBMu sync.Mutex // guards sess, baseTx + + sqlDB *sql.DB + sqlTx *sql.Tx + + sessID uint64 + txID uint64 + + cacheMu sync.Mutex // guards cachedStatements and cachedCollections + cachedPKs *cache.Cache + cachedStatements *cache.Cache + cachedCollections *cache.Cache + + template *exql.Template +} + +type sessionWithContext struct { + *session + + ctx context.Context +} + +func (sess *sessionWithContext) WithContext(ctx context.Context) mydb.Session { + if ctx == nil { + panic("nil context") + } + newSess := &sessionWithContext{ + session: sess.session, + ctx: ctx, + } + return newSess +} + +func (sess *sessionWithContext) Tx(fn func(sess mydb.Session) error) error { + return TxContext(sess.Context(), sess, fn, nil) +} + +func (sess *sessionWithContext) TxContext(ctx context.Context, fn func(sess mydb.Session) error, opts *sql.TxOptions) error { + return TxContext(ctx, sess, fn, opts) +} + +func (sess *sessionWithContext) SQL() mydb.SQL { + return sqlbuilder.WithSession( + sess, + sess.adapter.Template(), + ) +} + +func (sess *sessionWithContext) Err(errIn error) (errOur error) { + if convertError, ok := sess.adapter.(errorConverter); ok { + return convertError.Err(errIn) + } + return errIn +} + +func (sess *sessionWithContext) PrimaryKeys(tableName string) ([]string, error) { + h := cache.NewHashable(hashTypePrimaryKeys, tableName) + + cachedPK, ok := sess.cachedPKs.ReadRaw(h) + if ok { + return cachedPK.([]string), nil + } + + pk, err := sess.adapter.PrimaryKeys(sess, tableName) + if err != nil { + return nil, err + } + + sess.cachedPKs.Write(h, pk) + return pk, nil +} + +func (sess *sessionWithContext) TableExists(name string) error { + return sess.adapter.TableExists(sess, name) +} + +func (sess *sessionWithContext) NewTransaction(ctx context.Context, opts *sql.TxOptions) (Session, error) { + if ctx == nil { + ctx = context.Background() + } + clone, err := sess.NewClone(sess.adapter, false) + if err != nil { + return nil, err + } + + connFn := func() error { + sqlTx, err := compat.BeginTx(clone.DB(), clone.Context(), opts) + if err == nil { + return clone.BindTx(ctx, sqlTx) + } + return err + } + + if err := clone.WaitForConnection(connFn); err != nil { + return nil, err + } + + return clone, nil +} + +func (sess *sessionWithContext) Collections() ([]mydb.Collection, error) { + names, err := sess.adapter.Collections(sess) + if err != nil { + return nil, err + } + + collections := make([]mydb.Collection, 0, len(names)) + for i := range names { + collections = append(collections, sess.Collection(names[i])) + } + + return collections, nil +} + +func (sess *sessionWithContext) ConnectionURL() mydb.ConnectionURL { + return sess.connURL +} + +func (sess *sessionWithContext) Open() error { + var sqlDB *sql.DB + var err error + + connFn := func() error { + sqlDB, err = sess.adapter.OpenDSN(sess, sess.connURL.String()) + if err != nil { + return err + } + + sqlmydb.SetConnMaxLifetime(sess.ConnMaxLifetime()) + sqlmydb.SetConnMaxIdleTime(sess.ConnMaxIdleTime()) + sqlmydb.SetMaxIdleConns(sess.MaxIdleConns()) + sqlmydb.SetMaxOpenConns(sess.MaxOpenConns()) + return nil + } + + if err := sess.WaitForConnection(connFn); err != nil { + return err + } + + return sess.BindDB(sqlDB) +} + +func (sess *sessionWithContext) Get(record mydb.Record, id interface{}) error { + store := record.Store(sess) + if getter, ok := store.(mydb.StoreGetter); ok { + return getter.Get(record, id) + } + return store.Find(id).One(record) +} + +func (sess *sessionWithContext) Save(record mydb.Record) error { + if record == nil { + return mydb.ErrNilRecord + } + + if reflect.TypeOf(record).Kind() != reflect.Ptr { + return mydb.ErrExpectingPointerToStruct + } + + store := record.Store(sess) + if saver, ok := store.(mydb.StoreSaver); ok { + return saver.Save(record) + } + + id := mydb.Cond{} + keys, values, err := recordPrimaryKeyFieldValues(store, record) + if err != nil { + return err + } + for i := range values { + if values[i] != reflect.Zero(reflect.TypeOf(values[i])).Interface() { + id[keys[i]] = values[i] + } + } + + if len(id) > 0 && len(id) == len(values) { + // check if record exists before updating it + exists, _ := store.Find(id).Count() + if exists > 0 { + return recordUpdate(store, record) + } + } + + return recordCreate(store, record) +} + +func (sess *sessionWithContext) Delete(record mydb.Record) error { + if record == nil { + return mydb.ErrNilRecord + } + + if reflect.TypeOf(record).Kind() != reflect.Ptr { + return mydb.ErrExpectingPointerToStruct + } + + store := record.Store(sess) + + if hook, ok := record.(mydb.BeforeDeleteHook); ok { + if err := hook.BeforeDelete(sess); err != nil { + return err + } + } + + if deleter, ok := store.(mydb.StoreDeleter); ok { + if err := deleter.Delete(record); err != nil { + return err + } + } else { + conds, err := recordID(store, record) + if err != nil { + return err + } + if err := store.Find(conds).Delete(); err != nil { + return err + } + } + + if hook, ok := record.(mydb.AfterDeleteHook); ok { + if err := hook.AfterDelete(sess); err != nil { + return err + } + } + + return nil +} + +func (sess *sessionWithContext) DB() *sql.DB { + return sess.sqlDB +} + +func (sess *sessionWithContext) SetContext(ctx context.Context) { + sess.mu.Lock() + sess.ctx = ctx + sess.mu.Unlock() +} + +func (sess *sessionWithContext) Context() context.Context { + return sess.ctx +} + +func (sess *sessionWithContext) SetTxOptions(txOptions sql.TxOptions) { + sess.mu.Lock() + sess.txOptions = &txOptions + sess.mu.Unlock() +} + +func (sess *sessionWithContext) TxOptions() *sql.TxOptions { + sess.mu.Lock() + defer sess.mu.Unlock() + if sess.txOptions == nil { + return nil + } + return sess.txOptions +} + +func (sess *sessionWithContext) BindTx(ctx context.Context, tx *sql.Tx) error { + sess.sqlDBMu.Lock() + defer sess.sqlDBMu.Unlock() + + sess.sqlTx = tx + sess.SetContext(ctx) + + sess.txID = newBaseTxID() + + return nil +} + +func (sess *sessionWithContext) Commit() error { + if sess.sqlTx != nil { + return sess.sqlTx.Commit() + } + return mydb.ErrNotWithinTransaction +} + +func (sess *sessionWithContext) Rollback() error { + if sess.sqlTx != nil { + return sess.sqlTx.Rollback() + } + return mydb.ErrNotWithinTransaction +} + +func (sess *sessionWithContext) IsTransaction() bool { + return sess.sqlTx != nil +} + +func (sess *sessionWithContext) Transaction() *sql.Tx { + return sess.sqlTx +} + +func (sess *sessionWithContext) Name() string { + sess.lookupNameOnce.Do(func() { + if sess.name == "" { + sess.name, _ = sess.adapter.LookupName(sess) + } + }) + + return sess.name +} + +func (sess *sessionWithContext) BindDB(sqlDB *sql.DB) error { + + sess.sqlDBMu.Lock() + sess.sqlDB = sqlDB + sess.sqlDBMu.Unlock() + + if err := sess.Ping(); err != nil { + return err + } + + sess.sessID = newSessionID() + name, err := sess.adapter.LookupName(sess) + if err != nil { + return err + } + sess.name = name + + return nil +} + +func (sess *sessionWithContext) Ping() error { + if sess.sqlDB != nil { + return sess.sqlmydb.Ping() + } + return mydb.ErrNotConnected +} + +func (sess *sessionWithContext) SetConnMaxLifetime(t time.Duration) { + sess.Settings.SetConnMaxLifetime(t) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetConnMaxLifetime(sess.Settings.ConnMaxLifetime()) + } +} + +func (sess *sessionWithContext) SetConnMaxIdleTime(t time.Duration) { + sess.Settings.SetConnMaxIdleTime(t) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetConnMaxIdleTime(sess.Settings.ConnMaxIdleTime()) + } +} + +func (sess *sessionWithContext) SetMaxIdleConns(n int) { + sess.Settings.SetMaxIdleConns(n) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetMaxIdleConns(sess.Settings.MaxIdleConns()) + } +} + +func (sess *sessionWithContext) SetMaxOpenConns(n int) { + sess.Settings.SetMaxOpenConns(n) + if sessDB := sess.DB(); sessDB != nil { + sessmydb.SetMaxOpenConns(sess.Settings.MaxOpenConns()) + } +} + +// Reset removes all caches. +func (sess *sessionWithContext) Reset() { + sess.cacheMu.Lock() + defer sess.cacheMu.Unlock() + + sess.cachedPKs.Clear() + sess.cachedCollections.Clear() + sess.cachedStatements.Clear() + + if sess.template != nil { + sess.template.Cache.Clear() + } +} + +func (sess *sessionWithContext) NewClone(adapter AdapterSession, checkConn bool) (Session, error) { + + newSess := NewSession(sess.connURL, adapter).(*sessionWithContext) + + newSess.name = sess.name + newSess.sqlDB = sess.sqlDB + newSess.cachedPKs = sess.cachedPKs + + if checkConn { + if err := newSess.Ping(); err != nil { + // Retry once if ping fails. + return sess.NewClone(adapter, false) + } + } + + newSess.sessID = newSessionID() + + // New transaction should inherit parent settings + copySettings(sess, newSess) + + return newSess, nil +} + +func (sess *sessionWithContext) Close() error { + defer func() { + sess.sqlDBMu.Lock() + sess.sqlDB = nil + sess.sqlTx = nil + sess.sqlDBMu.Unlock() + }() + + if sess.sqlDB == nil { + return nil + } + + sess.cachedCollections.Clear() + sess.cachedStatements.Clear() // Closes prepared statements as well. + + if !sess.IsTransaction() { + if cleaner, ok := sess.adapter.(hasCleanUp); ok { + if err := cleaner.CleanUp(); err != nil { + return err + } + } + // Not within a transaction. + return sess.sqlmydb.Close() + } + + return nil +} + +func (sess *sessionWithContext) Collection(name string) mydb.Collection { + sess.cacheMu.Lock() + defer sess.cacheMu.Unlock() + + h := cache.NewHashable(hashTypeCollection, name) + + col, ok := sess.cachedCollections.ReadRaw(h) + if !ok { + col = newCollection(name, sess.adapter.NewCollection()) + sess.cachedCollections.Write(h, col) + } + + return &collectionWithSession{ + collection: col.(*collection), + session: sess, + } +} + +func queryLog(status *mydb.QueryStatus) { + diff := status.End.Sub(status.Start) + + slowQuery := false + if diff >= slowQueryThreshold { + status.Err = mydb.ErrWarnSlowQuery + slowQuery = true + } + + if status.Err != nil || slowQuery { + mydb.LC().Warn(status) + return + } + + mydb.LC().Debug(status) +} + +func (sess *sessionWithContext) StatementPrepare(ctx context.Context, stmt *exql.Statement) (sqlStmt *sql.Stmt, err error) { + var query string + + defer func(start time.Time) { + queryLog(&mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + }) + }(time.Now()) + + query, _, err = sess.compileStatement(stmt, nil) + if err != nil { + return nil, err + } + + tx := sess.Transaction() + if tx != nil { + sqlStmt, err = compat.PrepareContext(tx, ctx, query) + return + } + + sqlStmt, err = compat.PrepareContext(sess.sqlDB, ctx, query) + return +} + +func (sess *sessionWithContext) ConvertValue(value interface{}) interface{} { + if scannerValuer, ok := value.(sqlbuilder.ScannerValuer); ok { + return scannerValuer + } + + dv := reflect.Indirect(reflect.ValueOf(value)) + if dv.IsValid() { + if converter, ok := dv.Interface().(valueConverter); ok { + return converter.ConvertValue(dv.Interface()) + } + } + + if converter, ok := sess.adapter.(sessValueConverterContext); ok { + return converter.ConvertValueContext(sess.Context(), value) + } + + if converter, ok := sess.adapter.(sessValueConverter); ok { + return converter.ConvertValue(value) + } + + return value +} + +func (sess *sessionWithContext) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (res sql.Result, err error) { + var query string + + defer func(start time.Time) { + status := mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + } + + if res != nil { + if rowsAffected, err := res.RowsAffected(); err == nil { + status.RowsAffected = &rowsAffected + } + + if lastInsertID, err := res.LastInsertId(); err == nil { + status.LastInsertID = &lastInsertID + } + } + + queryLog(&status) + }(time.Now()) + + if execer, ok := sess.adapter.(statementExecer); ok { + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + res, err = execer.StatementExec(sess, ctx, query, args...) + return + } + + tx := sess.Transaction() + if sess.Settings.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil { + return nil, err + } + defer p.Close() + + res, err = compat.PreparedExecContext(p, ctx, args) + return + } + + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + + if tx != nil { + res, err = compat.ExecContext(tx, ctx, query, args) + return + } + + res, err = compat.ExecContext(sess.sqlDB, ctx, query, args) + return +} + +// StatementQuery compiles and executes a statement that returns rows. +func (sess *sessionWithContext) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (rows *sql.Rows, err error) { + var query string + + defer func(start time.Time) { + status := mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + } + queryLog(&status) + }(time.Now()) + + tx := sess.Transaction() + + if sess.Settings.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil { + return nil, err + } + defer p.Close() + + rows, err = compat.PreparedQueryContext(p, ctx, args) + return + } + + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + if tx != nil { + rows, err = compat.QueryContext(tx, ctx, query, args) + return + } + + rows, err = compat.QueryContext(sess.sqlDB, ctx, query, args) + return +} + +// StatementQueryRow compiles and executes a statement that returns at most one +// row. +func (sess *sessionWithContext) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (row *sql.Row, err error) { + var query string + + defer func(start time.Time) { + status := mydb.QueryStatus{ + TxID: sess.txID, + SessID: sess.sessID, + RawQuery: query, + Args: args, + Err: err, + Start: start, + End: time.Now(), + Context: ctx, + } + queryLog(&status) + }(time.Now()) + + tx := sess.Transaction() + + if sess.Settings.PreparedStatementCacheEnabled() && tx == nil { + var p *Stmt + if p, query, args, err = sess.prepareStatement(ctx, stmt, args); err != nil { + return nil, err + } + defer p.Close() + + row = compat.PreparedQueryRowContext(p, ctx, args) + return + } + + query, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, err + } + if tx != nil { + row = compat.QueryRowContext(tx, ctx, query, args) + return + } + + row = compat.QueryRowContext(sess.sqlDB, ctx, query, args) + return +} + +// Driver returns the underlying *sql.DB or *sql.Tx instance. +func (sess *sessionWithContext) Driver() interface{} { + if sess.sqlTx != nil { + return sess.sqlTx + } + return sess.sqlDB +} + +// compileStatement compiles the given statement into a string. +func (sess *sessionWithContext) compileStatement(stmt *exql.Statement, args []interface{}) (string, []interface{}, error) { + for i := range args { + args[i] = sess.ConvertValue(args[i]) + } + if statementCompiler, ok := sess.adapter.(statementCompiler); ok { + return statementCompiler.CompileStatement(sess, stmt, args) + } + + compiled, err := stmt.Compile(sess.adapter.Template()) + if err != nil { + return "", nil, err + } + query, args := sqlbuilder.Preprocess(compiled, args) + return query, args, nil +} + +// prepareStatement compiles a query and tries to use previously generated +// statement. +func (sess *sessionWithContext) prepareStatement(ctx context.Context, stmt *exql.Statement, args []interface{}) (*Stmt, string, []interface{}, error) { + sess.sqlDBMu.Lock() + defer sess.sqlDBMu.Unlock() + + sqlDB, tx := sess.sqlDB, sess.Transaction() + if sqlDB == nil && tx == nil { + return nil, "", nil, mydb.ErrNotConnected + } + + pc, ok := sess.cachedStatements.ReadRaw(stmt) + if ok { + // The statement was cachesess. + ps, err := pc.(*Stmt).Open() + if err == nil { + _, args, err = sess.compileStatement(stmt, args) + if err != nil { + return nil, "", nil, err + } + return ps, ps.query, args, nil + } + } + + query, args, err := sess.compileStatement(stmt, args) + if err != nil { + return nil, "", nil, err + } + sqlStmt, err := func(query *string) (*sql.Stmt, error) { + if tx != nil { + return compat.PrepareContext(tx, ctx, *query) + } + return compat.PrepareContext(sess.sqlDB, ctx, *query) + }(&query) + if err != nil { + return nil, "", nil, err + } + + p, err := NewStatement(sqlStmt, query).Open() + if err != nil { + return nil, query, args, err + } + sess.cachedStatements.Write(stmt, p) + return p, p.query, args, nil +} + +var waitForConnMu sync.Mutex + +// WaitForConnection tries to execute the given connectFn function, if +// connectFn returns an error, then WaitForConnection will keep trying until +// connectFn returns nil. Maximum waiting time is 5s after having acquired the +// lock. +func (sess *sessionWithContext) WaitForConnection(connectFn func() error) error { + // This lock ensures first-come, first-served and prevents opening too many + // file descriptors. + waitForConnMu.Lock() + defer waitForConnMu.Unlock() + + // Minimum waiting time. + waitTime := time.Millisecond * 10 + + // Waitig 5 seconds for a successful connection. + for timeStart := time.Now(); time.Since(timeStart) < time.Second*5; { + err := connectFn() + if err == nil { + return nil // Connected! + } + + // Only attempt to reconnect if the error is too many clients. + if sess.Err(err) == mydb.ErrTooManyClients { + // Sleep and try again if, and only if, the server replied with a "too + // many clients" error. + time.Sleep(waitTime) + if waitTime < time.Millisecond*500 { + // Wait a bit more next time. + waitTime = waitTime * 2 + } + continue + } + + // Return any other error immediately. + return err + } + + return mydb.ErrGivingUpTryingToConnect +} + +// ReplaceWithDollarSign turns a SQL statament with '?' placeholders into +// dollar placeholders, like $1, $2, ..., $n +func ReplaceWithDollarSign(buf []byte) []byte { + z := bytes.Count(buf, []byte{'?'}) + // the capacity is a quick estimation of the total memory required, this + // reduces reallocations + out := make([]byte, 0, len(buf)+z*3) + + var i, k = 0, 1 + for i < len(buf) { + if buf[i] == '?' { + out = append(out, buf[:i]...) + buf = buf[i+1:] + i = 0 + + if len(buf) > 0 && buf[0] == '?' { + out = append(out, '?') + buf = buf[1:] + continue + } + + out = append(out, '$') + out = append(out, []byte(strconv.Itoa(k))...) + k = k + 1 + continue + } + i = i + 1 + } + + out = append(out, buf[:len(buf)]...) + buf = nil + + return out +} + +func copySettings(from Session, into Session) { + into.SetPreparedStatementCache(from.PreparedStatementCacheEnabled()) + into.SetConnMaxLifetime(from.ConnMaxLifetime()) + into.SetConnMaxIdleTime(from.ConnMaxIdleTime()) + into.SetMaxIdleConns(from.MaxIdleConns()) + into.SetMaxOpenConns(from.MaxOpenConns()) +} + +func newSessionID() uint64 { + if atomic.LoadUint64(&lastSessID) == math.MaxUint64 { + atomic.StoreUint64(&lastSessID, 0) + return 0 + } + return atomic.AddUint64(&lastSessID, 1) +} + +func newBaseTxID() uint64 { + if atomic.LoadUint64(&lastTxID) == math.MaxUint64 { + atomic.StoreUint64(&lastTxID, 0) + return 0 + } + return atomic.AddUint64(&lastTxID, 1) +} + +// TxContext creates a transaction context and runs fn within it. +func TxContext(ctx context.Context, sess mydb.Session, fn func(tx mydb.Session) error, opts *sql.TxOptions) error { + txFn := func(sess mydb.Session) error { + tx, err := sess.(Session).NewTransaction(ctx, opts) + if err != nil { + return err + } + defer tx.Close() + + if err := fn(tx); err != nil { + if rollbackErr := tx.Rollback(); rollbackErr != nil { + return fmt.Errorf("%v: %w", rollbackErr, err) + } + return err + } + return tx.Commit() + } + + retryTime := retryTransactionWaitTime + + var txErr error + for i := 0; i < sess.MaxTransactionRetries(); i++ { + txErr = sess.(*sessionWithContext).Err(txFn(sess)) + if txErr == nil { + return nil + } + if errors.Is(txErr, mydb.ErrTransactionAborted) { + time.Sleep(retryTime) + + retryTime = retryTime * 2 + if retryTime > retryTransactionMaxWaitTime { + retryTime = retryTransactionMaxWaitTime + } + + continue + } + return txErr + } + + return fmt.Errorf("db: giving up trying to commit transaction: %w", txErr) +} + +var _ = mydb.Session(&sessionWithContext{}) diff --git a/internal/sqladapter/sqladapter.go b/internal/sqladapter/sqladapter.go new file mode 100644 index 0000000..a567e2e --- /dev/null +++ b/internal/sqladapter/sqladapter.go @@ -0,0 +1,62 @@ +// Package sqladapter provides common logic for SQL adapters. +package sqladapter + +import ( + "database/sql" + "database/sql/driver" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" +) + +// IsKeyValue reports whether v is a valid value for a primary key that can be +// used with Find(pKey). +func IsKeyValue(v interface{}) bool { + if v == nil { + return true + } + switch v.(type) { + case int64, int, uint, uint64, + []int64, []int, []uint, []uint64, + []byte, []string, + []interface{}, + driver.Valuer: + return true + } + return false +} + +type sqlAdapterWrapper struct { + adapter AdapterSession +} + +func (w *sqlAdapterWrapper) OpenDSN(dsn mydb.ConnectionURL) (mydb.Session, error) { + sess := NewSession(dsn, w.adapter) + if err := sess.Open(); err != nil { + return nil, err + } + return sess, nil +} + +func (w *sqlAdapterWrapper) NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) { + tx, err := NewTx(w.adapter, sqlTx) + if err != nil { + return nil, err + } + return tx, nil +} + +func (w *sqlAdapterWrapper) New(sqlDB *sql.DB) (mydb.Session, error) { + sess := NewSession(nil, w.adapter) + if err := sess.BindDB(sqlDB); err != nil { + return nil, err + } + return sess, nil +} + +// RegisterAdapter registers a new SQL adapter. +func RegisterAdapter(name string, adapter AdapterSession) sqlbuilder.Adapter { + z := &sqlAdapterWrapper{adapter} + mydb.RegisterAdapter(name, sqlbuilder.NewCompatAdapter(z)) + return z +} diff --git a/internal/sqladapter/sqladapter_test.go b/internal/sqladapter/sqladapter_test.go new file mode 100644 index 0000000..99240d0 --- /dev/null +++ b/internal/sqladapter/sqladapter_test.go @@ -0,0 +1,45 @@ +package sqladapter + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/assert" +) + +var ( + _ mydb.Collection = &collectionWithSession{} + _ Collection = &collectionWithSession{} +) + +func TestReplaceWithDollarSign(t *testing.T) { + tests := []struct { + in string + out string + }{ + { + `SELECT ?`, + `SELECT $1`, + }, + { + `SELECT ? FROM ? WHERE ?`, + `SELECT $1 FROM $2 WHERE $3`, + }, + { + `SELECT ?? FROM ? WHERE ??`, + `SELECT ? FROM $1 WHERE ?`, + }, + { + `SELECT ??? FROM ? WHERE ??`, + `SELECT ?$1 FROM $2 WHERE ?`, + }, + { + `SELECT ??? FROM ? WHERE ????`, + `SELECT ?$1 FROM $2 WHERE ??`, + }, + } + + for _, test := range tests { + assert.Equal(t, []byte(test.out), ReplaceWithDollarSign([]byte(test.in))) + } +} diff --git a/internal/sqladapter/statement.go b/internal/sqladapter/statement.go new file mode 100644 index 0000000..0b18ebd --- /dev/null +++ b/internal/sqladapter/statement.go @@ -0,0 +1,85 @@ +package sqladapter + +import ( + "database/sql" + "errors" + "sync" + "sync/atomic" +) + +var ( + activeStatements int64 +) + +// Stmt represents a *sql.Stmt that is cached and provides the +// OnEvict method to allow it to clean after itself. +type Stmt struct { + *sql.Stmt + + query string + mu sync.Mutex + + count int64 + dead bool +} + +// NewStatement creates an returns an opened statement +func NewStatement(stmt *sql.Stmt, query string) *Stmt { + s := &Stmt{ + Stmt: stmt, + query: query, + } + atomic.AddInt64(&activeStatements, 1) + return s +} + +// Open marks the statement as in-use +func (c *Stmt) Open() (*Stmt, error) { + c.mu.Lock() + defer c.mu.Unlock() + + if c.dead { + return nil, errors.New("statement is dead") + } + + c.count++ + return c, nil +} + +// Close closes the underlying statement if no other go-routine is using it. +func (c *Stmt) Close() error { + c.mu.Lock() + defer c.mu.Unlock() + + c.count-- + + return c.checkClose() +} + +func (c *Stmt) checkClose() error { + if c.dead && c.count == 0 { + // Statement is dead and we can close it for real. + err := c.Stmt.Close() + if err != nil { + return err + } + // Reduce active statements counter. + atomic.AddInt64(&activeStatements, -1) + } + return nil +} + +// OnEvict marks the statement as ready to be cleaned up. +func (c *Stmt) OnEvict() { + c.mu.Lock() + defer c.mu.Unlock() + + c.dead = true + c.checkClose() +} + +// NumActiveStatements returns the global number of prepared statements in use +// at any point. +func NumActiveStatements() int64 { + return atomic.LoadInt64(&activeStatements) +} diff --git a/internal/sqlbuilder/batch.go b/internal/sqlbuilder/batch.go new file mode 100644 index 0000000..c988e04 --- /dev/null +++ b/internal/sqlbuilder/batch.go @@ -0,0 +1,84 @@ +package sqlbuilder + +import git.hexq.cn/tiglog/mydb + +// BatchInserter provides a helper that can be used to do massive insertions in +// batches. +type BatchInserter struct { + inserter *inserter + size int + values chan []interface{} + err error +} + +func newBatchInserter(inserter *inserter, size int) *BatchInserter { + if size < 1 { + size = 1 + } + b := &BatchInserter{ + inserter: inserter, + size: size, + values: make(chan []interface{}, size), + } + return b +} + +// Values pushes column values to be inserted as part of the batch. +func (b *BatchInserter) Values(values ...interface{}) mydb.BatchInserter { + b.values <- values + return b +} + +func (b *BatchInserter) nextQuery() *inserter { + ins := &inserter{} + *ins = *b.inserter + i := 0 + for values := range b.values { + i++ + ins = ins.Values(values...).(*inserter) + if i == b.size { + break + } + } + if i == 0 { + return nil + } + return ins +} + +// NextResult is useful when using PostgreSQL and Returning(), it dumps the +// next slice of results to dst, which can mean having the IDs of all inserted +// elements in the batch. +func (b *BatchInserter) NextResult(dst interface{}) bool { + clone := b.nextQuery() + if clone == nil { + return false + } + b.err = clone.Iterator().All(dst) + return (b.err == nil) +} + +// Done means that no more elements are going to be added. +func (b *BatchInserter) Done() { + close(b.values) +} + +// Wait blocks until the whole batch is executed. +func (b *BatchInserter) Wait() error { + for { + q := b.nextQuery() + if q == nil { + break + } + if _, err := q.Exec(); err != nil { + b.err = err + break + } + } + return b.Err() +} + +// Err returns any error while executing the batch. +func (b *BatchInserter) Err() error { + return b.err +} diff --git a/internal/sqlbuilder/builder.go b/internal/sqlbuilder/builder.go new file mode 100644 index 0000000..a90ac74 --- /dev/null +++ b/internal/sqlbuilder/builder.go @@ -0,0 +1,611 @@ +// Package sqlbuilder provides tools for building custom SQL queries. +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + "fmt" + "reflect" + "sort" + "strconv" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/reflectx" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/compat" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +// MapOptions represents options for the mapper. +type MapOptions struct { + IncludeZeroed bool + IncludeNil bool +} + +var defaultMapOptions = MapOptions{ + IncludeZeroed: false, + IncludeNil: false, +} + +type hasPaginator interface { + Paginator() (mydb.Paginator, error) +} + +type isCompilable interface { + Compile() (string, error) + Arguments() []interface{} +} + +type hasIsZero interface { + IsZero() bool +} + +type iterator struct { + sess exprDB + cursor *sql.Rows // This is the main query cursor. It starts as a nil value. + err error +} + +type fieldValue struct { + fields []string + values []interface{} +} + +var ( + sqlPlaceholder = &exql.Raw{Value: `?`} +) + +var ( + errDeprecatedJSONBTag = errors.New(`Tag "jsonb" is deprecated. See "PostgreSQL: jsonb tag" at https://github.com/upper/db/releases/tag/v3.4.0`) +) + +type exprDB interface { + StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) + StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error) + StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) + StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) + + Context() context.Context +} + +type sqlBuilder struct { + sess exprDB + t *templateWithUtils +} + +// WithSession returns a query builder that is bound to the given database session. +func WithSession(sess interface{}, t *exql.Template) mydb.SQL { + if sqlDB, ok := sess.(*sql.DB); ok { + sess = sqlDB + } + return &sqlBuilder{ + sess: sess.(exprDB), // Let it panic, it will show the developer an informative error. + t: newTemplateWithUtils(t), + } +} + +// WithTemplate returns a builder that is based on the given template. +func WithTemplate(t *exql.Template) mydb.SQL { + return &sqlBuilder{ + t: newTemplateWithUtils(t), + } +} + +func (b *sqlBuilder) NewIteratorContext(ctx context.Context, rows *sql.Rows) mydb.Iterator { + return &iterator{b.sess, rows, nil} +} + +func (b *sqlBuilder) NewIterator(rows *sql.Rows) mydb.Iterator { + return b.NewIteratorContext(b.sess.Context(), rows) +} + +func (b *sqlBuilder) Iterator(query interface{}, args ...interface{}) mydb.Iterator { + return b.IteratorContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) IteratorContext(ctx context.Context, query interface{}, args ...interface{}) mydb.Iterator { + rows, err := b.QueryContext(ctx, query, args...) + return &iterator{b.sess, rows, err} +} + +func (b *sqlBuilder) Prepare(query interface{}) (*sql.Stmt, error) { + return b.PrepareContext(b.sess.Context(), query) +} + +func (b *sqlBuilder) PrepareContext(ctx context.Context, query interface{}) (*sql.Stmt, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementPrepare(ctx, q) + case string: + return b.sess.StatementPrepare(ctx, exql.RawSQL(q)) + case *adapter.RawExpr: + return b.PrepareContext(ctx, q.Raw()) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) Exec(query interface{}, args ...interface{}) (sql.Result, error) { + return b.ExecContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementExec(ctx, q, args...) + case string: + return b.sess.StatementExec(ctx, exql.RawSQL(q), args...) + case *adapter.RawExpr: + return b.ExecContext(ctx, q.Raw(), q.Arguments()...) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) Query(query interface{}, args ...interface{}) (*sql.Rows, error) { + return b.QueryContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementQuery(ctx, q, args...) + case string: + return b.sess.StatementQuery(ctx, exql.RawSQL(q), args...) + case *adapter.RawExpr: + return b.QueryContext(ctx, q.Raw(), q.Arguments()...) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) { + return b.QueryRowContext(b.sess.Context(), query, args...) +} + +func (b *sqlBuilder) QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) { + switch q := query.(type) { + case *exql.Statement: + return b.sess.StatementQueryRow(ctx, q, args...) + case string: + return b.sess.StatementQueryRow(ctx, exql.RawSQL(q), args...) + case *adapter.RawExpr: + return b.QueryRowContext(ctx, q.Raw(), q.Arguments()...) + default: + return nil, fmt.Errorf("unsupported query type %T", query) + } +} + +func (b *sqlBuilder) SelectFrom(table ...interface{}) mydb.Selector { + qs := &selector{ + builder: b, + } + return qs.From(table...) +} + +func (b *sqlBuilder) Select(columns ...interface{}) mydb.Selector { + qs := &selector{ + builder: b, + } + return qs.Columns(columns...) +} + +func (b *sqlBuilder) InsertInto(table string) mydb.Inserter { + qi := &inserter{ + builder: b, + } + return qi.Into(table) +} + +func (b *sqlBuilder) DeleteFrom(table string) mydb.Deleter { + qd := &deleter{ + builder: b, + } + return qd.setTable(table) +} + +func (b *sqlBuilder) Update(table string) mydb.Updater { + qu := &updater{ + builder: b, + } + return qu.setTable(table) +} + +// Map receives a pointer to map or struct and maps it to columns and values. +func Map(item interface{}, options *MapOptions) ([]string, []interface{}, error) { + var fv fieldValue + if options == nil { + options = &defaultMapOptions + } + + itemV := reflect.ValueOf(item) + if !itemV.IsValid() { + return nil, nil, nil + } + + itemT := itemV.Type() + + if itemT.Kind() == reflect.Ptr { + // Single dereference. Just in case the user passes a pointer to struct + // instead of a struct. + item = itemV.Elem().Interface() + itemV = reflect.ValueOf(item) + itemT = itemV.Type() + } + + switch itemT.Kind() { + case reflect.Struct: + fieldMap := Mapper.TypeMap(itemT).Names + nfields := len(fieldMap) + + fv.values = make([]interface{}, 0, nfields) + fv.fields = make([]string, 0, nfields) + + for _, fi := range fieldMap { + + // Check for deprecated JSONB tag + if _, hasJSONBTag := fi.Options["jsonb"]; hasJSONBTag { + return nil, nil, errDeprecatedJSONBTag + } + + // Field options + _, tagOmitEmpty := fi.Options["omitempty"] + + fld := reflectx.FieldByIndexesReadOnly(itemV, fi.Index) + if fld.Kind() == reflect.Ptr && fld.IsNil() { + if tagOmitEmpty && !options.IncludeNil { + continue + } + fv.fields = append(fv.fields, fi.Name) + if tagOmitEmpty { + fv.values = append(fv.values, sqlDefault) + } else { + fv.values = append(fv.values, nil) + } + continue + } + + value := fld.Interface() + + isZero := false + if t, ok := fld.Interface().(hasIsZero); ok { + if t.IsZero() { + isZero = true + } + } else if fld.Kind() == reflect.Array || fld.Kind() == reflect.Slice { + if fld.Len() == 0 { + isZero = true + } + } else if reflect.DeepEqual(fi.Zero.Interface(), value) { + isZero = true + } + + if isZero && tagOmitEmpty && !options.IncludeZeroed { + continue + } + + fv.fields = append(fv.fields, fi.Name) + v, err := marshal(value) + if err != nil { + return nil, nil, err + } + if isZero && tagOmitEmpty { + v = sqlDefault + } + fv.values = append(fv.values, v) + } + + case reflect.Map: + nfields := itemV.Len() + fv.values = make([]interface{}, nfields) + fv.fields = make([]string, nfields) + mkeys := itemV.MapKeys() + + for i, keyV := range mkeys { + valv := itemV.MapIndex(keyV) + fv.fields[i] = fmt.Sprintf("%v", keyV.Interface()) + + v, err := marshal(valv.Interface()) + if err != nil { + return nil, nil, err + } + + fv.values[i] = v + } + default: + return nil, nil, ErrExpectingPointerToEitherMapOrStruct + } + + sort.Sort(&fv) + + return fv.fields, fv.values, nil +} + +func columnFragments(columns []interface{}) ([]exql.Fragment, []interface{}, error) { + f := make([]exql.Fragment, len(columns)) + args := []interface{}{} + + for i := range columns { + switch v := columns[i].(type) { + case hasPaginator: + p, err := v.Paginator() + if err != nil { + return nil, nil, err + } + + q, a := Preprocess(p.String(), p.Arguments()) + + f[i] = &exql.Raw{Value: "(" + q + ")"} + args = append(args, a...) + case isCompilable: + c, err := v.Compile() + if err != nil { + return nil, nil, err + } + q, a := Preprocess(c, v.Arguments()) + if _, ok := v.(mydb.Selector); ok { + q = "(" + q + ")" + } + f[i] = &exql.Raw{Value: q} + args = append(args, a...) + case *adapter.FuncExpr: + fnName, fnArgs := v.Name(), v.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + f[i] = &exql.Raw{Value: fnName} + args = append(args, fnArgs...) + case *adapter.RawExpr: + q, a := Preprocess(v.Raw(), v.Arguments()) + f[i] = &exql.Raw{Value: q} + args = append(args, a...) + case exql.Fragment: + f[i] = v + case string: + f[i] = exql.ColumnWithName(v) + case fmt.Stringer: + f[i] = exql.ColumnWithName(v.String()) + default: + var err error + f[i], err = exql.NewRawValue(columns[i]) + if err != nil { + return nil, nil, fmt.Errorf("unexpected argument type %T for Select() argument: %w", v, err) + } + } + } + return f, args, nil +} + +func prepareQueryForDisplay(in string) string { + out := make([]byte, 0, len(in)) + + offset := 0 + whitespace := true + placeholders := 1 + + for i := 0; i < len(in); i++ { + if in[i] == ' ' || in[i] == '\r' || in[i] == '\n' || in[i] == '\t' { + if whitespace { + offset = i + } else { + whitespace = true + out = append(out, in[offset:i]...) + offset = i + } + continue + } + if whitespace { + whitespace = false + if len(out) > 0 { + out = append(out, ' ') + } + offset = i + } + if in[i] == '?' { + out = append(out, in[offset:i]...) + offset = i + 1 + + out = append(out, '$') + out = append(out, strconv.Itoa(placeholders)...) + placeholders++ + } + } + if !whitespace { + out = append(out, in[offset:len(in)]...) + } + return string(out) +} + +func (iter *iterator) NextScan(dst ...interface{}) error { + if ok := iter.Next(); ok { + return iter.Scan(dst...) + } + if err := iter.Err(); err != nil { + return err + } + return mydb.ErrNoMoreRows +} + +func (iter *iterator) ScanOne(dst ...interface{}) error { + defer iter.Close() + return iter.NextScan(dst...) +} + +func (iter *iterator) Scan(dst ...interface{}) error { + if err := iter.Err(); err != nil { + return err + } + return iter.cursor.Scan(dst...) +} + +func (iter *iterator) setErr(err error) error { + iter.err = err + return iter.err +} + +func (iter *iterator) One(dst interface{}) error { + if err := iter.Err(); err != nil { + return err + } + defer iter.Close() + return iter.setErr(iter.next(dst)) +} + +func (iter *iterator) All(dst interface{}) error { + if err := iter.Err(); err != nil { + return err + } + defer iter.Close() + + // Fetching all results within the cursor. + if err := fetchRows(iter, dst); err != nil { + return iter.setErr(err) + } + + return nil +} + +func (iter *iterator) Err() (err error) { + return iter.err +} + +func (iter *iterator) Next(dst ...interface{}) bool { + if err := iter.Err(); err != nil { + return false + } + + if err := iter.next(dst...); err != nil { + // ignore mydb.ErrNoMoreRows, just break. + if !errors.Is(err, mydb.ErrNoMoreRows) { + _ = iter.setErr(err) + } + return false + } + + return true +} + +func (iter *iterator) next(dst ...interface{}) error { + if iter.cursor == nil { + return iter.setErr(mydb.ErrNoMoreRows) + } + + switch len(dst) { + case 0: + if ok := iter.cursor.Next(); !ok { + defer iter.Close() + err := iter.cursor.Err() + if err == nil { + err = mydb.ErrNoMoreRows + } + return err + } + return nil + case 1: + if err := fetchRow(iter, dst[0]); err != nil { + defer iter.Close() + return err + } + return nil + } + + return errors.New("Next does not currently supports more than one parameters") +} + +func (iter *iterator) Close() (err error) { + if iter.cursor != nil { + err = iter.cursor.Close() + iter.cursor = nil + } + return err +} + +func marshal(v interface{}) (interface{}, error) { + if m, isMarshaler := v.(mydb.Marshaler); isMarshaler { + var err error + if v, err = m.MarshalDB(); err != nil { + return nil, err + } + } + return v, nil +} + +func (fv *fieldValue) Len() int { + return len(fv.fields) +} + +func (fv *fieldValue) Swap(i, j int) { + fv.fields[i], fv.fields[j] = fv.fields[j], fv.fields[i] + fv.values[i], fv.values[j] = fv.values[j], fv.values[i] +} + +func (fv *fieldValue) Less(i, j int) bool { + return fv.fields[i] < fv.fields[j] +} + +type exprProxy struct { + db *sql.DB + t *exql.Template +} + +func (p *exprProxy) Context() context.Context { + return context.Background() +} + +func (p *exprProxy) StatementExec(ctx context.Context, stmt *exql.Statement, args ...interface{}) (sql.Result, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.ExecContext(p.db, ctx, s, args) +} + +func (p *exprProxy) StatementPrepare(ctx context.Context, stmt *exql.Statement) (*sql.Stmt, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.PrepareContext(p.db, ctx, s) +} + +func (p *exprProxy) StatementQuery(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Rows, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.QueryContext(p.db, ctx, s, args) +} + +func (p *exprProxy) StatementQueryRow(ctx context.Context, stmt *exql.Statement, args ...interface{}) (*sql.Row, error) { + s, err := stmt.Compile(p.t) + if err != nil { + return nil, err + } + return compat.QueryRowContext(p.db, ctx, s, args), nil +} + +var ( + _ = mydb.SQL(&sqlBuilder{}) + _ = exprDB(&exprProxy{}) +) + +func joinArguments(args ...[]interface{}) []interface{} { + total := 0 + for i := range args { + total += len(args[i]) + } + if total == 0 { + return nil + } + + flatten := make([]interface{}, 0, total) + for i := range args { + flatten = append(flatten, args[i]...) + } + return flatten +} diff --git a/internal/sqlbuilder/builder_test.go b/internal/sqlbuilder/builder_test.go new file mode 100644 index 0000000..4bb843b --- /dev/null +++ b/internal/sqlbuilder/builder_test.go @@ -0,0 +1,1510 @@ +package sqlbuilder + +import ( + "fmt" + "regexp" + "strings" + "testing" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/assert" +) + +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) +) + +func TestSelect(t *testing.T) { + + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Func("DATE")).String(), + ) + + assert.Equal( + `SELECT DATE() FOR UPDATE`, + b.Select(mydb.Func("DATE")).Amend(func(query string) string { + return query + " FOR UPDATE" + }).String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.SelectFrom("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT "bcolor" FROM "artist"`, + b.Select().Distinct("bcolor").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT * FROM "artist"`, + b.Select().Distinct().From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON("col1"), "col2" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON("col1")`), "col2").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON("col1") AS col1, "col2" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON("col1") AS col1`)).Distinct("col2").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON("col1") AS col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON("col1") AS col1`)).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + + assert.Equal( + `SELECT DISTINCT ON(SELECT foo FROM bar) col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON(?) col1`, mydb.Raw(`SELECT foo FROM bar`))).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + + { + q0 := b.Select("foo").From("bar") + assert.Equal( + `SELECT DISTINCT ON (SELECT "foo" FROM "bar") col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON ? col1`, q0)).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + } + + assert.Equal( + `SELECT DISTINCT ON (SELECT foo FROM bar, SELECT baz from qux) col1, "col2", "col3", "col4", "col5" FROM "artist"`, + b.Select().Distinct(mydb.Raw(`ON ? col1`, []interface{}{mydb.Raw(`SELECT foo FROM bar`), mydb.Raw(`SELECT baz from qux`)})).Columns("col2", "col3").Distinct("col4", "col5").From("artist").String(), + ) + + { + q := b.Select(). + Distinct(mydb.Raw(`ON ? col1`, []*mydb.RawExpr{mydb.Raw(`SELECT foo FROM bar WHERE id = ?`, 1), mydb.Raw(`SELECT baz from qux WHERE id = 2`)})). + Columns("col2", "col3"). + Distinct("col4", "col5").From("artist"). + Where("id", 3) + assert.Equal( + `SELECT DISTINCT ON (SELECT foo FROM bar WHERE id = $1, SELECT baz from qux WHERE id = 2) col1, "col2", "col3", "col4", "col5" FROM "artist" WHERE ("id" = $2)`, + q.String(), + ) + + assert.Equal( + []interface{}{1, 3}, + q.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000, 2000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + sel.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{1000}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000}, + sel.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN ? THEN 0 ELSE 1 END", []int{}) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN (NULL) THEN 0 ELSE 1 END") + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN (NULL) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + rawCase.Arguments(), + ) + } + + { + rawCase := mydb.Raw("CASE WHEN id IN (?, ?) THEN 0 ELSE 1 END", 1000, 2000) + sel := b.SelectFrom("artist").OrderBy(rawCase) + assert.Equal( + `SELECT * FROM "artist" ORDER BY CASE WHEN id IN ($1, $2) THEN 0 ELSE 1 END`, + sel.String(), + ) + assert.Equal( + []interface{}{1000, 2000}, + rawCase.Arguments(), + ) + } + + { + sel := b.Select(mydb.Func("DISTINCT", "name")).From("artist") + assert.Equal( + `SELECT DISTINCT($1) FROM "artist"`, + sel.String(), + ) + assert.Equal( + []interface{}{"name"}, + sel.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = $1)`, + b.Select().From("artist").Where(mydb.Cond{1: 1}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = ANY($1))`, + b.Select().From("artist").Where(mydb.Cond{1: mydb.Func("ANY", "name")}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (1 = ANY(column))`, + b.Select().From("artist").Where(mydb.Cond{1: mydb.Func("ANY", mydb.Raw("column"))}).String(), + ) + + { + q := b.Select().From("artist").Where(mydb.Cond{"id NOT IN": []int{0, -1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1, $2))`, + q.String(), + ) + assert.Equal( + []interface{}{0, -1}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Where(mydb.Cond{"id NOT IN": []int{-1}}) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" NOT IN ($1))`, + q.String(), + ) + assert.Equal( + []interface{}{-1}, + q.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2))`, + b.Select().From("artist").Where(mydb.Cond{"id IN": []int{0, -1}}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (("id" = $1 OR "id" = $2 OR "id" = $3))`, + b.Select().From("artist").Where( + mydb.Or( + mydb.Cond{"id": 1}, + mydb.Cond{"id": 2}, + mydb.Cond{"id": 3}, + ), + ).String(), + ) + + { + q := b.Select().From("artist").Where( + mydb.Or( + mydb.And(mydb.Cond{"a": 1}, mydb.Cond{"b": 2}, mydb.Cond{"c": 3}), + mydb.And(mydb.Cond{"d": 1}, mydb.Cond{"e": 2}, mydb.Cond{"f": 3}), + ), + ) + assert.Equal( + `SELECT * FROM "artist" WHERE ((("a" = $1 AND "b" = $2 AND "c" = $3) OR ("d" = $4 AND "e" = $5 AND "f" = $6)))`, + q.String(), + ) + assert.Equal( + []interface{}{1, 2, 3, 1, 2, 3}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Where( + mydb.Or( + mydb.And(mydb.Cond{"a": 1, "b": 2, "c": 3}), + mydb.And(mydb.Cond{"f": 6, "e": 5, "d": 4}), + ), + ) + assert.Equal( + `SELECT * FROM "artist" WHERE ((("a" = $1 AND "b" = $2 AND "c" = $3) OR ("d" = $4 AND "e" = $5 AND "f" = $6)))`, + q.String(), + ) + assert.Equal( + []interface{}{1, 2, 3, 4, 5, 6}, + q.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" WHERE ((("id" = $1 OR "id" = $2 OR "id" IS NULL) OR ("name" = $3 OR "name" = $4)))`, + b.Select().From("artist").Where( + mydb.Or( + mydb.Or( + mydb.Cond{"id": 1}, + mydb.Cond{"id": 2}, + mydb.Cond{"id IS": nil}, + ), + mydb.Or( + mydb.Cond{"name": "John"}, + mydb.Cond{"name": "Peter"}, + ), + ), + ).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ((("id" = $1 OR "id" = $2 OR "id" = $3 OR "id" = $4) AND ("name" = $5 AND "last_name" = $6) AND "age" > $7))`, + b.Select().From("artist").Where( + mydb.And( + mydb.Or( + mydb.Cond{"id": 1}, + mydb.Cond{"id": 2}, + mydb.Cond{"id": 3}, + ).Or( + mydb.Cond{"id": 4}, + ), + mydb.Or(), + mydb.And( + mydb.Cond{"name": "John"}, + mydb.Cond{"last_name": "Smith"}, + ), + mydb.And(), + ).And( + mydb.Cond{"age >": "20"}, + ), + ).String(), + ) + + assert.Equal( + `SELECT * FROM "artist"`, + b.Select().From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("name DESC").String(), + ) + + { + sel := b.Select().From("artist").OrderBy(mydb.Raw("id = ?", 1), "name DESC") + assert.Equal( + `SELECT * FROM "artist" ORDER BY id = $1 , "name" DESC`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.Select().From("artist").OrderBy(mydb.Func("RAND")) + assert.Equal( + `SELECT * FROM "artist" ORDER BY RAND()`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + assert.Equal( + `SELECT * FROM "artist" ORDER BY RAND()`, + b.Select().From("artist").OrderBy(mydb.Raw("RAND()")).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" DESC`, + b.Select().From("artist").OrderBy("-name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" ORDER BY "name" ASC`, + b.Select().From("artist").OrderBy("name ASC").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" OFFSET 5`, + b.Select().From("artist").Limit(-1).Offset(5).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 1 OFFSET 5`, + b.Select().From("artist").Limit(1).Offset(5).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist"`, + b.Select("id").From("artist").String(), + ) + + assert.Equal( + `SELECT "id", "name" FROM "artist"`, + b.Select("id", "name").From("artist").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("name" = $1)`, + b.SelectFrom("artist").Where("name", "Haruki").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name LIKE $1)`, + b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`, + b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1)`, + b.SelectFrom("artist").Where("id >", 2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`, + b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`, + b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (id IN ($1, $2, $3, $4) AND foo = $5 AND bar IN ($6, $7, $8))`, + b.SelectFrom("artist").Where("id IN ? AND foo = ? AND bar IN ?", []int{1, 9, 8, 7}, 28, []int{1, 2, 3}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE (name IS NOT NULL)`, + b.SelectFrom("artist").Where("name IS NOT NULL").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`, + b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT "id" FROM "artist" NATURAL JOIN "publication"`, + b.Select("id").From("artist").Join("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`, + b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3 AND a.sub_id = $4) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Where("a.sub_id = ?", 3).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3 AND a.id = $4) LIMIT 1`, + b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).And("a.id = ?", 3).Limit(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`, + b.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String(), + ) + + assert.Equal( + `SELECT * FROM "artist" CROSS JOIN "publication"`, + b.SelectFrom("artist").CrossJoin("publication").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" JOIN "publication" USING ("id")`, + b.SelectFrom("artist").Join("publication").Using("id").String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IS NULL)`, + b.SelectFrom("artist").Where(mydb.Cond{"id": nil}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN (NULL))`, + b.SelectFrom("artist").Where(mydb.Cond{"id": []int64{}}).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" IN ($1))`, + b.SelectFrom("artist").Where(mydb.Cond{"id": []int64{0}}).String(), + ) + + assert.Equal( + `SELECT COUNT(*) AS total FROM "user" AS "u" JOIN (SELECT DISTINCT user_id FROM user_profile) AS up ON (u.id = up.user_id)`, + b.Select(mydb.Raw(`COUNT(*) AS total`)).From("user u").Join(mydb.Raw("(SELECT DISTINCT user_id FROM user_profile) AS up")).On("u.id = up.user_id").String(), + ) + + { + q0 := b.Select("user_id").Distinct().From("user_profile") + + assert.Equal( + `SELECT COUNT(*) AS total FROM "user" AS "u" JOIN (SELECT DISTINCT "user_id" FROM "user_profile") AS up ON (u.id = up.user_id)`, + b.Select(mydb.Raw(`COUNT(*) AS total`)).From("user u").Join(mydb.Raw("? AS up", q0)).On("u.id = up.user_id").String(), + ) + } + + { + q0 := b.Select("user_id").Distinct().From("user_profile").Where("t", []int{1, 2, 4, 5}) + + assert.Equal( + []interface{}{1, 2, 4, 5}, + q0.Arguments(), + ) + + q1 := b.Select(mydb.Raw(`COUNT(*) AS total`)).From("user u").Join(mydb.Raw("? AS up", q0)).On("u.id = up.user_id AND foo = ?", 8) + + assert.Equal( + `SELECT COUNT(*) AS total FROM "user" AS "u" JOIN (SELECT DISTINCT "user_id" FROM "user_profile" WHERE ("t" IN ($1, $2, $3, $4))) AS up ON (u.id = up.user_id AND foo = $5)`, + q1.String(), + ) + + assert.Equal( + []interface{}{1, 2, 4, 5, 8}, + q1.Arguments(), + ) + } + + assert.Equal( + `SELECT DATE()`, + b.Select(mydb.Raw("DATE()")).String(), + ) + + { + sel := b.Select(mydb.Raw("CONCAT(?, ?)", "foo", "bar")) + assert.Equal( + `SELECT CONCAT($1, $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{"foo", "bar"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": mydb.Raw("1")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{mydb.Raw("1"): 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{mydb.Raw("1"): mydb.Raw("1")}) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Raw("1 = 1")) + assert.Equal( + `SELECT * FROM "foo" WHERE (1 = 1)`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Cond{"baz": mydb.Raw("CONCAT(?, ?)", "foo", "bar")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = CONCAT($2, $3))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "foo", "bar"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Raw("? = ANY(col)", "name")) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND $2 = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "name"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Cond{"name": mydb.Raw("ANY(col)")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "name" = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{"bar": 1}, mydb.Cond{mydb.Raw("CONCAT(?, ?)", "a", "b"): mydb.Raw("ANY(col)")}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND CONCAT($2, $3) = ANY(col))`, + sel.String(), + ) + assert.Equal( + []interface{}{1, "a", "b"}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).And(mydb.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").And(mydb.Cond{"bar": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).And(mydb.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar", 2).Where(mydb.Cond{"baz": 1}) + assert.Equal( + `SELECT * FROM "foo" WHERE ("bar" = $1 AND "baz" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{2, 1}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Raw("bar->'baz' = ?", true)) + assert.Equal( + `SELECT * FROM "foo" WHERE (bar->'baz' = $1)`, + sel.String(), + ) + assert.Equal( + []interface{}{true}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where(mydb.Cond{}).And(mydb.Cond{}) + assert.Equal( + `SELECT * FROM "foo"`, + sel.String(), + ) + assert.Equal( + []interface{}(nil), + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("bar = 1").And(mydb.Or( + mydb.Raw("fieldA ILIKE ?", `%a%`), + mydb.Raw("fieldB ILIKE ?", `%b%`), + )) + assert.Equal( + `SELECT * FROM "foo" WHERE (bar = 1 AND (fieldA ILIKE $1 OR fieldB ILIKE $2))`, + sel.String(), + ) + assert.Equal( + []interface{}{`%a%`, `%b%`}, + sel.Arguments(), + ) + } + + { + sel := b.SelectFrom("foo").Where("group_id", 1).And("user_id", 2) + assert.Equal( + `SELECT * FROM "foo" WHERE ("group_id" = $1 AND "user_id" = $2)`, + sel.String(), + ) + assert.Equal( + []interface{}{1, 2}, + sel.Arguments(), + ) + } + + { + s := `SUM(CASE WHEN foo in ? THEN 1 ELSE 0 END) AS _sum` + sel := b.Select("c1").Columns(mydb.Raw(s, []int{5, 4, 3, 2})).From("foo").Where("bar = ?", 1) + assert.Equal( + `SELECT "c1", SUM(CASE WHEN foo in ($1, $2, $3, $4) THEN 1 ELSE 0 END) AS _sum FROM "foo" WHERE (bar = $5)`, + sel.String(), + ) + assert.Equal( + []interface{}{5, 4, 3, 2, 1}, + sel.Arguments(), + ) + } + + { + s := `SUM(CASE WHEN foo in ? THEN 1 ELSE 0 END) AS _sum` + sel := b.Select("c1").Columns(mydb.Raw(s, []int{5, 4, 3, 2})).From("foo").Where("bar = ?", 1) + sel2 := b.SelectFrom(sel).As("subquery").Where(mydb.Cond{"foo": "bar"}).OrderBy("subquery.seq") + assert.Equal( + `SELECT * FROM (SELECT "c1", SUM(CASE WHEN foo in ($1, $2, $3, $4) THEN 1 ELSE 0 END) AS _sum FROM "foo" WHERE (bar = $5)) AS "subquery" WHERE ("foo" = $6) ORDER BY "subquery"."seq" ASC`, + sel2.String(), + ) + assert.Equal( + []interface{}{5, 4, 3, 2, 1, "bar"}, + sel2.Arguments(), + ) + } + + { + series := b.Select( + mydb.Raw("start + interval ? - interval '1s' AS end", "1 day"), + ).From( + b.Select( + mydb.Raw("generate_series(?::timestamp, ?::timestamp, ?::interval) AS start", 1, 2, 3), + ), + ).As("series") + + assert.Equal( + []interface{}{"1 day", 1, 2, 3}, + series.Arguments(), + ) + + distinct := b.Select().Distinct( + mydb.Raw(`ON(dt.email) SUBSTRING(email,(POSITION('@' in email) + 1),252) AS email_domain`), + "dt.event_type AS event_type", + mydb.Raw("count(dt.*) AS count"), + "intervals.start AS start", + "intervals.end AS start", + ).From("email_events AS dt"). + RightJoin("intervals").On("dt.ts BETWEEN intervals.stast AND intervals.END AND dt.hub_id = ? AND dt.object_id = ?", 67, 68). + GroupBy("email_domain", "event_type", "start"). + OrderBy("email", "start", "event_type") + + sq, args := Preprocess( + `WITH intervals AS ? ?`, + []interface{}{ + series, + distinct, + }, + ) + + assert.Equal( + stripWhitespace(` + WITH intervals AS (SELECT start + interval ? - interval '1s' AS end FROM (SELECT generate_series(?::timestamp, ?::timestamp, ?::interval) AS start) AS "series") + (SELECT DISTINCT ON(dt.email) SUBSTRING(email,(POSITION('@' in email) + 1),252) AS email_domain, "dt"."event_type" AS "event_type", count(dt.*) AS count, "intervals"."start" AS "start", "intervals"."end" AS "start" + FROM "email_events" AS "dt" + RIGHT JOIN "intervals" ON (dt.ts BETWEEN intervals.stast AND intervals.END AND dt.hub_id = ? AND dt.object_id = ?) + GROUP BY "email_domain", "event_type", "start" + ORDER BY "email" ASC, "start" ASC, "event_type" ASC)`), + stripWhitespace(sq), + ) + + assert.Equal( + []interface{}{"1 day", 1, 2, 3, 67, 68}, + args, + ) + } + + { + sq := b. + Select("user_id"). + From("user_access"). + Where(mydb.Cond{"hub_id": 3}) + + // Don't reassign + _ = sq.And(mydb.Cond{"role": []int{1, 2}}) + + assert.Equal( + `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1)`, + sq.String(), + ) + + assert.Equal( + []interface{}{3}, + sq.Arguments(), + ) + + // Reassign + sq = sq.And(mydb.Cond{"role": []int{1, 2}}) + + assert.Equal( + `SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))`, + sq.String(), + ) + + assert.Equal( + []interface{}{3, 1, 2}, + sq.Arguments(), + ) + + cond := mydb.Or( + mydb.Raw("a.id IN ?", sq), + ) + + cond = cond.Or(mydb.Cond{"ml.mailing_list_id": []int{4, 5, 6}}) + + sel := b. + Select(mydb.Raw("DISTINCT ON(a.id) a.id"), mydb.Raw("COALESCE(NULLIF(ml.name,''), a.name) as name"), "a.email"). + From("mailing_list_recipients ml"). + FullJoin("accounts a").On("a.id = ml.user_id"). + Where(cond) + + search := "word" + sel = sel.And(mydb.Or( + mydb.Raw("COALESCE(NULLIF(ml.name,''), a.name) ILIKE ?", fmt.Sprintf("%%%s%%", search)), + mydb.Cond{"a.email ILIKE": fmt.Sprintf("%%%s%%", search)}, + )) + + assert.Equal( + `SELECT DISTINCT ON(a.id) a.id, COALESCE(NULLIF(ml.name,''), a.name) as name, "a"."email" FROM "mailing_list_recipients" AS "ml" FULL JOIN "accounts" AS "a" ON (a.id = ml.user_id) WHERE ((a.id IN (SELECT "user_id" FROM "user_access" WHERE ("hub_id" = $1 AND "role" IN ($2, $3))) OR "ml"."mailing_list_id" IN ($4, $5, $6)) AND (COALESCE(NULLIF(ml.name,''), a.name) ILIKE $7 OR "a"."email" ILIKE $8))`, + sel.String(), + ) + + assert.Equal( + []interface{}{3, 1, 2, 4, 5, 6, `%word%`, `%word%`}, + sel.Arguments(), + ) + + { + sel := b.Select(mydb.Func("FOO", "A", "B", 1)).From("accounts").Where(mydb.Cond{"time": mydb.Func("FUNCTION", "20170103", "YYYYMMDD", 1, "E")}) + + assert.Equal( + `SELECT FOO($1, $2, $3) FROM "accounts" WHERE ("time" = FUNCTION($4, $5, $6, $7))`, + sel.String(), + ) + assert.Equal( + []interface{}{"A", "B", 1, "20170103", "YYYYMMDD", 1, "E"}, + sel.Arguments(), + ) + } + + { + + sel := b.Select(mydb.Func("FOO", "A", "B", 1)).From("accounts").Where(mydb.Cond{mydb.Func("FUNCTION", "20170103", "YYYYMMDD", 1, "E"): mydb.Func("FOO", 1)}) + + assert.Equal( + `SELECT FOO($1, $2, $3) FROM "accounts" WHERE (FUNCTION($4, $5, $6, $7) = FOO($8))`, + sel.String(), + ) + assert.Equal( + []interface{}{"A", "B", 1, "20170103", "YYYYMMDD", 1, "E", 1}, + sel.Arguments(), + ) + } + } +} + +func TestInsert(t *testing.T) { + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(10, "Ryuichi Sakamoto"). + Values(11, "Alondra de la Parra"). + Values(12, "Haruki Murakami"). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`, + b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Amend(func(query string) string { + return query + ` RETURNING "id"` + }).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`, + b.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String(), + ) + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, $2), ($3, $4), ($5, $6)`, + b.InsertInto("artist"). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{13, "Alondra de la Parra"}). + Values(artistStruct{14, "Haruki Murakami"}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{0, ""}). + Values(artistStruct{12, "Chavela Vargas"}). + Values(artistStruct{0, "Alondra de la Parra"}). + Values(artistStruct{14, ""}). + Values(artistStruct{0, ""}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES (DEFAULT, DEFAULT), ($1, $2), (DEFAULT, $3), ($4, DEFAULT), (DEFAULT, DEFAULT)`, + q.String(), + ) + + assert.Equal( + []interface{}{12, "Chavela Vargas", "Alondra de la Parra", 14}, + q.Arguments(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1)`, + b.InsertInto("artist"). + Values(artistStruct{Name: "Chavela Vargas"}). + String(), + ) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1)`, + b.InsertInto("artist"). + Values(artistStruct{ID: 1}). + String(), + ) + } + + { + type artistStruct struct { + ID int `db:"id,omitempty"` + Name string `db:"name,omitempty"` + } + + { + q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}) + + assert.Equal( + `INSERT INTO "artist" ("name") VALUES ($1)`, + q.String(), + ) + assert.Equal( + []interface{}{"Chavela Vargas"}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{Name: "Chavela Vargas"}).Values(artistStruct{Name: "Alondra de la Parra"}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES (DEFAULT, $1), (DEFAULT, $2)`, + q.String(), + ) + assert.Equal( + []interface{}{"Chavela Vargas", "Alondra de la Parra"}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{ID: 1}) + + assert.Equal( + `INSERT INTO "artist" ("id") VALUES ($1)`, + q.String(), + ) + + assert.Equal( + []interface{}{1}, + q.Arguments(), + ) + } + + { + q := b.InsertInto("artist").Values(artistStruct{ID: 1}).Values(artistStruct{ID: 2}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES ($1, DEFAULT), ($2, DEFAULT)`, + q.String(), + ) + + assert.Equal( + []interface{}{1, 2}, + q.Arguments(), + ) + } + + } + + { + intRef := func(i int) *int { + if i == 0 { + return nil + } + return &i + } + + strRef := func(s string) *string { + if s == "" { + return nil + } + return &s + } + + type artistStruct struct { + ID *int `db:"id,omitempty"` + Name *string `db:"name,omitempty"` + } + + q := b.InsertInto("artist"). + Values(artistStruct{intRef(0), strRef("")}). + Values(artistStruct{intRef(12), strRef("Chavela Vargas")}). + Values(artistStruct{intRef(0), strRef("Alondra de la Parra")}). + Values(artistStruct{intRef(14), strRef("")}). + Values(artistStruct{intRef(0), strRef("")}) + + assert.Equal( + `INSERT INTO "artist" ("id", "name") VALUES (DEFAULT, DEFAULT), ($1, $2), (DEFAULT, $3), ($4, DEFAULT), (DEFAULT, DEFAULT)`, + q.String(), + ) + + assert.Equal( + []interface{}{intRef(12), strRef("Chavela Vargas"), strRef("Alondra de la Parra"), intRef(14)}, + q.Arguments(), + ) + } + + assert.Equal( + `INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`, + b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(), + ) + + assert.Equal( + `INSERT INTO "artist" VALUES (default)`, + b.InsertInto("artist").String(), + ) +} + +func TestUpdate(t *testing.T) { + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + assert.Equal( + `UPDATE "artist" SET "name" = $1`, + b.Update("artist").Set("name", "Artist").String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 RETURNING "name"`, + b.Update("artist").Set("name", "Artist").Amend(func(query string) string { + return query + ` RETURNING "name"` + }).String(), + ) + + { + idSlice := []int64{8, 7, 6} + q := b.Update("artist").Set(mydb.Cond{"some_column": 10}).Where(mydb.Cond{"id": 1}, mydb.Cond{"another_val": idSlice}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN ($3, $4, $5))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1, int64(8), int64(7), int64(6)}, + q.Arguments(), + ) + } + + { + idSlice := []int64{} + q := b.Update("artist").Set(mydb.Cond{"some_column": 10}).Where(mydb.Cond{"id": 1}, mydb.Cond{"another_val": idSlice}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1}, + q.Arguments(), + ) + } + + { + idSlice := []int64{} + q := b.Update("artist").Where(mydb.Cond{"id": 1}, mydb.Cond{"another_val": idSlice}).Set(mydb.Cond{"some_column": 10}) + assert.Equal( + `UPDATE "artist" SET "some_column" = $1 WHERE ("id" = $2 AND "another_val" IN (NULL))`, + q.String(), + ) + assert.Equal( + []interface{}{10, 1}, + q.Arguments(), + ) + } + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`, + b.Update("artist").Where(mydb.Cond{"id <": 5}).Set(struct { + Nombre string `db:"name"` + }{"Artist"}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`, + b.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(), + ) + + assert.Equal( + `UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`, + b.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String(), + ) + + { + q := b.Update("posts").Set("column = ?", "foo") + + assert.Equal( + `UPDATE "posts" SET "column" = $1`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo"}, + q.Arguments(), + ) + } + + { + q := b.Update("posts").Set(mydb.Raw("column = ?", "foo")) + + assert.Equal( + `UPDATE "posts" SET column = $1`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo"}, + q.Arguments(), + ) + } + + { + q := b.Update("posts").Set("foo = bar") + + assert.Equal( + []interface{}(nil), + q.Arguments(), + ) + + assert.Equal( + `UPDATE "posts" SET "foo" = bar`, + q.String(), + ) + } + + { + q := b.Update("posts").Set( + mydb.Cond{"tags": mydb.Raw("array_remove(tags, ?)", "foo")}, + ).Where(mydb.Raw("hub_id = ? AND ? = ANY(tags) AND ? = ANY(tags)", 1, "bar", "baz")) + + assert.Equal( + `UPDATE "posts" SET "tags" = array_remove(tags, $1) WHERE (hub_id = $2 AND $3 = ANY(tags) AND $4 = ANY(tags))`, + q.String(), + ) + + assert.Equal( + []interface{}{"foo", 1, "bar", "baz"}, + q.Arguments(), + ) + } +} + +func TestDelete(t *testing.T) { + bt := WithTemplate(&testTemplate) + assert := assert.New(t) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1)`, + bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (name = $1) RETURNING 1`, + bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Amend(func(query string) string { + return fmt.Sprintf("%s RETURNING 1", query) + }).String(), + ) + + assert.Equal( + `DELETE FROM "artist" WHERE (id > 5)`, + bt.DeleteFrom("artist").Where("id > 5").String(), + ) +} + +func TestPaginate(t *testing.T) { + b := &sqlBuilder{t: newTemplateWithUtils(&testTemplate)} + assert := assert.New(t) + + // Limit, offset + assert.Equal( + `SELECT * FROM "artist" LIMIT 10`, + b.Select().From("artist").Paginate(10).Page(1).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 10 OFFSET 10`, + b.Select().From("artist").Paginate(10).Page(2).String(), + ) + + assert.Equal( + `SELECT * FROM "artist" LIMIT 5 OFFSET 110`, + b.Select().From("artist").Paginate(5).Page(23).String(), + ) + + // Cursor + assert.Equal( + `SELECT * FROM "artist" ORDER BY "id" ASC LIMIT 10`, + b.Select().From("artist").Paginate(10).Cursor("id").String(), + ) + + { + q := b.Select().From("artist").Paginate(10).Cursor("id").NextPage(3) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" > $1) ORDER BY "id" ASC LIMIT 10`, + q.String(), + ) + assert.Equal( + []interface{}{3}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Paginate(10).Cursor("id").PrevPage(30) + assert.Equal( + `SELECT * FROM (SELECT * FROM "artist" WHERE ("id" < $1) ORDER BY "id" DESC LIMIT 10) AS p0 ORDER BY "id" ASC`, + q.String(), + ) + assert.Equal( + []interface{}{30}, + q.Arguments(), + ) + } + + // Cursor reversed + assert.Equal( + `SELECT * FROM "artist" ORDER BY "id" DESC LIMIT 10`, + b.Select().From("artist").Paginate(10).Cursor("-id").String(), + ) + + { + q := b.Select().From("artist").Paginate(10).Cursor("-id").NextPage(3) + assert.Equal( + `SELECT * FROM "artist" WHERE ("id" < $1) ORDER BY "id" DESC LIMIT 10`, + q.String(), + ) + assert.Equal( + []interface{}{3}, + q.Arguments(), + ) + } + + { + q := b.Select().From("artist").Paginate(10).Cursor("-id").PrevPage(30) + assert.Equal( + `SELECT * FROM (SELECT * FROM "artist" WHERE ("id" > $1) ORDER BY "id" ASC LIMIT 10) AS p0 ORDER BY "id" DESC`, + q.String(), + ) + assert.Equal( + []interface{}{30}, + q.Arguments(), + ) + } +} + +func BenchmarkDelete1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String() + } +} + +func BenchmarkDelete2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.DeleteFrom("artist").Where("id > 5").String() + } +} + +func BenchmarkInsert1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(10, "Ryuichi Sakamoto").Values(11, "Alondra de la Parra").Values(12, "Haruki Murakami").String() + } +} + +func BenchmarkInsert2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String() + } +} + +func BenchmarkInsert3(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String() + } +} + +func BenchmarkInsert4(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String() + } +} + +func BenchmarkInsert5(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.InsertInto("artist").Values(struct { + ID int `db:"id"` + Name string `db:"name"` + }{12, "Chavela Vargas"}).String() + } +} + +func BenchmarkSelect1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Select().From("artist").OrderBy("name DESC").String() + } +} + +func BenchmarkSelect2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String() + } +} + +func BenchmarkSelect3(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String() + } +} + +func BenchmarkSelect4(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String() + } +} + +func BenchmarkSelect5(b *testing.B) { + t := WithTemplate(&testTemplate) + b.ResetTimer() + for n := 0; n < b.N; n++ { + _ = t.SelectFrom("artist a"). + LeftJoin("publication p1").On("p1.id = a.id"). + RightJoin("publication p2").On("p2.id = a.id"). + String() + } +} + +func BenchmarkUpdate1(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set("name", "Artist").String() + } +} + +func BenchmarkUpdate2(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String() + } +} + +func BenchmarkUpdate3(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set(struct { + Nombre string `db:"name"` + }{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String() + } +} + +func BenchmarkUpdate4(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String() + } +} + +func BenchmarkUpdate5(b *testing.B) { + bt := WithTemplate(&testTemplate) + for n := 0; n < b.N; n++ { + _ = bt.Update("artist").Set( + "name = ? || ' ' || ? || id", "Artist", "#", + "id = id + ?", 10, + ).Where("id > ?", 0).String() + } +} + +func stripWhitespace(in string) string { + q := reInvisibleChars.ReplaceAllString(in, ` `) + return strings.TrimSpace(q) +} diff --git a/internal/sqlbuilder/comparison.go b/internal/sqlbuilder/comparison.go new file mode 100644 index 0000000..915aa8a --- /dev/null +++ b/internal/sqlbuilder/comparison.go @@ -0,0 +1,122 @@ +package sqlbuilder + +import ( + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +var comparisonOperators = map[adapter.ComparisonOperator]string{ + adapter.ComparisonOperatorEqual: "=", + adapter.ComparisonOperatorNotEqual: "!=", + + adapter.ComparisonOperatorLessThan: "<", + adapter.ComparisonOperatorGreaterThan: ">", + + adapter.ComparisonOperatorLessThanOrEqualTo: "<=", + adapter.ComparisonOperatorGreaterThanOrEqualTo: ">=", + + adapter.ComparisonOperatorBetween: "BETWEEN", + adapter.ComparisonOperatorNotBetween: "NOT BETWEEN", + + adapter.ComparisonOperatorIn: "IN", + adapter.ComparisonOperatorNotIn: "NOT IN", + + adapter.ComparisonOperatorIs: "IS", + adapter.ComparisonOperatorIsNot: "IS NOT", + + adapter.ComparisonOperatorLike: "LIKE", + adapter.ComparisonOperatorNotLike: "NOT LIKE", + + adapter.ComparisonOperatorRegExp: "REGEXP", + adapter.ComparisonOperatorNotRegExp: "NOT REGEXP", +} + +type operatorWrapper struct { + tu *templateWithUtils + cv *exql.ColumnValue + + op *adapter.Comparison + v interface{} +} + +func (ow *operatorWrapper) cmp() *adapter.Comparison { + if ow.op != nil { + return ow.op + } + + if ow.cv.Operator != "" { + return mydb.Op(ow.cv.Operator, ow.v).Comparison + } + + if ow.v == nil { + return mydb.Is(nil).Comparison + } + + args, isSlice := toInterfaceArguments(ow.v) + if isSlice { + return mydb.In(args...).Comparison + } + + return mydb.Eq(ow.v).Comparison +} + +func (ow *operatorWrapper) preprocess() (string, []interface{}) { + placeholder := "?" + + column, err := ow.cv.Column.Compile(ow.tu.Template) + if err != nil { + panic(fmt.Sprintf("could not compile column: %v", err.Error())) + } + + c := ow.cmp() + + op := ow.tu.comparisonOperatorMapper(c.Operator()) + + var args []interface{} + + switch c.Operator() { + case adapter.ComparisonOperatorNone: + panic("no operator given") + case adapter.ComparisonOperatorCustom: + op = c.CustomOperator() + case adapter.ComparisonOperatorIn, adapter.ComparisonOperatorNotIn: + values := c.Value().([]interface{}) + if len(values) < 1 { + placeholder, args = "(NULL)", []interface{}{} + break + } + placeholder, args = "(?"+strings.Repeat(", ?", len(values)-1)+")", values + case adapter.ComparisonOperatorIs, adapter.ComparisonOperatorIsNot: + switch c.Value() { + case nil: + placeholder, args = "NULL", []interface{}{} + case false: + placeholder, args = "FALSE", []interface{}{} + case true: + placeholder, args = "TRUE", []interface{}{} + } + case adapter.ComparisonOperatorBetween, adapter.ComparisonOperatorNotBetween: + values := c.Value().([]interface{}) + placeholder, args = "? AND ?", []interface{}{values[0], values[1]} + case adapter.ComparisonOperatorEqual: + v := c.Value() + if b, ok := v.([]byte); ok { + v = string(b) + } + args = []interface{}{v} + } + + if args == nil { + args = []interface{}{c.Value()} + } + + if strings.Contains(op, ":column") { + return strings.Replace(op, ":column", column, -1), args + } + + return column + " " + op + " " + placeholder, args +} diff --git a/internal/sqlbuilder/convert.go b/internal/sqlbuilder/convert.go new file mode 100644 index 0000000..5036465 --- /dev/null +++ b/internal/sqlbuilder/convert.go @@ -0,0 +1,166 @@ +package sqlbuilder + +import ( + "bytes" + "database/sql/driver" + "reflect" + + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +var ( + sqlDefault = &exql.Raw{Value: "DEFAULT"} +) + +func expandQuery(in []byte, inArgs []interface{}) ([]byte, []interface{}) { + out := make([]byte, 0, len(in)) + outArgs := make([]interface{}, 0, len(inArgs)) + + i := 0 + for i < len(in) && len(inArgs) > 0 { + if in[i] == '?' { + out = append(out, in[:i]...) + in = in[i+1:] + i = 0 + + replace, replaceArgs := expandArgument(inArgs[0]) + inArgs = inArgs[1:] + + if len(replace) > 0 { + replace, replaceArgs = expandQuery(replace, replaceArgs) + out = append(out, replace...) + } else { + out = append(out, '?') + } + + outArgs = append(outArgs, replaceArgs...) + continue + } + i = i + 1 + } + + if len(out) < 1 { + return in, inArgs + } + + out = append(out, in[:len(in)]...) + in = nil + + outArgs = append(outArgs, inArgs[:len(inArgs)]...) + inArgs = nil + + return out, outArgs +} + +func expandArgument(arg interface{}) ([]byte, []interface{}) { + values, isSlice := toInterfaceArguments(arg) + + if isSlice { + if len(values) == 0 { + return []byte("(NULL)"), nil + } + buf := bytes.Repeat([]byte(" ?,"), len(values)) + buf[0] = '(' + buf[len(buf)-1] = ')' + return buf, values + } + + if len(values) == 1 { + switch t := arg.(type) { + case *adapter.RawExpr: + return expandQuery([]byte(t.Raw()), t.Arguments()) + case hasPaginator: + p, err := t.Paginator() + if err == nil { + return append([]byte{'('}, append([]byte(p.String()), ')')...), p.Arguments() + } + panic(err.Error()) + case isCompilable: + s, err := t.Compile() + if err == nil { + return append([]byte{'('}, append([]byte(s), ')')...), t.Arguments() + } + panic(err.Error()) + } + } else if len(values) == 0 { + return []byte("NULL"), nil + } + + return nil, []interface{}{arg} +} + +// toInterfaceArguments converts the given value into an array of interfaces. +func toInterfaceArguments(value interface{}) (args []interface{}, isSlice bool) { + if value == nil { + return nil, false + } + + switch t := value.(type) { + case driver.Valuer: + return []interface{}{t}, false + } + + v := reflect.ValueOf(value) + if v.Type().Kind() == reflect.Slice { + var i, total int + + // Byte slice gets transformed into a string. + if v.Type().Elem().Kind() == reflect.Uint8 { + return []interface{}{string(v.Bytes())}, false + } + + total = v.Len() + args = make([]interface{}, total) + for i = 0; i < total; i++ { + args[i] = v.Index(i).Interface() + } + return args, true + } + + return []interface{}{value}, false +} + +// toColumnsValuesAndArguments maps the given columnNames and columnValues into +// expr's Columns and Values, it also extracts and returns query arguments. +func toColumnsValuesAndArguments(columnNames []string, columnValues []interface{}) (*exql.Columns, *exql.Values, []interface{}, error) { + var arguments []interface{} + + columns := new(exql.Columns) + + columns.Columns = make([]exql.Fragment, 0, len(columnNames)) + for i := range columnNames { + columns.Columns = append(columns.Columns, exql.ColumnWithName(columnNames[i])) + } + + values := new(exql.Values) + + arguments = make([]interface{}, 0, len(columnValues)) + values.Values = make([]exql.Fragment, 0, len(columnValues)) + + for i := range columnValues { + switch v := columnValues[i].(type) { + case *exql.Raw, exql.Raw: + values.Values = append(values.Values, sqlDefault) + case *exql.Value: + // Adding value. + values.Values = append(values.Values, v) + case exql.Value: + // Adding value. + values.Values = append(values.Values, &v) + default: + // Adding both value and placeholder. + values.Values = append(values.Values, sqlPlaceholder) + arguments = append(arguments, v) + } + } + + return columns, values, arguments, nil +} + +// Preprocess expands arguments that needs to be expanded and compiles a query +// into a single string. +func Preprocess(in string, args []interface{}) (string, []interface{}) { + b, args := expandQuery([]byte(in), args) + return string(b), args +} diff --git a/internal/sqlbuilder/custom_types.go b/internal/sqlbuilder/custom_types.go new file mode 100644 index 0000000..9b5e0cf --- /dev/null +++ b/internal/sqlbuilder/custom_types.go @@ -0,0 +1,11 @@ +package sqlbuilder + +import ( + "database/sql" + "database/sql/driver" +) + +type ScannerValuer interface { + sql.Scanner + driver.Valuer +} diff --git a/internal/sqlbuilder/delete.go b/internal/sqlbuilder/delete.go new file mode 100644 index 0000000..396be02 --- /dev/null +++ b/internal/sqlbuilder/delete.go @@ -0,0 +1,195 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type deleterQuery struct { + table string + limit int + + where *exql.Where + whereArgs []interface{} + + amendFn func(string) string +} + +func (dq *deleterQuery) and(b *sqlBuilder, terms ...interface{}) error { + where, whereArgs := b.t.toWhereWithArguments(terms) + + if dq.where == nil { + dq.where, dq.whereArgs = &exql.Where{}, []interface{}{} + } + dq.where.Append(&where) + dq.whereArgs = append(dq.whereArgs, whereArgs...) + + return nil +} + +func (dq *deleterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Delete, + Table: exql.TableWithName(dq.table), + } + + if dq.where != nil { + stmt.Where = dq.where + } + + if dq.limit != 0 { + stmt.Limit = exql.Limit(dq.limit) + } + + stmt.SetAmendment(dq.amendFn) + + return stmt +} + +type deleter struct { + builder *sqlBuilder + + fn func(*deleterQuery) error + prev *deleter +} + +var _ = immutable.Immutable(&deleter{}) + +func (del *deleter) SQL() *sqlBuilder { + if del.prev == nil { + return del.builder + } + return del.prev.SQL() +} + +func (del *deleter) template() *exql.Template { + return del.SQL().t.Template +} + +func (del *deleter) String() string { + s, err := del.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (del *deleter) setTable(table string) *deleter { + return del.frame(func(uq *deleterQuery) error { + uq.table = table + return nil + }) +} + +func (del *deleter) frame(fn func(*deleterQuery) error) *deleter { + return &deleter{prev: del, fn: fn} +} + +func (del *deleter) Where(terms ...interface{}) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + dq.where, dq.whereArgs = &exql.Where{}, []interface{}{} + return dq.and(del.SQL(), terms...) + }) +} + +func (del *deleter) And(terms ...interface{}) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + return dq.and(del.SQL(), terms...) + }) +} + +func (del *deleter) Limit(limit int) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + dq.limit = limit + return nil + }) +} + +func (del *deleter) Amend(fn func(string) string) mydb.Deleter { + return del.frame(func(dq *deleterQuery) error { + dq.amendFn = fn + return nil + }) +} + +func (dq *deleterQuery) arguments() []interface{} { + return joinArguments(dq.whereArgs) +} + +func (del *deleter) Arguments() []interface{} { + dq, err := del.build() + if err != nil { + return nil + } + return dq.arguments() +} + +func (del *deleter) Prepare() (*sql.Stmt, error) { + return del.PrepareContext(del.SQL().sess.Context()) +} + +func (del *deleter) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + dq, err := del.build() + if err != nil { + return nil, err + } + return del.SQL().sess.StatementPrepare(ctx, dq.statement()) +} + +func (del *deleter) Exec() (sql.Result, error) { + return del.ExecContext(del.SQL().sess.Context()) +} + +func (del *deleter) ExecContext(ctx context.Context) (sql.Result, error) { + dq, err := del.build() + if err != nil { + return nil, err + } + return del.SQL().sess.StatementExec(ctx, dq.statement(), dq.arguments()...) +} + +func (del *deleter) statement() (*exql.Statement, error) { + iq, err := del.build() + if err != nil { + return nil, err + } + return iq.statement(), nil +} + +func (del *deleter) build() (*deleterQuery, error) { + dq, err := immutable.FastForward(del) + if err != nil { + return nil, err + } + return dq.(*deleterQuery), nil +} + +func (del *deleter) Compile() (string, error) { + s, err := del.statement() + if err != nil { + return "", err + } + return s.Compile(del.template()) +} + +func (del *deleter) Prev() immutable.Immutable { + if del == nil { + return nil + } + return del.prev +} + +func (del *deleter) Fn(in interface{}) error { + if del.fn == nil { + return nil + } + return del.fn(in.(*deleterQuery)) +} + +func (del *deleter) Base() interface{} { + return &deleterQuery{} +} diff --git a/internal/sqlbuilder/errors.go b/internal/sqlbuilder/errors.go new file mode 100644 index 0000000..5c5a723 --- /dev/null +++ b/internal/sqlbuilder/errors.go @@ -0,0 +1,14 @@ +package sqlbuilder + +import ( + "errors" +) + +// Common error messages. +var ( + ErrExpectingPointer = errors.New(`argument must be an address`) + ErrExpectingSlicePointer = errors.New(`argument must be a slice address`) + ErrExpectingSliceMapStruct = errors.New(`argument must be a slice address of maps or structs`) + ErrExpectingMapOrStruct = errors.New(`argument must be either a map or a struct`) + ErrExpectingPointerToEitherMapOrStruct = errors.New(`expecting a pointer to either a map or a struct`) +) diff --git a/internal/sqlbuilder/fetch.go b/internal/sqlbuilder/fetch.go new file mode 100644 index 0000000..6f686b4 --- /dev/null +++ b/internal/sqlbuilder/fetch.go @@ -0,0 +1,234 @@ +package sqlbuilder + +import ( + "reflect" + + "database/sql" + "database/sql/driver" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/reflectx" +) + +type sessValueConverter interface { + ConvertValue(interface{}) interface{} +} + +type valueConverter interface { + ConvertValue(in interface{}) (out interface { + sql.Scanner + driver.Valuer + }) +} + +var Mapper = reflectx.NewMapper("db") + +// fetchRow receives a *sql.Rows value and tries to map all the rows into a +// single struct given by the pointer `dst`. +func fetchRow(iter *iterator, dst interface{}) error { + var columns []string + var err error + + rows := iter.cursor + + dstv := reflect.ValueOf(dst) + + if dstv.IsNil() || dstv.Kind() != reflect.Ptr { + return ErrExpectingPointer + } + + itemV := dstv.Elem() + + if columns, err = rows.Columns(); err != nil { + return err + } + + reset(dst) + + next := rows.Next() + + if !next { + if err = rows.Err(); err != nil { + return err + } + return mydb.ErrNoMoreRows + } + + itemT := itemV.Type() + item, err := fetchResult(iter, itemT, columns) + if err != nil { + return err + } + + if itemT.Kind() == reflect.Ptr { + itemV.Set(item) + } else { + itemV.Set(reflect.Indirect(item)) + } + + return nil +} + +// fetchRows receives a *sql.Rows value and tries to map all the rows into a +// slice of structs given by the pointer `dst`. +func fetchRows(iter *iterator, dst interface{}) error { + var err error + rows := iter.cursor + defer rows.Close() + + // Destination. + dstv := reflect.ValueOf(dst) + + if dstv.IsNil() || dstv.Kind() != reflect.Ptr { + return ErrExpectingPointer + } + + if dstv.Elem().Kind() != reflect.Slice { + return ErrExpectingSlicePointer + } + + if dstv.Kind() != reflect.Ptr || dstv.Elem().Kind() != reflect.Slice || dstv.IsNil() { + return ErrExpectingSliceMapStruct + } + + var columns []string + if columns, err = rows.Columns(); err != nil { + return err + } + + slicev := dstv.Elem() + itemT := slicev.Type().Elem() + + reset(dst) + + for rows.Next() { + item, err := fetchResult(iter, itemT, columns) + if err != nil { + return err + } + if itemT.Kind() == reflect.Ptr { + slicev = reflect.Append(slicev, item) + } else { + slicev = reflect.Append(slicev, reflect.Indirect(item)) + } + } + + dstv.Elem().Set(slicev) + + return rows.Err() +} + +func fetchResult(iter *iterator, itemT reflect.Type, columns []string) (reflect.Value, error) { + + var item reflect.Value + var err error + rows := iter.cursor + + objT := itemT + + switch objT.Kind() { + case reflect.Map: + item = reflect.MakeMap(objT) + case reflect.Struct: + item = reflect.New(objT) + case reflect.Ptr: + objT = itemT.Elem() + if objT.Kind() != reflect.Struct { + return item, ErrExpectingMapOrStruct + } + item = reflect.New(objT) + default: + return item, ErrExpectingMapOrStruct + } + + switch objT.Kind() { + case reflect.Struct: + + values := make([]interface{}, len(columns)) + typeMap := Mapper.TypeMap(itemT) + fieldMap := typeMap.Names + + for i, k := range columns { + fi, ok := fieldMap[k] + if !ok { + values[i] = new(interface{}) + continue + } + + // Check for deprecated jsonb tag. + if _, hasJSONBTag := fi.Options["jsonb"]; hasJSONBTag { + return item, errDeprecatedJSONBTag + } + + f := reflectx.FieldByIndexes(item, fi.Index) + + // TODO: type switch + scanner + + if w, ok := f.Interface().(valueConverter); ok { + wrapper := w.ConvertValue(f.Addr().Interface()) + z := reflect.ValueOf(wrapper) + values[i] = z.Interface() + continue + } else { + values[i] = f.Addr().Interface() + } + + if unmarshaler, ok := values[i].(mydb.Unmarshaler); ok { + values[i] = scanner{unmarshaler} + continue + } + + if converter, ok := iter.sess.(sessValueConverter); ok { + values[i] = converter.ConvertValue(values[i]) + continue + } + } + + if err = rows.Scan(values...); err != nil { + return item, err + } + + case reflect.Map: + + columns, err := rows.Columns() + if err != nil { + return item, err + } + + values := make([]interface{}, len(columns)) + for i := range values { + if itemT.Elem().Kind() == reflect.Interface { + values[i] = new(interface{}) + } else { + values[i] = reflect.New(itemT.Elem()).Interface() + } + } + + if err = rows.Scan(values...); err != nil { + return item, err + } + + for i, column := range columns { + item.SetMapIndex(reflect.ValueOf(column), reflect.Indirect(reflect.ValueOf(values[i]))) + } + } + + return item, nil +} + +func reset(data interface{}) { + // Resetting element. + v := reflect.ValueOf(data).Elem() + t := v.Type() + + var z reflect.Value + + switch v.Kind() { + case reflect.Slice: + z = reflect.MakeSlice(t, 0, v.Cap()) + default: + z = reflect.Zero(t) + } + + v.Set(z) +} diff --git a/internal/sqlbuilder/insert.go b/internal/sqlbuilder/insert.go new file mode 100644 index 0000000..130264c --- /dev/null +++ b/internal/sqlbuilder/insert.go @@ -0,0 +1,285 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type inserterQuery struct { + table string + enqueuedValues [][]interface{} + returning []exql.Fragment + columns []exql.Fragment + values []*exql.Values + arguments []interface{} + amendFn func(string) string +} + +func (iq *inserterQuery) processValues() ([]*exql.Values, []interface{}, error) { + var values []*exql.Values + var arguments []interface{} + + var mapOptions *MapOptions + if len(iq.enqueuedValues) > 1 { + mapOptions = &MapOptions{IncludeZeroed: true, IncludeNil: true} + } + + for _, enqueuedValue := range iq.enqueuedValues { + if len(enqueuedValue) == 1 { + // If and only if we passed one argument to Values. + ff, vv, err := Map(enqueuedValue[0], mapOptions) + + if err == nil { + // If we didn't have any problem with mapping we can convert it into + // columns and values. + columns, vals, args, _ := toColumnsValuesAndArguments(ff, vv) + + values, arguments = append(values, vals), append(arguments, args...) + + if len(iq.columns) == 0 { + iq.columns = append(iq.columns, columns.Columns...) + } + continue + } + + // The only error we can expect without exiting is this argument not + // being a map or struct, in which case we can continue. + if !errors.Is(err, ErrExpectingPointerToEitherMapOrStruct) { + return nil, nil, err + } + } + + if len(iq.columns) == 0 || len(enqueuedValue) == len(iq.columns) { + arguments = append(arguments, enqueuedValue...) + + l := len(enqueuedValue) + placeholders := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + placeholders[i] = sqlPlaceholder + } + values = append(values, exql.NewValueGroup(placeholders...)) + } + } + + return values, arguments, nil +} + +func (iq *inserterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Insert, + Table: exql.TableWithName(iq.table), + } + + if len(iq.values) > 0 { + stmt.Values = exql.JoinValueGroups(iq.values...) + } + + if len(iq.columns) > 0 { + stmt.Columns = exql.JoinColumns(iq.columns...) + } + + if len(iq.returning) > 0 { + stmt.Returning = exql.ReturningColumns(iq.returning...) + } + + stmt.SetAmendment(iq.amendFn) + + return stmt +} + +type inserter struct { + builder *sqlBuilder + + fn func(*inserterQuery) error + prev *inserter +} + +var _ = immutable.Immutable(&inserter{}) + +func (ins *inserter) SQL() *sqlBuilder { + if ins.prev == nil { + return ins.builder + } + return ins.prev.SQL() +} + +func (ins *inserter) template() *exql.Template { + return ins.SQL().t.Template +} + +func (ins *inserter) String() string { + s, err := ins.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (ins *inserter) frame(fn func(*inserterQuery) error) *inserter { + return &inserter{prev: ins, fn: fn} +} + +func (ins *inserter) Batch(n int) mydb.BatchInserter { + return newBatchInserter(ins, n) +} + +func (ins *inserter) Amend(fn func(string) string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.amendFn = fn + return nil + }) +} + +func (ins *inserter) Arguments() []interface{} { + iq, err := ins.build() + if err != nil { + return nil + } + return iq.arguments +} + +func (ins *inserter) Returning(columns ...string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + columnsToFragments(&iq.returning, columns) + return nil + }) +} + +func (ins *inserter) Exec() (sql.Result, error) { + return ins.ExecContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) ExecContext(ctx context.Context) (sql.Result, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementExec(ctx, iq.statement(), iq.arguments...) +} + +func (ins *inserter) Prepare() (*sql.Stmt, error) { + return ins.PrepareContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementPrepare(ctx, iq.statement()) +} + +func (ins *inserter) Query() (*sql.Rows, error) { + return ins.QueryContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) QueryContext(ctx context.Context) (*sql.Rows, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementQuery(ctx, iq.statement(), iq.arguments...) +} + +func (ins *inserter) QueryRow() (*sql.Row, error) { + return ins.QueryRowContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) QueryRowContext(ctx context.Context) (*sql.Row, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return ins.SQL().sess.StatementQueryRow(ctx, iq.statement(), iq.arguments...) +} + +func (ins *inserter) Iterator() mydb.Iterator { + return ins.IteratorContext(ins.SQL().sess.Context()) +} + +func (ins *inserter) IteratorContext(ctx context.Context) mydb.Iterator { + rows, err := ins.QueryContext(ctx) + return &iterator{ins.SQL().sess, rows, err} +} + +func (ins *inserter) Into(table string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.table = table + return nil + }) +} + +func (ins *inserter) Columns(columns ...string) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + columnsToFragments(&iq.columns, columns) + return nil + }) +} + +func (ins *inserter) Values(values ...interface{}) mydb.Inserter { + return ins.frame(func(iq *inserterQuery) error { + iq.enqueuedValues = append(iq.enqueuedValues, values) + return nil + }) +} + +func (ins *inserter) statement() (*exql.Statement, error) { + iq, err := ins.build() + if err != nil { + return nil, err + } + return iq.statement(), nil +} + +func (ins *inserter) build() (*inserterQuery, error) { + iq, err := immutable.FastForward(ins) + if err != nil { + return nil, err + } + ret := iq.(*inserterQuery) + ret.values, ret.arguments, err = ret.processValues() + if err != nil { + return nil, err + } + return ret, nil +} + +func (ins *inserter) Compile() (string, error) { + s, err := ins.statement() + if err != nil { + return "", err + } + return s.Compile(ins.template()) +} + +func (ins *inserter) Prev() immutable.Immutable { + if ins == nil { + return nil + } + return ins.prev +} + +func (ins *inserter) Fn(in interface{}) error { + if ins.fn == nil { + return nil + } + return ins.fn(in.(*inserterQuery)) +} + +func (ins *inserter) Base() interface{} { + return &inserterQuery{} +} + +func columnsToFragments(dst *[]exql.Fragment, columns []string) { + l := len(columns) + f := make([]exql.Fragment, l) + for i := 0; i < l; i++ { + f[i] = exql.ColumnWithName(columns[i]) + } + *dst = append(*dst, f...) +} diff --git a/internal/sqlbuilder/paginate.go b/internal/sqlbuilder/paginate.go new file mode 100644 index 0000000..6d75d82 --- /dev/null +++ b/internal/sqlbuilder/paginate.go @@ -0,0 +1,340 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + "math" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" +) + +var ( + errMissingCursorColumn = errors.New("Missing cursor column") +) + +type paginatorQuery struct { + sel mydb.Selector + + cursorColumn string + cursorValue interface{} + cursorCond mydb.Cond + cursorReverseOrder bool + + pageSize uint + pageNumber uint +} + +func newPaginator(sel mydb.Selector, pageSize uint) mydb.Paginator { + pag := &paginator{} + return pag.frame(func(pq *paginatorQuery) error { + pq.pageSize = pageSize + pq.sel = sel + return nil + }).Page(1) +} + +func (pq *paginatorQuery) count() (uint64, error) { + var count uint64 + + row, err := pq.sel.(*selector).setColumns(mydb.Raw("count(1) AS _t")). + Limit(0). + Offset(0). + OrderBy(nil). + QueryRow() + if err != nil { + return 0, err + } + + err = row.Scan(&count) + if err != nil { + return 0, err + } + + return count, nil +} + +type paginator struct { + fn func(*paginatorQuery) error + prev *paginator +} + +var _ = immutable.Immutable(&paginator{}) + +func (pag *paginator) frame(fn func(*paginatorQuery) error) *paginator { + return &paginator{prev: pag, fn: fn} +} + +func (pag *paginator) Page(pageNumber uint) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + if pageNumber < 1 { + pageNumber = 1 + } + pq.pageNumber = pageNumber + return nil + }) +} + +func (pag *paginator) Cursor(column string) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + pq.cursorColumn = column + pq.cursorValue = nil + return nil + }) +} + +func (pag *paginator) NextPage(cursorValue interface{}) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + if pq.cursorValue != nil && pq.cursorColumn == "" { + return errMissingCursorColumn + } + pq.cursorValue = cursorValue + pq.cursorReverseOrder = false + if strings.HasPrefix(pq.cursorColumn, "-") { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn[1:]: mydb.Lt(cursorValue), + } + } else { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn: mydb.Gt(cursorValue), + } + } + return nil + }) +} + +func (pag *paginator) PrevPage(cursorValue interface{}) mydb.Paginator { + return pag.frame(func(pq *paginatorQuery) error { + if pq.cursorValue != nil && pq.cursorColumn == "" { + return errMissingCursorColumn + } + pq.cursorValue = cursorValue + pq.cursorReverseOrder = true + if strings.HasPrefix(pq.cursorColumn, "-") { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn[1:]: mydb.Gt(cursorValue), + } + } else { + pq.cursorCond = mydb.Cond{ + pq.cursorColumn: mydb.Lt(cursorValue), + } + } + return nil + }) +} + +func (pag *paginator) TotalPages() (uint, error) { + pq, err := pag.build() + if err != nil { + return 0, err + } + + count, err := pq.count() + if err != nil { + return 0, err + } + if count < 1 { + return 0, nil + } + + if pq.pageSize < 1 { + return 1, nil + } + + pages := uint(math.Ceil(float64(count) / float64(pq.pageSize))) + return pages, nil +} + +func (pag *paginator) All(dest interface{}) error { + pq, err := pag.buildWithCursor() + if err != nil { + return err + } + err = pq.sel.All(dest) + if err != nil { + return err + } + return nil +} + +func (pag *paginator) One(dest interface{}) error { + pq, err := pag.buildWithCursor() + if err != nil { + return err + } + return pq.sel.One(dest) +} + +func (pag *paginator) Iterator() mydb.Iterator { + pq, err := pag.buildWithCursor() + if err != nil { + sess := pq.sel.(*selector).SQL().sess + return &iterator{sess, nil, err} + } + return pq.sel.Iterator() +} + +func (pag *paginator) IteratorContext(ctx context.Context) mydb.Iterator { + pq, err := pag.buildWithCursor() + if err != nil { + sess := pq.sel.(*selector).SQL().sess + return &iterator{sess, nil, err} + } + return pq.sel.IteratorContext(ctx) +} + +func (pag *paginator) String() string { + pq, err := pag.buildWithCursor() + if err != nil { + panic(err.Error()) + } + return pq.sel.String() +} + +func (pag *paginator) Arguments() []interface{} { + pq, err := pag.buildWithCursor() + if err != nil { + return nil + } + return pq.sel.Arguments() +} + +func (pag *paginator) Compile() (string, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return "", err + } + return pq.sel.(*selector).Compile() +} + +func (pag *paginator) Query() (*sql.Rows, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.Query() +} + +func (pag *paginator) QueryContext(ctx context.Context) (*sql.Rows, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.QueryContext(ctx) +} + +func (pag *paginator) QueryRow() (*sql.Row, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.QueryRow() +} + +func (pag *paginator) QueryRowContext(ctx context.Context) (*sql.Row, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.QueryRowContext(ctx) +} + +func (pag *paginator) Prepare() (*sql.Stmt, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.Prepare() +} + +func (pag *paginator) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + pq, err := pag.buildWithCursor() + if err != nil { + return nil, err + } + return pq.sel.PrepareContext(ctx) +} + +func (pag *paginator) TotalEntries() (uint64, error) { + pq, err := pag.build() + if err != nil { + return 0, err + } + return pq.count() +} + +func (pag *paginator) build() (*paginatorQuery, error) { + pq, err := immutable.FastForward(pag) + if err != nil { + return nil, err + } + return pq.(*paginatorQuery), nil +} + +func (pag *paginator) buildWithCursor() (*paginatorQuery, error) { + pq, err := immutable.FastForward(pag) + if err != nil { + return nil, err + } + + pqq := pq.(*paginatorQuery) + + if pqq.cursorReverseOrder { + orderBy := pqq.cursorColumn + + if orderBy == "" { + return nil, errMissingCursorColumn + } + + if strings.HasPrefix(orderBy, "-") { + orderBy = orderBy[1:] + } else { + orderBy = "-" + orderBy + } + + pqq.sel = pqq.sel.OrderBy(orderBy) + } + + if pqq.pageSize > 0 { + pqq.sel = pqq.sel.Limit(int(pqq.pageSize)) + if pqq.pageNumber > 1 { + pqq.sel = pqq.sel.Offset(int(pqq.pageSize * (pqq.pageNumber - 1))) + } + } + + if pqq.cursorCond != nil { + pqq.sel = pqq.sel.Where(pqq.cursorCond).Offset(0) + } + + if pqq.cursorColumn != "" { + if pqq.cursorReverseOrder { + pqq.sel = pqq.sel.(*selector).SQL(). + SelectFrom(mydb.Raw("? AS p0", pqq.sel)). + OrderBy(pqq.cursorColumn) + } else { + pqq.sel = pqq.sel.OrderBy(pqq.cursorColumn) + } + } + + return pqq, nil +} + +func (pag *paginator) Prev() immutable.Immutable { + if pag == nil { + return nil + } + return pag.prev +} + +func (pag *paginator) Fn(in interface{}) error { + if pag.fn == nil { + return nil + } + return pag.fn(in.(*paginatorQuery)) +} + +func (pag *paginator) Base() interface{} { + return &paginatorQuery{} +} diff --git a/internal/sqlbuilder/placeholder_test.go b/internal/sqlbuilder/placeholder_test.go new file mode 100644 index 0000000..bbea17c --- /dev/null +++ b/internal/sqlbuilder/placeholder_test.go @@ -0,0 +1,146 @@ +package sqlbuilder + +import ( + "testing" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/assert" +) + +func TestPrepareForDisplay(t *testing.T) { + samples := []struct { + In string + Out string + }{ + { + In: "12345", + Out: "12345", + }, + { + In: "\r\n\t12345", + Out: "12345", + }, + { + In: "12345\r\n\t", + Out: "12345", + }, + { + In: "\r\n\t1\r2\n3\t4\r5\r\n\t", + Out: "1 2 3 4 5", + }, + { + In: "\r\n \t 1\r 2\n 3\t 4\r 5\r \n\t", + Out: "1 2 3 4 5", + }, + { + In: "\r\n \t 11\r 22\n 33\t 44 \r 55", + Out: "11 22 33 44 55", + }, + { + In: "11\r 22\n 33\t 44 \r 55", + Out: "11 22 33 44 55", + }, + { + In: "1 2 3 4 5", + Out: "1 2 3 4 5", + }, + { + In: "?", + Out: "$1", + }, + { + In: "? ?", + Out: "$1 $2", + }, + { + In: "? ? ?", + Out: "$1 $2 $3", + }, + { + In: " ? ? ? ", + Out: "$1 $2 $3", + }, + { + In: "???", + Out: "$1$2$3", + }, + } + for _, sample := range samples { + assert.Equal(t, sample.Out, prepareQueryForDisplay(sample.In)) + } +} + +func TestPlaceholderSimple(t *testing.T) { + { + ret, _ := Preprocess("?", []interface{}{1}) + assert.Equal(t, "?", ret) + } + { + ret, _ := Preprocess("?", nil) + assert.Equal(t, "?", ret) + } +} + +func TestPlaceholderMany(t *testing.T) { + { + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, 3}) + assert.Equal(t, "?, ?, ?", ret) + } +} + +func TestPlaceholderArray(t *testing.T) { + { + ret, _ := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) + assert.Equal(t, "?, ?, (?, ?, ?)", ret) + } + + { + ret, _ := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) + assert.Equal(t, "(?, ?, ?), ?, ?", ret) + } + + { + ret, _ := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) + assert.Equal(t, "?, (?, ?, ?), ?", ret) + } + + { + ret, _ := Preprocess("???", []interface{}{1, []interface{}{2, 3, 4}, 5}) + assert.Equal(t, "?(?, ?, ?)?", ret) + } + + { + ret, _ := Preprocess("??", []interface{}{[]interface{}{1, 2, 3}, []interface{}{}, []interface{}{4, 5}, []interface{}{}}) + assert.Equal(t, "(?, ?, ?)(NULL)", ret) + } +} + +func TestPlaceholderArguments(t *testing.T) { + { + _, args := Preprocess("?, ?, ?", []interface{}{1, 2, []interface{}{3, 4, 5}}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := Preprocess("?, ?, ?", []interface{}{1, []interface{}{2, 3, 4}, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := Preprocess("?, ?, ?", []interface{}{[]interface{}{1, 2, 3}, 4, 5}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } + + { + _, args := Preprocess("?, ?", []interface{}{[]interface{}{1, 2, 3}, []interface{}{4, 5}}) + assert.Equal(t, []interface{}{1, 2, 3, 4, 5}, args) + } +} + +func TestPlaceholderReplace(t *testing.T) { + { + ret, args := Preprocess("?, ?, ?", []interface{}{1, mydb.Raw("foo"), 3}) + assert.Equal(t, "?, foo, ?", ret) + assert.Equal(t, []interface{}{1, 3}, args) + } +} diff --git a/internal/sqlbuilder/scanner.go b/internal/sqlbuilder/scanner.go new file mode 100644 index 0000000..228a8e0 --- /dev/null +++ b/internal/sqlbuilder/scanner.go @@ -0,0 +1,17 @@ +package sqlbuilder + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" +) + +type scanner struct { + v mydb.Unmarshaler +} + +func (u scanner) Scan(v interface{}) error { + return u.v.UnmarshalDB(v) +} + +var _ sql.Scanner = scanner{} diff --git a/internal/sqlbuilder/select.go b/internal/sqlbuilder/select.go new file mode 100644 index 0000000..cbf90b0 --- /dev/null +++ b/internal/sqlbuilder/select.go @@ -0,0 +1,524 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + "errors" + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type selectorQuery struct { + table *exql.Columns + tableArgs []interface{} + + distinct bool + + where *exql.Where + whereArgs []interface{} + + groupBy *exql.GroupBy + groupByArgs []interface{} + + orderBy *exql.OrderBy + orderByArgs []interface{} + + limit exql.Limit + offset exql.Offset + + columns *exql.Columns + columnsArgs []interface{} + + joins []*exql.Join + joinsArgs []interface{} + + amendFn func(string) string +} + +func (sq *selectorQuery) and(b *sqlBuilder, terms ...interface{}) error { + where, whereArgs := b.t.toWhereWithArguments(terms) + + if sq.where == nil { + sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} + } + sq.where.Append(&where) + sq.whereArgs = append(sq.whereArgs, whereArgs...) + + return nil +} + +func (sq *selectorQuery) arguments() []interface{} { + return joinArguments( + sq.columnsArgs, + sq.tableArgs, + sq.joinsArgs, + sq.whereArgs, + sq.groupByArgs, + sq.orderByArgs, + ) +} + +func (sq *selectorQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Select, + Table: sq.table, + Columns: sq.columns, + Distinct: sq.distinct, + Limit: sq.limit, + Offset: sq.offset, + Where: sq.where, + OrderBy: sq.orderBy, + GroupBy: sq.groupBy, + } + + if len(sq.joins) > 0 { + stmt.Joins = exql.JoinConditions(sq.joins...) + } + + stmt.SetAmendment(sq.amendFn) + + return stmt +} + +func (sq *selectorQuery) pushJoin(t string, tables []interface{}) error { + fragments, args, err := columnFragments(tables) + if err != nil { + return err + } + + if sq.joins == nil { + sq.joins = []*exql.Join{} + } + sq.joins = append(sq.joins, + &exql.Join{ + Type: t, + Table: exql.JoinColumns(fragments...), + }, + ) + + sq.joinsArgs = append(sq.joinsArgs, args...) + + return nil +} + +type selector struct { + builder *sqlBuilder + + fn func(*selectorQuery) error + prev *selector +} + +var _ = immutable.Immutable(&selector{}) + +func (sel *selector) SQL() *sqlBuilder { + if sel.prev == nil { + return sel.builder + } + return sel.prev.SQL() +} + +func (sel *selector) String() string { + s, err := sel.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (sel *selector) frame(fn func(*selectorQuery) error) *selector { + return &selector{prev: sel, fn: fn} +} + +func (sel *selector) clone() mydb.Selector { + return sel.frame(func(*selectorQuery) error { + return nil + }) +} + +func (sel *selector) From(tables ...interface{}) mydb.Selector { + return sel.frame( + func(sq *selectorQuery) error { + fragments, args, err := columnFragments(tables) + if err != nil { + return err + } + sq.table = exql.JoinColumns(fragments...) + sq.tableArgs = args + return nil + }, + ) +} + +func (sel *selector) setColumns(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.columns = nil + return sq.pushColumns(columns...) + }) +} + +func (sel *selector) Columns(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushColumns(columns...) + }) +} + +func (sq *selectorQuery) pushColumns(columns ...interface{}) error { + f, args, err := columnFragments(columns) + if err != nil { + return err + } + + c := exql.JoinColumns(f...) + + if sq.columns != nil { + sq.columns.Append(c) + } else { + sq.columns = c + } + + sq.columnsArgs = append(sq.columnsArgs, args...) + return nil +} + +func (sel *selector) Distinct(exps ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.distinct = true + return sq.pushColumns(exps...) + }) +} + +func (sel *selector) Where(terms ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if len(terms) == 1 && terms[0] == nil { + sq.where, sq.whereArgs = &exql.Where{}, []interface{}{} + return nil + } + return sq.and(sel.SQL(), terms...) + }) +} + +func (sel *selector) And(terms ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.and(sel.SQL(), terms...) + }) +} + +func (sel *selector) Amend(fn func(string) string) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + sq.amendFn = fn + return nil + }) +} + +func (sel *selector) Arguments() []interface{} { + sq, err := sel.build() + if err != nil { + return nil + } + return sq.arguments() +} + +func (sel *selector) GroupBy(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + fragments, args, err := columnFragments(columns) + if err != nil { + return err + } + + if fragments != nil { + sq.groupBy = exql.GroupByColumns(fragments...) + } + sq.groupByArgs = args + + return nil + }) +} + +func (sel *selector) OrderBy(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + + if len(columns) == 1 && columns[0] == nil { + sq.orderBy = nil + sq.orderByArgs = nil + return nil + } + + var sortColumns exql.SortColumns + + for i := range columns { + var sort *exql.SortColumn + + switch value := columns[i].(type) { + case *adapter.RawExpr: + query, args := Preprocess(value.Raw(), value.Arguments()) + sort = &exql.SortColumn{ + Column: &exql.Raw{Value: query}, + } + sq.orderByArgs = append(sq.orderByArgs, args...) + case *adapter.FuncExpr: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + sort = &exql.SortColumn{ + Column: &exql.Raw{Value: fnName}, + } + sq.orderByArgs = append(sq.orderByArgs, fnArgs...) + case string: + if strings.HasPrefix(value, "-") { + sort = &exql.SortColumn{ + Column: exql.ColumnWithName(value[1:]), + Order: exql.Order_Descendent, + } + } else { + chunks := strings.SplitN(value, " ", 2) + + order := exql.Order_Ascendent + if len(chunks) > 1 && strings.ToUpper(chunks[1]) == "DESC" { + order = exql.Order_Descendent + } + + sort = &exql.SortColumn{ + Column: exql.ColumnWithName(chunks[0]), + Order: order, + } + } + default: + return fmt.Errorf("Can't sort by type %T", value) + } + sortColumns.Columns = append(sortColumns.Columns, sort) + } + + sq.orderBy = &exql.OrderBy{ + SortColumns: &sortColumns, + } + return nil + }) +} + +func (sel *selector) Using(columns ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + + joins := len(sq.joins) + if joins == 0 { + return errors.New(`cannot use Using() without a preceding Join() expression`) + } + + lastJoin := sq.joins[joins-1] + if lastJoin.On != nil { + return errors.New(`cannot use Using() and On() with the same Join() expression`) + } + + fragments, args, err := columnFragments(columns) + if err != nil { + return err + } + + sq.joinsArgs = append(sq.joinsArgs, args...) + lastJoin.Using = exql.UsingColumns(fragments...) + + return nil + }) +} + +func (sel *selector) FullJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("FULL", tables) + }) +} + +func (sel *selector) CrossJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("CROSS", tables) + }) +} + +func (sel *selector) RightJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("RIGHT", tables) + }) +} + +func (sel *selector) LeftJoin(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("LEFT", tables) + }) +} + +func (sel *selector) Join(tables ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + return sq.pushJoin("", tables) + }) +} + +func (sel *selector) On(terms ...interface{}) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + joins := len(sq.joins) + + if joins == 0 { + return errors.New(`cannot use On() without a preceding Join() expression`) + } + + lastJoin := sq.joins[joins-1] + if lastJoin.On != nil { + return errors.New(`cannot use Using() and On() with the same Join() expression`) + } + + w, a := sel.SQL().t.toWhereWithArguments(terms) + o := exql.On(w) + + lastJoin.On = &o + + sq.joinsArgs = append(sq.joinsArgs, a...) + + return nil + }) +} + +func (sel *selector) Limit(n int) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if n < 0 { + n = 0 + } + sq.limit = exql.Limit(n) + return nil + }) +} + +func (sel *selector) Offset(n int) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if n < 0 { + n = 0 + } + sq.offset = exql.Offset(n) + return nil + }) +} + +func (sel *selector) template() *exql.Template { + return sel.SQL().t.Template +} + +func (sel *selector) As(alias string) mydb.Selector { + return sel.frame(func(sq *selectorQuery) error { + if sq.table == nil { + return errors.New("Cannot use As() without a preceding From() expression") + } + last := len(sq.table.Columns) - 1 + if raw, ok := sq.table.Columns[last].(*exql.Raw); ok { + compiled, err := exql.ColumnWithName(alias).Compile(sel.template()) + if err != nil { + return err + } + sq.table.Columns[last] = &exql.Raw{Value: raw.Value + " AS " + compiled} + } + return nil + }) +} + +func (sel *selector) statement() *exql.Statement { + sq, _ := sel.build() + return sq.statement() +} + +func (sel *selector) QueryRow() (*sql.Row, error) { + return sel.QueryRowContext(sel.SQL().sess.Context()) +} + +func (sel *selector) QueryRowContext(ctx context.Context) (*sql.Row, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + + return sel.SQL().sess.StatementQueryRow(ctx, sq.statement(), sq.arguments()...) +} + +func (sel *selector) Prepare() (*sql.Stmt, error) { + return sel.PrepareContext(sel.SQL().sess.Context()) +} + +func (sel *selector) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + return sel.SQL().sess.StatementPrepare(ctx, sq.statement()) +} + +func (sel *selector) Query() (*sql.Rows, error) { + return sel.QueryContext(sel.SQL().sess.Context()) +} + +func (sel *selector) QueryContext(ctx context.Context) (*sql.Rows, error) { + sq, err := sel.build() + if err != nil { + return nil, err + } + return sel.SQL().sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) +} + +func (sel *selector) Iterator() mydb.Iterator { + return sel.IteratorContext(sel.SQL().sess.Context()) +} + +func (sel *selector) IteratorContext(ctx context.Context) mydb.Iterator { + sess := sel.SQL().sess + sq, err := sel.build() + if err != nil { + return &iterator{sess, nil, err} + } + + rows, err := sess.StatementQuery(ctx, sq.statement(), sq.arguments()...) + return &iterator{sess, rows, err} +} + +func (sel *selector) Paginate(pageSize uint) mydb.Paginator { + return newPaginator(sel.clone(), pageSize) +} + +func (sel *selector) All(destSlice interface{}) error { + return sel.Iterator().All(destSlice) +} + +func (sel *selector) One(dest interface{}) error { + return sel.Iterator().One(dest) +} + +func (sel *selector) build() (*selectorQuery, error) { + sq, err := immutable.FastForward(sel) + if err != nil { + return nil, err + } + return sq.(*selectorQuery), nil +} + +func (sel *selector) Compile() (string, error) { + return sel.statement().Compile(sel.template()) +} + +func (sel *selector) Prev() immutable.Immutable { + if sel == nil { + return nil + } + return sel.prev +} + +func (sel *selector) Fn(in interface{}) error { + if sel.fn == nil { + return nil + } + return sel.fn(in.(*selectorQuery)) +} + +func (sel *selector) Base() interface{} { + return &selectorQuery{} +} diff --git a/internal/sqlbuilder/sqlbuilder.go b/internal/sqlbuilder/sqlbuilder.go new file mode 100644 index 0000000..7afa620 --- /dev/null +++ b/internal/sqlbuilder/sqlbuilder.go @@ -0,0 +1,40 @@ +package sqlbuilder + +import ( + "database/sql" + "fmt" + + "git.hexq.cn/tiglog/mydb" +) + +// Engine represents a SQL database engine. +type Engine interface { + mydb.Session + + mydb.SQL +} + +func lookupAdapter(adapterName string) (Adapter, error) { + adapter := mydb.LookupAdapter(adapterName) + if sqlAdapter, ok := adapter.(Adapter); ok { + return sqlAdapter, nil + } + return nil, fmt.Errorf("%w %q", mydb.ErrMissingAdapter, adapterName) +} + +func BindTx(adapterName string, tx *sql.Tx) (Tx, error) { + adapter, err := lookupAdapter(adapterName) + if err != nil { + return nil, err + } + return adapter.NewTx(tx) +} + +// Bind creates a binding between an adapter and a *sql.Tx or a *sql.mydb. +func BindDB(adapterName string, sess *sql.DB) (mydb.Session, error) { + adapter, err := lookupAdapter(adapterName) + if err != nil { + return nil, err + } + return adapter.New(sess) +} diff --git a/internal/sqlbuilder/template.go b/internal/sqlbuilder/template.go new file mode 100644 index 0000000..c27466d --- /dev/null +++ b/internal/sqlbuilder/template.go @@ -0,0 +1,332 @@ +package sqlbuilder + +import ( + "database/sql/driver" + "fmt" + "strings" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/adapter" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type templateWithUtils struct { + *exql.Template +} + +func newTemplateWithUtils(template *exql.Template) *templateWithUtils { + return &templateWithUtils{template} +} + +func (tu *templateWithUtils) PlaceholderValue(in interface{}) (exql.Fragment, []interface{}) { + switch t := in.(type) { + case *adapter.RawExpr: + return &exql.Raw{Value: t.Raw()}, t.Arguments() + case *adapter.FuncExpr: + fnName := t.Name() + fnArgs := []interface{}{} + args, _ := toInterfaceArguments(t.Arguments()) + fragments := []string{} + for i := range args { + frag, args := tu.PlaceholderValue(args[i]) + fragment, err := frag.Compile(tu.Template) + if err == nil { + fragments = append(fragments, fragment) + fnArgs = append(fnArgs, args...) + } + } + return &exql.Raw{Value: fnName + `(` + strings.Join(fragments, `, `) + `)`}, fnArgs + default: + return sqlPlaceholder, []interface{}{in} + } +} + +// toWhereWithArguments converts the given parameters into a exql.Where value. +func (tu *templateWithUtils) toWhereWithArguments(term interface{}) (where exql.Where, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + if len(t) > 0 { + if s, ok := t[0].(string); ok { + if strings.ContainsAny(s, "?") || len(t) == 1 { + s, args = Preprocess(s, t[1:]) + where.Conditions = []exql.Fragment{&exql.Raw{Value: s}} + } else { + var val interface{} + key := s + if len(t) > 2 { + val = t[1:] + } else { + val = t[1] + } + cv, v := tu.toColumnValues(adapter.NewConstraint(key, val)) + args = append(args, v...) + where.Conditions = append(where.Conditions, cv.ColumnValues...) + } + return + } + } + for i := range t { + w, v := tu.toWhereWithArguments(t[i]) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case *adapter.RawExpr: + r, v := Preprocess(t.Raw(), t.Arguments()) + where.Conditions = []exql.Fragment{&exql.Raw{Value: r}} + args = append(args, v...) + return + case adapter.Constraints: + for _, c := range t.Constraints() { + w, v := tu.toWhereWithArguments(c) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + where.Conditions = append(where.Conditions, w.Conditions...) + } + return + case adapter.LogicalExpr: + var cond exql.Where + + expressions := t.Expressions() + for i := range expressions { + w, v := tu.toWhereWithArguments(expressions[i]) + if len(w.Conditions) == 0 { + continue + } + args = append(args, v...) + cond.Conditions = append(cond.Conditions, w.Conditions...) + } + if len(cond.Conditions) < 1 { + return + } + + if len(cond.Conditions) <= 1 { + where.Conditions = append(where.Conditions, cond.Conditions...) + return where, args + } + + var frag exql.Fragment + switch t.Operator() { + case adapter.LogicalOperatorNone, adapter.LogicalOperatorAnd: + q := exql.And(cond) + frag = &q + case adapter.LogicalOperatorOr: + q := exql.Or(cond) + frag = &q + default: + panic(fmt.Sprintf("Unknown type %T", t)) + } + where.Conditions = append(where.Conditions, frag) + return + + case mydb.InsertResult: + return tu.toWhereWithArguments(t.ID()) + + case adapter.Constraint: + cv, v := tu.toColumnValues(t) + args = append(args, v...) + where.Conditions = append(where.Conditions, cv.ColumnValues...) + return where, args + } + + panic(fmt.Sprintf("Unknown condition type %T", term)) +} + +func (tu *templateWithUtils) comparisonOperatorMapper(t adapter.ComparisonOperator) string { + if t == adapter.ComparisonOperatorCustom { + return "" + } + if tu.ComparisonOperator != nil { + if op, ok := tu.ComparisonOperator[t]; ok { + return op + } + } + if op, ok := comparisonOperators[t]; ok { + return op + } + panic(fmt.Sprintf("unsupported comparison operator %v", t)) +} + +func (tu *templateWithUtils) toColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case adapter.Constraint: + columnValue := exql.ColumnValue{} + + // TODO: Key and Value are similar. Can we refactor this? Maybe think about + // Left/Right rather than Key/Value. + + switch key := t.Key().(type) { + case string: + chunks := strings.SplitN(strings.TrimSpace(key), " ", 2) + columnValue.Column = exql.ColumnWithName(chunks[0]) + if len(chunks) > 1 { + columnValue.Operator = chunks[1] + } + case *adapter.RawExpr: + columnValue.Column = &exql.Raw{Value: key.Raw()} + args = append(args, key.Arguments()...) + case *mydb.FuncExpr: + fnName, fnArgs := key.Name(), key.Arguments() + if len(fnArgs) == 0 { + fnName = fnName + "()" + } else { + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + columnValue.Column = &exql.Raw{Value: fnName} + args = append(args, fnArgs...) + default: + columnValue.Column = &exql.Raw{Value: fmt.Sprintf("%v", key)} + } + + switch value := t.Value().(type) { + case *mydb.FuncExpr: + fnName, fnArgs := value.Name(), value.Arguments() + if len(fnArgs) == 0 { + // A function with no arguments. + fnName = fnName + "()" + } else { + // A function with one or more arguments. + fnName = fnName + "(?" + strings.Repeat(", ?", len(fnArgs)-1) + ")" + } + fnName, fnArgs = Preprocess(fnName, fnArgs) + columnValue.Value = &exql.Raw{Value: fnName} + args = append(args, fnArgs...) + case *mydb.RawExpr: + q, a := Preprocess(value.Raw(), value.Arguments()) + columnValue.Value = &exql.Raw{Value: q} + args = append(args, a...) + case driver.Valuer: + columnValue.Value = sqlPlaceholder + args = append(args, value) + case *mydb.Comparison: + wrapper := &operatorWrapper{ + tu: tu, + cv: &columnValue, + op: value.Comparison, + } + + q, a := wrapper.preprocess() + q, a = Preprocess(q, a) + + columnValue = exql.ColumnValue{ + Column: &exql.Raw{Value: q}, + } + if a != nil { + args = append(args, a...) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + default: + wrapper := &operatorWrapper{ + tu: tu, + cv: &columnValue, + v: value, + } + + q, a := wrapper.preprocess() + q, a = Preprocess(q, a) + + columnValue = exql.ColumnValue{ + Column: &exql.Raw{Value: q}, + } + if a != nil { + args = append(args, a...) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + } + + if columnValue.Operator == "" { + columnValue.Operator = tu.comparisonOperatorMapper(adapter.ComparisonOperatorEqual) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + return cv, args + + case *adapter.RawExpr: + columnValue := exql.ColumnValue{} + p, q := Preprocess(t.Raw(), t.Arguments()) + columnValue.Column = &exql.Raw{Value: p} + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + args = append(args, q...) + return cv, args + + case adapter.Constraints: + for _, constraint := range t.Constraints() { + p, q := tu.toColumnValues(constraint) + cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) + args = append(args, q...) + } + return cv, args + } + + panic(fmt.Sprintf("Unknown term type %T.", term)) +} + +func (tu *templateWithUtils) setColumnValues(term interface{}) (cv exql.ColumnValues, args []interface{}) { + args = []interface{}{} + + switch t := term.(type) { + case []interface{}: + l := len(t) + for i := 0; i < l; i++ { + column, isString := t[i].(string) + + if !isString { + p, q := tu.setColumnValues(t[i]) + cv.ColumnValues = append(cv.ColumnValues, p.ColumnValues...) + args = append(args, q...) + continue + } + + if !strings.ContainsAny(column, tu.AssignmentOperator) { + column = column + " " + tu.AssignmentOperator + " ?" + } + + chunks := strings.SplitN(column, tu.AssignmentOperator, 2) + + column = chunks[0] + format := strings.TrimSpace(chunks[1]) + + columnValue := exql.ColumnValue{ + Column: exql.ColumnWithName(column), + Operator: tu.AssignmentOperator, + Value: &exql.Raw{Value: format}, + } + + ps := strings.Count(format, "?") + if i+ps < l { + for j := 0; j < ps; j++ { + args = append(args, t[i+j+1]) + } + i = i + ps + } else { + panic(fmt.Sprintf("Format string %q has more placeholders than given arguments.", format)) + } + + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + } + return cv, args + case *adapter.RawExpr: + columnValue := exql.ColumnValue{} + p, q := Preprocess(t.Raw(), t.Arguments()) + columnValue.Column = &exql.Raw{Value: p} + cv.ColumnValues = append(cv.ColumnValues, &columnValue) + args = append(args, q...) + return cv, args + } + + panic(fmt.Sprintf("Unknown term type %T.", term)) +} diff --git a/internal/sqlbuilder/template_test.go b/internal/sqlbuilder/template_test.go new file mode 100644 index 0000000..f185afd --- /dev/null +++ b/internal/sqlbuilder/template_test.go @@ -0,0 +1,192 @@ +package sqlbuilder + +import ( + "git.hexq.cn/tiglog/mydb/internal/cache" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +const ( + defaultColumnSeparator = `.` + defaultIdentifierSeparator = `, ` + defaultIdentifierQuote = `"{{.Value}}"` + defaultValueSeparator = `, ` + defaultValueQuote = `'{{.}}'` + defaultAndKeyword = `AND` + defaultOrKeyword = `OR` + defaultDescKeyword = `DESC` + defaultAscKeyword = `ASC` + defaultAssignmentOperator = `=` + defaultClauseGroup = `({{.}})` + defaultClauseOperator = ` {{.}} ` + defaultColumnValue = `{{.Column}} {{.Operator}} {{.Value}}` + defaultTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}` + defaultSortByColumnLayout = `{{.Column}} {{.Order}}` + + defaultOrderByLayout = ` + {{if .SortColumns}} + ORDER BY {{.SortColumns}} + {{end}} + ` + + defaultWhereLayout = ` + {{if .Conds}} + WHERE {{.Conds}} + {{end}} + ` + + defaultUsingLayout = ` + {{if .Columns}} + USING ({{.Columns}}) + {{end}} + ` + + defaultJoinLayout = ` + {{if .Table}} + {{ if .On }} + {{.Type}} JOIN {{.Table}} + {{.On}} + {{ else if .Using }} + {{.Type}} JOIN {{.Table}} + {{.Using}} + {{ else if .Type | eq "CROSS" }} + {{.Type}} JOIN {{.Table}} + {{else}} + NATURAL {{.Type}} JOIN {{.Table}} + {{end}} + {{end}} + ` + + defaultOnLayout = ` + {{if .Conds}} + ON {{.Conds}} + {{end}} + ` + + defaultSelectLayout = ` + SELECT + {{if .Distinct}} + DISTINCT + {{end}} + + {{if defined .Columns}} + {{.Columns | compile}} + {{else}} + * + {{end}} + + {{if defined .Table}} + FROM {{.Table | compile}} + {{end}} + + {{.Joins | compile}} + + {{.Where | compile}} + + {{if defined .GroupBy}} + {{.GroupBy | compile}} + {{end}} + + {{.OrderBy | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + defaultDeleteLayout = ` + DELETE + FROM {{.Table | compile}} + {{.Where | compile}} + ` + defaultUpdateLayout = ` + UPDATE + {{.Table | compile}} + SET {{.ColumnValues | compile}} + {{.Where | compile}} + ` + + defaultCountLayout = ` + SELECT + COUNT(1) AS _t + FROM {{.Table | compile}} + {{.Where | compile}} + + {{if .Limit}} + LIMIT {{.Limit}} + {{end}} + + {{if .Offset}} + OFFSET {{.Offset}} + {{end}} + ` + + defaultInsertLayout = ` + INSERT INTO {{.Table | compile}} + {{if defined .Columns }}({{.Columns | compile}}){{end}} + VALUES + {{if defined .Values}} + {{.Values | compile}} + {{else}} + (default) + {{end}} + {{if defined .Returning}} + RETURNING {{.Returning | compile}} + {{end}} + ` + + defaultTruncateLayout = ` + TRUNCATE TABLE {{.Table | compile}} + ` + + defaultDropDatabaseLayout = ` + DROP DATABASE {{.Database | compile}} + ` + + defaultDropTableLayout = ` + DROP TABLE {{.Table | compile}} + ` + + defaultGroupByLayout = ` + {{if .GroupColumns}} + GROUP BY {{.GroupColumns}} + {{end}} + ` +) + +var testTemplate = exql.Template{ + ColumnSeparator: defaultColumnSeparator, + IdentifierSeparator: defaultIdentifierSeparator, + IdentifierQuote: defaultIdentifierQuote, + ValueSeparator: defaultValueSeparator, + ValueQuote: defaultValueQuote, + AndKeyword: defaultAndKeyword, + OrKeyword: defaultOrKeyword, + DescKeyword: defaultDescKeyword, + AscKeyword: defaultAscKeyword, + AssignmentOperator: defaultAssignmentOperator, + ClauseGroup: defaultClauseGroup, + ClauseOperator: defaultClauseOperator, + ColumnValue: defaultColumnValue, + TableAliasLayout: defaultTableAliasLayout, + ColumnAliasLayout: defaultColumnAliasLayout, + SortByColumnLayout: defaultSortByColumnLayout, + WhereLayout: defaultWhereLayout, + OnLayout: defaultOnLayout, + UsingLayout: defaultUsingLayout, + JoinLayout: defaultJoinLayout, + OrderByLayout: defaultOrderByLayout, + InsertLayout: defaultInsertLayout, + SelectLayout: defaultSelectLayout, + UpdateLayout: defaultUpdateLayout, + DeleteLayout: defaultDeleteLayout, + TruncateLayout: defaultTruncateLayout, + DropDatabaseLayout: defaultDropDatabaseLayout, + DropTableLayout: defaultDropTableLayout, + CountLayout: defaultCountLayout, + GroupByLayout: defaultGroupByLayout, + Cache: cache.NewCache(), +} diff --git a/internal/sqlbuilder/update.go b/internal/sqlbuilder/update.go new file mode 100644 index 0000000..fb3ae30 --- /dev/null +++ b/internal/sqlbuilder/update.go @@ -0,0 +1,242 @@ +package sqlbuilder + +import ( + "context" + "database/sql" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/immutable" + "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" +) + +type updaterQuery struct { + table string + + columnValues *exql.ColumnValues + columnValuesArgs []interface{} + + limit int + + where *exql.Where + whereArgs []interface{} + + amendFn func(string) string +} + +func (uq *updaterQuery) and(b *sqlBuilder, terms ...interface{}) error { + where, whereArgs := b.t.toWhereWithArguments(terms) + + if uq.where == nil { + uq.where, uq.whereArgs = &exql.Where{}, []interface{}{} + } + uq.where.Append(&where) + uq.whereArgs = append(uq.whereArgs, whereArgs...) + + return nil +} + +func (uq *updaterQuery) statement() *exql.Statement { + stmt := &exql.Statement{ + Type: exql.Update, + Table: exql.TableWithName(uq.table), + ColumnValues: uq.columnValues, + } + + if uq.where != nil { + stmt.Where = uq.where + } + + if uq.limit != 0 { + stmt.Limit = exql.Limit(uq.limit) + } + + stmt.SetAmendment(uq.amendFn) + + return stmt +} + +func (uq *updaterQuery) arguments() []interface{} { + return joinArguments( + uq.columnValuesArgs, + uq.whereArgs, + ) +} + +type updater struct { + builder *sqlBuilder + + fn func(*updaterQuery) error + prev *updater +} + +var _ = immutable.Immutable(&updater{}) + +func (upd *updater) SQL() *sqlBuilder { + if upd.prev == nil { + return upd.builder + } + return upd.prev.SQL() +} + +func (upd *updater) template() *exql.Template { + return upd.SQL().t.Template +} + +func (upd *updater) String() string { + s, err := upd.Compile() + if err != nil { + panic(err.Error()) + } + return prepareQueryForDisplay(s) +} + +func (upd *updater) setTable(table string) *updater { + return upd.frame(func(uq *updaterQuery) error { + uq.table = table + return nil + }) +} + +func (upd *updater) frame(fn func(*updaterQuery) error) *updater { + return &updater{prev: upd, fn: fn} +} + +func (upd *updater) Set(terms ...interface{}) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + if uq.columnValues == nil { + uq.columnValues = &exql.ColumnValues{} + } + + if len(terms) == 1 { + ff, vv, err := Map(terms[0], nil) + if err == nil && len(ff) > 0 { + cvs := make([]exql.Fragment, 0, len(ff)) + args := make([]interface{}, 0, len(vv)) + + for i := range ff { + cv := &exql.ColumnValue{ + Column: exql.ColumnWithName(ff[i]), + Operator: upd.SQL().t.AssignmentOperator, + } + + var localArgs []interface{} + cv.Value, localArgs = upd.SQL().t.PlaceholderValue(vv[i]) + + args = append(args, localArgs...) + cvs = append(cvs, cv) + } + + uq.columnValues.Insert(cvs...) + uq.columnValuesArgs = append(uq.columnValuesArgs, args...) + + return nil + } + } + + cv, arguments := upd.SQL().t.setColumnValues(terms) + uq.columnValues.Insert(cv.ColumnValues...) + uq.columnValuesArgs = append(uq.columnValuesArgs, arguments...) + return nil + }) +} + +func (upd *updater) Amend(fn func(string) string) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.amendFn = fn + return nil + }) +} + +func (upd *updater) Arguments() []interface{} { + uq, err := upd.build() + if err != nil { + return nil + } + return uq.arguments() +} + +func (upd *updater) Where(terms ...interface{}) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.where, uq.whereArgs = &exql.Where{}, []interface{}{} + return uq.and(upd.SQL(), terms...) + }) +} + +func (upd *updater) And(terms ...interface{}) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + return uq.and(upd.SQL(), terms...) + }) +} + +func (upd *updater) Prepare() (*sql.Stmt, error) { + return upd.PrepareContext(upd.SQL().sess.Context()) +} + +func (upd *updater) PrepareContext(ctx context.Context) (*sql.Stmt, error) { + uq, err := upd.build() + if err != nil { + return nil, err + } + return upd.SQL().sess.StatementPrepare(ctx, uq.statement()) +} + +func (upd *updater) Exec() (sql.Result, error) { + return upd.ExecContext(upd.SQL().sess.Context()) +} + +func (upd *updater) ExecContext(ctx context.Context) (sql.Result, error) { + uq, err := upd.build() + if err != nil { + return nil, err + } + return upd.SQL().sess.StatementExec(ctx, uq.statement(), uq.arguments()...) +} + +func (upd *updater) Limit(limit int) mydb.Updater { + return upd.frame(func(uq *updaterQuery) error { + uq.limit = limit + return nil + }) +} + +func (upd *updater) statement() (*exql.Statement, error) { + iq, err := upd.build() + if err != nil { + return nil, err + } + return iq.statement(), nil +} + +func (upd *updater) build() (*updaterQuery, error) { + uq, err := immutable.FastForward(upd) + if err != nil { + return nil, err + } + return uq.(*updaterQuery), nil +} + +func (upd *updater) Compile() (string, error) { + s, err := upd.statement() + if err != nil { + return "", err + } + return s.Compile(upd.template()) +} + +func (upd *updater) Prev() immutable.Immutable { + if upd == nil { + return nil + } + return upd.prev +} + +func (upd *updater) Fn(in interface{}) error { + if upd.fn == nil { + return nil + } + return upd.fn(in.(*updaterQuery)) +} + +func (upd *updater) Base() interface{} { + return &updaterQuery{} +} diff --git a/internal/sqlbuilder/wrapper.go b/internal/sqlbuilder/wrapper.go new file mode 100644 index 0000000..6806f1e --- /dev/null +++ b/internal/sqlbuilder/wrapper.go @@ -0,0 +1,64 @@ +package sqlbuilder + +import ( + "database/sql" + + "git.hexq.cn/tiglog/mydb" +) + +// Tx represents a transaction on a SQL database. A transaction is like a +// regular Session except it has two extra methods: Commit and Rollback. +// +// A transaction needs to be committed (with Commit) to make changes permanent, +// changes can be discarded before committing by rolling back (with Rollback). +// After either committing or rolling back a transaction it can not longer be +// used and it's automatically closed. +type Tx interface { + // All mydb.Session methods are available on transaction sessions. They will + // run on the same transaction. + mydb.Session + + Commit() error + + Rollback() error +} + +// Adapter represents a SQL adapter. +type Adapter interface { + // New wraps an active *sql.DB session and returns a SQLBuilder database. The + // adapter needs to be imported to the blank namespace in order for it to be + // used here. + // + // This method is internally used by upper-db to create a builder backed by the + // given database. You may want to use your adapter's New function instead of + // this one. + New(*sql.DB) (mydb.Session, error) + + // NewTx wraps an active *sql.Tx transation and returns a SQLBuilder + // transaction. The adapter needs to be imported to the blank namespace in + // order for it to be used. + // + // This method is internally used by upper-db to create a builder backed by the + // given transaction. You may want to use your adapter's NewTx function + // instead of this one. + NewTx(*sql.Tx) (Tx, error) + + // Open opens a SQL database. + OpenDSN(mydb.ConnectionURL) (mydb.Session, error) +} + +type dbAdapter struct { + Adapter +} + +func (d *dbAdapter) Open(conn mydb.ConnectionURL) (mydb.Session, error) { + sess, err := d.Adapter.OpenDSN(conn) + if err != nil { + return nil, err + } + return sess.(mydb.Session), nil +} + +func NewCompatAdapter(adapter Adapter) mydb.Adapter { + return &dbAdapter{adapter} +} diff --git a/internal/testsuite/generic_suite.go b/internal/testsuite/generic_suite.go new file mode 100644 index 0000000..c6139f9 --- /dev/null +++ b/internal/testsuite/generic_suite.go @@ -0,0 +1,889 @@ +package testsuite + +import ( + "database/sql/driver" + "time" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/suite" + "gopkg.in/mgo.v2/bson" +) + +type birthday struct { + Name string `db:"name"` + Born time.Time `db:"born"` + BornUT *unixTimestamp `db:"born_ut,omitempty"` + OmitMe bool `json:"omit_me" db:"-" bson:"-"` +} + +type fibonacci struct { + Input uint64 `db:"input"` + Output uint64 `db:"output"` + // Test for BSON option. + OmitMe bool `json:"omit_me" db:"omit_me,bson,omitempty" bson:"omit_me,omitempty"` +} + +type oddEven struct { + // Test for JSON option. + Input int `json:"input" db:"input"` + // Test for JSON option. + // The "bson" tag is required by mgo. + IsEven bool `json:"is_even" db:"is_even,json" bson:"is_even"` + OmitMe bool `json:"omit_me" db:"-" bson:"-"` +} + +// Struct that relies on explicit mapping. +type mapE struct { + ID uint `db:"id,omitempty" bson:"-"` + MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` + CaseTest string `db:"case_test" bson:"case_test"` +} + +// Struct that will fallback to default mapping. +type mapN struct { + ID uint `db:"id,omitempty"` + MongoID bson.ObjectId `db:"-" bson:"_id,omitempty"` + Case_TEST string `db:"case_test"` +} + +// Struct for testing marshalling. +type unixTimestamp struct { + // Time is handled internally as time.Time but saved as an (integer) unix + // timestamp. + value time.Time +} + +func (u unixTimestamp) Value() (driver.Value, error) { + return u.value.UTC().Unix(), nil +} + +func (u *unixTimestamp) Scan(v interface{}) error { + var unixTime int64 + + switch t := v.(type) { + case int64: + unixTime = t + case nil: + return nil + default: + return mydb.ErrUnsupportedValue + } + + t := time.Unix(unixTime, 0).In(time.UTC) + *u = unixTimestamp{t} + + return nil +} + +func newUnixTimestamp(t time.Time) *unixTimestamp { + return &unixTimestamp{t.UTC()} +} + +func even(i int) bool { + return i%2 == 0 +} + +func fib(i uint64) uint64 { + if i == 0 { + return 0 + } else if i == 1 { + return 1 + } + return fib(i-1) + fib(i-2) +} + +type GenericTestSuite struct { + suite.Suite + + Helper +} + +func (s *GenericTestSuite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *GenericTestSuite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) +} + +func (s *GenericTestSuite) TestDatesAndUnicode() { + sess := s.Session() + + testTimeZone := time.Local + switch s.Adapter() { + case "mysql", "cockroachdb", "postgresql": + testTimeZone = defaultTimeLocation + case "sqlite", "ql", "mssql": + testTimeZone = time.UTC + } + + born := time.Date(1941, time.January, 5, 0, 0, 0, 0, testTimeZone) + + controlItem := birthday{ + Name: "Hayao Miyazaki", + Born: born, + BornUT: newUnixTimestamp(born), + } + + col := sess.Collection(`birthdays`) + + record, err := col.Insert(controlItem) + s.NoError(err) + s.NotZero(record.ID()) + + var res mydb.Result + switch s.Adapter() { + case "mongo": + res = col.Find(mydb.Cond{"_id": record.ID().(bson.ObjectId)}) + case "ql": + res = col.Find(mydb.Cond{"id()": record.ID()}) + default: + res = col.Find(mydb.Cond{"id": record.ID()}) + } + + var total uint64 + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(1), total) + + switch s.Adapter() { + case "mongo": + s.T().Skip() + } + + var testItem birthday + err = res.One(&testItem) + s.NoError(err) + + switch s.Adapter() { + case "sqlite", "ql", "mssql": + testItem.Born = testItem.Born.In(time.UTC) + } + s.Equal(controlItem.Born, testItem.Born) + + s.Equal(controlItem.BornUT, testItem.BornUT) + s.Equal(controlItem, testItem) + + var testItems []birthday + err = res.All(&testItems) + s.NoError(err) + s.NotZero(len(testItems)) + + for _, testItem = range testItems { + switch s.Adapter() { + case "sqlite", "ql", "mssql": + testItem.Born = testItem.Born.In(time.UTC) + } + s.Equal(controlItem, testItem) + } + + controlItem.Name = `宮崎駿` + err = res.Update(controlItem) + s.NoError(err) + + err = res.One(&testItem) + s.NoError(err) + + switch s.Adapter() { + case "sqlite", "ql", "mssql": + testItem.Born = testItem.Born.In(time.UTC) + } + + s.Equal(controlItem, testItem) + + err = res.Delete() + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Zero(total) + + err = res.Close() + s.NoError(err) +} + +func (s *GenericTestSuite) TestFibonacci() { + var err error + var res mydb.Result + var total uint64 + + sess := s.Session() + + col := sess.Collection("fibonacci") + + // Adding some items. + var i uint64 + for i = 0; i < 10; i++ { + item := fibonacci{Input: i, Output: fib(i)} + _, err = col.Insert(item) + s.NoError(err) + } + + // Testing sorting by function. + res = col.Find( + // 5, 6, 7, 3 + mydb.Or( + mydb.And( + mydb.Cond{"input": mydb.Gte(5)}, + mydb.Cond{"input": mydb.Lte(7)}, + ), + mydb.Cond{"input": mydb.Eq(3)}, + ), + ) + + // Testing sort by function. + switch s.Adapter() { + case "postgresql": + res = res.OrderBy(mydb.Raw("RANDOM()")) + case "sqlite": + res = res.OrderBy(mydb.Raw("RANDOM()")) + case "mysql": + res = res.OrderBy(mydb.Raw("RAND()")) + case "mssql": + res = res.OrderBy(mydb.Raw("NEWID()")) + } + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + // Find() with IN/$in + res = col.Find(mydb.Cond{"input IN": []int{3, 5, 6, 7}}).OrderBy("input") + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + res = res.Offset(1).Limit(2) + + var item fibonacci + for res.Next(&item) { + switch item.Input { + case 5: + case 6: + s.Equal(fib(item.Input), item.Output) + default: + s.T().Errorf(`Unexpected item: %v.`, item) + } + } + s.NoError(res.Err()) + + // Find() with range + res = col.Find( + // 5, 6, 7, 3 + mydb.Or( + mydb.And( + mydb.Cond{"input >=": 5}, + mydb.Cond{"input <=": 7}, + ), + mydb.Cond{"input": 3}, + ), + ).OrderBy("-input") + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + // Skipping. + res = res.Offset(1).Limit(2) + + var item2 fibonacci + for res.Next(&item2) { + switch item2.Input { + case 5: + case 6: + s.Equal(fib(item2.Input), item2.Output) + default: + s.T().Errorf(`Unexpected item: %v.`, item2) + } + } + err = res.Err() + s.NoError(err) + + err = res.Delete() + s.NoError(err) + + { + total, err := res.Count() + s.NoError(err) + s.Zero(total) + } + + // Find() with no arguments. + res = col.Find() + { + total, err := res.Count() + s.NoError(err) + s.Equal(uint64(6), total) + } + + // Skipping mongodb as the results of this are not defined there. + if s.Adapter() != `mongo` { + // Find() with empty mydb.Cond. + { + total, err := col.Find(mydb.Cond{}).Count() + s.NoError(err) + s.Equal(uint64(6), total) + } + + // Find() with empty expression + { + total, err := col.Find(mydb.Or(mydb.And(mydb.Cond{}, mydb.Cond{}), mydb.Or(mydb.Cond{}))).Count() + s.NoError(err) + s.Equal(uint64(6), total) + } + + // Find() with explicit IS NULL + { + total, err := col.Find(mydb.Cond{"input IS": nil}).Count() + s.NoError(err) + s.Equal(uint64(0), total) + } + + // Find() with implicit IS NULL + { + total, err := col.Find(mydb.Cond{"input": nil}).Count() + s.NoError(err) + s.Equal(uint64(0), total) + } + + // Find() with explicit = NULL + { + total, err := col.Find(mydb.Cond{"input =": nil}).Count() + s.NoError(err) + s.Equal(uint64(0), total) + } + + // Find() with implicit IN + { + total, err := col.Find(mydb.Cond{"input": []int{1, 2, 3, 4}}).Count() + s.NoError(err) + s.Equal(uint64(3), total) + } + + // Find() with implicit NOT IN + { + total, err := col.Find(mydb.Cond{"input NOT IN": []int{1, 2, 3, 4}}).Count() + s.NoError(err) + s.Equal(uint64(3), total) + } + } + + var items []fibonacci + err = res.All(&items) + s.NoError(err) + + for _, item := range items { + switch item.Input { + case 0: + case 1: + case 2: + case 4: + case 8: + case 9: + s.Equal(fib(item.Input), item.Output) + default: + s.T().Errorf(`Unexpected item: %v`, item) + } + } + + err = res.Close() + s.NoError(err) +} + +func (s *GenericTestSuite) TestOddEven() { + sess := s.Session() + + col := sess.Collection("is_even") + + // Adding some items. + var i int + for i = 1; i < 100; i++ { + item := oddEven{Input: i, IsEven: even(i)} + _, err := col.Insert(item) + s.NoError(err) + } + + // Retrieving items + res := col.Find(mydb.Cond{"is_even": true}) + + var item oddEven + for res.Next(&item) { + s.Zero(item.Input % 2) + } + + err := res.Err() + s.NoError(err) + + err = res.Delete() + s.NoError(err) + + // Testing named inputs (using tags). + res = col.Find() + + var item2 struct { + Value uint `db:"input" bson:"input"` // The "bson" tag is required by mgo. + } + for res.Next(&item2) { + s.NotZero(item2.Value % 2) + } + err = res.Err() + s.NoError(err) + + // Testing inline tag. + res = col.Find() + + var item3 struct { + OddEven oddEven `db:",inline" bson:",inline"` + } + for res.Next(&item3) { + s.NotZero(item3.OddEven.Input % 2) + s.NotZero(item3.OddEven.Input) + } + err = res.Err() + s.NoError(err) + + // Testing inline tag. + type OddEven oddEven + res = col.Find() + + var item31 struct { + OddEven `db:",inline" bson:",inline"` + } + for res.Next(&item31) { + s.NotZero(item31.Input % 2) + s.NotZero(item31.Input) + } + s.NoError(res.Err()) + + // Testing omision tag. + res = col.Find() + + var item4 struct { + Value uint `db:"-"` + } + for res.Next(&item4) { + s.Zero(item4.Value) + } + s.NoError(res.Err()) +} + +func (s *GenericTestSuite) TestExplicitAndDefaultMapping() { + var err error + var res mydb.Result + + var testE mapE + var testN mapN + + sess := s.Session() + + col := sess.Collection("CaSe_TesT") + + if err = col.Truncate(); err != nil { + if s.Adapter() != "mongo" { + s.NoError(err) + } + } + + // Testing explicit mapping. + testE = mapE{ + CaseTest: "Hello!", + } + + _, err = col.Insert(testE) + s.NoError(err) + + res = col.Find(mydb.Cond{"case_test": "Hello!"}) + if s.Adapter() == "ql" { + res = res.Select("id() as id", "case_test") + } + + err = res.One(&testE) + s.NoError(err) + + if s.Adapter() == "mongo" { + s.True(testE.MongoID.Valid()) + } else { + s.NotZero(testE.ID) + } + + // Testing default mapping. + testN = mapN{ + Case_TEST: "World!", + } + + _, err = col.Insert(testN) + s.NoError(err) + + if s.Adapter() == `mongo` { + res = col.Find(mydb.Cond{"case_test": "World!"}) + } else { + res = col.Find(mydb.Cond{"case_test": "World!"}) + } + + if s.Adapter() == `ql` { + res = res.Select(`id() as id`, `case_test`) + } + + err = res.One(&testN) + s.NoError(err) + + if s.Adapter() == `mongo` { + s.True(testN.MongoID.Valid()) + } else { + s.NotZero(testN.ID) + } +} + +func (s *GenericTestSuite) TestComparisonOperators() { + sess := s.Session() + + birthdays := sess.Collection("birthdays") + err := birthdays.Truncate() + if err != nil { + if s.Adapter() != "mongo" { + s.NoError(err) + } + } + + // Insert data for testing + birthdaysDataset := []birthday{ + { + Name: "Marie Smith", + Born: time.Date(1956, time.August, 5, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Peter", + Born: time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Eve Smith", + Born: time.Date(1911, time.February, 8, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Alex López", + Born: time.Date(2001, time.May, 5, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Rose Smith", + Born: time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Daria López", + Born: time.Date(1923, time.March, 23, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "", + Born: time.Date(1945, time.December, 1, 0, 0, 0, 0, defaultTimeLocation), + }, + { + Name: "Colin", + Born: time.Date(2010, time.May, 6, 0, 0, 0, 0, defaultTimeLocation), + }, + } + for _, birthday := range birthdaysDataset { + _, err := birthdays.Insert(birthday) + s.NoError(err) + } + + // Test: equal + { + var item birthday + err := birthdays.Find(mydb.Cond{ + "name": mydb.Eq("Colin"), + }).One(&item) + s.NoError(err) + s.NotNil(item) + + s.Equal("Colin", item.Name) + } + + // Test: not equal + { + var item birthday + err := birthdays.Find(mydb.Cond{ + "name": mydb.NotEq("Colin"), + }).One(&item) + s.NoError(err) + s.NotNil(item) + + s.NotEqual("Colin", item.Name) + } + + // Test: greater than + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Gt(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(2, len(items)) + for _, item := range items { + s.True(item.Born.After(ref)) + } + } + + // Test: less than + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Lt(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(5, len(items)) + for _, item := range items { + s.True(item.Born.Before(ref)) + } + } + + // Test: greater than or equal to + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Gte(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(3, len(items)) + for _, item := range items { + s.True(item.Born.After(ref) || item.Born.Equal(ref)) + } + } + + // Test: less than or equal to + { + var items []birthday + ref := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Lte(ref), + }).All(&items) + s.NoError(err) + s.NotZero(len(items)) + s.Equal(6, len(items)) + for _, item := range items { + s.True(item.Born.Before(ref) || item.Born.Equal(ref)) + } + } + + // Test: between + { + var items []birthday + dateA := time.Date(1911, time.February, 8, 0, 0, 0, 0, defaultTimeLocation) + dateB := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.Between(dateA, dateB), + }).All(&items) + s.NoError(err) + s.Equal(6, len(items)) + for _, item := range items { + s.True(item.Born.After(dateA) || item.Born.Equal(dateA)) + s.True(item.Born.Before(dateB) || item.Born.Equal(dateB)) + } + } + + // Test: not between + { + var items []birthday + dateA := time.Date(1911, time.February, 8, 0, 0, 0, 0, defaultTimeLocation) + dateB := time.Date(1967, time.July, 23, 0, 0, 0, 0, defaultTimeLocation) + err := birthdays.Find(mydb.Cond{ + "born": mydb.NotBetween(dateA, dateB), + }).All(&items) + s.NoError(err) + s.Equal(2, len(items)) + for _, item := range items { + s.False(item.Born.Before(dateA) || item.Born.Equal(dateA)) + s.False(item.Born.Before(dateB) || item.Born.Equal(dateB)) + } + } + + // Test: in + { + var items []birthday + names := []interface{}{"Peter", "Eve Smith", "Daria López", "Alex López"} + err := birthdays.Find(mydb.Cond{ + "name": mydb.In(names...), + }).All(&items) + s.NoError(err) + s.Equal(4, len(items)) + for _, item := range items { + inArray := false + for _, name := range names { + if name == item.Name { + inArray = true + } + } + s.True(inArray) + } + } + + // Test: not in + { + var items []birthday + names := []interface{}{"Peter", "Eve Smith", "Daria López", "Alex López"} + err := birthdays.Find(mydb.Cond{ + "name": mydb.NotIn(names...), + }).All(&items) + s.NoError(err) + s.Equal(4, len(items)) + for _, item := range items { + inArray := false + for _, name := range names { + if name == item.Name { + inArray = true + } + } + s.False(inArray) + } + } + + // Test: not in + { + var items []birthday + names := []interface{}{"Peter", "Eve Smith", "Daria López", "Alex López"} + err := birthdays.Find(mydb.Cond{ + "name": mydb.NotIn(names...), + }).All(&items) + s.NoError(err) + s.Equal(4, len(items)) + for _, item := range items { + inArray := false + for _, name := range names { + if name == item.Name { + inArray = true + } + } + s.False(inArray) + } + } + + // Test: is and is not + { + var items []birthday + err := birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.Is(nil)}, + mydb.Cond{"name": mydb.IsNot(nil)}, + )).All(&items) + s.NoError(err) + s.Equal(0, len(items)) + } + + // Test: is nil + { + var items []birthday + err := birthdays.Find(mydb.And( + mydb.Cond{"born_ut": mydb.IsNull()}, + )).All(&items) + s.NoError(err) + s.Equal(8, len(items)) + } + + // Test: like and not like + { + var items []birthday + var q mydb.Result + + switch s.Adapter() { + case "ql", "mongo": + q = birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.Like(".*ari.*")}, + mydb.Cond{"name": mydb.NotLike(".*Smith")}, + )) + default: + q = birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.Like("%ari%")}, + mydb.Cond{"name": mydb.NotLike("%Smith")}, + )) + } + + err := q.All(&items) + s.NoError(err) + s.Equal(1, len(items)) + + s.Equal("Daria López", items[0].Name) + } + + if s.Adapter() != "sqlite" && s.Adapter() != "mssql" { + // Test: regexp + { + var items []birthday + err := birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.RegExp("^[D|C|M]")}, + )).OrderBy("name").All(&items) + s.NoError(err) + s.Equal(3, len(items)) + + s.Equal("Colin", items[0].Name) + s.Equal("Daria López", items[1].Name) + s.Equal("Marie Smith", items[2].Name) + } + + // Test: not regexp + { + var items []birthday + names := []string{"Daria López", "Colin", "Marie Smith"} + err := birthdays.Find(mydb.And( + mydb.Cond{"name": mydb.NotRegExp("^[D|C|M]")}, + )).OrderBy("name").All(&items) + s.NoError(err) + s.Equal(5, len(items)) + + for _, item := range items { + for _, name := range names { + s.NotEqual(item.Name, name) + } + } + } + } + + // Test: after + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.After(ref), + }).All(&items) + s.NoError(err) + s.Equal(5, len(items)) + } + + // Test: on or after + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.OnOrAfter(ref), + }).All(&items) + s.NoError(err) + s.Equal(6, len(items)) + } + + // Test: before + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.Before(ref), + }).All(&items) + s.NoError(err) + s.Equal(2, len(items)) + } + + // Test: on or before + { + ref := time.Date(1944, time.December, 9, 0, 0, 0, 0, defaultTimeLocation) + var items []birthday + err := birthdays.Find(mydb.Cond{ + "born": mydb.OnOrBefore(ref), + }).All(&items) + s.NoError(err) + s.Equal(3, len(items)) + } +} diff --git a/internal/testsuite/record_suite.go b/internal/testsuite/record_suite.go new file mode 100644 index 0000000..89cddc4 --- /dev/null +++ b/internal/testsuite/record_suite.go @@ -0,0 +1,428 @@ +package testsuite + +import ( + "context" + "database/sql" + "fmt" + "time" + + "git.hexq.cn/tiglog/mydb" + "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" + "github.com/stretchr/testify/suite" +) + +type AccountsStore struct { + mydb.Collection +} + +type UsersStore struct { + mydb.Collection +} + +type LogsStore struct { + mydb.Collection +} + +func Accounts(sess mydb.Session) mydb.Store { + return &AccountsStore{sess.Collection("accounts")} +} + +func Users(sess mydb.Session) *UsersStore { + return &UsersStore{sess.Collection("users")} +} + +func Logs(sess mydb.Session) *LogsStore { + return &LogsStore{sess.Collection("logs")} +} + +type Log struct { + ID uint64 `db:"id,omitempty"` + Message string `db:"message"` +} + +func (*Log) Store(sess mydb.Session) mydb.Store { + return Logs(sess) +} + +var _ = mydb.Store(&LogsStore{}) + +type Account struct { + ID uint64 `db:"id,omitempty"` + Name string `db:"name"` + Disabled bool `db:"disabled"` + CreatedAt *time.Time `db:"created_at,omitempty"` +} + +func (*Account) Store(sess mydb.Session) mydb.Store { + return Accounts(sess) +} + +func (account *Account) AfterCreate(sess mydb.Session) error { + message := fmt.Sprintf("Account %q was created.", account.Name) + return sess.Save(&Log{Message: message}) +} + +type User struct { + ID uint64 `db:"id,omitempty"` + AccountID uint64 `db:"account_id"` + Username string `db:"username"` +} + +func (user *User) AfterCreate(sess mydb.Session) error { + message := fmt.Sprintf("User %q was created.", user.Username) + return sess.Save(&Log{Message: message}) +} + +func (*User) Store(sess mydb.Session) mydb.Store { + return Users(sess) +} + +type RecordTestSuite struct { + suite.Suite + Helper +} + +func (s *RecordTestSuite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *RecordTestSuite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) + + sess := s.Helper.Session() + + cols, err := sess.Collections() + s.NoError(err) + + for i := range cols { + err = cols[i].Truncate() + s.NoError(err) + } +} + +func (s *RecordTestSuite) TestFindOne() { + var err error + sess := s.Session() + + user := User{Username: "jose"} + err = sess.Save(&user) + s.NoError(err) + + s.NotZero(user.ID) + userID := user.ID + + user = User{} + err = Users(sess).Find(userID).One(&user) + s.NoError(err) + s.Equal("jose", user.Username) + + user = User{} + err = sess.Get(&user, mydb.Cond{"username": "jose"}) + s.NoError(err) + s.Equal("jose", user.Username) + + user.Username = "Catalina" + err = sess.Save(&user) + s.NoError(err) + + user = User{} + err = sess.Get(&user, userID) + s.NoError(err) + s.Equal("Catalina", user.Username) + + err = sess.Delete(&user) + s.NoError(err) + + err = sess.Get(&user, userID) + s.Error(err) + + err = sess.Collection("users"). + Find(userID). + One(&user) + s.Error(err) +} + +func (s *RecordTestSuite) TestAccounts() { + sess := s.Session() + + user := User{Username: "peter"} + + err := sess.Save(&user) + s.NoError(err) + + user = User{Username: "peter"} + err = sess.Save(&user) + s.Error(err, "username should be unique") + + account1 := Account{Name: "skywalker"} + err = sess.Save(&account1) + s.NoError(err) + + account2 := Account{} + err = sess.Get(&account2, account1.ID) + + s.NoError(err) + s.Equal(account1.Name, account2.Name) + + var account3 Account + err = sess.Get(&account3, account1.ID) + + s.NoError(err) + s.Equal(account1.Name, account3.Name) + + var a Account + err = sess.Get(&a, account1.ID) + s.NoError(err) + s.NotNil(a) + + account1.Disabled = true + err = sess.Save(&account1) + s.NoError(err) + + count, err := Accounts(sess).Count() + s.NoError(err) + s.Equal(uint64(1), count) + + err = sess.Delete(&account1) + s.NoError(err) + + count, err = Accounts(sess).Find().Count() + s.NoError(err) + s.Zero(count) +} + +func (s *RecordTestSuite) TestDelete() { + sess := s.Session() + + account := Account{Name: "Pressly"} + err := sess.Save(&account) + s.NoError(err) + s.NotZero(account.ID) + + // Delete by query -- without callbacks + err = Accounts(sess). + Find(account.ID). + Delete() + s.NoError(err) + + count, err := Accounts(sess).Find(account.ID).Count() + s.Zero(count) + s.NoError(err) +} + +func (s *RecordTestSuite) TestSlices() { + sess := s.Session() + + err := sess.Save(&Account{Name: "Apple"}) + s.NoError(err) + + err = sess.Save(&Account{Name: "Google"}) + s.NoError(err) + + var accounts []*Account + err = Accounts(sess). + Find(mydb.Cond{}). + All(&accounts) + s.NoError(err) + s.Len(accounts, 2) +} + +func (s *RecordTestSuite) TestSelectOnlyIDs() { + sess := s.Session() + + err := sess.Save(&Account{Name: "Apple"}) + s.NoError(err) + + err = sess.Save(&Account{Name: "Google"}) + s.NoError(err) + + var ids []struct { + Id int64 `db:"id"` + } + + err = Accounts(sess). + Find(). + Select("id").All(&ids) + s.NoError(err) + s.Len(ids, 2) + s.NotEmpty(ids[0]) +} + +func (s *RecordTestSuite) TestTx() { + sess := s.Session() + + user := User{Username: "peter"} + err := sess.Save(&user) + s.NoError(err) + + // This transaction should fail because user is a UNIQUE value and we already + // have a "peter". + err = sess.Tx(func(tx mydb.Session) error { + return tx.Save(&User{Username: "peter"}) + }) + s.Error(err) + + // This transaction should fail because user is a UNIQUE value and we already + // have a "peter". + err = sess.Tx(func(tx mydb.Session) error { + return tx.Save(&User{Username: "peter"}) + }) + s.Error(err) + + // This transaction will have no errors, but we'll produce one in order for + // it to rollback at the last moment. + err = sess.Tx(func(tx mydb.Session) error { + if err := tx.Save(&User{Username: "Joe"}); err != nil { + return err + } + + if err := tx.Save(&User{Username: "Cool"}); err != nil { + return err + } + + return fmt.Errorf("Rolling back for no reason.") + }) + s.Error(err) + + // Attempt to add two new unique values, if the transaction above had not + // been rolled back this transaction will fail. + err = sess.Tx(func(tx mydb.Session) error { + if err := tx.Save(&User{Username: "Joe"}); err != nil { + return err + } + + if err := tx.Save(&User{Username: "Cool"}); err != nil { + return err + } + + return nil + }) + s.NoError(err) + + // If the transaction above was successful, this one will fail. + err = sess.Tx(func(tx mydb.Session) error { + if err := tx.Save(&User{Username: "Joe"}); err != nil { + return err + } + + if err := tx.Save(&User{Username: "Cool"}); err != nil { + return err + } + + return nil + }) + s.Error(err) +} + +func (s *RecordTestSuite) TestInheritedTx() { + sess := s.Session() + + sqlDB := sess.Driver().(*sql.DB) + + user := User{Username: "peter"} + err := sess.Save(&user) + s.NoError(err) + + // Create a transaction + sqlTx, err := sqlmydb.Begin() + s.NoError(err) + + // And pass that transaction to upper/db, this whole session is a transaction. + upperTx, err := sqlbuilder.BindTx(s.Adapter(), sqlTx) + s.NoError(err) + + // Should fail because user is a UNIQUE value and we already have a "peter". + err = upperTx.Save(&User{Username: "peter"}) + s.Error(err) + + // The transaction is controlled outside upper/mydb. + err = sqlTx.Rollback() + s.NoError(err) + + // The sqlTx is worthless now. + err = upperTx.Save(&User{Username: "peter-2"}) + s.Error(err) + + // But we can create a new one. + sqlTx, err = sqlmydb.Begin() + s.NoError(err) + s.NotNil(sqlTx) + + // And create another session. + upperTx, err = sqlbuilder.BindTx(s.Adapter(), sqlTx) + s.NoError(err) + + // Adding two new values. + err = upperTx.Save(&User{Username: "Joe-2"}) + s.NoError(err) + + err = upperTx.Save(&User{Username: "Cool-2"}) + s.NoError(err) + + // And a value that is going to be rolled back. + err = upperTx.Save(&Account{Name: "Rolled back"}) + s.NoError(err) + + // This session happens to be a transaction, let's rollback everything. + err = sqlTx.Rollback() + s.NoError(err) + + // Start again. + sqlTx, err = sqlmydb.Begin() + s.NoError(err) + + tx, err := sqlbuilder.BindTx(s.Adapter(), sqlTx) + s.NoError(err) + + // Attempt to add two unique values. + err = tx.Save(&User{Username: "Joe-2"}) + s.NoError(err) + + err = tx.Save(&User{Username: "Cool-2"}) + s.NoError(err) + + // And a value that is going to be commited. + err = tx.Save(&Account{Name: "Commited!"}) + s.NoError(err) + + // Yes, commit them. + err = sqlTx.Commit() + s.NoError(err) +} + +func (s *RecordTestSuite) TestUnknownCollection() { + var err error + sess := s.Session() + + err = sess.Save(nil) + s.Error(err) + + _, err = sess.Collection("users").Insert(&User{Username: "Foo"}) + s.NoError(err) +} + +func (s *RecordTestSuite) TestContextCanceled() { + var err error + + sess := s.Session() + + err = sess.Collection("users").Truncate() + s.NoError(err) + + { + ctx, cancelFn := context.WithTimeout(context.Background(), time.Minute) + canceledSess := sess.WithContext(ctx) + + cancelFn() + + user := User{Username: "foo"} + err = canceledSess.Save(&user) + s.Error(err) + + c, err := sess.Collection("users").Count() + s.NoError(err) + s.Equal(uint64(0), c) + } +} diff --git a/internal/testsuite/sql_suite.go b/internal/testsuite/sql_suite.go new file mode 100644 index 0000000..b19a1e6 --- /dev/null +++ b/internal/testsuite/sql_suite.go @@ -0,0 +1,1974 @@ +package testsuite + +import ( + "context" + "database/sql" + "errors" + "fmt" + "math/rand" + "strconv" + "strings" + "sync" + "time" + + "git.hexq.cn/tiglog/mydb" + detectrace "github.com/ipfs/go-detect-race" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/suite" +) + +type artistType struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` +} + +type itemWithCompoundKey struct { + Code string `db:"code"` + UserID string `db:"user_id"` + SomeVal string `db:"some_val"` +} + +type customType struct { + Val []byte +} + +type artistWithCustomType struct { + Custom customType `db:"name"` +} + +func (f customType) String() string { + return fmt.Sprintf("foo: %s", string(f.Val)) +} + +func (f customType) MarshalDB() (interface{}, error) { + return f.String(), nil +} + +func (f *customType) UnmarshalDB(in interface{}) error { + switch t := in.(type) { + case []byte: + f.Val = t + case string: + f.Val = []byte(t) + } + return nil +} + +var ( + _ = mydb.Marshaler(&customType{}) + _ = mydb.Unmarshaler(&customType{}) +) + +type SQLTestSuite struct { + suite.Suite + + Helper +} + +func (s *SQLTestSuite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *SQLTestSuite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) + + sess := s.Session() + + // Creating test data + artist := sess.Collection("artist") + + artistNames := []string{"Ozzie", "Flea", "Slash", "Chrono"} + for _, artistName := range artistNames { + _, err := artist.Insert(map[string]string{ + "name": artistName, + }) + s.NoError(err) + } +} + +func (s *SQLTestSuite) TestPreparedStatementsCache() { + sess := s.Session() + + sess.SetPreparedStatementCache(true) + defer sess.SetPreparedStatementCache(false) + + var tMu sync.Mutex + tFatal := func(err error) { + tMu.Lock() + defer tMu.Unlock() + + s.T().Errorf("tmu: %v", err) + } + + // This limit was chosen because, by default, MySQL accepts 16k statements + // and dies. See https://github.com/upper/db/issues/287 + limit := 20000 + + if detectrace.WithRace() { + // When running this test under the Go race detector we quickly reach the limit + // of 8128 alive goroutines it can handle, so we set it to a safer number. + // + // Note that in order to fully stress this feature you'll have to run this + // test without the race detector. + limit = 100 + } + + var wg sync.WaitGroup + + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + + // This query is different on each iteration and generates a new + // prepared statement everytime it's called. + res := sess.Collection("artist").Find().Select(mydb.Raw(fmt.Sprintf("count(%d) AS c", i))) + + var count map[string]uint64 + err := res.One(&count) + if err != nil { + tFatal(err) + } + }(i) + } + wg.Wait() + + // Concurrent Insert can open many connections on MySQL / PostgreSQL, this + // sets a limit on them. + sess.SetMaxOpenConns(90) + + switch s.Adapter() { + case "ql": + limit = 1000 + } + + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // The same prepared query on every iteration. + _, err := sess.Collection("artist").Insert(artistType{ + Name: fmt.Sprintf("artist-%d", i), + }) + if err != nil { + tFatal(err) + } + }(i) + } + wg.Wait() + + // Insert returning creates a transaction. + for i := 0; i < limit; i++ { + wg.Add(1) + go func(i int) { + defer wg.Done() + // The same prepared query on every iteration. + artist := artistType{ + Name: fmt.Sprintf("artist-%d", i), + } + err := sess.Collection("artist").InsertReturning(&artist) + if err != nil { + tFatal(err) + } + }(i) + } + wg.Wait() + + // Removing the limit. + sess.SetMaxOpenConns(0) +} + +func (s *SQLTestSuite) TestTruncateAllCollections() { + sess := s.Session() + + collections, err := sess.Collections() + s.NoError(err) + s.True(len(collections) > 0) + + for _, col := range collections { + if ok, _ := col.Exists(); ok { + if err = col.Truncate(); err != nil { + s.NoError(err) + } + } + } +} + +func (s *SQLTestSuite) TestQueryLogger() { + logLevel := mydb.LC().Level() + + mydb.LC().SetLogger(logrus.New()) + mydb.LC().SetLevel(mydb.LogLevelDebug) + + defer func() { + mydb.LC().SetLogger(nil) + mydb.LC().SetLevel(logLevel) + }() + + sess := s.Session() + + _, err := sess.Collection("artist").Find().Count() + s.Equal(nil, err) + + _, err = sess.Collection("artist_x").Find().Count() + s.NotEqual(nil, err) +} + +func (s *SQLTestSuite) TestExpectCursorError() { + sess := s.Session() + + artist := sess.Collection("artist") + + res := artist.Find(-1) + c, err := res.Count() + s.Equal(uint64(0), c) + s.NoError(err) + + var item map[string]interface{} + err = res.One(&item) + s.Error(err) +} + +func (s *SQLTestSuite) TestInsertDefault() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + id, err := artist.Insert(&artistType{}) + s.NoError(err) + s.NotNil(id) + + err = artist.Truncate() + s.NoError(err) + + id, err = artist.Insert(nil) + s.NoError(err) + s.NotNil(id) +} + +func (s *SQLTestSuite) TestInsertReturning() { + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + itemMap := map[string]string{ + "name": "Ozzie", + } + s.Zero(itemMap["id"], "Must be zero before inserting") + err = artist.InsertReturning(&itemMap) + s.NoError(err) + s.NotZero(itemMap["id"], "Must not be zero after inserting") + + itemStruct := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Flea", + } + s.Zero(itemStruct.ID, "Must be zero before inserting") + err = artist.InsertReturning(&itemStruct) + s.NoError(err) + s.NotZero(itemStruct.ID, "Must not be zero after inserting") + + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") + + itemStruct2 := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Slash", + } + s.Zero(itemStruct2.ID, "Must be zero before inserting") + err = artist.InsertReturning(itemStruct2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemStruct2.ID, "Must still be zero because there was no insertion") + + itemMap2 := map[string]string{ + "name": "Janus", + } + s.Zero(itemMap2["id"], "Must be zero before inserting") + err = artist.InsertReturning(itemMap2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemMap2["id"], "Must still be zero because there was no insertion") + + // Counting elements, must be exactly 2 elements. + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") +} + +func (s *SQLTestSuite) TestInsertReturningWithinTransaction() { + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + err = sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + + itemMap := map[string]string{ + "name": "Ozzie", + } + s.Zero(itemMap["id"], "Must be zero before inserting") + err = artist.InsertReturning(&itemMap) + s.NoError(err) + s.NotZero(itemMap["id"], "Must not be zero after inserting") + + itemStruct := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Flea", + } + s.Zero(itemStruct.ID, "Must be zero before inserting") + err = artist.InsertReturning(&itemStruct) + s.NoError(err) + s.NotZero(itemStruct.ID, "Must not be zero after inserting") + + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") + + itemStruct2 := struct { + ID int `db:"id,omitempty"` + Name string `db:"name"` + }{ + 0, + "Slash", + } + s.Zero(itemStruct2.ID, "Must be zero before inserting") + err = artist.InsertReturning(itemStruct2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemStruct2.ID, "Must still be zero because there was no insertion") + + itemMap2 := map[string]string{ + "name": "Janus", + } + s.Zero(itemMap2["id"], "Must be zero before inserting") + err = artist.InsertReturning(itemMap2) + s.Error(err, "Should not happen, using a pointer should be enforced") + s.Zero(itemMap2["id"], "Must still be zero because there was no insertion") + + // Counting elements, must be exactly 2 elements. + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(2), count, "Expecting 2 elements") + + return fmt.Errorf("rolling back for no reason") + }) + s.Error(err) + + // Expecting no elements. + count, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(0), count, "Expecting 0 elements, everything was rolled back!") +} + +func (s *SQLTestSuite) TestInsertIntoArtistsTable() { + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + itemMap := map[string]string{ + "name": "Ozzie", + } + + record, err := artist.Insert(itemMap) + s.NoError(err) + s.NotNil(record) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + // Attempt to append a struct. + itemStruct := struct { + Name string `db:"name"` + }{ + "Flea", + } + + record, err = artist.Insert(itemStruct) + s.NoError(err) + s.NotNil(record) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + // Attempt to append a tagged struct. + itemStruct2 := struct { + ArtistName string `db:"name"` + }{ + "Slash", + } + + record, err = artist.Insert(&itemStruct2) + s.NoError(err) + s.NotNil(record) + + if pk, ok := record.ID().(int64); !ok || pk == 0 { + s.T().Errorf("Expecting an ID.") + } + + itemStruct3 := artistType{ + Name: "Janus", + } + record, err = artist.Insert(&itemStruct3) + s.NoError(err) + if s.Adapter() != "ql" { + s.NotZero(record) // QL always inserts an ID. + } + + // Counting elements, must be exactly 4 elements. + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(4), count) + + count, err = artist.Find(mydb.Cond{"name": mydb.Eq("Ozzie")}).Count() + s.NoError(err) + s.Equal(uint64(1), count) + + count, err = artist.Find("name", "Ozzie").And("name", "Flea").Count() + s.NoError(err) + s.Equal(uint64(0), count) + + count, err = artist.Find(mydb.Or(mydb.Cond{"name": "Ozzie"}, mydb.Cond{"name": "Flea"})).Count() + s.NoError(err) + s.Equal(uint64(2), count) + + count, err = artist.Find(mydb.And(mydb.Cond{"name": "Ozzie"}, mydb.Cond{"name": "Flea"})).Count() + s.NoError(err) + s.Equal(uint64(0), count) + + count, err = artist.Find(mydb.Cond{"name": "Ozzie"}).And(mydb.Cond{"name": "Flea"}).Count() + s.NoError(err) + s.Equal(uint64(0), count) +} + +func (s *SQLTestSuite) TestQueryNonExistentCollection() { + sess := s.Session() + + count, err := sess.Collection("doesnotexist").Find().Count() + s.Error(err) + s.Zero(count) +} + +func (s *SQLTestSuite) TestGetOneResult() { + sess := s.Session() + + artist := sess.Collection("artist") + + for i := 0; i < 5; i++ { + _, err := artist.Insert(map[string]string{ + "name": fmt.Sprintf("Artist %d", i), + }) + s.NoError(err) + } + + // Fetching one struct. + var someArtist artistType + err := artist.Find().Limit(1).One(&someArtist) + s.NoError(err) + + s.NotZero(someArtist.Name) + if s.Adapter() != "ql" { + s.NotZero(someArtist.ID) + } + + // Fetching a pointer to a pointer. + var someArtistObj *artistType + err = artist.Find().Limit(1).One(&someArtistObj) + s.NoError(err) + s.NotZero(someArtist.Name) + if s.Adapter() != "ql" { + s.NotZero(someArtist.ID) + } +} + +func (s *SQLTestSuite) TestGetWithOffset() { + sess := s.Session() + + artist := sess.Collection("artist") + + // Fetching one struct. + var artists []artistType + err := artist.Find().Offset(1).All(&artists) + s.NoError(err) + + s.Equal(3, len(artists)) +} + +func (s *SQLTestSuite) TestGetResultsOneByOne() { + sess := s.Session() + + artist := sess.Collection("artist") + + rowMap := map[string]interface{}{} + + res := artist.Find() + + err := res.Err() + s.NoError(err) + + for res.Next(&rowMap) { + s.NotZero(rowMap["id"]) + s.NotZero(rowMap["name"]) + } + err = res.Err() + s.NoError(err) + + err = res.Close() + s.NoError(err) + + // Dumping into a tagged struct. + rowStruct2 := struct { + Value1 int64 `db:"id"` + Value2 string `db:"name"` + }{} + + res = artist.Find() + + for res.Next(&rowStruct2) { + s.NotZero(rowStruct2.Value1) + s.NotZero(rowStruct2.Value2) + } + err = res.Err() + s.NoError(err) + + err = res.Close() + s.NoError(err) + + // Dumping into a slice of maps. + allRowsMap := []map[string]interface{}{} + + res = artist.Find() + + err = res.All(&allRowsMap) + s.NoError(err) + s.Equal(4, len(allRowsMap)) + + for _, singleRowMap := range allRowsMap { + if fmt.Sprintf("%d", singleRowMap["id"]) == "0" { + s.T().Errorf("Expecting a not null ID.") + } + } + + // Dumping into a slice of structs. + allRowsStruct := []struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + }{} + + res = artist.Find() + + if err = res.All(&allRowsStruct); err != nil { + s.T().Errorf("%v", err) + } + + s.Equal(4, len(allRowsStruct)) + + for _, singleRowStruct := range allRowsStruct { + s.NotZero(singleRowStruct.ID) + } + + // Dumping into a slice of tagged structs. + allRowsStruct2 := []struct { + Value1 int64 `db:"id"` + Value2 string `db:"name"` + }{} + + res = artist.Find() + + err = res.All(&allRowsStruct2) + s.NoError(err) + + s.Equal(4, len(allRowsStruct2)) + + for _, singleRowStruct := range allRowsStruct2 { + s.NotZero(singleRowStruct.Value1) + } +} + +func (s *SQLTestSuite) TestGetAllResults() { + sess := s.Session() + + artist := sess.Collection("artist") + + total, err := artist.Find().Count() + s.NoError(err) + s.NotZero(total) + + // Fetching all artists into struct + artists := []artistType{} + + res := artist.Find() + + err = res.All(&artists) + s.NoError(err) + s.Equal(len(artists), int(total)) + + s.NotZero(artists[0].Name) + s.NotZero(artists[0].ID) + + // Fetching all artists into struct pointers + artistObjs := []*artistType{} + res = artist.Find() + + err = res.All(&artistObjs) + s.NoError(err) + s.Equal(len(artistObjs), int(total)) + + s.NotZero(artistObjs[0].Name) + s.NotZero(artistObjs[0].ID) +} + +func (s *SQLTestSuite) TestInlineStructs() { + type reviewTypeDetails struct { + Name string `db:"name"` + Comments string `db:"comments"` + Created time.Time `db:"created"` + } + + type reviewType struct { + ID int64 `db:"id,omitempty"` + PublicationID int64 `db:"publication_id"` + Details reviewTypeDetails `db:",inline"` + } + + sess := s.Session() + + review := sess.Collection("review") + + err := review.Truncate() + s.NoError(err) + + rec := reviewType{ + PublicationID: 123, + Details: reviewTypeDetails{ + Name: "..name..", + Comments: "..comments..", + }, + } + + testTimeZone := time.UTC + switch s.Adapter() { + case "mysql": // MySQL uses a global time zone + testTimeZone = defaultTimeLocation + } + + createdAt := time.Date(2016, time.January, 1, 2, 3, 4, 0, testTimeZone) + rec.Details.Created = createdAt + + record, err := review.Insert(rec) + s.NoError(err) + s.NotZero(record.ID().(int64)) + + rec.ID = record.ID().(int64) + + var recChk reviewType + res := review.Find() + + err = res.One(&recChk) + s.NoError(err) + + s.Equal(rec, recChk) +} + +func (s *SQLTestSuite) TestUpdate() { + sess := s.Session() + + artist := sess.Collection("artist") + + _, err := artist.Insert(map[string]string{ + "name": "Ozzie", + }) + s.NoError(err) + + // Defining destination struct + value := struct { + ID int64 `db:"id,omitempty"` + Name string `db:"name"` + }{} + + // Getting the first artist. + cond := mydb.Cond{"id !=": mydb.NotEq(0)} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id() !=": 0} + } + res := artist.Find(cond).Limit(1) + + err = res.One(&value) + s.NoError(err) + + res = artist.Find(value.ID) + + // Updating set with a map + rowMap := map[string]interface{}{ + "name": strings.ToUpper(value.Name), + } + + err = res.Update(rowMap) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, rowMap["name"]) + + if s.Adapter() != "ql" { + + // Updating using raw + if err = res.Update(map[string]interface{}{"name": mydb.Raw("LOWER(name)")}); err != nil { + s.T().Errorf("%v", err) + } + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, strings.ToLower(rowMap["name"].(string))) + + // Updating using raw + if err = res.Update(struct { + Name *mydb.RawExpr `db:"name"` + }{mydb.Raw(`UPPER(name)`)}); err != nil { + s.T().Errorf("%v", err) + } + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, strings.ToUpper(rowMap["name"].(string))) + + // Updating using raw + if err = res.Update(struct { + Name *mydb.FuncExpr `db:"name"` + }{mydb.Func("LOWER", mydb.Raw("name"))}); err != nil { + s.T().Errorf("%v", err) + } + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying. + s.Equal(value.Name, strings.ToLower(rowMap["name"].(string))) + } + + // Updating set with a struct + rowStruct := struct { + Name string `db:"name"` + }{strings.ToLower(value.Name)} + + err = res.Update(rowStruct) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying + s.Equal(value.Name, rowStruct.Name) + + // Updating set with a tagged struct + rowStruct2 := struct { + Value1 string `db:"name"` + }{"john"} + + err = res.Update(rowStruct2) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying + s.Equal(value.Name, rowStruct2.Value1) + + // Updating set with a tagged object + rowStruct3 := &struct { + Value1 string `db:"name"` + }{"anderson"} + + err = res.Update(rowStruct3) + s.NoError(err) + + // Pulling it again. + err = res.One(&value) + s.NoError(err) + + // Verifying + s.Equal(value.Name, rowStruct3.Value1) +} + +func (s *SQLTestSuite) TestFunction() { + sess := s.Session() + + rowStruct := struct { + ID int64 + Name string + }{} + + artist := sess.Collection("artist") + + cond := mydb.Cond{"id NOT IN": []int{0, -1}} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id() NOT IN": []int{0, -1}} + } + res := artist.Find(cond) + + err := res.One(&rowStruct) + s.NoError(err) + + total, err := res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + // Testing conditions + cond = mydb.Cond{"id NOT IN": []interface{}{0, -1}} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id() NOT IN": []interface{}{0, -1}} + } + res = artist.Find(cond) + + err = res.One(&rowStruct) + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + res = artist.Find().Select("name") + + var rowMap map[string]interface{} + err = res.One(&rowMap) + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) + + res = artist.Find().Select("name") + + err = res.One(&rowMap) + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(4), total) +} + +func (s *SQLTestSuite) TestNullableFields() { + sess := s.Session() + + type testType struct { + ID int64 `db:"id,omitempty"` + NullStringTest sql.NullString `db:"_string"` + NullInt64Test sql.NullInt64 `db:"_int64"` + NullFloat64Test sql.NullFloat64 `db:"_float64"` + NullBoolTest sql.NullBool `db:"_bool"` + } + + col := sess.Collection(`data_types`) + + err := col.Truncate() + s.NoError(err) + + // Testing insertion of invalid nulls. + test := testType{ + NullStringTest: sql.NullString{String: "", Valid: false}, + NullInt64Test: sql.NullInt64{Int64: 0, Valid: false}, + NullFloat64Test: sql.NullFloat64{Float64: 0.0, Valid: false}, + NullBoolTest: sql.NullBool{Bool: false, Valid: false}, + } + + id, err := col.Insert(testType{}) + s.NoError(err) + + // Testing fetching of invalid nulls. + err = col.Find(id).One(&test) + s.NoError(err) + + s.False(test.NullInt64Test.Valid) + s.False(test.NullFloat64Test.Valid) + s.False(test.NullBoolTest.Valid) + + // Testing insertion of valid nulls. + test = testType{ + NullStringTest: sql.NullString{String: "", Valid: true}, + NullInt64Test: sql.NullInt64{Int64: 0, Valid: true}, + NullFloat64Test: sql.NullFloat64{Float64: 0.0, Valid: true}, + NullBoolTest: sql.NullBool{Bool: false, Valid: true}, + } + + id, err = col.Insert(test) + s.NoError(err) + + // Testing fetching of valid nulls. + err = col.Find(id).One(&test) + s.NoError(err) + + s.True(test.NullInt64Test.Valid) + s.True(test.NullBoolTest.Valid) + s.True(test.NullStringTest.Valid) +} + +func (s *SQLTestSuite) TestGroup() { + sess := s.Session() + + type statsType struct { + Numeric int `db:"numeric"` + Value int `db:"value"` + } + + stats := sess.Collection("stats_test") + + err := stats.Truncate() + s.NoError(err) + + // Adding row append. + for i := 0; i < 100; i++ { + numeric, value := rand.Intn(5), rand.Intn(100) + _, err := stats.Insert(statsType{numeric, value}) + s.NoError(err) + } + + // Testing GROUP BY + res := stats.Find().Select( + "numeric", + mydb.Raw("count(1) AS counter"), + mydb.Raw("sum(value) AS total"), + ).GroupBy("numeric") + + var results []map[string]interface{} + + err = res.All(&results) + s.NoError(err) + + s.Equal(5, len(results)) +} + +func (s *SQLTestSuite) TestInsertAndDelete() { + sess := s.Session() + + artist := sess.Collection("artist") + res := artist.Find() + + total, err := res.Count() + s.NoError(err) + s.Greater(total, uint64(0)) + + err = res.Delete() + s.NoError(err) + + total, err = res.Count() + s.NoError(err) + s.Equal(uint64(0), total) +} + +func (s *SQLTestSuite) TestCompositeKeys() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + sess := s.Session() + + compositeKeys := sess.Collection("composite_keys") + + { + n := rand.Intn(100000) + + item := itemWithCompoundKey{ + "ABCDEF", + strconv.Itoa(n), + "Some value", + } + + id, err := compositeKeys.Insert(&item) + s.NoError(err) + s.NotZero(id) + + var item2 itemWithCompoundKey + s.NotEqual(item2.SomeVal, item.SomeVal) + + // Finding by ID + err = compositeKeys.Find(id).One(&item2) + s.NoError(err) + + s.Equal(item2.SomeVal, item.SomeVal) + } + + { + n := rand.Intn(100000) + + item := itemWithCompoundKey{ + "ABCDEF", + strconv.Itoa(n), + "Some value", + } + + err := compositeKeys.InsertReturning(&item) + s.NoError(err) + } +} + +// Attempts to test database transactions. +func (s *SQLTestSuite) TestTransactionsAndRollback() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + sess := s.Session() + + err := sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + err := artist.Truncate() + s.NoError(err) + + _, err = artist.Insert(artistType{1, "First"}) + s.NoError(err) + + return nil + }) + s.NoError(err) + + err = sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + + _, err = artist.Insert(artistType{2, "Second"}) + s.NoError(err) + + // Won't fail. + _, err = artist.Insert(artistType{3, "Third"}) + s.NoError(err) + + // Will fail. + _, err = artist.Insert(artistType{1, "Duplicated"}) + s.Error(err) + + return err + }) + s.Error(err) + + // Let's verify we still have one element. + artist := sess.Collection("artist") + + count, err := artist.Find().Count() + s.NoError(err) + s.Equal(uint64(1), count) + + err = sess.Tx(func(tx mydb.Session) error { + artist := tx.Collection("artist") + + // Won't fail. + _, err = artist.Insert(artistType{2, "Second"}) + s.NoError(err) + + // Won't fail. + _, err = artist.Insert(artistType{3, "Third"}) + s.NoError(err) + + return fmt.Errorf("rollback for no reason") + }) + s.Error(err) + + // Let's verify we still have one element. + artist = sess.Collection("artist") + + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(1), count) + + // Attempt to add some rows. + err = sess.Tx(func(tx mydb.Session) error { + artist = tx.Collection("artist") + + // Won't fail. + _, err = artist.Insert(artistType{2, "Second"}) + s.NoError(err) + + // Won't fail. + _, err = artist.Insert(artistType{3, "Third"}) + s.NoError(err) + + return nil + }) + s.NoError(err) + + // Let's verify we have 3 rows. + artist = sess.Collection("artist") + + count, err = artist.Find().Count() + s.NoError(err) + s.Equal(uint64(3), count) +} + +func (s *SQLTestSuite) TestDataTypes() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + } + + type testValuesStruct struct { + Uint uint `db:"_uint"` + Uint8 uint8 `db:"_uint8"` + Uint16 uint16 `db:"_uint16"` + Uint32 uint32 `db:"_uint32"` + Uint64 uint64 `db:"_uint64"` + + Int int `db:"_int"` + Int8 int8 `db:"_int8"` + Int16 int16 `db:"_int16"` + Int32 int32 `db:"_int32"` + Int64 int64 `db:"_int64"` + + Float32 float32 `db:"_float32"` + Float64 float64 `db:"_float64"` + + Bool bool `db:"_bool"` + String string `db:"_string"` + Blob []byte `db:"_blob"` + + Date time.Time `db:"_date"` + DateN *time.Time `db:"_nildate"` + DateP *time.Time `db:"_ptrdate"` + DateD *time.Time `db:"_defaultdate,omitempty"` + Time int64 `db:"_time"` + } + + sess := s.Session() + + // Getting a pointer to the "data_types" collection. + dataTypes := sess.Collection("data_types") + + // Removing all data. + err := dataTypes.Truncate() + s.NoError(err) + + testTimeZone := time.Local + switch s.Adapter() { + case "mysql", "postgresql": // MySQL uses a global time zone + testTimeZone = defaultTimeLocation + } + + ts := time.Date(2011, 7, 28, 1, 2, 3, 0, testTimeZone) + tnz := ts.In(time.UTC) + + switch s.Adapter() { + case "mysql": + // MySQL uses a global timezone + tnz = ts.In(defaultTimeLocation) + } + + testValues := testValuesStruct{ + 1, 1, 1, 1, 1, + -1, -1, -1, -1, -1, + + 1.337, 1.337, + + true, + "Hello world!", + []byte("Hello world!"), + + ts, // Date + nil, // DateN + &tnz, // DateP + nil, // DateD + int64(time.Second * time.Duration(7331)), + } + id, err := dataTypes.Insert(testValues) + s.NoError(err) + s.NotNil(id) + + // Defining our set. + cond := mydb.Cond{"id": id} + if s.Adapter() == "ql" { + cond = mydb.Cond{"id()": id} + } + res := dataTypes.Find(cond) + + count, err := res.Count() + s.NoError(err) + s.NotZero(count) + + // Trying to dump the subject into an empty structure of the same type. + var item testValuesStruct + + err = res.One(&item) + s.NoError(err) + + s.NotNil(item.DateD) + s.NotNil(item.Date) + + // Copy the default date (this value is set by the database) + testValues.DateD = item.DateD + item.Date = item.Date.In(testTimeZone) + + s.Equal(testValues.Date, item.Date) + s.Equal(testValues.DateN, item.DateN) + s.Equal(testValues.DateP, item.DateP) + s.Equal(testValues.DateD, item.DateD) + + // The original value and the test subject must match. + s.Equal(testValues, item) +} + +func (s *SQLTestSuite) TestUpdateWithNullColumn() { + sess := s.Session() + + artist := sess.Collection("artist") + err := artist.Truncate() + s.NoError(err) + + type Artist struct { + ID int64 `db:"id,omitempty"` + Name *string `db:"name"` + } + + name := "José" + id, err := artist.Insert(Artist{0, &name}) + s.NoError(err) + + var item Artist + err = artist.Find(id).One(&item) + s.NoError(err) + + s.NotEqual(nil, item.Name) + s.Equal(name, *item.Name) + + err = artist.Find(id).Update(Artist{Name: nil}) + s.NoError(err) + + var item2 Artist + err = artist.Find(id).One(&item2) + s.NoError(err) + + s.Equal((*string)(nil), item2.Name) +} + +func (s *SQLTestSuite) TestBatchInsert() { + sess := s.Session() + + for batchSize := 0; batchSize < 17; batchSize++ { + err := sess.Collection("artist").Truncate() + s.NoError(err) + + q := sess.SQL().InsertInto("artist").Columns("name") + + switch s.Adapter() { + case "postgresql", "cockroachdb": + q = q.Amend(func(query string) string { + return query + ` ON CONFLICT DO NOTHING` + }) + } + + batch := q.Batch(batchSize) + + totalItems := int(rand.Int31n(21)) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + + c, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(totalItems), c) + + for i := 0; i < totalItems; i++ { + c, err := sess.Collection("artist").Find(mydb.Cond{"name": fmt.Sprintf("artist-%d", i)}).Count() + s.NoError(err) + s.Equal(uint64(1), c) + } + } +} + +func (s *SQLTestSuite) TestBatchInsertNoColumns() { + sess := s.Session() + + for batchSize := 0; batchSize < 17; batchSize++ { + err := sess.Collection("artist").Truncate() + s.NoError(err) + + batch := sess.SQL().InsertInto("artist").Batch(batchSize) + + totalItems := int(rand.Int31n(21)) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + + c, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(totalItems), c) + + for i := 0; i < totalItems; i++ { + c, err := sess.Collection("artist").Find(mydb.Cond{"name": fmt.Sprintf("artist-%d", i)}).Count() + s.NoError(err) + s.Equal(uint64(1), c) + } + } +} + +func (s *SQLTestSuite) TestBatchInsertReturningKeys() { + switch s.Adapter() { + case "postgresql", "cockroachdb": + // pass + default: + s.T().Skip("Currently not supported.") + return + } + + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + batchSize, totalItems := 7, 12 + + batch := sess.SQL().InsertInto("artist").Columns("name").Returning("id").Batch(batchSize) + + go func() { + defer batch.Done() + for i := 0; i < totalItems; i++ { + batch.Values(fmt.Sprintf("artist-%d", i)) + } + }() + + var keyMap []struct { + ID int `db:"id"` + } + for batch.NextResult(&keyMap) { + // Each insertion must produce new keys. + s.True(len(keyMap) > 0) + s.True(len(keyMap) <= batchSize) + + // Find the elements we've just inserted + keys := make([]int, 0, len(keyMap)) + for i := range keyMap { + keys = append(keys, keyMap[i].ID) + } + + // Make sure count matches. + c, err := sess.Collection("artist").Find(mydb.Cond{"id": keys}).Count() + s.NoError(err) + s.Equal(uint64(len(keyMap)), c) + } + s.NoError(batch.Err()) + + // Count all new elements + c, err := sess.Collection("artist").Find().Count() + s.NoError(err) + s.Equal(uint64(totalItems), c) +} + +func (s *SQLTestSuite) TestPaginator() { + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + batch := sess.SQL().InsertInto("artist").Batch(100) + + go func() { + defer batch.Done() + for i := 0; i < 999; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + + q := sess.SQL().SelectFrom("artist") + if s.Adapter() == "ql" { + q = sess.SQL().SelectFrom(sess.SQL().Select("id() AS id", "name").From("artist")) + } + + const pageSize = 13 + cursorColumn := "id" + + paginator := q.Paginate(pageSize) + + var zerothPage []artistType + err = paginator.Page(0).All(&zerothPage) + s.NoError(err) + s.Equal(pageSize, len(zerothPage)) + + var firstPage []artistType + err = paginator.Page(1).All(&firstPage) + s.NoError(err) + s.Equal(pageSize, len(firstPage)) + + s.Equal(zerothPage, firstPage) + + var secondPage []artistType + err = paginator.Page(2).All(&secondPage) + s.NoError(err) + s.Equal(pageSize, len(secondPage)) + + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(77), totalPages) + + totalEntries, err := paginator.TotalEntries() + s.NoError(err) + s.NotZero(totalEntries) + s.Equal(uint64(999), totalEntries) + + var lastPage []artistType + err = paginator.Page(totalPages).All(&lastPage) + s.NoError(err) + s.Equal(11, len(lastPage)) + + var beyondLastPage []artistType + err = paginator.Page(totalPages + 1).All(&beyondLastPage) + s.NoError(err) + s.Equal(0, len(beyondLastPage)) + + var hundredthPage []artistType + err = paginator.Page(100).All(&hundredthPage) + s.NoError(err) + s.Equal(0, len(hundredthPage)) + + for i := uint(0); i < totalPages; i++ { + current := paginator.Page(i + 1) + + var items []artistType + err := current.All(&items) + if err != nil { + s.T().Errorf("%v", err) + } + s.NoError(err) + if len(items) < 1 { + s.Equal(totalPages+1, i) + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(pageSize*int(i)+j)), items[j].Name) + } + } + + paginator = paginator.Cursor(cursorColumn) + { + current := paginator.Page(1) + for i := 0; ; i++ { + var items []artistType + err := current.All(&items) + if err != nil { + s.T().Errorf("%v", err) + } + if len(items) < 1 { + s.Equal(int(totalPages), i) + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int64(pageSize*int(i)+j)), items[j].Name) + } + current = current.NextPage(items[len(items)-1].ID) + } + } + + { + current := paginator.Page(totalPages) + for i := totalPages; ; i-- { + var items []artistType + + err := current.All(&items) + s.NoError(err) + + if len(items) < 1 { + s.Equal(uint(0), i) + break + } + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", pageSize*int(i-1)+j), items[j].Name) + } + + current = current.PrevPage(items[0].ID) + } + } + + if s.Adapter() == "ql" { + s.T().Skip("Unsupported, see https://github.com/cznic/ql/issues/182") + return + } + + { + result := sess.Collection("artist").Find() + + fifteenResults := 15 + resultPaginator := result.Paginate(uint(fifteenResults)) + + count, err := resultPaginator.TotalPages() + s.Equal(uint(67), count) + s.NoError(err) + + var items []artistType + fifthPage := 5 + err = resultPaginator.Page(uint(fifthPage)).All(&items) + s.NoError(err) + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", int(fifteenResults)*(fifthPage-1)+j), items[j].Name) + } + + resultPaginator = resultPaginator.Cursor(cursorColumn).Page(1) + for i := 0; ; i++ { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", fifteenResults*i+j), items[j].Name) + } + resultPaginator = resultPaginator.NextPage(items[len(items)-1].ID) + } + + resultPaginator = resultPaginator.Cursor(cursorColumn).Page(count) + for i := count; ; i-- { + var items []artistType + + err = resultPaginator.All(&items) + s.NoError(err) + + if len(items) < 1 { + s.Equal(uint(0), i) + break + } + + for j := 0; j < len(items); j++ { + s.Equal(fmt.Sprintf("artist-%d", fifteenResults*(int(i)-1)+j), items[j].Name) + } + resultPaginator = resultPaginator.PrevPage(items[0].ID) + } + } + + { + // Testing page size 0. + paginator := q.Paginate(0) + + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.Equal(uint(1), totalPages) + + totalEntries, err := paginator.TotalEntries() + s.NoError(err) + s.Equal(uint64(999), totalEntries) + + var allItems []artistType + err = paginator.Page(0).All(&allItems) + s.NoError(err) + s.Equal(totalEntries, uint64(len(allItems))) + + } +} + +func (s *SQLTestSuite) TestPaginator_Issue607() { + sess := s.Session() + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + // Add first batch + { + batch := sess.SQL().InsertInto("artist").Batch(50) + + go func() { + defer batch.Done() + for i := 0; i < 49; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-1.%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + } + + artists := []*artistType{} + paginator := sess.SQL().Select("name").From("artist").Paginate(10) + + err = paginator.Page(1).All(&artists) + s.NoError(err) + + { + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(5), totalPages) + } + + // Add second batch + { + batch := sess.SQL().InsertInto("artist").Batch(50) + + go func() { + defer batch.Done() + for i := 0; i < 49; i++ { + value := struct { + Name string `db:"name"` + }{fmt.Sprintf("artist-2.%d", i)} + batch.Values(value) + } + }() + + err = batch.Wait() + s.NoError(err) + s.NoError(batch.Err()) + } + + { + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(10), totalPages, "expect number of pages to change") + } + + artists = []*artistType{} + + cond := mydb.Cond{"name": mydb.Like("artist-1.%")} + if s.Adapter() == "ql" { + cond = mydb.Cond{"name": mydb.Like("artist-1.")} + } + + paginator = sess.SQL().Select("name").From("artist").Where(cond).Paginate(10) + + err = paginator.Page(1).All(&artists) + s.NoError(err) + + { + totalPages, err := paginator.TotalPages() + s.NoError(err) + s.NotZero(totalPages) + s.Equal(uint(5), totalPages, "expect same 5 pages from the first batch") + } + +} + +func (s *SQLTestSuite) TestSession() { + sess := s.Session() + + var all []map[string]interface{} + + err := sess.Collection("artist").Truncate() + s.NoError(err) + + _, err = sess.SQL().InsertInto("artist").Values(struct { + Name string `db:"name"` + }{"Rinko Kikuchi"}).Exec() + s.NoError(err) + + // Using explicit iterator. + iter := sess.SQL().SelectFrom("artist").Iterator() + err = iter.All(&all) + + s.NoError(err) + s.NotZero(all) + + // Using explicit iterator to fetch one item. + var item map[string]interface{} + iter = sess.SQL().SelectFrom("artist").Iterator() + err = iter.One(&item) + + s.NoError(err) + s.NotZero(item) + + // Using explicit iterator and NextScan. + iter = sess.SQL().SelectFrom("artist").Iterator() + var id int + var name string + + if s.Adapter() == "ql" { + err = iter.NextScan(&name) + id = 1 + } else { + err = iter.NextScan(&id, &name) + } + + s.NoError(err) + s.NotZero(id) + s.NotEmpty(name) + s.NoError(iter.Close()) + + err = iter.NextScan(&id, &name) + s.Error(err) + + // Using explicit iterator and ScanOne. + iter = sess.SQL().SelectFrom("artist").Iterator() + id, name = 0, "" + if s.Adapter() == "ql" { + err = iter.ScanOne(&name) + id = 1 + } else { + err = iter.ScanOne(&id, &name) + } + + s.NoError(err) + s.NotZero(id) + s.NotEmpty(name) + + err = iter.ScanOne(&id, &name) + s.Error(err) + + // Using explicit iterator and Next. + iter = sess.SQL().SelectFrom("artist").Iterator() + + var artist map[string]interface{} + for iter.Next(&artist) { + if s.Adapter() != "ql" { + s.NotZero(artist["id"]) + } + s.NotEmpty(artist["name"]) + } + // We should not have any error after finishing successfully exiting a Next() loop. + s.Empty(iter.Err()) + + for i := 0; i < 5; i++ { + // But we'll get errors if we attempt to continue using Next(). + s.False(iter.Next(&artist)) + s.Error(iter.Err()) + } + + // Using implicit iterator. + q := sess.SQL().SelectFrom("artist") + err = q.All(&all) + + s.NoError(err) + s.NotZero(all) + + err = sess.Tx(func(tx mydb.Session) error { + q := tx.SQL().SelectFrom("artist") + s.NotZero(iter) + + err = q.All(&all) + s.NoError(err) + s.NotZero(all) + + return nil + }) + + s.NoError(err) +} + +func (s *SQLTestSuite) TestExhaustConnectionPool() { + if s.Adapter() == "ql" { + s.T().Skip("Currently not supported.") + return + } + + sess := s.Session() + errRolledBack := errors.New("rolled back") + + var wg sync.WaitGroup + for i := 0; i < 100; i++ { + s.T().Logf("Tx %d: Pending", i) + + wg.Add(1) + go func(wg *sync.WaitGroup, i int) { + defer wg.Done() + + // Requesting a new transaction session. + start := time.Now() + s.T().Logf("Tx: %d: NewTx", i) + + expectError := false + if i%2 == 1 { + expectError = true + } + + err := sess.Tx(func(tx mydb.Session) error { + s.T().Logf("Tx %d: OK (time to connect: %v)", i, time.Since(start)) + // Let's suppose that we do a bunch of complex stuff and that the + // transaction lasts 3 seconds. + time.Sleep(time.Second * 3) + + if expectError { + if _, err := tx.SQL().DeleteFrom("artist").Exec(); err != nil { + return err + } + return errRolledBack + } + + var account map[string]interface{} + if err := tx.Collection("artist").Find().One(&account); err != nil { + return err + } + return nil + }) + if expectError { + s.Error(err) + s.True(errors.Is(err, errRolledBack)) + } else { + s.NoError(err) + } + }(&wg, i) + } + + wg.Wait() +} + +func (s *SQLTestSuite) TestCustomType() { + // See https://github.com/upper/db/issues/332 + sess := s.Session() + + artist := sess.Collection("artist") + + err := artist.Truncate() + s.NoError(err) + + id, err := artist.Insert(artistWithCustomType{ + Custom: customType{Val: []byte("some name")}, + }) + s.NoError(err) + s.NotNil(id) + + var bar artistWithCustomType + err = artist.Find(id).One(&bar) + s.NoError(err) + + s.Equal("foo: some name", string(bar.Custom.Val)) +} + +func (s *SQLTestSuite) Test_Issue565() { + s.Session().Collection("birthdays").Insert(&birthday{ + Name: "Lucy", + Born: time.Now(), + }) + + parentCtx := context.WithValue(s.Session().Context(), "carry", 1) + s.NotZero(parentCtx.Value("carry")) + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Nanosecond) + defer cancel() + + sess := s.Session() + + sess = sess.WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.Error(err) + s.Zero(result.Name) + + s.NotZero(ctx.Value("carry")) + } + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Second*10) + cancel() // cancel before passing + + sess := s.Session().WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.Error(err) + s.Zero(result.Name) + + s.NotZero(ctx.Value("carry")) + } + + { + ctx, cancel := context.WithTimeout(parentCtx, time.Second) + defer cancel() + + sess := s.Session().WithContext(ctx) + + var result birthday + err := sess.Collection("birthdays").Find().Select("name").One(&result) + + s.NoError(err) + s.NotZero(result.Name) + + s.NotZero(ctx.Value("carry")) + } +} + +func (s *SQLTestSuite) TestSelectFromSubquery() { + sess := s.Session() + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.SQL().SelectFrom("artist").Where(mydb.Cond{ + "name": mydb.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + + { + var artists []artistType + q := sess.SQL().SelectFrom( + sess.Collection("artist").Find(mydb.Cond{ + "name": mydb.IsNotNull(), + }), + ).As("_q") + err := q.All(&artists) + s.NoError(err) + + s.NotZero(len(artists)) + } + +} diff --git a/internal/testsuite/suite.go b/internal/testsuite/suite.go new file mode 100644 index 0000000..52ae956 --- /dev/null +++ b/internal/testsuite/suite.go @@ -0,0 +1,37 @@ +package testsuite + +import ( + "time" + + "git.hexq.cn/tiglog/mydb" + "github.com/stretchr/testify/suite" +) + +const TimeZone = "Canada/Eastern" + +var defaultTimeLocation, _ = time.LoadLocation(TimeZone) + +type Helper interface { + Session() mydb.Session + + Adapter() string + + TearUp() error + TearDown() error +} + +type Suite struct { + suite.Suite + + Helper +} + +func (s *Suite) AfterTest(suiteName, testName string) { + err := s.TearDown() + s.NoError(err) +} + +func (s *Suite) BeforeTest(suiteName, testName string) { + err := s.TearUp() + s.NoError(err) +} diff --git a/intersection.go b/intersection.go new file mode 100644 index 0000000..dda67a0 --- /dev/null +++ b/intersection.go @@ -0,0 +1,50 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// AndExpr represents an expression joined by a logical conjuction (AND). +type AndExpr struct { + *adapter.LogicalExprGroup +} + +// And adds more expressions to the group. +func (a *AndExpr) And(andConds ...LogicalExpr) *AndExpr { + var fn func(*[]LogicalExpr) error + if len(andConds) > 0 { + fn = func(in *[]LogicalExpr) error { + *in = append(*in, andConds...) + return nil + } + } + return &AndExpr{a.LogicalExprGroup.Frame(fn)} +} + +// Empty returns false if the expressions has zero conditions. +func (a *AndExpr) Empty() bool { + return a.LogicalExprGroup.Empty() +} + +// And joins conditions under logical conjunction. Conditions can be +// represented by `db.Cond{}`, `db.Or()` or `db.And()`. +// +// Examples: +// +// // name = "Peter" AND last_name = "Parker" +// db.And( +// db.Cond{"name": "Peter"}, +// db.Cond{"last_name": "Parker "}, +// ) +// +// // (name = "Peter" OR name = "Mickey") AND last_name = "Mouse" +// db.And( +// db.Or( +// db.Cond{"name": "Peter"}, +// db.Cond{"name": "Mickey"}, +// ), +// db.Cond{"last_name": "Mouse"}, +// ) +func And(conds ...LogicalExpr) *AndExpr { + return &AndExpr{adapter.NewLogicalExprGroup(adapter.LogicalOperatorAnd, conds...)} +} + +var _ = adapter.LogicalExpr(&AndExpr{}) diff --git a/iterator.go b/iterator.go new file mode 100644 index 0000000..c833bba --- /dev/null +++ b/iterator.go @@ -0,0 +1,26 @@ +package mydb + +// Iterator provides methods for iterating over query results. +type Iterator interface { + // ResultMapper provides methods to retrieve and map results. + ResultMapper + + // Scan dumps the current result into the given pointer variable pointers. + Scan(dest ...interface{}) error + + // NextScan advances the iterator and performs Scan. + NextScan(dest ...interface{}) error + + // ScanOne advances the iterator, performs Scan and closes the iterator. + ScanOne(dest ...interface{}) error + + // Next dumps the current element into the given destination, which could be + // a pointer to either a map or a struct. + Next(dest ...interface{}) bool + + // Err returns the last error produced by the cursor. + Err() error + + // Close closes the iterator and frees up the cursor. + Close() error +} diff --git a/logger.go b/logger.go new file mode 100644 index 0000000..cdb9f2e --- /dev/null +++ b/logger.go @@ -0,0 +1,349 @@ +package mydb + +import ( + "context" + "fmt" + "log" + "os" + "regexp" + "runtime" + "strings" + "time" +) + +const ( + fmtLogSessID = `Session ID: %05d` + fmtLogTxID = `Transaction ID: %05d` + fmtLogQuery = `Query: %s` + fmtLogArgs = `Arguments: %#v` + fmtLogRowsAffected = `Rows affected: %d` + fmtLogLastInsertID = `Last insert ID: %d` + fmtLogError = `Error: %v` + fmtLogStack = `Stack: %v` + fmtLogTimeTaken = `Time taken: %0.5fs` + fmtLogContext = `Context: %v` +) + +const ( + maxFrames = 30 + skipFrames = 3 +) + +var ( + reInvisibleChars = regexp.MustCompile(`[\s\r\n\t]+`) +) + +// LogLevel represents a verbosity level for logs +type LogLevel int8 + +// Log levels +const ( + LogLevelTrace LogLevel = -1 + + LogLevelDebug LogLevel = iota + LogLevelInfo + LogLevelWarn + LogLevelError + LogLevelFatal + LogLevelPanic +) + +var logLevels = map[LogLevel]string{ + LogLevelTrace: "TRACE", + LogLevelDebug: "DEBUG", + LogLevelInfo: "INFO", + LogLevelWarn: "WARNING", + LogLevelError: "ERROR", + LogLevelFatal: "FATAL", + LogLevelPanic: "PANIC", +} + +func (ll LogLevel) String() string { + return logLevels[ll] +} + +const ( + defaultLogLevel LogLevel = LogLevelWarn +) + +var defaultLogger Logger = log.New(os.Stdout, "", log.LstdFlags) + +// Logger represents a logging interface that is compatible with the standard +// "log" and with many other logging libraries. +type Logger interface { + Fatal(v ...interface{}) + Fatalf(format string, v ...interface{}) + + Print(v ...interface{}) + Printf(format string, v ...interface{}) + + Panic(v ...interface{}) + Panicf(format string, v ...interface{}) +} + +// LoggingCollector provides different methods for collecting and classifying +// log messages. +type LoggingCollector interface { + Enabled(LogLevel) bool + + Level() LogLevel + + SetLogger(Logger) + SetLevel(LogLevel) + + Trace(v ...interface{}) + Tracef(format string, v ...interface{}) + + Debug(v ...interface{}) + Debugf(format string, v ...interface{}) + + Info(v ...interface{}) + Infof(format string, v ...interface{}) + + Warn(v ...interface{}) + Warnf(format string, v ...interface{}) + + Error(v ...interface{}) + Errorf(format string, v ...interface{}) + + Fatal(v ...interface{}) + Fatalf(format string, v ...interface{}) + + Panic(v ...interface{}) + Panicf(format string, v ...interface{}) +} + +type loggingCollector struct { + level LogLevel + logger Logger +} + +func (c *loggingCollector) Enabled(level LogLevel) bool { + return level >= c.level +} + +func (c *loggingCollector) SetLevel(level LogLevel) { + c.level = level +} + +func (c *loggingCollector) Level() LogLevel { + return c.level +} + +func (c *loggingCollector) Logger() Logger { + if c.logger == nil { + return defaultLogger + } + return c.logger +} + +func (c *loggingCollector) SetLogger(logger Logger) { + c.logger = logger +} + +func (c *loggingCollector) logf(level LogLevel, f string, v ...interface{}) { + if level >= LogLevelPanic { + c.Logger().Panicf(f, v...) + } + if level >= LogLevelFatal { + c.Logger().Fatalf(f, v...) + } + if c.Enabled(level) { + c.Logger().Printf(f, v...) + } +} + +func (c *loggingCollector) log(level LogLevel, v ...interface{}) { + if level >= LogLevelPanic { + c.Logger().Panic(v...) + } + if level >= LogLevelFatal { + c.Logger().Fatal(v...) + } + if c.Enabled(level) { + c.Logger().Print(v...) + } +} + +func (c *loggingCollector) Debugf(format string, v ...interface{}) { + c.logf(LogLevelDebug, format, v...) +} +func (c *loggingCollector) Debug(v ...interface{}) { + c.log(LogLevelDebug, v...) +} + +func (c *loggingCollector) Tracef(format string, v ...interface{}) { + c.logf(LogLevelTrace, format, v...) +} +func (c *loggingCollector) Trace(v ...interface{}) { + c.log(LogLevelDebug, v...) +} + +func (c *loggingCollector) Infof(format string, v ...interface{}) { + c.logf(LogLevelInfo, format, v...) +} +func (c *loggingCollector) Info(v ...interface{}) { + c.log(LogLevelInfo, v...) +} + +func (c *loggingCollector) Warnf(format string, v ...interface{}) { + c.logf(LogLevelWarn, format, v...) +} +func (c *loggingCollector) Warn(v ...interface{}) { + c.log(LogLevelWarn, v...) +} + +func (c *loggingCollector) Errorf(format string, v ...interface{}) { + c.logf(LogLevelError, format, v...) +} +func (c *loggingCollector) Error(v ...interface{}) { + c.log(LogLevelError, v...) +} + +func (c *loggingCollector) Fatalf(format string, v ...interface{}) { + c.logf(LogLevelFatal, format, v...) +} +func (c *loggingCollector) Fatal(v ...interface{}) { + c.log(LogLevelFatal, v...) +} + +func (c *loggingCollector) Panicf(format string, v ...interface{}) { + c.logf(LogLevelPanic, format, v...) +} +func (c *loggingCollector) Panic(v ...interface{}) { + c.log(LogLevelPanic, v...) +} + +var defaultLoggingCollector LoggingCollector = &loggingCollector{ + level: defaultLogLevel, + logger: defaultLogger, +} + +// QueryStatus represents the status of a query after being executed. +type QueryStatus struct { + SessID uint64 + TxID uint64 + + RowsAffected *int64 + LastInsertID *int64 + + RawQuery string + Args []interface{} + + Err error + + Start time.Time + End time.Time + + Context context.Context +} + +func (q *QueryStatus) Query() string { + query := reInvisibleChars.ReplaceAllString(q.RawQuery, " ") + query = strings.TrimSpace(query) + return query +} + +func (q *QueryStatus) Stack() []string { + frames := collectFrames() + lines := make([]string, 0, len(frames)) + + for _, frame := range frames { + lines = append(lines, fmt.Sprintf("%s@%s:%d", frame.Function, frame.File, frame.Line)) + } + return lines +} + +// String returns a formatted log message. +func (q *QueryStatus) String() string { + lines := make([]string, 0, 8) + + if q.SessID > 0 { + lines = append(lines, fmt.Sprintf(fmtLogSessID, q.SessID)) + } + + if q.TxID > 0 { + lines = append(lines, fmt.Sprintf(fmtLogTxID, q.TxID)) + } + + if query := q.RawQuery; query != "" { + lines = append(lines, fmt.Sprintf(fmtLogQuery, q.Query())) + } + + if len(q.Args) > 0 { + lines = append(lines, fmt.Sprintf(fmtLogArgs, q.Args)) + } + + if stack := q.Stack(); len(stack) > 0 { + lines = append(lines, fmt.Sprintf(fmtLogStack, "\n\t"+strings.Join(stack, "\n\t"))) + } + + if q.RowsAffected != nil { + lines = append(lines, fmt.Sprintf(fmtLogRowsAffected, *q.RowsAffected)) + } + if q.LastInsertID != nil { + lines = append(lines, fmt.Sprintf(fmtLogLastInsertID, *q.LastInsertID)) + } + + if q.Err != nil { + lines = append(lines, fmt.Sprintf(fmtLogError, q.Err)) + } + + lines = append(lines, fmt.Sprintf(fmtLogTimeTaken, float64(q.End.UnixNano()-q.Start.UnixNano())/float64(1e9))) + + if q.Context != nil { + lines = append(lines, fmt.Sprintf(fmtLogContext, q.Context)) + } + + return "\t" + strings.Replace(strings.Join(lines, "\n"), "\n", "\n\t", -1) + "\n\n" +} + +// LC returns the logging collector. +func LC() LoggingCollector { + return defaultLoggingCollector +} + +func init() { + if logLevel := strings.ToUpper(os.Getenv("UPPER_DB_LOG")); logLevel != "" { + for ll := range logLevels { + if ll.String() == logLevel { + LC().SetLevel(ll) + break + } + } + } +} + +func collectFrames() []runtime.Frame { + pc := make([]uintptr, maxFrames) + n := runtime.Callers(skipFrames, pc) + if n == 0 { + return nil + } + + pc = pc[:n] + frames := runtime.CallersFrames(pc) + + collectedFrames := make([]runtime.Frame, 0, maxFrames) + discardedFrames := make([]runtime.Frame, 0, maxFrames) + for { + frame, more := frames.Next() + + // collect all frames except those from upper/db and runtime stack + if (strings.Contains(frame.Function, "upper/db") || strings.Contains(frame.Function, "/go/src/")) && !strings.Contains(frame.Function, "test") { + discardedFrames = append(discardedFrames, frame) + } else { + collectedFrames = append(collectedFrames, frame) + } + + if !more { + break + } + } + + if len(collectedFrames) < 1 { + return discardedFrames + } + + return collectedFrames +} diff --git a/logger_test.go b/logger_test.go new file mode 100644 index 0000000..5295d97 --- /dev/null +++ b/logger_test.go @@ -0,0 +1,11 @@ +package mydb + +import ( + "errors" + "testing" +) + +func TestLogger(t *testing.T) { + err := errors.New("fake error") + LC().Error(err) +} diff --git a/marshal.go b/marshal.go new file mode 100644 index 0000000..b9dd64a --- /dev/null +++ b/marshal.go @@ -0,0 +1,16 @@ +package mydb + +// Marshaler is the interface implemented by struct fields that can transform +// themselves into values to be stored in a database. +type Marshaler interface { + // MarshalDB returns the internal database representation of the Go value. + MarshalDB() (interface{}, error) +} + +// Unmarshaler is the interface implemented by struct fields that can transform +// themselves from database values into Go values. +type Unmarshaler interface { + // UnmarshalDB receives an internal database representation of a value and + // transforms it into a Go value. + UnmarshalDB(interface{}) error +} diff --git a/mydb.go b/mydb.go new file mode 100644 index 0000000..2942fe8 --- /dev/null +++ b/mydb.go @@ -0,0 +1,50 @@ +// Package db (or tiglog/mydb) provides an agnostic data access layer to work with +// different databases. +// +// Install tiglog/mydb: +// +// go get git.hexq.cn/tiglog/mydb +// +// Usage +// +// package main +// +// import ( +// "log" +// +// "git.hexq.cn/tiglog/mydb/adapter/postgresql" // Imports the postgresql adapter. +// ) +// +// var settings = postgresql.ConnectionURL{ +// Database: `booktown`, +// Host: `demo.upper.io`, +// User: `demouser`, +// Password: `demop4ss`, +// } +// +// // Book represents a book. +// type Book struct { +// ID uint `db:"id"` +// Title string `db:"title"` +// AuthorID uint `db:"author_id"` +// SubjectID uint `db:"subject_id"` +// } +// +// func main() { +// sess, err := postgresql.Open(settings) +// if err != nil { +// log.Fatal(err) +// } +// defer sess.Close() +// +// var books []Book +// if err := sess.Collection("books").Find().OrderBy("title").All(&books); err != nil { +// log.Fatal(err) +// } +// +// log.Println("Books:") +// for _, book := range books { +// log.Printf("%q (ID: %d)\n", book.Title, book.ID) +// } +// } +package mydb diff --git a/raw.go b/raw.go new file mode 100644 index 0000000..44164fd --- /dev/null +++ b/raw.go @@ -0,0 +1,17 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// RawExpr represents a raw (non-filtered) expression. +type RawExpr = adapter.RawExpr + +// Raw marks chunks of data as protected, so they pass directly to the query +// without any filtering. Use with care. +// +// Example: +// +// // SOUNDEX('Hello') +// Raw("SOUNDEX('Hello')") +func Raw(value string, args ...interface{}) *RawExpr { + return adapter.NewRawExpr(value, args) +} diff --git a/readme.adoc b/readme.adoc new file mode 100644 index 0000000..0c91d29 --- /dev/null +++ b/readme.adoc @@ -0,0 +1,18 @@ += 说明 +:author: tiglog +:experimental: +:toc: left +:toclevels: 3 +:toc-title: 目录 +:sectnums: +:icons: font +:!webfonts: +:autofit-option: +:source-highlighter: rouge +:rouge-style: github +:source-linenums-option: +:revdate: 2023-09-18 +:imagesdir: ./img + + +基于 `upper/db` 改造。 diff --git a/record.go b/record.go new file mode 100644 index 0000000..20f3ead --- /dev/null +++ b/record.go @@ -0,0 +1,62 @@ +package mydb + +// Record is the equivalence between concrete database schemas and Go values. +type Record interface { + Store(sess Session) Store +} + +// HasConstraints is an interface for records that defines a Constraints method +// that returns the record's own constraints. +type HasConstraints interface { + Constraints() Cond +} + +// Validator is an interface for records that defines an (optional) Validate +// method that is called before persisting a record (creating or updating). If +// Validate returns an error the current operation is cancelled and rolled +// back. +type Validator interface { + Validate() error +} + +// BeforeCreateHook is an interface for records that defines an BeforeCreate +// method that is called before creating a record. If BeforeCreate returns an +// error the create process is cancelled and rolled back. +type BeforeCreateHook interface { + BeforeCreate(Session) error +} + +// AfterCreateHook is an interface for records that defines an AfterCreate +// method that is called after creating a record. If AfterCreate returns an +// error the create process is cancelled and rolled back. +type AfterCreateHook interface { + AfterCreate(Session) error +} + +// BeforeUpdateHook is an interface for records that defines a BeforeUpdate +// method that is called before updating a record. If BeforeUpdate returns an +// error the update process is cancelled and rolled back. +type BeforeUpdateHook interface { + BeforeUpdate(Session) error +} + +// AfterUpdateHook is an interface for records that defines an AfterUpdate +// method that is called after updating a record. If AfterUpdate returns an +// error the update process is cancelled and rolled back. +type AfterUpdateHook interface { + AfterUpdate(Session) error +} + +// BeforeDeleteHook is an interface for records that defines a BeforeDelete +// method that is called before removing a record. If BeforeDelete returns an +// error the delete process is cancelled and rolled back. +type BeforeDeleteHook interface { + BeforeDelete(Session) error +} + +// AfterDeleteHook is an interface for records that defines a AfterDelete +// method that is called after removing a record. If AfterDelete returns an +// error the delete process is cancelled and rolled back. +type AfterDeleteHook interface { + AfterDelete(Session) error +} diff --git a/result.go b/result.go new file mode 100644 index 0000000..6b962a1 --- /dev/null +++ b/result.go @@ -0,0 +1,193 @@ +package mydb + +import ( + "database/sql/driver" +) + +// Result is an interface that defines methods for result sets. +type Result interface { + + // String returns the SQL statement to be used in the query. + String() string + + // Limit defines the maximum number of results for this set. It only has + // effect on `One()`, `All()` and `Next()`. A negative limit cancels any + // previous limit settings. + Limit(int) Result + + // Offset ignores the first n results. It only has effect on `One()`, `All()` + // and `Next()`. A negative offset cancels any previous offset settings. + Offset(int) Result + + // OrderBy receives one or more field names that define the order in which + // elements will be returned in a query, field names may be prefixed with a + // minus sign (-) indicating descending order, ascending order will be used + // otherwise. + OrderBy(...interface{}) Result + + // Select defines specific columns to be fetched on every column in the + // result set. + Select(...interface{}) Result + + // And adds more filtering conditions on top of the existing constraints. + // + // res := col.Find(...).And(...) + And(...interface{}) Result + + // GroupBy is used to group results that have the same value in the same column + // or columns. + GroupBy(...interface{}) Result + + // Delete deletes all items within the result set. `Offset()` and `Limit()` + // are not honoured by `Delete()`. + Delete() error + + // Update modifies all items within the result set. `Offset()` and `Limit()` + // are not honoured by `Update()`. + Update(interface{}) error + + // Count returns the number of items that match the set conditions. + // `Offset()` and `Limit()` are not honoured by `Count()` + Count() (uint64, error) + + // Exists returns true if at least one item on the collection exists. False + // otherwise. + Exists() (bool, error) + + // Next fetches the next result within the result set and dumps it into the + // given pointer to struct or pointer to map. You must call + // `Close()` after finishing using `Next()`. + Next(ptrToStruct interface{}) bool + + // Err returns the last error that has happened with the result set, nil + // otherwise. + Err() error + + // One fetches the first result within the result set and dumps it into the + // given pointer to struct or pointer to map. The result set is automatically + // closed after picking the element, so there is no need to call Close() + // after using One(). + One(ptrToStruct interface{}) error + + // All fetches all results within the result set and dumps them into the + // given pointer to slice of maps or structs. The result set is + // automatically closed, so there is no need to call Close() after + // using All(). + All(sliceOfStructs interface{}) error + + // Paginate splits the results of the query into pages containing pageSize + // items. When using pagination previous settings for `Limit()` and + // `Offset()` are ignored. Page numbering starts at 1. + // + // Use `Page()` to define the specific page to get results from. + // + // Example: + // + // r = q.Paginate(12) + // + // You can provide constraints an order settings when using pagination: + // + // Example: + // + // res := q.Where(conds).OrderBy("-id").Paginate(12) + // err := res.Page(4).All(&items) + Paginate(pageSize uint) Result + + // Page makes the result set return results only from the page identified by + // pageNumber. Page numbering starts from 1. + // + // Example: + // + // r = q.Paginate(12).Page(4) + Page(pageNumber uint) Result + + // Cursor defines the column that is going to be taken as basis for + // cursor-based pagination. + // + // Example: + // + // a = q.Paginate(10).Cursor("id") + // b = q.Paginate(12).Cursor("-id") + // + // You can set "" as cursorColumn to disable cursors. + Cursor(cursorColumn string) Result + + // NextPage returns the next results page according to the cursor. It expects + // a cursorValue, which is the value the cursor column had on the last item + // of the current result set (lower bound). + // + // Example: + // + // cursor = q.Paginate(12).Cursor("id") + // res = cursor.NextPage(items[len(items)-1].ID) + // + // Note that `NextPage()` requires a cursor, any column with an absolute + // order (given two values one always precedes the other) can be a cursor. + // + // You can define the pagination order and add constraints to your result: + // + // cursor = q.Where(...).OrderBy("id").Paginate(10).Cursor("id") + // res = cursor.NextPage(lowerBound) + NextPage(cursorValue interface{}) Result + + // PrevPage returns the previous results page according to the cursor. It + // expects a cursorValue, which is the value the cursor column had on the + // fist item of the current result set. + // + // Example: + // + // current = current.PrevPage(items[0].ID) + // + // Note that PrevPage requires a cursor, any column with an absolute order + // (given two values one always precedes the other) can be a cursor. + // + // You can define the pagination order and add constraints to your result: + // + // cursor = q.Where(...).OrderBy("id").Paginate(10).Cursor("id") + // res = cursor.PrevPage(upperBound) + PrevPage(cursorValue interface{}) Result + + // TotalPages returns the total number of pages the result set could produce. + // If no pagination parameters have been set this value equals 1. + TotalPages() (uint, error) + + // TotalEntries returns the total number of matching items in the result set. + TotalEntries() (uint64, error) + + // Close closes the result set and frees all locked resources. + Close() error +} + +// InsertResult provides infomation about an insert operation. +type InsertResult interface { + // ID returns the ID of the newly inserted record. + ID() ID +} + +type insertResult struct { + id interface{} +} + +func (r *insertResult) ID() ID { + return r.id +} + +// ConstraintValue satisfies adapter.ConstraintValuer +func (r *insertResult) ConstraintValue() interface{} { + return r.id +} + +// Value satisfies driver.Valuer +func (r *insertResult) Value() (driver.Value, error) { + return r.id, nil +} + +// NewInsertResult creates an InsertResult +func NewInsertResult(id interface{}) InsertResult { + return &insertResult{id: id} +} + +// ID represents a record ID +type ID interface{} + +var _ = driver.Valuer(&insertResult{}) diff --git a/session.go b/session.go new file mode 100644 index 0000000..a49192a --- /dev/null +++ b/session.go @@ -0,0 +1,78 @@ +package mydb + +import ( + "context" + "database/sql" +) + +// Session is an interface that defines methods for database adapters. +type Session interface { + // ConnectionURL returns the DSN that was used to set up the adapter. + ConnectionURL() ConnectionURL + + // Name returns the name of the database. + Name() string + + // Ping returns an error if the DBMS could not be reached. + Ping() error + + // Collection receives a table name and returns a collection reference. The + // information retrieved from a collection is cached. + Collection(name string) Collection + + // Collections returns a collection reference of all non system tables on the + // database. + Collections() ([]Collection, error) + + // Save creates or updates a record. + Save(record Record) error + + // Get retrieves a record that matches the given condition. + Get(record Record, cond interface{}) error + + // Delete deletes a record. + Delete(record Record) error + + // Reset resets all the caching mechanisms the adapter is using. + Reset() + + // Close terminates the currently active connection to the DBMS and clears + // all caches. + Close() error + + // Driver returns the underlying driver of the adapter as an interface. + // + // In order to actually use the driver, the `interface{}` value needs to be + // casted into the appropriate type. + // + // Example: + // internalSQLDriver := sess.Driver().(*sql.DB) + Driver() interface{} + + // SQL returns a special interface for SQL databases. + SQL() SQL + + // Tx creates a transaction block on the default database context and passes + // it to the function fn. If fn returns no error the transaction is commited, + // else the transaction is rolled back. After being commited or rolled back + // the transaction is closed automatically. + Tx(fn func(sess Session) error) error + + // TxContext creates a transaction block on the given context and passes it to + // the function fn. If fn returns no error the transaction is commited, else + // the transaction is rolled back. After being commited or rolled back the + // transaction is closed automatically. + TxContext(ctx context.Context, fn func(sess Session) error, opts *sql.TxOptions) error + + // Context returns the context used as default for queries on this session + // and for new transactions. If no context has been set, a default + // context.Background() is returned. + Context() context.Context + + // WithContext returns the same session on a different default context. The + // session is identical to the original one in all ways except for the + // context. + WithContext(ctx context.Context) Session + + Settings +} diff --git a/settings.go b/settings.go new file mode 100644 index 0000000..3bb2172 --- /dev/null +++ b/settings.go @@ -0,0 +1,179 @@ +package mydb + +import ( + "sync" + "sync/atomic" + "time" +) + +// Settings defines methods to get or set configuration values. +type Settings interface { + // SetPreparedStatementCache enables or disables the prepared statement + // cache. + SetPreparedStatementCache(bool) + + // PreparedStatementCacheEnabled returns true if the prepared statement cache + // is enabled, false otherwise. + PreparedStatementCacheEnabled() bool + + // SetConnMaxLifetime sets the default maximum amount of time a connection + // may be reused. + SetConnMaxLifetime(time.Duration) + + // ConnMaxLifetime returns the default maximum amount of time a connection + // may be reused. + ConnMaxLifetime() time.Duration + + // SetConnMaxIdleTime sets the default maximum amount of time a connection + // may remain idle. + SetConnMaxIdleTime(time.Duration) + + // ConnMaxIdleTime returns the default maximum amount of time a connection + // may remain idle. + ConnMaxIdleTime() time.Duration + + // SetMaxIdleConns sets the default maximum number of connections in the idle + // connection pool. + SetMaxIdleConns(int) + + // MaxIdleConns returns the default maximum number of connections in the idle + // connection pool. + MaxIdleConns() int + + // SetMaxOpenConns sets the default maximum number of open connections to the + // database. + SetMaxOpenConns(int) + + // MaxOpenConns returns the default maximum number of open connections to the + // database. + MaxOpenConns() int + + // SetMaxTransactionRetries sets the number of times a transaction can + // be retried. + SetMaxTransactionRetries(int) + + // MaxTransactionRetries returns the maximum number of times a + // transaction can be retried. + MaxTransactionRetries() int +} + +type settings struct { + sync.RWMutex + + preparedStatementCacheEnabled uint32 + + connMaxLifetime time.Duration + connMaxIdleTime time.Duration + maxOpenConns int + maxIdleConns int + + maxTransactionRetries int +} + +func (c *settings) binaryOption(opt *uint32) bool { + return atomic.LoadUint32(opt) == 1 +} + +func (c *settings) setBinaryOption(opt *uint32, value bool) { + if value { + atomic.StoreUint32(opt, 1) + return + } + atomic.StoreUint32(opt, 0) +} + +func (c *settings) SetPreparedStatementCache(value bool) { + c.setBinaryOption(&c.preparedStatementCacheEnabled, value) +} + +func (c *settings) PreparedStatementCacheEnabled() bool { + return c.binaryOption(&c.preparedStatementCacheEnabled) +} + +func (c *settings) SetConnMaxLifetime(t time.Duration) { + c.Lock() + c.connMaxLifetime = t + c.Unlock() +} + +func (c *settings) ConnMaxLifetime() time.Duration { + c.RLock() + defer c.RUnlock() + return c.connMaxLifetime +} + +func (c *settings) SetConnMaxIdleTime(t time.Duration) { + c.Lock() + c.connMaxIdleTime = t + c.Unlock() +} + +func (c *settings) ConnMaxIdleTime() time.Duration { + c.RLock() + defer c.RUnlock() + return c.connMaxIdleTime +} + +func (c *settings) SetMaxIdleConns(n int) { + c.Lock() + c.maxIdleConns = n + c.Unlock() +} + +func (c *settings) MaxIdleConns() int { + c.RLock() + defer c.RUnlock() + return c.maxIdleConns +} + +func (c *settings) SetMaxTransactionRetries(n int) { + c.Lock() + c.maxTransactionRetries = n + c.Unlock() +} + +func (c *settings) MaxTransactionRetries() int { + c.RLock() + defer c.RUnlock() + if c.maxTransactionRetries < 1 { + return 1 + } + return c.maxTransactionRetries +} + +func (c *settings) SetMaxOpenConns(n int) { + c.Lock() + c.maxOpenConns = n + c.Unlock() +} + +func (c *settings) MaxOpenConns() int { + c.RLock() + defer c.RUnlock() + return c.maxOpenConns +} + +// NewSettings returns a new settings value prefilled with the current default +// settings. +func NewSettings() Settings { + def := DefaultSettings.(*settings) + return &settings{ + preparedStatementCacheEnabled: def.preparedStatementCacheEnabled, + connMaxLifetime: def.connMaxLifetime, + connMaxIdleTime: def.connMaxIdleTime, + maxIdleConns: def.maxIdleConns, + maxOpenConns: def.maxOpenConns, + maxTransactionRetries: def.maxTransactionRetries, + } +} + +// DefaultSettings provides default global configuration settings for database +// sessions. +var DefaultSettings Settings = &settings{ + preparedStatementCacheEnabled: 0, + connMaxLifetime: time.Duration(0), + connMaxIdleTime: time.Duration(0), + maxIdleConns: 10, + maxOpenConns: 0, + maxTransactionRetries: 1, +} diff --git a/sql.go b/sql.go new file mode 100644 index 0000000..f6934a1 --- /dev/null +++ b/sql.go @@ -0,0 +1,190 @@ +package mydb + +import ( + "context" + "database/sql" +) + +// SQL defines methods that can be used to build a SQL query with chainable +// method calls. +// +// Queries are immutable, so every call to any method will return a new +// pointer, if you want to build a query using variables you need to reassign +// them, like this: +// +// a = builder.Select("name").From("foo") // "a" is created +// +// a.Where(...) // No effect, the value returned from Where is ignored. +// +// a = a.Where(...) // "a" is reassigned and points to a different address. +type SQL interface { + + // Select initializes and returns a Selector, it accepts column names as + // parameters. + // + // The returned Selector does not initially point to any table, a call to + // From() is required after Select() to complete a valid query. + // + // Example: + // + // q := sqlbuilder.Select("first_name", "last_name").From("people").Where(...) + Select(columns ...interface{}) Selector + + // SelectFrom creates a Selector that selects all columns (like SELECT *) + // from the given table. + // + // Example: + // + // q := sqlbuilder.SelectFrom("people").Where(...) + SelectFrom(table ...interface{}) Selector + + // InsertInto prepares and returns an Inserter targeted at the given table. + // + // Example: + // + // q := sqlbuilder.InsertInto("books").Columns(...).Values(...) + InsertInto(table string) Inserter + + // DeleteFrom prepares a Deleter targeted at the given table. + // + // Example: + // + // q := sqlbuilder.DeleteFrom("tasks").Where(...) + DeleteFrom(table string) Deleter + + // Update prepares and returns an Updater targeted at the given table. + // + // Example: + // + // q := sqlbuilder.Update("profile").Set(...).Where(...) + Update(table string) Updater + + // Exec executes a SQL query that does not return any rows, like sql.Exec. + // Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.Exec(`INSERT INTO books (title) VALUES("La Ciudad y los Perros")`) + Exec(query interface{}, args ...interface{}) (sql.Result, error) + + // ExecContext executes a SQL query that does not return any rows, like sql.ExecContext. + // Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.ExecContext(ctx, `INSERT INTO books (title) VALUES(?)`, "La Ciudad y los Perros") + ExecContext(ctx context.Context, query interface{}, args ...interface{}) (sql.Result, error) + + // Prepare creates a prepared statement for later queries or executions. The + // caller must call the statement's Close method when the statement is no + // longer needed. + Prepare(query interface{}) (*sql.Stmt, error) + + // Prepare creates a prepared statement on the guiven context for later + // queries or executions. The caller must call the statement's Close method + // when the statement is no longer needed. + PrepareContext(ctx context.Context, query interface{}) (*sql.Stmt, error) + + // Query executes a SQL query that returns rows, like sql.Query. Queries can + // be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.Query(`SELECT * FROM people WHERE name = "Mateo"`) + Query(query interface{}, args ...interface{}) (*sql.Rows, error) + + // QueryContext executes a SQL query that returns rows, like + // sql.QueryContext. Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.QueryContext(ctx, `SELECT * FROM people WHERE name = ?`, "Mateo") + QueryContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Rows, error) + + // QueryRow executes a SQL query that returns one row, like sql.QueryRow. + // Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.QueryRow(`SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`) + QueryRow(query interface{}, args ...interface{}) (*sql.Row, error) + + // QueryRowContext executes a SQL query that returns one row, like + // sql.QueryRowContext. Queries can be either strings or upper-db statements. + // + // Example: + // + // sqlbuilder.QueryRowContext(ctx, `SELECT * FROM people WHERE name = "Haruki" AND last_name = "Murakami" LIMIT 1`) + QueryRowContext(ctx context.Context, query interface{}, args ...interface{}) (*sql.Row, error) + + // Iterator executes a SQL query that returns rows and creates an Iterator + // with it. + // + // Example: + // + // sqlbuilder.Iterator(`SELECT * FROM people WHERE name LIKE "M%"`) + Iterator(query interface{}, args ...interface{}) Iterator + + // IteratorContext executes a SQL query that returns rows and creates an Iterator + // with it. + // + // Example: + // + // sqlbuilder.IteratorContext(ctx, `SELECT * FROM people WHERE name LIKE "M%"`) + IteratorContext(ctx context.Context, query interface{}, args ...interface{}) Iterator + + // NewIterator converts a *sql.Rows value into an Iterator. + NewIterator(rows *sql.Rows) Iterator + + // NewIteratorContext converts a *sql.Rows value into an Iterator. + NewIteratorContext(ctx context.Context, rows *sql.Rows) Iterator +} + +// SQLExecer provides methods for executing statements that do not return +// results. +type SQLExecer interface { + // Exec executes a statement and returns sql.Result. + Exec() (sql.Result, error) + + // ExecContext executes a statement and returns sql.Result. + ExecContext(context.Context) (sql.Result, error) +} + +// SQLPreparer provides the Prepare and PrepareContext methods for creating +// prepared statements. +type SQLPreparer interface { + // Prepare creates a prepared statement. + Prepare() (*sql.Stmt, error) + + // PrepareContext creates a prepared statement. + PrepareContext(context.Context) (*sql.Stmt, error) +} + +// SQLGetter provides methods for executing statements that return results. +type SQLGetter interface { + // Query returns *sql.Rows. + Query() (*sql.Rows, error) + + // QueryContext returns *sql.Rows. + QueryContext(context.Context) (*sql.Rows, error) + + // QueryRow returns only one row. + QueryRow() (*sql.Row, error) + + // QueryRowContext returns only one row. + QueryRowContext(ctx context.Context) (*sql.Row, error) +} + +// SQLEngine represents a SQL engine that can execute SQL queries. This is +// compatible with *sql.DB. +type SQLEngine interface { + Exec(string, ...interface{}) (sql.Result, error) + Prepare(string) (*sql.Stmt, error) + Query(string, ...interface{}) (*sql.Rows, error) + QueryRow(string, ...interface{}) *sql.Row + + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + PrepareContext(context.Context, string) (*sql.Stmt, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} diff --git a/store.go b/store.go new file mode 100644 index 0000000..4225df2 --- /dev/null +++ b/store.go @@ -0,0 +1,36 @@ +package mydb + +// Store represents a data store. +type Store interface { + Collection +} + +// StoreSaver is an interface for data stores that defines a Save method that +// has the task of persisting a record. +type StoreSaver interface { + Save(record Record) error +} + +// StoreCreator is an interface for data stores that defines a Create method +// that has the task of creating a new record. +type StoreCreator interface { + Create(record Record) error +} + +// StoreDeleter is an interface for data stores that defines a Delete method +// that has the task of removing a record. +type StoreDeleter interface { + Delete(record Record) error +} + +// StoreUpdater is an interface for data stores that defines a Update method +// that has the task of updating a record. +type StoreUpdater interface { + Update(record Record) error +} + +// StoreGetter is an interface for data stores that defines a Get method that +// has the task of retrieving a record. +type StoreGetter interface { + Get(record Record, id interface{}) error +} diff --git a/union.go b/union.go new file mode 100644 index 0000000..d97b3a9 --- /dev/null +++ b/union.go @@ -0,0 +1,41 @@ +package mydb + +import "git.hexq.cn/tiglog/mydb/internal/adapter" + +// OrExpr represents a logical expression joined by logical disjunction (OR). +type OrExpr struct { + *adapter.LogicalExprGroup +} + +// Or adds more expressions to the group. +func (o *OrExpr) Or(orConds ...LogicalExpr) *OrExpr { + var fn func(*[]LogicalExpr) error + if len(orConds) > 0 { + fn = func(in *[]LogicalExpr) error { + *in = append(*in, orConds...) + return nil + } + } + return &OrExpr{o.LogicalExprGroup.Frame(fn)} +} + +// Empty returns false if the expressions has zero conditions. +func (o *OrExpr) Empty() bool { + return o.LogicalExprGroup.Empty() +} + +// Or joins conditions under logical disjunction. Conditions can be represented +// by `db.Cond{}`, `db.Or()` or `db.And()`. +// +// Example: +// +// // year = 2012 OR year = 1987 +// db.Or( +// db.Cond{"year": 2012}, +// db.Cond{"year": 1987}, +// ) +func Or(conds ...LogicalExpr) *OrExpr { + return &OrExpr{adapter.NewLogicalExprGroup(adapter.LogicalOperatorOr, defaultJoin(conds...)...)} +} + +var _ = adapter.LogicalExpr(&OrExpr{})