first commit
This commit is contained in:
commit
4ac92c26b0
4
.gitignore
vendored
Normal file
4
.gitignore
vendored
Normal file
@ -0,0 +1,4 @@
|
||||
*.sw?
|
||||
*.db
|
||||
*.tmp
|
||||
generated_*.go
|
39
Makefile
Normal file
39
Makefile
Normal file
@ -0,0 +1,39 @@
|
||||
SHELL ?= /bin/bash
|
||||
|
||||
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs=2 -v -u
|
||||
|
||||
TEST_FLAGS ?=
|
||||
|
||||
UPPER_DB_LOG ?= WARN
|
||||
|
||||
export TEST_FLAGS
|
||||
export PARALLEL_FLAGS
|
||||
export UPPER_DB_LOG
|
||||
|
||||
test: go-test-internal test-adapters
|
||||
|
||||
benchmark: go-benchmark-internal
|
||||
|
||||
go-benchmark-%:
|
||||
go test -v -benchtime=500ms -bench=. ./$*/...
|
||||
|
||||
go-test-%:
|
||||
go test -v ./$*/...
|
||||
|
||||
test-adapters: \
|
||||
test-adapter-postgresql \
|
||||
# test-adapter-mysql \
|
||||
# test-adapter-sqlite \
|
||||
# test-adapter-mongo
|
||||
|
||||
test-adapter-%:
|
||||
($(MAKE) -C adapter/$* test-extended || exit 1)
|
||||
|
||||
test-generic:
|
||||
export TEST_FLAGS="-run TestGeneric"; \
|
||||
$(MAKE) test-adapters
|
||||
|
||||
goimports:
|
||||
for FILE in $$(find -name "*.go" | grep -v vendor); do \
|
||||
goimports -w $$FILE; \
|
||||
done
|
54
adapter.go
Normal file
54
adapter.go
Normal file
@ -0,0 +1,54 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
)
|
||||
|
||||
var (
|
||||
adapterMap = make(map[string]Adapter)
|
||||
adapterMapMu sync.RWMutex
|
||||
)
|
||||
|
||||
// Adapter interface defines an adapter
|
||||
type Adapter interface {
|
||||
Open(ConnectionURL) (Session, error)
|
||||
}
|
||||
|
||||
type missingAdapter struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (ma *missingAdapter) Open(ConnectionURL) (Session, error) {
|
||||
return nil, fmt.Errorf("mydb: Missing adapter %q, did you forget to import it?", ma.name)
|
||||
}
|
||||
|
||||
// RegisterAdapter registers a generic database adapter.
|
||||
func RegisterAdapter(name string, adapter Adapter) {
|
||||
adapterMapMu.Lock()
|
||||
defer adapterMapMu.Unlock()
|
||||
|
||||
if name == "" {
|
||||
panic(`Missing adapter name`)
|
||||
}
|
||||
if _, ok := adapterMap[name]; ok {
|
||||
panic(`db.RegisterAdapter() called twice for adapter: ` + name)
|
||||
}
|
||||
adapterMap[name] = adapter
|
||||
}
|
||||
|
||||
// LookupAdapter returns a previously registered adapter by name.
|
||||
func LookupAdapter(name string) Adapter {
|
||||
adapterMapMu.RLock()
|
||||
defer adapterMapMu.RUnlock()
|
||||
|
||||
if adapter, ok := adapterMap[name]; ok {
|
||||
return adapter
|
||||
}
|
||||
return &missingAdapter{name: name}
|
||||
}
|
||||
|
||||
// Open attempts to stablish a connection with a database.
|
||||
func Open(adapterName string, settings ConnectionURL) (Session, error) {
|
||||
return LookupAdapter(adapterName).Open(settings)
|
||||
}
|
43
adapter/mongo/Makefile
Normal file
43
adapter/mongo/Makefile
Normal file
@ -0,0 +1,43 @@
|
||||
SHELL ?= bash
|
||||
|
||||
MONGO_VERSION ?= 4
|
||||
MONGO_SUPPORTED ?= $(MONGO_VERSION) 3
|
||||
PROJECT ?= upper_mongo_$(MONGO_VERSION)
|
||||
|
||||
DB_HOST ?= 127.0.0.1
|
||||
DB_PORT ?= 27017
|
||||
|
||||
DB_NAME ?= admin
|
||||
DB_USERNAME ?= upperio_user
|
||||
DB_PASSWORD ?= upperio//s3cr37
|
||||
|
||||
TEST_FLAGS ?=
|
||||
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1
|
||||
|
||||
export MONGO_VERSION
|
||||
|
||||
export DB_HOST
|
||||
export DB_NAME
|
||||
export DB_PASSWORD
|
||||
export DB_PORT
|
||||
export DB_USERNAME
|
||||
|
||||
export TEST_FLAGS
|
||||
|
||||
test:
|
||||
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
|
||||
|
||||
test-no-race:
|
||||
go test -v -failfast $(TEST_FLAGS)
|
||||
|
||||
server-up: server-down
|
||||
docker-compose -p $(PROJECT) up -d && \
|
||||
sleep 10
|
||||
|
||||
server-down:
|
||||
docker-compose -p $(PROJECT) down
|
||||
|
||||
test-extended:
|
||||
parallel $(PARALLEL_FLAGS) \
|
||||
"MONGO_VERSION={} DB_PORT=\$$((27017+{#})) $(MAKE) server-up test server-down" ::: \
|
||||
$(MONGO_SUPPORTED)
|
4
adapter/mongo/README.md
Normal file
4
adapter/mongo/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# MongoDB adapter for upper/db
|
||||
|
||||
Please read the full docs, acknowledgements and examples at
|
||||
[https://upper.io/v4/adapter/mongo/](https://upper.io/v4/adapter/mongo/).
|
346
adapter/mongo/collection.go
Normal file
346
adapter/mongo/collection.go
Normal file
@ -0,0 +1,346 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"reflect"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/adapter"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
// Collection represents a mongodb collection.
|
||||
type Collection struct {
|
||||
parent *Source
|
||||
collection *mgo.Collection
|
||||
}
|
||||
|
||||
var (
|
||||
// idCache should be a struct if we're going to cache more than just
|
||||
// _id field here
|
||||
idCache = make(map[reflect.Type]string)
|
||||
idCacheMutex sync.RWMutex
|
||||
)
|
||||
|
||||
// Find creates a result set with the given conditions.
|
||||
func (col *Collection) Find(terms ...interface{}) mydb.Result {
|
||||
fields := []string{"*"}
|
||||
|
||||
conditions := col.compileQuery(terms...)
|
||||
|
||||
res := &result{}
|
||||
res = res.frame(func(r *resultQuery) error {
|
||||
r.c = col
|
||||
r.conditions = conditions
|
||||
r.fields = fields
|
||||
return nil
|
||||
})
|
||||
|
||||
return res
|
||||
}
|
||||
|
||||
var comparisonOperators = map[adapter.ComparisonOperator]string{
|
||||
adapter.ComparisonOperatorEqual: "$eq",
|
||||
adapter.ComparisonOperatorNotEqual: "$ne",
|
||||
|
||||
adapter.ComparisonOperatorLessThan: "$lt",
|
||||
adapter.ComparisonOperatorGreaterThan: "$gt",
|
||||
|
||||
adapter.ComparisonOperatorLessThanOrEqualTo: "$lte",
|
||||
adapter.ComparisonOperatorGreaterThanOrEqualTo: "$gte",
|
||||
|
||||
adapter.ComparisonOperatorIn: "$in",
|
||||
adapter.ComparisonOperatorNotIn: "$nin",
|
||||
}
|
||||
|
||||
func compare(field string, cmp *adapter.Comparison) (string, interface{}) {
|
||||
op := cmp.Operator()
|
||||
value := cmp.Value()
|
||||
|
||||
switch op {
|
||||
case adapter.ComparisonOperatorEqual:
|
||||
return field, value
|
||||
case adapter.ComparisonOperatorBetween:
|
||||
values := value.([]interface{})
|
||||
return field, bson.M{
|
||||
"$gte": values[0],
|
||||
"$lte": values[1],
|
||||
}
|
||||
case adapter.ComparisonOperatorNotBetween:
|
||||
values := value.([]interface{})
|
||||
return "$or", []bson.M{
|
||||
{field: bson.M{"$gt": values[1]}},
|
||||
{field: bson.M{"$lt": values[0]}},
|
||||
}
|
||||
case adapter.ComparisonOperatorIs:
|
||||
if value == nil {
|
||||
return field, bson.M{"$exists": false}
|
||||
}
|
||||
return field, bson.M{"$eq": value}
|
||||
case adapter.ComparisonOperatorIsNot:
|
||||
if value == nil {
|
||||
return field, bson.M{"$exists": true}
|
||||
}
|
||||
return field, bson.M{"$ne": value}
|
||||
case adapter.ComparisonOperatorRegExp, adapter.ComparisonOperatorLike:
|
||||
return field, bson.RegEx{Pattern: value.(string), Options: ""}
|
||||
case adapter.ComparisonOperatorNotRegExp, adapter.ComparisonOperatorNotLike:
|
||||
return field, bson.M{"$not": bson.RegEx{Pattern: value.(string), Options: ""}}
|
||||
}
|
||||
|
||||
if cmpOp, ok := comparisonOperators[op]; ok {
|
||||
return field, bson.M{
|
||||
cmpOp: value,
|
||||
}
|
||||
}
|
||||
|
||||
panic(fmt.Sprintf("Unsupported operator %v", op))
|
||||
}
|
||||
|
||||
// compileStatement transforms conditions into something *mgo.Session can
|
||||
// understand.
|
||||
func compileStatement(cond mydb.Cond) bson.M {
|
||||
conds := bson.M{}
|
||||
|
||||
// Walking over conditions
|
||||
for fieldI, value := range cond {
|
||||
field := strings.TrimSpace(fmt.Sprintf("%v", fieldI))
|
||||
|
||||
if cmp, ok := value.(*mydb.Comparison); ok {
|
||||
k, v := compare(field, cmp.Comparison)
|
||||
conds[k] = v
|
||||
continue
|
||||
}
|
||||
|
||||
var op string
|
||||
chunks := strings.SplitN(field, ` `, 2)
|
||||
|
||||
if len(chunks) > 1 {
|
||||
switch chunks[1] {
|
||||
case `IN`:
|
||||
op = `$in`
|
||||
case `NOT IN`:
|
||||
op = `$nin`
|
||||
case `>`:
|
||||
op = `$gt`
|
||||
case `<`:
|
||||
op = `$lt`
|
||||
case `<=`:
|
||||
op = `$lte`
|
||||
case `>=`:
|
||||
op = `$gte`
|
||||
default:
|
||||
op = chunks[1]
|
||||
}
|
||||
}
|
||||
field = chunks[0]
|
||||
|
||||
if op == "" {
|
||||
conds[field] = value
|
||||
} else {
|
||||
conds[field] = bson.M{op: value}
|
||||
}
|
||||
}
|
||||
|
||||
return conds
|
||||
}
|
||||
|
||||
// compileConditions compiles terms into something *mgo.Session can
|
||||
// understand.
|
||||
func (col *Collection) compileConditions(term interface{}) interface{} {
|
||||
|
||||
switch t := term.(type) {
|
||||
case []interface{}:
|
||||
values := []interface{}{}
|
||||
for i := range t {
|
||||
value := col.compileConditions(t[i])
|
||||
if value != nil {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
if len(values) > 0 {
|
||||
return values
|
||||
}
|
||||
case mydb.Cond:
|
||||
return compileStatement(t)
|
||||
case adapter.LogicalExpr:
|
||||
values := []interface{}{}
|
||||
|
||||
for _, s := range t.Expressions() {
|
||||
values = append(values, col.compileConditions(s))
|
||||
}
|
||||
|
||||
var op string
|
||||
switch t.Operator() {
|
||||
case adapter.LogicalOperatorOr:
|
||||
op = `$or`
|
||||
default:
|
||||
op = `$and`
|
||||
}
|
||||
|
||||
return bson.M{op: values}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// compileQuery compiles terms into something that *mgo.Session can
|
||||
// understand.
|
||||
func (col *Collection) compileQuery(terms ...interface{}) interface{} {
|
||||
compiled := col.compileConditions(terms)
|
||||
if compiled == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
conditions := compiled.([]interface{})
|
||||
if len(conditions) == 1 {
|
||||
return conditions[0]
|
||||
}
|
||||
// this should be correct.
|
||||
// query = map[string]interface{}{"$and": conditions}
|
||||
|
||||
// attempt to workaround https://jira.mongomydb.org/browse/SERVER-4572
|
||||
mapped := map[string]interface{}{}
|
||||
for _, v := range conditions {
|
||||
for kk := range v.(map[string]interface{}) {
|
||||
mapped[kk] = v.(map[string]interface{})[kk]
|
||||
}
|
||||
}
|
||||
|
||||
return mapped
|
||||
}
|
||||
|
||||
// Name returns the name of the table or tables that form the collection.
|
||||
func (col *Collection) Name() string {
|
||||
return col.collection.Name
|
||||
}
|
||||
|
||||
// Truncate deletes all rows from the table.
|
||||
func (col *Collection) Truncate() error {
|
||||
err := col.collection.DropCollection()
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (col *Collection) Session() mydb.Session {
|
||||
return col.parent
|
||||
}
|
||||
|
||||
func (col *Collection) Count() (uint64, error) {
|
||||
return col.Find().Count()
|
||||
}
|
||||
|
||||
func (col *Collection) InsertReturning(item interface{}) error {
|
||||
return mydb.ErrUnsupported
|
||||
}
|
||||
|
||||
func (col *Collection) UpdateReturning(item interface{}) error {
|
||||
return mydb.ErrUnsupported
|
||||
}
|
||||
|
||||
// Insert inserts a record (map or struct) into the collection.
|
||||
func (col *Collection) Insert(item interface{}) (mydb.InsertResult, error) {
|
||||
var err error
|
||||
|
||||
id := getID(item)
|
||||
|
||||
if col.parent.versionAtLeast(2, 6, 0, 0) {
|
||||
// this breaks MongoDb older than 2.6
|
||||
if _, err = col.collection.Upsert(bson.M{"_id": id}, item); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Allocating a new ID.
|
||||
if err = col.collection.Insert(bson.M{"_id": id}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Now append data the user wants to append.
|
||||
if err = col.collection.Update(bson.M{"_id": id}, item); err != nil {
|
||||
// Cleanup allocated ID
|
||||
if err := col.collection.Remove(bson.M{"_id": id}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return mydb.NewInsertResult(id), nil
|
||||
}
|
||||
|
||||
// Exists returns true if the collection exists.
|
||||
func (col *Collection) Exists() (bool, error) {
|
||||
query := col.parent.database.C(`system.namespaces`).Find(map[string]string{`name`: fmt.Sprintf(`%s.%s`, col.parent.database.Name, col.collection.Name)})
|
||||
count, err := query.Count()
|
||||
return count > 0, err
|
||||
}
|
||||
|
||||
// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid
|
||||
func getID(item interface{}) interface{} {
|
||||
v := reflect.ValueOf(item) // convert interface to Value
|
||||
v = reflect.Indirect(v) // convert pointers
|
||||
|
||||
switch v.Kind() {
|
||||
case reflect.Map:
|
||||
if inItem, ok := item.(map[string]interface{}); ok {
|
||||
if id, ok := inItem["_id"]; ok {
|
||||
bsonID, ok := id.(bson.ObjectId)
|
||||
if ok {
|
||||
return bsonID
|
||||
}
|
||||
}
|
||||
}
|
||||
case reflect.Struct:
|
||||
t := v.Type()
|
||||
|
||||
idCacheMutex.RLock()
|
||||
fieldName, found := idCache[t]
|
||||
idCacheMutex.RUnlock()
|
||||
|
||||
if !found {
|
||||
for n := 0; n < t.NumField(); n++ {
|
||||
field := t.Field(n)
|
||||
if field.PkgPath != "" {
|
||||
continue // Private field
|
||||
}
|
||||
|
||||
tag := field.Tag.Get("bson")
|
||||
if tag == "" {
|
||||
tag = field.Tag.Get("db")
|
||||
}
|
||||
|
||||
if tag == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := strings.Split(tag, ",")
|
||||
|
||||
if parts[0] == "_id" {
|
||||
fieldName = field.Name
|
||||
idCacheMutex.RLock()
|
||||
idCache[t] = fieldName
|
||||
idCacheMutex.RUnlock()
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if fieldName != "" {
|
||||
if bsonID, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok {
|
||||
if bsonID.Valid() {
|
||||
return bsonID
|
||||
}
|
||||
} else {
|
||||
return v.FieldByName(fieldName).Interface()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bson.NewObjectId()
|
||||
}
|
98
adapter/mongo/connection.go
Normal file
98
adapter/mongo/connection.go
Normal file
@ -0,0 +1,98 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const connectionScheme = `mongodb`
|
||||
|
||||
// ConnectionURL implements a MongoDB connection struct.
|
||||
type ConnectionURL struct {
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Database string
|
||||
Options map[string]string
|
||||
}
|
||||
|
||||
func (c ConnectionURL) String() (s string) {
|
||||
|
||||
if c.Database == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
vv := url.Values{}
|
||||
|
||||
// Do we have any options?
|
||||
if c.Options == nil {
|
||||
c.Options = map[string]string{}
|
||||
}
|
||||
|
||||
// Converting options into URL values.
|
||||
for k, v := range c.Options {
|
||||
vv.Set(k, v)
|
||||
}
|
||||
|
||||
// Has user?
|
||||
var userInfo *url.Userinfo
|
||||
|
||||
if c.User != "" {
|
||||
if c.Password == "" {
|
||||
userInfo = url.User(c.User)
|
||||
} else {
|
||||
userInfo = url.UserPassword(c.User, c.Password)
|
||||
}
|
||||
}
|
||||
|
||||
// Building URL.
|
||||
u := url.URL{
|
||||
Scheme: connectionScheme,
|
||||
Path: c.Database,
|
||||
Host: c.Host,
|
||||
User: userInfo,
|
||||
RawQuery: vv.Encode(),
|
||||
}
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ParseURL parses s into a ConnectionURL struct.
|
||||
func ParseURL(s string) (conn ConnectionURL, err error) {
|
||||
var u *url.URL
|
||||
|
||||
if !strings.HasPrefix(s, connectionScheme+"://") {
|
||||
return conn, fmt.Errorf(`Expecting mongodb:// connection scheme.`)
|
||||
}
|
||||
|
||||
if u, err = url.Parse(s); err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
conn.Host = u.Host
|
||||
|
||||
// Deleting / from start of the string.
|
||||
conn.Database = strings.Trim(u.Path, "/")
|
||||
|
||||
// Adding user / password.
|
||||
if u.User != nil {
|
||||
conn.User = u.User.Username()
|
||||
conn.Password, _ = u.User.Password()
|
||||
}
|
||||
|
||||
// Adding options.
|
||||
conn.Options = map[string]string{}
|
||||
|
||||
var vv url.Values
|
||||
|
||||
if vv, err = url.ParseQuery(u.RawQuery); err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
for k := range vv {
|
||||
conn.Options[k] = vv.Get(k)
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
114
adapter/mongo/connection_test.go
Normal file
114
adapter/mongo/connection_test.go
Normal file
@ -0,0 +1,114 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnectionURL(t *testing.T) {
|
||||
|
||||
c := ConnectionURL{}
|
||||
|
||||
// Default connection string is only the protocol.
|
||||
if c.String() != "" {
|
||||
t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String())
|
||||
}
|
||||
|
||||
// Adding a database name.
|
||||
c.Database = "myfilename"
|
||||
|
||||
if c.String() != "mongodb://myfilename" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Adding an option.
|
||||
c.Options = map[string]string{
|
||||
"cache": "foobar",
|
||||
"mode": "ro",
|
||||
}
|
||||
|
||||
// Adding username and password
|
||||
c.User = "user"
|
||||
c.Password = "pass"
|
||||
|
||||
// Setting host.
|
||||
c.Host = "localhost"
|
||||
|
||||
if c.String() != "mongodb://user:pass@localhost/myfilename?cache=foobar&mode=ro" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting host and port.
|
||||
c.Host = "localhost:27017"
|
||||
|
||||
if c.String() != "mongodb://user:pass@localhost:27017/myfilename?cache=foobar&mode=ro" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting cluster.
|
||||
c.Host = "localhost,1.2.3.4,example.org:1234"
|
||||
|
||||
if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/myfilename?cache=foobar&mode=ro" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting another database.
|
||||
c.Database = "another_database"
|
||||
|
||||
if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseConnectionURL(t *testing.T) {
|
||||
var u ConnectionURL
|
||||
var s string
|
||||
var err error
|
||||
|
||||
s = "mongodb:///mydatabase"
|
||||
|
||||
if u, err = ParseURL(s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u.Database != "mydatabase" {
|
||||
t.Fatal("Failed to parse database.")
|
||||
}
|
||||
|
||||
s = "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro"
|
||||
|
||||
if u, err = ParseURL(s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u.Database != "another_database" {
|
||||
t.Fatal("Failed to get database.")
|
||||
}
|
||||
|
||||
if u.Options["cache"] != "foobar" {
|
||||
t.Fatal("Expecting option.")
|
||||
}
|
||||
|
||||
if u.Options["mode"] != "ro" {
|
||||
t.Fatal("Expecting option.")
|
||||
}
|
||||
|
||||
if u.User != "user" {
|
||||
t.Fatal("Expecting user.")
|
||||
}
|
||||
|
||||
if u.Password != "pass" {
|
||||
t.Fatal("Expecting password.")
|
||||
}
|
||||
|
||||
if u.Host != "localhost,1.2.3.4,example.org:1234" {
|
||||
t.Fatal("Expecting host.")
|
||||
}
|
||||
|
||||
s = "http://example.org"
|
||||
|
||||
if _, err = ParseURL(s); err == nil {
|
||||
t.Fatal("Expecting error.")
|
||||
}
|
||||
|
||||
}
|
245
adapter/mongo/database.go
Normal file
245
adapter/mongo/database.go
Normal file
@ -0,0 +1,245 @@
|
||||
// Package mongo wraps the gopkg.in/mgo.v2 MongoDB driver. See
|
||||
// https://github.com/upper/db/adapter/mongo for documentation, particularities and usage
|
||||
// examples.
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
// Adapter holds the name of the mongodb adapter.
|
||||
const Adapter = `mongo`
|
||||
|
||||
var connTimeout = time.Second * 5
|
||||
|
||||
// Source represents a MongoDB database.
|
||||
type Source struct {
|
||||
mydb.Settings
|
||||
|
||||
ctx context.Context
|
||||
|
||||
name string
|
||||
connURL mydb.ConnectionURL
|
||||
session *mgo.Session
|
||||
database *mgo.Database
|
||||
version []int
|
||||
collections map[string]*Collection
|
||||
collectionsMu sync.Mutex
|
||||
}
|
||||
|
||||
type mongoAdapter struct {
|
||||
}
|
||||
|
||||
func (mongoAdapter) Open(dsn mydb.ConnectionURL) (mydb.Session, error) {
|
||||
return Open(dsn)
|
||||
}
|
||||
|
||||
func init() {
|
||||
mydb.RegisterAdapter(Adapter, mydb.Adapter(&mongoAdapter{}))
|
||||
}
|
||||
|
||||
// Open stablishes a new connection to a SQL server.
|
||||
func Open(settings mydb.ConnectionURL) (mydb.Session, error) {
|
||||
d := &Source{Settings: mydb.NewSettings(), ctx: context.Background()}
|
||||
if err := d.Open(settings); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return d, nil
|
||||
}
|
||||
|
||||
func (s *Source) TxContext(context.Context, func(tx mydb.Session) error, *sql.TxOptions) error {
|
||||
return mydb.ErrNotSupportedByAdapter
|
||||
}
|
||||
|
||||
func (s *Source) Tx(func(mydb.Session) error) error {
|
||||
return mydb.ErrNotSupportedByAdapter
|
||||
}
|
||||
|
||||
func (s *Source) SQL() mydb.SQL {
|
||||
// Not supported
|
||||
panic("sql builder is not supported by mongodb")
|
||||
}
|
||||
|
||||
func (s *Source) ConnectionURL() mydb.ConnectionURL {
|
||||
return s.connURL
|
||||
}
|
||||
|
||||
// SetConnMaxLifetime is not supported.
|
||||
func (s *Source) SetConnMaxLifetime(time.Duration) {
|
||||
s.Settings.SetConnMaxLifetime(time.Duration(0))
|
||||
}
|
||||
|
||||
// SetMaxIdleConns is not supported.
|
||||
func (s *Source) SetMaxIdleConns(int) {
|
||||
s.Settings.SetMaxIdleConns(0)
|
||||
}
|
||||
|
||||
// SetMaxOpenConns is not supported.
|
||||
func (s *Source) SetMaxOpenConns(int) {
|
||||
s.Settings.SetMaxOpenConns(0)
|
||||
}
|
||||
|
||||
// Name returns the name of the database.
|
||||
func (s *Source) Name() string {
|
||||
return s.name
|
||||
}
|
||||
|
||||
// Open attempts to connect to the database.
|
||||
func (s *Source) Open(connURL mydb.ConnectionURL) error {
|
||||
s.connURL = connURL
|
||||
return s.open()
|
||||
}
|
||||
|
||||
// Clone returns a cloned mydb.Session session.
|
||||
func (s *Source) Clone() (mydb.Session, error) {
|
||||
newSession := s.session.Copy()
|
||||
clone := &Source{
|
||||
Settings: mydb.NewSettings(),
|
||||
|
||||
name: s.name,
|
||||
connURL: s.connURL,
|
||||
session: newSession,
|
||||
database: newSession.DB(s.database.Name),
|
||||
version: s.version,
|
||||
collections: map[string]*Collection{},
|
||||
}
|
||||
return clone, nil
|
||||
}
|
||||
|
||||
// Ping checks whether a connection to the database is still alive by pinging
|
||||
// it, establishing a connection if necessary.
|
||||
func (s *Source) Ping() error {
|
||||
return s.session.Ping()
|
||||
}
|
||||
|
||||
func (s *Source) Reset() {
|
||||
s.collectionsMu.Lock()
|
||||
defer s.collectionsMu.Unlock()
|
||||
s.collections = make(map[string]*Collection)
|
||||
}
|
||||
|
||||
// Driver returns the underlying *mgo.Session instance.
|
||||
func (s *Source) Driver() interface{} {
|
||||
return s.session
|
||||
}
|
||||
|
||||
func (s *Source) open() error {
|
||||
var err error
|
||||
|
||||
if s.session, err = mgo.DialWithTimeout(s.connURL.String(), connTimeout); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
s.collections = map[string]*Collection{}
|
||||
s.database = s.session.DB("")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close terminates the current database session.
|
||||
func (s *Source) Close() error {
|
||||
if s.session != nil {
|
||||
s.session.Close()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Collections returns a list of non-system tables from the database.
|
||||
func (s *Source) Collections() (cols []mydb.Collection, err error) {
|
||||
var rawcols []string
|
||||
var col string
|
||||
|
||||
if rawcols, err = s.database.CollectionNames(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
cols = make([]mydb.Collection, 0, len(rawcols))
|
||||
|
||||
for _, col = range rawcols {
|
||||
if !strings.HasPrefix(col, "system.") {
|
||||
cols = append(cols, s.Collection(col))
|
||||
}
|
||||
}
|
||||
|
||||
return cols, nil
|
||||
}
|
||||
|
||||
func (s *Source) Delete(mydb.Record) error {
|
||||
return mydb.ErrNotImplemented
|
||||
}
|
||||
|
||||
func (s *Source) Get(mydb.Record, interface{}) error {
|
||||
return mydb.ErrNotImplemented
|
||||
}
|
||||
|
||||
func (s *Source) Save(mydb.Record) error {
|
||||
return mydb.ErrNotImplemented
|
||||
}
|
||||
|
||||
func (s *Source) Context() context.Context {
|
||||
return s.ctx
|
||||
}
|
||||
|
||||
func (s *Source) WithContext(ctx context.Context) mydb.Session {
|
||||
return &Source{
|
||||
ctx: ctx,
|
||||
Settings: s.Settings,
|
||||
name: s.name,
|
||||
connURL: s.connURL,
|
||||
session: s.session,
|
||||
database: s.database,
|
||||
version: s.version,
|
||||
}
|
||||
}
|
||||
|
||||
// Collection returns a collection by name.
|
||||
func (s *Source) Collection(name string) mydb.Collection {
|
||||
s.collectionsMu.Lock()
|
||||
defer s.collectionsMu.Unlock()
|
||||
|
||||
var col *Collection
|
||||
var ok bool
|
||||
|
||||
if col, ok = s.collections[name]; !ok {
|
||||
col = &Collection{
|
||||
parent: s,
|
||||
collection: s.database.C(name),
|
||||
}
|
||||
s.collections[name] = col
|
||||
}
|
||||
|
||||
return col
|
||||
}
|
||||
|
||||
func (s *Source) versionAtLeast(version ...int) bool {
|
||||
// only fetch this once - it makes a db call
|
||||
if len(s.version) == 0 {
|
||||
buildInfo, err := s.database.Session.BuildInfo()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
s.version = buildInfo.VersionArray
|
||||
}
|
||||
|
||||
// Check major version first
|
||||
if s.version[0] > version[0] {
|
||||
return true
|
||||
}
|
||||
|
||||
for i := range version {
|
||||
if i == len(s.version) {
|
||||
return false
|
||||
}
|
||||
if s.version[i] < version[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
13
adapter/mongo/docker-compose.yml
Normal file
13
adapter/mongo/docker-compose.yml
Normal file
@ -0,0 +1,13 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
|
||||
server:
|
||||
image: mongo:${MONGO_VERSION:-3}
|
||||
environment:
|
||||
MONGO_INITDB_ROOT_USERNAME: ${DB_USERNAME:-upperio_user}
|
||||
MONGO_INITDB_ROOT_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37}
|
||||
MONGO_INITDB_DATABASE: ${DB_NAME:-upperio}
|
||||
ports:
|
||||
- '${BIND_HOST:-127.0.0.1}:${DB_PORT:-27017}:27017'
|
||||
|
20
adapter/mongo/generic_test.go
Normal file
20
adapter/mongo/generic_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GenericTests struct {
|
||||
testsuite.GenericTestSuite
|
||||
}
|
||||
|
||||
func (s *GenericTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
suite.Run(t, &GenericTests{})
|
||||
}
|
77
adapter/mongo/helper_test.go
Normal file
77
adapter/mongo/helper_test.go
Normal file
@ -0,0 +1,77 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
)
|
||||
|
||||
var settings = ConnectionURL{
|
||||
Database: os.Getenv("DB_NAME"),
|
||||
User: os.Getenv("DB_USERNAME"),
|
||||
Password: os.Getenv("DB_PASSWORD"),
|
||||
Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"),
|
||||
}
|
||||
|
||||
type Helper struct {
|
||||
sess mydb.Session
|
||||
}
|
||||
|
||||
func (h *Helper) Session() mydb.Session {
|
||||
return h.sess
|
||||
}
|
||||
|
||||
func (h *Helper) Adapter() string {
|
||||
return "mongo"
|
||||
}
|
||||
|
||||
func (h *Helper) TearDown() error {
|
||||
return h.sess.Close()
|
||||
}
|
||||
|
||||
func (h *Helper) TearUp() error {
|
||||
var err error
|
||||
|
||||
h.sess, err = Open(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
mgod, ok := h.sess.Driver().(*mgo.Session)
|
||||
if !ok {
|
||||
panic("expecting mgo.Session")
|
||||
}
|
||||
|
||||
var col *mgo.Collection
|
||||
col = mgod.DB(settings.Database).C("birthdays")
|
||||
_ = col.DropCollection()
|
||||
|
||||
col = mgod.DB(settings.Database).C("fibonacci")
|
||||
_ = col.DropCollection()
|
||||
|
||||
col = mgod.DB(settings.Database).C("is_even")
|
||||
_ = col.DropCollection()
|
||||
|
||||
col = mgod.DB(settings.Database).C("CaSe_TesT")
|
||||
_ = col.DropCollection()
|
||||
|
||||
// Getting a pointer to the "artist" collection.
|
||||
artist := h.sess.Collection("artist")
|
||||
|
||||
_ = artist.Truncate()
|
||||
for i := 0; i < 999; i++ {
|
||||
_, err = artist.Insert(artistType{
|
||||
Name: fmt.Sprintf("artist-%d", i),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ testsuite.Helper = &Helper{}
|
754
adapter/mongo/mongo_test.go
Normal file
754
adapter/mongo/mongo_test.go
Normal file
@ -0,0 +1,754 @@
|
||||
// Tests for the mongodb adapter.
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"math/rand"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
type artistType struct {
|
||||
ID bson.ObjectId `bson:"_id,omitempty"`
|
||||
Name string `bson:"name"`
|
||||
}
|
||||
|
||||
// Structure for testing conversions and datatypes.
|
||||
type testValuesStruct struct {
|
||||
Uint uint `bson:"_uint"`
|
||||
Uint8 uint8 `bson:"_uint8"`
|
||||
Uint16 uint16 `bson:"_uint16"`
|
||||
Uint32 uint32 `bson:"_uint32"`
|
||||
Uint64 uint64 `bson:"_uint64"`
|
||||
|
||||
Int int `bson:"_int"`
|
||||
Int8 int8 `bson:"_int8"`
|
||||
Int16 int16 `bson:"_int16"`
|
||||
Int32 int32 `bson:"_int32"`
|
||||
Int64 int64 `bson:"_int64"`
|
||||
|
||||
Float32 float32 `bson:"_float32"`
|
||||
Float64 float64 `bson:"_float64"`
|
||||
|
||||
Bool bool `bson:"_bool"`
|
||||
String string `bson:"_string"`
|
||||
|
||||
Date time.Time `bson:"_date"`
|
||||
DateN *time.Time `bson:"_nildate"`
|
||||
DateP *time.Time `bson:"_ptrdate"`
|
||||
Time time.Duration `bson:"_time"`
|
||||
}
|
||||
|
||||
var testValues testValuesStruct
|
||||
|
||||
func init() {
|
||||
t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local)
|
||||
|
||||
testValues = testValuesStruct{
|
||||
1, 1, 1, 1, 1,
|
||||
-1, -1, -1, -1, -1,
|
||||
1.337, 1.337,
|
||||
true,
|
||||
"Hello world!",
|
||||
t,
|
||||
nil,
|
||||
&t,
|
||||
time.Second * time.Duration(7331),
|
||||
}
|
||||
}
|
||||
|
||||
type AdapterTests struct {
|
||||
testsuite.Suite
|
||||
}
|
||||
|
||||
func (s *AdapterTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestOpenWithWrongData() {
|
||||
var err error
|
||||
var rightSettings, wrongSettings ConnectionURL
|
||||
|
||||
// Attempt to open with safe settings.
|
||||
rightSettings = ConnectionURL{
|
||||
Database: settings.Database,
|
||||
Host: settings.Host,
|
||||
User: settings.User,
|
||||
Password: settings.Password,
|
||||
}
|
||||
|
||||
// Attempt to open an empty database.
|
||||
_, err = Open(rightSettings)
|
||||
s.NoError(err)
|
||||
|
||||
// Attempt to open with wrong password.
|
||||
wrongSettings = ConnectionURL{
|
||||
Database: settings.Database,
|
||||
Host: settings.Host,
|
||||
User: settings.User,
|
||||
Password: "fail",
|
||||
}
|
||||
|
||||
_, err = Open(wrongSettings)
|
||||
s.Error(err)
|
||||
|
||||
// Attempt to open with wrong database.
|
||||
wrongSettings = ConnectionURL{
|
||||
Database: "fail",
|
||||
Host: settings.Host,
|
||||
User: settings.User,
|
||||
Password: settings.Password,
|
||||
}
|
||||
|
||||
_, err = Open(wrongSettings)
|
||||
s.Error(err)
|
||||
|
||||
// Attempt to open with wrong username.
|
||||
wrongSettings = ConnectionURL{
|
||||
Database: settings.Database,
|
||||
Host: settings.Host,
|
||||
User: "fail",
|
||||
Password: settings.Password,
|
||||
}
|
||||
|
||||
_, err = Open(wrongSettings)
|
||||
s.Error(err)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestTruncate() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a list of all collections in this database.
|
||||
collections, err := sess.Collections()
|
||||
s.NoError(err)
|
||||
|
||||
for _, col := range collections {
|
||||
// The collection may ot may not exists.
|
||||
if ok, _ := col.Exists(); ok {
|
||||
// Truncating the structure, if exists.
|
||||
err = col.Truncate()
|
||||
s.NoError(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestInsert() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a pointer to the "artist" collection.
|
||||
artist := sess.Collection("artist")
|
||||
_ = artist.Truncate()
|
||||
|
||||
// Inserting a map.
|
||||
record, err := artist.Insert(map[string]string{
|
||||
"name": "Ozzie",
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
id := record.ID()
|
||||
s.NotZero(record.ID())
|
||||
|
||||
_, ok := id.(bson.ObjectId)
|
||||
s.True(ok)
|
||||
|
||||
s.True(id.(bson.ObjectId).Valid())
|
||||
|
||||
// Inserting a struct.
|
||||
record, err = artist.Insert(struct {
|
||||
Name string
|
||||
}{
|
||||
"Flea",
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
id = record.ID()
|
||||
s.NotZero(id)
|
||||
|
||||
_, ok = id.(bson.ObjectId)
|
||||
s.True(ok)
|
||||
s.True(id.(bson.ObjectId).Valid())
|
||||
|
||||
// Inserting a struct (using tags to specify the field name).
|
||||
record, err = artist.Insert(struct {
|
||||
ArtistName string `bson:"name"`
|
||||
}{
|
||||
"Slash",
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
id = record.ID()
|
||||
s.NotNil(id)
|
||||
|
||||
_, ok = id.(bson.ObjectId)
|
||||
|
||||
s.True(ok)
|
||||
s.True(id.(bson.ObjectId).Valid())
|
||||
|
||||
// Inserting a pointer to a struct
|
||||
record, err = artist.Insert(&struct {
|
||||
ArtistName string `bson:"name"`
|
||||
}{
|
||||
"Metallica",
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
id = record.ID()
|
||||
s.NotNil(id)
|
||||
|
||||
_, ok = id.(bson.ObjectId)
|
||||
s.True(ok)
|
||||
s.True(id.(bson.ObjectId).Valid())
|
||||
|
||||
// Inserting a pointer to a map
|
||||
record, err = artist.Insert(&map[string]string{
|
||||
"name": "Freddie",
|
||||
})
|
||||
s.NoError(err)
|
||||
s.NotZero(id)
|
||||
|
||||
_, ok = id.(bson.ObjectId)
|
||||
s.True(ok)
|
||||
|
||||
id = record.ID()
|
||||
s.NotNil(id)
|
||||
|
||||
s.True(id.(bson.ObjectId).Valid())
|
||||
|
||||
// Counting elements, must be exactly 6 elements.
|
||||
total, err := artist.Find().Count()
|
||||
s.NoError(err)
|
||||
s.Equal(uint64(5), total)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestGetNonExistentRow_Issue426() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
defer sess.Close()
|
||||
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
var one artistType
|
||||
err = artist.Find(mydb.Cond{"name": "nothing"}).One(&one)
|
||||
|
||||
s.NotZero(err)
|
||||
s.Equal(mydb.ErrNoMoreRows, err)
|
||||
|
||||
var all []artistType
|
||||
err = artist.Find(mydb.Cond{"name": "nothing"}).All(&all)
|
||||
|
||||
s.Zero(err, "All should not return mgo.ErrNotFound")
|
||||
s.Equal(0, len(all))
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestResultCount() {
|
||||
var err error
|
||||
var res mydb.Result
|
||||
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
defer sess.Close()
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
res = artist.Find()
|
||||
|
||||
// Counting all the matching rows.
|
||||
total, err := res.Count()
|
||||
s.NoError(err)
|
||||
s.NotZero(total)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestGroup() {
|
||||
var stats mydb.Collection
|
||||
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
type statsT struct {
|
||||
Numeric int `db:"numeric" bson:"numeric"`
|
||||
Value int `db:"value" bson:"value"`
|
||||
}
|
||||
|
||||
defer sess.Close()
|
||||
|
||||
stats = sess.Collection("statsTest")
|
||||
|
||||
// Truncating table.
|
||||
_ = stats.Truncate()
|
||||
|
||||
// Adding row append.
|
||||
for i := 0; i < 1000; i++ {
|
||||
numeric, value := rand.Intn(10), rand.Intn(100)
|
||||
_, err = stats.Insert(statsT{numeric, value})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
// mydb.statsTest.group({key: {numeric: true}, initial: {sum: 0}, reduce: function(doc, prev) { prev.sum += 1}});
|
||||
|
||||
// Testing GROUP BY
|
||||
res := stats.Find().GroupBy(bson.M{
|
||||
"key": bson.M{"numeric": true},
|
||||
"initial": bson.M{"sum": 0},
|
||||
"reduce": `function(doc, prev) { prev.sum += 1}`,
|
||||
})
|
||||
|
||||
var results []map[string]interface{}
|
||||
|
||||
err = res.All(&results)
|
||||
s.Equal(mydb.ErrUnsupported, err)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestResultNonExistentCount() {
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
defer sess.Close()
|
||||
|
||||
total, err := sess.Collection("notartist").Find().Count()
|
||||
s.NoError(err)
|
||||
s.Zero(total)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestResultFetch() {
|
||||
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
// Testing map
|
||||
res := artist.Find()
|
||||
|
||||
rowM := map[string]interface{}{}
|
||||
|
||||
for res.Next(&rowM) {
|
||||
s.NotZero(rowM["_id"])
|
||||
|
||||
_, ok := rowM["_id"].(bson.ObjectId)
|
||||
s.True(ok)
|
||||
|
||||
s.True(rowM["_id"].(bson.ObjectId).Valid())
|
||||
|
||||
name, ok := rowM["name"].(string)
|
||||
s.True(ok)
|
||||
s.NotZero(name)
|
||||
}
|
||||
|
||||
err = res.Close()
|
||||
s.NoError(err)
|
||||
|
||||
// Testing struct
|
||||
rowS := struct {
|
||||
ID bson.ObjectId `bson:"_id"`
|
||||
Name string `bson:"name"`
|
||||
}{}
|
||||
|
||||
res = artist.Find()
|
||||
|
||||
for res.Next(&rowS) {
|
||||
s.True(rowS.ID.Valid())
|
||||
s.NotZero(rowS.Name)
|
||||
}
|
||||
|
||||
err = res.Close()
|
||||
s.NoError(err)
|
||||
|
||||
// Testing tagged struct
|
||||
rowT := struct {
|
||||
Value1 bson.ObjectId `bson:"_id"`
|
||||
Value2 string `bson:"name"`
|
||||
}{}
|
||||
|
||||
res = artist.Find()
|
||||
|
||||
for res.Next(&rowT) {
|
||||
s.True(rowT.Value1.Valid())
|
||||
s.NotZero(rowT.Value2)
|
||||
}
|
||||
|
||||
err = res.Close()
|
||||
s.NoError(err)
|
||||
|
||||
// Testing Result.All() with a slice of maps.
|
||||
res = artist.Find()
|
||||
|
||||
allRowsM := []map[string]interface{}{}
|
||||
err = res.All(&allRowsM)
|
||||
s.NoError(err)
|
||||
|
||||
for _, singleRowM := range allRowsM {
|
||||
s.NotZero(singleRowM["_id"])
|
||||
}
|
||||
|
||||
// Testing Result.All() with a slice of structs.
|
||||
res = artist.Find()
|
||||
|
||||
allRowsS := []struct {
|
||||
ID bson.ObjectId `bson:"_id"`
|
||||
Name string
|
||||
}{}
|
||||
err = res.All(&allRowsS)
|
||||
s.NoError(err)
|
||||
|
||||
for _, singleRowS := range allRowsS {
|
||||
s.True(singleRowS.ID.Valid())
|
||||
}
|
||||
|
||||
// Testing Result.All() with a slice of tagged structs.
|
||||
res = artist.Find()
|
||||
|
||||
allRowsT := []struct {
|
||||
Value1 bson.ObjectId `bson:"_id"`
|
||||
Value2 string `bson:"name"`
|
||||
}{}
|
||||
err = res.All(&allRowsT)
|
||||
s.NoError(err)
|
||||
|
||||
for _, singleRowT := range allRowsT {
|
||||
s.True(singleRowT.Value1.Valid())
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestUpdate() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a pointer to the "artist" collection.
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
// Value
|
||||
value := struct {
|
||||
ID bson.ObjectId `bson:"_id"`
|
||||
Name string
|
||||
}{}
|
||||
|
||||
// Getting the first artist.
|
||||
res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1)
|
||||
|
||||
err = res.One(&value)
|
||||
s.NoError(err)
|
||||
|
||||
// Updating with a map
|
||||
rowM := map[string]interface{}{
|
||||
"name": strings.ToUpper(value.Name),
|
||||
}
|
||||
|
||||
err = res.Update(rowM)
|
||||
s.NoError(err)
|
||||
|
||||
err = res.One(&value)
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(value.Name, rowM["name"])
|
||||
|
||||
// Updating with a struct
|
||||
rowS := struct {
|
||||
Name string
|
||||
}{strings.ToLower(value.Name)}
|
||||
|
||||
err = res.Update(rowS)
|
||||
s.NoError(err)
|
||||
|
||||
err = res.One(&value)
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(value.Name, rowS.Name)
|
||||
|
||||
// Updating with a tagged struct
|
||||
rowT := struct {
|
||||
Value1 string `bson:"name"`
|
||||
}{strings.Replace(value.Name, "z", "Z", -1)}
|
||||
|
||||
err = res.Update(rowT)
|
||||
s.NoError(err)
|
||||
|
||||
err = res.One(&value)
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(value.Name, rowT.Value1)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestOperators() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a pointer to the "artist" collection.
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
rowS := struct {
|
||||
ID uint64
|
||||
Name string
|
||||
}{}
|
||||
|
||||
res := artist.Find(mydb.Cond{"_id": mydb.NotIn(0, -1)})
|
||||
|
||||
err = res.One(&rowS)
|
||||
s.NoError(err)
|
||||
|
||||
err = res.Close()
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestDelete() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a pointer to the "artist" collection.
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
// Getting the first artist.
|
||||
res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1)
|
||||
|
||||
var first struct {
|
||||
ID bson.ObjectId `bson:"_id"`
|
||||
}
|
||||
|
||||
err = res.One(&first)
|
||||
s.NoError(err)
|
||||
|
||||
res = artist.Find(mydb.Cond{"_id": mydb.Eq(first.ID)})
|
||||
|
||||
// Trying to remove the row.
|
||||
err = res.Delete()
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestDataTypes() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a pointer to the "data_types" collection.
|
||||
dataTypes := sess.Collection("data_types")
|
||||
|
||||
// Inserting our test subject.
|
||||
record, err := dataTypes.Insert(testValues)
|
||||
s.NoError(err)
|
||||
|
||||
id := record.ID()
|
||||
s.NotZero(id)
|
||||
|
||||
// Trying to get the same subject we added.
|
||||
res := dataTypes.Find(mydb.Cond{"_id": mydb.Eq(id)})
|
||||
|
||||
exists, err := res.Count()
|
||||
s.NoError(err)
|
||||
s.NotZero(exists)
|
||||
|
||||
// Trying to dump the subject into an empty structure of the same type.
|
||||
var item testValuesStruct
|
||||
err = res.One(&item)
|
||||
s.NoError(err)
|
||||
|
||||
// The original value and the test subject must match.
|
||||
s.Equal(testValues, item)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestPaginator() {
|
||||
// Opening database.
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
|
||||
// We should close the database when it's no longer in use.
|
||||
defer sess.Close()
|
||||
|
||||
// Getting a pointer to the "artist" collection.
|
||||
artist := sess.Collection("artist")
|
||||
|
||||
err = artist.Truncate()
|
||||
s.NoError(err)
|
||||
|
||||
for i := 0; i < 999; i++ {
|
||||
_, err = artist.Insert(artistType{
|
||||
Name: fmt.Sprintf("artist-%d", i),
|
||||
})
|
||||
s.NoError(err)
|
||||
}
|
||||
|
||||
q := sess.Collection("artist").Find().Paginate(15)
|
||||
paginator := q.Paginate(13)
|
||||
|
||||
var zerothPage []artistType
|
||||
err = paginator.Page(0).All(&zerothPage)
|
||||
s.NoError(err)
|
||||
s.Equal(13, len(zerothPage))
|
||||
|
||||
var secondPage []artistType
|
||||
err = paginator.Page(2).All(&secondPage)
|
||||
s.NoError(err)
|
||||
s.Equal(13, len(secondPage))
|
||||
|
||||
tp, err := paginator.TotalPages()
|
||||
s.NoError(err)
|
||||
s.NotZero(tp)
|
||||
s.Equal(uint(77), tp)
|
||||
|
||||
ti, err := paginator.TotalEntries()
|
||||
s.NoError(err)
|
||||
s.NotZero(ti)
|
||||
s.Equal(uint64(999), ti)
|
||||
|
||||
var seventySixthPage []artistType
|
||||
err = paginator.Page(76).All(&seventySixthPage)
|
||||
s.NoError(err)
|
||||
s.Equal(11, len(seventySixthPage))
|
||||
|
||||
var seventySeventhPage []artistType
|
||||
err = paginator.Page(77).All(&seventySeventhPage)
|
||||
s.NoError(err)
|
||||
s.Equal(0, len(seventySeventhPage))
|
||||
|
||||
var hundredthPage []artistType
|
||||
err = paginator.Page(100).All(&hundredthPage)
|
||||
s.NoError(err)
|
||||
s.Equal(0, len(hundredthPage))
|
||||
|
||||
for i := uint(0); i < tp; i++ {
|
||||
current := paginator.Page(i)
|
||||
|
||||
var items []artistType
|
||||
err := current.All(&items)
|
||||
s.NoError(err)
|
||||
if len(items) < 1 {
|
||||
break
|
||||
}
|
||||
for j := 0; j < len(items); j++ {
|
||||
s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name)
|
||||
}
|
||||
}
|
||||
|
||||
paginator = paginator.Cursor("_id")
|
||||
{
|
||||
current := paginator.Page(0)
|
||||
for i := 0; ; i++ {
|
||||
var items []artistType
|
||||
err := current.All(&items)
|
||||
s.NoError(err)
|
||||
|
||||
if len(items) < 1 {
|
||||
break
|
||||
}
|
||||
|
||||
for j := 0; j < len(items); j++ {
|
||||
s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name)
|
||||
}
|
||||
current = current.NextPage(items[len(items)-1].ID)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
log.Printf("Page 76")
|
||||
current := paginator.Page(76)
|
||||
for i := 76; ; i-- {
|
||||
var items []artistType
|
||||
|
||||
err := current.All(&items)
|
||||
s.NoError(err)
|
||||
|
||||
if len(items) < 1 {
|
||||
s.Equal(0, len(items))
|
||||
break
|
||||
}
|
||||
for j := 0; j < len(items); j++ {
|
||||
s.Equal(fmt.Sprintf("artist-%d", 13*int(i)+j), items[j].Name)
|
||||
}
|
||||
|
||||
current = current.PrevPage(items[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
resultPaginator := sess.Collection("artist").Find().Paginate(15)
|
||||
|
||||
count, err := resultPaginator.TotalPages()
|
||||
s.Equal(uint(67), count)
|
||||
s.NoError(err)
|
||||
|
||||
var items []artistType
|
||||
err = resultPaginator.Page(5).All(&items)
|
||||
s.NoError(err)
|
||||
|
||||
for j := 0; j < len(items); j++ {
|
||||
s.Equal(fmt.Sprintf("artist-%d", 15*5+j), items[j].Name)
|
||||
}
|
||||
|
||||
resultPaginator = resultPaginator.Cursor("_id").Page(0)
|
||||
for i := 0; ; i++ {
|
||||
var items []artistType
|
||||
|
||||
err = resultPaginator.All(&items)
|
||||
s.NoError(err)
|
||||
|
||||
if len(items) < 1 {
|
||||
break
|
||||
}
|
||||
|
||||
for j := 0; j < len(items); j++ {
|
||||
s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name)
|
||||
}
|
||||
resultPaginator = resultPaginator.NextPage(items[len(items)-1].ID)
|
||||
}
|
||||
|
||||
resultPaginator = resultPaginator.Cursor("_id").Page(66)
|
||||
for i := 66; ; i-- {
|
||||
var items []artistType
|
||||
|
||||
err = resultPaginator.All(&items)
|
||||
s.NoError(err)
|
||||
|
||||
if len(items) < 1 {
|
||||
break
|
||||
}
|
||||
|
||||
for j := 0; j < len(items); j++ {
|
||||
s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name)
|
||||
}
|
||||
resultPaginator = resultPaginator.PrevPage(items[0].ID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter(t *testing.T) {
|
||||
suite.Run(t, &AdapterTests{})
|
||||
}
|
565
adapter/mongo/result.go
Normal file
565
adapter/mongo/result.go
Normal file
@ -0,0 +1,565 @@
|
||||
package mongo
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"encoding/json"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/immutable"
|
||||
mgo "gopkg.in/mgo.v2"
|
||||
"gopkg.in/mgo.v2/bson"
|
||||
)
|
||||
|
||||
type resultQuery struct {
|
||||
c *Collection
|
||||
|
||||
fields []string
|
||||
limit int
|
||||
offset int
|
||||
sort []string
|
||||
conditions interface{}
|
||||
groupBy []interface{}
|
||||
|
||||
pageSize uint
|
||||
pageNumber uint
|
||||
cursorColumn string
|
||||
cursorValue interface{}
|
||||
cursorCond mydb.Cond
|
||||
cursorReverseOrder bool
|
||||
}
|
||||
|
||||
type result struct {
|
||||
iter *mgo.Iter
|
||||
err error
|
||||
errMu sync.Mutex
|
||||
|
||||
fn func(*resultQuery) error
|
||||
prev *result
|
||||
}
|
||||
|
||||
var _ = immutable.Immutable(&result{})
|
||||
|
||||
func (res *result) frame(fn func(*resultQuery) error) *result {
|
||||
return &result{prev: res, fn: fn}
|
||||
}
|
||||
|
||||
func (r *resultQuery) and(terms ...interface{}) error {
|
||||
if r.conditions == nil {
|
||||
return r.where(terms...)
|
||||
}
|
||||
|
||||
r.conditions = map[string]interface{}{
|
||||
"$and": []interface{}{
|
||||
r.conditions,
|
||||
r.c.compileQuery(terms...),
|
||||
},
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *resultQuery) where(terms ...interface{}) error {
|
||||
r.conditions = r.c.compileQuery(terms...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (res *result) And(terms ...interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
return r.and(terms...)
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) Where(terms ...interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
return r.where(terms...)
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) Paginate(pageSize uint) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.pageSize = pageSize
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) Page(pageNumber uint) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.pageNumber = pageNumber
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) Cursor(cursorColumn string) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.cursorColumn = cursorColumn
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) NextPage(cursorValue interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.cursorValue = cursorValue
|
||||
r.cursorReverseOrder = false
|
||||
r.cursorCond = mydb.Cond{
|
||||
r.cursorColumn: bson.M{"$gt": cursorValue},
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) PrevPage(cursorValue interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.cursorValue = cursorValue
|
||||
r.cursorReverseOrder = true
|
||||
r.cursorCond = mydb.Cond{
|
||||
r.cursorColumn: bson.M{"$lt": cursorValue},
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
func (res *result) TotalEntries() (uint64, error) {
|
||||
return res.Count()
|
||||
}
|
||||
|
||||
func (res *result) TotalPages() (uint, error) {
|
||||
count, err := res.Count()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if rq.pageSize < 1 {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
total := uint(math.Ceil(float64(count) / float64(rq.pageSize)))
|
||||
return total, nil
|
||||
}
|
||||
|
||||
// Limit determines the maximum limit of results to be returned.
|
||||
func (res *result) Limit(n int) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.limit = n
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Offset determines how many documents will be skipped before starting to grab
|
||||
// results.
|
||||
func (res *result) Offset(n int) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.offset = n
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// OrderBy determines sorting of results according to the provided names. Fields
|
||||
// may be prefixed by - (minus) which means descending order, ascending order
|
||||
// would be used otherwise.
|
||||
func (res *result) OrderBy(fields ...interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
ss := make([]string, len(fields))
|
||||
for i, field := range fields {
|
||||
ss[i] = fmt.Sprintf(`%v`, field)
|
||||
}
|
||||
r.sort = ss
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// String satisfies fmt.Stringer
|
||||
func (res *result) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Select marks the specific fields the user wants to retrieve.
|
||||
func (res *result) Select(fields ...interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
fieldslen := len(fields)
|
||||
r.fields = make([]string, 0, fieldslen)
|
||||
for i := 0; i < fieldslen; i++ {
|
||||
r.fields = append(r.fields, fmt.Sprintf(`%v`, fields[i]))
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// All dumps all results into a pointer to an slice of structs or maps.
|
||||
func (res *result) All(dst interface{}) error {
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q, err := rq.query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
queryLog(&mydb.QueryStatus{
|
||||
RawQuery: rq.debugQuery("Find.All"),
|
||||
Err: err,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
})
|
||||
}(time.Now())
|
||||
|
||||
err = q.All(dst)
|
||||
if errors.Is(err, mgo.ErrNotFound) {
|
||||
return mydb.ErrNoMoreRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// GroupBy is used to group results that have the same value in the same column
|
||||
// or columns.
|
||||
func (res *result) GroupBy(fields ...interface{}) mydb.Result {
|
||||
return res.frame(func(r *resultQuery) error {
|
||||
r.groupBy = fields
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// One fetches only one result from the resultset.
|
||||
func (res *result) One(dst interface{}) error {
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q, err := rq.query()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
queryLog(&mydb.QueryStatus{
|
||||
RawQuery: rq.debugQuery("Find.One"),
|
||||
Err: err,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
})
|
||||
}(time.Now())
|
||||
|
||||
err = q.One(dst)
|
||||
if errors.Is(err, mgo.ErrNotFound) {
|
||||
return mydb.ErrNoMoreRows
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (res *result) Err() error {
|
||||
res.errMu.Lock()
|
||||
defer res.errMu.Unlock()
|
||||
|
||||
return res.err
|
||||
}
|
||||
|
||||
func (res *result) setErr(err error) {
|
||||
res.errMu.Lock()
|
||||
defer res.errMu.Unlock()
|
||||
res.err = err
|
||||
}
|
||||
|
||||
func (res *result) Next(dst interface{}) bool {
|
||||
if res.iter == nil {
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
q, err := rq.query()
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
queryLog(&mydb.QueryStatus{
|
||||
RawQuery: rq.debugQuery("Find.Next"),
|
||||
Err: err,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
})
|
||||
}(time.Now())
|
||||
|
||||
res.iter = q.Iter()
|
||||
}
|
||||
|
||||
if !res.iter.Next(dst) {
|
||||
res.setErr(res.iter.Err())
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
// Delete remove the matching items from the collection.
|
||||
func (res *result) Delete() error {
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
queryLog(&mydb.QueryStatus{
|
||||
RawQuery: rq.debugQuery("Remove"),
|
||||
Err: err,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
})
|
||||
}(time.Now())
|
||||
|
||||
_, err = rq.c.collection.RemoveAll(rq.conditions)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Close closes the result set.
|
||||
func (r *result) Close() error {
|
||||
var err error
|
||||
if r.iter != nil {
|
||||
err = r.iter.Close()
|
||||
r.iter = nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// Update modified matching items from the collection with values of the given
|
||||
// map or struct.
|
||||
func (res *result) Update(src interface{}) (err error) {
|
||||
updateSet := map[string]interface{}{"$set": src}
|
||||
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
queryLog(&mydb.QueryStatus{
|
||||
RawQuery: rq.debugQuery("Update"),
|
||||
Err: err,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
})
|
||||
}(time.Now())
|
||||
|
||||
_, err = rq.c.collection.UpdateAll(rq.conditions, updateSet)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (res *result) build() (*resultQuery, error) {
|
||||
rqi, err := immutable.FastForward(res)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rq := rqi.(*resultQuery)
|
||||
if !rq.cursorCond.Empty() {
|
||||
if err := rq.and(rq.cursorCond); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if rq.cursorColumn != "" {
|
||||
if rq.cursorReverseOrder {
|
||||
rq.sort = append(rq.sort, "-"+rq.cursorColumn)
|
||||
} else {
|
||||
rq.sort = append(rq.sort, rq.cursorColumn)
|
||||
}
|
||||
}
|
||||
return rq, nil
|
||||
}
|
||||
|
||||
// query executes a mgo query.
|
||||
func (r *resultQuery) query() (*mgo.Query, error) {
|
||||
if len(r.groupBy) > 0 {
|
||||
return nil, mydb.ErrUnsupported
|
||||
}
|
||||
|
||||
q := r.c.collection.Find(r.conditions)
|
||||
|
||||
if r.pageSize > 0 {
|
||||
r.offset = int(r.pageSize * r.pageNumber)
|
||||
r.limit = int(r.pageSize)
|
||||
}
|
||||
|
||||
if r.offset > 0 {
|
||||
q.Skip(r.offset)
|
||||
}
|
||||
|
||||
if r.limit > 0 {
|
||||
q.Limit(r.limit)
|
||||
}
|
||||
|
||||
if len(r.sort) > 0 {
|
||||
q.Sort(r.sort...)
|
||||
}
|
||||
|
||||
selectedFields := bson.M{}
|
||||
if len(r.fields) > 0 {
|
||||
for _, field := range r.fields {
|
||||
if field == `*` {
|
||||
break
|
||||
}
|
||||
selectedFields[field] = true
|
||||
}
|
||||
}
|
||||
|
||||
if r.cursorReverseOrder {
|
||||
ids := make([]bson.ObjectId, 0, r.limit)
|
||||
|
||||
iter := q.Select(bson.M{"_id": true}).Iter()
|
||||
defer iter.Close()
|
||||
|
||||
var item map[string]bson.ObjectId
|
||||
for iter.Next(&item) {
|
||||
ids = append(ids, item["_id"])
|
||||
}
|
||||
|
||||
r.conditions = bson.M{"_id": bson.M{"$in": ids}}
|
||||
|
||||
q = r.c.collection.Find(r.conditions)
|
||||
}
|
||||
|
||||
if len(selectedFields) > 0 {
|
||||
q.Select(selectedFields)
|
||||
}
|
||||
|
||||
return q, nil
|
||||
}
|
||||
|
||||
func (res *result) Exists() (bool, error) {
|
||||
total, err := res.Count()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if total > 0 {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// Count counts matching elements.
|
||||
func (res *result) Count() (total uint64, err error) {
|
||||
rq, err := res.build()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
defer func(start time.Time) {
|
||||
queryLog(&mydb.QueryStatus{
|
||||
RawQuery: rq.debugQuery("Find.Count"),
|
||||
Err: err,
|
||||
Start: start,
|
||||
End: time.Now(),
|
||||
})
|
||||
}(time.Now())
|
||||
|
||||
q := rq.c.collection.Find(rq.conditions)
|
||||
|
||||
var c int
|
||||
c, err = q.Count()
|
||||
|
||||
return uint64(c), err
|
||||
}
|
||||
|
||||
func (res *result) Prev() immutable.Immutable {
|
||||
if res == nil {
|
||||
return nil
|
||||
}
|
||||
return res.prev
|
||||
}
|
||||
|
||||
func (res *result) Fn(in interface{}) error {
|
||||
if res.fn == nil {
|
||||
return nil
|
||||
}
|
||||
return res.fn(in.(*resultQuery))
|
||||
}
|
||||
|
||||
func (res *result) Base() interface{} {
|
||||
return &resultQuery{}
|
||||
}
|
||||
|
||||
func (r *resultQuery) debugQuery(action string) string {
|
||||
query := fmt.Sprintf("mydb.%s.%s", r.c.collection.Name, action)
|
||||
|
||||
if r.conditions != nil {
|
||||
query = fmt.Sprintf("%s.conds(%v)", query, r.conditions)
|
||||
}
|
||||
if r.limit > 0 {
|
||||
query = fmt.Sprintf("%s.limit(%d)", query, r.limit)
|
||||
}
|
||||
if r.offset > 0 {
|
||||
query = fmt.Sprintf("%s.offset(%d)", query, r.offset)
|
||||
}
|
||||
if len(r.fields) > 0 {
|
||||
selectedFields := bson.M{}
|
||||
for _, field := range r.fields {
|
||||
if field == `*` {
|
||||
break
|
||||
}
|
||||
selectedFields[field] = true
|
||||
}
|
||||
if len(selectedFields) > 0 {
|
||||
query = fmt.Sprintf("%s.select(%v)", query, selectedFields)
|
||||
}
|
||||
}
|
||||
if len(r.groupBy) > 0 {
|
||||
escaped := make([]string, len(r.groupBy))
|
||||
for i := range r.groupBy {
|
||||
escaped[i] = string(mustJSON(r.groupBy[i]))
|
||||
}
|
||||
query = fmt.Sprintf("%s.groupBy(%v)", query, strings.Join(escaped, ", "))
|
||||
}
|
||||
if len(r.sort) > 0 {
|
||||
escaped := make([]string, len(r.sort))
|
||||
for i := range r.sort {
|
||||
escaped[i] = string(mustJSON(r.sort[i]))
|
||||
}
|
||||
query = fmt.Sprintf("%s.sort(%s)", query, strings.Join(escaped, ", "))
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func mustJSON(in interface{}) (out []byte) {
|
||||
out, err := json.Marshal(in)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func queryLog(status *mydb.QueryStatus) {
|
||||
diff := status.End.Sub(status.Start)
|
||||
|
||||
slowQuery := false
|
||||
if diff >= time.Millisecond*100 {
|
||||
status.Err = mydb.ErrWarnSlowQuery
|
||||
slowQuery = true
|
||||
}
|
||||
|
||||
if status.Err != nil || slowQuery {
|
||||
mydb.LC().Warn(status)
|
||||
return
|
||||
}
|
||||
|
||||
mydb.LC().Debug(status)
|
||||
}
|
43
adapter/mysql/Makefile
Normal file
43
adapter/mysql/Makefile
Normal file
@ -0,0 +1,43 @@
|
||||
SHELL ?= bash
|
||||
|
||||
MYSQL_VERSION ?= 8.1
|
||||
MYSQL_SUPPORTED ?= $(MYSQL_VERSION) 5.7
|
||||
PROJECT ?= upper_mysql_$(MYSQL_VERSION)
|
||||
|
||||
DB_HOST ?= 127.0.0.1
|
||||
DB_PORT ?= 3306
|
||||
|
||||
DB_NAME ?= upperio
|
||||
DB_USERNAME ?= upperio_user
|
||||
DB_PASSWORD ?= upperio//s3cr37
|
||||
|
||||
TEST_FLAGS ?=
|
||||
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1
|
||||
|
||||
export MYSQL_VERSION
|
||||
|
||||
export DB_HOST
|
||||
export DB_NAME
|
||||
export DB_PASSWORD
|
||||
export DB_PORT
|
||||
export DB_USERNAME
|
||||
|
||||
export TEST_FLAGS
|
||||
|
||||
test:
|
||||
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
|
||||
|
||||
test-no-race:
|
||||
go test -v -failfast $(TEST_FLAGS)
|
||||
|
||||
server-up: server-down
|
||||
docker-compose -p $(PROJECT) up -d && \
|
||||
sleep 15
|
||||
|
||||
server-down:
|
||||
docker-compose -p $(PROJECT) down
|
||||
|
||||
test-extended:
|
||||
parallel $(PARALLEL_FLAGS) \
|
||||
"MYSQL_VERSION={} DB_PORT=\$$((3306+{#})) $(MAKE) server-up test server-down" ::: \
|
||||
$(MYSQL_SUPPORTED)
|
5
adapter/mysql/README.md
Normal file
5
adapter/mysql/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# MySQL adapter for upper/db
|
||||
|
||||
Please read the full docs, acknowledgements and examples at
|
||||
[https://upper.io/v4/adapter/mysql/](https://upper.io/v4/adapter/mysql/).
|
||||
|
56
adapter/mysql/collection.go
Normal file
56
adapter/mysql/collection.go
Normal file
@ -0,0 +1,56 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
type collectionAdapter struct {
|
||||
}
|
||||
|
||||
func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) {
|
||||
columnNames, columnValues, err := sqlbuilder.Map(item, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pKey, err := col.PrimaryKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := col.SQL().InsertInto(col.Name()).
|
||||
Columns(columnNames...).
|
||||
Values(columnValues...)
|
||||
|
||||
res, err := q.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lastID, err := res.LastInsertId()
|
||||
if err == nil && len(pKey) <= 1 {
|
||||
return lastID, nil
|
||||
}
|
||||
|
||||
keyMap := mydb.Cond{}
|
||||
for i := range columnNames {
|
||||
for j := 0; j < len(pKey); j++ {
|
||||
if pKey[j] == columnNames[i] {
|
||||
keyMap[pKey[j]] = columnValues[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// There was an auto column among primary keys, let's search for it.
|
||||
if lastID > 0 {
|
||||
for j := 0; j < len(pKey); j++ {
|
||||
if keyMap[pKey[j]] == nil {
|
||||
keyMap[pKey[j]] = lastID
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keyMap, nil
|
||||
}
|
244
adapter/mysql/connection.go
Normal file
244
adapter/mysql/connection.go
Normal file
@ -0,0 +1,244 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// From https://github.com/go-sql-driver/mysql/blob/master/utils.go
|
||||
var (
|
||||
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
|
||||
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
|
||||
errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name")
|
||||
)
|
||||
|
||||
// From https://github.com/go-sql-driver/mysql/blob/master/utils.go
|
||||
type config struct {
|
||||
user string
|
||||
passwd string
|
||||
net string
|
||||
addr string
|
||||
dbname string
|
||||
params map[string]string
|
||||
}
|
||||
|
||||
// ConnectionURL implements a MySQL connection struct.
|
||||
type ConnectionURL struct {
|
||||
User string
|
||||
Password string
|
||||
Database string
|
||||
Host string
|
||||
Socket string
|
||||
Options map[string]string
|
||||
}
|
||||
|
||||
func (c ConnectionURL) String() (s string) {
|
||||
|
||||
if c.Database == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Adding username.
|
||||
if c.User != "" {
|
||||
s = s + c.User
|
||||
// Adding password.
|
||||
if c.Password != "" {
|
||||
s = s + ":" + c.Password
|
||||
}
|
||||
s = s + "@"
|
||||
}
|
||||
|
||||
// Adding protocol and address
|
||||
if c.Socket != "" {
|
||||
s = s + fmt.Sprintf("unix(%s)", c.Socket)
|
||||
} else if c.Host != "" {
|
||||
host, port, err := net.SplitHostPort(c.Host)
|
||||
if err != nil {
|
||||
host = c.Host
|
||||
port = "3306"
|
||||
}
|
||||
s = s + fmt.Sprintf("tcp(%s:%s)", host, port)
|
||||
}
|
||||
|
||||
// Adding database
|
||||
s = s + "/" + c.Database
|
||||
|
||||
// Do we have any options?
|
||||
if c.Options == nil {
|
||||
c.Options = map[string]string{}
|
||||
}
|
||||
|
||||
// Default options.
|
||||
if _, ok := c.Options["charset"]; !ok {
|
||||
c.Options["charset"] = "utf8"
|
||||
}
|
||||
|
||||
if _, ok := c.Options["parseTime"]; !ok {
|
||||
c.Options["parseTime"] = "true"
|
||||
}
|
||||
|
||||
// Converting options into URL values.
|
||||
vv := url.Values{}
|
||||
|
||||
for k, v := range c.Options {
|
||||
vv.Set(k, v)
|
||||
}
|
||||
|
||||
// Inserting options.
|
||||
if p := vv.Encode(); p != "" {
|
||||
s = s + "?" + p
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// ParseURL parses s into a ConnectionURL struct.
|
||||
func ParseURL(s string) (conn ConnectionURL, err error) {
|
||||
var cfg *config
|
||||
|
||||
if cfg, err = parseDSN(s); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
conn.User = cfg.user
|
||||
conn.Password = cfg.passwd
|
||||
|
||||
if cfg.net == "unix" {
|
||||
conn.Socket = cfg.addr
|
||||
} else if cfg.net == "tcp" {
|
||||
conn.Host = cfg.addr
|
||||
}
|
||||
|
||||
conn.Database = cfg.dbname
|
||||
|
||||
conn.Options = map[string]string{}
|
||||
|
||||
for k, v := range cfg.params {
|
||||
conn.Options[k] = v
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// from https://github.com/go-sql-driver/mysql/blob/master/utils.go
|
||||
// parseDSN parses the DSN string to a config
|
||||
func parseDSN(dsn string) (cfg *config, err error) {
|
||||
// New config with some default values
|
||||
cfg = &config{}
|
||||
|
||||
// TODO: use strings.IndexByte when we can depend on Go 1.2
|
||||
|
||||
// [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN]
|
||||
// Find the last '/' (since the password or the net addr might contain a '/')
|
||||
foundSlash := false
|
||||
for i := len(dsn) - 1; i >= 0; i-- {
|
||||
if dsn[i] == '/' {
|
||||
foundSlash = true
|
||||
var j, k int
|
||||
|
||||
// left part is empty if i <= 0
|
||||
if i > 0 {
|
||||
// [username[:password]@][protocol[(address)]]
|
||||
// Find the last '@' in dsn[:i]
|
||||
for j = i; j >= 0; j-- {
|
||||
if dsn[j] == '@' {
|
||||
// username[:password]
|
||||
// Find the first ':' in dsn[:j]
|
||||
for k = 0; k < j; k++ {
|
||||
if dsn[k] == ':' {
|
||||
cfg.passwd = dsn[k+1 : j]
|
||||
break
|
||||
}
|
||||
}
|
||||
cfg.user = dsn[:k]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// [protocol[(address)]]
|
||||
// Find the first '(' in dsn[j+1:i]
|
||||
for k = j + 1; k < i; k++ {
|
||||
if dsn[k] == '(' {
|
||||
// dsn[i-1] must be == ')' if an address is specified
|
||||
if dsn[i-1] != ')' {
|
||||
if strings.ContainsRune(dsn[k+1:i], ')') {
|
||||
return nil, errInvalidDSNUnescaped
|
||||
}
|
||||
return nil, errInvalidDSNAddr
|
||||
}
|
||||
cfg.addr = dsn[k+1 : i-1]
|
||||
break
|
||||
}
|
||||
}
|
||||
cfg.net = dsn[j+1 : k]
|
||||
}
|
||||
|
||||
// dbname[?param1=value1&...¶mN=valueN]
|
||||
// Find the first '?' in dsn[i+1:]
|
||||
for j = i + 1; j < len(dsn); j++ {
|
||||
if dsn[j] == '?' {
|
||||
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
|
||||
return
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
cfg.dbname = dsn[i+1 : j]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if !foundSlash && len(dsn) > 0 {
|
||||
return nil, errInvalidDSNNoSlash
|
||||
}
|
||||
|
||||
// Set default network if empty
|
||||
if cfg.net == "" {
|
||||
cfg.net = "tcp"
|
||||
}
|
||||
|
||||
// Set default address if empty
|
||||
if cfg.addr == "" {
|
||||
switch cfg.net {
|
||||
case "tcp":
|
||||
cfg.addr = "127.0.0.1:3306"
|
||||
case "unix":
|
||||
cfg.addr = "/tmp/mysql.sock"
|
||||
default:
|
||||
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// From https://github.com/go-sql-driver/mysql/blob/master/utils.go
|
||||
// parseDSNParams parses the DSN "query string"
|
||||
// Values must be url.QueryEscape'ed
|
||||
func parseDSNParams(cfg *config, params string) (err error) {
|
||||
for _, v := range strings.Split(params, "&") {
|
||||
param := strings.SplitN(v, "=", 2)
|
||||
if len(param) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
value := param[1]
|
||||
|
||||
// lazy init
|
||||
if cfg.params == nil {
|
||||
cfg.params = make(map[string]string)
|
||||
}
|
||||
|
||||
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
return
|
||||
}
|
117
adapter/mysql/connection_test.go
Normal file
117
adapter/mysql/connection_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnectionURL(t *testing.T) {
|
||||
|
||||
c := ConnectionURL{}
|
||||
|
||||
// Zero value equals to an empty string.
|
||||
if c.String() != "" {
|
||||
t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String())
|
||||
}
|
||||
|
||||
// Adding a database name.
|
||||
c.Database = "mydbname"
|
||||
|
||||
if c.String() != "/mydbname?charset=utf8&parseTime=true" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Adding an option.
|
||||
c.Options = map[string]string{
|
||||
"charset": "utf8mb4,utf8",
|
||||
"sys_var": "esc@ped",
|
||||
}
|
||||
|
||||
if c.String() != "/mydbname?charset=utf8mb4%2Cutf8&parseTime=true&sys_var=esc%40ped" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting default options
|
||||
c.Options = nil
|
||||
|
||||
// Setting user and password.
|
||||
c.User = "user"
|
||||
c.Password = "pass"
|
||||
|
||||
if c.String() != `user:pass@/mydbname?charset=utf8&parseTime=true` {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting host.
|
||||
c.Host = "1.2.3.4:3306"
|
||||
|
||||
if c.String() != `user:pass@tcp(1.2.3.4:3306)/mydbname?charset=utf8&parseTime=true` {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting socket.
|
||||
c.Socket = "/path/to/socket"
|
||||
|
||||
if c.String() != `user:pass@unix(/path/to/socket)/mydbname?charset=utf8&parseTime=true` {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseConnectionURL(t *testing.T) {
|
||||
var u ConnectionURL
|
||||
var s string
|
||||
var err error
|
||||
|
||||
s = "user:pass@unix(/path/to/socket)/mydbname?charset=utf8"
|
||||
|
||||
if u, err = ParseURL(s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u.User != "user" {
|
||||
t.Fatal("Expecting username.")
|
||||
}
|
||||
|
||||
if u.Password != "pass" {
|
||||
t.Fatal("Expecting password.")
|
||||
}
|
||||
|
||||
if u.Socket != "/path/to/socket" {
|
||||
t.Fatal("Expecting socket.")
|
||||
}
|
||||
|
||||
if u.Database != "mydbname" {
|
||||
t.Fatal("Expecting database.")
|
||||
}
|
||||
|
||||
if u.Options["charset"] != "utf8" {
|
||||
t.Fatal("Expecting charset.")
|
||||
}
|
||||
|
||||
s = "user:pass@tcp(1.2.3.4:5678)/mydbname?charset=utf8"
|
||||
|
||||
if u, err = ParseURL(s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u.User != "user" {
|
||||
t.Fatal("Expecting username.")
|
||||
}
|
||||
|
||||
if u.Password != "pass" {
|
||||
t.Fatal("Expecting password.")
|
||||
}
|
||||
|
||||
if u.Host != "1.2.3.4:5678" {
|
||||
t.Fatal("Expecting host.")
|
||||
}
|
||||
|
||||
if u.Database != "mydbname" {
|
||||
t.Fatal("Expecting database.")
|
||||
}
|
||||
|
||||
if u.Options["charset"] != "utf8" {
|
||||
t.Fatal("Expecting charset.")
|
||||
}
|
||||
|
||||
}
|
151
adapter/mysql/custom_types.go
Normal file
151
adapter/mysql/custom_types.go
Normal file
@ -0,0 +1,151 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
// JSON represents a MySQL's JSON value:
|
||||
// https://www.mysql.org/docs/9.6/static/datatype-json.html. JSON
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type JSON struct {
|
||||
V interface{}
|
||||
}
|
||||
|
||||
// MarshalJSON encodes the wrapper value as JSON.
|
||||
func (j JSON) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(j.V)
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes the given JSON into the wrapped value.
|
||||
func (j *JSON) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
j.V = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (j *JSON) Scan(src interface{}) error {
|
||||
if j.V == nil {
|
||||
return nil
|
||||
}
|
||||
if src == nil {
|
||||
dv := reflect.Indirect(reflect.ValueOf(j.V))
|
||||
dv.Set(reflect.Zero(dv.Type()))
|
||||
return nil
|
||||
}
|
||||
b, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source was not []bytes")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, j.V); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (j JSON) Value() (driver.Value, error) {
|
||||
if j.V == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if v, ok := j.V.(json.RawMessage); ok {
|
||||
return string(v), nil
|
||||
}
|
||||
b, err := json.Marshal(j.V)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// JSONMap represents a map of interfaces with string keys
|
||||
// (`map[string]interface{}`) that is compatible with MySQL's JSON type.
|
||||
// JSONMap satisfies sqlbuilder.ScannerValuer.
|
||||
type JSONMap map[string]interface{}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (m JSONMap) Value() (driver.Value, error) {
|
||||
return JSONValue(m)
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (m *JSONMap) Scan(src interface{}) error {
|
||||
*m = map[string]interface{}(nil)
|
||||
return ScanJSON(m, src)
|
||||
}
|
||||
|
||||
// JSONArray represents an array of any type (`[]interface{}`) that is
|
||||
// compatible with MySQL's JSON type. JSONArray satisfies
|
||||
// sqlbuilder.ScannerValuer.
|
||||
type JSONArray []interface{}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (a JSONArray) Value() (driver.Value, error) {
|
||||
return JSONValue(a)
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (a *JSONArray) Scan(src interface{}) error {
|
||||
return ScanJSON(a, src)
|
||||
}
|
||||
|
||||
// JSONValue takes an interface and provides a driver.Value that can be
|
||||
// stored as a JSON column.
|
||||
func JSONValue(i interface{}) (driver.Value, error) {
|
||||
v := JSON{i}
|
||||
return v.Value()
|
||||
}
|
||||
|
||||
// ScanJSON decodes a JSON byte stream into the passed dst value.
|
||||
func ScanJSON(dst interface{}, src interface{}) error {
|
||||
v := JSON{dst}
|
||||
return v.Scan(src)
|
||||
}
|
||||
|
||||
// EncodeJSON is deprecated and going to be removed. Use ScanJSON instead.
|
||||
func EncodeJSON(i interface{}) (driver.Value, error) {
|
||||
return JSONValue(i)
|
||||
}
|
||||
|
||||
// DecodeJSON is deprecated and going to be removed. Use JSONValue instead.
|
||||
func DecodeJSON(dst interface{}, src interface{}) error {
|
||||
return ScanJSON(dst, src)
|
||||
}
|
||||
|
||||
// JSONConverter provides a helper method WrapValue that satisfies
|
||||
// sqlbuilder.ValueWrapper, can be used to encode Go structs into JSON
|
||||
// MySQL types and vice versa.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// type MyCustomStruct struct {
|
||||
// ID int64 `db:"id" json:"id"`
|
||||
// Name string `db:"name" json:"name"`
|
||||
// ...
|
||||
// mysql.JSONConverter
|
||||
// }
|
||||
type JSONConverter struct{}
|
||||
|
||||
func (*JSONConverter) ConvertValue(in interface{}) interface {
|
||||
sql.Scanner
|
||||
driver.Valuer
|
||||
} {
|
||||
return &JSON{in}
|
||||
}
|
||||
|
||||
// Type checks.
|
||||
var (
|
||||
_ sqlbuilder.ScannerValuer = &JSONMap{}
|
||||
_ sqlbuilder.ScannerValuer = &JSONArray{}
|
||||
_ sqlbuilder.ScannerValuer = &JSON{}
|
||||
)
|
168
adapter/mysql/database.go
Normal file
168
adapter/mysql/database.go
Normal file
@ -0,0 +1,168 @@
|
||||
// Package mysql wraps the github.com/go-sql-driver/mysql MySQL driver. See
|
||||
// https://github.com/upper/db/adapter/mysql for documentation, particularities and usage
|
||||
// examples.
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
_ "github.com/go-sql-driver/mysql" // MySQL driver.
|
||||
)
|
||||
|
||||
// database is the actual implementation of Database
|
||||
type database struct {
|
||||
}
|
||||
|
||||
func (*database) Template() *exql.Template {
|
||||
return template
|
||||
}
|
||||
|
||||
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
|
||||
return sql.Open("mysql", dsn)
|
||||
}
|
||||
|
||||
func (*database) Collections(sess sqladapter.Session) (collections []string, err error) {
|
||||
q := sess.SQL().
|
||||
Select("table_name").
|
||||
From("information_schema.tables").
|
||||
Where("table_schema = ?", sess.Name())
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
for iter.Next() {
|
||||
var tableName string
|
||||
if err := iter.Scan(&tableName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collections = append(collections, tableName)
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return collections, nil
|
||||
}
|
||||
|
||||
func (d *database) ConvertValue(in interface{}) interface{} {
|
||||
switch v := in.(type) {
|
||||
case *map[string]interface{}:
|
||||
return (*JSONMap)(v)
|
||||
|
||||
case map[string]interface{}:
|
||||
return (*JSONMap)(&v)
|
||||
}
|
||||
|
||||
dv := reflect.ValueOf(in)
|
||||
if dv.IsValid() {
|
||||
if dv.Type().Kind() == reflect.Ptr {
|
||||
dv = dv.Elem()
|
||||
}
|
||||
|
||||
switch dv.Kind() {
|
||||
case reflect.Map:
|
||||
if reflect.TypeOf(in).Kind() == reflect.Ptr {
|
||||
w := reflect.ValueOf(in)
|
||||
z := reflect.New(w.Elem().Type())
|
||||
w.Elem().Set(z.Elem())
|
||||
}
|
||||
return &JSON{in}
|
||||
case reflect.Slice:
|
||||
return &JSON{in}
|
||||
}
|
||||
}
|
||||
|
||||
return in
|
||||
}
|
||||
|
||||
func (*database) Err(err error) error {
|
||||
if err != nil {
|
||||
// This error is not exported so we have to check it by its string value.
|
||||
s := err.Error()
|
||||
if strings.Contains(s, `many connections`) {
|
||||
return mydb.ErrTooManyClients
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (*database) NewCollection() sqladapter.CollectionAdapter {
|
||||
return &collectionAdapter{}
|
||||
}
|
||||
|
||||
func (*database) LookupName(sess sqladapter.Session) (string, error) {
|
||||
q := sess.SQL().
|
||||
Select(mydb.Raw("DATABASE() AS name"))
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
if iter.Next() {
|
||||
var name string
|
||||
if err := iter.Scan(&name); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
return "", iter.Err()
|
||||
}
|
||||
|
||||
func (*database) TableExists(sess sqladapter.Session, name string) error {
|
||||
q := sess.SQL().
|
||||
Select("table_name").
|
||||
From("information_schema.tables").
|
||||
Where("table_schema = ? AND table_name = ?", sess.Name(), name)
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
if iter.Next() {
|
||||
var name string
|
||||
if err := iter.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mydb.ErrCollectionDoesNotExist
|
||||
}
|
||||
|
||||
func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) {
|
||||
q := sess.SQL().
|
||||
Select("k.column_name").
|
||||
From("information_schema.key_column_usage AS k").
|
||||
Where(`
|
||||
k.constraint_name = 'PRIMARY'
|
||||
AND k.table_schema = ?
|
||||
AND k.table_name = ?
|
||||
`, sess.Name(), tableName).
|
||||
OrderBy("k.ordinal_position")
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
pk := []string{}
|
||||
|
||||
for iter.Next() {
|
||||
var k string
|
||||
if err := iter.Scan(&k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pk = append(pk, k)
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
14
adapter/mysql/docker-compose.yml
Normal file
14
adapter/mysql/docker-compose.yml
Normal file
@ -0,0 +1,14 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
|
||||
server:
|
||||
image: mysql:${MYSQL_VERSION:-5}
|
||||
environment:
|
||||
MYSQL_USER: ${DB_USERNAME:-upperio_user}
|
||||
MYSQL_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37}
|
||||
MYSQL_ALLOW_EMPTY_PASSWORD: 1
|
||||
MYSQL_DATABASE: ${DB_NAME:-upperio}
|
||||
ports:
|
||||
- '${DB_HOST:-127.0.0.1}:${DB_PORT:-3306}:3306'
|
||||
|
20
adapter/mysql/generic_test.go
Normal file
20
adapter/mysql/generic_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GenericTests struct {
|
||||
testsuite.GenericTestSuite
|
||||
}
|
||||
|
||||
func (s *GenericTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
suite.Run(t, &GenericTests{})
|
||||
}
|
276
adapter/mysql/helper_test.go
Normal file
276
adapter/mysql/helper_test.go
Normal file
@ -0,0 +1,276 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
)
|
||||
|
||||
var settings = ConnectionURL{
|
||||
Database: os.Getenv("DB_NAME"),
|
||||
User: os.Getenv("DB_USERNAME"),
|
||||
Password: os.Getenv("DB_PASSWORD"),
|
||||
Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"),
|
||||
Options: map[string]string{
|
||||
// See https://github.com/go-sql-driver/mysql/issues/9
|
||||
"parseTime": "true",
|
||||
// Might require you to use mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root -p mysql
|
||||
"time_zone": fmt.Sprintf(`'%s'`, testsuite.TimeZone),
|
||||
"loc": testsuite.TimeZone,
|
||||
},
|
||||
}
|
||||
|
||||
type Helper struct {
|
||||
sess mydb.Session
|
||||
}
|
||||
|
||||
func cleanUp(sess mydb.Session) error {
|
||||
if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 {
|
||||
return fmt.Errorf("Expecting active statements to be at most 128, got %d", activeStatements)
|
||||
}
|
||||
|
||||
sess.Reset()
|
||||
|
||||
if activeStatements := sqladapter.NumActiveStatements(); activeStatements != 0 {
|
||||
return fmt.Errorf("Expecting active statements to be 0, got %d", activeStatements)
|
||||
}
|
||||
|
||||
var err error
|
||||
var stats map[string]int
|
||||
for i := 0; i < 10; i++ {
|
||||
stats, err = getStats(sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stats["Prepared_stmt_count"] != 0 {
|
||||
time.Sleep(time.Millisecond * 200) // Sometimes it takes a bit to clean prepared statements
|
||||
err = fmt.Errorf(`Expecting "Prepared_stmt_count" to be 0, got %d`, stats["Prepared_stmt_count"])
|
||||
continue
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func getStats(sess mydb.Session) (map[string]int, error) {
|
||||
stats := make(map[string]int)
|
||||
|
||||
res, err := sess.Driver().(*sql.DB).Query(`SHOW GLOBAL STATUS LIKE '%stmt%'`)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var result struct {
|
||||
VariableName string `db:"Variable_name"`
|
||||
Value int `db:"Value"`
|
||||
}
|
||||
|
||||
iter := sess.SQL().NewIterator(res)
|
||||
for iter.Next(&result) {
|
||||
stats[result.VariableName] = result.Value
|
||||
}
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (h *Helper) Session() mydb.Session {
|
||||
return h.sess
|
||||
}
|
||||
|
||||
func (h *Helper) Adapter() string {
|
||||
return "mysql"
|
||||
}
|
||||
|
||||
func (h *Helper) TearDown() error {
|
||||
if err := cleanUp(h.sess); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.sess.Close()
|
||||
}
|
||||
|
||||
func (h *Helper) TearUp() error {
|
||||
var err error
|
||||
|
||||
h.sess, err = Open(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batch := []string{
|
||||
`DROP TABLE IF EXISTS artist`,
|
||||
`CREATE TABLE artist (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
name VARCHAR(60)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS publication`,
|
||||
`CREATE TABLE publication (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
title VARCHAR(80),
|
||||
author_id BIGINT(20)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS review`,
|
||||
`CREATE TABLE review (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
publication_id BIGINT(20),
|
||||
name VARCHAR(80),
|
||||
comments TEXT,
|
||||
created DATETIME NOT NULL
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS data_types`,
|
||||
`CREATE TABLE data_types (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
_uint INT(10) UNSIGNED DEFAULT 0,
|
||||
_uint8 INT(10) UNSIGNED DEFAULT 0,
|
||||
_uint16 INT(10) UNSIGNED DEFAULT 0,
|
||||
_uint32 INT(10) UNSIGNED DEFAULT 0,
|
||||
_uint64 INT(10) UNSIGNED DEFAULT 0,
|
||||
_int INT(10) DEFAULT 0,
|
||||
_int8 INT(10) DEFAULT 0,
|
||||
_int16 INT(10) DEFAULT 0,
|
||||
_int32 INT(10) DEFAULT 0,
|
||||
_int64 INT(10) DEFAULT 0,
|
||||
_float32 DECIMAL(10,6),
|
||||
_float64 DECIMAL(10,6),
|
||||
_bool TINYINT(1),
|
||||
_string text,
|
||||
_blob blob,
|
||||
_date TIMESTAMP NULL,
|
||||
_nildate DATETIME NULL,
|
||||
_ptrdate DATETIME NULL,
|
||||
_defaultdate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
|
||||
_time BIGINT UNSIGNED NOT NULL DEFAULT 0
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS stats_test`,
|
||||
`CREATE TABLE stats_test (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
|
||||
` + "`numeric`" + ` INT(10),
|
||||
` + "`value`" + ` INT(10)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS composite_keys`,
|
||||
`CREATE TABLE composite_keys (
|
||||
code VARCHAR(255) default '',
|
||||
user_id VARCHAR(255) default '',
|
||||
some_val VARCHAR(255) default '',
|
||||
primary key (code, user_id)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS admin`,
|
||||
`CREATE TABLE admin (
|
||||
ID int(11) NOT NULL AUTO_INCREMENT,
|
||||
Accounts varchar(255) DEFAULT '',
|
||||
LoginPassWord varchar(255) DEFAULT '',
|
||||
Date TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
|
||||
PRIMARY KEY (ID,Date)
|
||||
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8`,
|
||||
|
||||
`DROP TABLE IF EXISTS my_types`,
|
||||
|
||||
`CREATE TABLE my_types (id int(11) NOT NULL AUTO_INCREMENT, PRIMARY KEY(id)
|
||||
, json_map JSON
|
||||
, json_map_ptr JSON
|
||||
|
||||
, auto_json_map JSON
|
||||
, auto_json_map_string JSON
|
||||
, auto_json_map_integer JSON
|
||||
|
||||
, json_object JSON
|
||||
, json_array JSON
|
||||
|
||||
, custom_json_object JSON
|
||||
, auto_custom_json_object JSON
|
||||
|
||||
, custom_json_object_ptr JSON
|
||||
, auto_custom_json_object_ptr JSON
|
||||
|
||||
, custom_json_object_array JSON
|
||||
, auto_custom_json_object_array JSON
|
||||
, auto_custom_json_object_map JSON
|
||||
|
||||
, integer_compat_value_json_array JSON
|
||||
, string_compat_value_json_array JSON
|
||||
, uinteger_compat_value_json_array JSON
|
||||
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS ` + "`" + `birthdays` + "`" + ``,
|
||||
`CREATE TABLE ` + "`" + `birthdays` + "`" + ` (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
|
||||
name VARCHAR(50),
|
||||
born DATE,
|
||||
born_ut BIGINT(20) SIGNED
|
||||
) CHARSET=utf8`,
|
||||
|
||||
`DROP TABLE IF EXISTS ` + "`" + `fibonacci` + "`" + ``,
|
||||
`CREATE TABLE ` + "`" + `fibonacci` + "`" + ` (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
|
||||
input BIGINT(20) UNSIGNED NOT NULL,
|
||||
output BIGINT(20) UNSIGNED NOT NULL
|
||||
) CHARSET=utf8`,
|
||||
|
||||
`DROP TABLE IF EXISTS ` + "`" + `is_even` + "`" + ``,
|
||||
`CREATE TABLE ` + "`" + `is_even` + "`" + ` (
|
||||
input BIGINT(20) UNSIGNED NOT NULL,
|
||||
is_even TINYINT(1)
|
||||
) CHARSET=utf8`,
|
||||
|
||||
`DROP TABLE IF EXISTS ` + "`" + `CaSe_TesT` + "`" + ``,
|
||||
`CREATE TABLE ` + "`" + `CaSe_TesT` + "`" + ` (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
|
||||
case_test VARCHAR(60)
|
||||
) CHARSET=utf8`,
|
||||
|
||||
`DROP TABLE IF EXISTS accounts`,
|
||||
|
||||
`CREATE TABLE accounts (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
name varchar(255),
|
||||
disabled BOOL,
|
||||
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS users`,
|
||||
|
||||
`CREATE TABLE users (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
account_id BIGINT(20),
|
||||
username varchar(255) UNIQUE
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS logs`,
|
||||
|
||||
`CREATE TABLE logs (
|
||||
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
|
||||
PRIMARY KEY(id),
|
||||
message VARCHAR(255)
|
||||
)`,
|
||||
}
|
||||
|
||||
for _, query := range batch {
|
||||
driver := h.sess.Driver().(*sql.DB)
|
||||
if _, err := driver.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ testsuite.Helper = &Helper{}
|
30
adapter/mysql/mysql.go
Normal file
30
adapter/mysql/mysql.go
Normal file
@ -0,0 +1,30 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
// Adapter is the public name of the adapter.
|
||||
const Adapter = `mysql`
|
||||
|
||||
var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{})
|
||||
|
||||
// Open establishes a connection to the database server and returns a
|
||||
// mydb.Session instance (which is compatible with mydb.Session).
|
||||
func Open(connURL mydb.ConnectionURL) (mydb.Session, error) {
|
||||
return registeredAdapter.OpenDSN(connURL)
|
||||
}
|
||||
|
||||
// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value.
|
||||
func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
|
||||
return registeredAdapter.NewTx(sqlTx)
|
||||
}
|
||||
|
||||
// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value.
|
||||
func New(sqlDB *sql.DB) (mydb.Session, error) {
|
||||
return registeredAdapter.New(sqlDB)
|
||||
}
|
379
adapter/mysql/mysql_test.go
Normal file
379
adapter/mysql/mysql_test.go
Normal file
@ -0,0 +1,379 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type int64Compat int64
|
||||
|
||||
type uintCompat uint
|
||||
|
||||
type stringCompat string
|
||||
|
||||
type uintCompatArray []uintCompat
|
||||
|
||||
func (u *int64Compat) Scan(src interface{}) error {
|
||||
if src != nil {
|
||||
switch v := src.(type) {
|
||||
case int64:
|
||||
*u = int64Compat((src).(int64))
|
||||
case []byte:
|
||||
i, err := strconv.ParseInt(string(v), 10, 64)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*u = int64Compat(i)
|
||||
default:
|
||||
panic(fmt.Sprintf("expected type %T", src))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type customJSON struct {
|
||||
N string `json:"name"`
|
||||
V float64 `json:"value"`
|
||||
}
|
||||
|
||||
func (c customJSON) Value() (driver.Value, error) {
|
||||
return JSONValue(c)
|
||||
}
|
||||
|
||||
func (c *customJSON) Scan(src interface{}) error {
|
||||
return ScanJSON(c, src)
|
||||
}
|
||||
|
||||
type autoCustomJSON struct {
|
||||
N string `json:"name"`
|
||||
V float64 `json:"value"`
|
||||
|
||||
*JSONConverter
|
||||
}
|
||||
|
||||
var (
|
||||
_ = driver.Valuer(&customJSON{})
|
||||
_ = sql.Scanner(&customJSON{})
|
||||
)
|
||||
|
||||
type AdapterTests struct {
|
||||
testsuite.Suite
|
||||
}
|
||||
|
||||
func (s *AdapterTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestInsertReturningCompositeKey_Issue383() {
|
||||
sess := s.Session()
|
||||
|
||||
type Admin struct {
|
||||
ID int `db:"ID,omitempty"`
|
||||
Accounts string `db:"Accounts"`
|
||||
LoginPassWord string `db:"LoginPassWord"`
|
||||
Date time.Time `db:"Date"`
|
||||
}
|
||||
|
||||
dateNow := time.Now()
|
||||
|
||||
a := Admin{
|
||||
Accounts: "admin",
|
||||
LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E",
|
||||
Date: dateNow,
|
||||
}
|
||||
|
||||
adminCollection := sess.Collection("admin")
|
||||
err := adminCollection.InsertReturning(&a)
|
||||
s.NoError(err)
|
||||
|
||||
s.NotZero(a.ID)
|
||||
s.NotZero(a.Date)
|
||||
s.Equal("admin", a.Accounts)
|
||||
s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord)
|
||||
|
||||
b := Admin{
|
||||
Accounts: "admin2",
|
||||
LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E",
|
||||
Date: dateNow,
|
||||
}
|
||||
|
||||
err = adminCollection.InsertReturning(&b)
|
||||
s.NoError(err)
|
||||
|
||||
s.NotZero(b.ID)
|
||||
s.NotZero(b.Date)
|
||||
s.Equal("admin2", b.Accounts)
|
||||
s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord)
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestIssue469_BadConnection() {
|
||||
var err error
|
||||
sess := s.Session()
|
||||
|
||||
// Ask the MySQL server to disconnect sessions that remain inactive for more
|
||||
// than 1 second.
|
||||
_, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`)
|
||||
s.NoError(err)
|
||||
|
||||
// Remain inactive for 2 seconds.
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
// A query should start a new connection, even if the server disconnected us.
|
||||
_, err = sess.Collection("artist").Find().Count()
|
||||
s.NoError(err)
|
||||
|
||||
// This is a new session, ask the MySQL server to disconnect sessions that
|
||||
// remain inactive for more than 1 second.
|
||||
_, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`)
|
||||
s.NoError(err)
|
||||
|
||||
// Remain inactive for 2 seconds.
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
// At this point the server should have disconnected us. Let's try to create
|
||||
// a transaction anyway.
|
||||
err = sess.Tx(func(sess mydb.Session) error {
|
||||
var err error
|
||||
|
||||
_, err = sess.Collection("artist").Find().Count()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
// This is a new session, ask the MySQL server to disconnect sessions that
|
||||
// remain inactive for more than 1 second.
|
||||
_, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`)
|
||||
s.NoError(err)
|
||||
|
||||
err = sess.Tx(func(sess mydb.Session) error {
|
||||
var err error
|
||||
|
||||
// This query should succeed.
|
||||
_, err = sess.Collection("artist").Find().Count()
|
||||
if err != nil {
|
||||
panic(err.Error())
|
||||
}
|
||||
|
||||
// Remain inactive for 2 seconds.
|
||||
time.Sleep(time.Second * 2)
|
||||
|
||||
// This query should fail because the server disconnected us in the middle
|
||||
// of a transaction.
|
||||
_, err = sess.Collection("artist").Find().Count()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
s.Error(err, "Expecting an error (can't recover from this)")
|
||||
}
|
||||
|
||||
func (s *AdapterTests) TestMySQLTypes() {
|
||||
sess := s.Session()
|
||||
|
||||
type MyType struct {
|
||||
ID int64 `db:"id,omitempty"`
|
||||
|
||||
JSONMap JSONMap `db:"json_map"`
|
||||
|
||||
JSONObject JSONMap `db:"json_object"`
|
||||
JSONArray JSONArray `db:"json_array"`
|
||||
|
||||
CustomJSONObject customJSON `db:"custom_json_object"`
|
||||
AutoCustomJSONObject autoCustomJSON `db:"auto_custom_json_object"`
|
||||
|
||||
CustomJSONObjectPtr *customJSON `db:"custom_json_object_ptr,omitempty"`
|
||||
AutoCustomJSONObjectPtr *autoCustomJSON `db:"auto_custom_json_object_ptr,omitempty"`
|
||||
|
||||
AutoCustomJSONObjectArray []autoCustomJSON `db:"auto_custom_json_object_array"`
|
||||
AutoCustomJSONObjectMap map[string]autoCustomJSON `db:"auto_custom_json_object_map"`
|
||||
|
||||
Int64CompatValueJSONArray []int64Compat `db:"integer_compat_value_json_array"`
|
||||
UIntCompatValueJSONArray uintCompatArray `db:"uinteger_compat_value_json_array"`
|
||||
StringCompatValueJSONArray []stringCompat `db:"string_compat_value_json_array"`
|
||||
}
|
||||
|
||||
origMyTypeTests := []MyType{
|
||||
MyType{
|
||||
Int64CompatValueJSONArray: []int64Compat{1, -2, 3, -4},
|
||||
UIntCompatValueJSONArray: []uintCompat{1, 2, 3, 4},
|
||||
StringCompatValueJSONArray: []stringCompat{"a", "b", "", "c"},
|
||||
},
|
||||
MyType{
|
||||
Int64CompatValueJSONArray: []int64Compat(nil),
|
||||
UIntCompatValueJSONArray: []uintCompat(nil),
|
||||
StringCompatValueJSONArray: []stringCompat(nil),
|
||||
},
|
||||
MyType{
|
||||
AutoCustomJSONObjectArray: []autoCustomJSON{
|
||||
autoCustomJSON{
|
||||
N: "Hello",
|
||||
},
|
||||
autoCustomJSON{
|
||||
N: "World",
|
||||
},
|
||||
},
|
||||
AutoCustomJSONObjectMap: map[string]autoCustomJSON{
|
||||
"a": autoCustomJSON{
|
||||
N: "Hello",
|
||||
},
|
||||
"b": autoCustomJSON{
|
||||
N: "World",
|
||||
},
|
||||
},
|
||||
JSONArray: JSONArray{float64(1), float64(2), float64(3), float64(4)},
|
||||
},
|
||||
MyType{
|
||||
JSONArray: JSONArray{},
|
||||
},
|
||||
MyType{
|
||||
JSONArray: JSONArray(nil),
|
||||
},
|
||||
MyType{},
|
||||
MyType{
|
||||
CustomJSONObject: customJSON{
|
||||
N: "Hello",
|
||||
},
|
||||
AutoCustomJSONObject: autoCustomJSON{
|
||||
N: "World",
|
||||
},
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObject: customJSON{},
|
||||
AutoCustomJSONObject: autoCustomJSON{},
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObject: customJSON{
|
||||
N: "Hello 1",
|
||||
},
|
||||
AutoCustomJSONObject: autoCustomJSON{
|
||||
N: "World 2",
|
||||
},
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObjectPtr: nil,
|
||||
AutoCustomJSONObjectPtr: nil,
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObjectPtr: &customJSON{},
|
||||
AutoCustomJSONObjectPtr: &autoCustomJSON{},
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObjectPtr: &customJSON{
|
||||
N: "Hello 3",
|
||||
},
|
||||
AutoCustomJSONObjectPtr: &autoCustomJSON{
|
||||
N: "World 4",
|
||||
},
|
||||
},
|
||||
MyType{},
|
||||
MyType{
|
||||
CustomJSONObject: customJSON{
|
||||
V: 4.4,
|
||||
},
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObject: customJSON{},
|
||||
},
|
||||
MyType{
|
||||
CustomJSONObject: customJSON{
|
||||
N: "Peter",
|
||||
V: 5.56,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for i := 0; i < 100; i++ {
|
||||
|
||||
myTypeTests := make([]MyType, len(origMyTypeTests))
|
||||
perm := rand.Perm(len(origMyTypeTests))
|
||||
for i, v := range perm {
|
||||
myTypeTests[v] = origMyTypeTests[i]
|
||||
}
|
||||
|
||||
for i := range myTypeTests {
|
||||
result, err := sess.Collection("my_types").Insert(myTypeTests[i])
|
||||
s.NoError(err)
|
||||
|
||||
var actual MyType
|
||||
err = sess.Collection("my_types").Find(result).One(&actual)
|
||||
s.NoError(err)
|
||||
|
||||
expected := myTypeTests[i]
|
||||
expected.ID = result.ID().(int64)
|
||||
s.Equal(expected, actual)
|
||||
}
|
||||
|
||||
for i := range myTypeTests {
|
||||
res, err := sess.SQL().InsertInto("my_types").Values(myTypeTests[i]).Exec()
|
||||
s.NoError(err)
|
||||
|
||||
id, err := res.LastInsertId()
|
||||
s.NoError(err)
|
||||
s.NotEqual(0, id)
|
||||
|
||||
var actual MyType
|
||||
err = sess.Collection("my_types").Find(id).One(&actual)
|
||||
s.NoError(err)
|
||||
|
||||
expected := myTypeTests[i]
|
||||
expected.ID = id
|
||||
|
||||
s.Equal(expected, actual)
|
||||
|
||||
var actual2 MyType
|
||||
err = sess.SQL().SelectFrom("my_types").Where("id = ?", id).One(&actual2)
|
||||
s.NoError(err)
|
||||
s.Equal(expected, actual2)
|
||||
}
|
||||
|
||||
inserter := sess.SQL().InsertInto("my_types")
|
||||
for i := range myTypeTests {
|
||||
inserter = inserter.Values(myTypeTests[i])
|
||||
}
|
||||
_, err := inserter.Exec()
|
||||
s.NoError(err)
|
||||
|
||||
err = sess.Collection("my_types").Truncate()
|
||||
s.NoError(err)
|
||||
|
||||
batch := sess.SQL().InsertInto("my_types").Batch(50)
|
||||
go func() {
|
||||
defer batch.Done()
|
||||
for i := range myTypeTests {
|
||||
batch.Values(myTypeTests[i])
|
||||
}
|
||||
}()
|
||||
|
||||
err = batch.Wait()
|
||||
s.NoError(err)
|
||||
|
||||
var values []MyType
|
||||
err = sess.SQL().SelectFrom("my_types").All(&values)
|
||||
s.NoError(err)
|
||||
|
||||
for i := range values {
|
||||
expected := myTypeTests[i]
|
||||
expected.ID = values[i].ID
|
||||
s.Equal(expected, values[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestAdapter(t *testing.T) {
|
||||
suite.Run(t, &AdapterTests{})
|
||||
}
|
20
adapter/mysql/record_test.go
Normal file
20
adapter/mysql/record_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RecordTests struct {
|
||||
testsuite.RecordTestSuite
|
||||
}
|
||||
|
||||
func (s *RecordTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
suite.Run(t, &RecordTests{})
|
||||
}
|
20
adapter/mysql/sql_test.go
Normal file
20
adapter/mysql/sql_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SQLTests struct {
|
||||
testsuite.SQLTestSuite
|
||||
}
|
||||
|
||||
func (s *SQLTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
suite.Run(t, &SQLTests{})
|
||||
}
|
198
adapter/mysql/template.go
Normal file
198
adapter/mysql/template.go
Normal file
@ -0,0 +1,198 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"git.hexq.cn/tiglog/mydb/internal/cache"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
)
|
||||
|
||||
const (
|
||||
adapterColumnSeparator = `.`
|
||||
adapterIdentifierSeparator = `, `
|
||||
adapterIdentifierQuote = "`{{.Value}}`"
|
||||
adapterValueSeparator = `, `
|
||||
adapterValueQuote = `'{{.}}'`
|
||||
adapterAndKeyword = `AND`
|
||||
adapterOrKeyword = `OR`
|
||||
adapterDescKeyword = `DESC`
|
||||
adapterAscKeyword = `ASC`
|
||||
adapterAssignmentOperator = `=`
|
||||
adapterClauseGroup = `({{.}})`
|
||||
adapterClauseOperator = ` {{.}} `
|
||||
adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}`
|
||||
adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
|
||||
adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
|
||||
adapterSortByColumnLayout = `{{.Column}} {{.Order}}`
|
||||
|
||||
adapterOrderByLayout = `
|
||||
{{if .SortColumns}}
|
||||
ORDER BY {{.SortColumns}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterWhereLayout = `
|
||||
{{if .Conds}}
|
||||
WHERE {{.Conds}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterUsingLayout = `
|
||||
{{if .Columns}}
|
||||
USING ({{.Columns}})
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterJoinLayout = `
|
||||
{{if .Table}}
|
||||
{{ if .On }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{.On}}
|
||||
{{ else if .Using }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{.Using}}
|
||||
{{ else if .Type | eq "CROSS" }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{else}}
|
||||
NATURAL {{.Type}} JOIN {{.Table}}
|
||||
{{end}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterOnLayout = `
|
||||
{{if .Conds}}
|
||||
ON {{.Conds}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterSelectLayout = `
|
||||
SELECT
|
||||
{{if .Distinct}}
|
||||
DISTINCT
|
||||
{{end}}
|
||||
|
||||
{{if defined .Columns}}
|
||||
{{.Columns | compile}}
|
||||
{{else}}
|
||||
*
|
||||
{{end}}
|
||||
|
||||
{{if defined .Table}}
|
||||
FROM {{.Table | compile}}
|
||||
{{end}}
|
||||
|
||||
{{.Joins | compile}}
|
||||
|
||||
{{.Where | compile}}
|
||||
|
||||
{{if defined .GroupBy}}
|
||||
{{.GroupBy | compile}}
|
||||
{{end}}
|
||||
|
||||
{{.OrderBy | compile}}
|
||||
|
||||
{{if .Limit}}
|
||||
LIMIT {{.Limit}}
|
||||
{{end}}
|
||||
` +
|
||||
// The argument for LIMIT when only OFFSET is specified is a pretty odd magic
|
||||
// number; this comes directly from MySQL's manual, see:
|
||||
// https://dev.mysql.com/doc/refman/5.7/en/select.html
|
||||
//
|
||||
// "To retrieve all rows from a certain offset up to the end of the result
|
||||
// set, you can use some large number for the second parameter. This
|
||||
// statement retrieves all rows from the 96th row to the last:
|
||||
// SELECT * FROM tbl LIMIT 95,18446744073709551615; "
|
||||
//
|
||||
// ¯\_(ツ)_/¯
|
||||
`
|
||||
{{if .Offset}}
|
||||
{{if not .Limit}}
|
||||
LIMIT 18446744073709551615
|
||||
{{end}}
|
||||
OFFSET {{.Offset}}
|
||||
{{end}}
|
||||
`
|
||||
adapterDeleteLayout = `
|
||||
DELETE
|
||||
FROM {{.Table | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
adapterUpdateLayout = `
|
||||
UPDATE
|
||||
{{.Table | compile}}
|
||||
SET {{.ColumnValues | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
|
||||
adapterSelectCountLayout = `
|
||||
SELECT
|
||||
COUNT(1) AS _t
|
||||
FROM {{.Table | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
|
||||
adapterInsertLayout = `
|
||||
INSERT INTO {{.Table | compile}}
|
||||
{{if defined .Columns}}({{.Columns | compile}}){{end}}
|
||||
VALUES
|
||||
{{if defined .Values}}
|
||||
{{.Values | compile}}
|
||||
{{else}}
|
||||
()
|
||||
{{end}}
|
||||
{{if defined .Returning}}
|
||||
RETURNING {{.Returning | compile}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterTruncateLayout = `
|
||||
TRUNCATE TABLE {{.Table | compile}}
|
||||
`
|
||||
|
||||
adapterDropDatabaseLayout = `
|
||||
DROP DATABASE {{.Database | compile}}
|
||||
`
|
||||
|
||||
adapterDropTableLayout = `
|
||||
DROP TABLE {{.Table | compile}}
|
||||
`
|
||||
|
||||
adapterGroupByLayout = `
|
||||
{{if .GroupColumns}}
|
||||
GROUP BY {{.GroupColumns}}
|
||||
{{end}}
|
||||
`
|
||||
)
|
||||
|
||||
var template = &exql.Template{
|
||||
ColumnSeparator: adapterColumnSeparator,
|
||||
IdentifierSeparator: adapterIdentifierSeparator,
|
||||
IdentifierQuote: adapterIdentifierQuote,
|
||||
ValueSeparator: adapterValueSeparator,
|
||||
ValueQuote: adapterValueQuote,
|
||||
AndKeyword: adapterAndKeyword,
|
||||
OrKeyword: adapterOrKeyword,
|
||||
DescKeyword: adapterDescKeyword,
|
||||
AscKeyword: adapterAscKeyword,
|
||||
AssignmentOperator: adapterAssignmentOperator,
|
||||
ClauseGroup: adapterClauseGroup,
|
||||
ClauseOperator: adapterClauseOperator,
|
||||
ColumnValue: adapterColumnValue,
|
||||
TableAliasLayout: adapterTableAliasLayout,
|
||||
ColumnAliasLayout: adapterColumnAliasLayout,
|
||||
SortByColumnLayout: adapterSortByColumnLayout,
|
||||
WhereLayout: adapterWhereLayout,
|
||||
JoinLayout: adapterJoinLayout,
|
||||
OnLayout: adapterOnLayout,
|
||||
UsingLayout: adapterUsingLayout,
|
||||
OrderByLayout: adapterOrderByLayout,
|
||||
InsertLayout: adapterInsertLayout,
|
||||
SelectLayout: adapterSelectLayout,
|
||||
UpdateLayout: adapterUpdateLayout,
|
||||
DeleteLayout: adapterDeleteLayout,
|
||||
TruncateLayout: adapterTruncateLayout,
|
||||
DropDatabaseLayout: adapterDropDatabaseLayout,
|
||||
DropTableLayout: adapterDropTableLayout,
|
||||
CountLayout: adapterSelectCountLayout,
|
||||
GroupByLayout: adapterGroupByLayout,
|
||||
Cache: cache.NewCache(),
|
||||
}
|
269
adapter/mysql/template_test.go
Normal file
269
adapter/mysql/template_test.go
Normal file
@ -0,0 +1,269 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTemplateSelect(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist`",
|
||||
b.SelectFrom("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist`",
|
||||
b.Select().From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` ORDER BY `name` DESC",
|
||||
b.Select().From("artist").OrderBy("name DESC").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` ORDER BY `name` DESC",
|
||||
b.Select().From("artist").OrderBy("-name").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` ORDER BY `name` ASC",
|
||||
b.Select().From("artist").OrderBy("name").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` ORDER BY `name` ASC",
|
||||
b.Select().From("artist").OrderBy("name ASC").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` LIMIT 18446744073709551615 OFFSET 5",
|
||||
b.Select().From("artist").Limit(-1).Offset(5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT `id` FROM `artist`",
|
||||
b.Select("id").From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT `id`, `name` FROM `artist`",
|
||||
b.Select("id", "name").From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`name` = $1)",
|
||||
b.SelectFrom("artist").Where("name", "Haruki").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (name LIKE $1)",
|
||||
b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT `id` FROM `artist` WHERE (name LIKE $1 OR name LIKE $2)",
|
||||
b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`id` > $1)",
|
||||
b.SelectFrom("artist").Where("id >", 2).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (id <= 2 AND name != $1)",
|
||||
b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2, $3, $4))",
|
||||
b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (name IS NOT NULL)",
|
||||
b.SelectFrom("artist").Where("name IS NOT NULL").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` AS `a`, `publication` AS `p` WHERE (p.author_id = a.id) LIMIT 1",
|
||||
b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT `id` FROM `artist` NATURAL JOIN `publication`",
|
||||
b.Select("id").From("artist").Join("publication").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) LIMIT 1",
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (`a`.`id` = $1) LIMIT 1",
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1",
|
||||
b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1",
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` AS `a` LEFT JOIN `publication` AS `p1` ON (p1.id = a.id) RIGHT JOIN `publication` AS `p2` ON (p2.id = a.id)",
|
||||
b.SelectFrom("artist a").
|
||||
LeftJoin("publication p1").On("p1.id = a.id").
|
||||
RightJoin("publication p2").On("p2.id = a.id").
|
||||
String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` CROSS JOIN `publication`",
|
||||
b.SelectFrom("artist").CrossJoin("publication").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` JOIN `publication` USING (`id`)",
|
||||
b.SelectFrom("artist").Join("publication").Using("id").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT DATE()",
|
||||
b.Select(mydb.Raw("DATE()")).String(),
|
||||
)
|
||||
|
||||
// Issue #408
|
||||
{
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)",
|
||||
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id IN": []int{1, 2}}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`id` = $1 AND `name` LIKE $2)",
|
||||
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": []byte{1, 2}}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)",
|
||||
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.In(1, 2)}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)",
|
||||
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.AnyOf([]int{1, 2})}).String(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTemplateInsert(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
"INSERT INTO `artist` VALUES ($1, $2), ($3, $4), ($5, $6)",
|
||||
b.InsertInto("artist").
|
||||
Values(10, "Ryuichi Sakamoto").
|
||||
Values(11, "Alondra de la Parra").
|
||||
Values(12, "Haruki Murakami").
|
||||
String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)",
|
||||
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2) RETURNING `id`",
|
||||
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)",
|
||||
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)",
|
||||
b.InsertInto("artist").Values(struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
}{12, "Chavela Vargas"}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"INSERT INTO `artist` (`name`, `id`) VALUES ($1, $2)",
|
||||
b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateUpdate(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
"UPDATE `artist` SET `name` = $1",
|
||||
b.Update("artist").Set("name", "Artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)",
|
||||
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)",
|
||||
b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)",
|
||||
b.Update("artist").Set(struct {
|
||||
Nombre string `db:"name"`
|
||||
}{"Artist"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"UPDATE `artist` SET `name` = $1, `last_name` = $2 WHERE (`id` < $3)",
|
||||
b.Update("artist").Set(struct {
|
||||
Nombre string `db:"name"`
|
||||
}{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"UPDATE `artist` SET `name` = $1 || ' ' || $2 || id, `id` = id + $3 WHERE (id > $4)",
|
||||
b.Update("artist").Set(
|
||||
"name = ? || ' ' || ? || id", "Artist", "#",
|
||||
"id = id + ?", 10,
|
||||
).Where("id > ?", 0).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateDelete(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
"DELETE FROM `artist` WHERE (name = $1)",
|
||||
b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
"DELETE FROM `artist` WHERE (id > 5)",
|
||||
b.DeleteFrom("artist").Where("id > 5").String(),
|
||||
)
|
||||
}
|
44
adapter/postgresql/Makefile
Normal file
44
adapter/postgresql/Makefile
Normal file
@ -0,0 +1,44 @@
|
||||
SHELL ?= bash
|
||||
|
||||
POSTGRES_VERSION ?= 15-alpine
|
||||
POSTGRES_SUPPORTED ?= $(POSTGRES_VERSION) 14-alpine 13-alpine 12-alpine
|
||||
|
||||
PROJECT ?= upper_postgres_$(POSTGRES_VERSION)
|
||||
|
||||
DB_HOST ?= 127.0.0.1
|
||||
DB_PORT ?= 5432
|
||||
|
||||
DB_NAME ?= upperio
|
||||
DB_USERNAME ?= upperio_user
|
||||
DB_PASSWORD ?= upperio//s3cr37
|
||||
|
||||
TEST_FLAGS ?=
|
||||
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1
|
||||
|
||||
export POSTGRES_VERSION
|
||||
|
||||
export DB_HOST
|
||||
export DB_NAME
|
||||
export DB_PASSWORD
|
||||
export DB_PORT
|
||||
export DB_USERNAME
|
||||
|
||||
export TEST_FLAGS
|
||||
|
||||
test:
|
||||
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
|
||||
|
||||
test-no-race:
|
||||
go test -v -failfast $(TEST_FLAGS)
|
||||
|
||||
server-up: server-down
|
||||
docker-compose -p $(PROJECT) up -d && \
|
||||
sleep 10
|
||||
|
||||
server-down:
|
||||
docker-compose -p $(PROJECT) down
|
||||
|
||||
test-extended:
|
||||
parallel $(PARALLEL_FLAGS) \
|
||||
"POSTGRES_VERSION={} DB_PORT=\$$((5432+{#})) $(MAKE) server-up test server-down" ::: \
|
||||
$(POSTGRES_SUPPORTED)
|
5
adapter/postgresql/README.md
Normal file
5
adapter/postgresql/README.md
Normal file
@ -0,0 +1,5 @@
|
||||
# PostgreSQL adapter for upper/db
|
||||
|
||||
Please read the full docs, acknowledgements and examples at
|
||||
[https://upper.io/v4/adapter/postgresql/](https://upper.io/v4/adapter/postgresql/).
|
||||
|
50
adapter/postgresql/collection.go
Normal file
50
adapter/postgresql/collection.go
Normal file
@ -0,0 +1,50 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
)
|
||||
|
||||
type collectionAdapter struct {
|
||||
}
|
||||
|
||||
func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) {
|
||||
pKey, err := col.PrimaryKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := col.SQL().InsertInto(col.Name()).Values(item)
|
||||
|
||||
if len(pKey) == 0 {
|
||||
// There is no primary key.
|
||||
res, err := q.Exec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Attempt to use LastInsertId() (probably won't work, but the Exec()
|
||||
// succeeded, so we can safely ignore the error from LastInsertId()).
|
||||
lastID, err := res.LastInsertId()
|
||||
if err != nil {
|
||||
return nil, nil
|
||||
}
|
||||
return lastID, nil
|
||||
}
|
||||
|
||||
// Asking the database to return the primary key after insertion.
|
||||
q = q.Returning(pKey...)
|
||||
|
||||
var keyMap mydb.Cond
|
||||
if err := q.Iterator().One(&keyMap); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The IDSetter interface does not match, look for another interface match.
|
||||
if len(keyMap) == 1 {
|
||||
return keyMap[pKey[0]], nil
|
||||
}
|
||||
|
||||
// This was a compound key and no interface matched it, let's return a map.
|
||||
return keyMap, nil
|
||||
}
|
289
adapter/postgresql/connection.go
Normal file
289
adapter/postgresql/connection.go
Normal file
@ -0,0 +1,289 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/url"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
"unicode"
|
||||
)
|
||||
|
||||
// scanner implements a tokenizer for libpq-style option strings.
|
||||
type scanner struct {
|
||||
s []rune
|
||||
i int
|
||||
}
|
||||
|
||||
// Next returns the next rune. It returns 0, false if the end of the text has
|
||||
// been reached.
|
||||
func (s *scanner) Next() (rune, bool) {
|
||||
if s.i >= len(s.s) {
|
||||
return 0, false
|
||||
}
|
||||
r := s.s[s.i]
|
||||
s.i++
|
||||
return r, true
|
||||
}
|
||||
|
||||
// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the
|
||||
// end of the text has been reached.
|
||||
func (s *scanner) SkipSpaces() (rune, bool) {
|
||||
r, ok := s.Next()
|
||||
for unicode.IsSpace(r) && ok {
|
||||
r, ok = s.Next()
|
||||
}
|
||||
return r, ok
|
||||
}
|
||||
|
||||
type values map[string]string
|
||||
|
||||
func (vs values) Set(k, v string) {
|
||||
vs[k] = v
|
||||
}
|
||||
|
||||
func (vs values) Get(k string) (v string) {
|
||||
return vs[k]
|
||||
}
|
||||
|
||||
func (vs values) Isset(k string) bool {
|
||||
_, ok := vs[k]
|
||||
return ok
|
||||
}
|
||||
|
||||
// ConnectionURL represents a parsed PostgreSQL connection URL.
|
||||
//
|
||||
// You can use a ConnectionURL struct as an argument for Open:
|
||||
//
|
||||
// var settings = postgresql.ConnectionURL{
|
||||
// Host: "localhost", // PostgreSQL server IP or name.
|
||||
// Database: "peanuts", // Database name.
|
||||
// User: "cbrown", // Optional user name.
|
||||
// Password: "snoopy", // Optional user password.
|
||||
// }
|
||||
//
|
||||
// sess, err = postgresql.Open(settings)
|
||||
//
|
||||
// If you already have a valid DSN, you can use ParseURL to convert it into
|
||||
// a ConnectionURL before passing it to Open.
|
||||
type ConnectionURL struct {
|
||||
User string
|
||||
Password string
|
||||
Host string
|
||||
Socket string
|
||||
Database string
|
||||
Options map[string]string
|
||||
|
||||
timezone *time.Location
|
||||
}
|
||||
|
||||
var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
|
||||
|
||||
// ParseURL parses the given DSN into a ConnectionURL struct.
|
||||
// A typical PostgreSQL connection URL looks like:
|
||||
//
|
||||
// postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
|
||||
func ParseURL(s string) (u *ConnectionURL, err error) {
|
||||
o := make(values)
|
||||
|
||||
if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") {
|
||||
s, err = parseURL(s)
|
||||
if err != nil {
|
||||
return u, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := parseOpts(s, o); err != nil {
|
||||
return u, err
|
||||
}
|
||||
u = &ConnectionURL{}
|
||||
|
||||
u.User = o.Get("user")
|
||||
u.Password = o.Get("password")
|
||||
|
||||
h := o.Get("host")
|
||||
p := o.Get("port")
|
||||
|
||||
if strings.HasPrefix(h, "/") {
|
||||
u.Socket = h
|
||||
} else {
|
||||
if p == "" {
|
||||
u.Host = h
|
||||
} else {
|
||||
u.Host = fmt.Sprintf("%s:%s", h, p)
|
||||
}
|
||||
}
|
||||
|
||||
u.Database = o.Get("dbname")
|
||||
|
||||
u.Options = make(map[string]string)
|
||||
|
||||
for k := range o {
|
||||
switch k {
|
||||
case "user", "password", "host", "port", "dbname":
|
||||
// Skip
|
||||
default:
|
||||
u.Options[k] = o[k]
|
||||
}
|
||||
}
|
||||
|
||||
if timezone, ok := u.Options["timezone"]; ok {
|
||||
u.timezone, _ = time.LoadLocation(timezone)
|
||||
}
|
||||
|
||||
return u, err
|
||||
}
|
||||
|
||||
// parseOpts parses the options from name and adds them to the values.
|
||||
//
|
||||
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
|
||||
func parseOpts(name string, o values) error {
|
||||
s := newScanner(name)
|
||||
|
||||
for {
|
||||
var (
|
||||
keyRunes, valRunes []rune
|
||||
r rune
|
||||
ok bool
|
||||
)
|
||||
|
||||
if r, ok = s.SkipSpaces(); !ok {
|
||||
break
|
||||
}
|
||||
|
||||
// Scan the key
|
||||
for !unicode.IsSpace(r) && r != '=' {
|
||||
keyRunes = append(keyRunes, r)
|
||||
if r, ok = s.Next(); !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Skip any whitespace if we're not at the = yet
|
||||
if r != '=' {
|
||||
r, ok = s.SkipSpaces()
|
||||
}
|
||||
|
||||
// The current character should be =
|
||||
if r != '=' || !ok {
|
||||
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
|
||||
}
|
||||
|
||||
// Skip any whitespace after the =
|
||||
if r, ok = s.SkipSpaces(); !ok {
|
||||
// If we reach the end here, the last value is just an empty string as per libpq.
|
||||
o.Set(string(keyRunes), "")
|
||||
break
|
||||
}
|
||||
|
||||
if r != '\'' {
|
||||
for !unicode.IsSpace(r) {
|
||||
if r == '\\' {
|
||||
if r, ok = s.Next(); !ok {
|
||||
return fmt.Errorf(`missing character after backslash`)
|
||||
}
|
||||
}
|
||||
valRunes = append(valRunes, r)
|
||||
|
||||
if r, ok = s.Next(); !ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
quote:
|
||||
for {
|
||||
if r, ok = s.Next(); !ok {
|
||||
return fmt.Errorf(`unterminated quoted string literal in connection string`)
|
||||
}
|
||||
switch r {
|
||||
case '\'':
|
||||
break quote
|
||||
case '\\':
|
||||
r, _ = s.Next()
|
||||
fallthrough
|
||||
default:
|
||||
valRunes = append(valRunes, r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
o.Set(string(keyRunes), string(valRunes))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// newScanner returns a new scanner initialized with the option string s.
|
||||
func newScanner(s string) *scanner {
|
||||
return &scanner{[]rune(s), 0}
|
||||
}
|
||||
|
||||
// ParseURL no longer needs to be used by clients of this library since supplying a URL as a
|
||||
// connection string to sql.Open() is now supported:
|
||||
//
|
||||
// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")
|
||||
//
|
||||
// It remains exported here for backwards-compatibility.
|
||||
//
|
||||
// ParseURL converts a url to a connection string for driver.Open.
|
||||
// Example:
|
||||
//
|
||||
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
|
||||
//
|
||||
// converts to:
|
||||
//
|
||||
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
|
||||
//
|
||||
// A minimal example:
|
||||
//
|
||||
// "postgres://"
|
||||
//
|
||||
// # This will be blank, causing driver.Open to use all of the defaults
|
||||
//
|
||||
// NOTE: vendored/copied from github.com/lib/pq
|
||||
func parseURL(uri string) (string, error) {
|
||||
u, err := url.Parse(uri)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
|
||||
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
||||
}
|
||||
|
||||
var kvs []string
|
||||
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
|
||||
accrue := func(k, v string) {
|
||||
if v != "" {
|
||||
kvs = append(kvs, k+"="+escaper.Replace(v))
|
||||
}
|
||||
}
|
||||
|
||||
if u.User != nil {
|
||||
v := u.User.Username()
|
||||
accrue("user", v)
|
||||
|
||||
v, _ = u.User.Password()
|
||||
accrue("password", v)
|
||||
}
|
||||
|
||||
if host, port, err := net.SplitHostPort(u.Host); err != nil {
|
||||
accrue("host", u.Host)
|
||||
} else {
|
||||
accrue("host", host)
|
||||
accrue("port", port)
|
||||
}
|
||||
|
||||
if u.Path != "" {
|
||||
accrue("dbname", u.Path[1:])
|
||||
}
|
||||
|
||||
q := u.Query()
|
||||
for k := range q {
|
||||
accrue(k, q.Get(k))
|
||||
}
|
||||
|
||||
sort.Strings(kvs) // Makes testing easier (not a performance concern)
|
||||
return strings.Join(kvs, " "), nil
|
||||
}
|
73
adapter/postgresql/connection_pgx.go
Normal file
73
adapter/postgresql/connection_pgx.go
Normal file
@ -0,0 +1,73 @@
|
||||
//go:build !pq
|
||||
// +build !pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// String reassembles the parsed PostgreSQL connection URL into a valid DSN.
|
||||
func (c ConnectionURL) String() (s string) {
|
||||
u := []string{}
|
||||
|
||||
// TODO: This surely needs some sort of escaping.
|
||||
if c.User != "" {
|
||||
u = append(u, "user="+escaper.Replace(c.User))
|
||||
}
|
||||
|
||||
if c.Password != "" {
|
||||
u = append(u, "password="+escaper.Replace(c.Password))
|
||||
}
|
||||
|
||||
if c.Host != "" {
|
||||
host, port, err := net.SplitHostPort(c.Host)
|
||||
if err == nil {
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
if port == "" {
|
||||
port = "5432"
|
||||
}
|
||||
u = append(u, "host="+escaper.Replace(host))
|
||||
u = append(u, "port="+escaper.Replace(port))
|
||||
} else {
|
||||
u = append(u, "host="+escaper.Replace(c.Host))
|
||||
}
|
||||
}
|
||||
|
||||
if c.Socket != "" {
|
||||
u = append(u, "host="+escaper.Replace(c.Socket))
|
||||
}
|
||||
|
||||
if c.Database != "" {
|
||||
u = append(u, "dbname="+escaper.Replace(c.Database))
|
||||
}
|
||||
|
||||
// Is there actually any connection data?
|
||||
if len(u) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if c.Options == nil {
|
||||
c.Options = map[string]string{}
|
||||
}
|
||||
|
||||
// If not present, SSL mode is assumed "prefer".
|
||||
if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" {
|
||||
c.Options["sslmode"] = "prefer"
|
||||
}
|
||||
|
||||
// Disabled by default
|
||||
c.Options["statement_cache_capacity"] = "0"
|
||||
|
||||
for k, v := range c.Options {
|
||||
u = append(u, escaper.Replace(k)+"="+escaper.Replace(v))
|
||||
}
|
||||
|
||||
sort.Strings(u)
|
||||
|
||||
return strings.Join(u, " ")
|
||||
}
|
108
adapter/postgresql/connection_pgx_test.go
Normal file
108
adapter/postgresql/connection_pgx_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
//go:build !pq
|
||||
// +build !pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectionURL(t *testing.T) {
|
||||
c := ConnectionURL{}
|
||||
|
||||
// Default connection string is empty.
|
||||
assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty")
|
||||
|
||||
// Adding a host with port.
|
||||
c.Host = "localhost:1234"
|
||||
assert.Equal(t, "host=localhost port=1234 sslmode=prefer statement_cache_capacity=0", c.String())
|
||||
|
||||
// Adding a host.
|
||||
c.Host = "localhost"
|
||||
assert.Equal(t, "host=localhost sslmode=prefer statement_cache_capacity=0", c.String())
|
||||
|
||||
// Adding a username.
|
||||
c.User = "Anakin"
|
||||
assert.Equal(t, `host=localhost sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
|
||||
|
||||
// Adding a password with special characters.
|
||||
c.Password = "Some Sort of ' Password"
|
||||
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
|
||||
|
||||
// Adding a port.
|
||||
c.Host = "localhost:1234"
|
||||
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
|
||||
|
||||
// Adding a database.
|
||||
c.Database = "MyDatabase"
|
||||
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
|
||||
|
||||
// Adding options.
|
||||
c.Options = map[string]string{
|
||||
"sslmode": "verify-full",
|
||||
}
|
||||
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full statement_cache_capacity=0 user=Anakin`, c.String())
|
||||
}
|
||||
|
||||
func TestParseConnectionURL(t *testing.T) {
|
||||
|
||||
{
|
||||
s := "postgres://anakin:skywalker@localhost/jedis"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "anakin", u.User)
|
||||
assert.Equal(t, "skywalker", u.Password)
|
||||
assert.Equal(t, "localhost", u.Host)
|
||||
assert.Equal(t, "jedis", u.Database)
|
||||
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
|
||||
}
|
||||
|
||||
{
|
||||
// case with port
|
||||
s := "postgres://anakin:skywalker@localhost:1234/jedis"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "anakin", u.User)
|
||||
assert.Equal(t, "skywalker", u.Password)
|
||||
assert.Equal(t, "jedis", u.Database)
|
||||
assert.Equal(t, "localhost:1234", u.Host)
|
||||
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
|
||||
}
|
||||
|
||||
{
|
||||
s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "verify-full", u.Options["sslmode"])
|
||||
}
|
||||
|
||||
{
|
||||
s := "user=anakin password=skywalker host=localhost dbname=jedis"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "anakin", u.User)
|
||||
assert.Equal(t, "skywalker", u.Password)
|
||||
assert.Equal(t, "jedis", u.Database)
|
||||
assert.Equal(t, "localhost", u.Host)
|
||||
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
|
||||
}
|
||||
|
||||
{
|
||||
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "verify-full", u.Options["sslmode"])
|
||||
}
|
||||
|
||||
{
|
||||
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(u.Options), "Expecting exactly two options")
|
||||
assert.Equal(t, "verify-full", u.Options["sslmode"])
|
||||
assert.Equal(t, "UTC", u.Options["timezone"])
|
||||
}
|
||||
}
|
70
adapter/postgresql/connection_pq.go
Normal file
70
adapter/postgresql/connection_pq.go
Normal file
@ -0,0 +1,70 @@
|
||||
//go:build pq
|
||||
// +build pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"net"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// String reassembles the parsed PostgreSQL connection URL into a valid DSN.
|
||||
func (c ConnectionURL) String() (s string) {
|
||||
u := []string{}
|
||||
|
||||
// TODO: This surely needs some sort of escaping.
|
||||
if c.User != "" {
|
||||
u = append(u, "user="+escaper.Replace(c.User))
|
||||
}
|
||||
|
||||
if c.Password != "" {
|
||||
u = append(u, "password="+escaper.Replace(c.Password))
|
||||
}
|
||||
|
||||
if c.Host != "" {
|
||||
host, port, err := net.SplitHostPort(c.Host)
|
||||
if err == nil {
|
||||
if host == "" {
|
||||
host = "127.0.0.1"
|
||||
}
|
||||
if port == "" {
|
||||
port = "5432"
|
||||
}
|
||||
u = append(u, "host="+escaper.Replace(host))
|
||||
u = append(u, "port="+escaper.Replace(port))
|
||||
} else {
|
||||
u = append(u, "host="+escaper.Replace(c.Host))
|
||||
}
|
||||
}
|
||||
|
||||
if c.Socket != "" {
|
||||
u = append(u, "host="+escaper.Replace(c.Socket))
|
||||
}
|
||||
|
||||
if c.Database != "" {
|
||||
u = append(u, "dbname="+escaper.Replace(c.Database))
|
||||
}
|
||||
|
||||
// Is there actually any connection data?
|
||||
if len(u) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
if c.Options == nil {
|
||||
c.Options = map[string]string{}
|
||||
}
|
||||
|
||||
// If not present, SSL mode is assumed "prefer".
|
||||
if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" {
|
||||
c.Options["sslmode"] = "prefer"
|
||||
}
|
||||
|
||||
for k, v := range c.Options {
|
||||
u = append(u, escaper.Replace(k)+"="+escaper.Replace(v))
|
||||
}
|
||||
|
||||
sort.Strings(u)
|
||||
|
||||
return strings.Join(u, " ")
|
||||
}
|
108
adapter/postgresql/connection_pq_test.go
Normal file
108
adapter/postgresql/connection_pq_test.go
Normal file
@ -0,0 +1,108 @@
|
||||
//go:build pq
|
||||
// +build pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestConnectionURL(t *testing.T) {
|
||||
c := ConnectionURL{}
|
||||
|
||||
// Default connection string is empty.
|
||||
assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty")
|
||||
|
||||
// Adding a host with port.
|
||||
c.Host = "localhost:1234"
|
||||
assert.Equal(t, "host=localhost port=1234 sslmode=prefer", c.String())
|
||||
|
||||
// Adding a host.
|
||||
c.Host = "localhost"
|
||||
assert.Equal(t, "host=localhost sslmode=prefer", c.String())
|
||||
|
||||
// Adding a username.
|
||||
c.User = "Anakin"
|
||||
assert.Equal(t, `host=localhost sslmode=prefer user=Anakin`, c.String())
|
||||
|
||||
// Adding a password with special characters.
|
||||
c.Password = "Some Sort of ' Password"
|
||||
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer user=Anakin`, c.String())
|
||||
|
||||
// Adding a port.
|
||||
c.Host = "localhost:1234"
|
||||
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String())
|
||||
|
||||
// Adding a database.
|
||||
c.Database = "MyDatabase"
|
||||
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String())
|
||||
|
||||
// Adding options.
|
||||
c.Options = map[string]string{
|
||||
"sslmode": "verify-full",
|
||||
}
|
||||
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full user=Anakin`, c.String())
|
||||
}
|
||||
|
||||
func TestParseConnectionURL(t *testing.T) {
|
||||
|
||||
{
|
||||
s := "postgres://anakin:skywalker@localhost/jedis"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
|
||||
assert.Equal(t, "anakin", u.User)
|
||||
assert.Equal(t, "skywalker", u.Password)
|
||||
assert.Equal(t, "localhost", u.Host)
|
||||
assert.Equal(t, "jedis", u.Database)
|
||||
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
|
||||
}
|
||||
|
||||
{
|
||||
// case with port
|
||||
s := "postgres://anakin:skywalker@localhost:1234/jedis"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "anakin", u.User)
|
||||
assert.Equal(t, "skywalker", u.Password)
|
||||
assert.Equal(t, "jedis", u.Database)
|
||||
assert.Equal(t, "localhost:1234", u.Host)
|
||||
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
|
||||
}
|
||||
|
||||
{
|
||||
s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "verify-full", u.Options["sslmode"])
|
||||
}
|
||||
|
||||
{
|
||||
s := "user=anakin password=skywalker host=localhost dbname=jedis"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "anakin", u.User)
|
||||
assert.Equal(t, "skywalker", u.Password)
|
||||
assert.Equal(t, "jedis", u.Database)
|
||||
assert.Equal(t, "localhost", u.Host)
|
||||
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
|
||||
}
|
||||
|
||||
{
|
||||
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "verify-full", u.Options["sslmode"])
|
||||
}
|
||||
|
||||
{
|
||||
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC"
|
||||
u, err := ParseURL(s)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 2, len(u.Options), "Expecting exactly two options")
|
||||
assert.Equal(t, "verify-full", u.Options["sslmode"])
|
||||
assert.Equal(t, "UTC", u.Options["timezone"])
|
||||
}
|
||||
}
|
126
adapter/postgresql/custom_types.go
Normal file
126
adapter/postgresql/custom_types.go
Normal file
@ -0,0 +1,126 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
// JSONBMap represents a map of interfaces with string keys
|
||||
// (`map[string]interface{}`) that is compatible with PostgreSQL's JSONB type.
|
||||
// JSONBMap satisfies sqlbuilder.ScannerValuer.
|
||||
type JSONBMap map[string]interface{}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (m JSONBMap) Value() (driver.Value, error) {
|
||||
return JSONBValue(m)
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (m *JSONBMap) Scan(src interface{}) error {
|
||||
*m = map[string]interface{}(nil)
|
||||
return ScanJSONB(m, src)
|
||||
}
|
||||
|
||||
// JSONBArray represents an array of any type (`[]interface{}`) that is
|
||||
// compatible with PostgreSQL's JSONB type. JSONBArray satisfies
|
||||
// sqlbuilder.ScannerValuer.
|
||||
type JSONBArray []interface{}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (a JSONBArray) Value() (driver.Value, error) {
|
||||
return JSONBValue(a)
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (a *JSONBArray) Scan(src interface{}) error {
|
||||
return ScanJSONB(a, src)
|
||||
}
|
||||
|
||||
// JSONBValue takes an interface and provides a driver.Value that can be
|
||||
// stored as a JSONB column.
|
||||
func JSONBValue(i interface{}) (driver.Value, error) {
|
||||
v := JSONB{i}
|
||||
return v.Value()
|
||||
}
|
||||
|
||||
// ScanJSONB decodes a JSON byte stream into the passed dst value.
|
||||
func ScanJSONB(dst interface{}, src interface{}) error {
|
||||
v := JSONB{dst}
|
||||
return v.Scan(src)
|
||||
}
|
||||
|
||||
type JSONBConverter struct {
|
||||
}
|
||||
|
||||
func (*JSONBConverter) ConvertValue(in interface{}) interface {
|
||||
sql.Scanner
|
||||
driver.Valuer
|
||||
} {
|
||||
return &JSONB{in}
|
||||
}
|
||||
|
||||
type timeWrapper struct {
|
||||
v **time.Time
|
||||
loc *time.Location
|
||||
}
|
||||
|
||||
func (t timeWrapper) Value() (driver.Value, error) {
|
||||
if *t.v != nil {
|
||||
return **t.v, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (t *timeWrapper) Scan(src interface{}) error {
|
||||
if src == nil {
|
||||
nilTime := (*time.Time)(nil)
|
||||
if t.v == nil {
|
||||
t.v = &nilTime
|
||||
} else {
|
||||
*(t.v) = nilTime
|
||||
}
|
||||
return nil
|
||||
}
|
||||
tz := src.(time.Time)
|
||||
if t.loc != nil && (tz.Location() == time.Local) {
|
||||
tz = tz.In(t.loc)
|
||||
}
|
||||
if tz.Location().String() == "" {
|
||||
tz = tz.In(time.UTC)
|
||||
}
|
||||
if *(t.v) == nil {
|
||||
*(t.v) = &tz
|
||||
} else {
|
||||
**t.v = tz
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *database) ConvertValueContext(ctx context.Context, in interface{}) interface{} {
|
||||
tz, _ := ctx.Value("timezone").(*time.Location)
|
||||
|
||||
switch v := in.(type) {
|
||||
case *time.Time:
|
||||
return &timeWrapper{&v, tz}
|
||||
case **time.Time:
|
||||
return &timeWrapper{v, tz}
|
||||
}
|
||||
|
||||
return d.ConvertValue(in)
|
||||
}
|
||||
|
||||
// Type checks.
|
||||
var (
|
||||
_ sqlbuilder.ScannerValuer = &StringArray{}
|
||||
_ sqlbuilder.ScannerValuer = &Int64Array{}
|
||||
_ sqlbuilder.ScannerValuer = &Float64Array{}
|
||||
_ sqlbuilder.ScannerValuer = &Float32Array{}
|
||||
_ sqlbuilder.ScannerValuer = &BoolArray{}
|
||||
_ sqlbuilder.ScannerValuer = &JSONBMap{}
|
||||
_ sqlbuilder.ScannerValuer = &JSONBArray{}
|
||||
_ sqlbuilder.ScannerValuer = &JSONB{}
|
||||
)
|
286
adapter/postgresql/custom_types_pgx.go
Normal file
286
adapter/postgresql/custom_types_pgx.go
Normal file
@ -0,0 +1,286 @@
|
||||
//go:build !pq
|
||||
// +build !pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
)
|
||||
|
||||
// JSONB represents a PostgreSQL's JSONB value:
|
||||
// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type JSONB struct {
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// MarshalJSON encodes the wrapper value as JSON.
|
||||
func (j JSONB) MarshalJSON() ([]byte, error) {
|
||||
t := &pgtype.JSONB{}
|
||||
if err := t.Set(j.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.MarshalJSON()
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes the given JSON into the wrapped value.
|
||||
func (j *JSONB) UnmarshalJSON(b []byte) error {
|
||||
t := &pgtype.JSONB{}
|
||||
if err := t.UnmarshalJSON(b); err != nil {
|
||||
return err
|
||||
}
|
||||
if j.Data == nil {
|
||||
j.Data = t.Get()
|
||||
return nil
|
||||
}
|
||||
if err := t.AssignTo(&j.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (j *JSONB) Scan(src interface{}) error {
|
||||
t := &pgtype.JSONB{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if j.Data == nil {
|
||||
j.Data = t.Get()
|
||||
return nil
|
||||
}
|
||||
if err := t.AssignTo(j.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
t := &pgtype.JSONB{}
|
||||
if err := t.Set(j.Data); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// StringArray represents a one-dimensional array of strings (`[]string{}`)
|
||||
// that is compatible with PostgreSQL's text array (`text[]`). StringArray
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type StringArray []string
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (a StringArray) Value() (driver.Value, error) {
|
||||
t := pgtype.TextArray{}
|
||||
if err := t.Set(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (sa *StringArray) Scan(src interface{}) error {
|
||||
d := []string{}
|
||||
t := pgtype.TextArray{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*sa = StringArray(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Bytea []byte
|
||||
|
||||
func (b Bytea) Value() (driver.Value, error) {
|
||||
t := pgtype.Bytea{Bytes: b}
|
||||
if err := t.Set(b); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
func (b *Bytea) Scan(src interface{}) error {
|
||||
d := []byte{}
|
||||
t := pgtype.Bytea{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*b = Bytea(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// ByteaArray represents a one-dimensional array of strings (`[]string{}`)
|
||||
// that is compatible with PostgreSQL's text array (`text[]`). ByteaArray
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type ByteaArray [][]byte
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (a ByteaArray) Value() (driver.Value, error) {
|
||||
t := pgtype.ByteaArray{}
|
||||
if err := t.Set(a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (ba *ByteaArray) Scan(src interface{}) error {
|
||||
d := [][]byte{}
|
||||
t := pgtype.ByteaArray{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*ba = ByteaArray(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that
|
||||
// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type Int64Array []int64
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (i64a Int64Array) Value() (driver.Value, error) {
|
||||
t := pgtype.Int8Array{}
|
||||
if err := t.Set(i64a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (i64a *Int64Array) Scan(src interface{}) error {
|
||||
d := []int64{}
|
||||
t := pgtype.Int8Array{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*i64a = Int64Array(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int32Array represents a one-dimensional array of int32s (`[]int32{}`) that
|
||||
// is compatible with PostgreSQL's integer array (`integer[]`). Int32Array
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type Int32Array []int32
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (i32a Int32Array) Value() (driver.Value, error) {
|
||||
t := pgtype.Int4Array{}
|
||||
if err := t.Set(i32a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (i32a *Int32Array) Scan(src interface{}) error {
|
||||
d := []int32{}
|
||||
t := pgtype.Int4Array{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*i32a = Int32Array(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Float64Array represents a one-dimensional array of float64s (`[]float64{}`)
|
||||
// that is compatible with PostgreSQL's double precision array (`double
|
||||
// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer.
|
||||
type Float64Array []float64
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (f64a Float64Array) Value() (driver.Value, error) {
|
||||
t := pgtype.Float8Array{}
|
||||
if err := t.Set(f64a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (f64a *Float64Array) Scan(src interface{}) error {
|
||||
d := []float64{}
|
||||
t := pgtype.Float8Array{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*f64a = Float64Array(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Float32Array represents a one-dimensional array of float32s (`[]float32{}`)
|
||||
// that is compatible with PostgreSQL's double precision array (`double
|
||||
// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer.
|
||||
type Float32Array []float32
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (f32a Float32Array) Value() (driver.Value, error) {
|
||||
t := pgtype.Float8Array{}
|
||||
if err := t.Set(f32a); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (f32a *Float32Array) Scan(src interface{}) error {
|
||||
d := []float32{}
|
||||
t := pgtype.Float8Array{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*f32a = Float32Array(d)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that
|
||||
// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type BoolArray []bool
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (ba BoolArray) Value() (driver.Value, error) {
|
||||
t := pgtype.BoolArray{}
|
||||
if err := t.Set(ba); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return t.Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (ba *BoolArray) Scan(src interface{}) error {
|
||||
d := []bool{}
|
||||
t := pgtype.BoolArray{}
|
||||
if err := t.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := t.AssignTo(&d); err != nil {
|
||||
return err
|
||||
}
|
||||
*ba = BoolArray(d)
|
||||
return nil
|
||||
}
|
249
adapter/postgresql/custom_types_pq.go
Normal file
249
adapter/postgresql/custom_types_pq.go
Normal file
@ -0,0 +1,249 @@
|
||||
//go:build pq
|
||||
// +build pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
)
|
||||
|
||||
// JSONB represents a PostgreSQL's JSONB value:
|
||||
// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type JSONB struct {
|
||||
Data interface{}
|
||||
}
|
||||
|
||||
// MarshalJSON encodes the wrapper value as JSON.
|
||||
func (j JSONB) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(j.Data)
|
||||
}
|
||||
|
||||
// UnmarshalJSON decodes the given JSON into the wrapped value.
|
||||
func (j *JSONB) UnmarshalJSON(b []byte) error {
|
||||
var v interface{}
|
||||
if err := json.Unmarshal(b, &v); err != nil {
|
||||
return err
|
||||
}
|
||||
j.Data = v
|
||||
return nil
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (j *JSONB) Scan(src interface{}) error {
|
||||
if j.Data == nil {
|
||||
return nil
|
||||
}
|
||||
if src == nil {
|
||||
dv := reflect.Indirect(reflect.ValueOf(j.Data))
|
||||
dv.Set(reflect.Zero(dv.Type()))
|
||||
return nil
|
||||
}
|
||||
|
||||
b, ok := src.([]byte)
|
||||
if !ok {
|
||||
return errors.New("Scan source was not []bytes")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(b, j.Data); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (j JSONB) Value() (driver.Value, error) {
|
||||
// See https://github.com/lib/pq/issues/528#issuecomment-257197239 on why are
|
||||
// we returning string instead of []byte.
|
||||
if j.Data == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if v, ok := j.Data.(json.RawMessage); ok {
|
||||
return string(v), nil
|
||||
}
|
||||
b, err := json.Marshal(j.Data)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return string(b), nil
|
||||
}
|
||||
|
||||
// StringArray represents a one-dimensional array of strings (`[]string{}`)
|
||||
// that is compatible with PostgreSQL's text array (`text[]`). StringArray
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type StringArray pq.StringArray
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (a StringArray) Value() (driver.Value, error) {
|
||||
return pq.StringArray(a).Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (a *StringArray) Scan(src interface{}) error {
|
||||
s := pq.StringArray(*a)
|
||||
if err := s.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*a = StringArray(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that
|
||||
// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type Int64Array pq.Int64Array
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (i Int64Array) Value() (driver.Value, error) {
|
||||
return pq.Int64Array(i).Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (i *Int64Array) Scan(src interface{}) error {
|
||||
s := pq.Int64Array(*i)
|
||||
if err := s.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*i = Int64Array(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Float64Array represents a one-dimensional array of float64s (`[]float64{}`)
|
||||
// that is compatible with PostgreSQL's double precision array (`double
|
||||
// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer.
|
||||
type Float64Array pq.Float64Array
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (f Float64Array) Value() (driver.Value, error) {
|
||||
return pq.Float64Array(f).Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (f *Float64Array) Scan(src interface{}) error {
|
||||
s := pq.Float64Array(*f)
|
||||
if err := s.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*f = Float64Array(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Float32Array represents a one-dimensional array of float32s (`[]float32{}`)
|
||||
// that is compatible with PostgreSQL's double precision array (`double
|
||||
// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer.
|
||||
type Float32Array pq.Float32Array
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (f Float32Array) Value() (driver.Value, error) {
|
||||
return pq.Float32Array(f).Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (f *Float32Array) Scan(src interface{}) error {
|
||||
s := pq.Float32Array(*f)
|
||||
if err := s.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*f = Float32Array(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that
|
||||
// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray
|
||||
// satisfies sqlbuilder.ScannerValuer.
|
||||
type BoolArray pq.BoolArray
|
||||
|
||||
// Value satisfies the driver.Valuer interface.
|
||||
func (b BoolArray) Value() (driver.Value, error) {
|
||||
return pq.BoolArray(b).Value()
|
||||
}
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (b *BoolArray) Scan(src interface{}) error {
|
||||
s := pq.BoolArray(*b)
|
||||
if err := s.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*b = BoolArray(s)
|
||||
return nil
|
||||
}
|
||||
|
||||
type Bytea []byte
|
||||
|
||||
// Scan satisfies the sql.Scanner interface.
|
||||
func (b *Bytea) Scan(src interface{}) error {
|
||||
decoded, err := parseBytea(src.([]byte))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(decoded) < 1 {
|
||||
*b = nil
|
||||
return nil
|
||||
}
|
||||
(*b) = make(Bytea, len(decoded))
|
||||
for i := range decoded {
|
||||
(*b)[i] = decoded[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type Time time.Time
|
||||
|
||||
// Parse a bytea value received from the server. Both "hex" and the legacy
|
||||
// "escape" format are supported.
|
||||
func parseBytea(s []byte) (result []byte, err error) {
|
||||
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
|
||||
// bytea_output = hex
|
||||
s = s[2:] // trim off leading "\\x"
|
||||
result = make([]byte, hex.DecodedLen(len(s)))
|
||||
_, err := hex.Decode(result, s)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// bytea_output = escape
|
||||
for len(s) > 0 {
|
||||
if s[0] == '\\' {
|
||||
// escaped '\\'
|
||||
if len(s) >= 2 && s[1] == '\\' {
|
||||
result = append(result, '\\')
|
||||
s = s[2:]
|
||||
continue
|
||||
}
|
||||
|
||||
// '\\' followed by an octal number
|
||||
if len(s) < 4 {
|
||||
return nil, fmt.Errorf("invalid bytea sequence %v", s)
|
||||
}
|
||||
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
|
||||
}
|
||||
result = append(result, byte(r))
|
||||
s = s[4:]
|
||||
} else {
|
||||
// We hit an unescaped, raw byte. Try to read in as many as
|
||||
// possible in one go.
|
||||
i := bytes.IndexByte(s, '\\')
|
||||
if i == -1 {
|
||||
result = append(result, s...)
|
||||
break
|
||||
}
|
||||
result = append(result, s[:i]...)
|
||||
s = s[i:]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
105
adapter/postgresql/custom_types_test.go
Normal file
105
adapter/postgresql/custom_types_test.go
Normal file
@ -0,0 +1,105 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testStruct struct {
|
||||
X int `json:"x"`
|
||||
Z string `json:"z"`
|
||||
V interface{} `json:"v"`
|
||||
}
|
||||
|
||||
func TestScanJSONB(t *testing.T) {
|
||||
{
|
||||
a := testStruct{}
|
||||
err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": 1}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello", a.Z)
|
||||
assert.Equal(t, float64(1), a.V)
|
||||
assert.Equal(t, 5, a.X)
|
||||
}
|
||||
{
|
||||
a := testStruct{}
|
||||
err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": null}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello", a.Z)
|
||||
assert.Equal(t, nil, a.V)
|
||||
assert.Equal(t, 5, a.X)
|
||||
}
|
||||
{
|
||||
a := testStruct{}
|
||||
err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello"}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello", a.Z)
|
||||
assert.Equal(t, nil, a.V)
|
||||
assert.Equal(t, 5, a.X)
|
||||
}
|
||||
{
|
||||
a := testStruct{}
|
||||
err := ScanJSONB(&a, []byte(`{"v": "Hello"}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "Hello", a.V)
|
||||
}
|
||||
{
|
||||
a := testStruct{}
|
||||
err := ScanJSONB(&a, []byte(`{"v": true}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, true, a.V)
|
||||
}
|
||||
{
|
||||
a := testStruct{}
|
||||
err := ScanJSONB(&a, []byte(`{}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, nil, a.V)
|
||||
}
|
||||
{
|
||||
var a []byte
|
||||
err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a))
|
||||
}
|
||||
{
|
||||
var a json.RawMessage
|
||||
err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a))
|
||||
}
|
||||
{
|
||||
var a json.RawMessage
|
||||
err := ScanJSONB(&a, []byte("{\"age\":[{\"\u003e\":\"1h\"}]}"))
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `{"age":[{">":"1h"}]}`, string(a))
|
||||
}
|
||||
{
|
||||
a := []*testStruct{}
|
||||
err := json.Unmarshal([]byte(`[{}]`), &a)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(a))
|
||||
assert.Nil(t, a[0].V)
|
||||
}
|
||||
{
|
||||
a := []*testStruct{}
|
||||
err := json.Unmarshal([]byte(`[{"v": true}]`), &a)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(a))
|
||||
assert.Equal(t, true, a[0].V)
|
||||
}
|
||||
{
|
||||
a := []*testStruct{}
|
||||
err := json.Unmarshal([]byte(`[{"v": null}]`), &a)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(a))
|
||||
assert.Nil(t, a[0].V)
|
||||
}
|
||||
{
|
||||
a := []*testStruct{}
|
||||
err := json.Unmarshal([]byte(`[{"v": 12.34}]`), &a)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, 1, len(a))
|
||||
assert.Equal(t, 12.34, a[0].V)
|
||||
}
|
||||
}
|
180
adapter/postgresql/database.go
Normal file
180
adapter/postgresql/database.go
Normal file
@ -0,0 +1,180 @@
|
||||
// Package postgresql provides an adapter for PostgreSQL.
|
||||
// See https://github.com/upper/db/adapter/postgresql for documentation,
|
||||
// particularities and usage examples.
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
type database struct {
|
||||
}
|
||||
|
||||
func (*database) Template() *exql.Template {
|
||||
return template
|
||||
}
|
||||
|
||||
func (*database) Collections(sess sqladapter.Session) (collections []string, err error) {
|
||||
q := sess.SQL().
|
||||
Select("table_name").
|
||||
From("information_schema.tables").
|
||||
Where("table_schema = ?", "public")
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
for iter.Next() {
|
||||
var name string
|
||||
if err := iter.Scan(&name); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collections = append(collections, name)
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return collections, nil
|
||||
}
|
||||
|
||||
func (*database) ConvertValue(in interface{}) interface{} {
|
||||
switch v := in.(type) {
|
||||
case *[]int64:
|
||||
return (*Int64Array)(v)
|
||||
case *[]string:
|
||||
return (*StringArray)(v)
|
||||
case *[]float64:
|
||||
return (*Float64Array)(v)
|
||||
case *[]bool:
|
||||
return (*BoolArray)(v)
|
||||
case *map[string]interface{}:
|
||||
return (*JSONBMap)(v)
|
||||
|
||||
case []int64:
|
||||
return (*Int64Array)(&v)
|
||||
case []string:
|
||||
return (*StringArray)(&v)
|
||||
case []float64:
|
||||
return (*Float64Array)(&v)
|
||||
case []bool:
|
||||
return (*BoolArray)(&v)
|
||||
case map[string]interface{}:
|
||||
return (*JSONBMap)(&v)
|
||||
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error) {
|
||||
compiled, err := stmt.Compile(template)
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
|
||||
query, args := sqlbuilder.Preprocess(compiled, args)
|
||||
query = string(sqladapter.ReplaceWithDollarSign([]byte(query)))
|
||||
return query, args, nil
|
||||
}
|
||||
|
||||
func (*database) Err(err error) error {
|
||||
if err != nil {
|
||||
s := err.Error()
|
||||
// These errors are not exported so we have to check them by they string value.
|
||||
if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) || strings.Contains(s, `too many open`) {
|
||||
return mydb.ErrTooManyClients
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (*database) NewCollection() sqladapter.CollectionAdapter {
|
||||
return &collectionAdapter{}
|
||||
}
|
||||
|
||||
func (*database) LookupName(sess sqladapter.Session) (string, error) {
|
||||
q := sess.SQL().
|
||||
Select(mydb.Raw("CURRENT_DATABASE() AS name"))
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
if iter.Next() {
|
||||
var name string
|
||||
if err := iter.Scan(&name); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return name, nil
|
||||
}
|
||||
|
||||
return "", iter.Err()
|
||||
}
|
||||
|
||||
func (*database) TableExists(sess sqladapter.Session, name string) error {
|
||||
q := sess.SQL().
|
||||
Select("table_name").
|
||||
From("information_schema.tables").
|
||||
Where("table_catalog = ? AND table_name = ?", sess.Name(), name)
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
if iter.Next() {
|
||||
var name string
|
||||
if err := iter.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mydb.ErrCollectionDoesNotExist
|
||||
}
|
||||
|
||||
func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) {
|
||||
q := sess.SQL().
|
||||
Select("pg_attribute.attname AS pkey").
|
||||
From("pg_index", "pg_class", "pg_attribute").
|
||||
Where(`
|
||||
pg_class.oid = '` + quotedTableName(tableName) + `'::regclass
|
||||
AND indrelid = pg_class.oid
|
||||
AND pg_attribute.attrelid = pg_class.oid
|
||||
AND pg_attribute.attnum = ANY(pg_index.indkey)
|
||||
AND indisprimary
|
||||
`).OrderBy("pkey")
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
pk := []string{}
|
||||
|
||||
for iter.Next() {
|
||||
var k string
|
||||
if err := iter.Scan(&k); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pk = append(pk, k)
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
||||
|
||||
// quotedTableName returns a valid regclass name for both regular tables and
|
||||
// for schemas.
|
||||
func quotedTableName(s string) string {
|
||||
chunks := strings.Split(s, ".")
|
||||
for i := range chunks {
|
||||
chunks[i] = fmt.Sprintf("%q", chunks[i])
|
||||
}
|
||||
return strings.Join(chunks, ".")
|
||||
}
|
26
adapter/postgresql/database_pgx.go
Normal file
26
adapter/postgresql/database_pgx.go
Normal file
@ -0,0 +1,26 @@
|
||||
//go:build !pq
|
||||
// +build !pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
)
|
||||
|
||||
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
|
||||
connURL, err := ParseURL(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tz := connURL.Options["timezone"]; tz != "" {
|
||||
loc, _ := time.LoadLocation(tz)
|
||||
ctx := context.WithValue(sess.Context(), "timezone", loc)
|
||||
sess.SetContext(ctx)
|
||||
}
|
||||
return sql.Open("pgx", dsn)
|
||||
}
|
26
adapter/postgresql/database_pq.go
Normal file
26
adapter/postgresql/database_pq.go
Normal file
@ -0,0 +1,26 @@
|
||||
//go:build pq
|
||||
// +build pq
|
||||
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
|
||||
connURL, err := ParseURL(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if tz := connURL.Options["timezone"]; tz != "" {
|
||||
loc, _ := time.LoadLocation(tz)
|
||||
ctx := context.WithValue(sess.Context(), "timezone", loc)
|
||||
sess.SetContext(ctx)
|
||||
}
|
||||
return sql.Open("postgres", dsn)
|
||||
}
|
13
adapter/postgresql/docker-compose.yml
Normal file
13
adapter/postgresql/docker-compose.yml
Normal file
@ -0,0 +1,13 @@
|
||||
version: '3'
|
||||
|
||||
services:
|
||||
|
||||
server:
|
||||
image: postgres:${POSTGRES_VERSION:-11}
|
||||
environment:
|
||||
POSTGRES_USER: ${DB_USERNAME:-upperio_user}
|
||||
POSTGRES_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37}
|
||||
POSTGRES_DB: ${DB_NAME:-upperio}
|
||||
ports:
|
||||
- '${DB_HOST:-127.0.0.1}:${DB_PORT:-5432}:5432'
|
||||
|
20
adapter/postgresql/generic_test.go
Normal file
20
adapter/postgresql/generic_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GenericTests struct {
|
||||
testsuite.GenericTestSuite
|
||||
}
|
||||
|
||||
func (s *GenericTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
suite.Run(t, &GenericTests{})
|
||||
}
|
321
adapter/postgresql/helper_test.go
Normal file
321
adapter/postgresql/helper_test.go
Normal file
@ -0,0 +1,321 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
)
|
||||
|
||||
var settings = ConnectionURL{
|
||||
Database: os.Getenv("DB_NAME"),
|
||||
User: os.Getenv("DB_USERNAME"),
|
||||
Password: os.Getenv("DB_PASSWORD"),
|
||||
Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"),
|
||||
Options: map[string]string{
|
||||
"timezone": testsuite.TimeZone,
|
||||
},
|
||||
}
|
||||
|
||||
const preparedStatementsKey = "pg_prepared_statements_count"
|
||||
|
||||
type Helper struct {
|
||||
sess mydb.Session
|
||||
}
|
||||
|
||||
func cleanUp(sess mydb.Session) error {
|
||||
if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 {
|
||||
return fmt.Errorf("Expecting active statements to be less than 128, got %d", activeStatements)
|
||||
}
|
||||
|
||||
sess.Reset()
|
||||
|
||||
stats, err := getStats(sess)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if stats[preparedStatementsKey] != 0 {
|
||||
return fmt.Errorf(`Expecting %q to be 0, got %d`, preparedStatementsKey, stats[preparedStatementsKey])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getStats(sess mydb.Session) (map[string]int, error) {
|
||||
stats := make(map[string]int)
|
||||
|
||||
row := sess.Driver().(*sql.DB).QueryRow(`SELECT count(1) AS value FROM pg_prepared_statements`)
|
||||
|
||||
var value int
|
||||
err := row.Scan(&value)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stats[preparedStatementsKey] = value
|
||||
|
||||
return stats, nil
|
||||
}
|
||||
|
||||
func (h *Helper) Session() mydb.Session {
|
||||
return h.sess
|
||||
}
|
||||
|
||||
func (h *Helper) Adapter() string {
|
||||
return Adapter
|
||||
}
|
||||
|
||||
func (h *Helper) TearDown() error {
|
||||
if err := cleanUp(h.sess); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return h.sess.Close()
|
||||
}
|
||||
|
||||
func (h *Helper) TearUp() error {
|
||||
var err error
|
||||
|
||||
h.sess, err = Open(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batch := []string{
|
||||
`DROP TABLE IF EXISTS artist`,
|
||||
`CREATE TABLE artist (
|
||||
id serial primary key,
|
||||
name varchar(60)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS publication`,
|
||||
`CREATE TABLE publication (
|
||||
id serial primary key,
|
||||
title varchar(80),
|
||||
author_id integer
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS review`,
|
||||
`CREATE TABLE review (
|
||||
id serial primary key,
|
||||
publication_id integer,
|
||||
name varchar(80),
|
||||
comments text,
|
||||
created timestamp without time zone
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS data_types`,
|
||||
`CREATE TABLE data_types (
|
||||
id serial primary key,
|
||||
_uint integer,
|
||||
_uint8 integer,
|
||||
_uint16 integer,
|
||||
_uint32 integer,
|
||||
_uint64 integer,
|
||||
_int integer,
|
||||
_int8 integer,
|
||||
_int16 integer,
|
||||
_int32 integer,
|
||||
_int64 integer,
|
||||
_float32 numeric(10,6),
|
||||
_float64 numeric(10,6),
|
||||
_bool boolean,
|
||||
_string text,
|
||||
_blob bytea,
|
||||
_date timestamp with time zone,
|
||||
_nildate timestamp without time zone null,
|
||||
_ptrdate timestamp without time zone,
|
||||
_defaultdate timestamp without time zone DEFAULT now(),
|
||||
_time bigint
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS stats_test`,
|
||||
`CREATE TABLE stats_test (
|
||||
id serial primary key,
|
||||
numeric integer,
|
||||
value integer
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS composite_keys`,
|
||||
`CREATE TABLE composite_keys (
|
||||
code varchar(255) default '',
|
||||
user_id varchar(255) default '',
|
||||
some_val varchar(255) default '',
|
||||
primary key (code, user_id)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS option_types`,
|
||||
`CREATE TABLE option_types (
|
||||
id serial primary key,
|
||||
name varchar(255) default '',
|
||||
tags varchar(64)[],
|
||||
settings jsonb
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS test_schema.test`,
|
||||
`DROP SCHEMA IF EXISTS test_schema`,
|
||||
|
||||
`CREATE SCHEMA test_schema`,
|
||||
`CREATE TABLE test_schema.test (id integer)`,
|
||||
|
||||
`DROP TABLE IF EXISTS pg_types`,
|
||||
`CREATE TABLE pg_types (id serial primary key
|
||||
, uint8_value smallint
|
||||
, uint8_value_array bytea
|
||||
|
||||
, int64_value smallint
|
||||
, int64_value_array smallint[]
|
||||
|
||||
, integer_array integer[]
|
||||
, string_array text[]
|
||||
, jsonb_map jsonb
|
||||
, raw_jsonb_map jsonb
|
||||
, raw_jsonb_text jsonb
|
||||
|
||||
, integer_array_ptr integer[]
|
||||
, string_array_ptr text[]
|
||||
, jsonb_map_ptr jsonb
|
||||
|
||||
, auto_integer_array integer[]
|
||||
, auto_string_array text[]
|
||||
, auto_jsonb_map jsonb
|
||||
, auto_jsonb_map_string jsonb
|
||||
, auto_jsonb_map_integer jsonb
|
||||
|
||||
, jsonb_object jsonb
|
||||
, jsonb_array jsonb
|
||||
|
||||
, custom_jsonb_object jsonb
|
||||
, auto_custom_jsonb_object jsonb
|
||||
|
||||
, custom_jsonb_object_ptr jsonb
|
||||
, auto_custom_jsonb_object_ptr jsonb
|
||||
|
||||
, custom_jsonb_object_array jsonb
|
||||
, auto_custom_jsonb_object_array jsonb
|
||||
, auto_custom_jsonb_object_map jsonb
|
||||
|
||||
, string_value varchar(255)
|
||||
, integer_value int
|
||||
, varchar_value varchar(64)
|
||||
, decimal_value decimal
|
||||
|
||||
, integer_compat_value int
|
||||
, uinteger_compat_value int
|
||||
, string_compat_value text
|
||||
|
||||
, integer_compat_value_jsonb_array jsonb
|
||||
, string_compat_value_jsonb_array jsonb
|
||||
, uinteger_compat_value_jsonb_array jsonb
|
||||
|
||||
, string_value_ptr varchar(255)
|
||||
, integer_value_ptr int
|
||||
, varchar_value_ptr varchar(64)
|
||||
, decimal_value_ptr decimal
|
||||
|
||||
, uuid_value_string UUID
|
||||
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS issue_370`,
|
||||
`CREATE TABLE issue_370 (
|
||||
id UUID PRIMARY KEY,
|
||||
name VARCHAR(25)
|
||||
)`,
|
||||
|
||||
`CREATE EXTENSION IF NOT EXISTS "uuid-ossp"`,
|
||||
|
||||
`DROP TABLE IF EXISTS issue_602_organizations`,
|
||||
`CREATE TABLE issue_602_organizations (
|
||||
name character varying(256) NOT NULL,
|
||||
created_at timestamp without time zone DEFAULT now() NOT NULL,
|
||||
updated_at timestamp without time zone DEFAULT now() NOT NULL,
|
||||
id uuid DEFAULT public.uuid_generate_v4() NOT NULL
|
||||
)`,
|
||||
|
||||
`ALTER TABLE ONLY issue_602_organizations ADD CONSTRAINT issue_602_organizations_pkey PRIMARY KEY (id)`,
|
||||
|
||||
`DROP TABLE IF EXISTS issue_370_2`,
|
||||
`CREATE TABLE issue_370_2 (
|
||||
id INTEGER[3] PRIMARY KEY,
|
||||
name VARCHAR(25)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS varchar_primary_key`,
|
||||
`CREATE TABLE varchar_primary_key (
|
||||
address VARCHAR(42) PRIMARY KEY NOT NULL,
|
||||
name VARCHAR(25)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "birthdays"`,
|
||||
`CREATE TABLE "birthdays" (
|
||||
"id" serial primary key,
|
||||
"name" CHARACTER VARYING(50),
|
||||
"born" TIMESTAMP WITH TIME ZONE,
|
||||
"born_ut" INT
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "fibonacci"`,
|
||||
`CREATE TABLE "fibonacci" (
|
||||
"id" serial primary key,
|
||||
"input" NUMERIC,
|
||||
"output" NUMERIC
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "is_even"`,
|
||||
`CREATE TABLE "is_even" (
|
||||
"input" NUMERIC,
|
||||
"is_even" BOOL
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "CaSe_TesT"`,
|
||||
`CREATE TABLE "CaSe_TesT" (
|
||||
"id" SERIAL PRIMARY KEY,
|
||||
"case_test" VARCHAR(60)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS accounts`,
|
||||
`CREATE TABLE accounts (
|
||||
id serial primary key,
|
||||
name varchar(255),
|
||||
disabled boolean,
|
||||
created_at timestamp with time zone
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS users`,
|
||||
`CREATE TABLE users (
|
||||
id serial primary key,
|
||||
account_id integer,
|
||||
username varchar(255) UNIQUE
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS logs`,
|
||||
`CREATE TABLE logs (
|
||||
id serial primary key,
|
||||
message VARCHAR
|
||||
)`,
|
||||
}
|
||||
|
||||
driver := h.sess.Driver().(*sql.DB)
|
||||
tx, err := driver.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, query := range batch {
|
||||
if _, err := tx.Exec(query); err != nil {
|
||||
_ = tx.Rollback()
|
||||
return err
|
||||
}
|
||||
}
|
||||
if err := tx.Commit(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ testsuite.Helper = &Helper{}
|
30
adapter/postgresql/postgresql.go
Normal file
30
adapter/postgresql/postgresql.go
Normal file
@ -0,0 +1,30 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
// Adapter is the internal name of the adapter.
|
||||
const Adapter = "postgresql"
|
||||
|
||||
var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{})
|
||||
|
||||
// Open establishes a connection to the database server and returns a
|
||||
// sqlbuilder.Session instance (which is compatible with mydb.Session).
|
||||
func Open(connURL mydb.ConnectionURL) (mydb.Session, error) {
|
||||
return registeredAdapter.OpenDSN(connURL)
|
||||
}
|
||||
|
||||
// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value.
|
||||
func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
|
||||
return registeredAdapter.NewTx(sqlTx)
|
||||
}
|
||||
|
||||
// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value.
|
||||
func New(sqlDB *sql.DB) (mydb.Session, error) {
|
||||
return registeredAdapter.New(sqlDB)
|
||||
}
|
1404
adapter/postgresql/postgresql_test.go
Normal file
1404
adapter/postgresql/postgresql_test.go
Normal file
File diff suppressed because it is too large
Load Diff
20
adapter/postgresql/record_test.go
Normal file
20
adapter/postgresql/record_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RecordTests struct {
|
||||
testsuite.RecordTestSuite
|
||||
}
|
||||
|
||||
func (s *RecordTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
suite.Run(t, &RecordTests{})
|
||||
}
|
20
adapter/postgresql/sql_test.go
Normal file
20
adapter/postgresql/sql_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SQLTests struct {
|
||||
testsuite.SQLTestSuite
|
||||
}
|
||||
|
||||
func (s *SQLTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
suite.Run(t, &SQLTests{})
|
||||
}
|
189
adapter/postgresql/template.go
Normal file
189
adapter/postgresql/template.go
Normal file
@ -0,0 +1,189 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"git.hexq.cn/tiglog/mydb/internal/adapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/cache"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
)
|
||||
|
||||
const (
|
||||
adapterColumnSeparator = `.`
|
||||
adapterIdentifierSeparator = `, `
|
||||
adapterIdentifierQuote = `"{{.Value}}"`
|
||||
adapterValueSeparator = `, `
|
||||
adapterValueQuote = `'{{.}}'`
|
||||
adapterAndKeyword = `AND`
|
||||
adapterOrKeyword = `OR`
|
||||
adapterDescKeyword = `DESC`
|
||||
adapterAscKeyword = `ASC`
|
||||
adapterAssignmentOperator = `=`
|
||||
adapterClauseGroup = `({{.}})`
|
||||
adapterClauseOperator = ` {{.}} `
|
||||
adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}`
|
||||
adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
|
||||
adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
|
||||
adapterSortByColumnLayout = `{{.Column}} {{.Order}}`
|
||||
|
||||
adapterOrderByLayout = `
|
||||
{{if .SortColumns}}
|
||||
ORDER BY {{.SortColumns}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterWhereLayout = `
|
||||
{{if .Conds}}
|
||||
WHERE {{.Conds}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterUsingLayout = `
|
||||
{{if .Columns}}
|
||||
USING ({{.Columns}})
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterJoinLayout = `
|
||||
{{if .Table}}
|
||||
{{ if .On }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{.On}}
|
||||
{{ else if .Using }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{.Using}}
|
||||
{{ else if .Type | eq "CROSS" }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{else}}
|
||||
NATURAL {{.Type}} JOIN {{.Table}}
|
||||
{{end}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterOnLayout = `
|
||||
{{if .Conds}}
|
||||
ON {{.Conds}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterSelectLayout = `
|
||||
SELECT
|
||||
{{if .Distinct}}
|
||||
DISTINCT
|
||||
{{end}}
|
||||
|
||||
{{if defined .Columns}}
|
||||
{{.Columns | compile}}
|
||||
{{else}}
|
||||
*
|
||||
{{end}}
|
||||
|
||||
{{if defined .Table}}
|
||||
FROM {{.Table | compile}}
|
||||
{{end}}
|
||||
|
||||
{{.Joins | compile}}
|
||||
|
||||
{{.Where | compile}}
|
||||
|
||||
{{if defined .GroupBy}}
|
||||
{{.GroupBy | compile}}
|
||||
{{end}}
|
||||
|
||||
{{.OrderBy | compile}}
|
||||
|
||||
{{if .Limit}}
|
||||
LIMIT {{.Limit}}
|
||||
{{end}}
|
||||
|
||||
{{if .Offset}}
|
||||
OFFSET {{.Offset}}
|
||||
{{end}}
|
||||
`
|
||||
adapterDeleteLayout = `
|
||||
DELETE
|
||||
FROM {{.Table | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
adapterUpdateLayout = `
|
||||
UPDATE
|
||||
{{.Table | compile}}
|
||||
SET {{.ColumnValues | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
|
||||
adapterSelectCountLayout = `
|
||||
SELECT
|
||||
COUNT(1) AS _t
|
||||
FROM {{.Table | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
|
||||
adapterInsertLayout = `
|
||||
INSERT INTO {{.Table | compile}}
|
||||
{{if defined .Columns}}({{.Columns | compile}}){{end}}
|
||||
VALUES
|
||||
{{if defined .Values}}
|
||||
{{.Values | compile}}
|
||||
{{else}}
|
||||
(default)
|
||||
{{end}}
|
||||
{{if defined .Returning}}
|
||||
RETURNING {{.Returning | compile}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterTruncateLayout = `
|
||||
TRUNCATE TABLE {{.Table | compile}} RESTART IDENTITY
|
||||
`
|
||||
|
||||
adapterDropDatabaseLayout = `
|
||||
DROP DATABASE {{.Database | compile}}
|
||||
`
|
||||
|
||||
adapterDropTableLayout = `
|
||||
DROP TABLE {{.Table | compile}}
|
||||
`
|
||||
|
||||
adapterGroupByLayout = `
|
||||
{{if .GroupColumns}}
|
||||
GROUP BY {{.GroupColumns}}
|
||||
{{end}}
|
||||
`
|
||||
)
|
||||
|
||||
var template = &exql.Template{
|
||||
ColumnSeparator: adapterColumnSeparator,
|
||||
IdentifierSeparator: adapterIdentifierSeparator,
|
||||
IdentifierQuote: adapterIdentifierQuote,
|
||||
ValueSeparator: adapterValueSeparator,
|
||||
ValueQuote: adapterValueQuote,
|
||||
AndKeyword: adapterAndKeyword,
|
||||
OrKeyword: adapterOrKeyword,
|
||||
DescKeyword: adapterDescKeyword,
|
||||
AscKeyword: adapterAscKeyword,
|
||||
AssignmentOperator: adapterAssignmentOperator,
|
||||
ClauseGroup: adapterClauseGroup,
|
||||
ClauseOperator: adapterClauseOperator,
|
||||
ColumnValue: adapterColumnValue,
|
||||
TableAliasLayout: adapterTableAliasLayout,
|
||||
ColumnAliasLayout: adapterColumnAliasLayout,
|
||||
SortByColumnLayout: adapterSortByColumnLayout,
|
||||
WhereLayout: adapterWhereLayout,
|
||||
JoinLayout: adapterJoinLayout,
|
||||
OnLayout: adapterOnLayout,
|
||||
UsingLayout: adapterUsingLayout,
|
||||
OrderByLayout: adapterOrderByLayout,
|
||||
InsertLayout: adapterInsertLayout,
|
||||
SelectLayout: adapterSelectLayout,
|
||||
UpdateLayout: adapterUpdateLayout,
|
||||
DeleteLayout: adapterDeleteLayout,
|
||||
TruncateLayout: adapterTruncateLayout,
|
||||
DropDatabaseLayout: adapterDropDatabaseLayout,
|
||||
DropTableLayout: adapterDropTableLayout,
|
||||
CountLayout: adapterSelectCountLayout,
|
||||
GroupByLayout: adapterGroupByLayout,
|
||||
Cache: cache.NewCache(),
|
||||
ComparisonOperator: map[adapter.ComparisonOperator]string{
|
||||
adapter.ComparisonOperatorRegExp: "~",
|
||||
adapter.ComparisonOperatorNotRegExp: "!~",
|
||||
},
|
||||
}
|
262
adapter/postgresql/template_test.go
Normal file
262
adapter/postgresql/template_test.go
Normal file
@ -0,0 +1,262 @@
|
||||
package postgresql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTemplateSelect(t *testing.T) {
|
||||
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist"`,
|
||||
b.SelectFrom("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist"`,
|
||||
b.Select().From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" DESC`,
|
||||
b.Select().From("artist").OrderBy("name DESC").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" DESC`,
|
||||
b.Select().From("artist").OrderBy("-name").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" ASC`,
|
||||
b.Select().From("artist").OrderBy("name").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" ASC`,
|
||||
b.Select().From("artist").OrderBy("name ASC").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" LIMIT 1 OFFSET 5`,
|
||||
b.Select().From("artist").Limit(1).Offset(5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" LIMIT 1 OFFSET 5`,
|
||||
b.Select().From("artist").Offset(5).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" OFFSET 5`,
|
||||
b.Select().From("artist").Limit(-1).Offset(5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" OFFSET 5`,
|
||||
b.Select().From("artist").Offset(5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id" FROM "artist"`,
|
||||
b.Select("id").From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id", "name" FROM "artist"`,
|
||||
b.Select("id", "name").From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE ("name" = $1)`,
|
||||
b.SelectFrom("artist").Where("name", "Haruki").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE (name LIKE $1)`,
|
||||
b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`,
|
||||
b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE ("id" > $1)`,
|
||||
b.SelectFrom("artist").Where("id >", 2).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`,
|
||||
b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`,
|
||||
b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE (name IS NOT NULL)`,
|
||||
b.SelectFrom("artist").Where("name IS NOT NULL").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`,
|
||||
b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id" FROM "artist" NATURAL JOIN "publication"`,
|
||||
b.Select("id").From("artist").Join("publication").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`,
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`,
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`,
|
||||
b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`,
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`,
|
||||
b.SelectFrom("artist a").
|
||||
LeftJoin("publication p1").On("p1.id = a.id").
|
||||
RightJoin("publication p2").On("p2.id = a.id").
|
||||
String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" CROSS JOIN "publication"`,
|
||||
b.SelectFrom("artist").CrossJoin("publication").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" JOIN "publication" USING ("id")`,
|
||||
b.SelectFrom("artist").Join("publication").Using("id").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT DATE()`,
|
||||
b.Select(mydb.Raw("DATE()")).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateInsert(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`,
|
||||
b.InsertInto("artist").
|
||||
Values(10, "Ryuichi Sakamoto").
|
||||
Values(11, "Alondra de la Parra").
|
||||
Values(12, "Haruki Murakami").
|
||||
String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`,
|
||||
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Values(struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
}{12, "Chavela Vargas"}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateUpdate(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1`,
|
||||
b.Update("artist").Set("name", "Artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
|
||||
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
|
||||
b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
|
||||
b.Update("artist").Set(struct {
|
||||
Nombre string `db:"name"`
|
||||
}{"Artist"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`,
|
||||
b.Update("artist").Set(struct {
|
||||
Nombre string `db:"name"`
|
||||
}{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`,
|
||||
b.Update("artist").Set(
|
||||
"name = ? || ' ' || ? || id", "Artist", "#",
|
||||
"id = id + ?", 10,
|
||||
).Where("id > ?", 0).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateDelete(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`DELETE FROM "artist" WHERE (name = $1)`,
|
||||
b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`DELETE FROM "artist" WHERE (id > 5)`,
|
||||
b.DeleteFrom("artist").Where("id > 5").String(),
|
||||
)
|
||||
}
|
27
adapter/sqlite/Makefile
Normal file
27
adapter/sqlite/Makefile
Normal file
@ -0,0 +1,27 @@
|
||||
SHELL ?= bash
|
||||
DB_NAME ?= sqlite3-test.db
|
||||
TEST_FLAGS ?=
|
||||
|
||||
export DB_NAME
|
||||
|
||||
export TEST_FLAGS
|
||||
|
||||
build:
|
||||
go build && go install
|
||||
|
||||
require-client:
|
||||
@if [ -z "$$(which sqlite3)" ]; then \
|
||||
echo 'Missing "sqlite3" command. Please install SQLite3 and try again.' && \
|
||||
exit 1; \
|
||||
fi
|
||||
|
||||
reset-db: require-client
|
||||
rm -f $(DB_NAME)
|
||||
|
||||
test: reset-db
|
||||
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
|
||||
|
||||
test-no-race:
|
||||
go test -v -failfast $(TEST_FLAGS)
|
||||
|
||||
test-extended: test
|
4
adapter/sqlite/README.md
Normal file
4
adapter/sqlite/README.md
Normal file
@ -0,0 +1,4 @@
|
||||
# SQLite adapter for upper/db
|
||||
|
||||
Please read the full docs, acknowledgements and examples at
|
||||
[https://upper.io/v4/adapter/sqlite/](https://upper.io/v4/adapter/sqlite/).
|
49
adapter/sqlite/collection.go
Normal file
49
adapter/sqlite/collection.go
Normal file
@ -0,0 +1,49 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
type collectionAdapter struct {
|
||||
}
|
||||
|
||||
func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) {
|
||||
columnNames, columnValues, err := sqlbuilder.Map(item, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pKey, err := col.PrimaryKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
q := col.SQL().InsertInto(col.Name()).
|
||||
Columns(columnNames...).
|
||||
Values(columnValues...)
|
||||
|
||||
var res sql.Result
|
||||
if res, err = q.Exec(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(pKey) <= 1 {
|
||||
return res.LastInsertId()
|
||||
}
|
||||
|
||||
keyMap := mydb.Cond{}
|
||||
|
||||
for i := range columnNames {
|
||||
for j := 0; j < len(pKey); j++ {
|
||||
if pKey[j] == columnNames[i] {
|
||||
keyMap[pKey[j]] = columnValues[i]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keyMap, nil
|
||||
}
|
89
adapter/sqlite/connection.go
Normal file
89
adapter/sqlite/connection.go
Normal file
@ -0,0 +1,89 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const connectionScheme = `file`
|
||||
|
||||
// ConnectionURL implements a SQLite connection struct.
|
||||
type ConnectionURL struct {
|
||||
Database string
|
||||
Options map[string]string
|
||||
}
|
||||
|
||||
func (c ConnectionURL) String() (s string) {
|
||||
vv := url.Values{}
|
||||
|
||||
if c.Database == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Did the user provided a full database path?
|
||||
if !strings.HasPrefix(c.Database, "/") {
|
||||
c.Database, _ = filepath.Abs(c.Database)
|
||||
if runtime.GOOS == "windows" {
|
||||
// Closes https://github.com/upper/db/issues/60
|
||||
c.Database = "/" + strings.Replace(c.Database, `\`, `/`, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// Do we have any options?
|
||||
if c.Options == nil {
|
||||
c.Options = map[string]string{}
|
||||
}
|
||||
|
||||
if _, ok := c.Options["_busy_timeout"]; !ok {
|
||||
c.Options["_busy_timeout"] = "10000"
|
||||
}
|
||||
|
||||
// Converting options into URL values.
|
||||
for k, v := range c.Options {
|
||||
vv.Set(k, v)
|
||||
}
|
||||
|
||||
// Building URL.
|
||||
u := url.URL{
|
||||
Scheme: connectionScheme,
|
||||
Path: c.Database,
|
||||
RawQuery: vv.Encode(),
|
||||
}
|
||||
|
||||
return u.String()
|
||||
}
|
||||
|
||||
// ParseURL parses s into a ConnectionURL struct.
|
||||
func ParseURL(s string) (conn ConnectionURL, err error) {
|
||||
var u *url.URL
|
||||
|
||||
if !strings.HasPrefix(s, connectionScheme+"://") {
|
||||
return conn, fmt.Errorf(`Expecting file:// connection scheme.`)
|
||||
}
|
||||
|
||||
if u, err = url.Parse(s); err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
conn.Database = u.Host + u.Path
|
||||
conn.Options = map[string]string{}
|
||||
|
||||
var vv url.Values
|
||||
|
||||
if vv, err = url.ParseQuery(u.RawQuery); err != nil {
|
||||
return conn, err
|
||||
}
|
||||
|
||||
for k := range vv {
|
||||
conn.Options[k] = vv.Get(k)
|
||||
}
|
||||
|
||||
if _, ok := conn.Options["cache"]; !ok {
|
||||
conn.Options["cache"] = "shared"
|
||||
}
|
||||
|
||||
return conn, err
|
||||
}
|
88
adapter/sqlite/connection_test.go
Normal file
88
adapter/sqlite/connection_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestConnectionURL(t *testing.T) {
|
||||
|
||||
c := ConnectionURL{}
|
||||
|
||||
// Default connection string is only the protocol.
|
||||
if c.String() != "" {
|
||||
t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String())
|
||||
}
|
||||
|
||||
// Adding a database name.
|
||||
c.Database = "myfilename"
|
||||
|
||||
absoluteName, _ := filepath.Abs(c.Database)
|
||||
|
||||
if c.String() != "file://"+absoluteName+"?_busy_timeout=10000" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Adding an option.
|
||||
c.Options = map[string]string{
|
||||
"cache": "foobar",
|
||||
"mode": "ro",
|
||||
}
|
||||
|
||||
if c.String() != "file://"+absoluteName+"?_busy_timeout=10000&cache=foobar&mode=ro" {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
// Setting another database.
|
||||
c.Database = "/another/database"
|
||||
|
||||
if c.String() != `file:///another/database?_busy_timeout=10000&cache=foobar&mode=ro` {
|
||||
t.Fatal(`Test failed, got:`, c.String())
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestParseConnectionURL(t *testing.T) {
|
||||
var u ConnectionURL
|
||||
var s string
|
||||
var err error
|
||||
|
||||
s = "file://mydatabase.db"
|
||||
|
||||
if u, err = ParseURL(s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u.Database != "mydatabase.db" {
|
||||
t.Fatal("Failed to parse database.")
|
||||
}
|
||||
|
||||
if u.Options["cache"] != "shared" {
|
||||
t.Fatal("If not defined, cache should be shared by default.")
|
||||
}
|
||||
|
||||
s = "file:///path/to/my/database.db?_busy_timeout=10000&mode=ro&cache=foobar"
|
||||
|
||||
if u, err = ParseURL(s); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if u.Database != "/path/to/my/database.db" {
|
||||
t.Fatal("Failed to parse username.")
|
||||
}
|
||||
|
||||
if u.Options["cache"] != "foobar" {
|
||||
t.Fatal("Expecting option.")
|
||||
}
|
||||
|
||||
if u.Options["mode"] != "ro" {
|
||||
t.Fatal("Expecting option.")
|
||||
}
|
||||
|
||||
s = "http://example.org"
|
||||
|
||||
if _, err = ParseURL(s); err == nil {
|
||||
t.Fatal("Expecting error.")
|
||||
}
|
||||
|
||||
}
|
168
adapter/sqlite/database.go
Normal file
168
adapter/sqlite/database.go
Normal file
@ -0,0 +1,168 @@
|
||||
// Package sqlite wraps the github.com/lib/sqlite SQLite driver. See
|
||||
// https://github.com/upper/db/adapter/sqlite for documentation, particularities and
|
||||
// usage examples.
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/compat"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
_ "github.com/mattn/go-sqlite3" // SQLite3 driver.
|
||||
)
|
||||
|
||||
// database is the actual implementation of Database
|
||||
type database struct {
|
||||
}
|
||||
|
||||
func (*database) Template() *exql.Template {
|
||||
return template
|
||||
}
|
||||
|
||||
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
|
||||
return sql.Open("sqlite3", dsn)
|
||||
}
|
||||
|
||||
func (*database) Collections(sess sqladapter.Session) (collections []string, err error) {
|
||||
q := sess.SQL().
|
||||
Select("tbl_name").
|
||||
From("sqlite_master").
|
||||
Where("type = ?", "table")
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
for iter.Next() {
|
||||
var tableName string
|
||||
if err := iter.Scan(&tableName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
collections = append(collections, tableName)
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return collections, nil
|
||||
}
|
||||
|
||||
func (*database) StatementExec(sess sqladapter.Session, ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) {
|
||||
if sess.Transaction() != nil {
|
||||
return compat.ExecContext(sess.Driver().(*sql.Tx), ctx, query, args)
|
||||
}
|
||||
|
||||
sqlTx, err := compat.BeginTx(sess.Driver().(*sql.DB), ctx, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if res, err = compat.ExecContext(sqlTx, ctx, query, args); err != nil {
|
||||
_ = sqlTx.Rollback()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = sqlTx.Commit(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (*database) NewCollection() sqladapter.CollectionAdapter {
|
||||
return &collectionAdapter{}
|
||||
}
|
||||
|
||||
func (*database) LookupName(sess sqladapter.Session) (string, error) {
|
||||
connURL := sess.ConnectionURL()
|
||||
if connURL != nil {
|
||||
connURL, err := ParseURL(connURL.String())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return connURL.Database, nil
|
||||
}
|
||||
|
||||
// sess.ConnectionURL() is nil if using sqlite.New
|
||||
rows, err := sess.SQL().Query(exql.RawSQL("PRAGMA database_list"))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dbInfo := struct {
|
||||
Name string `db:"name"`
|
||||
File string `db:"file"`
|
||||
}{}
|
||||
|
||||
if err := sess.SQL().NewIterator(rows).One(&dbInfo); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if dbInfo.File != "" {
|
||||
return dbInfo.File, nil
|
||||
}
|
||||
// dbInfo.File is empty if in memory mode
|
||||
return dbInfo.Name, nil
|
||||
}
|
||||
|
||||
func (*database) TableExists(sess sqladapter.Session, name string) error {
|
||||
q := sess.SQL().
|
||||
Select("tbl_name").
|
||||
From("sqlite_master").
|
||||
Where("type = 'table' AND tbl_name = ?", name)
|
||||
|
||||
iter := q.Iterator()
|
||||
defer iter.Close()
|
||||
|
||||
if iter.Next() {
|
||||
var name string
|
||||
if err := iter.Scan(&name); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if err := iter.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return mydb.ErrCollectionDoesNotExist
|
||||
}
|
||||
|
||||
func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) {
|
||||
pk := make([]string, 0, 1)
|
||||
|
||||
stmt := exql.RawSQL(fmt.Sprintf("PRAGMA TABLE_INFO('%s')", tableName))
|
||||
|
||||
rows, err := sess.SQL().Query(stmt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
columns := []struct {
|
||||
Name string `db:"name"`
|
||||
PK int `db:"pk"`
|
||||
}{}
|
||||
|
||||
if err := sess.SQL().NewIterator(rows).All(&columns); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxValue := -1
|
||||
|
||||
for _, column := range columns {
|
||||
if column.PK > 0 && column.PK > maxValue {
|
||||
maxValue = column.PK
|
||||
}
|
||||
}
|
||||
|
||||
if maxValue > 0 {
|
||||
for _, column := range columns {
|
||||
if column.PK > 0 {
|
||||
pk = append(pk, column.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return pk, nil
|
||||
}
|
20
adapter/sqlite/generic_test.go
Normal file
20
adapter/sqlite/generic_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type GenericTests struct {
|
||||
testsuite.GenericTestSuite
|
||||
}
|
||||
|
||||
func (s *GenericTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestGeneric(t *testing.T) {
|
||||
suite.Run(t, &GenericTests{})
|
||||
}
|
170
adapter/sqlite/helper_test.go
Normal file
170
adapter/sqlite/helper_test.go
Normal file
@ -0,0 +1,170 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"os"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
)
|
||||
|
||||
var settings = ConnectionURL{
|
||||
Database: os.Getenv("DB_NAME"),
|
||||
}
|
||||
|
||||
type Helper struct {
|
||||
sess mydb.Session
|
||||
}
|
||||
|
||||
func (h *Helper) Session() mydb.Session {
|
||||
return h.sess
|
||||
}
|
||||
|
||||
func (h *Helper) Adapter() string {
|
||||
return "sqlite"
|
||||
}
|
||||
|
||||
func (h *Helper) TearDown() error {
|
||||
return h.sess.Close()
|
||||
}
|
||||
|
||||
func (h *Helper) TearUp() error {
|
||||
var err error
|
||||
|
||||
h.sess, err = Open(settings)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
batch := []string{
|
||||
`PRAGMA foreign_keys=OFF`,
|
||||
|
||||
`BEGIN TRANSACTION`,
|
||||
|
||||
`DROP TABLE IF EXISTS artist`,
|
||||
`CREATE TABLE artist (
|
||||
id integer primary key,
|
||||
name varchar(60)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS publication`,
|
||||
`CREATE TABLE publication (
|
||||
id integer primary key,
|
||||
title varchar(80),
|
||||
author_id integer
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS review`,
|
||||
`CREATE TABLE review (
|
||||
id integer primary key,
|
||||
publication_id integer,
|
||||
name varchar(80),
|
||||
comments text,
|
||||
created datetime
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS data_types`,
|
||||
`CREATE TABLE data_types (
|
||||
id integer primary key,
|
||||
_uint integer,
|
||||
_uintptr integer,
|
||||
_uint8 integer,
|
||||
_uint16 int,
|
||||
_uint32 int,
|
||||
_uint64 int,
|
||||
_int integer,
|
||||
_int8 integer,
|
||||
_int16 integer,
|
||||
_int32 integer,
|
||||
_int64 integer,
|
||||
_float32 real,
|
||||
_float64 real,
|
||||
_byte integer,
|
||||
_rune integer,
|
||||
_bool integer,
|
||||
_string text,
|
||||
_blob blob,
|
||||
_date datetime,
|
||||
_nildate datetime,
|
||||
_ptrdate datetime,
|
||||
_defaultdate datetime default current_timestamp,
|
||||
_time text
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS stats_test`,
|
||||
`CREATE TABLE stats_test (
|
||||
id integer primary key,
|
||||
numeric integer,
|
||||
value integer
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS composite_keys`,
|
||||
`CREATE TABLE composite_keys (
|
||||
code VARCHAR(255) default '',
|
||||
user_id VARCHAR(255) default '',
|
||||
some_val VARCHAR(255) default '',
|
||||
primary key (code, user_id)
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "birthdays"`,
|
||||
`CREATE TABLE "birthdays" (
|
||||
"id" INTEGER PRIMARY KEY,
|
||||
"name" VARCHAR(50) DEFAULT NULL,
|
||||
"born" DATETIME DEFAULT NULL,
|
||||
"born_ut" INTEGER
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "fibonacci"`,
|
||||
`CREATE TABLE "fibonacci" (
|
||||
"id" INTEGER PRIMARY KEY,
|
||||
"input" INTEGER,
|
||||
"output" INTEGER
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "is_even"`,
|
||||
`CREATE TABLE "is_even" (
|
||||
"input" INTEGER,
|
||||
"is_even" INTEGER
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS "CaSe_TesT"`,
|
||||
`CREATE TABLE "CaSe_TesT" (
|
||||
"id" INTEGER PRIMARY KEY,
|
||||
"case_test" VARCHAR
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS accounts`,
|
||||
`CREATE TABLE accounts (
|
||||
id integer primary key,
|
||||
name varchar,
|
||||
disabled integer,
|
||||
created_at datetime default current_timestamp
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS users`,
|
||||
`CREATE TABLE users (
|
||||
id integer primary key,
|
||||
account_id integer,
|
||||
username varchar UNIQUE
|
||||
)`,
|
||||
|
||||
`DROP TABLE IF EXISTS logs`,
|
||||
`CREATE TABLE logs (
|
||||
id integer primary key,
|
||||
message VARCHAR
|
||||
)`,
|
||||
|
||||
`COMMIT`,
|
||||
}
|
||||
|
||||
for _, query := range batch {
|
||||
driver := h.sess.Driver().(*sql.DB)
|
||||
if _, err := driver.Exec(query); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ testsuite.Helper = &Helper{}
|
20
adapter/sqlite/record_test.go
Normal file
20
adapter/sqlite/record_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type RecordTests struct {
|
||||
testsuite.RecordTestSuite
|
||||
}
|
||||
|
||||
func (s *RecordTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestRecord(t *testing.T) {
|
||||
suite.Run(t, &RecordTests{})
|
||||
}
|
20
adapter/sqlite/sql_test.go
Normal file
20
adapter/sqlite/sql_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type SQLTests struct {
|
||||
testsuite.SQLTestSuite
|
||||
}
|
||||
|
||||
func (s *SQLTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func TestSQL(t *testing.T) {
|
||||
suite.Run(t, &SQLTests{})
|
||||
}
|
30
adapter/sqlite/sqlite.go
Normal file
30
adapter/sqlite/sqlite.go
Normal file
@ -0,0 +1,30 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
// Adapter is the public name of the adapter.
|
||||
const Adapter = `sqlite`
|
||||
|
||||
var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{})
|
||||
|
||||
// Open establishes a connection to the database server and returns a
|
||||
// mydb.Session instance (which is compatible with mydb.Session).
|
||||
func Open(connURL mydb.ConnectionURL) (mydb.Session, error) {
|
||||
return registeredAdapter.OpenDSN(connURL)
|
||||
}
|
||||
|
||||
// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value.
|
||||
func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
|
||||
return registeredAdapter.NewTx(sqlTx)
|
||||
}
|
||||
|
||||
// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value.
|
||||
func New(sqlDB *sql.DB) (mydb.Session, error) {
|
||||
return registeredAdapter.New(sqlDB)
|
||||
}
|
55
adapter/sqlite/sqlite_test.go
Normal file
55
adapter/sqlite/sqlite_test.go
Normal file
@ -0,0 +1,55 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"database/sql"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/testsuite"
|
||||
"github.com/stretchr/testify/suite"
|
||||
)
|
||||
|
||||
type AdapterTests struct {
|
||||
testsuite.Suite
|
||||
}
|
||||
|
||||
func (s *AdapterTests) SetupSuite() {
|
||||
s.Helper = &Helper{}
|
||||
}
|
||||
|
||||
func (s *AdapterTests) Test_Issue633_OpenSession() {
|
||||
sess, err := Open(settings)
|
||||
s.NoError(err)
|
||||
defer sess.Close()
|
||||
|
||||
absoluteName, _ := filepath.Abs(settings.Database)
|
||||
s.Equal(absoluteName, sess.Name())
|
||||
}
|
||||
|
||||
func (s *AdapterTests) Test_Issue633_NewAdapterWithFile() {
|
||||
sqldb, err := sql.Open("sqlite3", settings.Database)
|
||||
s.NoError(err)
|
||||
|
||||
sess, err := New(sqldb)
|
||||
s.NoError(err)
|
||||
defer sess.Close()
|
||||
|
||||
absoluteName, _ := filepath.Abs(settings.Database)
|
||||
s.Equal(absoluteName, sess.Name())
|
||||
}
|
||||
|
||||
func (s *AdapterTests) Test_Issue633_NewAdapterWithMemory() {
|
||||
sqldb, err := sql.Open("sqlite3", ":memory:")
|
||||
s.NoError(err)
|
||||
|
||||
sess, err := New(sqldb)
|
||||
s.NoError(err)
|
||||
defer sess.Close()
|
||||
|
||||
s.Equal("main", sess.Name())
|
||||
}
|
||||
|
||||
func TestAdapter(t *testing.T) {
|
||||
suite.Run(t, &AdapterTests{})
|
||||
}
|
187
adapter/sqlite/template.go
Normal file
187
adapter/sqlite/template.go
Normal file
@ -0,0 +1,187 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"git.hexq.cn/tiglog/mydb/internal/cache"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
)
|
||||
|
||||
const (
|
||||
adapterColumnSeparator = `.`
|
||||
adapterIdentifierSeparator = `, `
|
||||
adapterIdentifierQuote = `"{{.Value}}"`
|
||||
adapterValueSeparator = `, `
|
||||
adapterValueQuote = `'{{.}}'`
|
||||
adapterAndKeyword = `AND`
|
||||
adapterOrKeyword = `OR`
|
||||
adapterDescKeyword = `DESC`
|
||||
adapterAscKeyword = `ASC`
|
||||
adapterAssignmentOperator = `=`
|
||||
adapterClauseGroup = `({{.}})`
|
||||
adapterClauseOperator = ` {{.}} `
|
||||
adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}`
|
||||
adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
|
||||
adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
|
||||
adapterSortByColumnLayout = `{{.Column}} {{.Order}}`
|
||||
|
||||
adapterOrderByLayout = `
|
||||
{{if .SortColumns}}
|
||||
ORDER BY {{.SortColumns}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterWhereLayout = `
|
||||
{{if .Conds}}
|
||||
WHERE {{.Conds}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterUsingLayout = `
|
||||
{{if .Columns}}
|
||||
USING ({{.Columns}})
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterJoinLayout = `
|
||||
{{if .Table}}
|
||||
{{ if .On }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{.On}}
|
||||
{{ else if .Using }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{.Using}}
|
||||
{{ else if .Type | eq "CROSS" }}
|
||||
{{.Type}} JOIN {{.Table}}
|
||||
{{else}}
|
||||
NATURAL {{.Type}} JOIN {{.Table}}
|
||||
{{end}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterOnLayout = `
|
||||
{{if .Conds}}
|
||||
ON {{.Conds}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterSelectLayout = `
|
||||
SELECT
|
||||
{{if .Distinct}}
|
||||
DISTINCT
|
||||
{{end}}
|
||||
|
||||
{{if defined .Columns}}
|
||||
{{.Columns | compile}}
|
||||
{{else}}
|
||||
*
|
||||
{{end}}
|
||||
|
||||
{{if defined .Table}}
|
||||
FROM {{.Table | compile}}
|
||||
{{end}}
|
||||
|
||||
{{.Joins | compile}}
|
||||
|
||||
{{.Where | compile}}
|
||||
|
||||
{{if defined .GroupBy}}
|
||||
{{.GroupBy | compile}}
|
||||
{{end}}
|
||||
|
||||
{{.OrderBy | compile}}
|
||||
|
||||
{{if .Limit}}
|
||||
LIMIT {{.Limit}}
|
||||
{{end}}
|
||||
|
||||
{{if .Offset}}
|
||||
{{if not .Limit}}
|
||||
LIMIT -1
|
||||
{{end}}
|
||||
OFFSET {{.Offset}}
|
||||
{{end}}
|
||||
`
|
||||
adapterDeleteLayout = `
|
||||
DELETE
|
||||
FROM {{.Table | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
adapterUpdateLayout = `
|
||||
UPDATE
|
||||
{{.Table | compile}}
|
||||
SET {{.ColumnValues | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
|
||||
adapterSelectCountLayout = `
|
||||
SELECT
|
||||
COUNT(1) AS _t
|
||||
FROM {{.Table | compile}}
|
||||
{{.Where | compile}}
|
||||
`
|
||||
|
||||
adapterInsertLayout = `
|
||||
INSERT INTO {{.Table | compile}}
|
||||
{{if .Columns }}({{.Columns | compile}}){{end}}
|
||||
{{if defined .Values}}
|
||||
VALUES
|
||||
{{.Values | compile}}
|
||||
{{else}}
|
||||
DEFAULT VALUES
|
||||
{{end}}
|
||||
{{if defined .Returning}}
|
||||
RETURNING {{.Returning | compile}}
|
||||
{{end}}
|
||||
`
|
||||
|
||||
adapterTruncateLayout = `
|
||||
DELETE FROM {{.Table | compile}}
|
||||
`
|
||||
|
||||
adapterDropDatabaseLayout = `
|
||||
DROP DATABASE {{.Database | compile}}
|
||||
`
|
||||
|
||||
adapterDropTableLayout = `
|
||||
DROP TABLE {{.Table | compile}}
|
||||
`
|
||||
|
||||
adapterGroupByLayout = `
|
||||
{{if .GroupColumns}}
|
||||
GROUP BY {{.GroupColumns}}
|
||||
{{end}}
|
||||
`
|
||||
)
|
||||
|
||||
var template = &exql.Template{
|
||||
ColumnSeparator: adapterColumnSeparator,
|
||||
IdentifierSeparator: adapterIdentifierSeparator,
|
||||
IdentifierQuote: adapterIdentifierQuote,
|
||||
ValueSeparator: adapterValueSeparator,
|
||||
ValueQuote: adapterValueQuote,
|
||||
AndKeyword: adapterAndKeyword,
|
||||
OrKeyword: adapterOrKeyword,
|
||||
DescKeyword: adapterDescKeyword,
|
||||
AscKeyword: adapterAscKeyword,
|
||||
AssignmentOperator: adapterAssignmentOperator,
|
||||
ClauseGroup: adapterClauseGroup,
|
||||
ClauseOperator: adapterClauseOperator,
|
||||
ColumnValue: adapterColumnValue,
|
||||
TableAliasLayout: adapterTableAliasLayout,
|
||||
ColumnAliasLayout: adapterColumnAliasLayout,
|
||||
SortByColumnLayout: adapterSortByColumnLayout,
|
||||
WhereLayout: adapterWhereLayout,
|
||||
JoinLayout: adapterJoinLayout,
|
||||
OnLayout: adapterOnLayout,
|
||||
UsingLayout: adapterUsingLayout,
|
||||
OrderByLayout: adapterOrderByLayout,
|
||||
InsertLayout: adapterInsertLayout,
|
||||
SelectLayout: adapterSelectLayout,
|
||||
UpdateLayout: adapterUpdateLayout,
|
||||
DeleteLayout: adapterDeleteLayout,
|
||||
TruncateLayout: adapterTruncateLayout,
|
||||
DropDatabaseLayout: adapterDropDatabaseLayout,
|
||||
DropTableLayout: adapterDropTableLayout,
|
||||
CountLayout: adapterSelectCountLayout,
|
||||
GroupByLayout: adapterGroupByLayout,
|
||||
Cache: cache.NewCache(),
|
||||
}
|
246
adapter/sqlite/template_test.go
Normal file
246
adapter/sqlite/template_test.go
Normal file
@ -0,0 +1,246 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestTemplateSelect(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist"`,
|
||||
b.SelectFrom("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist"`,
|
||||
b.Select().From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" DESC`,
|
||||
b.Select().From("artist").OrderBy("name DESC").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" DESC`,
|
||||
b.Select().From("artist").OrderBy("-name").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" ASC`,
|
||||
b.Select().From("artist").OrderBy("name").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" ORDER BY "name" ASC`,
|
||||
b.Select().From("artist").OrderBy("name ASC").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" LIMIT -1 OFFSET 5`,
|
||||
b.Select().From("artist").Limit(-1).Offset(5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id" FROM "artist"`,
|
||||
b.Select("id").From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id", "name" FROM "artist"`,
|
||||
b.Select("id", "name").From("artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE ("name" = $1)`,
|
||||
b.SelectFrom("artist").Where("name", "Haruki").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE (name LIKE $1)`,
|
||||
b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`,
|
||||
b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE ("id" > $1)`,
|
||||
b.SelectFrom("artist").Where("id >", 2).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`,
|
||||
b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`,
|
||||
b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" WHERE (name IS NOT NULL)`,
|
||||
b.SelectFrom("artist").Where("name IS NOT NULL").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`,
|
||||
b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT "id" FROM "artist" NATURAL JOIN "publication"`,
|
||||
b.Select("id").From("artist").Join("publication").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`,
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`,
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`,
|
||||
b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`,
|
||||
b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`,
|
||||
b.SelectFrom("artist a").
|
||||
LeftJoin("publication p1").On("p1.id = a.id").
|
||||
RightJoin("publication p2").On("p2.id = a.id").
|
||||
String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" CROSS JOIN "publication"`,
|
||||
b.SelectFrom("artist").CrossJoin("publication").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT * FROM "artist" JOIN "publication" USING ("id")`,
|
||||
b.SelectFrom("artist").Join("publication").Using("id").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`SELECT DATE()`,
|
||||
b.Select(mydb.Raw("DATE()")).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateInsert(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`,
|
||||
b.InsertInto("artist").
|
||||
Values(10, "Ryuichi Sakamoto").
|
||||
Values(11, "Alondra de la Parra").
|
||||
Values(12, "Haruki Murakami").
|
||||
String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`,
|
||||
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Values(struct {
|
||||
ID int `db:"id"`
|
||||
Name string `db:"name"`
|
||||
}{12, "Chavela Vargas"}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`,
|
||||
b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateUpdate(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1`,
|
||||
b.Update("artist").Set("name", "Artist").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
|
||||
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
|
||||
b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
|
||||
b.Update("artist").Set(struct {
|
||||
Nombre string `db:"name"`
|
||||
}{"Artist"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`,
|
||||
b.Update("artist").Set(struct {
|
||||
Nombre string `db:"name"`
|
||||
}{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`,
|
||||
b.Update("artist").Set(
|
||||
"name = ? || ' ' || ? || id", "Artist", "#",
|
||||
"id = id + ?", 10,
|
||||
).Where("id > ?", 0).String(),
|
||||
)
|
||||
}
|
||||
|
||||
func TestTemplateDelete(t *testing.T) {
|
||||
b := sqlbuilder.WithTemplate(template)
|
||||
assert := assert.New(t)
|
||||
|
||||
assert.Equal(
|
||||
`DELETE FROM "artist" WHERE (name = $1)`,
|
||||
b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(),
|
||||
)
|
||||
|
||||
assert.Equal(
|
||||
`DELETE FROM "artist" WHERE (id > 5)`,
|
||||
b.DeleteFrom("artist").Where("id > 5").String(),
|
||||
)
|
||||
}
|
468
clauses.go
Normal file
468
clauses.go
Normal file
@ -0,0 +1,468 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Selector represents a SELECT statement.
|
||||
type Selector interface {
|
||||
// Columns defines which columns to retrive.
|
||||
//
|
||||
// You should call From() after Columns() if you want to query data from an
|
||||
// specific table.
|
||||
//
|
||||
// s.Columns("name", "last_name").From(...)
|
||||
//
|
||||
// It is also possible to use an alias for the column, this could be handy if
|
||||
// you plan to use the alias later, use the "AS" keyword to denote an alias.
|
||||
//
|
||||
// s.Columns("name AS n")
|
||||
//
|
||||
// or the shortcut:
|
||||
//
|
||||
// s.Columns("name n")
|
||||
//
|
||||
// If you don't want the column to be escaped use the db.Raw
|
||||
// function.
|
||||
//
|
||||
// s.Columns(db.Raw("MAX(id)"))
|
||||
//
|
||||
// The above statement is equivalent to:
|
||||
//
|
||||
// s.Columns(db.Func("MAX", "id"))
|
||||
Columns(columns ...interface{}) Selector
|
||||
|
||||
// From represents a FROM clause and is tipically used after Columns().
|
||||
//
|
||||
// FROM defines from which table data is going to be retrieved
|
||||
//
|
||||
// s.Columns(...).From("people")
|
||||
//
|
||||
// It is also possible to use an alias for the table, this could be handy if
|
||||
// you plan to use the alias later:
|
||||
//
|
||||
// s.Columns(...).From("people AS p").Where("p.name = ?", ...)
|
||||
//
|
||||
// Or with the shortcut:
|
||||
//
|
||||
// s.Columns(...).From("people p").Where("p.name = ?", ...)
|
||||
From(tables ...interface{}) Selector
|
||||
|
||||
// Distict represents a DISTINCT clause
|
||||
//
|
||||
// DISTINCT is used to ask the database to return only values that are
|
||||
// different.
|
||||
Distinct(columns ...interface{}) Selector
|
||||
|
||||
// As defines an alias for a table.
|
||||
As(string) Selector
|
||||
|
||||
// Where specifies the conditions that columns must match in order to be
|
||||
// retrieved.
|
||||
//
|
||||
// Where accepts raw strings and fmt.Stringer to define conditions and
|
||||
// interface{} to specify parameters. Be careful not to embed any parameters
|
||||
// within the SQL part as that could lead to security problems. You can use
|
||||
// que question mark (?) as placeholder for parameters.
|
||||
//
|
||||
// s.Where("name = ?", "max")
|
||||
//
|
||||
// s.Where("name = ? AND last_name = ?", "Mary", "Doe")
|
||||
//
|
||||
// s.Where("last_name IS NULL")
|
||||
//
|
||||
// You can also use other types of parameters besides only strings, like:
|
||||
//
|
||||
// s.Where("online = ? AND last_logged <= ?", true, time.Now())
|
||||
//
|
||||
// and Where() will transform them into strings before feeding them to the
|
||||
// database.
|
||||
//
|
||||
// When an unknown type is provided, Where() will first try to match it with
|
||||
// the Marshaler interface, then with fmt.Stringer and finally, if the
|
||||
// argument does not satisfy any of those interfaces Where() will use
|
||||
// fmt.Sprintf("%v", arg) to transform the type into a string.
|
||||
//
|
||||
// Subsequent calls to Where() will overwrite previously set conditions, if
|
||||
// you want these new conditions to be appended use And() instead.
|
||||
Where(conds ...interface{}) Selector
|
||||
|
||||
// And appends more constraints to the WHERE clause without overwriting
|
||||
// conditions that have been already set.
|
||||
And(conds ...interface{}) Selector
|
||||
|
||||
// GroupBy represents a GROUP BY statement.
|
||||
//
|
||||
// GROUP BY defines which columns should be used to aggregate and group
|
||||
// results.
|
||||
//
|
||||
// s.GroupBy("country_id")
|
||||
//
|
||||
// GroupBy accepts more than one column:
|
||||
//
|
||||
// s.GroupBy("country_id", "city_id")
|
||||
GroupBy(columns ...interface{}) Selector
|
||||
|
||||
// Having(...interface{}) Selector
|
||||
|
||||
// OrderBy represents a ORDER BY statement.
|
||||
//
|
||||
// ORDER BY is used to define which columns are going to be used to sort
|
||||
// results.
|
||||
//
|
||||
// Use the column name to sort results in ascendent order.
|
||||
//
|
||||
// // "last_name" ASC
|
||||
// s.OrderBy("last_name")
|
||||
//
|
||||
// Prefix the column name with the minus sign (-) to sort results in
|
||||
// descendent order.
|
||||
//
|
||||
// // "last_name" DESC
|
||||
// s.OrderBy("-last_name")
|
||||
//
|
||||
// If you would rather be very explicit, you can also use ASC and DESC.
|
||||
//
|
||||
// s.OrderBy("last_name ASC")
|
||||
//
|
||||
// s.OrderBy("last_name DESC", "name ASC")
|
||||
OrderBy(columns ...interface{}) Selector
|
||||
|
||||
// Join represents a JOIN statement.
|
||||
//
|
||||
// JOIN statements are used to define external tables that the user wants to
|
||||
// include as part of the result.
|
||||
//
|
||||
// You can use the On() method after Join() to define the conditions of the
|
||||
// join.
|
||||
//
|
||||
// s.Join("author").On("author.id = book.author_id")
|
||||
//
|
||||
// If you don't specify conditions for the join, a NATURAL JOIN will be used.
|
||||
//
|
||||
// On() accepts the same arguments as Where()
|
||||
//
|
||||
// You can also use Using() after Join().
|
||||
//
|
||||
// s.Join("employee").Using("department_id")
|
||||
Join(table ...interface{}) Selector
|
||||
|
||||
// FullJoin is like Join() but with FULL JOIN.
|
||||
FullJoin(...interface{}) Selector
|
||||
|
||||
// CrossJoin is like Join() but with CROSS JOIN.
|
||||
CrossJoin(...interface{}) Selector
|
||||
|
||||
// RightJoin is like Join() but with RIGHT JOIN.
|
||||
RightJoin(...interface{}) Selector
|
||||
|
||||
// LeftJoin is like Join() but with LEFT JOIN.
|
||||
LeftJoin(...interface{}) Selector
|
||||
|
||||
// Using represents the USING clause.
|
||||
//
|
||||
// USING is used to specifiy columns to join results.
|
||||
//
|
||||
// s.LeftJoin(...).Using("country_id")
|
||||
Using(...interface{}) Selector
|
||||
|
||||
// On represents the ON clause.
|
||||
//
|
||||
// ON is used to define conditions on a join.
|
||||
//
|
||||
// s.Join(...).On("b.author_id = a.id")
|
||||
On(...interface{}) Selector
|
||||
|
||||
// Limit represents the LIMIT parameter.
|
||||
//
|
||||
// LIMIT defines the maximum number of rows to return from the table. A
|
||||
// negative limit cancels any previous limit settings.
|
||||
//
|
||||
// s.Limit(42)
|
||||
Limit(int) Selector
|
||||
|
||||
// Offset represents the OFFSET parameter.
|
||||
//
|
||||
// OFFSET defines how many results are going to be skipped before starting to
|
||||
// return results. A negative offset cancels any previous offset settings.
|
||||
//
|
||||
// s.Offset(56)
|
||||
Offset(int) Selector
|
||||
|
||||
// Amend lets you alter the query's text just before sending it to the
|
||||
// database server.
|
||||
Amend(func(queryIn string) (queryOut string)) Selector
|
||||
|
||||
// Paginate returns a paginator that can display a paginated lists of items.
|
||||
// Paginators ignore previous Offset and Limit settings. Page numbering
|
||||
// starts at 1.
|
||||
Paginate(uint) Paginator
|
||||
|
||||
// Iterator provides methods to iterate over the results returned by the
|
||||
// Selector.
|
||||
Iterator() Iterator
|
||||
|
||||
// IteratorContext provides methods to iterate over the results returned by
|
||||
// the Selector.
|
||||
IteratorContext(ctx context.Context) Iterator
|
||||
|
||||
// SQLPreparer provides methods for creating prepared statements.
|
||||
SQLPreparer
|
||||
|
||||
// SQLGetter provides methods to compile and execute a query that returns
|
||||
// results.
|
||||
SQLGetter
|
||||
|
||||
// ResultMapper provides methods to retrieve and map results.
|
||||
ResultMapper
|
||||
|
||||
// fmt.Stringer provides `String() string`, you can use `String()` to compile
|
||||
// the `Selector` into a string.
|
||||
fmt.Stringer
|
||||
|
||||
// Arguments returns the arguments that are prepared for this query.
|
||||
Arguments() []interface{}
|
||||
}
|
||||
|
||||
// Inserter represents an INSERT statement.
|
||||
type Inserter interface {
|
||||
// Columns represents the COLUMNS clause.
|
||||
//
|
||||
// COLUMNS defines the columns that we are going to provide values for.
|
||||
//
|
||||
// i.Columns("name", "last_name").Values(...)
|
||||
Columns(...string) Inserter
|
||||
|
||||
// Values represents the VALUES clause.
|
||||
//
|
||||
// VALUES defines the values of the columns.
|
||||
//
|
||||
// i.Columns(...).Values("María", "Méndez")
|
||||
//
|
||||
// i.Values(map[string][string]{"name": "María"})
|
||||
Values(...interface{}) Inserter
|
||||
|
||||
// Arguments returns the arguments that are prepared for this query.
|
||||
Arguments() []interface{}
|
||||
|
||||
// Returning represents a RETURNING clause.
|
||||
//
|
||||
// RETURNING specifies which columns should be returned after INSERT.
|
||||
//
|
||||
// RETURNING may not be supported by all SQL databases.
|
||||
Returning(columns ...string) Inserter
|
||||
|
||||
// Iterator provides methods to iterate over the results returned by the
|
||||
// Inserter. This is only possible when using Returning().
|
||||
Iterator() Iterator
|
||||
|
||||
// IteratorContext provides methods to iterate over the results returned by
|
||||
// the Inserter. This is only possible when using Returning().
|
||||
IteratorContext(ctx context.Context) Iterator
|
||||
|
||||
// Amend lets you alter the query's text just before sending it to the
|
||||
// database server.
|
||||
Amend(func(queryIn string) (queryOut string)) Inserter
|
||||
|
||||
// Batch provies a BatchInserter that can be used to insert many elements at
|
||||
// once by issuing several calls to Values(). It accepts a size parameter
|
||||
// which defines the batch size. If size is < 1, the batch size is set to 1.
|
||||
Batch(size int) BatchInserter
|
||||
|
||||
// SQLExecer provides the Exec method.
|
||||
SQLExecer
|
||||
|
||||
// SQLPreparer provides methods for creating prepared statements.
|
||||
SQLPreparer
|
||||
|
||||
// SQLGetter provides methods to return query results from INSERT statements
|
||||
// that support such feature (e.g.: queries with Returning).
|
||||
SQLGetter
|
||||
|
||||
// fmt.Stringer provides `String() string`, you can use `String()` to compile
|
||||
// the `Inserter` into a string.
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
// Deleter represents a DELETE statement.
|
||||
type Deleter interface {
|
||||
// Where represents the WHERE clause.
|
||||
//
|
||||
// See Selector.Where for documentation and usage examples.
|
||||
Where(...interface{}) Deleter
|
||||
|
||||
// And appends more constraints to the WHERE clause without overwriting
|
||||
// conditions that have been already set.
|
||||
And(conds ...interface{}) Deleter
|
||||
|
||||
// Limit represents the LIMIT clause.
|
||||
//
|
||||
// See Selector.Limit for documentation and usage examples.
|
||||
Limit(int) Deleter
|
||||
|
||||
// Amend lets you alter the query's text just before sending it to the
|
||||
// database server.
|
||||
Amend(func(queryIn string) (queryOut string)) Deleter
|
||||
|
||||
// SQLPreparer provides methods for creating prepared statements.
|
||||
SQLPreparer
|
||||
|
||||
// SQLExecer provides the Exec method.
|
||||
SQLExecer
|
||||
|
||||
// fmt.Stringer provides `String() string`, you can use `String()` to compile
|
||||
// the `Inserter` into a string.
|
||||
fmt.Stringer
|
||||
|
||||
// Arguments returns the arguments that are prepared for this query.
|
||||
Arguments() []interface{}
|
||||
}
|
||||
|
||||
// Updater represents an UPDATE statement.
|
||||
type Updater interface {
|
||||
// Set represents the SET clause.
|
||||
Set(...interface{}) Updater
|
||||
|
||||
// Where represents the WHERE clause.
|
||||
//
|
||||
// See Selector.Where for documentation and usage examples.
|
||||
Where(...interface{}) Updater
|
||||
|
||||
// And appends more constraints to the WHERE clause without overwriting
|
||||
// conditions that have been already set.
|
||||
And(conds ...interface{}) Updater
|
||||
|
||||
// Limit represents the LIMIT parameter.
|
||||
//
|
||||
// See Selector.Limit for documentation and usage examples.
|
||||
Limit(int) Updater
|
||||
|
||||
// SQLPreparer provides methods for creating prepared statements.
|
||||
SQLPreparer
|
||||
|
||||
// SQLExecer provides the Exec method.
|
||||
SQLExecer
|
||||
|
||||
// fmt.Stringer provides `String() string`, you can use `String()` to compile
|
||||
// the `Inserter` into a string.
|
||||
fmt.Stringer
|
||||
|
||||
// Arguments returns the arguments that are prepared for this query.
|
||||
Arguments() []interface{}
|
||||
|
||||
// Amend lets you alter the query's text just before sending it to the
|
||||
// database server.
|
||||
Amend(func(queryIn string) (queryOut string)) Updater
|
||||
}
|
||||
|
||||
// Paginator provides tools for splitting the results of a query into chunks
|
||||
// containing a fixed number of items.
|
||||
type Paginator interface {
|
||||
// Page sets the page number.
|
||||
Page(uint) Paginator
|
||||
|
||||
// Cursor defines the column that is going to be taken as basis for
|
||||
// cursor-based pagination.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// a = q.Paginate(10).Cursor("id")
|
||||
// b = q.Paginate(12).Cursor("-id")
|
||||
//
|
||||
// You can set "" as cursorColumn to disable cursors.
|
||||
Cursor(cursorColumn string) Paginator
|
||||
|
||||
// NextPage returns the next page according to the cursor. It expects a
|
||||
// cursorValue, which is the value the cursor column has on the last item of
|
||||
// the current result set (lower bound).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// p = q.NextPage(items[len(items)-1].ID)
|
||||
NextPage(cursorValue interface{}) Paginator
|
||||
|
||||
// PrevPage returns the previous page according to the cursor. It expects a
|
||||
// cursorValue, which is the value the cursor column has on the fist item of
|
||||
// the current result set (mydb bound).
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
// p = q.PrevPage(items[0].ID)
|
||||
PrevPage(cursorValue interface{}) Paginator
|
||||
|
||||
// TotalPages returns the total number of pages in the query.
|
||||
TotalPages() (uint, error)
|
||||
|
||||
// TotalEntries returns the total number of entries in the query.
|
||||
TotalEntries() (uint64, error)
|
||||
|
||||
// SQLPreparer provides methods for creating prepared statements.
|
||||
SQLPreparer
|
||||
|
||||
// SQLGetter provides methods to compile and execute a query that returns
|
||||
// results.
|
||||
SQLGetter
|
||||
|
||||
// Iterator provides methods to iterate over the results returned by the
|
||||
// Selector.
|
||||
Iterator() Iterator
|
||||
|
||||
// IteratorContext provides methods to iterate over the results returned by
|
||||
// the Selector.
|
||||
IteratorContext(ctx context.Context) Iterator
|
||||
|
||||
// ResultMapper provides methods to retrieve and map results.
|
||||
ResultMapper
|
||||
|
||||
// fmt.Stringer provides `String() string`, you can use `String()` to compile
|
||||
// the `Selector` into a string.
|
||||
fmt.Stringer
|
||||
|
||||
// Arguments returns the arguments that are prepared for this query.
|
||||
Arguments() []interface{}
|
||||
}
|
||||
|
||||
// ResultMapper defined methods for a result mapper.
|
||||
type ResultMapper interface {
|
||||
// All dumps all the results into the given slice, All() expects a pointer to
|
||||
// slice of maps or structs.
|
||||
//
|
||||
// The behaviour of One() extends to each one of the results.
|
||||
All(destSlice interface{}) error
|
||||
|
||||
// One maps the row that is in the current query cursor into the
|
||||
// given interface, which can be a pointer to either a map or a
|
||||
// struct.
|
||||
//
|
||||
// If dest is a pointer to map, each one of the columns will create a new map
|
||||
// key and the values of the result will be set as values for the keys.
|
||||
//
|
||||
// Depending on the type of map key and value, the results columns and values
|
||||
// may need to be transformed.
|
||||
//
|
||||
// If dest if a pointer to struct, each one of the fields will be tested for
|
||||
// a `db` tag which defines the column mapping. The value of the result will
|
||||
// be set as the value of the field.
|
||||
One(dest interface{}) error
|
||||
}
|
||||
|
||||
// BatchInserter provides an interface to do massive insertions in batches.
|
||||
type BatchInserter interface {
|
||||
// Values pushes column values to be inserted as part of the batch.
|
||||
Values(...interface{}) BatchInserter
|
||||
|
||||
// NextResult dumps the next slice of results to dst, which can mean having
|
||||
// the IDs of all inserted elements in the batch.
|
||||
NextResult(dst interface{}) bool
|
||||
|
||||
// Done signals that no more elements are going to be added.
|
||||
Done()
|
||||
|
||||
// Wait blocks until the whole batch is executed.
|
||||
Wait() error
|
||||
|
||||
// Err returns the last error that happened while executing the batch (or nil
|
||||
// if no error happened).
|
||||
Err() error
|
||||
}
|
45
collection.go
Normal file
45
collection.go
Normal file
@ -0,0 +1,45 @@
|
||||
package mydb
|
||||
|
||||
// Collection defines methods to work with database tables or collections.
|
||||
type Collection interface {
|
||||
|
||||
// Name returns the name of the collection.
|
||||
Name() string
|
||||
|
||||
// Session returns the Session that was used to create the collection
|
||||
// reference.
|
||||
Session() Session
|
||||
|
||||
// Find defines a new result set.
|
||||
Find(...interface{}) Result
|
||||
|
||||
Count() (uint64, error)
|
||||
|
||||
// Insert inserts a new item into the collection, the type of this item could
|
||||
// be a map, a struct or pointer to either of them. If the call succeeds and
|
||||
// if the collection has a primary key, Insert returns the ID of the newly
|
||||
// added element as an `interface{}`. The underlying type of this ID depends
|
||||
// on both the database adapter and the column storing the ID. The ID
|
||||
// returned by Insert() could be passed directly to Find() to retrieve the
|
||||
// newly added element.
|
||||
Insert(interface{}) (InsertResult, error)
|
||||
|
||||
// InsertReturning is like Insert() but it takes a pointer to map or struct
|
||||
// and, if the operation succeeds, updates it with data from the newly
|
||||
// inserted row. If the database does not support transactions this method
|
||||
// returns db.ErrUnsupported.
|
||||
InsertReturning(interface{}) error
|
||||
|
||||
// UpdateReturning takes a pointer to a map or struct and tries to update the
|
||||
// row the item is refering to. If the element is updated sucessfully,
|
||||
// UpdateReturning will fetch the row and update the fields of the passed
|
||||
// item. If the database does not support transactions this method returns
|
||||
// db.ErrUnsupported
|
||||
UpdateReturning(interface{}) error
|
||||
|
||||
// Exists returns true if the collection exists, false otherwise.
|
||||
Exists() (bool, error)
|
||||
|
||||
// Truncate removes all elements on the collection.
|
||||
Truncate() error
|
||||
}
|
158
comparison.go
Normal file
158
comparison.go
Normal file
@ -0,0 +1,158 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/adapter"
|
||||
)
|
||||
|
||||
// Comparison represents a relationship between values.
|
||||
type Comparison struct {
|
||||
*adapter.Comparison
|
||||
}
|
||||
|
||||
// Gte is a comparison that means: is greater than or equal to value.
|
||||
func Gte(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)}
|
||||
}
|
||||
|
||||
// Lte is a comparison that means: is less than or equal to value.
|
||||
func Lte(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)}
|
||||
}
|
||||
|
||||
// Eq is a comparison that means: is equal to value.
|
||||
func Eq(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, value)}
|
||||
}
|
||||
|
||||
// NotEq is a comparison that means: is not equal to value.
|
||||
func NotEq(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, value)}
|
||||
}
|
||||
|
||||
// Gt is a comparison that means: is greater than value.
|
||||
func Gt(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)}
|
||||
}
|
||||
|
||||
// Lt is a comparison that means: is less than value.
|
||||
func Lt(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)}
|
||||
}
|
||||
|
||||
// In is a comparison that means: is any of the values.
|
||||
func In(value ...interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))}
|
||||
}
|
||||
|
||||
// AnyOf is a comparison that means: is any of the values of the slice.
|
||||
func AnyOf(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))}
|
||||
}
|
||||
|
||||
// NotIn is a comparison that means: is none of the values.
|
||||
func NotIn(value ...interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))}
|
||||
}
|
||||
|
||||
// NotAnyOf is a comparison that means: is none of the values of the slice.
|
||||
func NotAnyOf(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))}
|
||||
}
|
||||
|
||||
// After is a comparison that means: is after the (time.Time) value.
|
||||
func After(value time.Time) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)}
|
||||
}
|
||||
|
||||
// Before is a comparison that means: is before the (time.Time) value.
|
||||
func Before(value time.Time) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)}
|
||||
}
|
||||
|
||||
// OnOrAfter is a comparison that means: is on or after the (time.Time) value.
|
||||
func OnOrAfter(value time.Time) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)}
|
||||
}
|
||||
|
||||
// OnOrBefore is a comparison that means: is on or before the (time.Time) value.
|
||||
func OnOrBefore(value time.Time) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)}
|
||||
}
|
||||
|
||||
// Between is a comparison that means: is between lowerBound and upperBound.
|
||||
func Between(lowerBound interface{}, upperBound interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{lowerBound, upperBound})}
|
||||
}
|
||||
|
||||
// NotBetween is a comparison that means: is not between lowerBound and upperBound.
|
||||
func NotBetween(lowerBound interface{}, upperBound interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{lowerBound, upperBound})}
|
||||
}
|
||||
|
||||
// Is is a comparison that means: is equivalent to nil, true or false.
|
||||
func Is(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, value)}
|
||||
}
|
||||
|
||||
// IsNot is a comparison that means: is not equivalent to nil, true nor false.
|
||||
func IsNot(value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, value)}
|
||||
}
|
||||
|
||||
// IsNull is a comparison that means: is equivalent to nil.
|
||||
func IsNull() *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil)}
|
||||
}
|
||||
|
||||
// IsNotNull is a comparison that means: is not equivalent to nil.
|
||||
func IsNotNull() *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil)}
|
||||
}
|
||||
|
||||
// Like is a comparison that checks whether the reference matches the wildcard
|
||||
// value.
|
||||
func Like(value string) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, value)}
|
||||
}
|
||||
|
||||
// NotLike is a comparison that checks whether the reference does not match the
|
||||
// wildcard value.
|
||||
func NotLike(value string) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, value)}
|
||||
}
|
||||
|
||||
// RegExp is a comparison that checks whether the reference matches the regular
|
||||
// expression.
|
||||
func RegExp(value string) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, value)}
|
||||
}
|
||||
|
||||
// NotRegExp is a comparison that checks whether the reference does not match
|
||||
// the regular expression.
|
||||
func NotRegExp(value string) *Comparison {
|
||||
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, value)}
|
||||
}
|
||||
|
||||
// Op returns a custom comparison operator.
|
||||
func Op(customOperator string, value interface{}) *Comparison {
|
||||
return &Comparison{adapter.NewCustomComparisonOperator(customOperator, value)}
|
||||
}
|
||||
|
||||
func toInterfaceArray(value interface{}) []interface{} {
|
||||
rv := reflect.ValueOf(value)
|
||||
switch rv.Type().Kind() {
|
||||
case reflect.Ptr:
|
||||
return toInterfaceArray(rv.Elem().Interface())
|
||||
case reflect.Slice:
|
||||
elems := rv.Len()
|
||||
args := make([]interface{}, elems)
|
||||
for i := 0; i < elems; i++ {
|
||||
args[i] = rv.Index(i).Interface()
|
||||
}
|
||||
return args
|
||||
}
|
||||
return []interface{}{value}
|
||||
}
|
111
comparison_test.go
Normal file
111
comparison_test.go
Normal file
@ -0,0 +1,111 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/adapter"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestComparison(t *testing.T) {
|
||||
testTimeVal := time.Now()
|
||||
|
||||
testCases := []struct {
|
||||
expects *adapter.Comparison
|
||||
result *Comparison
|
||||
}{
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, 1),
|
||||
Gte(1),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, 22),
|
||||
Lte(22),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, 6),
|
||||
Eq(6),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, 67),
|
||||
NotEq(67),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, 4),
|
||||
Gt(4),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, 47),
|
||||
Lt(47),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, []interface{}{1, 22, 34}),
|
||||
In(1, 22, 34),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, testTimeVal),
|
||||
After(testTimeVal),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, testTimeVal),
|
||||
Before(testTimeVal),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, testTimeVal),
|
||||
OnOrAfter(testTimeVal),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, testTimeVal),
|
||||
OnOrBefore(testTimeVal),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{11, 35}),
|
||||
Between(11, 35),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{11, 35}),
|
||||
NotBetween(11, 35),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, 178),
|
||||
Is(178),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, 32),
|
||||
IsNot(32),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil),
|
||||
IsNull(),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil),
|
||||
IsNotNull(),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, "%a%"),
|
||||
Like("%a%"),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, "%z%"),
|
||||
NotLike("%z%"),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, ".*"),
|
||||
RegExp(".*"),
|
||||
},
|
||||
{
|
||||
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, ".*"),
|
||||
NotRegExp(".*"),
|
||||
},
|
||||
{
|
||||
adapter.NewCustomComparisonOperator("~", 56),
|
||||
Op("~", 56),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range testCases {
|
||||
assert.Equal(t, testCases[i].expects, testCases[i].result.Comparison)
|
||||
}
|
||||
}
|
109
cond.go
Normal file
109
cond.go
Normal file
@ -0,0 +1,109 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/adapter"
|
||||
)
|
||||
|
||||
// LogicalExpr represents an expression to be used in logical statements.
|
||||
type LogicalExpr = adapter.LogicalExpr
|
||||
|
||||
// LogicalOperator represents a logical operation.
|
||||
type LogicalOperator = adapter.LogicalOperator
|
||||
|
||||
// Cond is a map that defines conditions for a query.
|
||||
//
|
||||
// Each entry of the map represents a condition (a column-value relation bound
|
||||
// by a comparison Operator). The comparison can be specified after the column
|
||||
// name, if no comparison operator is provided the equality operator is used as
|
||||
// default.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// // Age equals 18.
|
||||
// db.Cond{"age": 18}
|
||||
//
|
||||
// // Age is greater than or equal to 18.
|
||||
// db.Cond{"age >=": 18}
|
||||
//
|
||||
// // id is any of the values 1, 2 or 3.
|
||||
// db.Cond{"id IN": []{1, 2, 3}}
|
||||
//
|
||||
// // Age is lower than 18 (MongoDB syntax)
|
||||
// db.Cond{"age $lt": 18}
|
||||
//
|
||||
// // age > 32 and age < 35
|
||||
// db.Cond{"age >": 32, "age <": 35}
|
||||
type Cond map[interface{}]interface{}
|
||||
|
||||
// Empty returns false if there are no conditions.
|
||||
func (c Cond) Empty() bool {
|
||||
for range c {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// Constraints returns each one of the Cond map entires as a constraint.
|
||||
func (c Cond) Constraints() []adapter.Constraint {
|
||||
z := make([]adapter.Constraint, 0, len(c))
|
||||
for _, k := range c.keys() {
|
||||
z = append(z, adapter.NewConstraint(k, c[k]))
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
// Operator returns the equality operator.
|
||||
func (c Cond) Operator() LogicalOperator {
|
||||
return adapter.DefaultLogicalOperator
|
||||
}
|
||||
|
||||
func (c Cond) keys() []interface{} {
|
||||
keys := make(condKeys, 0, len(c))
|
||||
for k := range c {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
if len(c) > 1 {
|
||||
sort.Sort(keys)
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
// Expressions returns all the expressions contained in the condition.
|
||||
func (c Cond) Expressions() []LogicalExpr {
|
||||
z := make([]LogicalExpr, 0, len(c))
|
||||
for _, k := range c.keys() {
|
||||
z = append(z, Cond{k: c[k]})
|
||||
}
|
||||
return z
|
||||
}
|
||||
|
||||
type condKeys []interface{}
|
||||
|
||||
func (ck condKeys) Len() int {
|
||||
return len(ck)
|
||||
}
|
||||
|
||||
func (ck condKeys) Less(i, j int) bool {
|
||||
return fmt.Sprintf("%v", ck[i]) < fmt.Sprintf("%v", ck[j])
|
||||
}
|
||||
|
||||
func (ck condKeys) Swap(i, j int) {
|
||||
ck[i], ck[j] = ck[j], ck[i]
|
||||
}
|
||||
|
||||
func defaultJoin(in ...adapter.LogicalExpr) []adapter.LogicalExpr {
|
||||
for i := range in {
|
||||
cond, ok := in[i].(Cond)
|
||||
if ok && !cond.Empty() {
|
||||
in[i] = And(cond)
|
||||
}
|
||||
}
|
||||
return in
|
||||
}
|
||||
|
||||
var (
|
||||
_ = LogicalExpr(Cond{})
|
||||
)
|
69
cond_test.go
Normal file
69
cond_test.go
Normal file
@ -0,0 +1,69 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCond(t *testing.T) {
|
||||
c := Cond{}
|
||||
|
||||
if !c.Empty() {
|
||||
t.Fatal("Cond is empty.")
|
||||
}
|
||||
|
||||
c = Cond{"id": 1}
|
||||
if c.Empty() {
|
||||
t.Fatal("Cond is not empty.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCondAnd(t *testing.T) {
|
||||
a := And()
|
||||
|
||||
if !a.Empty() {
|
||||
t.Fatal("Cond is empty")
|
||||
}
|
||||
|
||||
_ = a.And(Cond{"id": 1})
|
||||
|
||||
if !a.Empty() {
|
||||
t.Fatal("Cond is still empty")
|
||||
}
|
||||
|
||||
a = a.And(Cond{"name": "Ana"})
|
||||
|
||||
if a.Empty() {
|
||||
t.Fatal("Cond is not empty anymore")
|
||||
}
|
||||
|
||||
a = a.And().And()
|
||||
|
||||
if a.Empty() {
|
||||
t.Fatal("Cond is not empty anymore")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCondOr(t *testing.T) {
|
||||
a := Or()
|
||||
|
||||
if !a.Empty() {
|
||||
t.Fatal("Cond is empty")
|
||||
}
|
||||
|
||||
_ = a.Or(Cond{"id": 1})
|
||||
|
||||
if !a.Empty() {
|
||||
t.Fatal("Cond is empty")
|
||||
}
|
||||
|
||||
a = a.Or(Cond{"name": "Ana"})
|
||||
|
||||
if a.Empty() {
|
||||
t.Fatal("Cond is not empty")
|
||||
}
|
||||
|
||||
a = a.Or().Or()
|
||||
if a.Empty() {
|
||||
t.Fatal("Cond is not empty")
|
||||
}
|
||||
}
|
8
connection_url.go
Normal file
8
connection_url.go
Normal file
@ -0,0 +1,8 @@
|
||||
package mydb
|
||||
|
||||
// ConnectionURL represents a data source name (DSN).
|
||||
type ConnectionURL interface {
|
||||
// String returns the connection string that is going to be passed to the
|
||||
// adapter.
|
||||
String() string
|
||||
}
|
42
errors.go
Normal file
42
errors.go
Normal file
@ -0,0 +1,42 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Error messages
|
||||
var (
|
||||
ErrMissingAdapter = errors.New(`mydb: missing adapter`)
|
||||
ErrAlreadyWithinTransaction = errors.New(`mydb: already within a transaction`)
|
||||
ErrCollectionDoesNotExist = errors.New(`mydb: collection does not exist`)
|
||||
ErrExpectingNonNilModel = errors.New(`mydb: expecting non nil model`)
|
||||
ErrExpectingPointerToStruct = errors.New(`mydb: expecting pointer to struct`)
|
||||
ErrGivingUpTryingToConnect = errors.New(`mydb: giving up trying to connect: too many clients`)
|
||||
ErrInvalidCollection = errors.New(`mydb: invalid collection`)
|
||||
ErrMissingCollectionName = errors.New(`mydb: missing collection name`)
|
||||
ErrMissingConditions = errors.New(`mydb: missing selector conditions`)
|
||||
ErrMissingConnURL = errors.New(`mydb: missing DSN`)
|
||||
ErrMissingDatabaseName = errors.New(`mydb: missing database name`)
|
||||
ErrNoMoreRows = errors.New(`mydb: no more rows in this result set`)
|
||||
ErrNotConnected = errors.New(`mydb: not connected to a database`)
|
||||
ErrNotImplemented = errors.New(`mydb: call not implemented`)
|
||||
ErrQueryIsPending = errors.New(`mydb: can't execute this instruction while the result set is still open`)
|
||||
ErrQueryLimitParam = errors.New(`mydb: a query can accept only one limit parameter`)
|
||||
ErrQueryOffsetParam = errors.New(`mydb: a query can accept only one offset parameter`)
|
||||
ErrQuerySortParam = errors.New(`mydb: a query can accept only one order-by parameter`)
|
||||
ErrSockerOrHost = errors.New(`mydb: you may connect either to a UNIX socket or a TCP address, but not both`)
|
||||
ErrTooManyClients = errors.New(`mydb: can't connect to database server: too many clients`)
|
||||
ErrUndefined = errors.New(`mydb: value is undefined`)
|
||||
ErrUnknownConditionType = errors.New(`mydb: arguments of type %T can't be used as constraints`)
|
||||
ErrUnsupported = errors.New(`mydb: action is not supported by the DBMS`)
|
||||
ErrUnsupportedDestination = errors.New(`mydb: unsupported destination type`)
|
||||
ErrUnsupportedType = errors.New(`mydb: type does not support marshaling`)
|
||||
ErrUnsupportedValue = errors.New(`mydb: value does not support unmarshaling`)
|
||||
ErrNilRecord = errors.New(`mydb: invalid item (nil)`)
|
||||
ErrRecordIDIsZero = errors.New(`mydb: item ID is not defined`)
|
||||
ErrMissingPrimaryKeys = errors.New(`mydb: collection %q has no primary keys`)
|
||||
ErrWarnSlowQuery = errors.New(`mydb: slow query`)
|
||||
ErrTransactionAborted = errors.New(`mydb: transaction was aborted`)
|
||||
ErrNotWithinTransaction = errors.New(`mydb: not within transaction`)
|
||||
ErrNotSupportedByAdapter = errors.New(`mydb: not supported by adapter`)
|
||||
)
|
14
errors_test.go
Normal file
14
errors_test.go
Normal file
@ -0,0 +1,14 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestErrorWrap(t *testing.T) {
|
||||
adapterFakeErr := fmt.Errorf("could not find item in %q: %w", "users", ErrCollectionDoesNotExist)
|
||||
assert.True(t, errors.Is(adapterFakeErr, ErrCollectionDoesNotExist))
|
||||
}
|
25
function.go
Normal file
25
function.go
Normal file
@ -0,0 +1,25 @@
|
||||
package mydb
|
||||
|
||||
import "git.hexq.cn/tiglog/mydb/internal/adapter"
|
||||
|
||||
// FuncExpr represents functions.
|
||||
type FuncExpr = adapter.FuncExpr
|
||||
|
||||
// Func returns a database function expression.
|
||||
//
|
||||
// Examples:
|
||||
//
|
||||
// // MOD(29, 9)
|
||||
// db.Func("MOD", 29, 9)
|
||||
//
|
||||
// // CONCAT("foo", "bar")
|
||||
// db.Func("CONCAT", "foo", "bar")
|
||||
//
|
||||
// // NOW()
|
||||
// db.Func("NOW")
|
||||
//
|
||||
// // RTRIM("Hello ")
|
||||
// db.Func("RTRIM", "Hello ")
|
||||
func Func(name string, args ...interface{}) *FuncExpr {
|
||||
return adapter.NewFuncExpr(name, args)
|
||||
}
|
51
function_test.go
Normal file
51
function_test.go
Normal file
@ -0,0 +1,51 @@
|
||||
package mydb
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFunction(t *testing.T) {
|
||||
{
|
||||
fn := Func("MOD", 29, 9)
|
||||
assert.Equal(t, "MOD", fn.Name())
|
||||
assert.Equal(t, []interface{}{29, 9}, fn.Arguments())
|
||||
}
|
||||
|
||||
{
|
||||
fn := Func("HELLO")
|
||||
assert.Equal(t, "HELLO", fn.Name())
|
||||
assert.Equal(t, []interface{}(nil), fn.Arguments())
|
||||
}
|
||||
|
||||
{
|
||||
fn := Func("CONCAT", "a")
|
||||
assert.Equal(t, "CONCAT", fn.Name())
|
||||
assert.Equal(t, []interface{}{"a"}, fn.Arguments())
|
||||
}
|
||||
|
||||
{
|
||||
fn := Func("CONCAT", "a", "b", "c")
|
||||
assert.Equal(t, "CONCAT", fn.Name())
|
||||
assert.Equal(t, []interface{}{"a", "b", "c"}, fn.Arguments())
|
||||
}
|
||||
|
||||
{
|
||||
fn := Func("IN", []interface{}{"a", "b", "c"})
|
||||
assert.Equal(t, "IN", fn.Name())
|
||||
assert.Equal(t, []interface{}{[]interface{}{"a", "b", "c"}}, fn.Arguments())
|
||||
}
|
||||
|
||||
{
|
||||
fn := Func("IN", []interface{}{"a"})
|
||||
assert.Equal(t, "IN", fn.Name())
|
||||
assert.Equal(t, []interface{}{[]interface{}{"a"}}, fn.Arguments())
|
||||
}
|
||||
|
||||
{
|
||||
fn := Func("IN", []interface{}(nil))
|
||||
assert.Equal(t, "IN", fn.Name())
|
||||
assert.Equal(t, []interface{}{[]interface{}(nil)}, fn.Arguments())
|
||||
}
|
||||
}
|
33
go.mod
Normal file
33
go.mod
Normal file
@ -0,0 +1,33 @@
|
||||
module git.hexq.cn/tiglog/mydb
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/go-sql-driver/mysql v1.7.1
|
||||
github.com/google/uuid v1.1.1
|
||||
github.com/ipfs/go-detect-race v0.0.1
|
||||
github.com/jackc/pgtype v1.14.0
|
||||
github.com/jackc/pgx/v4 v4.18.1
|
||||
github.com/lib/pq v1.10.9
|
||||
github.com/mattn/go-sqlite3 v1.14.17
|
||||
github.com/segmentio/fasthash v1.0.3
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/stretchr/testify v1.8.4
|
||||
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
|
||||
github.com/jackc/pgconn v1.14.1 // indirect
|
||||
github.com/jackc/pgio v1.0.0 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgproto3/v2 v2.3.2 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
golang.org/x/crypto v0.12.0 // indirect
|
||||
golang.org/x/sys v0.12.0 // indirect
|
||||
golang.org/x/text v0.13.0 // indirect
|
||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||
gopkg.in/yaml.v3 v3.0.1 // indirect
|
||||
)
|
226
go.sum
Normal file
226
go.sum
Normal file
@ -0,0 +1,226 @@
|
||||
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
|
||||
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
|
||||
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
|
||||
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
|
||||
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
|
||||
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
|
||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
|
||||
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
|
||||
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
|
||||
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
|
||||
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
|
||||
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
|
||||
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
|
||||
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
|
||||
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
|
||||
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
|
||||
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
|
||||
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ6cOk=
|
||||
github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps=
|
||||
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
|
||||
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
|
||||
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
|
||||
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
|
||||
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
|
||||
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
|
||||
github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
|
||||
github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
|
||||
github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI=
|
||||
github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E=
|
||||
github.com/jackc/pgconn v1.14.1 h1:smbxIaZA08n6YuxEX1sDyjV/qkbtUtkH20qLkR9MUR4=
|
||||
github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E=
|
||||
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
|
||||
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
|
||||
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
|
||||
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
|
||||
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
|
||||
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
|
||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
|
||||
github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0=
|
||||
github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
|
||||
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
|
||||
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
|
||||
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
|
||||
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
|
||||
github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
|
||||
github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw=
|
||||
github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
|
||||
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
|
||||
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
|
||||
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
|
||||
github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0=
|
||||
github.com/jackc/pgx/v4 v4.18.1/go.mod h1:FydWkUyadDmdNH/mHnGob881GawxeEm7TcMCzkb+qQE=
|
||||
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
|
||||
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
|
||||
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
|
||||
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
|
||||
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
|
||||
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
|
||||
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
|
||||
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
|
||||
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
|
||||
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
|
||||
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
|
||||
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
|
||||
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
|
||||
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
|
||||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
|
||||
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
|
||||
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
|
||||
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
|
||||
github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM=
|
||||
github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY=
|
||||
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
|
||||
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
|
||||
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
|
||||
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
|
||||
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
|
||||
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
|
||||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
|
||||
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
|
||||
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
|
||||
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
|
||||
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
|
||||
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
|
||||
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
|
||||
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
|
||||
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
|
||||
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
|
||||
golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
|
||||
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
|
||||
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
|
||||
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
|
||||
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
|
||||
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
|
||||
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
|
||||
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
|
||||
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
|
||||
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
|
||||
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
|
||||
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
|
||||
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
|
||||
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
|
||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
|
||||
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
|
||||
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
|
||||
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
||||
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
|
||||
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
|
||||
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
|
||||
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
|
||||
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
|
||||
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
|
||||
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
|
||||
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
|
||||
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
|
||||
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 h1:VpOs+IwYnYBaFnrNAeB8UUWtL3vEUnzSCL1nVjPhqrw=
|
||||
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
|
||||
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
|
||||
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
|
||||
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
|
||||
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
||||
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=
|
60
internal/adapter/comparison.go
Normal file
60
internal/adapter/comparison.go
Normal file
@ -0,0 +1,60 @@
|
||||
package adapter
|
||||
|
||||
// ComparisonOperator is the base type for comparison operators.
|
||||
type ComparisonOperator uint8
|
||||
|
||||
// Comparison operators
|
||||
const (
|
||||
ComparisonOperatorNone ComparisonOperator = iota
|
||||
ComparisonOperatorCustom
|
||||
|
||||
ComparisonOperatorEqual
|
||||
ComparisonOperatorNotEqual
|
||||
|
||||
ComparisonOperatorLessThan
|
||||
ComparisonOperatorGreaterThan
|
||||
|
||||
ComparisonOperatorLessThanOrEqualTo
|
||||
ComparisonOperatorGreaterThanOrEqualTo
|
||||
|
||||
ComparisonOperatorBetween
|
||||
ComparisonOperatorNotBetween
|
||||
|
||||
ComparisonOperatorIn
|
||||
ComparisonOperatorNotIn
|
||||
|
||||
ComparisonOperatorIs
|
||||
ComparisonOperatorIsNot
|
||||
|
||||
ComparisonOperatorLike
|
||||
ComparisonOperatorNotLike
|
||||
|
||||
ComparisonOperatorRegExp
|
||||
ComparisonOperatorNotRegExp
|
||||
)
|
||||
|
||||
type Comparison struct {
|
||||
t ComparisonOperator
|
||||
op string
|
||||
v interface{}
|
||||
}
|
||||
|
||||
func (c *Comparison) CustomOperator() string {
|
||||
return c.op
|
||||
}
|
||||
|
||||
func (c *Comparison) Operator() ComparisonOperator {
|
||||
return c.t
|
||||
}
|
||||
|
||||
func (c *Comparison) Value() interface{} {
|
||||
return c.v
|
||||
}
|
||||
|
||||
func NewComparisonOperator(t ComparisonOperator, v interface{}) *Comparison {
|
||||
return &Comparison{t: t, v: v}
|
||||
}
|
||||
|
||||
func NewCustomComparisonOperator(op string, v interface{}) *Comparison {
|
||||
return &Comparison{t: ComparisonOperatorCustom, op: op, v: v}
|
||||
}
|
51
internal/adapter/constraint.go
Normal file
51
internal/adapter/constraint.go
Normal file
@ -0,0 +1,51 @@
|
||||
package adapter
|
||||
|
||||
// ConstraintValuer allows constraints to use specific values of their own.
|
||||
type ConstraintValuer interface {
|
||||
ConstraintValue() interface{}
|
||||
}
|
||||
|
||||
// Constraint interface represents a single condition, like "a = 1". where `a`
|
||||
// is the key and `1` is the value. This is an exported interface but it's
|
||||
// rarely used directly, you may want to use the `db.Cond{}` map instead.
|
||||
type Constraint interface {
|
||||
// Key is the leftmost part of the constraint and usually contains a column
|
||||
// name.
|
||||
Key() interface{}
|
||||
|
||||
// Value if the rightmost part of the constraint and usually contains a
|
||||
// column value.
|
||||
Value() interface{}
|
||||
}
|
||||
|
||||
// Constraints interface represents an array of constraints, like "a = 1, b =
|
||||
// 2, c = 3".
|
||||
type Constraints interface {
|
||||
// Constraints returns an array of constraints.
|
||||
Constraints() []Constraint
|
||||
}
|
||||
|
||||
type constraint struct {
|
||||
k interface{}
|
||||
v interface{}
|
||||
}
|
||||
|
||||
func (c constraint) Key() interface{} {
|
||||
return c.k
|
||||
}
|
||||
|
||||
func (c constraint) Value() interface{} {
|
||||
if constraintValuer, ok := c.v.(ConstraintValuer); ok {
|
||||
return constraintValuer.ConstraintValue()
|
||||
}
|
||||
return c.v
|
||||
}
|
||||
|
||||
// NewConstraint creates a constraint.
|
||||
func NewConstraint(key interface{}, value interface{}) Constraint {
|
||||
return &constraint{k: key, v: value}
|
||||
}
|
||||
|
||||
var (
|
||||
_ = Constraint(&constraint{})
|
||||
)
|
18
internal/adapter/func.go
Normal file
18
internal/adapter/func.go
Normal file
@ -0,0 +1,18 @@
|
||||
package adapter
|
||||
|
||||
type FuncExpr struct {
|
||||
name string
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
func (f *FuncExpr) Arguments() []interface{} {
|
||||
return f.args
|
||||
}
|
||||
|
||||
func (f *FuncExpr) Name() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
func NewFuncExpr(name string, args []interface{}) *FuncExpr {
|
||||
return &FuncExpr{name: name, args: args}
|
||||
}
|
100
internal/adapter/logical_expr.go
Normal file
100
internal/adapter/logical_expr.go
Normal file
@ -0,0 +1,100 @@
|
||||
package adapter
|
||||
|
||||
import "git.hexq.cn/tiglog/mydb/internal/immutable"
|
||||
|
||||
// LogicalExpr represents a group formed by one or more sentences joined by
|
||||
// an Operator like "AND" or "OR".
|
||||
type LogicalExpr interface {
|
||||
// Expressions returns child sentences.
|
||||
Expressions() []LogicalExpr
|
||||
|
||||
// Operator returns the Operator that joins all the sentences in the group.
|
||||
Operator() LogicalOperator
|
||||
|
||||
// Empty returns true if the compound has zero children, false otherwise.
|
||||
Empty() bool
|
||||
}
|
||||
|
||||
// LogicalOperator represents the operation on a compound statement.
|
||||
type LogicalOperator uint
|
||||
|
||||
// LogicalExpr Operators.
|
||||
const (
|
||||
LogicalOperatorNone LogicalOperator = iota
|
||||
LogicalOperatorAnd
|
||||
LogicalOperatorOr
|
||||
)
|
||||
|
||||
const DefaultLogicalOperator = LogicalOperatorAnd
|
||||
|
||||
type LogicalExprGroup struct {
|
||||
op LogicalOperator
|
||||
|
||||
prev *LogicalExprGroup
|
||||
fn func(*[]LogicalExpr) error
|
||||
}
|
||||
|
||||
func NewLogicalExprGroup(op LogicalOperator, conds ...LogicalExpr) *LogicalExprGroup {
|
||||
group := &LogicalExprGroup{op: op}
|
||||
if len(conds) == 0 {
|
||||
return group
|
||||
}
|
||||
return group.Frame(func(in *[]LogicalExpr) error {
|
||||
*in = append(*in, conds...)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// Expressions returns each one of the conditions as a compound.
|
||||
func (g *LogicalExprGroup) Expressions() []LogicalExpr {
|
||||
conds, err := immutable.FastForward(g)
|
||||
if err == nil {
|
||||
return *(conds.(*[]LogicalExpr))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Operator is undefined for a logical group.
|
||||
func (g *LogicalExprGroup) Operator() LogicalOperator {
|
||||
if g.op == LogicalOperatorNone {
|
||||
panic("operator is not defined")
|
||||
}
|
||||
return g.op
|
||||
}
|
||||
|
||||
// Empty returns true if this condition has no elements. False otherwise.
|
||||
func (g *LogicalExprGroup) Empty() bool {
|
||||
if g.fn != nil {
|
||||
return false
|
||||
}
|
||||
if g.prev != nil {
|
||||
return g.prev.Empty()
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (g *LogicalExprGroup) Frame(fn func(*[]LogicalExpr) error) *LogicalExprGroup {
|
||||
return &LogicalExprGroup{prev: g, op: g.op, fn: fn}
|
||||
}
|
||||
|
||||
func (g *LogicalExprGroup) Prev() immutable.Immutable {
|
||||
if g == nil {
|
||||
return nil
|
||||
}
|
||||
return g.prev
|
||||
}
|
||||
|
||||
func (g *LogicalExprGroup) Fn(in interface{}) error {
|
||||
if g.fn == nil {
|
||||
return nil
|
||||
}
|
||||
return g.fn(in.(*[]LogicalExpr))
|
||||
}
|
||||
|
||||
func (g *LogicalExprGroup) Base() interface{} {
|
||||
return &[]LogicalExpr{}
|
||||
}
|
||||
|
||||
var (
|
||||
_ = immutable.Immutable(&LogicalExprGroup{})
|
||||
)
|
49
internal/adapter/raw.go
Normal file
49
internal/adapter/raw.go
Normal file
@ -0,0 +1,49 @@
|
||||
package adapter
|
||||
|
||||
// RawExpr interface represents values that can bypass SQL filters. This is an
|
||||
// exported interface but it's rarely used directly, you may want to use the
|
||||
// `db.Raw()` function instead.
|
||||
type RawExpr struct {
|
||||
value string
|
||||
args *[]interface{}
|
||||
}
|
||||
|
||||
func (r *RawExpr) Arguments() []interface{} {
|
||||
if r.args != nil {
|
||||
return *r.args
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r RawExpr) Raw() string {
|
||||
return r.value
|
||||
}
|
||||
|
||||
func (r RawExpr) String() string {
|
||||
return r.Raw()
|
||||
}
|
||||
|
||||
// Expressions returns a logical expressio.n
|
||||
func (r *RawExpr) Expressions() []LogicalExpr {
|
||||
return []LogicalExpr{r}
|
||||
}
|
||||
|
||||
// Operator returns the default compound operator.
|
||||
func (r RawExpr) Operator() LogicalOperator {
|
||||
return LogicalOperatorNone
|
||||
}
|
||||
|
||||
// Empty return false if this struct has no value.
|
||||
func (r *RawExpr) Empty() bool {
|
||||
return r.value == ""
|
||||
}
|
||||
|
||||
func NewRawExpr(value string, args []interface{}) *RawExpr {
|
||||
r := &RawExpr{value: value, args: nil}
|
||||
if len(args) > 0 {
|
||||
r.args = &args
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
var _ = LogicalExpr(&RawExpr{})
|
113
internal/cache/cache.go
vendored
Normal file
113
internal/cache/cache.go
vendored
Normal file
@ -0,0 +1,113 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"container/list"
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
const defaultCapacity = 128
|
||||
|
||||
// Cache holds a map of volatile key -> values.
|
||||
type Cache struct {
|
||||
keys *list.List
|
||||
items map[uint64]*list.Element
|
||||
mu sync.RWMutex
|
||||
capacity int
|
||||
}
|
||||
|
||||
type cacheItem struct {
|
||||
key uint64
|
||||
value interface{}
|
||||
}
|
||||
|
||||
// NewCacheWithCapacity initializes a new caching space with the given
|
||||
// capacity.
|
||||
func NewCacheWithCapacity(capacity int) (*Cache, error) {
|
||||
if capacity < 1 {
|
||||
return nil, errors.New("Capacity must be greater than zero.")
|
||||
}
|
||||
c := &Cache{
|
||||
capacity: capacity,
|
||||
}
|
||||
c.init()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
// NewCache initializes a new caching space with default settings.
|
||||
func NewCache() *Cache {
|
||||
c, err := NewCacheWithCapacity(defaultCapacity)
|
||||
if err != nil {
|
||||
panic(err.Error()) // Should never happen as we're not providing a negative defaultCapacity.
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
func (c *Cache) init() {
|
||||
c.items = make(map[uint64]*list.Element)
|
||||
c.keys = list.New()
|
||||
}
|
||||
|
||||
// Read attempts to retrieve a cached value as a string, if the value does not
|
||||
// exists returns an empty string and false.
|
||||
func (c *Cache) Read(h Hashable) (string, bool) {
|
||||
if v, ok := c.ReadRaw(h); ok {
|
||||
if s, ok := v.(string); ok {
|
||||
return s, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ReadRaw attempts to retrieve a cached value as an interface{}, if the value
|
||||
// does not exists returns nil and false.
|
||||
func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) {
|
||||
c.mu.RLock()
|
||||
defer c.mu.RUnlock()
|
||||
|
||||
item, ok := c.items[h.Hash()]
|
||||
if ok {
|
||||
return item.Value.(*cacheItem).value, true
|
||||
}
|
||||
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// Write stores a value in memory. If the value already exists its overwritten.
|
||||
func (c *Cache) Write(h Hashable, value interface{}) {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
key := h.Hash()
|
||||
|
||||
if item, ok := c.items[key]; ok {
|
||||
item.Value.(*cacheItem).value = value
|
||||
c.keys.MoveToFront(item)
|
||||
return
|
||||
}
|
||||
|
||||
c.items[key] = c.keys.PushFront(&cacheItem{key, value})
|
||||
|
||||
for c.keys.Len() > c.capacity {
|
||||
item := c.keys.Remove(c.keys.Back()).(*cacheItem)
|
||||
delete(c.items, item.key)
|
||||
if p, ok := item.value.(HasOnEvict); ok {
|
||||
p.OnEvict()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Clear generates a new memory space, leaving the old memory unreferenced, so
|
||||
// it can be claimed by the garbage collector.
|
||||
func (c *Cache) Clear() {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
|
||||
for _, item := range c.items {
|
||||
if p, ok := item.Value.(*cacheItem).value.(HasOnEvict); ok {
|
||||
p.OnEvict()
|
||||
}
|
||||
}
|
||||
|
||||
c.init()
|
||||
}
|
97
internal/cache/cache_test.go
vendored
Normal file
97
internal/cache/cache_test.go
vendored
Normal file
@ -0,0 +1,97 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"hash/fnv"
|
||||
"testing"
|
||||
)
|
||||
|
||||
var c *Cache
|
||||
|
||||
type cacheableT struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (ct *cacheableT) Hash() uint64 {
|
||||
s := fnv.New64()
|
||||
s.Sum([]byte(ct.Name))
|
||||
return s.Sum64()
|
||||
}
|
||||
|
||||
var (
|
||||
key = cacheableT{"foo"}
|
||||
value = "bar"
|
||||
)
|
||||
|
||||
func TestNewCache(t *testing.T) {
|
||||
c = NewCache()
|
||||
if c == nil {
|
||||
t.Fatal("Expecting a new cache object.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheReadNonExistentValue(t *testing.T) {
|
||||
if _, ok := c.Read(&key); ok {
|
||||
t.Fatal("Expecting false.")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheWritingValue(t *testing.T) {
|
||||
c.Write(&key, value)
|
||||
c.Write(&key, value)
|
||||
}
|
||||
|
||||
func TestCacheReadExistentValue(t *testing.T) {
|
||||
s, ok := c.Read(&key)
|
||||
|
||||
if !ok {
|
||||
t.Fatal("Expecting true.")
|
||||
}
|
||||
|
||||
if s != value {
|
||||
t.Fatal("Expecting value.")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
NewCache()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewCacheAndClear(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := NewCache()
|
||||
c.Clear()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReadNonExistentValue(b *testing.B) {
|
||||
z := NewCache()
|
||||
for i := 0; i < b.N; i++ {
|
||||
z.Read(&key)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteSameValue(b *testing.B) {
|
||||
z := NewCache()
|
||||
for i := 0; i < b.N; i++ {
|
||||
z.Write(&key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkWriteNewValue(b *testing.B) {
|
||||
z := NewCache()
|
||||
for i := 0; i < b.N; i++ {
|
||||
key := cacheableT{fmt.Sprintf("item-%d", i)}
|
||||
z.Write(&key, value)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkReadExistentValue(b *testing.B) {
|
||||
z := NewCache()
|
||||
z.Write(&key, value)
|
||||
for i := 0; i < b.N; i++ {
|
||||
z.Read(&key)
|
||||
}
|
||||
}
|
109
internal/cache/hash.go
vendored
Normal file
109
internal/cache/hash.go
vendored
Normal file
@ -0,0 +1,109 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/segmentio/fasthash/fnv1a"
|
||||
)
|
||||
|
||||
const (
|
||||
hashTypeInt uint64 = 1 << iota
|
||||
hashTypeSignedInt
|
||||
hashTypeBool
|
||||
hashTypeString
|
||||
hashTypeHashable
|
||||
hashTypeNil
|
||||
)
|
||||
|
||||
type hasher struct {
|
||||
t uint64
|
||||
v interface{}
|
||||
}
|
||||
|
||||
func (h *hasher) Hash() uint64 {
|
||||
return NewHash(h.t, h.v)
|
||||
}
|
||||
|
||||
func NewHashable(t uint64, v interface{}) Hashable {
|
||||
return &hasher{t: t, v: v}
|
||||
}
|
||||
|
||||
func InitHash(t uint64) uint64 {
|
||||
return fnv1a.AddUint64(fnv1a.Init64, t)
|
||||
}
|
||||
|
||||
func NewHash(t uint64, in ...interface{}) uint64 {
|
||||
return AddToHash(InitHash(t), in...)
|
||||
}
|
||||
|
||||
func AddToHash(h uint64, in ...interface{}) uint64 {
|
||||
for i := range in {
|
||||
if in[i] == nil {
|
||||
continue
|
||||
}
|
||||
h = addToHash(h, in[i])
|
||||
}
|
||||
return h
|
||||
}
|
||||
|
||||
func addToHash(h uint64, in interface{}) uint64 {
|
||||
switch v := in.(type) {
|
||||
case uint64:
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), v)
|
||||
case uint32:
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
case uint16:
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
case uint8:
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
case uint:
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
case int64:
|
||||
if v < 0 {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
|
||||
} else {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
}
|
||||
case int32:
|
||||
if v < 0 {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
|
||||
} else {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
}
|
||||
case int16:
|
||||
if v < 0 {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
|
||||
} else {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
}
|
||||
case int8:
|
||||
if v < 0 {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
|
||||
} else {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
}
|
||||
case int:
|
||||
if v < 0 {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
|
||||
} else {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
|
||||
}
|
||||
case bool:
|
||||
if v {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 1)
|
||||
} else {
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 2)
|
||||
}
|
||||
case string:
|
||||
return fnv1a.AddString64(fnv1a.AddUint64(h, hashTypeString), v)
|
||||
case Hashable:
|
||||
if in == nil {
|
||||
panic(fmt.Sprintf("could not hash nil element %T", in))
|
||||
}
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeHashable), v.Hash())
|
||||
case nil:
|
||||
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeNil), 0)
|
||||
default:
|
||||
panic(fmt.Sprintf("unsupported value type %T", in))
|
||||
}
|
||||
}
|
13
internal/cache/interface.go
vendored
Normal file
13
internal/cache/interface.go
vendored
Normal file
@ -0,0 +1,13 @@
|
||||
package cache
|
||||
|
||||
// Hashable types must implement a method that returns a key. This key will be
|
||||
// associated with a cached value.
|
||||
type Hashable interface {
|
||||
Hash() uint64
|
||||
}
|
||||
|
||||
// HasOnEvict type is (optionally) implemented by cache objects to clean after
|
||||
// themselves.
|
||||
type HasOnEvict interface {
|
||||
OnEvict()
|
||||
}
|
28
internal/immutable/immutable.go
Normal file
28
internal/immutable/immutable.go
Normal file
@ -0,0 +1,28 @@
|
||||
package immutable
|
||||
|
||||
// Immutable represents an immutable chain that, if passed to FastForward,
|
||||
// applies Fn() to every element of a chain, the first element of this chain is
|
||||
// represented by Base().
|
||||
type Immutable interface {
|
||||
// Prev is the previous element on a chain.
|
||||
Prev() Immutable
|
||||
// Fn a function that is able to modify the passed element.
|
||||
Fn(interface{}) error
|
||||
// Base is the first element on a chain, there's no previous element before
|
||||
// the Base element.
|
||||
Base() interface{}
|
||||
}
|
||||
|
||||
// FastForward applies all Fn methods in order on the given new Base.
|
||||
func FastForward(curr Immutable) (interface{}, error) {
|
||||
prev := curr.Prev()
|
||||
if prev == nil {
|
||||
return curr.Base(), nil
|
||||
}
|
||||
in, err := FastForward(prev)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = curr.Fn(in)
|
||||
return in, err
|
||||
}
|
23
internal/reflectx/LICENSE
Normal file
23
internal/reflectx/LICENSE
Normal file
@ -0,0 +1,23 @@
|
||||
Copyright (c) 2013, Jason Moiron
|
||||
|
||||
Permission is hereby granted, free of charge, to any person
|
||||
obtaining a copy of this software and associated documentation
|
||||
files (the "Software"), to deal in the Software without
|
||||
restriction, including without limitation the rights to use,
|
||||
copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the
|
||||
Software is furnished to do so, subject to the following
|
||||
conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be
|
||||
included in all copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
||||
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
||||
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
||||
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
||||
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
||||
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
||||
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
|
||||
OTHER DEALINGS IN THE SOFTWARE.
|
||||
|
17
internal/reflectx/README.md
Normal file
17
internal/reflectx/README.md
Normal file
@ -0,0 +1,17 @@
|
||||
# reflectx
|
||||
|
||||
The sqlx package has special reflect needs. In particular, it needs to:
|
||||
|
||||
* be able to map a name to a field
|
||||
* understand embedded structs
|
||||
* understand mapping names to fields by a particular tag
|
||||
* user specified name -> field mapping functions
|
||||
|
||||
These behaviors mimic the behaviors by the standard library marshallers and also the
|
||||
behavior of standard Go accessors.
|
||||
|
||||
The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is
|
||||
addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct
|
||||
tags in the ways that are vital to most marshalers, and they are slow.
|
||||
|
||||
This reflectx package extends reflect to achieve these goals.
|
404
internal/reflectx/reflect.go
Normal file
404
internal/reflectx/reflect.go
Normal file
@ -0,0 +1,404 @@
|
||||
// Package reflectx implements extensions to the standard reflect lib suitable
|
||||
// for implementing marshaling and unmarshaling packages. The main Mapper type
|
||||
// allows for Go-compatible named attribute access, including accessing embedded
|
||||
// struct attributes and the ability to use functions and struct tags to
|
||||
// customize field names.
|
||||
package reflectx
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// A FieldInfo is a collection of metadata about a struct field.
|
||||
type FieldInfo struct {
|
||||
Index []int
|
||||
Path string
|
||||
Field reflect.StructField
|
||||
Zero reflect.Value
|
||||
Name string
|
||||
Options map[string]string
|
||||
Embedded bool
|
||||
Children []*FieldInfo
|
||||
Parent *FieldInfo
|
||||
}
|
||||
|
||||
// A StructMap is an index of field metadata for a struct.
|
||||
type StructMap struct {
|
||||
Tree *FieldInfo
|
||||
Index []*FieldInfo
|
||||
Paths map[string]*FieldInfo
|
||||
Names map[string]*FieldInfo
|
||||
}
|
||||
|
||||
// GetByPath returns a *FieldInfo for a given string path.
|
||||
func (f StructMap) GetByPath(path string) *FieldInfo {
|
||||
return f.Paths[path]
|
||||
}
|
||||
|
||||
// GetByTraversal returns a *FieldInfo for a given integer path. It is
|
||||
// analogous to reflect.FieldByIndex.
|
||||
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
|
||||
if len(index) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
tree := f.Tree
|
||||
for _, i := range index {
|
||||
if i >= len(tree.Children) || tree.Children[i] == nil {
|
||||
return nil
|
||||
}
|
||||
tree = tree.Children[i]
|
||||
}
|
||||
return tree
|
||||
}
|
||||
|
||||
// Mapper is a general purpose mapper of names to struct fields. A Mapper
|
||||
// behaves like most marshallers, optionally obeying a field tag for name
|
||||
// mapping and a function to provide a basic mapping of fields to names.
|
||||
type Mapper struct {
|
||||
cache map[reflect.Type]*StructMap
|
||||
tagName string
|
||||
tagMapFunc func(string) string
|
||||
mapFunc func(string) string
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// NewMapper returns a new mapper which optionally obeys the field tag given
|
||||
// by tagName. If tagName is the empty string, it is ignored.
|
||||
func NewMapper(tagName string) *Mapper {
|
||||
return &Mapper{
|
||||
cache: make(map[reflect.Type]*StructMap),
|
||||
tagName: tagName,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
|
||||
// AND a mapper for tag values. This is useful for tags like json which can
|
||||
// have values like "name,omitempty".
|
||||
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
|
||||
return &Mapper{
|
||||
cache: make(map[reflect.Type]*StructMap),
|
||||
tagName: tagName,
|
||||
mapFunc: mapFunc,
|
||||
tagMapFunc: tagMapFunc,
|
||||
}
|
||||
}
|
||||
|
||||
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
|
||||
// a struct field name mapper func given by f. Tags will take precedence, but
|
||||
// for any other field, the mapped name will be f(field.Name)
|
||||
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
|
||||
return &Mapper{
|
||||
cache: make(map[reflect.Type]*StructMap),
|
||||
tagName: tagName,
|
||||
mapFunc: f,
|
||||
}
|
||||
}
|
||||
|
||||
// TypeMap returns a mapping of field strings to int slices representing
|
||||
// the traversal down the struct to reach the field.
|
||||
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
|
||||
m.mutex.Lock()
|
||||
mapping, ok := m.cache[t]
|
||||
if !ok {
|
||||
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
|
||||
m.cache[t] = mapping
|
||||
}
|
||||
m.mutex.Unlock()
|
||||
return mapping
|
||||
}
|
||||
|
||||
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
|
||||
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
|
||||
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
r := map[string]reflect.Value{}
|
||||
tm := m.TypeMap(v.Type())
|
||||
for tagName, fi := range tm.Names {
|
||||
r[tagName] = FieldByIndexes(v, fi.Index)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// ValidFieldMap returns the mapper's mapping of field names to reflect valid
|
||||
// field values. Panics if v's Kind is not Struct, or v is not Indirectable to
|
||||
// a struct kind.
|
||||
func (m *Mapper) ValidFieldMap(v reflect.Value) map[string]reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
r := map[string]reflect.Value{}
|
||||
tm := m.TypeMap(v.Type())
|
||||
for tagName, fi := range tm.Names {
|
||||
v := ValidFieldByIndexes(v, fi.Index)
|
||||
if v.IsValid() {
|
||||
r[tagName] = v
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FieldByName returns a field by the its mapped name as a reflect.Value.
|
||||
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
|
||||
// Returns zero Value if the name is not found.
|
||||
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
tm := m.TypeMap(v.Type())
|
||||
fi, ok := tm.Names[name]
|
||||
if !ok {
|
||||
return v
|
||||
}
|
||||
return FieldByIndexes(v, fi.Index)
|
||||
}
|
||||
|
||||
// FieldsByName returns a slice of values corresponding to the slice of names
|
||||
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
|
||||
// to a struct Kind. Returns zero Value for each name not found.
|
||||
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
|
||||
v = reflect.Indirect(v)
|
||||
mustBe(v, reflect.Struct)
|
||||
|
||||
tm := m.TypeMap(v.Type())
|
||||
vals := make([]reflect.Value, 0, len(names))
|
||||
for _, name := range names {
|
||||
fi, ok := tm.Names[name]
|
||||
if !ok {
|
||||
vals = append(vals, *new(reflect.Value))
|
||||
} else {
|
||||
vals = append(vals, FieldByIndexes(v, fi.Index))
|
||||
}
|
||||
}
|
||||
return vals
|
||||
}
|
||||
|
||||
// TraversalsByName returns a slice of int slices which represent the struct
|
||||
// traversals for each mapped name. Panics if t is not a struct or Indirectable
|
||||
// to a struct. Returns empty int slice for each name not found.
|
||||
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
|
||||
t = Deref(t)
|
||||
mustBe(t, reflect.Struct)
|
||||
tm := m.TypeMap(t)
|
||||
|
||||
r := make([][]int, 0, len(names))
|
||||
for _, name := range names {
|
||||
fi, ok := tm.Names[name]
|
||||
if !ok {
|
||||
r = append(r, []int{})
|
||||
} else {
|
||||
r = append(r, fi.Index)
|
||||
}
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// FieldByIndexes returns a value for a particular struct traversal.
|
||||
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
|
||||
for _, i := range indexes {
|
||||
v = reflect.Indirect(v).Field(i)
|
||||
// if this is a pointer, it's possible it is nil
|
||||
if v.Kind() == reflect.Ptr && v.IsNil() {
|
||||
alloc := reflect.New(Deref(v.Type()))
|
||||
v.Set(alloc)
|
||||
}
|
||||
if v.Kind() == reflect.Map && v.IsNil() {
|
||||
v.Set(reflect.MakeMap(v.Type()))
|
||||
}
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// ValidFieldByIndexes returns a value for a particular struct traversal.
|
||||
func ValidFieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
|
||||
|
||||
for _, i := range indexes {
|
||||
v = reflect.Indirect(v)
|
||||
if !v.IsValid() {
|
||||
return reflect.Value{}
|
||||
}
|
||||
v = v.Field(i)
|
||||
// if this is a pointer, it's possible it is nil
|
||||
if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Map) && v.IsNil() {
|
||||
return reflect.Value{}
|
||||
}
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
|
||||
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
|
||||
// but is not concerned with allocating nil pointers because the value is
|
||||
// going to be used for reading and not setting.
|
||||
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
|
||||
for _, i := range indexes {
|
||||
v = reflect.Indirect(v).Field(i)
|
||||
}
|
||||
return v
|
||||
}
|
||||
|
||||
// Deref is Indirect for reflect.Types
|
||||
func Deref(t reflect.Type) reflect.Type {
|
||||
if t.Kind() == reflect.Ptr {
|
||||
t = t.Elem()
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
// -- helpers & utilities --
|
||||
|
||||
type kinder interface {
|
||||
Kind() reflect.Kind
|
||||
}
|
||||
|
||||
// mustBe checks a value against a kind, panicing with a reflect.ValueError
|
||||
// if the kind isn't that which is required.
|
||||
func mustBe(v kinder, expected reflect.Kind) {
|
||||
k := v.Kind()
|
||||
if k != expected {
|
||||
panic(&reflect.ValueError{Method: methodName(), Kind: k})
|
||||
}
|
||||
}
|
||||
|
||||
// methodName is returns the caller of the function calling methodName
|
||||
func methodName() string {
|
||||
pc, _, _, _ := runtime.Caller(2)
|
||||
f := runtime.FuncForPC(pc)
|
||||
if f == nil {
|
||||
return "unknown method"
|
||||
}
|
||||
return f.Name()
|
||||
}
|
||||
|
||||
type typeQueue struct {
|
||||
t reflect.Type
|
||||
fi *FieldInfo
|
||||
pp string // Parent path
|
||||
}
|
||||
|
||||
// A copying append that creates a new slice each time.
|
||||
func apnd(is []int, i int) []int {
|
||||
x := make([]int, len(is)+1)
|
||||
copy(x, is)
|
||||
x[len(x)-1] = i
|
||||
return x
|
||||
}
|
||||
|
||||
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
|
||||
// tagMapFunc to determine the canonical names of fields.
|
||||
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap {
|
||||
m := []*FieldInfo{}
|
||||
|
||||
root := &FieldInfo{}
|
||||
queue := []typeQueue{}
|
||||
queue = append(queue, typeQueue{Deref(t), root, ""})
|
||||
|
||||
for len(queue) != 0 {
|
||||
// pop the first item off of the queue
|
||||
tq := queue[0]
|
||||
queue = queue[1:]
|
||||
nChildren := 0
|
||||
if tq.t.Kind() == reflect.Struct {
|
||||
nChildren = tq.t.NumField()
|
||||
}
|
||||
tq.fi.Children = make([]*FieldInfo, nChildren)
|
||||
|
||||
// iterate through all of its fields
|
||||
for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
|
||||
f := tq.t.Field(fieldPos)
|
||||
|
||||
fi := FieldInfo{}
|
||||
fi.Field = f
|
||||
fi.Zero = reflect.New(f.Type).Elem()
|
||||
fi.Options = map[string]string{}
|
||||
|
||||
var tag, name string
|
||||
if tagName != "" && strings.Contains(string(f.Tag), tagName+":") {
|
||||
tag = f.Tag.Get(tagName)
|
||||
name = tag
|
||||
} else {
|
||||
if mapFunc != nil {
|
||||
name = mapFunc(f.Name)
|
||||
}
|
||||
}
|
||||
|
||||
parts := strings.Split(name, ",")
|
||||
if len(parts) > 1 {
|
||||
name = parts[0]
|
||||
for _, opt := range parts[1:] {
|
||||
kv := strings.Split(opt, "=")
|
||||
if len(kv) > 1 {
|
||||
fi.Options[kv[0]] = kv[1]
|
||||
} else {
|
||||
fi.Options[kv[0]] = ""
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if tagMapFunc != nil {
|
||||
tag = tagMapFunc(tag)
|
||||
}
|
||||
|
||||
fi.Name = name
|
||||
|
||||
if tq.pp == "" || (tq.pp == "" && tag == "") {
|
||||
fi.Path = fi.Name
|
||||
} else {
|
||||
fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name)
|
||||
}
|
||||
|
||||
// if the name is "-", disabled via a tag, skip it
|
||||
if name == "-" {
|
||||
continue
|
||||
}
|
||||
|
||||
// skip unexported fields
|
||||
if len(f.PkgPath) != 0 && !f.Anonymous {
|
||||
continue
|
||||
}
|
||||
|
||||
// bfs search of anonymous embedded structs
|
||||
if f.Anonymous {
|
||||
pp := tq.pp
|
||||
if tag != "" {
|
||||
pp = fi.Path
|
||||
}
|
||||
|
||||
fi.Embedded = true
|
||||
fi.Index = apnd(tq.fi.Index, fieldPos)
|
||||
nChildren := 0
|
||||
ft := Deref(f.Type)
|
||||
if ft.Kind() == reflect.Struct {
|
||||
nChildren = ft.NumField()
|
||||
}
|
||||
fi.Children = make([]*FieldInfo, nChildren)
|
||||
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
|
||||
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
|
||||
fi.Index = apnd(tq.fi.Index, fieldPos)
|
||||
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
|
||||
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
|
||||
}
|
||||
|
||||
fi.Index = apnd(tq.fi.Index, fieldPos)
|
||||
fi.Parent = tq.fi
|
||||
tq.fi.Children[fieldPos] = &fi
|
||||
m = append(m, &fi)
|
||||
}
|
||||
}
|
||||
|
||||
flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}}
|
||||
for _, fi := range flds.Index {
|
||||
flds.Paths[fi.Path] = fi
|
||||
if fi.Name != "" && !fi.Embedded {
|
||||
flds.Names[fi.Path] = fi
|
||||
}
|
||||
}
|
||||
|
||||
return flds
|
||||
}
|
587
internal/reflectx/reflect_test.go
Normal file
587
internal/reflectx/reflect_test.go
Normal file
@ -0,0 +1,587 @@
|
||||
package reflectx
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func ival(v reflect.Value) int {
|
||||
return v.Interface().(int)
|
||||
}
|
||||
|
||||
func TestBasic(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int
|
||||
B int
|
||||
C int
|
||||
}
|
||||
|
||||
f := Foo{1, 2, 3}
|
||||
fv := reflect.ValueOf(f)
|
||||
m := NewMapperFunc("", func(s string) string { return s })
|
||||
|
||||
v := m.FieldByName(fv, "A")
|
||||
if ival(v) != f.A {
|
||||
t.Errorf("Expecting %d, got %d", ival(v), f.A)
|
||||
}
|
||||
v = m.FieldByName(fv, "B")
|
||||
if ival(v) != f.B {
|
||||
t.Errorf("Expecting %d, got %d", f.B, ival(v))
|
||||
}
|
||||
v = m.FieldByName(fv, "C")
|
||||
if ival(v) != f.C {
|
||||
t.Errorf("Expecting %d, got %d", f.C, ival(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicEmbedded(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
Foo // `db:""` is implied for an embedded struct
|
||||
B int
|
||||
C int `db:"-"`
|
||||
}
|
||||
|
||||
type Baz struct {
|
||||
A int
|
||||
Bar `db:"Bar"`
|
||||
}
|
||||
|
||||
m := NewMapperFunc("db", func(s string) string { return s })
|
||||
|
||||
z := Baz{}
|
||||
z.A = 1
|
||||
z.B = 2
|
||||
z.C = 4
|
||||
z.Bar.Foo.A = 3
|
||||
|
||||
zv := reflect.ValueOf(z)
|
||||
fields := m.TypeMap(reflect.TypeOf(z))
|
||||
|
||||
if len(fields.Index) != 5 {
|
||||
t.Errorf("Expecting 5 fields")
|
||||
}
|
||||
|
||||
// for _, fi := range fields.Index {
|
||||
// log.Println(fi)
|
||||
// }
|
||||
|
||||
v := m.FieldByName(zv, "A")
|
||||
if ival(v) != z.A {
|
||||
t.Errorf("Expecting %d, got %d", z.A, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "Bar.B")
|
||||
if ival(v) != z.Bar.B {
|
||||
t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "Bar.A")
|
||||
if ival(v) != z.Bar.Foo.A {
|
||||
t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "Bar.C")
|
||||
if _, ok := v.Interface().(int); ok {
|
||||
t.Errorf("Expecting Bar.C to not exist")
|
||||
}
|
||||
|
||||
fi := fields.GetByPath("Bar.C")
|
||||
if fi != nil {
|
||||
t.Errorf("Bar.C should not exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEmbeddedSimple(t *testing.T) {
|
||||
type UUID [16]byte
|
||||
type MyID struct {
|
||||
UUID
|
||||
}
|
||||
type Item struct {
|
||||
ID MyID
|
||||
}
|
||||
z := Item{}
|
||||
|
||||
m := NewMapper("db")
|
||||
m.TypeMap(reflect.TypeOf(z))
|
||||
}
|
||||
|
||||
func TestBasicEmbeddedWithTags(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int `db:"a"`
|
||||
}
|
||||
|
||||
type Bar struct {
|
||||
Foo // `db:""` is implied for an embedded struct
|
||||
B int `db:"b"`
|
||||
}
|
||||
|
||||
type Baz struct {
|
||||
A int `db:"a"`
|
||||
Bar // `db:""` is implied for an embedded struct
|
||||
}
|
||||
|
||||
m := NewMapper("db")
|
||||
|
||||
z := Baz{}
|
||||
z.A = 1
|
||||
z.B = 2
|
||||
z.Bar.Foo.A = 3
|
||||
|
||||
zv := reflect.ValueOf(z)
|
||||
fields := m.TypeMap(reflect.TypeOf(z))
|
||||
|
||||
if len(fields.Index) != 5 {
|
||||
t.Errorf("Expecting 5 fields")
|
||||
}
|
||||
|
||||
// for _, fi := range fields.index {
|
||||
// log.Println(fi)
|
||||
// }
|
||||
|
||||
v := m.FieldByName(zv, "a")
|
||||
if ival(v) != z.Bar.Foo.A { // the dominant field
|
||||
t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v))
|
||||
}
|
||||
v = m.FieldByName(zv, "b")
|
||||
if ival(v) != z.B {
|
||||
t.Errorf("Expecting %d, got %d", z.B, ival(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFlatTags(t *testing.T) {
|
||||
m := NewMapper("db")
|
||||
|
||||
type Asset struct {
|
||||
Title string `db:"title"`
|
||||
}
|
||||
type Post struct {
|
||||
Author string `db:"author,required"`
|
||||
Asset Asset `db:""`
|
||||
}
|
||||
// Post columns: (author title)
|
||||
|
||||
post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}}
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
v := m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "title")
|
||||
if v.Interface().(string) != post.Asset.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNestedStruct(t *testing.T) {
|
||||
m := NewMapper("db")
|
||||
|
||||
type Details struct {
|
||||
Active bool `db:"active"`
|
||||
}
|
||||
type Asset struct {
|
||||
Title string `db:"title"`
|
||||
Details Details `db:"details"`
|
||||
}
|
||||
type Post struct {
|
||||
Author string `db:"author,required"`
|
||||
Asset `db:"asset"`
|
||||
}
|
||||
// Post columns: (author asset.title asset.details.active)
|
||||
|
||||
post := Post{
|
||||
Author: "Joe",
|
||||
Asset: Asset{Title: "Hello", Details: Details{Active: true}},
|
||||
}
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
v := m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "title")
|
||||
if _, ok := v.Interface().(string); ok {
|
||||
t.Errorf("Expecting field to not exist")
|
||||
}
|
||||
v = m.FieldByName(pv, "asset.title")
|
||||
if v.Interface().(string) != post.Asset.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset.details.active")
|
||||
if v.Interface().(bool) != post.Asset.Details.Active {
|
||||
t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool))
|
||||
}
|
||||
}
|
||||
|
||||
func TestInlineStruct(t *testing.T) {
|
||||
m := NewMapperTagFunc("db", strings.ToLower, nil)
|
||||
|
||||
type Employee struct {
|
||||
Name string
|
||||
ID int
|
||||
}
|
||||
type Boss Employee
|
||||
type person struct {
|
||||
Employee `db:"employee"`
|
||||
Boss `db:"boss"`
|
||||
}
|
||||
// employees columns: (employee.name employee.id boss.name boss.id)
|
||||
|
||||
em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}}
|
||||
ev := reflect.ValueOf(em)
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(em))
|
||||
if len(fields.Index) != 6 {
|
||||
t.Errorf("Expecting 6 fields")
|
||||
}
|
||||
|
||||
v := m.FieldByName(ev, "employee.name")
|
||||
if v.Interface().(string) != em.Employee.Name {
|
||||
t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(ev, "boss.id")
|
||||
if ival(v) != em.Boss.ID {
|
||||
t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldsEmbedded(t *testing.T) {
|
||||
m := NewMapper("db")
|
||||
|
||||
type Person struct {
|
||||
Name string `db:"name"`
|
||||
}
|
||||
type Place struct {
|
||||
Name string `db:"name"`
|
||||
}
|
||||
type Article struct {
|
||||
Title string `db:"title"`
|
||||
}
|
||||
type PP struct {
|
||||
Person `db:"person,required"`
|
||||
Place `db:",someflag"`
|
||||
Article `db:",required"`
|
||||
}
|
||||
// PP columns: (person.name name title)
|
||||
|
||||
pp := PP{}
|
||||
pp.Person.Name = "Peter"
|
||||
pp.Place.Name = "Toronto"
|
||||
pp.Article.Title = "Best city ever"
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(pp))
|
||||
// for i, f := range fields {
|
||||
// log.Println(i, f)
|
||||
// }
|
||||
|
||||
ppv := reflect.ValueOf(pp)
|
||||
|
||||
v := m.FieldByName(ppv, "person.name")
|
||||
if v.Interface().(string) != pp.Person.Name {
|
||||
t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string))
|
||||
}
|
||||
|
||||
v = m.FieldByName(ppv, "name")
|
||||
if v.Interface().(string) != pp.Place.Name {
|
||||
t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string))
|
||||
}
|
||||
|
||||
v = m.FieldByName(ppv, "title")
|
||||
if v.Interface().(string) != pp.Article.Title {
|
||||
t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string))
|
||||
}
|
||||
|
||||
fi := fields.GetByPath("person")
|
||||
if _, ok := fi.Options["required"]; !ok {
|
||||
t.Errorf("Expecting required option to be set")
|
||||
}
|
||||
if !fi.Embedded {
|
||||
t.Errorf("Expecting field to be embedded")
|
||||
}
|
||||
if len(fi.Index) != 1 || fi.Index[0] != 0 {
|
||||
t.Errorf("Expecting index to be [0]")
|
||||
}
|
||||
|
||||
fi = fields.GetByPath("person.name")
|
||||
if fi == nil {
|
||||
t.Errorf("Expecting person.name to exist")
|
||||
}
|
||||
if fi.Path != "person.name" {
|
||||
t.Errorf("Expecting %s, got %s", "person.name", fi.Path)
|
||||
}
|
||||
|
||||
fi = fields.GetByTraversal([]int{1, 0})
|
||||
if fi == nil {
|
||||
t.Errorf("Expecting traveral to exist")
|
||||
}
|
||||
if fi.Path != "name" {
|
||||
t.Errorf("Expecting %s, got %s", "name", fi.Path)
|
||||
}
|
||||
|
||||
fi = fields.GetByTraversal([]int{2})
|
||||
if fi == nil {
|
||||
t.Errorf("Expecting traversal to exist")
|
||||
}
|
||||
if _, ok := fi.Options["required"]; !ok {
|
||||
t.Errorf("Expecting required option to be set")
|
||||
}
|
||||
|
||||
trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"})
|
||||
if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) {
|
||||
t.Errorf("Expecting traversal: %v", trs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPtrFields(t *testing.T) {
|
||||
m := NewMapperTagFunc("db", strings.ToLower, nil)
|
||||
type Asset struct {
|
||||
Title string
|
||||
}
|
||||
type Post struct {
|
||||
*Asset `db:"asset"`
|
||||
Author string
|
||||
}
|
||||
|
||||
post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}}
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(post))
|
||||
if len(fields.Index) != 3 {
|
||||
t.Errorf("Expecting 3 fields")
|
||||
}
|
||||
|
||||
v := m.FieldByName(pv, "asset.title")
|
||||
if v.Interface().(string) != post.Asset.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNamedPtrFields(t *testing.T) {
|
||||
m := NewMapperTagFunc("db", strings.ToLower, nil)
|
||||
|
||||
type User struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
type Asset struct {
|
||||
Title string
|
||||
|
||||
Owner *User `db:"owner"`
|
||||
}
|
||||
type Post struct {
|
||||
Author string
|
||||
|
||||
Asset1 *Asset `db:"asset1"`
|
||||
Asset2 *Asset `db:"asset2"`
|
||||
}
|
||||
|
||||
post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil
|
||||
pv := reflect.ValueOf(post)
|
||||
|
||||
fields := m.TypeMap(reflect.TypeOf(post))
|
||||
if len(fields.Index) != 9 {
|
||||
t.Errorf("Expecting 9 fields")
|
||||
}
|
||||
|
||||
v := m.FieldByName(pv, "asset1.title")
|
||||
if v.Interface().(string) != post.Asset1.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset1.owner.name")
|
||||
if v.Interface().(string) != post.Asset1.Owner.Name {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset2.title")
|
||||
if v.Interface().(string) != post.Asset2.Title {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "asset2.owner.name")
|
||||
if v.Interface().(string) != post.Asset2.Owner.Name {
|
||||
t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string))
|
||||
}
|
||||
v = m.FieldByName(pv, "author")
|
||||
if v.Interface().(string) != post.Author {
|
||||
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFieldMap(t *testing.T) {
|
||||
type Foo struct {
|
||||
A int
|
||||
B int
|
||||
C int
|
||||
}
|
||||
|
||||
f := Foo{1, 2, 3}
|
||||
m := NewMapperFunc("db", strings.ToLower)
|
||||
|
||||
fm := m.FieldMap(reflect.ValueOf(f))
|
||||
|
||||
if len(fm) != 3 {
|
||||
t.Errorf("Expecting %d keys, got %d", 3, len(fm))
|
||||
}
|
||||
if fm["a"].Interface().(int) != 1 {
|
||||
t.Errorf("Expecting %d, got %d", 1, ival(fm["a"]))
|
||||
}
|
||||
if fm["b"].Interface().(int) != 2 {
|
||||
t.Errorf("Expecting %d, got %d", 2, ival(fm["b"]))
|
||||
}
|
||||
if fm["c"].Interface().(int) != 3 {
|
||||
t.Errorf("Expecting %d, got %d", 3, ival(fm["c"]))
|
||||
}
|
||||
}
|
||||
|
||||
func TestTagNameMapping(t *testing.T) {
|
||||
type Strategy struct {
|
||||
StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"`
|
||||
StrategyName string
|
||||
}
|
||||
|
||||
m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string {
|
||||
if strings.Contains(value, ",") {
|
||||
return strings.Split(value, ",")[0]
|
||||
}
|
||||
return value
|
||||
})
|
||||
strategy := Strategy{"1", "Alpah"}
|
||||
mapping := m.TypeMap(reflect.TypeOf(strategy))
|
||||
|
||||
for _, key := range []string{"strategy_id", "STRATEGYNAME"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestMapping(t *testing.T) {
|
||||
type Person struct {
|
||||
ID int
|
||||
Name string
|
||||
WearsGlasses bool `db:"wears_glasses"`
|
||||
}
|
||||
|
||||
m := NewMapperFunc("db", strings.ToLower)
|
||||
p := Person{1, "Jason", true}
|
||||
mapping := m.TypeMap(reflect.TypeOf(p))
|
||||
|
||||
for _, key := range []string{"id", "name", "wears_glasses"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
|
||||
type SportsPerson struct {
|
||||
Weight int
|
||||
Age int
|
||||
Person
|
||||
}
|
||||
s := SportsPerson{Weight: 100, Age: 30, Person: p}
|
||||
mapping = m.TypeMap(reflect.TypeOf(s))
|
||||
for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
|
||||
type RugbyPlayer struct {
|
||||
Position int
|
||||
IsIntense bool `db:"is_intense"`
|
||||
IsAllBlack bool `db:"-"`
|
||||
SportsPerson
|
||||
}
|
||||
r := RugbyPlayer{12, true, false, s}
|
||||
mapping = m.TypeMap(reflect.TypeOf(r))
|
||||
for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} {
|
||||
if fi := mapping.GetByPath(key); fi == nil {
|
||||
t.Errorf("Expecting to find key %s in mapping but did not.", key)
|
||||
}
|
||||
}
|
||||
|
||||
if fi := mapping.GetByPath("isallblack"); fi != nil {
|
||||
t.Errorf("Expecting to ignore `IsAllBlack` field")
|
||||
}
|
||||
}
|
||||
|
||||
type E1 struct {
|
||||
A int
|
||||
}
|
||||
type E2 struct {
|
||||
E1
|
||||
B int
|
||||
}
|
||||
type E3 struct {
|
||||
E2
|
||||
C int
|
||||
}
|
||||
type E4 struct {
|
||||
E3
|
||||
D int
|
||||
}
|
||||
|
||||
func BenchmarkFieldNameL1(b *testing.B) {
|
||||
e4 := E4{D: 1}
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.FieldByName("D")
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldNameL4(b *testing.B) {
|
||||
e4 := E4{}
|
||||
e4.A = 1
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.FieldByName("A")
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldPosL1(b *testing.B) {
|
||||
e4 := E4{D: 1}
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.Field(1)
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldPosL4(b *testing.B) {
|
||||
e4 := E4{}
|
||||
e4.A = 1
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := v.Field(0)
|
||||
f = f.Field(0)
|
||||
f = f.Field(0)
|
||||
f = f.Field(0)
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkFieldByIndexL4(b *testing.B) {
|
||||
e4 := E4{}
|
||||
e4.A = 1
|
||||
idx := []int{0, 0, 0, 0}
|
||||
for i := 0; i < b.N; i++ {
|
||||
v := reflect.ValueOf(e4)
|
||||
f := FieldByIndexes(v, idx)
|
||||
if f.Interface().(int) != 1 {
|
||||
b.Fatal("Wrong value.")
|
||||
}
|
||||
}
|
||||
}
|
369
internal/sqladapter/collection.go
Normal file
369
internal/sqladapter/collection.go
Normal file
@ -0,0 +1,369 @@
|
||||
package sqladapter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
|
||||
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
|
||||
)
|
||||
|
||||
// CollectionAdapter defines methods to be implemented by SQL adapters.
|
||||
type CollectionAdapter interface {
|
||||
// Insert prepares and executes an INSERT statament. When the item is
|
||||
// succefully added, Insert returns a unique identifier of the newly added
|
||||
// element (or nil if the unique identifier couldn't be determined).
|
||||
Insert(Collection, interface{}) (interface{}, error)
|
||||
}
|
||||
|
||||
// Collection satisfies mydb.Collection.
|
||||
type Collection interface {
|
||||
// Insert inserts a new item into the collection.
|
||||
Insert(interface{}) (mydb.InsertResult, error)
|
||||
|
||||
// Name returns the name of the collection.
|
||||
Name() string
|
||||
|
||||
// Session returns the mydb.Session the collection belongs to.
|
||||
Session() mydb.Session
|
||||
|
||||
// Exists returns true if the collection exists, false otherwise.
|
||||
Exists() (bool, error)
|
||||
|
||||
// Find defined a new result set.
|
||||
Find(conds ...interface{}) mydb.Result
|
||||
|
||||
Count() (uint64, error)
|
||||
|
||||
// Truncate removes all elements on the collection and resets the
|
||||
// collection's IDs.
|
||||
Truncate() error
|
||||
|
||||
// InsertReturning inserts a new item into the collection and refreshes the
|
||||
// item with actual data from the database. This is useful to get automatic
|
||||
// values, such as timestamps, or IDs.
|
||||
InsertReturning(item interface{}) error
|
||||
|
||||
// UpdateReturning updates a record from the collection and refreshes the item
|
||||
// with actual data from the database. This is useful to get automatic
|
||||
// values, such as timestamps, or IDs.
|
||||
UpdateReturning(item interface{}) error
|
||||
|
||||
// PrimaryKeys returns the names of all primary keys in the table.
|
||||
PrimaryKeys() ([]string, error)
|
||||
|
||||
// SQLBuilder returns a mydb.SQL instance.
|
||||
SQL() mydb.SQL
|
||||
}
|
||||
|
||||
type finder interface {
|
||||
Find(Collection, *Result, ...interface{}) mydb.Result
|
||||
}
|
||||
|
||||
type condsFilter interface {
|
||||
FilterConds(...interface{}) []interface{}
|
||||
}
|
||||
|
||||
// collection is the implementation of Collection.
|
||||
type collection struct {
|
||||
name string
|
||||
adapter CollectionAdapter
|
||||
}
|
||||
|
||||
type collectionWithSession struct {
|
||||
*collection
|
||||
|
||||
session Session
|
||||
}
|
||||
|
||||
func newCollection(name string, adapter CollectionAdapter) *collection {
|
||||
if adapter == nil {
|
||||
panic("mydb: nil adapter")
|
||||
}
|
||||
return &collection{
|
||||
name: name,
|
||||
adapter: adapter,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) SQL() mydb.SQL {
|
||||
return c.session.SQL()
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Session() mydb.Session {
|
||||
return c.session
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Count() (uint64, error) {
|
||||
return c.Find().Count()
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Insert(item interface{}) (mydb.InsertResult, error) {
|
||||
id, err := c.adapter.Insert(c, item)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return mydb.NewInsertResult(id), nil
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) PrimaryKeys() ([]string, error) {
|
||||
return c.session.PrimaryKeys(c.Name())
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) {
|
||||
pk, err := c.PrimaryKeys()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(conds) == 1 && len(pk) == 1 {
|
||||
if id := conds[0]; IsKeyValue(id) {
|
||||
conds[0] = mydb.Cond{pk[0]: mydb.Eq(id)}
|
||||
}
|
||||
}
|
||||
if tr, ok := c.adapter.(condsFilter); ok {
|
||||
return tr.FilterConds(conds...), nil
|
||||
}
|
||||
return conds, nil
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Find(conds ...interface{}) mydb.Result {
|
||||
filteredConds, err := c.filterConds(conds...)
|
||||
if err != nil {
|
||||
res := &Result{}
|
||||
res.setErr(err)
|
||||
return res
|
||||
}
|
||||
|
||||
res := NewResult(
|
||||
c.session.SQL(),
|
||||
c.Name(),
|
||||
filteredConds,
|
||||
)
|
||||
if f, ok := c.adapter.(finder); ok {
|
||||
return f.Find(c, res, conds...)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Exists() (bool, error) {
|
||||
if err := c.session.TableExists(c.Name()); err != nil {
|
||||
return false, err
|
||||
}
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) InsertReturning(item interface{}) error {
|
||||
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("Expecting a pointer but got %T", item)
|
||||
}
|
||||
|
||||
// Grab primary keys
|
||||
pks, err := c.PrimaryKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(pks) == 0 {
|
||||
if ok, err := c.Exists(); !ok {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name())
|
||||
}
|
||||
|
||||
var tx Session
|
||||
isTransaction := c.session.IsTransaction()
|
||||
if isTransaction {
|
||||
tx = c.session
|
||||
} else {
|
||||
var err error
|
||||
tx, err = c.session.NewTransaction(c.session.Context(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Close()
|
||||
}
|
||||
|
||||
// Allocate a clone of item.
|
||||
newItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface()
|
||||
var newItemFieldMap map[string]reflect.Value
|
||||
|
||||
itemValue := reflect.ValueOf(item)
|
||||
|
||||
col := tx.Collection(c.Name())
|
||||
|
||||
// Insert item as is and grab the returning ID.
|
||||
var newItemRes mydb.Result
|
||||
id, err := col.Insert(item)
|
||||
if err != nil {
|
||||
goto cancel
|
||||
}
|
||||
if id == nil {
|
||||
err = fmt.Errorf("InsertReturning: Could not get a valid ID after inserting. Does the %q table have a primary key?", c.Name())
|
||||
goto cancel
|
||||
}
|
||||
|
||||
if len(pks) > 1 {
|
||||
newItemRes = col.Find(id)
|
||||
} else {
|
||||
// We have one primary key, build a explicit mydb.Cond with it to prevent
|
||||
// string keys to be considered as raw conditions.
|
||||
newItemRes = col.Find(mydb.Cond{pks[0]: id}) // We already checked that pks is not empty, so pks[0] is defined.
|
||||
}
|
||||
|
||||
// Fetch the row that was just interted into newItem
|
||||
err = newItemRes.One(newItem)
|
||||
if err != nil {
|
||||
goto cancel
|
||||
}
|
||||
|
||||
switch reflect.ValueOf(newItem).Elem().Kind() {
|
||||
case reflect.Struct:
|
||||
// Get valid fields from newItem to overwrite those that are on item.
|
||||
newItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(newItem))
|
||||
for fieldName := range newItemFieldMap {
|
||||
sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(newItemFieldMap[fieldName])
|
||||
}
|
||||
case reflect.Map:
|
||||
newItemV := reflect.ValueOf(newItem).Elem()
|
||||
itemV := reflect.ValueOf(item)
|
||||
if itemV.Kind() == reflect.Ptr {
|
||||
itemV = itemV.Elem()
|
||||
}
|
||||
for _, keyV := range newItemV.MapKeys() {
|
||||
itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV))
|
||||
}
|
||||
default:
|
||||
err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem)
|
||||
goto cancel
|
||||
}
|
||||
|
||||
if !isTransaction {
|
||||
// This is only executed if t.Session() was **not** a transaction and if
|
||||
// sess was created with sess.NewTransaction().
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
return err
|
||||
|
||||
cancel:
|
||||
// This goto label should only be used when we got an error within a
|
||||
// transaction and we don't want to continue.
|
||||
|
||||
if !isTransaction {
|
||||
// This is only executed if t.Session() was **not** a transaction and if
|
||||
// sess was created with sess.NewTransaction().
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) UpdateReturning(item interface{}) error {
|
||||
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
|
||||
return fmt.Errorf("Expecting a pointer but got %T", item)
|
||||
}
|
||||
|
||||
// Grab primary keys
|
||||
pks, err := c.PrimaryKeys()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(pks) == 0 {
|
||||
if ok, err := c.Exists(); !ok {
|
||||
return err
|
||||
}
|
||||
return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name())
|
||||
}
|
||||
|
||||
var tx Session
|
||||
isTransaction := c.session.IsTransaction()
|
||||
|
||||
if isTransaction {
|
||||
tx = c.session
|
||||
} else {
|
||||
// Not within a transaction, let's create one.
|
||||
var err error
|
||||
tx, err = c.session.NewTransaction(c.session.Context(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Close()
|
||||
}
|
||||
|
||||
// Allocate a clone of item.
|
||||
defaultItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface()
|
||||
var defaultItemFieldMap map[string]reflect.Value
|
||||
|
||||
itemValue := reflect.ValueOf(item)
|
||||
|
||||
conds := mydb.Cond{}
|
||||
for _, pk := range pks {
|
||||
conds[pk] = mydb.Eq(sqlbuilder.Mapper.FieldByName(itemValue, pk).Interface())
|
||||
}
|
||||
|
||||
col := tx.(Session).Collection(c.Name())
|
||||
|
||||
err = col.Find(conds).Update(item)
|
||||
if err != nil {
|
||||
goto cancel
|
||||
}
|
||||
|
||||
if err = col.Find(conds).One(defaultItem); err != nil {
|
||||
goto cancel
|
||||
}
|
||||
|
||||
switch reflect.ValueOf(defaultItem).Elem().Kind() {
|
||||
case reflect.Struct:
|
||||
// Get valid fields from defaultItem to overwrite those that are on item.
|
||||
defaultItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(defaultItem))
|
||||
for fieldName := range defaultItemFieldMap {
|
||||
sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(defaultItemFieldMap[fieldName])
|
||||
}
|
||||
case reflect.Map:
|
||||
defaultItemV := reflect.ValueOf(defaultItem).Elem()
|
||||
itemV := reflect.ValueOf(item)
|
||||
if itemV.Kind() == reflect.Ptr {
|
||||
itemV = itemV.Elem()
|
||||
}
|
||||
for _, keyV := range defaultItemV.MapKeys() {
|
||||
itemV.SetMapIndex(keyV, defaultItemV.MapIndex(keyV))
|
||||
}
|
||||
default:
|
||||
panic("default")
|
||||
}
|
||||
|
||||
if !isTransaction {
|
||||
// This is only executed if t.Session() was **not** a transaction and if
|
||||
// sess was created with sess.NewTransaction().
|
||||
return tx.Commit()
|
||||
}
|
||||
return err
|
||||
|
||||
cancel:
|
||||
// This goto label should only be used when we got an error within a
|
||||
// transaction and we don't want to continue.
|
||||
|
||||
if !isTransaction {
|
||||
// This is only executed if t.Session() was **not** a transaction and if
|
||||
// sess was created with sess.NewTransaction().
|
||||
_ = tx.Rollback()
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *collectionWithSession) Truncate() error {
|
||||
stmt := exql.Statement{
|
||||
Type: exql.Truncate,
|
||||
Table: exql.TableWithName(c.Name()),
|
||||
}
|
||||
if _, err := c.session.SQL().Exec(&stmt); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
72
internal/sqladapter/compat/query.go
Normal file
72
internal/sqladapter/compat/query.go
Normal file
@ -0,0 +1,72 @@
|
||||
// +build !go1.8
|
||||
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type PreparedExecer interface {
|
||||
Exec(...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) {
|
||||
return p.Exec(args...)
|
||||
}
|
||||
|
||||
type Execer interface {
|
||||
Exec(string, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) {
|
||||
return p.Exec(query, args...)
|
||||
}
|
||||
|
||||
type PreparedQueryer interface {
|
||||
Query(...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) {
|
||||
return p.Query(args...)
|
||||
}
|
||||
|
||||
type Queryer interface {
|
||||
Query(string, ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) {
|
||||
return p.Query(query, args...)
|
||||
}
|
||||
|
||||
type PreparedRowQueryer interface {
|
||||
QueryRow(...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row {
|
||||
return p.QueryRow(args...)
|
||||
}
|
||||
|
||||
type RowQueryer interface {
|
||||
QueryRow(string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row {
|
||||
return p.QueryRow(query, args...)
|
||||
}
|
||||
|
||||
type Preparer interface {
|
||||
Prepare(string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
return p.Prepare(query)
|
||||
}
|
||||
|
||||
type TxStarter interface {
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
func BeginTx(p TxStarter, ctx context.Context, opts interface{}) (*sql.Tx, error) {
|
||||
return p.Begin()
|
||||
}
|
72
internal/sqladapter/compat/query_go18.go
Normal file
72
internal/sqladapter/compat/query_go18.go
Normal file
@ -0,0 +1,72 @@
|
||||
// +build go1.8
|
||||
|
||||
package compat
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
)
|
||||
|
||||
type PreparedExecer interface {
|
||||
ExecContext(context.Context, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) {
|
||||
return p.ExecContext(ctx, args...)
|
||||
}
|
||||
|
||||
type Execer interface {
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) {
|
||||
return p.ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
type PreparedQueryer interface {
|
||||
QueryContext(context.Context, ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) {
|
||||
return p.QueryContext(ctx, args...)
|
||||
}
|
||||
|
||||
type Queryer interface {
|
||||
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) {
|
||||
return p.QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
type PreparedRowQueryer interface {
|
||||
QueryRowContext(context.Context, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row {
|
||||
return p.QueryRowContext(ctx, args...)
|
||||
}
|
||||
|
||||
type RowQueryer interface {
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row {
|
||||
return p.QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
type Preparer interface {
|
||||
PrepareContext(context.Context, string) (*sql.Stmt, error)
|
||||
}
|
||||
|
||||
func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) {
|
||||
return p.PrepareContext(ctx, query)
|
||||
}
|
||||
|
||||
type TxStarter interface {
|
||||
BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
|
||||
}
|
||||
|
||||
func BeginTx(p TxStarter, ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
|
||||
return p.BeginTx(ctx, opts)
|
||||
}
|
83
internal/sqladapter/exql/column.go
Normal file
83
internal/sqladapter/exql/column.go
Normal file
@ -0,0 +1,83 @@
|
||||
package exql
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"git.hexq.cn/tiglog/mydb/internal/cache"
|
||||
)
|
||||
|
||||
type columnWithAlias struct {
|
||||
Name string
|
||||
Alias string
|
||||
}
|
||||
|
||||
// Column represents a SQL column.
|
||||
type Column struct {
|
||||
Name interface{}
|
||||
}
|
||||
|
||||
var _ = Fragment(&Column{})
|
||||
|
||||
// ColumnWithName creates and returns a Column with the given name.
|
||||
func ColumnWithName(name string) *Column {
|
||||
return &Column{Name: name}
|
||||
}
|
||||
|
||||
// Hash returns a unique identifier for the struct.
|
||||
func (c *Column) Hash() uint64 {
|
||||
if c == nil {
|
||||
return cache.NewHash(FragmentType_Column, nil)
|
||||
}
|
||||
return cache.NewHash(FragmentType_Column, c.Name)
|
||||
}
|
||||
|
||||
// Compile transforms the ColumnValue into an equivalent SQL representation.
|
||||
func (c *Column) Compile(layout *Template) (compiled string, err error) {
|
||||
if z, ok := layout.Read(c); ok {
|
||||
return z, nil
|
||||
}
|
||||
|
||||
var alias string
|
||||
switch value := c.Name.(type) {
|
||||
case string:
|
||||
value = trimString(value)
|
||||
|
||||
chunks := separateByAS(value)
|
||||
if len(chunks) == 1 {
|
||||
chunks = separateBySpace(value)
|
||||
}
|
||||
|
||||
name := chunks[0]
|
||||
nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2)
|
||||
|
||||
for i := range nameChunks {
|
||||
nameChunks[i] = trimString(nameChunks[i])
|
||||
if nameChunks[i] == "*" {
|
||||
continue
|
||||
}
|
||||
nameChunks[i] = layout.MustCompile(layout.IdentifierQuote, Raw{Value: nameChunks[i]})
|
||||
}
|
||||
|
||||
compiled = strings.Join(nameChunks, layout.ColumnSeparator)
|
||||
|
||||
if len(chunks) > 1 {
|
||||
alias = trimString(chunks[1])
|
||||
alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias})
|
||||
}
|
||||
case compilable:
|
||||
compiled, err = value.Compile(layout)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
default:
|
||||
return "", fmt.Errorf(errExpectingHashableFmt, c.Name)
|
||||
}
|
||||
|
||||
if alias != "" {
|
||||
compiled = layout.MustCompile(layout.ColumnAliasLayout, columnWithAlias{compiled, alias})
|
||||
}
|
||||
|
||||
layout.Write(c, compiled)
|
||||
return
|
||||
}
|
88
internal/sqladapter/exql/column_test.go
Normal file
88
internal/sqladapter/exql/column_test.go
Normal file
@ -0,0 +1,88 @@
|
||||
package exql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestColumnString(t *testing.T) {
|
||||
column := Column{Name: "role.name"}
|
||||
s, err := column.Compile(defaultTemplate)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `"role"."name"`, s)
|
||||
}
|
||||
|
||||
func TestColumnAs(t *testing.T) {
|
||||
column := Column{Name: "role.name as foo"}
|
||||
s, err := column.Compile(defaultTemplate)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `"role"."name" AS "foo"`, s)
|
||||
}
|
||||
|
||||
func TestColumnImplicitAs(t *testing.T) {
|
||||
column := Column{Name: "role.name foo"}
|
||||
s, err := column.Compile(defaultTemplate)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `"role"."name" AS "foo"`, s)
|
||||
}
|
||||
|
||||
func TestColumnRaw(t *testing.T) {
|
||||
column := Column{Name: &Raw{Value: "role.name As foo"}}
|
||||
s, err := column.Compile(defaultTemplate)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, `role.name As foo`, s)
|
||||
}
|
||||
|
||||
func BenchmarkColumnWithName(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = ColumnWithName("a")
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkColumnHash(b *testing.B) {
|
||||
c := Column{Name: "name"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
c.Hash()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkColumnCompile(b *testing.B) {
|
||||
c := Column{Name: "name"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = c.Compile(defaultTemplate)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkColumnCompileNoCache(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
c := Column{Name: "name"}
|
||||
_, _ = c.Compile(defaultTemplate)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkColumnWithDotCompile(b *testing.B) {
|
||||
c := Column{Name: "role.name"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = c.Compile(defaultTemplate)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) {
|
||||
c := Column{Name: "role.name foo"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = c.Compile(defaultTemplate)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkColumnWithAsKeywordCompile(b *testing.B) {
|
||||
c := Column{Name: "role.name AS foo"}
|
||||
b.ResetTimer()
|
||||
for i := 0; i < b.N; i++ {
|
||||
_, _ = c.Compile(defaultTemplate)
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user