first commit

This commit is contained in:
tiglog 2023-09-18 15:15:42 +08:00
commit 4ac92c26b0
175 changed files with 28823 additions and 0 deletions

4
.gitignore vendored Normal file
View File

@ -0,0 +1,4 @@
*.sw?
*.db
*.tmp
generated_*.go

39
Makefile Normal file
View File

@ -0,0 +1,39 @@
SHELL ?= /bin/bash
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs=2 -v -u
TEST_FLAGS ?=
UPPER_DB_LOG ?= WARN
export TEST_FLAGS
export PARALLEL_FLAGS
export UPPER_DB_LOG
test: go-test-internal test-adapters
benchmark: go-benchmark-internal
go-benchmark-%:
go test -v -benchtime=500ms -bench=. ./$*/...
go-test-%:
go test -v ./$*/...
test-adapters: \
test-adapter-postgresql \
# test-adapter-mysql \
# test-adapter-sqlite \
# test-adapter-mongo
test-adapter-%:
($(MAKE) -C adapter/$* test-extended || exit 1)
test-generic:
export TEST_FLAGS="-run TestGeneric"; \
$(MAKE) test-adapters
goimports:
for FILE in $$(find -name "*.go" | grep -v vendor); do \
goimports -w $$FILE; \
done

54
adapter.go Normal file
View File

@ -0,0 +1,54 @@
package mydb
import (
"fmt"
"sync"
)
var (
adapterMap = make(map[string]Adapter)
adapterMapMu sync.RWMutex
)
// Adapter interface defines an adapter
type Adapter interface {
Open(ConnectionURL) (Session, error)
}
type missingAdapter struct {
name string
}
func (ma *missingAdapter) Open(ConnectionURL) (Session, error) {
return nil, fmt.Errorf("mydb: Missing adapter %q, did you forget to import it?", ma.name)
}
// RegisterAdapter registers a generic database adapter.
func RegisterAdapter(name string, adapter Adapter) {
adapterMapMu.Lock()
defer adapterMapMu.Unlock()
if name == "" {
panic(`Missing adapter name`)
}
if _, ok := adapterMap[name]; ok {
panic(`db.RegisterAdapter() called twice for adapter: ` + name)
}
adapterMap[name] = adapter
}
// LookupAdapter returns a previously registered adapter by name.
func LookupAdapter(name string) Adapter {
adapterMapMu.RLock()
defer adapterMapMu.RUnlock()
if adapter, ok := adapterMap[name]; ok {
return adapter
}
return &missingAdapter{name: name}
}
// Open attempts to stablish a connection with a database.
func Open(adapterName string, settings ConnectionURL) (Session, error) {
return LookupAdapter(adapterName).Open(settings)
}

43
adapter/mongo/Makefile Normal file
View File

@ -0,0 +1,43 @@
SHELL ?= bash
MONGO_VERSION ?= 4
MONGO_SUPPORTED ?= $(MONGO_VERSION) 3
PROJECT ?= upper_mongo_$(MONGO_VERSION)
DB_HOST ?= 127.0.0.1
DB_PORT ?= 27017
DB_NAME ?= admin
DB_USERNAME ?= upperio_user
DB_PASSWORD ?= upperio//s3cr37
TEST_FLAGS ?=
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1
export MONGO_VERSION
export DB_HOST
export DB_NAME
export DB_PASSWORD
export DB_PORT
export DB_USERNAME
export TEST_FLAGS
test:
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
test-no-race:
go test -v -failfast $(TEST_FLAGS)
server-up: server-down
docker-compose -p $(PROJECT) up -d && \
sleep 10
server-down:
docker-compose -p $(PROJECT) down
test-extended:
parallel $(PARALLEL_FLAGS) \
"MONGO_VERSION={} DB_PORT=\$$((27017+{#})) $(MAKE) server-up test server-down" ::: \
$(MONGO_SUPPORTED)

4
adapter/mongo/README.md Normal file
View File

@ -0,0 +1,4 @@
# MongoDB adapter for upper/db
Please read the full docs, acknowledgements and examples at
[https://upper.io/v4/adapter/mongo/](https://upper.io/v4/adapter/mongo/).

346
adapter/mongo/collection.go Normal file
View File

@ -0,0 +1,346 @@
package mongo
import (
"fmt"
"strings"
"sync"
"reflect"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/adapter"
mgo "gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
// Collection represents a mongodb collection.
type Collection struct {
parent *Source
collection *mgo.Collection
}
var (
// idCache should be a struct if we're going to cache more than just
// _id field here
idCache = make(map[reflect.Type]string)
idCacheMutex sync.RWMutex
)
// Find creates a result set with the given conditions.
func (col *Collection) Find(terms ...interface{}) mydb.Result {
fields := []string{"*"}
conditions := col.compileQuery(terms...)
res := &result{}
res = res.frame(func(r *resultQuery) error {
r.c = col
r.conditions = conditions
r.fields = fields
return nil
})
return res
}
var comparisonOperators = map[adapter.ComparisonOperator]string{
adapter.ComparisonOperatorEqual: "$eq",
adapter.ComparisonOperatorNotEqual: "$ne",
adapter.ComparisonOperatorLessThan: "$lt",
adapter.ComparisonOperatorGreaterThan: "$gt",
adapter.ComparisonOperatorLessThanOrEqualTo: "$lte",
adapter.ComparisonOperatorGreaterThanOrEqualTo: "$gte",
adapter.ComparisonOperatorIn: "$in",
adapter.ComparisonOperatorNotIn: "$nin",
}
func compare(field string, cmp *adapter.Comparison) (string, interface{}) {
op := cmp.Operator()
value := cmp.Value()
switch op {
case adapter.ComparisonOperatorEqual:
return field, value
case adapter.ComparisonOperatorBetween:
values := value.([]interface{})
return field, bson.M{
"$gte": values[0],
"$lte": values[1],
}
case adapter.ComparisonOperatorNotBetween:
values := value.([]interface{})
return "$or", []bson.M{
{field: bson.M{"$gt": values[1]}},
{field: bson.M{"$lt": values[0]}},
}
case adapter.ComparisonOperatorIs:
if value == nil {
return field, bson.M{"$exists": false}
}
return field, bson.M{"$eq": value}
case adapter.ComparisonOperatorIsNot:
if value == nil {
return field, bson.M{"$exists": true}
}
return field, bson.M{"$ne": value}
case adapter.ComparisonOperatorRegExp, adapter.ComparisonOperatorLike:
return field, bson.RegEx{Pattern: value.(string), Options: ""}
case adapter.ComparisonOperatorNotRegExp, adapter.ComparisonOperatorNotLike:
return field, bson.M{"$not": bson.RegEx{Pattern: value.(string), Options: ""}}
}
if cmpOp, ok := comparisonOperators[op]; ok {
return field, bson.M{
cmpOp: value,
}
}
panic(fmt.Sprintf("Unsupported operator %v", op))
}
// compileStatement transforms conditions into something *mgo.Session can
// understand.
func compileStatement(cond mydb.Cond) bson.M {
conds := bson.M{}
// Walking over conditions
for fieldI, value := range cond {
field := strings.TrimSpace(fmt.Sprintf("%v", fieldI))
if cmp, ok := value.(*mydb.Comparison); ok {
k, v := compare(field, cmp.Comparison)
conds[k] = v
continue
}
var op string
chunks := strings.SplitN(field, ` `, 2)
if len(chunks) > 1 {
switch chunks[1] {
case `IN`:
op = `$in`
case `NOT IN`:
op = `$nin`
case `>`:
op = `$gt`
case `<`:
op = `$lt`
case `<=`:
op = `$lte`
case `>=`:
op = `$gte`
default:
op = chunks[1]
}
}
field = chunks[0]
if op == "" {
conds[field] = value
} else {
conds[field] = bson.M{op: value}
}
}
return conds
}
// compileConditions compiles terms into something *mgo.Session can
// understand.
func (col *Collection) compileConditions(term interface{}) interface{} {
switch t := term.(type) {
case []interface{}:
values := []interface{}{}
for i := range t {
value := col.compileConditions(t[i])
if value != nil {
values = append(values, value)
}
}
if len(values) > 0 {
return values
}
case mydb.Cond:
return compileStatement(t)
case adapter.LogicalExpr:
values := []interface{}{}
for _, s := range t.Expressions() {
values = append(values, col.compileConditions(s))
}
var op string
switch t.Operator() {
case adapter.LogicalOperatorOr:
op = `$or`
default:
op = `$and`
}
return bson.M{op: values}
}
return nil
}
// compileQuery compiles terms into something that *mgo.Session can
// understand.
func (col *Collection) compileQuery(terms ...interface{}) interface{} {
compiled := col.compileConditions(terms)
if compiled == nil {
return nil
}
conditions := compiled.([]interface{})
if len(conditions) == 1 {
return conditions[0]
}
// this should be correct.
// query = map[string]interface{}{"$and": conditions}
// attempt to workaround https://jira.mongomydb.org/browse/SERVER-4572
mapped := map[string]interface{}{}
for _, v := range conditions {
for kk := range v.(map[string]interface{}) {
mapped[kk] = v.(map[string]interface{})[kk]
}
}
return mapped
}
// Name returns the name of the table or tables that form the collection.
func (col *Collection) Name() string {
return col.collection.Name
}
// Truncate deletes all rows from the table.
func (col *Collection) Truncate() error {
err := col.collection.DropCollection()
if err != nil {
return err
}
return nil
}
func (col *Collection) Session() mydb.Session {
return col.parent
}
func (col *Collection) Count() (uint64, error) {
return col.Find().Count()
}
func (col *Collection) InsertReturning(item interface{}) error {
return mydb.ErrUnsupported
}
func (col *Collection) UpdateReturning(item interface{}) error {
return mydb.ErrUnsupported
}
// Insert inserts a record (map or struct) into the collection.
func (col *Collection) Insert(item interface{}) (mydb.InsertResult, error) {
var err error
id := getID(item)
if col.parent.versionAtLeast(2, 6, 0, 0) {
// this breaks MongoDb older than 2.6
if _, err = col.collection.Upsert(bson.M{"_id": id}, item); err != nil {
return nil, err
}
} else {
// Allocating a new ID.
if err = col.collection.Insert(bson.M{"_id": id}); err != nil {
return nil, err
}
// Now append data the user wants to append.
if err = col.collection.Update(bson.M{"_id": id}, item); err != nil {
// Cleanup allocated ID
if err := col.collection.Remove(bson.M{"_id": id}); err != nil {
return nil, err
}
return nil, err
}
}
return mydb.NewInsertResult(id), nil
}
// Exists returns true if the collection exists.
func (col *Collection) Exists() (bool, error) {
query := col.parent.database.C(`system.namespaces`).Find(map[string]string{`name`: fmt.Sprintf(`%s.%s`, col.parent.database.Name, col.collection.Name)})
count, err := query.Count()
return count > 0, err
}
// Fetches object _id or generates a new one if object doesn't have one or the one it has is invalid
func getID(item interface{}) interface{} {
v := reflect.ValueOf(item) // convert interface to Value
v = reflect.Indirect(v) // convert pointers
switch v.Kind() {
case reflect.Map:
if inItem, ok := item.(map[string]interface{}); ok {
if id, ok := inItem["_id"]; ok {
bsonID, ok := id.(bson.ObjectId)
if ok {
return bsonID
}
}
}
case reflect.Struct:
t := v.Type()
idCacheMutex.RLock()
fieldName, found := idCache[t]
idCacheMutex.RUnlock()
if !found {
for n := 0; n < t.NumField(); n++ {
field := t.Field(n)
if field.PkgPath != "" {
continue // Private field
}
tag := field.Tag.Get("bson")
if tag == "" {
tag = field.Tag.Get("db")
}
if tag == "" {
continue
}
parts := strings.Split(tag, ",")
if parts[0] == "_id" {
fieldName = field.Name
idCacheMutex.RLock()
idCache[t] = fieldName
idCacheMutex.RUnlock()
break
}
}
}
if fieldName != "" {
if bsonID, ok := v.FieldByName(fieldName).Interface().(bson.ObjectId); ok {
if bsonID.Valid() {
return bsonID
}
} else {
return v.FieldByName(fieldName).Interface()
}
}
}
return bson.NewObjectId()
}

View File

@ -0,0 +1,98 @@
package mongo
import (
"fmt"
"net/url"
"strings"
)
const connectionScheme = `mongodb`
// ConnectionURL implements a MongoDB connection struct.
type ConnectionURL struct {
User string
Password string
Host string
Database string
Options map[string]string
}
func (c ConnectionURL) String() (s string) {
if c.Database == "" {
return ""
}
vv := url.Values{}
// Do we have any options?
if c.Options == nil {
c.Options = map[string]string{}
}
// Converting options into URL values.
for k, v := range c.Options {
vv.Set(k, v)
}
// Has user?
var userInfo *url.Userinfo
if c.User != "" {
if c.Password == "" {
userInfo = url.User(c.User)
} else {
userInfo = url.UserPassword(c.User, c.Password)
}
}
// Building URL.
u := url.URL{
Scheme: connectionScheme,
Path: c.Database,
Host: c.Host,
User: userInfo,
RawQuery: vv.Encode(),
}
return u.String()
}
// ParseURL parses s into a ConnectionURL struct.
func ParseURL(s string) (conn ConnectionURL, err error) {
var u *url.URL
if !strings.HasPrefix(s, connectionScheme+"://") {
return conn, fmt.Errorf(`Expecting mongodb:// connection scheme.`)
}
if u, err = url.Parse(s); err != nil {
return conn, err
}
conn.Host = u.Host
// Deleting / from start of the string.
conn.Database = strings.Trim(u.Path, "/")
// Adding user / password.
if u.User != nil {
conn.User = u.User.Username()
conn.Password, _ = u.User.Password()
}
// Adding options.
conn.Options = map[string]string{}
var vv url.Values
if vv, err = url.ParseQuery(u.RawQuery); err != nil {
return conn, err
}
for k := range vv {
conn.Options[k] = vv.Get(k)
}
return conn, err
}

View File

@ -0,0 +1,114 @@
package mongo
import (
"testing"
)
func TestConnectionURL(t *testing.T) {
c := ConnectionURL{}
// Default connection string is only the protocol.
if c.String() != "" {
t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String())
}
// Adding a database name.
c.Database = "myfilename"
if c.String() != "mongodb://myfilename" {
t.Fatal(`Test failed, got:`, c.String())
}
// Adding an option.
c.Options = map[string]string{
"cache": "foobar",
"mode": "ro",
}
// Adding username and password
c.User = "user"
c.Password = "pass"
// Setting host.
c.Host = "localhost"
if c.String() != "mongodb://user:pass@localhost/myfilename?cache=foobar&mode=ro" {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting host and port.
c.Host = "localhost:27017"
if c.String() != "mongodb://user:pass@localhost:27017/myfilename?cache=foobar&mode=ro" {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting cluster.
c.Host = "localhost,1.2.3.4,example.org:1234"
if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/myfilename?cache=foobar&mode=ro" {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting another database.
c.Database = "another_database"
if c.String() != "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro" {
t.Fatal(`Test failed, got:`, c.String())
}
}
func TestParseConnectionURL(t *testing.T) {
var u ConnectionURL
var s string
var err error
s = "mongodb:///mydatabase"
if u, err = ParseURL(s); err != nil {
t.Fatal(err)
}
if u.Database != "mydatabase" {
t.Fatal("Failed to parse database.")
}
s = "mongodb://user:pass@localhost,1.2.3.4,example.org:1234/another_database?cache=foobar&mode=ro"
if u, err = ParseURL(s); err != nil {
t.Fatal(err)
}
if u.Database != "another_database" {
t.Fatal("Failed to get database.")
}
if u.Options["cache"] != "foobar" {
t.Fatal("Expecting option.")
}
if u.Options["mode"] != "ro" {
t.Fatal("Expecting option.")
}
if u.User != "user" {
t.Fatal("Expecting user.")
}
if u.Password != "pass" {
t.Fatal("Expecting password.")
}
if u.Host != "localhost,1.2.3.4,example.org:1234" {
t.Fatal("Expecting host.")
}
s = "http://example.org"
if _, err = ParseURL(s); err == nil {
t.Fatal("Expecting error.")
}
}

245
adapter/mongo/database.go Normal file
View File

@ -0,0 +1,245 @@
// Package mongo wraps the gopkg.in/mgo.v2 MongoDB driver. See
// https://github.com/upper/db/adapter/mongo for documentation, particularities and usage
// examples.
package mongo
import (
"context"
"database/sql"
"strings"
"sync"
"time"
"git.hexq.cn/tiglog/mydb"
mgo "gopkg.in/mgo.v2"
)
// Adapter holds the name of the mongodb adapter.
const Adapter = `mongo`
var connTimeout = time.Second * 5
// Source represents a MongoDB database.
type Source struct {
mydb.Settings
ctx context.Context
name string
connURL mydb.ConnectionURL
session *mgo.Session
database *mgo.Database
version []int
collections map[string]*Collection
collectionsMu sync.Mutex
}
type mongoAdapter struct {
}
func (mongoAdapter) Open(dsn mydb.ConnectionURL) (mydb.Session, error) {
return Open(dsn)
}
func init() {
mydb.RegisterAdapter(Adapter, mydb.Adapter(&mongoAdapter{}))
}
// Open stablishes a new connection to a SQL server.
func Open(settings mydb.ConnectionURL) (mydb.Session, error) {
d := &Source{Settings: mydb.NewSettings(), ctx: context.Background()}
if err := d.Open(settings); err != nil {
return nil, err
}
return d, nil
}
func (s *Source) TxContext(context.Context, func(tx mydb.Session) error, *sql.TxOptions) error {
return mydb.ErrNotSupportedByAdapter
}
func (s *Source) Tx(func(mydb.Session) error) error {
return mydb.ErrNotSupportedByAdapter
}
func (s *Source) SQL() mydb.SQL {
// Not supported
panic("sql builder is not supported by mongodb")
}
func (s *Source) ConnectionURL() mydb.ConnectionURL {
return s.connURL
}
// SetConnMaxLifetime is not supported.
func (s *Source) SetConnMaxLifetime(time.Duration) {
s.Settings.SetConnMaxLifetime(time.Duration(0))
}
// SetMaxIdleConns is not supported.
func (s *Source) SetMaxIdleConns(int) {
s.Settings.SetMaxIdleConns(0)
}
// SetMaxOpenConns is not supported.
func (s *Source) SetMaxOpenConns(int) {
s.Settings.SetMaxOpenConns(0)
}
// Name returns the name of the database.
func (s *Source) Name() string {
return s.name
}
// Open attempts to connect to the database.
func (s *Source) Open(connURL mydb.ConnectionURL) error {
s.connURL = connURL
return s.open()
}
// Clone returns a cloned mydb.Session session.
func (s *Source) Clone() (mydb.Session, error) {
newSession := s.session.Copy()
clone := &Source{
Settings: mydb.NewSettings(),
name: s.name,
connURL: s.connURL,
session: newSession,
database: newSession.DB(s.database.Name),
version: s.version,
collections: map[string]*Collection{},
}
return clone, nil
}
// Ping checks whether a connection to the database is still alive by pinging
// it, establishing a connection if necessary.
func (s *Source) Ping() error {
return s.session.Ping()
}
func (s *Source) Reset() {
s.collectionsMu.Lock()
defer s.collectionsMu.Unlock()
s.collections = make(map[string]*Collection)
}
// Driver returns the underlying *mgo.Session instance.
func (s *Source) Driver() interface{} {
return s.session
}
func (s *Source) open() error {
var err error
if s.session, err = mgo.DialWithTimeout(s.connURL.String(), connTimeout); err != nil {
return err
}
s.collections = map[string]*Collection{}
s.database = s.session.DB("")
return nil
}
// Close terminates the current database session.
func (s *Source) Close() error {
if s.session != nil {
s.session.Close()
}
return nil
}
// Collections returns a list of non-system tables from the database.
func (s *Source) Collections() (cols []mydb.Collection, err error) {
var rawcols []string
var col string
if rawcols, err = s.database.CollectionNames(); err != nil {
return nil, err
}
cols = make([]mydb.Collection, 0, len(rawcols))
for _, col = range rawcols {
if !strings.HasPrefix(col, "system.") {
cols = append(cols, s.Collection(col))
}
}
return cols, nil
}
func (s *Source) Delete(mydb.Record) error {
return mydb.ErrNotImplemented
}
func (s *Source) Get(mydb.Record, interface{}) error {
return mydb.ErrNotImplemented
}
func (s *Source) Save(mydb.Record) error {
return mydb.ErrNotImplemented
}
func (s *Source) Context() context.Context {
return s.ctx
}
func (s *Source) WithContext(ctx context.Context) mydb.Session {
return &Source{
ctx: ctx,
Settings: s.Settings,
name: s.name,
connURL: s.connURL,
session: s.session,
database: s.database,
version: s.version,
}
}
// Collection returns a collection by name.
func (s *Source) Collection(name string) mydb.Collection {
s.collectionsMu.Lock()
defer s.collectionsMu.Unlock()
var col *Collection
var ok bool
if col, ok = s.collections[name]; !ok {
col = &Collection{
parent: s,
collection: s.database.C(name),
}
s.collections[name] = col
}
return col
}
func (s *Source) versionAtLeast(version ...int) bool {
// only fetch this once - it makes a db call
if len(s.version) == 0 {
buildInfo, err := s.database.Session.BuildInfo()
if err != nil {
return false
}
s.version = buildInfo.VersionArray
}
// Check major version first
if s.version[0] > version[0] {
return true
}
for i := range version {
if i == len(s.version) {
return false
}
if s.version[i] < version[i] {
return false
}
}
return true
}

View File

@ -0,0 +1,13 @@
version: '3'
services:
server:
image: mongo:${MONGO_VERSION:-3}
environment:
MONGO_INITDB_ROOT_USERNAME: ${DB_USERNAME:-upperio_user}
MONGO_INITDB_ROOT_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37}
MONGO_INITDB_DATABASE: ${DB_NAME:-upperio}
ports:
- '${BIND_HOST:-127.0.0.1}:${DB_PORT:-27017}:27017'

View File

@ -0,0 +1,20 @@
package mongo
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type GenericTests struct {
testsuite.GenericTestSuite
}
func (s *GenericTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestGeneric(t *testing.T) {
suite.Run(t, &GenericTests{})
}

View File

@ -0,0 +1,77 @@
package mongo
import (
"fmt"
"os"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
mgo "gopkg.in/mgo.v2"
)
var settings = ConnectionURL{
Database: os.Getenv("DB_NAME"),
User: os.Getenv("DB_USERNAME"),
Password: os.Getenv("DB_PASSWORD"),
Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"),
}
type Helper struct {
sess mydb.Session
}
func (h *Helper) Session() mydb.Session {
return h.sess
}
func (h *Helper) Adapter() string {
return "mongo"
}
func (h *Helper) TearDown() error {
return h.sess.Close()
}
func (h *Helper) TearUp() error {
var err error
h.sess, err = Open(settings)
if err != nil {
return err
}
mgod, ok := h.sess.Driver().(*mgo.Session)
if !ok {
panic("expecting mgo.Session")
}
var col *mgo.Collection
col = mgod.DB(settings.Database).C("birthdays")
_ = col.DropCollection()
col = mgod.DB(settings.Database).C("fibonacci")
_ = col.DropCollection()
col = mgod.DB(settings.Database).C("is_even")
_ = col.DropCollection()
col = mgod.DB(settings.Database).C("CaSe_TesT")
_ = col.DropCollection()
// Getting a pointer to the "artist" collection.
artist := h.sess.Collection("artist")
_ = artist.Truncate()
for i := 0; i < 999; i++ {
_, err = artist.Insert(artistType{
Name: fmt.Sprintf("artist-%d", i),
})
if err != nil {
return err
}
}
return nil
}
var _ testsuite.Helper = &Helper{}

754
adapter/mongo/mongo_test.go Normal file
View File

@ -0,0 +1,754 @@
// Tests for the mongodb adapter.
package mongo
import (
"fmt"
"log"
"math/rand"
"strings"
"testing"
"time"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
"gopkg.in/mgo.v2/bson"
)
type artistType struct {
ID bson.ObjectId `bson:"_id,omitempty"`
Name string `bson:"name"`
}
// Structure for testing conversions and datatypes.
type testValuesStruct struct {
Uint uint `bson:"_uint"`
Uint8 uint8 `bson:"_uint8"`
Uint16 uint16 `bson:"_uint16"`
Uint32 uint32 `bson:"_uint32"`
Uint64 uint64 `bson:"_uint64"`
Int int `bson:"_int"`
Int8 int8 `bson:"_int8"`
Int16 int16 `bson:"_int16"`
Int32 int32 `bson:"_int32"`
Int64 int64 `bson:"_int64"`
Float32 float32 `bson:"_float32"`
Float64 float64 `bson:"_float64"`
Bool bool `bson:"_bool"`
String string `bson:"_string"`
Date time.Time `bson:"_date"`
DateN *time.Time `bson:"_nildate"`
DateP *time.Time `bson:"_ptrdate"`
Time time.Duration `bson:"_time"`
}
var testValues testValuesStruct
func init() {
t := time.Date(2012, 7, 28, 1, 2, 3, 0, time.Local)
testValues = testValuesStruct{
1, 1, 1, 1, 1,
-1, -1, -1, -1, -1,
1.337, 1.337,
true,
"Hello world!",
t,
nil,
&t,
time.Second * time.Duration(7331),
}
}
type AdapterTests struct {
testsuite.Suite
}
func (s *AdapterTests) SetupSuite() {
s.Helper = &Helper{}
}
func (s *AdapterTests) TestOpenWithWrongData() {
var err error
var rightSettings, wrongSettings ConnectionURL
// Attempt to open with safe settings.
rightSettings = ConnectionURL{
Database: settings.Database,
Host: settings.Host,
User: settings.User,
Password: settings.Password,
}
// Attempt to open an empty database.
_, err = Open(rightSettings)
s.NoError(err)
// Attempt to open with wrong password.
wrongSettings = ConnectionURL{
Database: settings.Database,
Host: settings.Host,
User: settings.User,
Password: "fail",
}
_, err = Open(wrongSettings)
s.Error(err)
// Attempt to open with wrong database.
wrongSettings = ConnectionURL{
Database: "fail",
Host: settings.Host,
User: settings.User,
Password: settings.Password,
}
_, err = Open(wrongSettings)
s.Error(err)
// Attempt to open with wrong username.
wrongSettings = ConnectionURL{
Database: settings.Database,
Host: settings.Host,
User: "fail",
Password: settings.Password,
}
_, err = Open(wrongSettings)
s.Error(err)
}
func (s *AdapterTests) TestTruncate() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a list of all collections in this database.
collections, err := sess.Collections()
s.NoError(err)
for _, col := range collections {
// The collection may ot may not exists.
if ok, _ := col.Exists(); ok {
// Truncating the structure, if exists.
err = col.Truncate()
s.NoError(err)
}
}
}
func (s *AdapterTests) TestInsert() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a pointer to the "artist" collection.
artist := sess.Collection("artist")
_ = artist.Truncate()
// Inserting a map.
record, err := artist.Insert(map[string]string{
"name": "Ozzie",
})
s.NoError(err)
id := record.ID()
s.NotZero(record.ID())
_, ok := id.(bson.ObjectId)
s.True(ok)
s.True(id.(bson.ObjectId).Valid())
// Inserting a struct.
record, err = artist.Insert(struct {
Name string
}{
"Flea",
})
s.NoError(err)
id = record.ID()
s.NotZero(id)
_, ok = id.(bson.ObjectId)
s.True(ok)
s.True(id.(bson.ObjectId).Valid())
// Inserting a struct (using tags to specify the field name).
record, err = artist.Insert(struct {
ArtistName string `bson:"name"`
}{
"Slash",
})
s.NoError(err)
id = record.ID()
s.NotNil(id)
_, ok = id.(bson.ObjectId)
s.True(ok)
s.True(id.(bson.ObjectId).Valid())
// Inserting a pointer to a struct
record, err = artist.Insert(&struct {
ArtistName string `bson:"name"`
}{
"Metallica",
})
s.NoError(err)
id = record.ID()
s.NotNil(id)
_, ok = id.(bson.ObjectId)
s.True(ok)
s.True(id.(bson.ObjectId).Valid())
// Inserting a pointer to a map
record, err = artist.Insert(&map[string]string{
"name": "Freddie",
})
s.NoError(err)
s.NotZero(id)
_, ok = id.(bson.ObjectId)
s.True(ok)
id = record.ID()
s.NotNil(id)
s.True(id.(bson.ObjectId).Valid())
// Counting elements, must be exactly 6 elements.
total, err := artist.Find().Count()
s.NoError(err)
s.Equal(uint64(5), total)
}
func (s *AdapterTests) TestGetNonExistentRow_Issue426() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
defer sess.Close()
artist := sess.Collection("artist")
var one artistType
err = artist.Find(mydb.Cond{"name": "nothing"}).One(&one)
s.NotZero(err)
s.Equal(mydb.ErrNoMoreRows, err)
var all []artistType
err = artist.Find(mydb.Cond{"name": "nothing"}).All(&all)
s.Zero(err, "All should not return mgo.ErrNotFound")
s.Equal(0, len(all))
}
func (s *AdapterTests) TestResultCount() {
var err error
var res mydb.Result
// Opening database.
sess, err := Open(settings)
s.NoError(err)
defer sess.Close()
// We should close the database when it's no longer in use.
artist := sess.Collection("artist")
res = artist.Find()
// Counting all the matching rows.
total, err := res.Count()
s.NoError(err)
s.NotZero(total)
}
func (s *AdapterTests) TestGroup() {
var stats mydb.Collection
sess, err := Open(settings)
s.NoError(err)
type statsT struct {
Numeric int `db:"numeric" bson:"numeric"`
Value int `db:"value" bson:"value"`
}
defer sess.Close()
stats = sess.Collection("statsTest")
// Truncating table.
_ = stats.Truncate()
// Adding row append.
for i := 0; i < 1000; i++ {
numeric, value := rand.Intn(10), rand.Intn(100)
_, err = stats.Insert(statsT{numeric, value})
s.NoError(err)
}
// mydb.statsTest.group({key: {numeric: true}, initial: {sum: 0}, reduce: function(doc, prev) { prev.sum += 1}});
// Testing GROUP BY
res := stats.Find().GroupBy(bson.M{
"key": bson.M{"numeric": true},
"initial": bson.M{"sum": 0},
"reduce": `function(doc, prev) { prev.sum += 1}`,
})
var results []map[string]interface{}
err = res.All(&results)
s.Equal(mydb.ErrUnsupported, err)
}
func (s *AdapterTests) TestResultNonExistentCount() {
sess, err := Open(settings)
s.NoError(err)
defer sess.Close()
total, err := sess.Collection("notartist").Find().Count()
s.NoError(err)
s.Zero(total)
}
func (s *AdapterTests) TestResultFetch() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
artist := sess.Collection("artist")
// Testing map
res := artist.Find()
rowM := map[string]interface{}{}
for res.Next(&rowM) {
s.NotZero(rowM["_id"])
_, ok := rowM["_id"].(bson.ObjectId)
s.True(ok)
s.True(rowM["_id"].(bson.ObjectId).Valid())
name, ok := rowM["name"].(string)
s.True(ok)
s.NotZero(name)
}
err = res.Close()
s.NoError(err)
// Testing struct
rowS := struct {
ID bson.ObjectId `bson:"_id"`
Name string `bson:"name"`
}{}
res = artist.Find()
for res.Next(&rowS) {
s.True(rowS.ID.Valid())
s.NotZero(rowS.Name)
}
err = res.Close()
s.NoError(err)
// Testing tagged struct
rowT := struct {
Value1 bson.ObjectId `bson:"_id"`
Value2 string `bson:"name"`
}{}
res = artist.Find()
for res.Next(&rowT) {
s.True(rowT.Value1.Valid())
s.NotZero(rowT.Value2)
}
err = res.Close()
s.NoError(err)
// Testing Result.All() with a slice of maps.
res = artist.Find()
allRowsM := []map[string]interface{}{}
err = res.All(&allRowsM)
s.NoError(err)
for _, singleRowM := range allRowsM {
s.NotZero(singleRowM["_id"])
}
// Testing Result.All() with a slice of structs.
res = artist.Find()
allRowsS := []struct {
ID bson.ObjectId `bson:"_id"`
Name string
}{}
err = res.All(&allRowsS)
s.NoError(err)
for _, singleRowS := range allRowsS {
s.True(singleRowS.ID.Valid())
}
// Testing Result.All() with a slice of tagged structs.
res = artist.Find()
allRowsT := []struct {
Value1 bson.ObjectId `bson:"_id"`
Value2 string `bson:"name"`
}{}
err = res.All(&allRowsT)
s.NoError(err)
for _, singleRowT := range allRowsT {
s.True(singleRowT.Value1.Valid())
}
}
func (s *AdapterTests) TestUpdate() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a pointer to the "artist" collection.
artist := sess.Collection("artist")
// Value
value := struct {
ID bson.ObjectId `bson:"_id"`
Name string
}{}
// Getting the first artist.
res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1)
err = res.One(&value)
s.NoError(err)
// Updating with a map
rowM := map[string]interface{}{
"name": strings.ToUpper(value.Name),
}
err = res.Update(rowM)
s.NoError(err)
err = res.One(&value)
s.NoError(err)
s.Equal(value.Name, rowM["name"])
// Updating with a struct
rowS := struct {
Name string
}{strings.ToLower(value.Name)}
err = res.Update(rowS)
s.NoError(err)
err = res.One(&value)
s.NoError(err)
s.Equal(value.Name, rowS.Name)
// Updating with a tagged struct
rowT := struct {
Value1 string `bson:"name"`
}{strings.Replace(value.Name, "z", "Z", -1)}
err = res.Update(rowT)
s.NoError(err)
err = res.One(&value)
s.NoError(err)
s.Equal(value.Name, rowT.Value1)
}
func (s *AdapterTests) TestOperators() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a pointer to the "artist" collection.
artist := sess.Collection("artist")
rowS := struct {
ID uint64
Name string
}{}
res := artist.Find(mydb.Cond{"_id": mydb.NotIn(0, -1)})
err = res.One(&rowS)
s.NoError(err)
err = res.Close()
s.NoError(err)
}
func (s *AdapterTests) TestDelete() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a pointer to the "artist" collection.
artist := sess.Collection("artist")
// Getting the first artist.
res := artist.Find(mydb.Cond{"_id": mydb.NotEq(nil)}).Limit(1)
var first struct {
ID bson.ObjectId `bson:"_id"`
}
err = res.One(&first)
s.NoError(err)
res = artist.Find(mydb.Cond{"_id": mydb.Eq(first.ID)})
// Trying to remove the row.
err = res.Delete()
s.NoError(err)
}
func (s *AdapterTests) TestDataTypes() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a pointer to the "data_types" collection.
dataTypes := sess.Collection("data_types")
// Inserting our test subject.
record, err := dataTypes.Insert(testValues)
s.NoError(err)
id := record.ID()
s.NotZero(id)
// Trying to get the same subject we added.
res := dataTypes.Find(mydb.Cond{"_id": mydb.Eq(id)})
exists, err := res.Count()
s.NoError(err)
s.NotZero(exists)
// Trying to dump the subject into an empty structure of the same type.
var item testValuesStruct
err = res.One(&item)
s.NoError(err)
// The original value and the test subject must match.
s.Equal(testValues, item)
}
func (s *AdapterTests) TestPaginator() {
// Opening database.
sess, err := Open(settings)
s.NoError(err)
// We should close the database when it's no longer in use.
defer sess.Close()
// Getting a pointer to the "artist" collection.
artist := sess.Collection("artist")
err = artist.Truncate()
s.NoError(err)
for i := 0; i < 999; i++ {
_, err = artist.Insert(artistType{
Name: fmt.Sprintf("artist-%d", i),
})
s.NoError(err)
}
q := sess.Collection("artist").Find().Paginate(15)
paginator := q.Paginate(13)
var zerothPage []artistType
err = paginator.Page(0).All(&zerothPage)
s.NoError(err)
s.Equal(13, len(zerothPage))
var secondPage []artistType
err = paginator.Page(2).All(&secondPage)
s.NoError(err)
s.Equal(13, len(secondPage))
tp, err := paginator.TotalPages()
s.NoError(err)
s.NotZero(tp)
s.Equal(uint(77), tp)
ti, err := paginator.TotalEntries()
s.NoError(err)
s.NotZero(ti)
s.Equal(uint64(999), ti)
var seventySixthPage []artistType
err = paginator.Page(76).All(&seventySixthPage)
s.NoError(err)
s.Equal(11, len(seventySixthPage))
var seventySeventhPage []artistType
err = paginator.Page(77).All(&seventySeventhPage)
s.NoError(err)
s.Equal(0, len(seventySeventhPage))
var hundredthPage []artistType
err = paginator.Page(100).All(&hundredthPage)
s.NoError(err)
s.Equal(0, len(hundredthPage))
for i := uint(0); i < tp; i++ {
current := paginator.Page(i)
var items []artistType
err := current.All(&items)
s.NoError(err)
if len(items) < 1 {
break
}
for j := 0; j < len(items); j++ {
s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name)
}
}
paginator = paginator.Cursor("_id")
{
current := paginator.Page(0)
for i := 0; ; i++ {
var items []artistType
err := current.All(&items)
s.NoError(err)
if len(items) < 1 {
break
}
for j := 0; j < len(items); j++ {
s.Equal(fmt.Sprintf("artist-%d", int64(13*int(i)+j)), items[j].Name)
}
current = current.NextPage(items[len(items)-1].ID)
}
}
{
log.Printf("Page 76")
current := paginator.Page(76)
for i := 76; ; i-- {
var items []artistType
err := current.All(&items)
s.NoError(err)
if len(items) < 1 {
s.Equal(0, len(items))
break
}
for j := 0; j < len(items); j++ {
s.Equal(fmt.Sprintf("artist-%d", 13*int(i)+j), items[j].Name)
}
current = current.PrevPage(items[0].ID)
}
}
{
resultPaginator := sess.Collection("artist").Find().Paginate(15)
count, err := resultPaginator.TotalPages()
s.Equal(uint(67), count)
s.NoError(err)
var items []artistType
err = resultPaginator.Page(5).All(&items)
s.NoError(err)
for j := 0; j < len(items); j++ {
s.Equal(fmt.Sprintf("artist-%d", 15*5+j), items[j].Name)
}
resultPaginator = resultPaginator.Cursor("_id").Page(0)
for i := 0; ; i++ {
var items []artistType
err = resultPaginator.All(&items)
s.NoError(err)
if len(items) < 1 {
break
}
for j := 0; j < len(items); j++ {
s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name)
}
resultPaginator = resultPaginator.NextPage(items[len(items)-1].ID)
}
resultPaginator = resultPaginator.Cursor("_id").Page(66)
for i := 66; ; i-- {
var items []artistType
err = resultPaginator.All(&items)
s.NoError(err)
if len(items) < 1 {
break
}
for j := 0; j < len(items); j++ {
s.Equal(fmt.Sprintf("artist-%d", 15*i+j), items[j].Name)
}
resultPaginator = resultPaginator.PrevPage(items[0].ID)
}
}
}
func TestAdapter(t *testing.T) {
suite.Run(t, &AdapterTests{})
}

565
adapter/mongo/result.go Normal file
View File

@ -0,0 +1,565 @@
package mongo
import (
"errors"
"fmt"
"math"
"strings"
"sync"
"time"
"encoding/json"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/immutable"
mgo "gopkg.in/mgo.v2"
"gopkg.in/mgo.v2/bson"
)
type resultQuery struct {
c *Collection
fields []string
limit int
offset int
sort []string
conditions interface{}
groupBy []interface{}
pageSize uint
pageNumber uint
cursorColumn string
cursorValue interface{}
cursorCond mydb.Cond
cursorReverseOrder bool
}
type result struct {
iter *mgo.Iter
err error
errMu sync.Mutex
fn func(*resultQuery) error
prev *result
}
var _ = immutable.Immutable(&result{})
func (res *result) frame(fn func(*resultQuery) error) *result {
return &result{prev: res, fn: fn}
}
func (r *resultQuery) and(terms ...interface{}) error {
if r.conditions == nil {
return r.where(terms...)
}
r.conditions = map[string]interface{}{
"$and": []interface{}{
r.conditions,
r.c.compileQuery(terms...),
},
}
return nil
}
func (r *resultQuery) where(terms ...interface{}) error {
r.conditions = r.c.compileQuery(terms...)
return nil
}
func (res *result) And(terms ...interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
return r.and(terms...)
})
}
func (res *result) Where(terms ...interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
return r.where(terms...)
})
}
func (res *result) Paginate(pageSize uint) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.pageSize = pageSize
return nil
})
}
func (res *result) Page(pageNumber uint) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.pageNumber = pageNumber
return nil
})
}
func (res *result) Cursor(cursorColumn string) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.cursorColumn = cursorColumn
return nil
})
}
func (res *result) NextPage(cursorValue interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.cursorValue = cursorValue
r.cursorReverseOrder = false
r.cursorCond = mydb.Cond{
r.cursorColumn: bson.M{"$gt": cursorValue},
}
return nil
})
}
func (res *result) PrevPage(cursorValue interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.cursorValue = cursorValue
r.cursorReverseOrder = true
r.cursorCond = mydb.Cond{
r.cursorColumn: bson.M{"$lt": cursorValue},
}
return nil
})
}
func (res *result) TotalEntries() (uint64, error) {
return res.Count()
}
func (res *result) TotalPages() (uint, error) {
count, err := res.Count()
if err != nil {
return 0, err
}
rq, err := res.build()
if err != nil {
return 0, err
}
if rq.pageSize < 1 {
return 1, nil
}
total := uint(math.Ceil(float64(count) / float64(rq.pageSize)))
return total, nil
}
// Limit determines the maximum limit of results to be returned.
func (res *result) Limit(n int) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.limit = n
return nil
})
}
// Offset determines how many documents will be skipped before starting to grab
// results.
func (res *result) Offset(n int) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.offset = n
return nil
})
}
// OrderBy determines sorting of results according to the provided names. Fields
// may be prefixed by - (minus) which means descending order, ascending order
// would be used otherwise.
func (res *result) OrderBy(fields ...interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
ss := make([]string, len(fields))
for i, field := range fields {
ss[i] = fmt.Sprintf(`%v`, field)
}
r.sort = ss
return nil
})
}
// String satisfies fmt.Stringer
func (res *result) String() string {
return ""
}
// Select marks the specific fields the user wants to retrieve.
func (res *result) Select(fields ...interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
fieldslen := len(fields)
r.fields = make([]string, 0, fieldslen)
for i := 0; i < fieldslen; i++ {
r.fields = append(r.fields, fmt.Sprintf(`%v`, fields[i]))
}
return nil
})
}
// All dumps all results into a pointer to an slice of structs or maps.
func (res *result) All(dst interface{}) error {
rq, err := res.build()
if err != nil {
return err
}
q, err := rq.query()
if err != nil {
return err
}
defer func(start time.Time) {
queryLog(&mydb.QueryStatus{
RawQuery: rq.debugQuery("Find.All"),
Err: err,
Start: start,
End: time.Now(),
})
}(time.Now())
err = q.All(dst)
if errors.Is(err, mgo.ErrNotFound) {
return mydb.ErrNoMoreRows
}
return err
}
// GroupBy is used to group results that have the same value in the same column
// or columns.
func (res *result) GroupBy(fields ...interface{}) mydb.Result {
return res.frame(func(r *resultQuery) error {
r.groupBy = fields
return nil
})
}
// One fetches only one result from the resultset.
func (res *result) One(dst interface{}) error {
rq, err := res.build()
if err != nil {
return err
}
q, err := rq.query()
if err != nil {
return err
}
defer func(start time.Time) {
queryLog(&mydb.QueryStatus{
RawQuery: rq.debugQuery("Find.One"),
Err: err,
Start: start,
End: time.Now(),
})
}(time.Now())
err = q.One(dst)
if errors.Is(err, mgo.ErrNotFound) {
return mydb.ErrNoMoreRows
}
return err
}
func (res *result) Err() error {
res.errMu.Lock()
defer res.errMu.Unlock()
return res.err
}
func (res *result) setErr(err error) {
res.errMu.Lock()
defer res.errMu.Unlock()
res.err = err
}
func (res *result) Next(dst interface{}) bool {
if res.iter == nil {
rq, err := res.build()
if err != nil {
return false
}
q, err := rq.query()
if err != nil {
return false
}
defer func(start time.Time) {
queryLog(&mydb.QueryStatus{
RawQuery: rq.debugQuery("Find.Next"),
Err: err,
Start: start,
End: time.Now(),
})
}(time.Now())
res.iter = q.Iter()
}
if !res.iter.Next(dst) {
res.setErr(res.iter.Err())
return false
}
return true
}
// Delete remove the matching items from the collection.
func (res *result) Delete() error {
rq, err := res.build()
if err != nil {
return err
}
defer func(start time.Time) {
queryLog(&mydb.QueryStatus{
RawQuery: rq.debugQuery("Remove"),
Err: err,
Start: start,
End: time.Now(),
})
}(time.Now())
_, err = rq.c.collection.RemoveAll(rq.conditions)
if err != nil {
return err
}
return nil
}
// Close closes the result set.
func (r *result) Close() error {
var err error
if r.iter != nil {
err = r.iter.Close()
r.iter = nil
}
return err
}
// Update modified matching items from the collection with values of the given
// map or struct.
func (res *result) Update(src interface{}) (err error) {
updateSet := map[string]interface{}{"$set": src}
rq, err := res.build()
if err != nil {
return err
}
defer func(start time.Time) {
queryLog(&mydb.QueryStatus{
RawQuery: rq.debugQuery("Update"),
Err: err,
Start: start,
End: time.Now(),
})
}(time.Now())
_, err = rq.c.collection.UpdateAll(rq.conditions, updateSet)
if err != nil {
return err
}
return nil
}
func (res *result) build() (*resultQuery, error) {
rqi, err := immutable.FastForward(res)
if err != nil {
return nil, err
}
rq := rqi.(*resultQuery)
if !rq.cursorCond.Empty() {
if err := rq.and(rq.cursorCond); err != nil {
return nil, err
}
}
if rq.cursorColumn != "" {
if rq.cursorReverseOrder {
rq.sort = append(rq.sort, "-"+rq.cursorColumn)
} else {
rq.sort = append(rq.sort, rq.cursorColumn)
}
}
return rq, nil
}
// query executes a mgo query.
func (r *resultQuery) query() (*mgo.Query, error) {
if len(r.groupBy) > 0 {
return nil, mydb.ErrUnsupported
}
q := r.c.collection.Find(r.conditions)
if r.pageSize > 0 {
r.offset = int(r.pageSize * r.pageNumber)
r.limit = int(r.pageSize)
}
if r.offset > 0 {
q.Skip(r.offset)
}
if r.limit > 0 {
q.Limit(r.limit)
}
if len(r.sort) > 0 {
q.Sort(r.sort...)
}
selectedFields := bson.M{}
if len(r.fields) > 0 {
for _, field := range r.fields {
if field == `*` {
break
}
selectedFields[field] = true
}
}
if r.cursorReverseOrder {
ids := make([]bson.ObjectId, 0, r.limit)
iter := q.Select(bson.M{"_id": true}).Iter()
defer iter.Close()
var item map[string]bson.ObjectId
for iter.Next(&item) {
ids = append(ids, item["_id"])
}
r.conditions = bson.M{"_id": bson.M{"$in": ids}}
q = r.c.collection.Find(r.conditions)
}
if len(selectedFields) > 0 {
q.Select(selectedFields)
}
return q, nil
}
func (res *result) Exists() (bool, error) {
total, err := res.Count()
if err != nil {
return false, err
}
if total > 0 {
return true, nil
}
return false, nil
}
// Count counts matching elements.
func (res *result) Count() (total uint64, err error) {
rq, err := res.build()
if err != nil {
return 0, err
}
defer func(start time.Time) {
queryLog(&mydb.QueryStatus{
RawQuery: rq.debugQuery("Find.Count"),
Err: err,
Start: start,
End: time.Now(),
})
}(time.Now())
q := rq.c.collection.Find(rq.conditions)
var c int
c, err = q.Count()
return uint64(c), err
}
func (res *result) Prev() immutable.Immutable {
if res == nil {
return nil
}
return res.prev
}
func (res *result) Fn(in interface{}) error {
if res.fn == nil {
return nil
}
return res.fn(in.(*resultQuery))
}
func (res *result) Base() interface{} {
return &resultQuery{}
}
func (r *resultQuery) debugQuery(action string) string {
query := fmt.Sprintf("mydb.%s.%s", r.c.collection.Name, action)
if r.conditions != nil {
query = fmt.Sprintf("%s.conds(%v)", query, r.conditions)
}
if r.limit > 0 {
query = fmt.Sprintf("%s.limit(%d)", query, r.limit)
}
if r.offset > 0 {
query = fmt.Sprintf("%s.offset(%d)", query, r.offset)
}
if len(r.fields) > 0 {
selectedFields := bson.M{}
for _, field := range r.fields {
if field == `*` {
break
}
selectedFields[field] = true
}
if len(selectedFields) > 0 {
query = fmt.Sprintf("%s.select(%v)", query, selectedFields)
}
}
if len(r.groupBy) > 0 {
escaped := make([]string, len(r.groupBy))
for i := range r.groupBy {
escaped[i] = string(mustJSON(r.groupBy[i]))
}
query = fmt.Sprintf("%s.groupBy(%v)", query, strings.Join(escaped, ", "))
}
if len(r.sort) > 0 {
escaped := make([]string, len(r.sort))
for i := range r.sort {
escaped[i] = string(mustJSON(r.sort[i]))
}
query = fmt.Sprintf("%s.sort(%s)", query, strings.Join(escaped, ", "))
}
return query
}
func mustJSON(in interface{}) (out []byte) {
out, err := json.Marshal(in)
if err != nil {
panic(err)
}
return out
}
func queryLog(status *mydb.QueryStatus) {
diff := status.End.Sub(status.Start)
slowQuery := false
if diff >= time.Millisecond*100 {
status.Err = mydb.ErrWarnSlowQuery
slowQuery = true
}
if status.Err != nil || slowQuery {
mydb.LC().Warn(status)
return
}
mydb.LC().Debug(status)
}

43
adapter/mysql/Makefile Normal file
View File

@ -0,0 +1,43 @@
SHELL ?= bash
MYSQL_VERSION ?= 8.1
MYSQL_SUPPORTED ?= $(MYSQL_VERSION) 5.7
PROJECT ?= upper_mysql_$(MYSQL_VERSION)
DB_HOST ?= 127.0.0.1
DB_PORT ?= 3306
DB_NAME ?= upperio
DB_USERNAME ?= upperio_user
DB_PASSWORD ?= upperio//s3cr37
TEST_FLAGS ?=
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1
export MYSQL_VERSION
export DB_HOST
export DB_NAME
export DB_PASSWORD
export DB_PORT
export DB_USERNAME
export TEST_FLAGS
test:
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
test-no-race:
go test -v -failfast $(TEST_FLAGS)
server-up: server-down
docker-compose -p $(PROJECT) up -d && \
sleep 15
server-down:
docker-compose -p $(PROJECT) down
test-extended:
parallel $(PARALLEL_FLAGS) \
"MYSQL_VERSION={} DB_PORT=\$$((3306+{#})) $(MAKE) server-up test server-down" ::: \
$(MYSQL_SUPPORTED)

5
adapter/mysql/README.md Normal file
View File

@ -0,0 +1,5 @@
# MySQL adapter for upper/db
Please read the full docs, acknowledgements and examples at
[https://upper.io/v4/adapter/mysql/](https://upper.io/v4/adapter/mysql/).

View File

@ -0,0 +1,56 @@
package mysql
import (
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
type collectionAdapter struct {
}
func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) {
columnNames, columnValues, err := sqlbuilder.Map(item, nil)
if err != nil {
return nil, err
}
pKey, err := col.PrimaryKeys()
if err != nil {
return nil, err
}
q := col.SQL().InsertInto(col.Name()).
Columns(columnNames...).
Values(columnValues...)
res, err := q.Exec()
if err != nil {
return nil, err
}
lastID, err := res.LastInsertId()
if err == nil && len(pKey) <= 1 {
return lastID, nil
}
keyMap := mydb.Cond{}
for i := range columnNames {
for j := 0; j < len(pKey); j++ {
if pKey[j] == columnNames[i] {
keyMap[pKey[j]] = columnValues[i]
}
}
}
// There was an auto column among primary keys, let's search for it.
if lastID > 0 {
for j := 0; j < len(pKey); j++ {
if keyMap[pKey[j]] == nil {
keyMap[pKey[j]] = lastID
}
}
}
return keyMap, nil
}

244
adapter/mysql/connection.go Normal file
View File

@ -0,0 +1,244 @@
package mysql
import (
"errors"
"fmt"
"net"
"net/url"
"strings"
)
// From https://github.com/go-sql-driver/mysql/blob/master/utils.go
var (
errInvalidDSNUnescaped = errors.New("Invalid DSN: Did you forget to escape a param value?")
errInvalidDSNAddr = errors.New("Invalid DSN: Network Address not terminated (missing closing brace)")
errInvalidDSNNoSlash = errors.New("Invalid DSN: Missing the slash separating the database name")
)
// From https://github.com/go-sql-driver/mysql/blob/master/utils.go
type config struct {
user string
passwd string
net string
addr string
dbname string
params map[string]string
}
// ConnectionURL implements a MySQL connection struct.
type ConnectionURL struct {
User string
Password string
Database string
Host string
Socket string
Options map[string]string
}
func (c ConnectionURL) String() (s string) {
if c.Database == "" {
return ""
}
// Adding username.
if c.User != "" {
s = s + c.User
// Adding password.
if c.Password != "" {
s = s + ":" + c.Password
}
s = s + "@"
}
// Adding protocol and address
if c.Socket != "" {
s = s + fmt.Sprintf("unix(%s)", c.Socket)
} else if c.Host != "" {
host, port, err := net.SplitHostPort(c.Host)
if err != nil {
host = c.Host
port = "3306"
}
s = s + fmt.Sprintf("tcp(%s:%s)", host, port)
}
// Adding database
s = s + "/" + c.Database
// Do we have any options?
if c.Options == nil {
c.Options = map[string]string{}
}
// Default options.
if _, ok := c.Options["charset"]; !ok {
c.Options["charset"] = "utf8"
}
if _, ok := c.Options["parseTime"]; !ok {
c.Options["parseTime"] = "true"
}
// Converting options into URL values.
vv := url.Values{}
for k, v := range c.Options {
vv.Set(k, v)
}
// Inserting options.
if p := vv.Encode(); p != "" {
s = s + "?" + p
}
return s
}
// ParseURL parses s into a ConnectionURL struct.
func ParseURL(s string) (conn ConnectionURL, err error) {
var cfg *config
if cfg, err = parseDSN(s); err != nil {
return
}
conn.User = cfg.user
conn.Password = cfg.passwd
if cfg.net == "unix" {
conn.Socket = cfg.addr
} else if cfg.net == "tcp" {
conn.Host = cfg.addr
}
conn.Database = cfg.dbname
conn.Options = map[string]string{}
for k, v := range cfg.params {
conn.Options[k] = v
}
return
}
// from https://github.com/go-sql-driver/mysql/blob/master/utils.go
// parseDSN parses the DSN string to a config
func parseDSN(dsn string) (cfg *config, err error) {
// New config with some default values
cfg = &config{}
// TODO: use strings.IndexByte when we can depend on Go 1.2
// [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
// Find the last '/' (since the password or the net addr might contain a '/')
foundSlash := false
for i := len(dsn) - 1; i >= 0; i-- {
if dsn[i] == '/' {
foundSlash = true
var j, k int
// left part is empty if i <= 0
if i > 0 {
// [username[:password]@][protocol[(address)]]
// Find the last '@' in dsn[:i]
for j = i; j >= 0; j-- {
if dsn[j] == '@' {
// username[:password]
// Find the first ':' in dsn[:j]
for k = 0; k < j; k++ {
if dsn[k] == ':' {
cfg.passwd = dsn[k+1 : j]
break
}
}
cfg.user = dsn[:k]
break
}
}
// [protocol[(address)]]
// Find the first '(' in dsn[j+1:i]
for k = j + 1; k < i; k++ {
if dsn[k] == '(' {
// dsn[i-1] must be == ')' if an address is specified
if dsn[i-1] != ')' {
if strings.ContainsRune(dsn[k+1:i], ')') {
return nil, errInvalidDSNUnescaped
}
return nil, errInvalidDSNAddr
}
cfg.addr = dsn[k+1 : i-1]
break
}
}
cfg.net = dsn[j+1 : k]
}
// dbname[?param1=value1&...&paramN=valueN]
// Find the first '?' in dsn[i+1:]
for j = i + 1; j < len(dsn); j++ {
if dsn[j] == '?' {
if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
return
}
break
}
}
cfg.dbname = dsn[i+1 : j]
break
}
}
if !foundSlash && len(dsn) > 0 {
return nil, errInvalidDSNNoSlash
}
// Set default network if empty
if cfg.net == "" {
cfg.net = "tcp"
}
// Set default address if empty
if cfg.addr == "" {
switch cfg.net {
case "tcp":
cfg.addr = "127.0.0.1:3306"
case "unix":
cfg.addr = "/tmp/mysql.sock"
default:
return nil, errors.New("Default addr for network '" + cfg.net + "' unknown")
}
}
return
}
// From https://github.com/go-sql-driver/mysql/blob/master/utils.go
// parseDSNParams parses the DSN "query string"
// Values must be url.QueryEscape'ed
func parseDSNParams(cfg *config, params string) (err error) {
for _, v := range strings.Split(params, "&") {
param := strings.SplitN(v, "=", 2)
if len(param) != 2 {
continue
}
value := param[1]
// lazy init
if cfg.params == nil {
cfg.params = make(map[string]string)
}
if cfg.params[param[0]], err = url.QueryUnescape(value); err != nil {
return
}
}
return
}

View File

@ -0,0 +1,117 @@
package mysql
import (
"testing"
)
func TestConnectionURL(t *testing.T) {
c := ConnectionURL{}
// Zero value equals to an empty string.
if c.String() != "" {
t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String())
}
// Adding a database name.
c.Database = "mydbname"
if c.String() != "/mydbname?charset=utf8&parseTime=true" {
t.Fatal(`Test failed, got:`, c.String())
}
// Adding an option.
c.Options = map[string]string{
"charset": "utf8mb4,utf8",
"sys_var": "esc@ped",
}
if c.String() != "/mydbname?charset=utf8mb4%2Cutf8&parseTime=true&sys_var=esc%40ped" {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting default options
c.Options = nil
// Setting user and password.
c.User = "user"
c.Password = "pass"
if c.String() != `user:pass@/mydbname?charset=utf8&parseTime=true` {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting host.
c.Host = "1.2.3.4:3306"
if c.String() != `user:pass@tcp(1.2.3.4:3306)/mydbname?charset=utf8&parseTime=true` {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting socket.
c.Socket = "/path/to/socket"
if c.String() != `user:pass@unix(/path/to/socket)/mydbname?charset=utf8&parseTime=true` {
t.Fatal(`Test failed, got:`, c.String())
}
}
func TestParseConnectionURL(t *testing.T) {
var u ConnectionURL
var s string
var err error
s = "user:pass@unix(/path/to/socket)/mydbname?charset=utf8"
if u, err = ParseURL(s); err != nil {
t.Fatal(err)
}
if u.User != "user" {
t.Fatal("Expecting username.")
}
if u.Password != "pass" {
t.Fatal("Expecting password.")
}
if u.Socket != "/path/to/socket" {
t.Fatal("Expecting socket.")
}
if u.Database != "mydbname" {
t.Fatal("Expecting database.")
}
if u.Options["charset"] != "utf8" {
t.Fatal("Expecting charset.")
}
s = "user:pass@tcp(1.2.3.4:5678)/mydbname?charset=utf8"
if u, err = ParseURL(s); err != nil {
t.Fatal(err)
}
if u.User != "user" {
t.Fatal("Expecting username.")
}
if u.Password != "pass" {
t.Fatal("Expecting password.")
}
if u.Host != "1.2.3.4:5678" {
t.Fatal("Expecting host.")
}
if u.Database != "mydbname" {
t.Fatal("Expecting database.")
}
if u.Options["charset"] != "utf8" {
t.Fatal("Expecting charset.")
}
}

View File

@ -0,0 +1,151 @@
package mysql
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"reflect"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
// JSON represents a MySQL's JSON value:
// https://www.mysql.org/docs/9.6/static/datatype-json.html. JSON
// satisfies sqlbuilder.ScannerValuer.
type JSON struct {
V interface{}
}
// MarshalJSON encodes the wrapper value as JSON.
func (j JSON) MarshalJSON() ([]byte, error) {
return json.Marshal(j.V)
}
// UnmarshalJSON decodes the given JSON into the wrapped value.
func (j *JSON) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
j.V = v
return nil
}
// Scan satisfies the sql.Scanner interface.
func (j *JSON) Scan(src interface{}) error {
if j.V == nil {
return nil
}
if src == nil {
dv := reflect.Indirect(reflect.ValueOf(j.V))
dv.Set(reflect.Zero(dv.Type()))
return nil
}
b, ok := src.([]byte)
if !ok {
return errors.New("Scan source was not []bytes")
}
if err := json.Unmarshal(b, j.V); err != nil {
return err
}
return nil
}
// Value satisfies the driver.Valuer interface.
func (j JSON) Value() (driver.Value, error) {
if j.V == nil {
return nil, nil
}
if v, ok := j.V.(json.RawMessage); ok {
return string(v), nil
}
b, err := json.Marshal(j.V)
if err != nil {
return nil, err
}
return string(b), nil
}
// JSONMap represents a map of interfaces with string keys
// (`map[string]interface{}`) that is compatible with MySQL's JSON type.
// JSONMap satisfies sqlbuilder.ScannerValuer.
type JSONMap map[string]interface{}
// Value satisfies the driver.Valuer interface.
func (m JSONMap) Value() (driver.Value, error) {
return JSONValue(m)
}
// Scan satisfies the sql.Scanner interface.
func (m *JSONMap) Scan(src interface{}) error {
*m = map[string]interface{}(nil)
return ScanJSON(m, src)
}
// JSONArray represents an array of any type (`[]interface{}`) that is
// compatible with MySQL's JSON type. JSONArray satisfies
// sqlbuilder.ScannerValuer.
type JSONArray []interface{}
// Value satisfies the driver.Valuer interface.
func (a JSONArray) Value() (driver.Value, error) {
return JSONValue(a)
}
// Scan satisfies the sql.Scanner interface.
func (a *JSONArray) Scan(src interface{}) error {
return ScanJSON(a, src)
}
// JSONValue takes an interface and provides a driver.Value that can be
// stored as a JSON column.
func JSONValue(i interface{}) (driver.Value, error) {
v := JSON{i}
return v.Value()
}
// ScanJSON decodes a JSON byte stream into the passed dst value.
func ScanJSON(dst interface{}, src interface{}) error {
v := JSON{dst}
return v.Scan(src)
}
// EncodeJSON is deprecated and going to be removed. Use ScanJSON instead.
func EncodeJSON(i interface{}) (driver.Value, error) {
return JSONValue(i)
}
// DecodeJSON is deprecated and going to be removed. Use JSONValue instead.
func DecodeJSON(dst interface{}, src interface{}) error {
return ScanJSON(dst, src)
}
// JSONConverter provides a helper method WrapValue that satisfies
// sqlbuilder.ValueWrapper, can be used to encode Go structs into JSON
// MySQL types and vice versa.
//
// Example:
//
// type MyCustomStruct struct {
// ID int64 `db:"id" json:"id"`
// Name string `db:"name" json:"name"`
// ...
// mysql.JSONConverter
// }
type JSONConverter struct{}
func (*JSONConverter) ConvertValue(in interface{}) interface {
sql.Scanner
driver.Valuer
} {
return &JSON{in}
}
// Type checks.
var (
_ sqlbuilder.ScannerValuer = &JSONMap{}
_ sqlbuilder.ScannerValuer = &JSONArray{}
_ sqlbuilder.ScannerValuer = &JSON{}
)

168
adapter/mysql/database.go Normal file
View File

@ -0,0 +1,168 @@
// Package mysql wraps the github.com/go-sql-driver/mysql MySQL driver. See
// https://github.com/upper/db/adapter/mysql for documentation, particularities and usage
// examples.
package mysql
import (
"reflect"
"strings"
"database/sql"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
_ "github.com/go-sql-driver/mysql" // MySQL driver.
)
// database is the actual implementation of Database
type database struct {
}
func (*database) Template() *exql.Template {
return template
}
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
return sql.Open("mysql", dsn)
}
func (*database) Collections(sess sqladapter.Session) (collections []string, err error) {
q := sess.SQL().
Select("table_name").
From("information_schema.tables").
Where("table_schema = ?", sess.Name())
iter := q.Iterator()
defer iter.Close()
for iter.Next() {
var tableName string
if err := iter.Scan(&tableName); err != nil {
return nil, err
}
collections = append(collections, tableName)
}
if err := iter.Err(); err != nil {
return nil, err
}
return collections, nil
}
func (d *database) ConvertValue(in interface{}) interface{} {
switch v := in.(type) {
case *map[string]interface{}:
return (*JSONMap)(v)
case map[string]interface{}:
return (*JSONMap)(&v)
}
dv := reflect.ValueOf(in)
if dv.IsValid() {
if dv.Type().Kind() == reflect.Ptr {
dv = dv.Elem()
}
switch dv.Kind() {
case reflect.Map:
if reflect.TypeOf(in).Kind() == reflect.Ptr {
w := reflect.ValueOf(in)
z := reflect.New(w.Elem().Type())
w.Elem().Set(z.Elem())
}
return &JSON{in}
case reflect.Slice:
return &JSON{in}
}
}
return in
}
func (*database) Err(err error) error {
if err != nil {
// This error is not exported so we have to check it by its string value.
s := err.Error()
if strings.Contains(s, `many connections`) {
return mydb.ErrTooManyClients
}
}
return err
}
func (*database) NewCollection() sqladapter.CollectionAdapter {
return &collectionAdapter{}
}
func (*database) LookupName(sess sqladapter.Session) (string, error) {
q := sess.SQL().
Select(mydb.Raw("DATABASE() AS name"))
iter := q.Iterator()
defer iter.Close()
if iter.Next() {
var name string
if err := iter.Scan(&name); err != nil {
return "", err
}
return name, nil
}
return "", iter.Err()
}
func (*database) TableExists(sess sqladapter.Session, name string) error {
q := sess.SQL().
Select("table_name").
From("information_schema.tables").
Where("table_schema = ? AND table_name = ?", sess.Name(), name)
iter := q.Iterator()
defer iter.Close()
if iter.Next() {
var name string
if err := iter.Scan(&name); err != nil {
return err
}
return nil
}
if err := iter.Err(); err != nil {
return err
}
return mydb.ErrCollectionDoesNotExist
}
func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) {
q := sess.SQL().
Select("k.column_name").
From("information_schema.key_column_usage AS k").
Where(`
k.constraint_name = 'PRIMARY'
AND k.table_schema = ?
AND k.table_name = ?
`, sess.Name(), tableName).
OrderBy("k.ordinal_position")
iter := q.Iterator()
defer iter.Close()
pk := []string{}
for iter.Next() {
var k string
if err := iter.Scan(&k); err != nil {
return nil, err
}
pk = append(pk, k)
}
if err := iter.Err(); err != nil {
return nil, err
}
return pk, nil
}

View File

@ -0,0 +1,14 @@
version: '3'
services:
server:
image: mysql:${MYSQL_VERSION:-5}
environment:
MYSQL_USER: ${DB_USERNAME:-upperio_user}
MYSQL_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37}
MYSQL_ALLOW_EMPTY_PASSWORD: 1
MYSQL_DATABASE: ${DB_NAME:-upperio}
ports:
- '${DB_HOST:-127.0.0.1}:${DB_PORT:-3306}:3306'

View File

@ -0,0 +1,20 @@
package mysql
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type GenericTests struct {
testsuite.GenericTestSuite
}
func (s *GenericTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestGeneric(t *testing.T) {
suite.Run(t, &GenericTests{})
}

View File

@ -0,0 +1,276 @@
package mysql
import (
"database/sql"
"fmt"
"os"
"time"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
)
var settings = ConnectionURL{
Database: os.Getenv("DB_NAME"),
User: os.Getenv("DB_USERNAME"),
Password: os.Getenv("DB_PASSWORD"),
Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"),
Options: map[string]string{
// See https://github.com/go-sql-driver/mysql/issues/9
"parseTime": "true",
// Might require you to use mysql_tzinfo_to_sql /usr/share/zoneinfo | mysql -u root -p mysql
"time_zone": fmt.Sprintf(`'%s'`, testsuite.TimeZone),
"loc": testsuite.TimeZone,
},
}
type Helper struct {
sess mydb.Session
}
func cleanUp(sess mydb.Session) error {
if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 {
return fmt.Errorf("Expecting active statements to be at most 128, got %d", activeStatements)
}
sess.Reset()
if activeStatements := sqladapter.NumActiveStatements(); activeStatements != 0 {
return fmt.Errorf("Expecting active statements to be 0, got %d", activeStatements)
}
var err error
var stats map[string]int
for i := 0; i < 10; i++ {
stats, err = getStats(sess)
if err != nil {
return err
}
if stats["Prepared_stmt_count"] != 0 {
time.Sleep(time.Millisecond * 200) // Sometimes it takes a bit to clean prepared statements
err = fmt.Errorf(`Expecting "Prepared_stmt_count" to be 0, got %d`, stats["Prepared_stmt_count"])
continue
}
break
}
return err
}
func getStats(sess mydb.Session) (map[string]int, error) {
stats := make(map[string]int)
res, err := sess.Driver().(*sql.DB).Query(`SHOW GLOBAL STATUS LIKE '%stmt%'`)
if err != nil {
return nil, err
}
var result struct {
VariableName string `db:"Variable_name"`
Value int `db:"Value"`
}
iter := sess.SQL().NewIterator(res)
for iter.Next(&result) {
stats[result.VariableName] = result.Value
}
return stats, nil
}
func (h *Helper) Session() mydb.Session {
return h.sess
}
func (h *Helper) Adapter() string {
return "mysql"
}
func (h *Helper) TearDown() error {
if err := cleanUp(h.sess); err != nil {
return err
}
return h.sess.Close()
}
func (h *Helper) TearUp() error {
var err error
h.sess, err = Open(settings)
if err != nil {
return err
}
batch := []string{
`DROP TABLE IF EXISTS artist`,
`CREATE TABLE artist (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
name VARCHAR(60)
)`,
`DROP TABLE IF EXISTS publication`,
`CREATE TABLE publication (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
title VARCHAR(80),
author_id BIGINT(20)
)`,
`DROP TABLE IF EXISTS review`,
`CREATE TABLE review (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
publication_id BIGINT(20),
name VARCHAR(80),
comments TEXT,
created DATETIME NOT NULL
)`,
`DROP TABLE IF EXISTS data_types`,
`CREATE TABLE data_types (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
_uint INT(10) UNSIGNED DEFAULT 0,
_uint8 INT(10) UNSIGNED DEFAULT 0,
_uint16 INT(10) UNSIGNED DEFAULT 0,
_uint32 INT(10) UNSIGNED DEFAULT 0,
_uint64 INT(10) UNSIGNED DEFAULT 0,
_int INT(10) DEFAULT 0,
_int8 INT(10) DEFAULT 0,
_int16 INT(10) DEFAULT 0,
_int32 INT(10) DEFAULT 0,
_int64 INT(10) DEFAULT 0,
_float32 DECIMAL(10,6),
_float64 DECIMAL(10,6),
_bool TINYINT(1),
_string text,
_blob blob,
_date TIMESTAMP NULL,
_nildate DATETIME NULL,
_ptrdate DATETIME NULL,
_defaultdate TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
_time BIGINT UNSIGNED NOT NULL DEFAULT 0
)`,
`DROP TABLE IF EXISTS stats_test`,
`CREATE TABLE stats_test (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
` + "`numeric`" + ` INT(10),
` + "`value`" + ` INT(10)
)`,
`DROP TABLE IF EXISTS composite_keys`,
`CREATE TABLE composite_keys (
code VARCHAR(255) default '',
user_id VARCHAR(255) default '',
some_val VARCHAR(255) default '',
primary key (code, user_id)
)`,
`DROP TABLE IF EXISTS admin`,
`CREATE TABLE admin (
ID int(11) NOT NULL AUTO_INCREMENT,
Accounts varchar(255) DEFAULT '',
LoginPassWord varchar(255) DEFAULT '',
Date TIMESTAMP(6) NOT NULL DEFAULT CURRENT_TIMESTAMP(6),
PRIMARY KEY (ID,Date)
) ENGINE=InnoDB AUTO_INCREMENT=2 DEFAULT CHARSET=utf8`,
`DROP TABLE IF EXISTS my_types`,
`CREATE TABLE my_types (id int(11) NOT NULL AUTO_INCREMENT, PRIMARY KEY(id)
, json_map JSON
, json_map_ptr JSON
, auto_json_map JSON
, auto_json_map_string JSON
, auto_json_map_integer JSON
, json_object JSON
, json_array JSON
, custom_json_object JSON
, auto_custom_json_object JSON
, custom_json_object_ptr JSON
, auto_custom_json_object_ptr JSON
, custom_json_object_array JSON
, auto_custom_json_object_array JSON
, auto_custom_json_object_map JSON
, integer_compat_value_json_array JSON
, string_compat_value_json_array JSON
, uinteger_compat_value_json_array JSON
)`,
`DROP TABLE IF EXISTS ` + "`" + `birthdays` + "`" + ``,
`CREATE TABLE ` + "`" + `birthdays` + "`" + ` (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
name VARCHAR(50),
born DATE,
born_ut BIGINT(20) SIGNED
) CHARSET=utf8`,
`DROP TABLE IF EXISTS ` + "`" + `fibonacci` + "`" + ``,
`CREATE TABLE ` + "`" + `fibonacci` + "`" + ` (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
input BIGINT(20) UNSIGNED NOT NULL,
output BIGINT(20) UNSIGNED NOT NULL
) CHARSET=utf8`,
`DROP TABLE IF EXISTS ` + "`" + `is_even` + "`" + ``,
`CREATE TABLE ` + "`" + `is_even` + "`" + ` (
input BIGINT(20) UNSIGNED NOT NULL,
is_even TINYINT(1)
) CHARSET=utf8`,
`DROP TABLE IF EXISTS ` + "`" + `CaSe_TesT` + "`" + ``,
`CREATE TABLE ` + "`" + `CaSe_TesT` + "`" + ` (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT, PRIMARY KEY(id),
case_test VARCHAR(60)
) CHARSET=utf8`,
`DROP TABLE IF EXISTS accounts`,
`CREATE TABLE accounts (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
name varchar(255),
disabled BOOL,
created_at TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP
)`,
`DROP TABLE IF EXISTS users`,
`CREATE TABLE users (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
account_id BIGINT(20),
username varchar(255) UNIQUE
)`,
`DROP TABLE IF EXISTS logs`,
`CREATE TABLE logs (
id BIGINT(20) UNSIGNED NOT NULL AUTO_INCREMENT,
PRIMARY KEY(id),
message VARCHAR(255)
)`,
}
for _, query := range batch {
driver := h.sess.Driver().(*sql.DB)
if _, err := driver.Exec(query); err != nil {
return err
}
}
return nil
}
var _ testsuite.Helper = &Helper{}

30
adapter/mysql/mysql.go Normal file
View File

@ -0,0 +1,30 @@
package mysql
import (
"database/sql"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
// Adapter is the public name of the adapter.
const Adapter = `mysql`
var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{})
// Open establishes a connection to the database server and returns a
// mydb.Session instance (which is compatible with mydb.Session).
func Open(connURL mydb.ConnectionURL) (mydb.Session, error) {
return registeredAdapter.OpenDSN(connURL)
}
// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value.
func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
return registeredAdapter.NewTx(sqlTx)
}
// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value.
func New(sqlDB *sql.DB) (mydb.Session, error) {
return registeredAdapter.New(sqlDB)
}

379
adapter/mysql/mysql_test.go Normal file
View File

@ -0,0 +1,379 @@
package mysql
import (
"database/sql"
"database/sql/driver"
"fmt"
"math/rand"
"strconv"
"testing"
"time"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type int64Compat int64
type uintCompat uint
type stringCompat string
type uintCompatArray []uintCompat
func (u *int64Compat) Scan(src interface{}) error {
if src != nil {
switch v := src.(type) {
case int64:
*u = int64Compat((src).(int64))
case []byte:
i, err := strconv.ParseInt(string(v), 10, 64)
if err != nil {
return err
}
*u = int64Compat(i)
default:
panic(fmt.Sprintf("expected type %T", src))
}
}
return nil
}
type customJSON struct {
N string `json:"name"`
V float64 `json:"value"`
}
func (c customJSON) Value() (driver.Value, error) {
return JSONValue(c)
}
func (c *customJSON) Scan(src interface{}) error {
return ScanJSON(c, src)
}
type autoCustomJSON struct {
N string `json:"name"`
V float64 `json:"value"`
*JSONConverter
}
var (
_ = driver.Valuer(&customJSON{})
_ = sql.Scanner(&customJSON{})
)
type AdapterTests struct {
testsuite.Suite
}
func (s *AdapterTests) SetupSuite() {
s.Helper = &Helper{}
}
func (s *AdapterTests) TestInsertReturningCompositeKey_Issue383() {
sess := s.Session()
type Admin struct {
ID int `db:"ID,omitempty"`
Accounts string `db:"Accounts"`
LoginPassWord string `db:"LoginPassWord"`
Date time.Time `db:"Date"`
}
dateNow := time.Now()
a := Admin{
Accounts: "admin",
LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E",
Date: dateNow,
}
adminCollection := sess.Collection("admin")
err := adminCollection.InsertReturning(&a)
s.NoError(err)
s.NotZero(a.ID)
s.NotZero(a.Date)
s.Equal("admin", a.Accounts)
s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord)
b := Admin{
Accounts: "admin2",
LoginPassWord: "E10ADC3949BA59ABBE56E057F20F883E",
Date: dateNow,
}
err = adminCollection.InsertReturning(&b)
s.NoError(err)
s.NotZero(b.ID)
s.NotZero(b.Date)
s.Equal("admin2", b.Accounts)
s.Equal("E10ADC3949BA59ABBE56E057F20F883E", a.LoginPassWord)
}
func (s *AdapterTests) TestIssue469_BadConnection() {
var err error
sess := s.Session()
// Ask the MySQL server to disconnect sessions that remain inactive for more
// than 1 second.
_, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`)
s.NoError(err)
// Remain inactive for 2 seconds.
time.Sleep(time.Second * 2)
// A query should start a new connection, even if the server disconnected us.
_, err = sess.Collection("artist").Find().Count()
s.NoError(err)
// This is a new session, ask the MySQL server to disconnect sessions that
// remain inactive for more than 1 second.
_, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`)
s.NoError(err)
// Remain inactive for 2 seconds.
time.Sleep(time.Second * 2)
// At this point the server should have disconnected us. Let's try to create
// a transaction anyway.
err = sess.Tx(func(sess mydb.Session) error {
var err error
_, err = sess.Collection("artist").Find().Count()
if err != nil {
return err
}
return nil
})
s.NoError(err)
// This is a new session, ask the MySQL server to disconnect sessions that
// remain inactive for more than 1 second.
_, err = sess.SQL().Exec(`SET SESSION wait_timeout=1`)
s.NoError(err)
err = sess.Tx(func(sess mydb.Session) error {
var err error
// This query should succeed.
_, err = sess.Collection("artist").Find().Count()
if err != nil {
panic(err.Error())
}
// Remain inactive for 2 seconds.
time.Sleep(time.Second * 2)
// This query should fail because the server disconnected us in the middle
// of a transaction.
_, err = sess.Collection("artist").Find().Count()
if err != nil {
return err
}
return nil
})
s.Error(err, "Expecting an error (can't recover from this)")
}
func (s *AdapterTests) TestMySQLTypes() {
sess := s.Session()
type MyType struct {
ID int64 `db:"id,omitempty"`
JSONMap JSONMap `db:"json_map"`
JSONObject JSONMap `db:"json_object"`
JSONArray JSONArray `db:"json_array"`
CustomJSONObject customJSON `db:"custom_json_object"`
AutoCustomJSONObject autoCustomJSON `db:"auto_custom_json_object"`
CustomJSONObjectPtr *customJSON `db:"custom_json_object_ptr,omitempty"`
AutoCustomJSONObjectPtr *autoCustomJSON `db:"auto_custom_json_object_ptr,omitempty"`
AutoCustomJSONObjectArray []autoCustomJSON `db:"auto_custom_json_object_array"`
AutoCustomJSONObjectMap map[string]autoCustomJSON `db:"auto_custom_json_object_map"`
Int64CompatValueJSONArray []int64Compat `db:"integer_compat_value_json_array"`
UIntCompatValueJSONArray uintCompatArray `db:"uinteger_compat_value_json_array"`
StringCompatValueJSONArray []stringCompat `db:"string_compat_value_json_array"`
}
origMyTypeTests := []MyType{
MyType{
Int64CompatValueJSONArray: []int64Compat{1, -2, 3, -4},
UIntCompatValueJSONArray: []uintCompat{1, 2, 3, 4},
StringCompatValueJSONArray: []stringCompat{"a", "b", "", "c"},
},
MyType{
Int64CompatValueJSONArray: []int64Compat(nil),
UIntCompatValueJSONArray: []uintCompat(nil),
StringCompatValueJSONArray: []stringCompat(nil),
},
MyType{
AutoCustomJSONObjectArray: []autoCustomJSON{
autoCustomJSON{
N: "Hello",
},
autoCustomJSON{
N: "World",
},
},
AutoCustomJSONObjectMap: map[string]autoCustomJSON{
"a": autoCustomJSON{
N: "Hello",
},
"b": autoCustomJSON{
N: "World",
},
},
JSONArray: JSONArray{float64(1), float64(2), float64(3), float64(4)},
},
MyType{
JSONArray: JSONArray{},
},
MyType{
JSONArray: JSONArray(nil),
},
MyType{},
MyType{
CustomJSONObject: customJSON{
N: "Hello",
},
AutoCustomJSONObject: autoCustomJSON{
N: "World",
},
},
MyType{
CustomJSONObject: customJSON{},
AutoCustomJSONObject: autoCustomJSON{},
},
MyType{
CustomJSONObject: customJSON{
N: "Hello 1",
},
AutoCustomJSONObject: autoCustomJSON{
N: "World 2",
},
},
MyType{
CustomJSONObjectPtr: nil,
AutoCustomJSONObjectPtr: nil,
},
MyType{
CustomJSONObjectPtr: &customJSON{},
AutoCustomJSONObjectPtr: &autoCustomJSON{},
},
MyType{
CustomJSONObjectPtr: &customJSON{
N: "Hello 3",
},
AutoCustomJSONObjectPtr: &autoCustomJSON{
N: "World 4",
},
},
MyType{},
MyType{
CustomJSONObject: customJSON{
V: 4.4,
},
},
MyType{
CustomJSONObject: customJSON{},
},
MyType{
CustomJSONObject: customJSON{
N: "Peter",
V: 5.56,
},
},
}
for i := 0; i < 100; i++ {
myTypeTests := make([]MyType, len(origMyTypeTests))
perm := rand.Perm(len(origMyTypeTests))
for i, v := range perm {
myTypeTests[v] = origMyTypeTests[i]
}
for i := range myTypeTests {
result, err := sess.Collection("my_types").Insert(myTypeTests[i])
s.NoError(err)
var actual MyType
err = sess.Collection("my_types").Find(result).One(&actual)
s.NoError(err)
expected := myTypeTests[i]
expected.ID = result.ID().(int64)
s.Equal(expected, actual)
}
for i := range myTypeTests {
res, err := sess.SQL().InsertInto("my_types").Values(myTypeTests[i]).Exec()
s.NoError(err)
id, err := res.LastInsertId()
s.NoError(err)
s.NotEqual(0, id)
var actual MyType
err = sess.Collection("my_types").Find(id).One(&actual)
s.NoError(err)
expected := myTypeTests[i]
expected.ID = id
s.Equal(expected, actual)
var actual2 MyType
err = sess.SQL().SelectFrom("my_types").Where("id = ?", id).One(&actual2)
s.NoError(err)
s.Equal(expected, actual2)
}
inserter := sess.SQL().InsertInto("my_types")
for i := range myTypeTests {
inserter = inserter.Values(myTypeTests[i])
}
_, err := inserter.Exec()
s.NoError(err)
err = sess.Collection("my_types").Truncate()
s.NoError(err)
batch := sess.SQL().InsertInto("my_types").Batch(50)
go func() {
defer batch.Done()
for i := range myTypeTests {
batch.Values(myTypeTests[i])
}
}()
err = batch.Wait()
s.NoError(err)
var values []MyType
err = sess.SQL().SelectFrom("my_types").All(&values)
s.NoError(err)
for i := range values {
expected := myTypeTests[i]
expected.ID = values[i].ID
s.Equal(expected, values[i])
}
}
}
func TestAdapter(t *testing.T) {
suite.Run(t, &AdapterTests{})
}

View File

@ -0,0 +1,20 @@
package mysql
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type RecordTests struct {
testsuite.RecordTestSuite
}
func (s *RecordTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestRecord(t *testing.T) {
suite.Run(t, &RecordTests{})
}

20
adapter/mysql/sql_test.go Normal file
View File

@ -0,0 +1,20 @@
package mysql
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type SQLTests struct {
testsuite.SQLTestSuite
}
func (s *SQLTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestSQL(t *testing.T) {
suite.Run(t, &SQLTests{})
}

198
adapter/mysql/template.go Normal file
View File

@ -0,0 +1,198 @@
package mysql
import (
"git.hexq.cn/tiglog/mydb/internal/cache"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
)
const (
adapterColumnSeparator = `.`
adapterIdentifierSeparator = `, `
adapterIdentifierQuote = "`{{.Value}}`"
adapterValueSeparator = `, `
adapterValueQuote = `'{{.}}'`
adapterAndKeyword = `AND`
adapterOrKeyword = `OR`
adapterDescKeyword = `DESC`
adapterAscKeyword = `ASC`
adapterAssignmentOperator = `=`
adapterClauseGroup = `({{.}})`
adapterClauseOperator = ` {{.}} `
adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}`
adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
adapterSortByColumnLayout = `{{.Column}} {{.Order}}`
adapterOrderByLayout = `
{{if .SortColumns}}
ORDER BY {{.SortColumns}}
{{end}}
`
adapterWhereLayout = `
{{if .Conds}}
WHERE {{.Conds}}
{{end}}
`
adapterUsingLayout = `
{{if .Columns}}
USING ({{.Columns}})
{{end}}
`
adapterJoinLayout = `
{{if .Table}}
{{ if .On }}
{{.Type}} JOIN {{.Table}}
{{.On}}
{{ else if .Using }}
{{.Type}} JOIN {{.Table}}
{{.Using}}
{{ else if .Type | eq "CROSS" }}
{{.Type}} JOIN {{.Table}}
{{else}}
NATURAL {{.Type}} JOIN {{.Table}}
{{end}}
{{end}}
`
adapterOnLayout = `
{{if .Conds}}
ON {{.Conds}}
{{end}}
`
adapterSelectLayout = `
SELECT
{{if .Distinct}}
DISTINCT
{{end}}
{{if defined .Columns}}
{{.Columns | compile}}
{{else}}
*
{{end}}
{{if defined .Table}}
FROM {{.Table | compile}}
{{end}}
{{.Joins | compile}}
{{.Where | compile}}
{{if defined .GroupBy}}
{{.GroupBy | compile}}
{{end}}
{{.OrderBy | compile}}
{{if .Limit}}
LIMIT {{.Limit}}
{{end}}
` +
// The argument for LIMIT when only OFFSET is specified is a pretty odd magic
// number; this comes directly from MySQL's manual, see:
// https://dev.mysql.com/doc/refman/5.7/en/select.html
//
// "To retrieve all rows from a certain offset up to the end of the result
// set, you can use some large number for the second parameter. This
// statement retrieves all rows from the 96th row to the last:
// SELECT * FROM tbl LIMIT 95,18446744073709551615; "
//
// ¯\_(ツ)_/¯
`
{{if .Offset}}
{{if not .Limit}}
LIMIT 18446744073709551615
{{end}}
OFFSET {{.Offset}}
{{end}}
`
adapterDeleteLayout = `
DELETE
FROM {{.Table | compile}}
{{.Where | compile}}
`
adapterUpdateLayout = `
UPDATE
{{.Table | compile}}
SET {{.ColumnValues | compile}}
{{.Where | compile}}
`
adapterSelectCountLayout = `
SELECT
COUNT(1) AS _t
FROM {{.Table | compile}}
{{.Where | compile}}
`
adapterInsertLayout = `
INSERT INTO {{.Table | compile}}
{{if defined .Columns}}({{.Columns | compile}}){{end}}
VALUES
{{if defined .Values}}
{{.Values | compile}}
{{else}}
()
{{end}}
{{if defined .Returning}}
RETURNING {{.Returning | compile}}
{{end}}
`
adapterTruncateLayout = `
TRUNCATE TABLE {{.Table | compile}}
`
adapterDropDatabaseLayout = `
DROP DATABASE {{.Database | compile}}
`
adapterDropTableLayout = `
DROP TABLE {{.Table | compile}}
`
adapterGroupByLayout = `
{{if .GroupColumns}}
GROUP BY {{.GroupColumns}}
{{end}}
`
)
var template = &exql.Template{
ColumnSeparator: adapterColumnSeparator,
IdentifierSeparator: adapterIdentifierSeparator,
IdentifierQuote: adapterIdentifierQuote,
ValueSeparator: adapterValueSeparator,
ValueQuote: adapterValueQuote,
AndKeyword: adapterAndKeyword,
OrKeyword: adapterOrKeyword,
DescKeyword: adapterDescKeyword,
AscKeyword: adapterAscKeyword,
AssignmentOperator: adapterAssignmentOperator,
ClauseGroup: adapterClauseGroup,
ClauseOperator: adapterClauseOperator,
ColumnValue: adapterColumnValue,
TableAliasLayout: adapterTableAliasLayout,
ColumnAliasLayout: adapterColumnAliasLayout,
SortByColumnLayout: adapterSortByColumnLayout,
WhereLayout: adapterWhereLayout,
JoinLayout: adapterJoinLayout,
OnLayout: adapterOnLayout,
UsingLayout: adapterUsingLayout,
OrderByLayout: adapterOrderByLayout,
InsertLayout: adapterInsertLayout,
SelectLayout: adapterSelectLayout,
UpdateLayout: adapterUpdateLayout,
DeleteLayout: adapterDeleteLayout,
TruncateLayout: adapterTruncateLayout,
DropDatabaseLayout: adapterDropDatabaseLayout,
DropTableLayout: adapterDropTableLayout,
CountLayout: adapterSelectCountLayout,
GroupByLayout: adapterGroupByLayout,
Cache: cache.NewCache(),
}

View File

@ -0,0 +1,269 @@
package mysql
import (
"testing"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
"github.com/stretchr/testify/assert"
)
func TestTemplateSelect(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
"SELECT * FROM `artist`",
b.SelectFrom("artist").String(),
)
assert.Equal(
"SELECT * FROM `artist`",
b.Select().From("artist").String(),
)
assert.Equal(
"SELECT * FROM `artist` ORDER BY `name` DESC",
b.Select().From("artist").OrderBy("name DESC").String(),
)
assert.Equal(
"SELECT * FROM `artist` ORDER BY `name` DESC",
b.Select().From("artist").OrderBy("-name").String(),
)
assert.Equal(
"SELECT * FROM `artist` ORDER BY `name` ASC",
b.Select().From("artist").OrderBy("name").String(),
)
assert.Equal(
"SELECT * FROM `artist` ORDER BY `name` ASC",
b.Select().From("artist").OrderBy("name ASC").String(),
)
assert.Equal(
"SELECT * FROM `artist` LIMIT 18446744073709551615 OFFSET 5",
b.Select().From("artist").Limit(-1).Offset(5).String(),
)
assert.Equal(
"SELECT `id` FROM `artist`",
b.Select("id").From("artist").String(),
)
assert.Equal(
"SELECT `id`, `name` FROM `artist`",
b.Select("id", "name").From("artist").String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (`name` = $1)",
b.SelectFrom("artist").Where("name", "Haruki").String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (name LIKE $1)",
b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(),
)
assert.Equal(
"SELECT `id` FROM `artist` WHERE (name LIKE $1 OR name LIKE $2)",
b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (`id` > $1)",
b.SelectFrom("artist").Where("id >", 2).String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (id <= 2 AND name != $1)",
b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2, $3, $4))",
b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (name IS NOT NULL)",
b.SelectFrom("artist").Where("name IS NOT NULL").String(),
)
assert.Equal(
"SELECT * FROM `artist` AS `a`, `publication` AS `p` WHERE (p.author_id = a.id) LIMIT 1",
b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(),
)
assert.Equal(
"SELECT `id` FROM `artist` NATURAL JOIN `publication`",
b.Select("id").From("artist").Join("publication").String(),
)
assert.Equal(
"SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) LIMIT 1",
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(),
)
assert.Equal(
"SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (`a`.`id` = $1) LIMIT 1",
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(),
)
assert.Equal(
"SELECT * FROM `artist` JOIN `publication` AS `p` ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1",
b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(),
)
assert.Equal(
"SELECT * FROM `artist` AS `a` JOIN `publication` AS `p` ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1",
b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(),
)
assert.Equal(
"SELECT * FROM `artist` AS `a` LEFT JOIN `publication` AS `p1` ON (p1.id = a.id) RIGHT JOIN `publication` AS `p2` ON (p2.id = a.id)",
b.SelectFrom("artist a").
LeftJoin("publication p1").On("p1.id = a.id").
RightJoin("publication p2").On("p2.id = a.id").
String(),
)
assert.Equal(
"SELECT * FROM `artist` CROSS JOIN `publication`",
b.SelectFrom("artist").CrossJoin("publication").String(),
)
assert.Equal(
"SELECT * FROM `artist` JOIN `publication` USING (`id`)",
b.SelectFrom("artist").Join("publication").Using("id").String(),
)
assert.Equal(
"SELECT DATE()",
b.Select(mydb.Raw("DATE()")).String(),
)
// Issue #408
{
assert.Equal(
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)",
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id IN": []int{1, 2}}).String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (`id` = $1 AND `name` LIKE $2)",
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": []byte{1, 2}}).String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)",
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.In(1, 2)}).String(),
)
assert.Equal(
"SELECT * FROM `artist` WHERE (`id` IN ($1, $2) AND `name` LIKE $3)",
b.SelectFrom("artist").Where(mydb.Cond{"name LIKE": "%foo", "id": mydb.AnyOf([]int{1, 2})}).String(),
)
}
}
func TestTemplateInsert(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
"INSERT INTO `artist` VALUES ($1, $2), ($3, $4), ($5, $6)",
b.InsertInto("artist").
Values(10, "Ryuichi Sakamoto").
Values(11, "Alondra de la Parra").
Values(12, "Haruki Murakami").
String(),
)
assert.Equal(
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)",
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(),
)
assert.Equal(
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2) RETURNING `id`",
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
)
assert.Equal(
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)",
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
)
assert.Equal(
"INSERT INTO `artist` (`id`, `name`) VALUES ($1, $2)",
b.InsertInto("artist").Values(struct {
ID int `db:"id"`
Name string `db:"name"`
}{12, "Chavela Vargas"}).String(),
)
assert.Equal(
"INSERT INTO `artist` (`name`, `id`) VALUES ($1, $2)",
b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(),
)
}
func TestTemplateUpdate(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
"UPDATE `artist` SET `name` = $1",
b.Update("artist").Set("name", "Artist").String(),
)
assert.Equal(
"UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)",
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
)
assert.Equal(
"UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)",
b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
"UPDATE `artist` SET `name` = $1 WHERE (`id` < $2)",
b.Update("artist").Set(struct {
Nombre string `db:"name"`
}{"Artist"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
"UPDATE `artist` SET `name` = $1, `last_name` = $2 WHERE (`id` < $3)",
b.Update("artist").Set(struct {
Nombre string `db:"name"`
}{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
"UPDATE `artist` SET `name` = $1 || ' ' || $2 || id, `id` = id + $3 WHERE (id > $4)",
b.Update("artist").Set(
"name = ? || ' ' || ? || id", "Artist", "#",
"id = id + ?", 10,
).Where("id > ?", 0).String(),
)
}
func TestTemplateDelete(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
"DELETE FROM `artist` WHERE (name = $1)",
b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(),
)
assert.Equal(
"DELETE FROM `artist` WHERE (id > 5)",
b.DeleteFrom("artist").Where("id > 5").String(),
)
}

View File

@ -0,0 +1,44 @@
SHELL ?= bash
POSTGRES_VERSION ?= 15-alpine
POSTGRES_SUPPORTED ?= $(POSTGRES_VERSION) 14-alpine 13-alpine 12-alpine
PROJECT ?= upper_postgres_$(POSTGRES_VERSION)
DB_HOST ?= 127.0.0.1
DB_PORT ?= 5432
DB_NAME ?= upperio
DB_USERNAME ?= upperio_user
DB_PASSWORD ?= upperio//s3cr37
TEST_FLAGS ?=
PARALLEL_FLAGS ?= --halt-on-error 2 --jobs 1
export POSTGRES_VERSION
export DB_HOST
export DB_NAME
export DB_PASSWORD
export DB_PORT
export DB_USERNAME
export TEST_FLAGS
test:
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
test-no-race:
go test -v -failfast $(TEST_FLAGS)
server-up: server-down
docker-compose -p $(PROJECT) up -d && \
sleep 10
server-down:
docker-compose -p $(PROJECT) down
test-extended:
parallel $(PARALLEL_FLAGS) \
"POSTGRES_VERSION={} DB_PORT=\$$((5432+{#})) $(MAKE) server-up test server-down" ::: \
$(POSTGRES_SUPPORTED)

View File

@ -0,0 +1,5 @@
# PostgreSQL adapter for upper/db
Please read the full docs, acknowledgements and examples at
[https://upper.io/v4/adapter/postgresql/](https://upper.io/v4/adapter/postgresql/).

View File

@ -0,0 +1,50 @@
package postgresql
import (
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
)
type collectionAdapter struct {
}
func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) {
pKey, err := col.PrimaryKeys()
if err != nil {
return nil, err
}
q := col.SQL().InsertInto(col.Name()).Values(item)
if len(pKey) == 0 {
// There is no primary key.
res, err := q.Exec()
if err != nil {
return nil, err
}
// Attempt to use LastInsertId() (probably won't work, but the Exec()
// succeeded, so we can safely ignore the error from LastInsertId()).
lastID, err := res.LastInsertId()
if err != nil {
return nil, nil
}
return lastID, nil
}
// Asking the database to return the primary key after insertion.
q = q.Returning(pKey...)
var keyMap mydb.Cond
if err := q.Iterator().One(&keyMap); err != nil {
return nil, err
}
// The IDSetter interface does not match, look for another interface match.
if len(keyMap) == 1 {
return keyMap[pKey[0]], nil
}
// This was a compound key and no interface matched it, let's return a map.
return keyMap, nil
}

View File

@ -0,0 +1,289 @@
package postgresql
import (
"fmt"
"net"
"net/url"
"sort"
"strings"
"time"
"unicode"
)
// scanner implements a tokenizer for libpq-style option strings.
type scanner struct {
s []rune
i int
}
// Next returns the next rune. It returns 0, false if the end of the text has
// been reached.
func (s *scanner) Next() (rune, bool) {
if s.i >= len(s.s) {
return 0, false
}
r := s.s[s.i]
s.i++
return r, true
}
// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the
// end of the text has been reached.
func (s *scanner) SkipSpaces() (rune, bool) {
r, ok := s.Next()
for unicode.IsSpace(r) && ok {
r, ok = s.Next()
}
return r, ok
}
type values map[string]string
func (vs values) Set(k, v string) {
vs[k] = v
}
func (vs values) Get(k string) (v string) {
return vs[k]
}
func (vs values) Isset(k string) bool {
_, ok := vs[k]
return ok
}
// ConnectionURL represents a parsed PostgreSQL connection URL.
//
// You can use a ConnectionURL struct as an argument for Open:
//
// var settings = postgresql.ConnectionURL{
// Host: "localhost", // PostgreSQL server IP or name.
// Database: "peanuts", // Database name.
// User: "cbrown", // Optional user name.
// Password: "snoopy", // Optional user password.
// }
//
// sess, err = postgresql.Open(settings)
//
// If you already have a valid DSN, you can use ParseURL to convert it into
// a ConnectionURL before passing it to Open.
type ConnectionURL struct {
User string
Password string
Host string
Socket string
Database string
Options map[string]string
timezone *time.Location
}
var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
// ParseURL parses the given DSN into a ConnectionURL struct.
// A typical PostgreSQL connection URL looks like:
//
// postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
func ParseURL(s string) (u *ConnectionURL, err error) {
o := make(values)
if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") {
s, err = parseURL(s)
if err != nil {
return u, err
}
}
if err := parseOpts(s, o); err != nil {
return u, err
}
u = &ConnectionURL{}
u.User = o.Get("user")
u.Password = o.Get("password")
h := o.Get("host")
p := o.Get("port")
if strings.HasPrefix(h, "/") {
u.Socket = h
} else {
if p == "" {
u.Host = h
} else {
u.Host = fmt.Sprintf("%s:%s", h, p)
}
}
u.Database = o.Get("dbname")
u.Options = make(map[string]string)
for k := range o {
switch k {
case "user", "password", "host", "port", "dbname":
// Skip
default:
u.Options[k] = o[k]
}
}
if timezone, ok := u.Options["timezone"]; ok {
u.timezone, _ = time.LoadLocation(timezone)
}
return u, err
}
// parseOpts parses the options from name and adds them to the values.
//
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
func parseOpts(name string, o values) error {
s := newScanner(name)
for {
var (
keyRunes, valRunes []rune
r rune
ok bool
)
if r, ok = s.SkipSpaces(); !ok {
break
}
// Scan the key
for !unicode.IsSpace(r) && r != '=' {
keyRunes = append(keyRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
// Skip any whitespace if we're not at the = yet
if r != '=' {
r, ok = s.SkipSpaces()
}
// The current character should be =
if r != '=' || !ok {
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
}
// Skip any whitespace after the =
if r, ok = s.SkipSpaces(); !ok {
// If we reach the end here, the last value is just an empty string as per libpq.
o.Set(string(keyRunes), "")
break
}
if r != '\'' {
for !unicode.IsSpace(r) {
if r == '\\' {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`missing character after backslash`)
}
}
valRunes = append(valRunes, r)
if r, ok = s.Next(); !ok {
break
}
}
} else {
quote:
for {
if r, ok = s.Next(); !ok {
return fmt.Errorf(`unterminated quoted string literal in connection string`)
}
switch r {
case '\'':
break quote
case '\\':
r, _ = s.Next()
fallthrough
default:
valRunes = append(valRunes, r)
}
}
}
o.Set(string(keyRunes), string(valRunes))
}
return nil
}
// newScanner returns a new scanner initialized with the option string s.
func newScanner(s string) *scanner {
return &scanner{[]rune(s), 0}
}
// ParseURL no longer needs to be used by clients of this library since supplying a URL as a
// connection string to sql.Open() is now supported:
//
// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")
//
// It remains exported here for backwards-compatibility.
//
// ParseURL converts a url to a connection string for driver.Open.
// Example:
//
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
//
// converts to:
//
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
//
// A minimal example:
//
// "postgres://"
//
// # This will be blank, causing driver.Open to use all of the defaults
//
// NOTE: vendored/copied from github.com/lib/pq
func parseURL(uri string) (string, error) {
u, err := url.Parse(uri)
if err != nil {
return "", err
}
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
}
var kvs []string
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
accrue := func(k, v string) {
if v != "" {
kvs = append(kvs, k+"="+escaper.Replace(v))
}
}
if u.User != nil {
v := u.User.Username()
accrue("user", v)
v, _ = u.User.Password()
accrue("password", v)
}
if host, port, err := net.SplitHostPort(u.Host); err != nil {
accrue("host", u.Host)
} else {
accrue("host", host)
accrue("port", port)
}
if u.Path != "" {
accrue("dbname", u.Path[1:])
}
q := u.Query()
for k := range q {
accrue(k, q.Get(k))
}
sort.Strings(kvs) // Makes testing easier (not a performance concern)
return strings.Join(kvs, " "), nil
}

View File

@ -0,0 +1,73 @@
//go:build !pq
// +build !pq
package postgresql
import (
"net"
"sort"
"strings"
)
// String reassembles the parsed PostgreSQL connection URL into a valid DSN.
func (c ConnectionURL) String() (s string) {
u := []string{}
// TODO: This surely needs some sort of escaping.
if c.User != "" {
u = append(u, "user="+escaper.Replace(c.User))
}
if c.Password != "" {
u = append(u, "password="+escaper.Replace(c.Password))
}
if c.Host != "" {
host, port, err := net.SplitHostPort(c.Host)
if err == nil {
if host == "" {
host = "127.0.0.1"
}
if port == "" {
port = "5432"
}
u = append(u, "host="+escaper.Replace(host))
u = append(u, "port="+escaper.Replace(port))
} else {
u = append(u, "host="+escaper.Replace(c.Host))
}
}
if c.Socket != "" {
u = append(u, "host="+escaper.Replace(c.Socket))
}
if c.Database != "" {
u = append(u, "dbname="+escaper.Replace(c.Database))
}
// Is there actually any connection data?
if len(u) == 0 {
return ""
}
if c.Options == nil {
c.Options = map[string]string{}
}
// If not present, SSL mode is assumed "prefer".
if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" {
c.Options["sslmode"] = "prefer"
}
// Disabled by default
c.Options["statement_cache_capacity"] = "0"
for k, v := range c.Options {
u = append(u, escaper.Replace(k)+"="+escaper.Replace(v))
}
sort.Strings(u)
return strings.Join(u, " ")
}

View File

@ -0,0 +1,108 @@
//go:build !pq
// +build !pq
package postgresql
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConnectionURL(t *testing.T) {
c := ConnectionURL{}
// Default connection string is empty.
assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty")
// Adding a host with port.
c.Host = "localhost:1234"
assert.Equal(t, "host=localhost port=1234 sslmode=prefer statement_cache_capacity=0", c.String())
// Adding a host.
c.Host = "localhost"
assert.Equal(t, "host=localhost sslmode=prefer statement_cache_capacity=0", c.String())
// Adding a username.
c.User = "Anakin"
assert.Equal(t, `host=localhost sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
// Adding a password with special characters.
c.Password = "Some Sort of ' Password"
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
// Adding a port.
c.Host = "localhost:1234"
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
// Adding a database.
c.Database = "MyDatabase"
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer statement_cache_capacity=0 user=Anakin`, c.String())
// Adding options.
c.Options = map[string]string{
"sslmode": "verify-full",
}
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full statement_cache_capacity=0 user=Anakin`, c.String())
}
func TestParseConnectionURL(t *testing.T) {
{
s := "postgres://anakin:skywalker@localhost/jedis"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "anakin", u.User)
assert.Equal(t, "skywalker", u.Password)
assert.Equal(t, "localhost", u.Host)
assert.Equal(t, "jedis", u.Database)
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
}
{
// case with port
s := "postgres://anakin:skywalker@localhost:1234/jedis"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "anakin", u.User)
assert.Equal(t, "skywalker", u.Password)
assert.Equal(t, "jedis", u.Database)
assert.Equal(t, "localhost:1234", u.Host)
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
}
{
s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "verify-full", u.Options["sslmode"])
}
{
s := "user=anakin password=skywalker host=localhost dbname=jedis"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "anakin", u.User)
assert.Equal(t, "skywalker", u.Password)
assert.Equal(t, "jedis", u.Database)
assert.Equal(t, "localhost", u.Host)
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
}
{
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "verify-full", u.Options["sslmode"])
}
{
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, 2, len(u.Options), "Expecting exactly two options")
assert.Equal(t, "verify-full", u.Options["sslmode"])
assert.Equal(t, "UTC", u.Options["timezone"])
}
}

View File

@ -0,0 +1,70 @@
//go:build pq
// +build pq
package postgresql
import (
"net"
"sort"
"strings"
)
// String reassembles the parsed PostgreSQL connection URL into a valid DSN.
func (c ConnectionURL) String() (s string) {
u := []string{}
// TODO: This surely needs some sort of escaping.
if c.User != "" {
u = append(u, "user="+escaper.Replace(c.User))
}
if c.Password != "" {
u = append(u, "password="+escaper.Replace(c.Password))
}
if c.Host != "" {
host, port, err := net.SplitHostPort(c.Host)
if err == nil {
if host == "" {
host = "127.0.0.1"
}
if port == "" {
port = "5432"
}
u = append(u, "host="+escaper.Replace(host))
u = append(u, "port="+escaper.Replace(port))
} else {
u = append(u, "host="+escaper.Replace(c.Host))
}
}
if c.Socket != "" {
u = append(u, "host="+escaper.Replace(c.Socket))
}
if c.Database != "" {
u = append(u, "dbname="+escaper.Replace(c.Database))
}
// Is there actually any connection data?
if len(u) == 0 {
return ""
}
if c.Options == nil {
c.Options = map[string]string{}
}
// If not present, SSL mode is assumed "prefer".
if sslMode, ok := c.Options["sslmode"]; !ok || sslMode == "" {
c.Options["sslmode"] = "prefer"
}
for k, v := range c.Options {
u = append(u, escaper.Replace(k)+"="+escaper.Replace(v))
}
sort.Strings(u)
return strings.Join(u, " ")
}

View File

@ -0,0 +1,108 @@
//go:build pq
// +build pq
package postgresql
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestConnectionURL(t *testing.T) {
c := ConnectionURL{}
// Default connection string is empty.
assert.Equal(t, "", c.String(), "Expecting default connectiong string to be empty")
// Adding a host with port.
c.Host = "localhost:1234"
assert.Equal(t, "host=localhost port=1234 sslmode=prefer", c.String())
// Adding a host.
c.Host = "localhost"
assert.Equal(t, "host=localhost sslmode=prefer", c.String())
// Adding a username.
c.User = "Anakin"
assert.Equal(t, `host=localhost sslmode=prefer user=Anakin`, c.String())
// Adding a password with special characters.
c.Password = "Some Sort of ' Password"
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password sslmode=prefer user=Anakin`, c.String())
// Adding a port.
c.Host = "localhost:1234"
assert.Equal(t, `host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String())
// Adding a database.
c.Database = "MyDatabase"
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=prefer user=Anakin`, c.String())
// Adding options.
c.Options = map[string]string{
"sslmode": "verify-full",
}
assert.Equal(t, `dbname=MyDatabase host=localhost password=Some\ Sort\ of\ \'\ Password port=1234 sslmode=verify-full user=Anakin`, c.String())
}
func TestParseConnectionURL(t *testing.T) {
{
s := "postgres://anakin:skywalker@localhost/jedis"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "anakin", u.User)
assert.Equal(t, "skywalker", u.Password)
assert.Equal(t, "localhost", u.Host)
assert.Equal(t, "jedis", u.Database)
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
}
{
// case with port
s := "postgres://anakin:skywalker@localhost:1234/jedis"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "anakin", u.User)
assert.Equal(t, "skywalker", u.Password)
assert.Equal(t, "jedis", u.Database)
assert.Equal(t, "localhost:1234", u.Host)
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
}
{
s := "postgres://anakin:skywalker@localhost/jedis?sslmode=verify-full"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "verify-full", u.Options["sslmode"])
}
{
s := "user=anakin password=skywalker host=localhost dbname=jedis"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "anakin", u.User)
assert.Equal(t, "skywalker", u.Password)
assert.Equal(t, "jedis", u.Database)
assert.Equal(t, "localhost", u.Host)
assert.Zero(t, u.Options["sslmode"], "Failed to parse SSLMode.")
}
{
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, "verify-full", u.Options["sslmode"])
}
{
s := "user=anakin password=skywalker host=localhost dbname=jedis sslmode=verify-full timezone=UTC"
u, err := ParseURL(s)
assert.NoError(t, err)
assert.Equal(t, 2, len(u.Options), "Expecting exactly two options")
assert.Equal(t, "verify-full", u.Options["sslmode"])
assert.Equal(t, "UTC", u.Options["timezone"])
}
}

View File

@ -0,0 +1,126 @@
package postgresql
import (
"context"
"database/sql"
"database/sql/driver"
"time"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
// JSONBMap represents a map of interfaces with string keys
// (`map[string]interface{}`) that is compatible with PostgreSQL's JSONB type.
// JSONBMap satisfies sqlbuilder.ScannerValuer.
type JSONBMap map[string]interface{}
// Value satisfies the driver.Valuer interface.
func (m JSONBMap) Value() (driver.Value, error) {
return JSONBValue(m)
}
// Scan satisfies the sql.Scanner interface.
func (m *JSONBMap) Scan(src interface{}) error {
*m = map[string]interface{}(nil)
return ScanJSONB(m, src)
}
// JSONBArray represents an array of any type (`[]interface{}`) that is
// compatible with PostgreSQL's JSONB type. JSONBArray satisfies
// sqlbuilder.ScannerValuer.
type JSONBArray []interface{}
// Value satisfies the driver.Valuer interface.
func (a JSONBArray) Value() (driver.Value, error) {
return JSONBValue(a)
}
// Scan satisfies the sql.Scanner interface.
func (a *JSONBArray) Scan(src interface{}) error {
return ScanJSONB(a, src)
}
// JSONBValue takes an interface and provides a driver.Value that can be
// stored as a JSONB column.
func JSONBValue(i interface{}) (driver.Value, error) {
v := JSONB{i}
return v.Value()
}
// ScanJSONB decodes a JSON byte stream into the passed dst value.
func ScanJSONB(dst interface{}, src interface{}) error {
v := JSONB{dst}
return v.Scan(src)
}
type JSONBConverter struct {
}
func (*JSONBConverter) ConvertValue(in interface{}) interface {
sql.Scanner
driver.Valuer
} {
return &JSONB{in}
}
type timeWrapper struct {
v **time.Time
loc *time.Location
}
func (t timeWrapper) Value() (driver.Value, error) {
if *t.v != nil {
return **t.v, nil
}
return nil, nil
}
func (t *timeWrapper) Scan(src interface{}) error {
if src == nil {
nilTime := (*time.Time)(nil)
if t.v == nil {
t.v = &nilTime
} else {
*(t.v) = nilTime
}
return nil
}
tz := src.(time.Time)
if t.loc != nil && (tz.Location() == time.Local) {
tz = tz.In(t.loc)
}
if tz.Location().String() == "" {
tz = tz.In(time.UTC)
}
if *(t.v) == nil {
*(t.v) = &tz
} else {
**t.v = tz
}
return nil
}
func (d *database) ConvertValueContext(ctx context.Context, in interface{}) interface{} {
tz, _ := ctx.Value("timezone").(*time.Location)
switch v := in.(type) {
case *time.Time:
return &timeWrapper{&v, tz}
case **time.Time:
return &timeWrapper{v, tz}
}
return d.ConvertValue(in)
}
// Type checks.
var (
_ sqlbuilder.ScannerValuer = &StringArray{}
_ sqlbuilder.ScannerValuer = &Int64Array{}
_ sqlbuilder.ScannerValuer = &Float64Array{}
_ sqlbuilder.ScannerValuer = &Float32Array{}
_ sqlbuilder.ScannerValuer = &BoolArray{}
_ sqlbuilder.ScannerValuer = &JSONBMap{}
_ sqlbuilder.ScannerValuer = &JSONBArray{}
_ sqlbuilder.ScannerValuer = &JSONB{}
)

View File

@ -0,0 +1,286 @@
//go:build !pq
// +build !pq
package postgresql
import (
"database/sql/driver"
"github.com/jackc/pgtype"
)
// JSONB represents a PostgreSQL's JSONB value:
// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB
// satisfies sqlbuilder.ScannerValuer.
type JSONB struct {
Data interface{}
}
// MarshalJSON encodes the wrapper value as JSON.
func (j JSONB) MarshalJSON() ([]byte, error) {
t := &pgtype.JSONB{}
if err := t.Set(j.Data); err != nil {
return nil, err
}
return t.MarshalJSON()
}
// UnmarshalJSON decodes the given JSON into the wrapped value.
func (j *JSONB) UnmarshalJSON(b []byte) error {
t := &pgtype.JSONB{}
if err := t.UnmarshalJSON(b); err != nil {
return err
}
if j.Data == nil {
j.Data = t.Get()
return nil
}
if err := t.AssignTo(&j.Data); err != nil {
return err
}
return nil
}
// Scan satisfies the sql.Scanner interface.
func (j *JSONB) Scan(src interface{}) error {
t := &pgtype.JSONB{}
if err := t.Scan(src); err != nil {
return err
}
if j.Data == nil {
j.Data = t.Get()
return nil
}
if err := t.AssignTo(j.Data); err != nil {
return err
}
return nil
}
// Value satisfies the driver.Valuer interface.
func (j JSONB) Value() (driver.Value, error) {
t := &pgtype.JSONB{}
if err := t.Set(j.Data); err != nil {
return nil, err
}
return t.Value()
}
// StringArray represents a one-dimensional array of strings (`[]string{}`)
// that is compatible with PostgreSQL's text array (`text[]`). StringArray
// satisfies sqlbuilder.ScannerValuer.
type StringArray []string
// Value satisfies the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
t := pgtype.TextArray{}
if err := t.Set(a); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (sa *StringArray) Scan(src interface{}) error {
d := []string{}
t := pgtype.TextArray{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*sa = StringArray(d)
return nil
}
type Bytea []byte
func (b Bytea) Value() (driver.Value, error) {
t := pgtype.Bytea{Bytes: b}
if err := t.Set(b); err != nil {
return nil, err
}
return t.Value()
}
func (b *Bytea) Scan(src interface{}) error {
d := []byte{}
t := pgtype.Bytea{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*b = Bytea(d)
return nil
}
// ByteaArray represents a one-dimensional array of strings (`[]string{}`)
// that is compatible with PostgreSQL's text array (`text[]`). ByteaArray
// satisfies sqlbuilder.ScannerValuer.
type ByteaArray [][]byte
// Value satisfies the driver.Valuer interface.
func (a ByteaArray) Value() (driver.Value, error) {
t := pgtype.ByteaArray{}
if err := t.Set(a); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (ba *ByteaArray) Scan(src interface{}) error {
d := [][]byte{}
t := pgtype.ByteaArray{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*ba = ByteaArray(d)
return nil
}
// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that
// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array
// satisfies sqlbuilder.ScannerValuer.
type Int64Array []int64
// Value satisfies the driver.Valuer interface.
func (i64a Int64Array) Value() (driver.Value, error) {
t := pgtype.Int8Array{}
if err := t.Set(i64a); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (i64a *Int64Array) Scan(src interface{}) error {
d := []int64{}
t := pgtype.Int8Array{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*i64a = Int64Array(d)
return nil
}
// Int32Array represents a one-dimensional array of int32s (`[]int32{}`) that
// is compatible with PostgreSQL's integer array (`integer[]`). Int32Array
// satisfies sqlbuilder.ScannerValuer.
type Int32Array []int32
// Value satisfies the driver.Valuer interface.
func (i32a Int32Array) Value() (driver.Value, error) {
t := pgtype.Int4Array{}
if err := t.Set(i32a); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (i32a *Int32Array) Scan(src interface{}) error {
d := []int32{}
t := pgtype.Int4Array{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*i32a = Int32Array(d)
return nil
}
// Float64Array represents a one-dimensional array of float64s (`[]float64{}`)
// that is compatible with PostgreSQL's double precision array (`double
// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer.
type Float64Array []float64
// Value satisfies the driver.Valuer interface.
func (f64a Float64Array) Value() (driver.Value, error) {
t := pgtype.Float8Array{}
if err := t.Set(f64a); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (f64a *Float64Array) Scan(src interface{}) error {
d := []float64{}
t := pgtype.Float8Array{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*f64a = Float64Array(d)
return nil
}
// Float32Array represents a one-dimensional array of float32s (`[]float32{}`)
// that is compatible with PostgreSQL's double precision array (`double
// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer.
type Float32Array []float32
// Value satisfies the driver.Valuer interface.
func (f32a Float32Array) Value() (driver.Value, error) {
t := pgtype.Float8Array{}
if err := t.Set(f32a); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (f32a *Float32Array) Scan(src interface{}) error {
d := []float32{}
t := pgtype.Float8Array{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*f32a = Float32Array(d)
return nil
}
// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that
// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray
// satisfies sqlbuilder.ScannerValuer.
type BoolArray []bool
// Value satisfies the driver.Valuer interface.
func (ba BoolArray) Value() (driver.Value, error) {
t := pgtype.BoolArray{}
if err := t.Set(ba); err != nil {
return nil, err
}
return t.Value()
}
// Scan satisfies the sql.Scanner interface.
func (ba *BoolArray) Scan(src interface{}) error {
d := []bool{}
t := pgtype.BoolArray{}
if err := t.Scan(src); err != nil {
return err
}
if err := t.AssignTo(&d); err != nil {
return err
}
*ba = BoolArray(d)
return nil
}

View File

@ -0,0 +1,249 @@
//go:build pq
// +build pq
package postgresql
import (
"bytes"
"database/sql/driver"
"encoding/hex"
"encoding/json"
"errors"
"fmt"
"reflect"
"strconv"
"time"
"github.com/lib/pq"
)
// JSONB represents a PostgreSQL's JSONB value:
// https://www.postgresql.org/docs/9.6/static/datatype-json.html. JSONB
// satisfies sqlbuilder.ScannerValuer.
type JSONB struct {
Data interface{}
}
// MarshalJSON encodes the wrapper value as JSON.
func (j JSONB) MarshalJSON() ([]byte, error) {
return json.Marshal(j.Data)
}
// UnmarshalJSON decodes the given JSON into the wrapped value.
func (j *JSONB) UnmarshalJSON(b []byte) error {
var v interface{}
if err := json.Unmarshal(b, &v); err != nil {
return err
}
j.Data = v
return nil
}
// Scan satisfies the sql.Scanner interface.
func (j *JSONB) Scan(src interface{}) error {
if j.Data == nil {
return nil
}
if src == nil {
dv := reflect.Indirect(reflect.ValueOf(j.Data))
dv.Set(reflect.Zero(dv.Type()))
return nil
}
b, ok := src.([]byte)
if !ok {
return errors.New("Scan source was not []bytes")
}
if err := json.Unmarshal(b, j.Data); err != nil {
return err
}
return nil
}
// Value satisfies the driver.Valuer interface.
func (j JSONB) Value() (driver.Value, error) {
// See https://github.com/lib/pq/issues/528#issuecomment-257197239 on why are
// we returning string instead of []byte.
if j.Data == nil {
return nil, nil
}
if v, ok := j.Data.(json.RawMessage); ok {
return string(v), nil
}
b, err := json.Marshal(j.Data)
if err != nil {
return nil, err
}
return string(b), nil
}
// StringArray represents a one-dimensional array of strings (`[]string{}`)
// that is compatible with PostgreSQL's text array (`text[]`). StringArray
// satisfies sqlbuilder.ScannerValuer.
type StringArray pq.StringArray
// Value satisfies the driver.Valuer interface.
func (a StringArray) Value() (driver.Value, error) {
return pq.StringArray(a).Value()
}
// Scan satisfies the sql.Scanner interface.
func (a *StringArray) Scan(src interface{}) error {
s := pq.StringArray(*a)
if err := s.Scan(src); err != nil {
return err
}
*a = StringArray(s)
return nil
}
// Int64Array represents a one-dimensional array of int64s (`[]int64{}`) that
// is compatible with PostgreSQL's integer array (`integer[]`). Int64Array
// satisfies sqlbuilder.ScannerValuer.
type Int64Array pq.Int64Array
// Value satisfies the driver.Valuer interface.
func (i Int64Array) Value() (driver.Value, error) {
return pq.Int64Array(i).Value()
}
// Scan satisfies the sql.Scanner interface.
func (i *Int64Array) Scan(src interface{}) error {
s := pq.Int64Array(*i)
if err := s.Scan(src); err != nil {
return err
}
*i = Int64Array(s)
return nil
}
// Float64Array represents a one-dimensional array of float64s (`[]float64{}`)
// that is compatible with PostgreSQL's double precision array (`double
// precision[]`). Float64Array satisfies sqlbuilder.ScannerValuer.
type Float64Array pq.Float64Array
// Value satisfies the driver.Valuer interface.
func (f Float64Array) Value() (driver.Value, error) {
return pq.Float64Array(f).Value()
}
// Scan satisfies the sql.Scanner interface.
func (f *Float64Array) Scan(src interface{}) error {
s := pq.Float64Array(*f)
if err := s.Scan(src); err != nil {
return err
}
*f = Float64Array(s)
return nil
}
// Float32Array represents a one-dimensional array of float32s (`[]float32{}`)
// that is compatible with PostgreSQL's double precision array (`double
// precision[]`). Float32Array satisfies sqlbuilder.ScannerValuer.
type Float32Array pq.Float32Array
// Value satisfies the driver.Valuer interface.
func (f Float32Array) Value() (driver.Value, error) {
return pq.Float32Array(f).Value()
}
// Scan satisfies the sql.Scanner interface.
func (f *Float32Array) Scan(src interface{}) error {
s := pq.Float32Array(*f)
if err := s.Scan(src); err != nil {
return err
}
*f = Float32Array(s)
return nil
}
// BoolArray represents a one-dimensional array of int64s (`[]bool{}`) that
// is compatible with PostgreSQL's boolean type (`boolean[]`). BoolArray
// satisfies sqlbuilder.ScannerValuer.
type BoolArray pq.BoolArray
// Value satisfies the driver.Valuer interface.
func (b BoolArray) Value() (driver.Value, error) {
return pq.BoolArray(b).Value()
}
// Scan satisfies the sql.Scanner interface.
func (b *BoolArray) Scan(src interface{}) error {
s := pq.BoolArray(*b)
if err := s.Scan(src); err != nil {
return err
}
*b = BoolArray(s)
return nil
}
type Bytea []byte
// Scan satisfies the sql.Scanner interface.
func (b *Bytea) Scan(src interface{}) error {
decoded, err := parseBytea(src.([]byte))
if err != nil {
return err
}
if len(decoded) < 1 {
*b = nil
return nil
}
(*b) = make(Bytea, len(decoded))
for i := range decoded {
(*b)[i] = decoded[i]
}
return nil
}
type Time time.Time
// Parse a bytea value received from the server. Both "hex" and the legacy
// "escape" format are supported.
func parseBytea(s []byte) (result []byte, err error) {
if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
// bytea_output = hex
s = s[2:] // trim off leading "\\x"
result = make([]byte, hex.DecodedLen(len(s)))
_, err := hex.Decode(result, s)
if err != nil {
return nil, err
}
} else {
// bytea_output = escape
for len(s) > 0 {
if s[0] == '\\' {
// escaped '\\'
if len(s) >= 2 && s[1] == '\\' {
result = append(result, '\\')
s = s[2:]
continue
}
// '\\' followed by an octal number
if len(s) < 4 {
return nil, fmt.Errorf("invalid bytea sequence %v", s)
}
r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
if err != nil {
return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
}
result = append(result, byte(r))
s = s[4:]
} else {
// We hit an unescaped, raw byte. Try to read in as many as
// possible in one go.
i := bytes.IndexByte(s, '\\')
if i == -1 {
result = append(result, s...)
break
}
result = append(result, s[:i]...)
s = s[i:]
}
}
}
return result, nil
}

View File

@ -0,0 +1,105 @@
package postgresql
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
)
type testStruct struct {
X int `json:"x"`
Z string `json:"z"`
V interface{} `json:"v"`
}
func TestScanJSONB(t *testing.T) {
{
a := testStruct{}
err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": 1}`))
assert.NoError(t, err)
assert.Equal(t, "Hello", a.Z)
assert.Equal(t, float64(1), a.V)
assert.Equal(t, 5, a.X)
}
{
a := testStruct{}
err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello", "v": null}`))
assert.NoError(t, err)
assert.Equal(t, "Hello", a.Z)
assert.Equal(t, nil, a.V)
assert.Equal(t, 5, a.X)
}
{
a := testStruct{}
err := ScanJSONB(&a, []byte(`{"x": 5, "z": "Hello"}`))
assert.NoError(t, err)
assert.Equal(t, "Hello", a.Z)
assert.Equal(t, nil, a.V)
assert.Equal(t, 5, a.X)
}
{
a := testStruct{}
err := ScanJSONB(&a, []byte(`{"v": "Hello"}`))
assert.NoError(t, err)
assert.Equal(t, "Hello", a.V)
}
{
a := testStruct{}
err := ScanJSONB(&a, []byte(`{"v": true}`))
assert.NoError(t, err)
assert.Equal(t, true, a.V)
}
{
a := testStruct{}
err := ScanJSONB(&a, []byte(`{}`))
assert.NoError(t, err)
assert.Equal(t, nil, a.V)
}
{
var a []byte
err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`))
assert.NoError(t, err)
assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a))
}
{
var a json.RawMessage
err := ScanJSONB(&a, []byte(`{"age":[{"\u003e":"1h"}]}`))
assert.NoError(t, err)
assert.Equal(t, `{"age":[{"\u003e":"1h"}]}`, string(a))
}
{
var a json.RawMessage
err := ScanJSONB(&a, []byte("{\"age\":[{\"\u003e\":\"1h\"}]}"))
assert.NoError(t, err)
assert.Equal(t, `{"age":[{">":"1h"}]}`, string(a))
}
{
a := []*testStruct{}
err := json.Unmarshal([]byte(`[{}]`), &a)
assert.NoError(t, err)
assert.Equal(t, 1, len(a))
assert.Nil(t, a[0].V)
}
{
a := []*testStruct{}
err := json.Unmarshal([]byte(`[{"v": true}]`), &a)
assert.NoError(t, err)
assert.Equal(t, 1, len(a))
assert.Equal(t, true, a[0].V)
}
{
a := []*testStruct{}
err := json.Unmarshal([]byte(`[{"v": null}]`), &a)
assert.NoError(t, err)
assert.Equal(t, 1, len(a))
assert.Nil(t, a[0].V)
}
{
a := []*testStruct{}
err := json.Unmarshal([]byte(`[{"v": 12.34}]`), &a)
assert.NoError(t, err)
assert.Equal(t, 1, len(a))
assert.Equal(t, 12.34, a[0].V)
}
}

View File

@ -0,0 +1,180 @@
// Package postgresql provides an adapter for PostgreSQL.
// See https://github.com/upper/db/adapter/postgresql for documentation,
// particularities and usage examples.
package postgresql
import (
"fmt"
"strings"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
type database struct {
}
func (*database) Template() *exql.Template {
return template
}
func (*database) Collections(sess sqladapter.Session) (collections []string, err error) {
q := sess.SQL().
Select("table_name").
From("information_schema.tables").
Where("table_schema = ?", "public")
iter := q.Iterator()
defer iter.Close()
for iter.Next() {
var name string
if err := iter.Scan(&name); err != nil {
return nil, err
}
collections = append(collections, name)
}
if err := iter.Err(); err != nil {
return nil, err
}
return collections, nil
}
func (*database) ConvertValue(in interface{}) interface{} {
switch v := in.(type) {
case *[]int64:
return (*Int64Array)(v)
case *[]string:
return (*StringArray)(v)
case *[]float64:
return (*Float64Array)(v)
case *[]bool:
return (*BoolArray)(v)
case *map[string]interface{}:
return (*JSONBMap)(v)
case []int64:
return (*Int64Array)(&v)
case []string:
return (*StringArray)(&v)
case []float64:
return (*Float64Array)(&v)
case []bool:
return (*BoolArray)(&v)
case map[string]interface{}:
return (*JSONBMap)(&v)
}
return in
}
func (*database) CompileStatement(sess sqladapter.Session, stmt *exql.Statement, args []interface{}) (string, []interface{}, error) {
compiled, err := stmt.Compile(template)
if err != nil {
return "", nil, err
}
query, args := sqlbuilder.Preprocess(compiled, args)
query = string(sqladapter.ReplaceWithDollarSign([]byte(query)))
return query, args, nil
}
func (*database) Err(err error) error {
if err != nil {
s := err.Error()
// These errors are not exported so we have to check them by they string value.
if strings.Contains(s, `too many clients`) || strings.Contains(s, `remaining connection slots are reserved`) || strings.Contains(s, `too many open`) {
return mydb.ErrTooManyClients
}
}
return err
}
func (*database) NewCollection() sqladapter.CollectionAdapter {
return &collectionAdapter{}
}
func (*database) LookupName(sess sqladapter.Session) (string, error) {
q := sess.SQL().
Select(mydb.Raw("CURRENT_DATABASE() AS name"))
iter := q.Iterator()
defer iter.Close()
if iter.Next() {
var name string
if err := iter.Scan(&name); err != nil {
return "", err
}
return name, nil
}
return "", iter.Err()
}
func (*database) TableExists(sess sqladapter.Session, name string) error {
q := sess.SQL().
Select("table_name").
From("information_schema.tables").
Where("table_catalog = ? AND table_name = ?", sess.Name(), name)
iter := q.Iterator()
defer iter.Close()
if iter.Next() {
var name string
if err := iter.Scan(&name); err != nil {
return err
}
return nil
}
if err := iter.Err(); err != nil {
return err
}
return mydb.ErrCollectionDoesNotExist
}
func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) {
q := sess.SQL().
Select("pg_attribute.attname AS pkey").
From("pg_index", "pg_class", "pg_attribute").
Where(`
pg_class.oid = '` + quotedTableName(tableName) + `'::regclass
AND indrelid = pg_class.oid
AND pg_attribute.attrelid = pg_class.oid
AND pg_attribute.attnum = ANY(pg_index.indkey)
AND indisprimary
`).OrderBy("pkey")
iter := q.Iterator()
defer iter.Close()
pk := []string{}
for iter.Next() {
var k string
if err := iter.Scan(&k); err != nil {
return nil, err
}
pk = append(pk, k)
}
if err := iter.Err(); err != nil {
return nil, err
}
return pk, nil
}
// quotedTableName returns a valid regclass name for both regular tables and
// for schemas.
func quotedTableName(s string) string {
chunks := strings.Split(s, ".")
for i := range chunks {
chunks[i] = fmt.Sprintf("%q", chunks[i])
}
return strings.Join(chunks, ".")
}

View File

@ -0,0 +1,26 @@
//go:build !pq
// +build !pq
package postgresql
import (
"context"
"database/sql"
"time"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
_ "github.com/jackc/pgx/v4/stdlib"
)
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
connURL, err := ParseURL(dsn)
if err != nil {
return nil, err
}
if tz := connURL.Options["timezone"]; tz != "" {
loc, _ := time.LoadLocation(tz)
ctx := context.WithValue(sess.Context(), "timezone", loc)
sess.SetContext(ctx)
}
return sql.Open("pgx", dsn)
}

View File

@ -0,0 +1,26 @@
//go:build pq
// +build pq
package postgresql
import (
"context"
"database/sql"
"time"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
_ "github.com/lib/pq"
)
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
connURL, err := ParseURL(dsn)
if err != nil {
return nil, err
}
if tz := connURL.Options["timezone"]; tz != "" {
loc, _ := time.LoadLocation(tz)
ctx := context.WithValue(sess.Context(), "timezone", loc)
sess.SetContext(ctx)
}
return sql.Open("postgres", dsn)
}

View File

@ -0,0 +1,13 @@
version: '3'
services:
server:
image: postgres:${POSTGRES_VERSION:-11}
environment:
POSTGRES_USER: ${DB_USERNAME:-upperio_user}
POSTGRES_PASSWORD: ${DB_PASSWORD:-upperio//s3cr37}
POSTGRES_DB: ${DB_NAME:-upperio}
ports:
- '${DB_HOST:-127.0.0.1}:${DB_PORT:-5432}:5432'

View File

@ -0,0 +1,20 @@
package postgresql
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type GenericTests struct {
testsuite.GenericTestSuite
}
func (s *GenericTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestGeneric(t *testing.T) {
suite.Run(t, &GenericTests{})
}

View File

@ -0,0 +1,321 @@
package postgresql
import (
"database/sql"
"fmt"
"os"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
)
var settings = ConnectionURL{
Database: os.Getenv("DB_NAME"),
User: os.Getenv("DB_USERNAME"),
Password: os.Getenv("DB_PASSWORD"),
Host: os.Getenv("DB_HOST") + ":" + os.Getenv("DB_PORT"),
Options: map[string]string{
"timezone": testsuite.TimeZone,
},
}
const preparedStatementsKey = "pg_prepared_statements_count"
type Helper struct {
sess mydb.Session
}
func cleanUp(sess mydb.Session) error {
if activeStatements := sqladapter.NumActiveStatements(); activeStatements > 128 {
return fmt.Errorf("Expecting active statements to be less than 128, got %d", activeStatements)
}
sess.Reset()
stats, err := getStats(sess)
if err != nil {
return err
}
if stats[preparedStatementsKey] != 0 {
return fmt.Errorf(`Expecting %q to be 0, got %d`, preparedStatementsKey, stats[preparedStatementsKey])
}
return nil
}
func getStats(sess mydb.Session) (map[string]int, error) {
stats := make(map[string]int)
row := sess.Driver().(*sql.DB).QueryRow(`SELECT count(1) AS value FROM pg_prepared_statements`)
var value int
err := row.Scan(&value)
if err != nil {
return nil, err
}
stats[preparedStatementsKey] = value
return stats, nil
}
func (h *Helper) Session() mydb.Session {
return h.sess
}
func (h *Helper) Adapter() string {
return Adapter
}
func (h *Helper) TearDown() error {
if err := cleanUp(h.sess); err != nil {
return err
}
return h.sess.Close()
}
func (h *Helper) TearUp() error {
var err error
h.sess, err = Open(settings)
if err != nil {
return err
}
batch := []string{
`DROP TABLE IF EXISTS artist`,
`CREATE TABLE artist (
id serial primary key,
name varchar(60)
)`,
`DROP TABLE IF EXISTS publication`,
`CREATE TABLE publication (
id serial primary key,
title varchar(80),
author_id integer
)`,
`DROP TABLE IF EXISTS review`,
`CREATE TABLE review (
id serial primary key,
publication_id integer,
name varchar(80),
comments text,
created timestamp without time zone
)`,
`DROP TABLE IF EXISTS data_types`,
`CREATE TABLE data_types (
id serial primary key,
_uint integer,
_uint8 integer,
_uint16 integer,
_uint32 integer,
_uint64 integer,
_int integer,
_int8 integer,
_int16 integer,
_int32 integer,
_int64 integer,
_float32 numeric(10,6),
_float64 numeric(10,6),
_bool boolean,
_string text,
_blob bytea,
_date timestamp with time zone,
_nildate timestamp without time zone null,
_ptrdate timestamp without time zone,
_defaultdate timestamp without time zone DEFAULT now(),
_time bigint
)`,
`DROP TABLE IF EXISTS stats_test`,
`CREATE TABLE stats_test (
id serial primary key,
numeric integer,
value integer
)`,
`DROP TABLE IF EXISTS composite_keys`,
`CREATE TABLE composite_keys (
code varchar(255) default '',
user_id varchar(255) default '',
some_val varchar(255) default '',
primary key (code, user_id)
)`,
`DROP TABLE IF EXISTS option_types`,
`CREATE TABLE option_types (
id serial primary key,
name varchar(255) default '',
tags varchar(64)[],
settings jsonb
)`,
`DROP TABLE IF EXISTS test_schema.test`,
`DROP SCHEMA IF EXISTS test_schema`,
`CREATE SCHEMA test_schema`,
`CREATE TABLE test_schema.test (id integer)`,
`DROP TABLE IF EXISTS pg_types`,
`CREATE TABLE pg_types (id serial primary key
, uint8_value smallint
, uint8_value_array bytea
, int64_value smallint
, int64_value_array smallint[]
, integer_array integer[]
, string_array text[]
, jsonb_map jsonb
, raw_jsonb_map jsonb
, raw_jsonb_text jsonb
, integer_array_ptr integer[]
, string_array_ptr text[]
, jsonb_map_ptr jsonb
, auto_integer_array integer[]
, auto_string_array text[]
, auto_jsonb_map jsonb
, auto_jsonb_map_string jsonb
, auto_jsonb_map_integer jsonb
, jsonb_object jsonb
, jsonb_array jsonb
, custom_jsonb_object jsonb
, auto_custom_jsonb_object jsonb
, custom_jsonb_object_ptr jsonb
, auto_custom_jsonb_object_ptr jsonb
, custom_jsonb_object_array jsonb
, auto_custom_jsonb_object_array jsonb
, auto_custom_jsonb_object_map jsonb
, string_value varchar(255)
, integer_value int
, varchar_value varchar(64)
, decimal_value decimal
, integer_compat_value int
, uinteger_compat_value int
, string_compat_value text
, integer_compat_value_jsonb_array jsonb
, string_compat_value_jsonb_array jsonb
, uinteger_compat_value_jsonb_array jsonb
, string_value_ptr varchar(255)
, integer_value_ptr int
, varchar_value_ptr varchar(64)
, decimal_value_ptr decimal
, uuid_value_string UUID
)`,
`DROP TABLE IF EXISTS issue_370`,
`CREATE TABLE issue_370 (
id UUID PRIMARY KEY,
name VARCHAR(25)
)`,
`CREATE EXTENSION IF NOT EXISTS "uuid-ossp"`,
`DROP TABLE IF EXISTS issue_602_organizations`,
`CREATE TABLE issue_602_organizations (
name character varying(256) NOT NULL,
created_at timestamp without time zone DEFAULT now() NOT NULL,
updated_at timestamp without time zone DEFAULT now() NOT NULL,
id uuid DEFAULT public.uuid_generate_v4() NOT NULL
)`,
`ALTER TABLE ONLY issue_602_organizations ADD CONSTRAINT issue_602_organizations_pkey PRIMARY KEY (id)`,
`DROP TABLE IF EXISTS issue_370_2`,
`CREATE TABLE issue_370_2 (
id INTEGER[3] PRIMARY KEY,
name VARCHAR(25)
)`,
`DROP TABLE IF EXISTS varchar_primary_key`,
`CREATE TABLE varchar_primary_key (
address VARCHAR(42) PRIMARY KEY NOT NULL,
name VARCHAR(25)
)`,
`DROP TABLE IF EXISTS "birthdays"`,
`CREATE TABLE "birthdays" (
"id" serial primary key,
"name" CHARACTER VARYING(50),
"born" TIMESTAMP WITH TIME ZONE,
"born_ut" INT
)`,
`DROP TABLE IF EXISTS "fibonacci"`,
`CREATE TABLE "fibonacci" (
"id" serial primary key,
"input" NUMERIC,
"output" NUMERIC
)`,
`DROP TABLE IF EXISTS "is_even"`,
`CREATE TABLE "is_even" (
"input" NUMERIC,
"is_even" BOOL
)`,
`DROP TABLE IF EXISTS "CaSe_TesT"`,
`CREATE TABLE "CaSe_TesT" (
"id" SERIAL PRIMARY KEY,
"case_test" VARCHAR(60)
)`,
`DROP TABLE IF EXISTS accounts`,
`CREATE TABLE accounts (
id serial primary key,
name varchar(255),
disabled boolean,
created_at timestamp with time zone
)`,
`DROP TABLE IF EXISTS users`,
`CREATE TABLE users (
id serial primary key,
account_id integer,
username varchar(255) UNIQUE
)`,
`DROP TABLE IF EXISTS logs`,
`CREATE TABLE logs (
id serial primary key,
message VARCHAR
)`,
}
driver := h.sess.Driver().(*sql.DB)
tx, err := driver.Begin()
if err != nil {
return err
}
for _, query := range batch {
if _, err := tx.Exec(query); err != nil {
_ = tx.Rollback()
return err
}
}
if err := tx.Commit(); err != nil {
return err
}
return nil
}
var _ testsuite.Helper = &Helper{}

View File

@ -0,0 +1,30 @@
package postgresql
import (
"database/sql"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
// Adapter is the internal name of the adapter.
const Adapter = "postgresql"
var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{})
// Open establishes a connection to the database server and returns a
// sqlbuilder.Session instance (which is compatible with mydb.Session).
func Open(connURL mydb.ConnectionURL) (mydb.Session, error) {
return registeredAdapter.OpenDSN(connURL)
}
// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value.
func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
return registeredAdapter.NewTx(sqlTx)
}
// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value.
func New(sqlDB *sql.DB) (mydb.Session, error) {
return registeredAdapter.New(sqlDB)
}

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,20 @@
package postgresql
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type RecordTests struct {
testsuite.RecordTestSuite
}
func (s *RecordTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestRecord(t *testing.T) {
suite.Run(t, &RecordTests{})
}

View File

@ -0,0 +1,20 @@
package postgresql
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type SQLTests struct {
testsuite.SQLTestSuite
}
func (s *SQLTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestSQL(t *testing.T) {
suite.Run(t, &SQLTests{})
}

View File

@ -0,0 +1,189 @@
package postgresql
import (
"git.hexq.cn/tiglog/mydb/internal/adapter"
"git.hexq.cn/tiglog/mydb/internal/cache"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
)
const (
adapterColumnSeparator = `.`
adapterIdentifierSeparator = `, `
adapterIdentifierQuote = `"{{.Value}}"`
adapterValueSeparator = `, `
adapterValueQuote = `'{{.}}'`
adapterAndKeyword = `AND`
adapterOrKeyword = `OR`
adapterDescKeyword = `DESC`
adapterAscKeyword = `ASC`
adapterAssignmentOperator = `=`
adapterClauseGroup = `({{.}})`
adapterClauseOperator = ` {{.}} `
adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}`
adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
adapterSortByColumnLayout = `{{.Column}} {{.Order}}`
adapterOrderByLayout = `
{{if .SortColumns}}
ORDER BY {{.SortColumns}}
{{end}}
`
adapterWhereLayout = `
{{if .Conds}}
WHERE {{.Conds}}
{{end}}
`
adapterUsingLayout = `
{{if .Columns}}
USING ({{.Columns}})
{{end}}
`
adapterJoinLayout = `
{{if .Table}}
{{ if .On }}
{{.Type}} JOIN {{.Table}}
{{.On}}
{{ else if .Using }}
{{.Type}} JOIN {{.Table}}
{{.Using}}
{{ else if .Type | eq "CROSS" }}
{{.Type}} JOIN {{.Table}}
{{else}}
NATURAL {{.Type}} JOIN {{.Table}}
{{end}}
{{end}}
`
adapterOnLayout = `
{{if .Conds}}
ON {{.Conds}}
{{end}}
`
adapterSelectLayout = `
SELECT
{{if .Distinct}}
DISTINCT
{{end}}
{{if defined .Columns}}
{{.Columns | compile}}
{{else}}
*
{{end}}
{{if defined .Table}}
FROM {{.Table | compile}}
{{end}}
{{.Joins | compile}}
{{.Where | compile}}
{{if defined .GroupBy}}
{{.GroupBy | compile}}
{{end}}
{{.OrderBy | compile}}
{{if .Limit}}
LIMIT {{.Limit}}
{{end}}
{{if .Offset}}
OFFSET {{.Offset}}
{{end}}
`
adapterDeleteLayout = `
DELETE
FROM {{.Table | compile}}
{{.Where | compile}}
`
adapterUpdateLayout = `
UPDATE
{{.Table | compile}}
SET {{.ColumnValues | compile}}
{{.Where | compile}}
`
adapterSelectCountLayout = `
SELECT
COUNT(1) AS _t
FROM {{.Table | compile}}
{{.Where | compile}}
`
adapterInsertLayout = `
INSERT INTO {{.Table | compile}}
{{if defined .Columns}}({{.Columns | compile}}){{end}}
VALUES
{{if defined .Values}}
{{.Values | compile}}
{{else}}
(default)
{{end}}
{{if defined .Returning}}
RETURNING {{.Returning | compile}}
{{end}}
`
adapterTruncateLayout = `
TRUNCATE TABLE {{.Table | compile}} RESTART IDENTITY
`
adapterDropDatabaseLayout = `
DROP DATABASE {{.Database | compile}}
`
adapterDropTableLayout = `
DROP TABLE {{.Table | compile}}
`
adapterGroupByLayout = `
{{if .GroupColumns}}
GROUP BY {{.GroupColumns}}
{{end}}
`
)
var template = &exql.Template{
ColumnSeparator: adapterColumnSeparator,
IdentifierSeparator: adapterIdentifierSeparator,
IdentifierQuote: adapterIdentifierQuote,
ValueSeparator: adapterValueSeparator,
ValueQuote: adapterValueQuote,
AndKeyword: adapterAndKeyword,
OrKeyword: adapterOrKeyword,
DescKeyword: adapterDescKeyword,
AscKeyword: adapterAscKeyword,
AssignmentOperator: adapterAssignmentOperator,
ClauseGroup: adapterClauseGroup,
ClauseOperator: adapterClauseOperator,
ColumnValue: adapterColumnValue,
TableAliasLayout: adapterTableAliasLayout,
ColumnAliasLayout: adapterColumnAliasLayout,
SortByColumnLayout: adapterSortByColumnLayout,
WhereLayout: adapterWhereLayout,
JoinLayout: adapterJoinLayout,
OnLayout: adapterOnLayout,
UsingLayout: adapterUsingLayout,
OrderByLayout: adapterOrderByLayout,
InsertLayout: adapterInsertLayout,
SelectLayout: adapterSelectLayout,
UpdateLayout: adapterUpdateLayout,
DeleteLayout: adapterDeleteLayout,
TruncateLayout: adapterTruncateLayout,
DropDatabaseLayout: adapterDropDatabaseLayout,
DropTableLayout: adapterDropTableLayout,
CountLayout: adapterSelectCountLayout,
GroupByLayout: adapterGroupByLayout,
Cache: cache.NewCache(),
ComparisonOperator: map[adapter.ComparisonOperator]string{
adapter.ComparisonOperatorRegExp: "~",
adapter.ComparisonOperatorNotRegExp: "!~",
},
}

View File

@ -0,0 +1,262 @@
package postgresql
import (
"testing"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
"github.com/stretchr/testify/assert"
)
func TestTemplateSelect(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`SELECT * FROM "artist"`,
b.SelectFrom("artist").String(),
)
assert.Equal(
`SELECT * FROM "artist"`,
b.Select().From("artist").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" DESC`,
b.Select().From("artist").OrderBy("name DESC").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" DESC`,
b.Select().From("artist").OrderBy("-name").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" ASC`,
b.Select().From("artist").OrderBy("name").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" ASC`,
b.Select().From("artist").OrderBy("name ASC").String(),
)
assert.Equal(
`SELECT * FROM "artist" LIMIT 1 OFFSET 5`,
b.Select().From("artist").Limit(1).Offset(5).String(),
)
assert.Equal(
`SELECT * FROM "artist" LIMIT 1 OFFSET 5`,
b.Select().From("artist").Offset(5).Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" OFFSET 5`,
b.Select().From("artist").Limit(-1).Offset(5).String(),
)
assert.Equal(
`SELECT * FROM "artist" OFFSET 5`,
b.Select().From("artist").Offset(5).String(),
)
assert.Equal(
`SELECT "id" FROM "artist"`,
b.Select("id").From("artist").String(),
)
assert.Equal(
`SELECT "id", "name" FROM "artist"`,
b.Select("id", "name").From("artist").String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE ("name" = $1)`,
b.SelectFrom("artist").Where("name", "Haruki").String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE (name LIKE $1)`,
b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(),
)
assert.Equal(
`SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`,
b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE ("id" > $1)`,
b.SelectFrom("artist").Where("id >", 2).String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`,
b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`,
b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE (name IS NOT NULL)`,
b.SelectFrom("artist").Where("name IS NOT NULL").String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`,
b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(),
)
assert.Equal(
`SELECT "id" FROM "artist" NATURAL JOIN "publication"`,
b.Select("id").From("artist").Join("publication").String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`,
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`,
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`,
b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`,
b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`,
b.SelectFrom("artist a").
LeftJoin("publication p1").On("p1.id = a.id").
RightJoin("publication p2").On("p2.id = a.id").
String(),
)
assert.Equal(
`SELECT * FROM "artist" CROSS JOIN "publication"`,
b.SelectFrom("artist").CrossJoin("publication").String(),
)
assert.Equal(
`SELECT * FROM "artist" JOIN "publication" USING ("id")`,
b.SelectFrom("artist").Join("publication").Using("id").String(),
)
assert.Equal(
`SELECT DATE()`,
b.Select(mydb.Raw("DATE()")).String(),
)
}
func TestTemplateInsert(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`,
b.InsertInto("artist").
Values(10, "Ryuichi Sakamoto").
Values(11, "Alondra de la Parra").
Values(12, "Haruki Murakami").
String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`,
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(struct {
ID int `db:"id"`
Name string `db:"name"`
}{12, "Chavela Vargas"}).String(),
)
assert.Equal(
`INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`,
b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(),
)
}
func TestTemplateUpdate(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`UPDATE "artist" SET "name" = $1`,
b.Update("artist").Set("name", "Artist").String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set(struct {
Nombre string `db:"name"`
}{"Artist"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`,
b.Update("artist").Set(struct {
Nombre string `db:"name"`
}{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`,
b.Update("artist").Set(
"name = ? || ' ' || ? || id", "Artist", "#",
"id = id + ?", 10,
).Where("id > ?", 0).String(),
)
}
func TestTemplateDelete(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`DELETE FROM "artist" WHERE (name = $1)`,
b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").Limit(1).String(),
)
assert.Equal(
`DELETE FROM "artist" WHERE (id > 5)`,
b.DeleteFrom("artist").Where("id > 5").String(),
)
}

27
adapter/sqlite/Makefile Normal file
View File

@ -0,0 +1,27 @@
SHELL ?= bash
DB_NAME ?= sqlite3-test.db
TEST_FLAGS ?=
export DB_NAME
export TEST_FLAGS
build:
go build && go install
require-client:
@if [ -z "$$(which sqlite3)" ]; then \
echo 'Missing "sqlite3" command. Please install SQLite3 and try again.' && \
exit 1; \
fi
reset-db: require-client
rm -f $(DB_NAME)
test: reset-db
go test -v -failfast -race -timeout 20m $(TEST_FLAGS)
test-no-race:
go test -v -failfast $(TEST_FLAGS)
test-extended: test

4
adapter/sqlite/README.md Normal file
View File

@ -0,0 +1,4 @@
# SQLite adapter for upper/db
Please read the full docs, acknowledgements and examples at
[https://upper.io/v4/adapter/sqlite/](https://upper.io/v4/adapter/sqlite/).

View File

@ -0,0 +1,49 @@
package sqlite
import (
"database/sql"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
type collectionAdapter struct {
}
func (*collectionAdapter) Insert(col sqladapter.Collection, item interface{}) (interface{}, error) {
columnNames, columnValues, err := sqlbuilder.Map(item, nil)
if err != nil {
return nil, err
}
pKey, err := col.PrimaryKeys()
if err != nil {
return nil, err
}
q := col.SQL().InsertInto(col.Name()).
Columns(columnNames...).
Values(columnValues...)
var res sql.Result
if res, err = q.Exec(); err != nil {
return nil, err
}
if len(pKey) <= 1 {
return res.LastInsertId()
}
keyMap := mydb.Cond{}
for i := range columnNames {
for j := 0; j < len(pKey); j++ {
if pKey[j] == columnNames[i] {
keyMap[pKey[j]] = columnValues[i]
}
}
}
return keyMap, nil
}

View File

@ -0,0 +1,89 @@
package sqlite
import (
"fmt"
"net/url"
"path/filepath"
"runtime"
"strings"
)
const connectionScheme = `file`
// ConnectionURL implements a SQLite connection struct.
type ConnectionURL struct {
Database string
Options map[string]string
}
func (c ConnectionURL) String() (s string) {
vv := url.Values{}
if c.Database == "" {
return ""
}
// Did the user provided a full database path?
if !strings.HasPrefix(c.Database, "/") {
c.Database, _ = filepath.Abs(c.Database)
if runtime.GOOS == "windows" {
// Closes https://github.com/upper/db/issues/60
c.Database = "/" + strings.Replace(c.Database, `\`, `/`, -1)
}
}
// Do we have any options?
if c.Options == nil {
c.Options = map[string]string{}
}
if _, ok := c.Options["_busy_timeout"]; !ok {
c.Options["_busy_timeout"] = "10000"
}
// Converting options into URL values.
for k, v := range c.Options {
vv.Set(k, v)
}
// Building URL.
u := url.URL{
Scheme: connectionScheme,
Path: c.Database,
RawQuery: vv.Encode(),
}
return u.String()
}
// ParseURL parses s into a ConnectionURL struct.
func ParseURL(s string) (conn ConnectionURL, err error) {
var u *url.URL
if !strings.HasPrefix(s, connectionScheme+"://") {
return conn, fmt.Errorf(`Expecting file:// connection scheme.`)
}
if u, err = url.Parse(s); err != nil {
return conn, err
}
conn.Database = u.Host + u.Path
conn.Options = map[string]string{}
var vv url.Values
if vv, err = url.ParseQuery(u.RawQuery); err != nil {
return conn, err
}
for k := range vv {
conn.Options[k] = vv.Get(k)
}
if _, ok := conn.Options["cache"]; !ok {
conn.Options["cache"] = "shared"
}
return conn, err
}

View File

@ -0,0 +1,88 @@
package sqlite
import (
"path/filepath"
"testing"
)
func TestConnectionURL(t *testing.T) {
c := ConnectionURL{}
// Default connection string is only the protocol.
if c.String() != "" {
t.Fatal(`Expecting default connectiong string to be empty, got:`, c.String())
}
// Adding a database name.
c.Database = "myfilename"
absoluteName, _ := filepath.Abs(c.Database)
if c.String() != "file://"+absoluteName+"?_busy_timeout=10000" {
t.Fatal(`Test failed, got:`, c.String())
}
// Adding an option.
c.Options = map[string]string{
"cache": "foobar",
"mode": "ro",
}
if c.String() != "file://"+absoluteName+"?_busy_timeout=10000&cache=foobar&mode=ro" {
t.Fatal(`Test failed, got:`, c.String())
}
// Setting another database.
c.Database = "/another/database"
if c.String() != `file:///another/database?_busy_timeout=10000&cache=foobar&mode=ro` {
t.Fatal(`Test failed, got:`, c.String())
}
}
func TestParseConnectionURL(t *testing.T) {
var u ConnectionURL
var s string
var err error
s = "file://mydatabase.db"
if u, err = ParseURL(s); err != nil {
t.Fatal(err)
}
if u.Database != "mydatabase.db" {
t.Fatal("Failed to parse database.")
}
if u.Options["cache"] != "shared" {
t.Fatal("If not defined, cache should be shared by default.")
}
s = "file:///path/to/my/database.db?_busy_timeout=10000&mode=ro&cache=foobar"
if u, err = ParseURL(s); err != nil {
t.Fatal(err)
}
if u.Database != "/path/to/my/database.db" {
t.Fatal("Failed to parse username.")
}
if u.Options["cache"] != "foobar" {
t.Fatal("Expecting option.")
}
if u.Options["mode"] != "ro" {
t.Fatal("Expecting option.")
}
s = "http://example.org"
if _, err = ParseURL(s); err == nil {
t.Fatal("Expecting error.")
}
}

168
adapter/sqlite/database.go Normal file
View File

@ -0,0 +1,168 @@
// Package sqlite wraps the github.com/lib/sqlite SQLite driver. See
// https://github.com/upper/db/adapter/sqlite for documentation, particularities and
// usage examples.
package sqlite
import (
"context"
"database/sql"
"fmt"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/compat"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
_ "github.com/mattn/go-sqlite3" // SQLite3 driver.
)
// database is the actual implementation of Database
type database struct {
}
func (*database) Template() *exql.Template {
return template
}
func (*database) OpenDSN(sess sqladapter.Session, dsn string) (*sql.DB, error) {
return sql.Open("sqlite3", dsn)
}
func (*database) Collections(sess sqladapter.Session) (collections []string, err error) {
q := sess.SQL().
Select("tbl_name").
From("sqlite_master").
Where("type = ?", "table")
iter := q.Iterator()
defer iter.Close()
for iter.Next() {
var tableName string
if err := iter.Scan(&tableName); err != nil {
return nil, err
}
collections = append(collections, tableName)
}
if err := iter.Err(); err != nil {
return nil, err
}
return collections, nil
}
func (*database) StatementExec(sess sqladapter.Session, ctx context.Context, query string, args ...interface{}) (res sql.Result, err error) {
if sess.Transaction() != nil {
return compat.ExecContext(sess.Driver().(*sql.Tx), ctx, query, args)
}
sqlTx, err := compat.BeginTx(sess.Driver().(*sql.DB), ctx, nil)
if err != nil {
return nil, err
}
if res, err = compat.ExecContext(sqlTx, ctx, query, args); err != nil {
_ = sqlTx.Rollback()
return nil, err
}
if err = sqlTx.Commit(); err != nil {
return nil, err
}
return res, err
}
func (*database) NewCollection() sqladapter.CollectionAdapter {
return &collectionAdapter{}
}
func (*database) LookupName(sess sqladapter.Session) (string, error) {
connURL := sess.ConnectionURL()
if connURL != nil {
connURL, err := ParseURL(connURL.String())
if err != nil {
return "", err
}
return connURL.Database, nil
}
// sess.ConnectionURL() is nil if using sqlite.New
rows, err := sess.SQL().Query(exql.RawSQL("PRAGMA database_list"))
if err != nil {
return "", err
}
dbInfo := struct {
Name string `db:"name"`
File string `db:"file"`
}{}
if err := sess.SQL().NewIterator(rows).One(&dbInfo); err != nil {
return "", err
}
if dbInfo.File != "" {
return dbInfo.File, nil
}
// dbInfo.File is empty if in memory mode
return dbInfo.Name, nil
}
func (*database) TableExists(sess sqladapter.Session, name string) error {
q := sess.SQL().
Select("tbl_name").
From("sqlite_master").
Where("type = 'table' AND tbl_name = ?", name)
iter := q.Iterator()
defer iter.Close()
if iter.Next() {
var name string
if err := iter.Scan(&name); err != nil {
return err
}
return nil
}
if err := iter.Err(); err != nil {
return err
}
return mydb.ErrCollectionDoesNotExist
}
func (*database) PrimaryKeys(sess sqladapter.Session, tableName string) ([]string, error) {
pk := make([]string, 0, 1)
stmt := exql.RawSQL(fmt.Sprintf("PRAGMA TABLE_INFO('%s')", tableName))
rows, err := sess.SQL().Query(stmt)
if err != nil {
return nil, err
}
columns := []struct {
Name string `db:"name"`
PK int `db:"pk"`
}{}
if err := sess.SQL().NewIterator(rows).All(&columns); err != nil {
return nil, err
}
maxValue := -1
for _, column := range columns {
if column.PK > 0 && column.PK > maxValue {
maxValue = column.PK
}
}
if maxValue > 0 {
for _, column := range columns {
if column.PK > 0 {
pk = append(pk, column.Name)
}
}
}
return pk, nil
}

View File

@ -0,0 +1,20 @@
package sqlite
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type GenericTests struct {
testsuite.GenericTestSuite
}
func (s *GenericTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestGeneric(t *testing.T) {
suite.Run(t, &GenericTests{})
}

View File

@ -0,0 +1,170 @@
package sqlite
import (
"database/sql"
"os"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
)
var settings = ConnectionURL{
Database: os.Getenv("DB_NAME"),
}
type Helper struct {
sess mydb.Session
}
func (h *Helper) Session() mydb.Session {
return h.sess
}
func (h *Helper) Adapter() string {
return "sqlite"
}
func (h *Helper) TearDown() error {
return h.sess.Close()
}
func (h *Helper) TearUp() error {
var err error
h.sess, err = Open(settings)
if err != nil {
return err
}
batch := []string{
`PRAGMA foreign_keys=OFF`,
`BEGIN TRANSACTION`,
`DROP TABLE IF EXISTS artist`,
`CREATE TABLE artist (
id integer primary key,
name varchar(60)
)`,
`DROP TABLE IF EXISTS publication`,
`CREATE TABLE publication (
id integer primary key,
title varchar(80),
author_id integer
)`,
`DROP TABLE IF EXISTS review`,
`CREATE TABLE review (
id integer primary key,
publication_id integer,
name varchar(80),
comments text,
created datetime
)`,
`DROP TABLE IF EXISTS data_types`,
`CREATE TABLE data_types (
id integer primary key,
_uint integer,
_uintptr integer,
_uint8 integer,
_uint16 int,
_uint32 int,
_uint64 int,
_int integer,
_int8 integer,
_int16 integer,
_int32 integer,
_int64 integer,
_float32 real,
_float64 real,
_byte integer,
_rune integer,
_bool integer,
_string text,
_blob blob,
_date datetime,
_nildate datetime,
_ptrdate datetime,
_defaultdate datetime default current_timestamp,
_time text
)`,
`DROP TABLE IF EXISTS stats_test`,
`CREATE TABLE stats_test (
id integer primary key,
numeric integer,
value integer
)`,
`DROP TABLE IF EXISTS composite_keys`,
`CREATE TABLE composite_keys (
code VARCHAR(255) default '',
user_id VARCHAR(255) default '',
some_val VARCHAR(255) default '',
primary key (code, user_id)
)`,
`DROP TABLE IF EXISTS "birthdays"`,
`CREATE TABLE "birthdays" (
"id" INTEGER PRIMARY KEY,
"name" VARCHAR(50) DEFAULT NULL,
"born" DATETIME DEFAULT NULL,
"born_ut" INTEGER
)`,
`DROP TABLE IF EXISTS "fibonacci"`,
`CREATE TABLE "fibonacci" (
"id" INTEGER PRIMARY KEY,
"input" INTEGER,
"output" INTEGER
)`,
`DROP TABLE IF EXISTS "is_even"`,
`CREATE TABLE "is_even" (
"input" INTEGER,
"is_even" INTEGER
)`,
`DROP TABLE IF EXISTS "CaSe_TesT"`,
`CREATE TABLE "CaSe_TesT" (
"id" INTEGER PRIMARY KEY,
"case_test" VARCHAR
)`,
`DROP TABLE IF EXISTS accounts`,
`CREATE TABLE accounts (
id integer primary key,
name varchar,
disabled integer,
created_at datetime default current_timestamp
)`,
`DROP TABLE IF EXISTS users`,
`CREATE TABLE users (
id integer primary key,
account_id integer,
username varchar UNIQUE
)`,
`DROP TABLE IF EXISTS logs`,
`CREATE TABLE logs (
id integer primary key,
message VARCHAR
)`,
`COMMIT`,
}
for _, query := range batch {
driver := h.sess.Driver().(*sql.DB)
if _, err := driver.Exec(query); err != nil {
return err
}
}
return nil
}
var _ testsuite.Helper = &Helper{}

View File

@ -0,0 +1,20 @@
package sqlite
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type RecordTests struct {
testsuite.RecordTestSuite
}
func (s *RecordTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestRecord(t *testing.T) {
suite.Run(t, &RecordTests{})
}

View File

@ -0,0 +1,20 @@
package sqlite
import (
"testing"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type SQLTests struct {
testsuite.SQLTestSuite
}
func (s *SQLTests) SetupSuite() {
s.Helper = &Helper{}
}
func TestSQL(t *testing.T) {
suite.Run(t, &SQLTests{})
}

30
adapter/sqlite/sqlite.go Normal file
View File

@ -0,0 +1,30 @@
package sqlite
import (
"database/sql"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
// Adapter is the public name of the adapter.
const Adapter = `sqlite`
var registeredAdapter = sqladapter.RegisterAdapter(Adapter, &database{})
// Open establishes a connection to the database server and returns a
// mydb.Session instance (which is compatible with mydb.Session).
func Open(connURL mydb.ConnectionURL) (mydb.Session, error) {
return registeredAdapter.OpenDSN(connURL)
}
// NewTx creates a sqlbuilder.Tx instance by wrapping a *sql.Tx value.
func NewTx(sqlTx *sql.Tx) (sqlbuilder.Tx, error) {
return registeredAdapter.NewTx(sqlTx)
}
// New creates a sqlbuilder.Sesion instance by wrapping a *sql.DB value.
func New(sqlDB *sql.DB) (mydb.Session, error) {
return registeredAdapter.New(sqlDB)
}

View File

@ -0,0 +1,55 @@
package sqlite
import (
"path/filepath"
"testing"
"database/sql"
"git.hexq.cn/tiglog/mydb/internal/testsuite"
"github.com/stretchr/testify/suite"
)
type AdapterTests struct {
testsuite.Suite
}
func (s *AdapterTests) SetupSuite() {
s.Helper = &Helper{}
}
func (s *AdapterTests) Test_Issue633_OpenSession() {
sess, err := Open(settings)
s.NoError(err)
defer sess.Close()
absoluteName, _ := filepath.Abs(settings.Database)
s.Equal(absoluteName, sess.Name())
}
func (s *AdapterTests) Test_Issue633_NewAdapterWithFile() {
sqldb, err := sql.Open("sqlite3", settings.Database)
s.NoError(err)
sess, err := New(sqldb)
s.NoError(err)
defer sess.Close()
absoluteName, _ := filepath.Abs(settings.Database)
s.Equal(absoluteName, sess.Name())
}
func (s *AdapterTests) Test_Issue633_NewAdapterWithMemory() {
sqldb, err := sql.Open("sqlite3", ":memory:")
s.NoError(err)
sess, err := New(sqldb)
s.NoError(err)
defer sess.Close()
s.Equal("main", sess.Name())
}
func TestAdapter(t *testing.T) {
suite.Run(t, &AdapterTests{})
}

187
adapter/sqlite/template.go Normal file
View File

@ -0,0 +1,187 @@
package sqlite
import (
"git.hexq.cn/tiglog/mydb/internal/cache"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
)
const (
adapterColumnSeparator = `.`
adapterIdentifierSeparator = `, `
adapterIdentifierQuote = `"{{.Value}}"`
adapterValueSeparator = `, `
adapterValueQuote = `'{{.}}'`
adapterAndKeyword = `AND`
adapterOrKeyword = `OR`
adapterDescKeyword = `DESC`
adapterAscKeyword = `ASC`
adapterAssignmentOperator = `=`
adapterClauseGroup = `({{.}})`
adapterClauseOperator = ` {{.}} `
adapterColumnValue = `{{.Column}} {{.Operator}} {{.Value}}`
adapterTableAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
adapterColumnAliasLayout = `{{.Name}}{{if .Alias}} AS {{.Alias}}{{end}}`
adapterSortByColumnLayout = `{{.Column}} {{.Order}}`
adapterOrderByLayout = `
{{if .SortColumns}}
ORDER BY {{.SortColumns}}
{{end}}
`
adapterWhereLayout = `
{{if .Conds}}
WHERE {{.Conds}}
{{end}}
`
adapterUsingLayout = `
{{if .Columns}}
USING ({{.Columns}})
{{end}}
`
adapterJoinLayout = `
{{if .Table}}
{{ if .On }}
{{.Type}} JOIN {{.Table}}
{{.On}}
{{ else if .Using }}
{{.Type}} JOIN {{.Table}}
{{.Using}}
{{ else if .Type | eq "CROSS" }}
{{.Type}} JOIN {{.Table}}
{{else}}
NATURAL {{.Type}} JOIN {{.Table}}
{{end}}
{{end}}
`
adapterOnLayout = `
{{if .Conds}}
ON {{.Conds}}
{{end}}
`
adapterSelectLayout = `
SELECT
{{if .Distinct}}
DISTINCT
{{end}}
{{if defined .Columns}}
{{.Columns | compile}}
{{else}}
*
{{end}}
{{if defined .Table}}
FROM {{.Table | compile}}
{{end}}
{{.Joins | compile}}
{{.Where | compile}}
{{if defined .GroupBy}}
{{.GroupBy | compile}}
{{end}}
{{.OrderBy | compile}}
{{if .Limit}}
LIMIT {{.Limit}}
{{end}}
{{if .Offset}}
{{if not .Limit}}
LIMIT -1
{{end}}
OFFSET {{.Offset}}
{{end}}
`
adapterDeleteLayout = `
DELETE
FROM {{.Table | compile}}
{{.Where | compile}}
`
adapterUpdateLayout = `
UPDATE
{{.Table | compile}}
SET {{.ColumnValues | compile}}
{{.Where | compile}}
`
adapterSelectCountLayout = `
SELECT
COUNT(1) AS _t
FROM {{.Table | compile}}
{{.Where | compile}}
`
adapterInsertLayout = `
INSERT INTO {{.Table | compile}}
{{if .Columns }}({{.Columns | compile}}){{end}}
{{if defined .Values}}
VALUES
{{.Values | compile}}
{{else}}
DEFAULT VALUES
{{end}}
{{if defined .Returning}}
RETURNING {{.Returning | compile}}
{{end}}
`
adapterTruncateLayout = `
DELETE FROM {{.Table | compile}}
`
adapterDropDatabaseLayout = `
DROP DATABASE {{.Database | compile}}
`
adapterDropTableLayout = `
DROP TABLE {{.Table | compile}}
`
adapterGroupByLayout = `
{{if .GroupColumns}}
GROUP BY {{.GroupColumns}}
{{end}}
`
)
var template = &exql.Template{
ColumnSeparator: adapterColumnSeparator,
IdentifierSeparator: adapterIdentifierSeparator,
IdentifierQuote: adapterIdentifierQuote,
ValueSeparator: adapterValueSeparator,
ValueQuote: adapterValueQuote,
AndKeyword: adapterAndKeyword,
OrKeyword: adapterOrKeyword,
DescKeyword: adapterDescKeyword,
AscKeyword: adapterAscKeyword,
AssignmentOperator: adapterAssignmentOperator,
ClauseGroup: adapterClauseGroup,
ClauseOperator: adapterClauseOperator,
ColumnValue: adapterColumnValue,
TableAliasLayout: adapterTableAliasLayout,
ColumnAliasLayout: adapterColumnAliasLayout,
SortByColumnLayout: adapterSortByColumnLayout,
WhereLayout: adapterWhereLayout,
JoinLayout: adapterJoinLayout,
OnLayout: adapterOnLayout,
UsingLayout: adapterUsingLayout,
OrderByLayout: adapterOrderByLayout,
InsertLayout: adapterInsertLayout,
SelectLayout: adapterSelectLayout,
UpdateLayout: adapterUpdateLayout,
DeleteLayout: adapterDeleteLayout,
TruncateLayout: adapterTruncateLayout,
DropDatabaseLayout: adapterDropDatabaseLayout,
DropTableLayout: adapterDropTableLayout,
CountLayout: adapterSelectCountLayout,
GroupByLayout: adapterGroupByLayout,
Cache: cache.NewCache(),
}

View File

@ -0,0 +1,246 @@
package sqlite
import (
"testing"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
"github.com/stretchr/testify/assert"
)
func TestTemplateSelect(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`SELECT * FROM "artist"`,
b.SelectFrom("artist").String(),
)
assert.Equal(
`SELECT * FROM "artist"`,
b.Select().From("artist").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" DESC`,
b.Select().From("artist").OrderBy("name DESC").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" DESC`,
b.Select().From("artist").OrderBy("-name").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" ASC`,
b.Select().From("artist").OrderBy("name").String(),
)
assert.Equal(
`SELECT * FROM "artist" ORDER BY "name" ASC`,
b.Select().From("artist").OrderBy("name ASC").String(),
)
assert.Equal(
`SELECT * FROM "artist" LIMIT -1 OFFSET 5`,
b.Select().From("artist").Limit(-1).Offset(5).String(),
)
assert.Equal(
`SELECT "id" FROM "artist"`,
b.Select("id").From("artist").String(),
)
assert.Equal(
`SELECT "id", "name" FROM "artist"`,
b.Select("id", "name").From("artist").String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE ("name" = $1)`,
b.SelectFrom("artist").Where("name", "Haruki").String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE (name LIKE $1)`,
b.SelectFrom("artist").Where("name LIKE ?", `%F%`).String(),
)
assert.Equal(
`SELECT "id" FROM "artist" WHERE (name LIKE $1 OR name LIKE $2)`,
b.Select("id").From("artist").Where(`name LIKE ? OR name LIKE ?`, `%Miya%`, `F%`).String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE ("id" > $1)`,
b.SelectFrom("artist").Where("id >", 2).String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE (id <= 2 AND name != $1)`,
b.SelectFrom("artist").Where("id <= 2 AND name != ?", "A").String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE ("id" IN ($1, $2, $3, $4))`,
b.SelectFrom("artist").Where("id IN", []int{1, 9, 8, 7}).String(),
)
assert.Equal(
`SELECT * FROM "artist" WHERE (name IS NOT NULL)`,
b.SelectFrom("artist").Where("name IS NOT NULL").String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a", "publication" AS "p" WHERE (p.author_id = a.id) LIMIT 1`,
b.Select().From("artist a", "publication as p").Where("p.author_id = a.id").Limit(1).String(),
)
assert.Equal(
`SELECT "id" FROM "artist" NATURAL JOIN "publication"`,
b.Select("id").From("artist").Join("publication").String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) LIMIT 1`,
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE ("a"."id" = $1) LIMIT 1`,
b.SelectFrom("artist a").Join("publication p").On("p.author_id = a.id").Where("a.id", 2).Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" JOIN "publication" AS "p" ON (p.author_id = a.id) WHERE (a.id = 2) LIMIT 1`,
b.SelectFrom("artist").Join("publication p").On("p.author_id = a.id").Where("a.id = 2").Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" JOIN "publication" AS "p" ON (p.title LIKE $1 OR p.title LIKE $2) WHERE (a.id = $3) LIMIT 1`,
b.SelectFrom("artist a").Join("publication p").On("p.title LIKE ? OR p.title LIKE ?", "%Totoro%", "%Robot%").Where("a.id = ?", 2).Limit(1).String(),
)
assert.Equal(
`SELECT * FROM "artist" AS "a" LEFT JOIN "publication" AS "p1" ON (p1.id = a.id) RIGHT JOIN "publication" AS "p2" ON (p2.id = a.id)`,
b.SelectFrom("artist a").
LeftJoin("publication p1").On("p1.id = a.id").
RightJoin("publication p2").On("p2.id = a.id").
String(),
)
assert.Equal(
`SELECT * FROM "artist" CROSS JOIN "publication"`,
b.SelectFrom("artist").CrossJoin("publication").String(),
)
assert.Equal(
`SELECT * FROM "artist" JOIN "publication" USING ("id")`,
b.SelectFrom("artist").Join("publication").Using("id").String(),
)
assert.Equal(
`SELECT DATE()`,
b.Select(mydb.Raw("DATE()")).String(),
)
}
func TestTemplateInsert(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`INSERT INTO "artist" VALUES ($1, $2), ($3, $4), ($5, $6)`,
b.InsertInto("artist").
Values(10, "Ryuichi Sakamoto").
Values(11, "Alondra de la Parra").
Values(12, "Haruki Murakami").
String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2) RETURNING "id"`,
b.InsertInto("artist").Values(map[string]string{"id": "12", "name": "Chavela Vargas"}).Returning("id").String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(map[string]interface{}{"name": "Chavela Vargas", "id": 12}).String(),
)
assert.Equal(
`INSERT INTO "artist" ("id", "name") VALUES ($1, $2)`,
b.InsertInto("artist").Values(struct {
ID int `db:"id"`
Name string `db:"name"`
}{12, "Chavela Vargas"}).String(),
)
assert.Equal(
`INSERT INTO "artist" ("name", "id") VALUES ($1, $2)`,
b.InsertInto("artist").Columns("name", "id").Values("Chavela Vargas", 12).String(),
)
}
func TestTemplateUpdate(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`UPDATE "artist" SET "name" = $1`,
b.Update("artist").Set("name", "Artist").String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set("name = ?", "Artist").Where("id <", 5).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set(map[string]string{"name": "Artist"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 WHERE ("id" < $2)`,
b.Update("artist").Set(struct {
Nombre string `db:"name"`
}{"Artist"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1, "last_name" = $2 WHERE ("id" < $3)`,
b.Update("artist").Set(struct {
Nombre string `db:"name"`
}{"Artist"}).Set(map[string]string{"last_name": "Foo"}).Where(mydb.Cond{"id <": 5}).String(),
)
assert.Equal(
`UPDATE "artist" SET "name" = $1 || ' ' || $2 || id, "id" = id + $3 WHERE (id > $4)`,
b.Update("artist").Set(
"name = ? || ' ' || ? || id", "Artist", "#",
"id = id + ?", 10,
).Where("id > ?", 0).String(),
)
}
func TestTemplateDelete(t *testing.T) {
b := sqlbuilder.WithTemplate(template)
assert := assert.New(t)
assert.Equal(
`DELETE FROM "artist" WHERE (name = $1)`,
b.DeleteFrom("artist").Where("name = ?", "Chavela Vargas").String(),
)
assert.Equal(
`DELETE FROM "artist" WHERE (id > 5)`,
b.DeleteFrom("artist").Where("id > 5").String(),
)
}

468
clauses.go Normal file
View File

@ -0,0 +1,468 @@
package mydb
import (
"context"
"fmt"
)
// Selector represents a SELECT statement.
type Selector interface {
// Columns defines which columns to retrive.
//
// You should call From() after Columns() if you want to query data from an
// specific table.
//
// s.Columns("name", "last_name").From(...)
//
// It is also possible to use an alias for the column, this could be handy if
// you plan to use the alias later, use the "AS" keyword to denote an alias.
//
// s.Columns("name AS n")
//
// or the shortcut:
//
// s.Columns("name n")
//
// If you don't want the column to be escaped use the db.Raw
// function.
//
// s.Columns(db.Raw("MAX(id)"))
//
// The above statement is equivalent to:
//
// s.Columns(db.Func("MAX", "id"))
Columns(columns ...interface{}) Selector
// From represents a FROM clause and is tipically used after Columns().
//
// FROM defines from which table data is going to be retrieved
//
// s.Columns(...).From("people")
//
// It is also possible to use an alias for the table, this could be handy if
// you plan to use the alias later:
//
// s.Columns(...).From("people AS p").Where("p.name = ?", ...)
//
// Or with the shortcut:
//
// s.Columns(...).From("people p").Where("p.name = ?", ...)
From(tables ...interface{}) Selector
// Distict represents a DISTINCT clause
//
// DISTINCT is used to ask the database to return only values that are
// different.
Distinct(columns ...interface{}) Selector
// As defines an alias for a table.
As(string) Selector
// Where specifies the conditions that columns must match in order to be
// retrieved.
//
// Where accepts raw strings and fmt.Stringer to define conditions and
// interface{} to specify parameters. Be careful not to embed any parameters
// within the SQL part as that could lead to security problems. You can use
// que question mark (?) as placeholder for parameters.
//
// s.Where("name = ?", "max")
//
// s.Where("name = ? AND last_name = ?", "Mary", "Doe")
//
// s.Where("last_name IS NULL")
//
// You can also use other types of parameters besides only strings, like:
//
// s.Where("online = ? AND last_logged <= ?", true, time.Now())
//
// and Where() will transform them into strings before feeding them to the
// database.
//
// When an unknown type is provided, Where() will first try to match it with
// the Marshaler interface, then with fmt.Stringer and finally, if the
// argument does not satisfy any of those interfaces Where() will use
// fmt.Sprintf("%v", arg) to transform the type into a string.
//
// Subsequent calls to Where() will overwrite previously set conditions, if
// you want these new conditions to be appended use And() instead.
Where(conds ...interface{}) Selector
// And appends more constraints to the WHERE clause without overwriting
// conditions that have been already set.
And(conds ...interface{}) Selector
// GroupBy represents a GROUP BY statement.
//
// GROUP BY defines which columns should be used to aggregate and group
// results.
//
// s.GroupBy("country_id")
//
// GroupBy accepts more than one column:
//
// s.GroupBy("country_id", "city_id")
GroupBy(columns ...interface{}) Selector
// Having(...interface{}) Selector
// OrderBy represents a ORDER BY statement.
//
// ORDER BY is used to define which columns are going to be used to sort
// results.
//
// Use the column name to sort results in ascendent order.
//
// // "last_name" ASC
// s.OrderBy("last_name")
//
// Prefix the column name with the minus sign (-) to sort results in
// descendent order.
//
// // "last_name" DESC
// s.OrderBy("-last_name")
//
// If you would rather be very explicit, you can also use ASC and DESC.
//
// s.OrderBy("last_name ASC")
//
// s.OrderBy("last_name DESC", "name ASC")
OrderBy(columns ...interface{}) Selector
// Join represents a JOIN statement.
//
// JOIN statements are used to define external tables that the user wants to
// include as part of the result.
//
// You can use the On() method after Join() to define the conditions of the
// join.
//
// s.Join("author").On("author.id = book.author_id")
//
// If you don't specify conditions for the join, a NATURAL JOIN will be used.
//
// On() accepts the same arguments as Where()
//
// You can also use Using() after Join().
//
// s.Join("employee").Using("department_id")
Join(table ...interface{}) Selector
// FullJoin is like Join() but with FULL JOIN.
FullJoin(...interface{}) Selector
// CrossJoin is like Join() but with CROSS JOIN.
CrossJoin(...interface{}) Selector
// RightJoin is like Join() but with RIGHT JOIN.
RightJoin(...interface{}) Selector
// LeftJoin is like Join() but with LEFT JOIN.
LeftJoin(...interface{}) Selector
// Using represents the USING clause.
//
// USING is used to specifiy columns to join results.
//
// s.LeftJoin(...).Using("country_id")
Using(...interface{}) Selector
// On represents the ON clause.
//
// ON is used to define conditions on a join.
//
// s.Join(...).On("b.author_id = a.id")
On(...interface{}) Selector
// Limit represents the LIMIT parameter.
//
// LIMIT defines the maximum number of rows to return from the table. A
// negative limit cancels any previous limit settings.
//
// s.Limit(42)
Limit(int) Selector
// Offset represents the OFFSET parameter.
//
// OFFSET defines how many results are going to be skipped before starting to
// return results. A negative offset cancels any previous offset settings.
//
// s.Offset(56)
Offset(int) Selector
// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Selector
// Paginate returns a paginator that can display a paginated lists of items.
// Paginators ignore previous Offset and Limit settings. Page numbering
// starts at 1.
Paginate(uint) Paginator
// Iterator provides methods to iterate over the results returned by the
// Selector.
Iterator() Iterator
// IteratorContext provides methods to iterate over the results returned by
// the Selector.
IteratorContext(ctx context.Context) Iterator
// SQLPreparer provides methods for creating prepared statements.
SQLPreparer
// SQLGetter provides methods to compile and execute a query that returns
// results.
SQLGetter
// ResultMapper provides methods to retrieve and map results.
ResultMapper
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Selector` into a string.
fmt.Stringer
// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
}
// Inserter represents an INSERT statement.
type Inserter interface {
// Columns represents the COLUMNS clause.
//
// COLUMNS defines the columns that we are going to provide values for.
//
// i.Columns("name", "last_name").Values(...)
Columns(...string) Inserter
// Values represents the VALUES clause.
//
// VALUES defines the values of the columns.
//
// i.Columns(...).Values("María", "Méndez")
//
// i.Values(map[string][string]{"name": "María"})
Values(...interface{}) Inserter
// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
// Returning represents a RETURNING clause.
//
// RETURNING specifies which columns should be returned after INSERT.
//
// RETURNING may not be supported by all SQL databases.
Returning(columns ...string) Inserter
// Iterator provides methods to iterate over the results returned by the
// Inserter. This is only possible when using Returning().
Iterator() Iterator
// IteratorContext provides methods to iterate over the results returned by
// the Inserter. This is only possible when using Returning().
IteratorContext(ctx context.Context) Iterator
// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Inserter
// Batch provies a BatchInserter that can be used to insert many elements at
// once by issuing several calls to Values(). It accepts a size parameter
// which defines the batch size. If size is < 1, the batch size is set to 1.
Batch(size int) BatchInserter
// SQLExecer provides the Exec method.
SQLExecer
// SQLPreparer provides methods for creating prepared statements.
SQLPreparer
// SQLGetter provides methods to return query results from INSERT statements
// that support such feature (e.g.: queries with Returning).
SQLGetter
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Inserter` into a string.
fmt.Stringer
}
// Deleter represents a DELETE statement.
type Deleter interface {
// Where represents the WHERE clause.
//
// See Selector.Where for documentation and usage examples.
Where(...interface{}) Deleter
// And appends more constraints to the WHERE clause without overwriting
// conditions that have been already set.
And(conds ...interface{}) Deleter
// Limit represents the LIMIT clause.
//
// See Selector.Limit for documentation and usage examples.
Limit(int) Deleter
// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Deleter
// SQLPreparer provides methods for creating prepared statements.
SQLPreparer
// SQLExecer provides the Exec method.
SQLExecer
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Inserter` into a string.
fmt.Stringer
// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
}
// Updater represents an UPDATE statement.
type Updater interface {
// Set represents the SET clause.
Set(...interface{}) Updater
// Where represents the WHERE clause.
//
// See Selector.Where for documentation and usage examples.
Where(...interface{}) Updater
// And appends more constraints to the WHERE clause without overwriting
// conditions that have been already set.
And(conds ...interface{}) Updater
// Limit represents the LIMIT parameter.
//
// See Selector.Limit for documentation and usage examples.
Limit(int) Updater
// SQLPreparer provides methods for creating prepared statements.
SQLPreparer
// SQLExecer provides the Exec method.
SQLExecer
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Inserter` into a string.
fmt.Stringer
// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
// Amend lets you alter the query's text just before sending it to the
// database server.
Amend(func(queryIn string) (queryOut string)) Updater
}
// Paginator provides tools for splitting the results of a query into chunks
// containing a fixed number of items.
type Paginator interface {
// Page sets the page number.
Page(uint) Paginator
// Cursor defines the column that is going to be taken as basis for
// cursor-based pagination.
//
// Example:
//
// a = q.Paginate(10).Cursor("id")
// b = q.Paginate(12).Cursor("-id")
//
// You can set "" as cursorColumn to disable cursors.
Cursor(cursorColumn string) Paginator
// NextPage returns the next page according to the cursor. It expects a
// cursorValue, which is the value the cursor column has on the last item of
// the current result set (lower bound).
//
// Example:
//
// p = q.NextPage(items[len(items)-1].ID)
NextPage(cursorValue interface{}) Paginator
// PrevPage returns the previous page according to the cursor. It expects a
// cursorValue, which is the value the cursor column has on the fist item of
// the current result set (mydb bound).
//
// Example:
//
// p = q.PrevPage(items[0].ID)
PrevPage(cursorValue interface{}) Paginator
// TotalPages returns the total number of pages in the query.
TotalPages() (uint, error)
// TotalEntries returns the total number of entries in the query.
TotalEntries() (uint64, error)
// SQLPreparer provides methods for creating prepared statements.
SQLPreparer
// SQLGetter provides methods to compile and execute a query that returns
// results.
SQLGetter
// Iterator provides methods to iterate over the results returned by the
// Selector.
Iterator() Iterator
// IteratorContext provides methods to iterate over the results returned by
// the Selector.
IteratorContext(ctx context.Context) Iterator
// ResultMapper provides methods to retrieve and map results.
ResultMapper
// fmt.Stringer provides `String() string`, you can use `String()` to compile
// the `Selector` into a string.
fmt.Stringer
// Arguments returns the arguments that are prepared for this query.
Arguments() []interface{}
}
// ResultMapper defined methods for a result mapper.
type ResultMapper interface {
// All dumps all the results into the given slice, All() expects a pointer to
// slice of maps or structs.
//
// The behaviour of One() extends to each one of the results.
All(destSlice interface{}) error
// One maps the row that is in the current query cursor into the
// given interface, which can be a pointer to either a map or a
// struct.
//
// If dest is a pointer to map, each one of the columns will create a new map
// key and the values of the result will be set as values for the keys.
//
// Depending on the type of map key and value, the results columns and values
// may need to be transformed.
//
// If dest if a pointer to struct, each one of the fields will be tested for
// a `db` tag which defines the column mapping. The value of the result will
// be set as the value of the field.
One(dest interface{}) error
}
// BatchInserter provides an interface to do massive insertions in batches.
type BatchInserter interface {
// Values pushes column values to be inserted as part of the batch.
Values(...interface{}) BatchInserter
// NextResult dumps the next slice of results to dst, which can mean having
// the IDs of all inserted elements in the batch.
NextResult(dst interface{}) bool
// Done signals that no more elements are going to be added.
Done()
// Wait blocks until the whole batch is executed.
Wait() error
// Err returns the last error that happened while executing the batch (or nil
// if no error happened).
Err() error
}

45
collection.go Normal file
View File

@ -0,0 +1,45 @@
package mydb
// Collection defines methods to work with database tables or collections.
type Collection interface {
// Name returns the name of the collection.
Name() string
// Session returns the Session that was used to create the collection
// reference.
Session() Session
// Find defines a new result set.
Find(...interface{}) Result
Count() (uint64, error)
// Insert inserts a new item into the collection, the type of this item could
// be a map, a struct or pointer to either of them. If the call succeeds and
// if the collection has a primary key, Insert returns the ID of the newly
// added element as an `interface{}`. The underlying type of this ID depends
// on both the database adapter and the column storing the ID. The ID
// returned by Insert() could be passed directly to Find() to retrieve the
// newly added element.
Insert(interface{}) (InsertResult, error)
// InsertReturning is like Insert() but it takes a pointer to map or struct
// and, if the operation succeeds, updates it with data from the newly
// inserted row. If the database does not support transactions this method
// returns db.ErrUnsupported.
InsertReturning(interface{}) error
// UpdateReturning takes a pointer to a map or struct and tries to update the
// row the item is refering to. If the element is updated sucessfully,
// UpdateReturning will fetch the row and update the fields of the passed
// item. If the database does not support transactions this method returns
// db.ErrUnsupported
UpdateReturning(interface{}) error
// Exists returns true if the collection exists, false otherwise.
Exists() (bool, error)
// Truncate removes all elements on the collection.
Truncate() error
}

158
comparison.go Normal file
View File

@ -0,0 +1,158 @@
package mydb
import (
"reflect"
"time"
"git.hexq.cn/tiglog/mydb/internal/adapter"
)
// Comparison represents a relationship between values.
type Comparison struct {
*adapter.Comparison
}
// Gte is a comparison that means: is greater than or equal to value.
func Gte(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)}
}
// Lte is a comparison that means: is less than or equal to value.
func Lte(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)}
}
// Eq is a comparison that means: is equal to value.
func Eq(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, value)}
}
// NotEq is a comparison that means: is not equal to value.
func NotEq(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, value)}
}
// Gt is a comparison that means: is greater than value.
func Gt(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)}
}
// Lt is a comparison that means: is less than value.
func Lt(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)}
}
// In is a comparison that means: is any of the values.
func In(value ...interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))}
}
// AnyOf is a comparison that means: is any of the values of the slice.
func AnyOf(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, toInterfaceArray(value))}
}
// NotIn is a comparison that means: is none of the values.
func NotIn(value ...interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))}
}
// NotAnyOf is a comparison that means: is none of the values of the slice.
func NotAnyOf(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotIn, toInterfaceArray(value))}
}
// After is a comparison that means: is after the (time.Time) value.
func After(value time.Time) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, value)}
}
// Before is a comparison that means: is before the (time.Time) value.
func Before(value time.Time) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, value)}
}
// OnOrAfter is a comparison that means: is on or after the (time.Time) value.
func OnOrAfter(value time.Time) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, value)}
}
// OnOrBefore is a comparison that means: is on or before the (time.Time) value.
func OnOrBefore(value time.Time) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, value)}
}
// Between is a comparison that means: is between lowerBound and upperBound.
func Between(lowerBound interface{}, upperBound interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{lowerBound, upperBound})}
}
// NotBetween is a comparison that means: is not between lowerBound and upperBound.
func NotBetween(lowerBound interface{}, upperBound interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{lowerBound, upperBound})}
}
// Is is a comparison that means: is equivalent to nil, true or false.
func Is(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, value)}
}
// IsNot is a comparison that means: is not equivalent to nil, true nor false.
func IsNot(value interface{}) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, value)}
}
// IsNull is a comparison that means: is equivalent to nil.
func IsNull() *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil)}
}
// IsNotNull is a comparison that means: is not equivalent to nil.
func IsNotNull() *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil)}
}
// Like is a comparison that checks whether the reference matches the wildcard
// value.
func Like(value string) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, value)}
}
// NotLike is a comparison that checks whether the reference does not match the
// wildcard value.
func NotLike(value string) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, value)}
}
// RegExp is a comparison that checks whether the reference matches the regular
// expression.
func RegExp(value string) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, value)}
}
// NotRegExp is a comparison that checks whether the reference does not match
// the regular expression.
func NotRegExp(value string) *Comparison {
return &Comparison{adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, value)}
}
// Op returns a custom comparison operator.
func Op(customOperator string, value interface{}) *Comparison {
return &Comparison{adapter.NewCustomComparisonOperator(customOperator, value)}
}
func toInterfaceArray(value interface{}) []interface{} {
rv := reflect.ValueOf(value)
switch rv.Type().Kind() {
case reflect.Ptr:
return toInterfaceArray(rv.Elem().Interface())
case reflect.Slice:
elems := rv.Len()
args := make([]interface{}, elems)
for i := 0; i < elems; i++ {
args[i] = rv.Index(i).Interface()
}
return args
}
return []interface{}{value}
}

111
comparison_test.go Normal file
View File

@ -0,0 +1,111 @@
package mydb
import (
"testing"
"time"
"git.hexq.cn/tiglog/mydb/internal/adapter"
"github.com/stretchr/testify/assert"
)
func TestComparison(t *testing.T) {
testTimeVal := time.Now()
testCases := []struct {
expects *adapter.Comparison
result *Comparison
}{
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, 1),
Gte(1),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, 22),
Lte(22),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorEqual, 6),
Eq(6),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotEqual, 67),
NotEq(67),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, 4),
Gt(4),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, 47),
Lt(47),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorIn, []interface{}{1, 22, 34}),
In(1, 22, 34),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThan, testTimeVal),
After(testTimeVal),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThan, testTimeVal),
Before(testTimeVal),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorGreaterThanOrEqualTo, testTimeVal),
OnOrAfter(testTimeVal),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorLessThanOrEqualTo, testTimeVal),
OnOrBefore(testTimeVal),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorBetween, []interface{}{11, 35}),
Between(11, 35),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotBetween, []interface{}{11, 35}),
NotBetween(11, 35),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, 178),
Is(178),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, 32),
IsNot(32),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorIs, nil),
IsNull(),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorIsNot, nil),
IsNotNull(),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorLike, "%a%"),
Like("%a%"),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotLike, "%z%"),
NotLike("%z%"),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorRegExp, ".*"),
RegExp(".*"),
},
{
adapter.NewComparisonOperator(adapter.ComparisonOperatorNotRegExp, ".*"),
NotRegExp(".*"),
},
{
adapter.NewCustomComparisonOperator("~", 56),
Op("~", 56),
},
}
for i := range testCases {
assert.Equal(t, testCases[i].expects, testCases[i].result.Comparison)
}
}

109
cond.go Normal file
View File

@ -0,0 +1,109 @@
package mydb
import (
"fmt"
"sort"
"git.hexq.cn/tiglog/mydb/internal/adapter"
)
// LogicalExpr represents an expression to be used in logical statements.
type LogicalExpr = adapter.LogicalExpr
// LogicalOperator represents a logical operation.
type LogicalOperator = adapter.LogicalOperator
// Cond is a map that defines conditions for a query.
//
// Each entry of the map represents a condition (a column-value relation bound
// by a comparison Operator). The comparison can be specified after the column
// name, if no comparison operator is provided the equality operator is used as
// default.
//
// Examples:
//
// // Age equals 18.
// db.Cond{"age": 18}
//
// // Age is greater than or equal to 18.
// db.Cond{"age >=": 18}
//
// // id is any of the values 1, 2 or 3.
// db.Cond{"id IN": []{1, 2, 3}}
//
// // Age is lower than 18 (MongoDB syntax)
// db.Cond{"age $lt": 18}
//
// // age > 32 and age < 35
// db.Cond{"age >": 32, "age <": 35}
type Cond map[interface{}]interface{}
// Empty returns false if there are no conditions.
func (c Cond) Empty() bool {
for range c {
return false
}
return true
}
// Constraints returns each one of the Cond map entires as a constraint.
func (c Cond) Constraints() []adapter.Constraint {
z := make([]adapter.Constraint, 0, len(c))
for _, k := range c.keys() {
z = append(z, adapter.NewConstraint(k, c[k]))
}
return z
}
// Operator returns the equality operator.
func (c Cond) Operator() LogicalOperator {
return adapter.DefaultLogicalOperator
}
func (c Cond) keys() []interface{} {
keys := make(condKeys, 0, len(c))
for k := range c {
keys = append(keys, k)
}
if len(c) > 1 {
sort.Sort(keys)
}
return keys
}
// Expressions returns all the expressions contained in the condition.
func (c Cond) Expressions() []LogicalExpr {
z := make([]LogicalExpr, 0, len(c))
for _, k := range c.keys() {
z = append(z, Cond{k: c[k]})
}
return z
}
type condKeys []interface{}
func (ck condKeys) Len() int {
return len(ck)
}
func (ck condKeys) Less(i, j int) bool {
return fmt.Sprintf("%v", ck[i]) < fmt.Sprintf("%v", ck[j])
}
func (ck condKeys) Swap(i, j int) {
ck[i], ck[j] = ck[j], ck[i]
}
func defaultJoin(in ...adapter.LogicalExpr) []adapter.LogicalExpr {
for i := range in {
cond, ok := in[i].(Cond)
if ok && !cond.Empty() {
in[i] = And(cond)
}
}
return in
}
var (
_ = LogicalExpr(Cond{})
)

69
cond_test.go Normal file
View File

@ -0,0 +1,69 @@
package mydb
import (
"testing"
)
func TestCond(t *testing.T) {
c := Cond{}
if !c.Empty() {
t.Fatal("Cond is empty.")
}
c = Cond{"id": 1}
if c.Empty() {
t.Fatal("Cond is not empty.")
}
}
func TestCondAnd(t *testing.T) {
a := And()
if !a.Empty() {
t.Fatal("Cond is empty")
}
_ = a.And(Cond{"id": 1})
if !a.Empty() {
t.Fatal("Cond is still empty")
}
a = a.And(Cond{"name": "Ana"})
if a.Empty() {
t.Fatal("Cond is not empty anymore")
}
a = a.And().And()
if a.Empty() {
t.Fatal("Cond is not empty anymore")
}
}
func TestCondOr(t *testing.T) {
a := Or()
if !a.Empty() {
t.Fatal("Cond is empty")
}
_ = a.Or(Cond{"id": 1})
if !a.Empty() {
t.Fatal("Cond is empty")
}
a = a.Or(Cond{"name": "Ana"})
if a.Empty() {
t.Fatal("Cond is not empty")
}
a = a.Or().Or()
if a.Empty() {
t.Fatal("Cond is not empty")
}
}

8
connection_url.go Normal file
View File

@ -0,0 +1,8 @@
package mydb
// ConnectionURL represents a data source name (DSN).
type ConnectionURL interface {
// String returns the connection string that is going to be passed to the
// adapter.
String() string
}

42
errors.go Normal file
View File

@ -0,0 +1,42 @@
package mydb
import (
"errors"
)
// Error messages
var (
ErrMissingAdapter = errors.New(`mydb: missing adapter`)
ErrAlreadyWithinTransaction = errors.New(`mydb: already within a transaction`)
ErrCollectionDoesNotExist = errors.New(`mydb: collection does not exist`)
ErrExpectingNonNilModel = errors.New(`mydb: expecting non nil model`)
ErrExpectingPointerToStruct = errors.New(`mydb: expecting pointer to struct`)
ErrGivingUpTryingToConnect = errors.New(`mydb: giving up trying to connect: too many clients`)
ErrInvalidCollection = errors.New(`mydb: invalid collection`)
ErrMissingCollectionName = errors.New(`mydb: missing collection name`)
ErrMissingConditions = errors.New(`mydb: missing selector conditions`)
ErrMissingConnURL = errors.New(`mydb: missing DSN`)
ErrMissingDatabaseName = errors.New(`mydb: missing database name`)
ErrNoMoreRows = errors.New(`mydb: no more rows in this result set`)
ErrNotConnected = errors.New(`mydb: not connected to a database`)
ErrNotImplemented = errors.New(`mydb: call not implemented`)
ErrQueryIsPending = errors.New(`mydb: can't execute this instruction while the result set is still open`)
ErrQueryLimitParam = errors.New(`mydb: a query can accept only one limit parameter`)
ErrQueryOffsetParam = errors.New(`mydb: a query can accept only one offset parameter`)
ErrQuerySortParam = errors.New(`mydb: a query can accept only one order-by parameter`)
ErrSockerOrHost = errors.New(`mydb: you may connect either to a UNIX socket or a TCP address, but not both`)
ErrTooManyClients = errors.New(`mydb: can't connect to database server: too many clients`)
ErrUndefined = errors.New(`mydb: value is undefined`)
ErrUnknownConditionType = errors.New(`mydb: arguments of type %T can't be used as constraints`)
ErrUnsupported = errors.New(`mydb: action is not supported by the DBMS`)
ErrUnsupportedDestination = errors.New(`mydb: unsupported destination type`)
ErrUnsupportedType = errors.New(`mydb: type does not support marshaling`)
ErrUnsupportedValue = errors.New(`mydb: value does not support unmarshaling`)
ErrNilRecord = errors.New(`mydb: invalid item (nil)`)
ErrRecordIDIsZero = errors.New(`mydb: item ID is not defined`)
ErrMissingPrimaryKeys = errors.New(`mydb: collection %q has no primary keys`)
ErrWarnSlowQuery = errors.New(`mydb: slow query`)
ErrTransactionAborted = errors.New(`mydb: transaction was aborted`)
ErrNotWithinTransaction = errors.New(`mydb: not within transaction`)
ErrNotSupportedByAdapter = errors.New(`mydb: not supported by adapter`)
)

14
errors_test.go Normal file
View File

@ -0,0 +1,14 @@
package mydb
import (
"errors"
"fmt"
"testing"
"github.com/stretchr/testify/assert"
)
func TestErrorWrap(t *testing.T) {
adapterFakeErr := fmt.Errorf("could not find item in %q: %w", "users", ErrCollectionDoesNotExist)
assert.True(t, errors.Is(adapterFakeErr, ErrCollectionDoesNotExist))
}

25
function.go Normal file
View File

@ -0,0 +1,25 @@
package mydb
import "git.hexq.cn/tiglog/mydb/internal/adapter"
// FuncExpr represents functions.
type FuncExpr = adapter.FuncExpr
// Func returns a database function expression.
//
// Examples:
//
// // MOD(29, 9)
// db.Func("MOD", 29, 9)
//
// // CONCAT("foo", "bar")
// db.Func("CONCAT", "foo", "bar")
//
// // NOW()
// db.Func("NOW")
//
// // RTRIM("Hello ")
// db.Func("RTRIM", "Hello ")
func Func(name string, args ...interface{}) *FuncExpr {
return adapter.NewFuncExpr(name, args)
}

51
function_test.go Normal file
View File

@ -0,0 +1,51 @@
package mydb
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestFunction(t *testing.T) {
{
fn := Func("MOD", 29, 9)
assert.Equal(t, "MOD", fn.Name())
assert.Equal(t, []interface{}{29, 9}, fn.Arguments())
}
{
fn := Func("HELLO")
assert.Equal(t, "HELLO", fn.Name())
assert.Equal(t, []interface{}(nil), fn.Arguments())
}
{
fn := Func("CONCAT", "a")
assert.Equal(t, "CONCAT", fn.Name())
assert.Equal(t, []interface{}{"a"}, fn.Arguments())
}
{
fn := Func("CONCAT", "a", "b", "c")
assert.Equal(t, "CONCAT", fn.Name())
assert.Equal(t, []interface{}{"a", "b", "c"}, fn.Arguments())
}
{
fn := Func("IN", []interface{}{"a", "b", "c"})
assert.Equal(t, "IN", fn.Name())
assert.Equal(t, []interface{}{[]interface{}{"a", "b", "c"}}, fn.Arguments())
}
{
fn := Func("IN", []interface{}{"a"})
assert.Equal(t, "IN", fn.Name())
assert.Equal(t, []interface{}{[]interface{}{"a"}}, fn.Arguments())
}
{
fn := Func("IN", []interface{}(nil))
assert.Equal(t, "IN", fn.Name())
assert.Equal(t, []interface{}{[]interface{}(nil)}, fn.Arguments())
}
}

33
go.mod Normal file
View File

@ -0,0 +1,33 @@
module git.hexq.cn/tiglog/mydb
go 1.20
require (
github.com/go-sql-driver/mysql v1.7.1
github.com/google/uuid v1.1.1
github.com/ipfs/go-detect-race v0.0.1
github.com/jackc/pgtype v1.14.0
github.com/jackc/pgx/v4 v4.18.1
github.com/lib/pq v1.10.9
github.com/mattn/go-sqlite3 v1.14.17
github.com/segmentio/fasthash v1.0.3
github.com/sirupsen/logrus v1.9.3
github.com/stretchr/testify v1.8.4
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22
)
require (
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/jackc/chunkreader/v2 v2.0.1 // indirect
github.com/jackc/pgconn v1.14.1 // indirect
github.com/jackc/pgio v1.0.0 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgproto3/v2 v2.3.2 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
golang.org/x/crypto v0.12.0 // indirect
golang.org/x/sys v0.12.0 // indirect
golang.org/x/text v0.13.0 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
)

226
go.sum Normal file
View File

@ -0,0 +1,226 @@
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/Masterminds/semver/v3 v3.1.1 h1:hLg3sBzpNErnxhQtUy/mmLR2I9foDujNK030IGemrRc=
github.com/Masterminds/semver/v3 v3.1.1/go.mod h1:VPu/7SZ7ePZ3QOrcuXROw5FAcLl4a0cBrbBpGY/8hQs=
github.com/cockroachdb/apd v1.1.0 h1:3LFP3629v+1aKXU5Q37mxmRxX/pIu1nijXydLShEq5I=
github.com/cockroachdb/apd v1.1.0/go.mod h1:8Sl8LxpKi29FqWXR16WEFZRNSz3SoPzUzeMeY4+DwBQ=
github.com/coreos/go-systemd v0.0.0-20190321100706-95778dfbb74e/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/coreos/go-systemd v0.0.0-20190719114852-fd7a80b32e1f/go.mod h1:F5haX7vjVVG0kc13fIWeqUViNPyEJxv/OmvnBo0Yme4=
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/go-kit/log v0.1.0/go.mod h1:zbhenjAZHb184qTLMA9ZjW7ThYL0H2mk7Q6pNt4vbaY=
github.com/go-logfmt/logfmt v0.5.0/go.mod h1:wCYkCAKZfumFQihp8CzCvQ3paCTfi41vtzG1KdI/P7A=
github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI=
github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI=
github.com/go-stack/stack v1.8.0/go.mod h1:v0f6uXyyMGvRgIKkXu+yp6POWl0qKG85gN/melR3HDY=
github.com/gofrs/uuid v4.0.0+incompatible h1:1SD/1F5pU8p29ybwgQSwpQk+mwdRrXCYuPhW6m+TnJw=
github.com/gofrs/uuid v4.0.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=
github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI=
github.com/google/uuid v1.1.1 h1:Gkbcsh/GbpXz7lPftLA3P6TYMwjCLYm83jiFQZF/3gY=
github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ6cOk=
github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps=
github.com/jackc/chunkreader v1.0.0/go.mod h1:RT6O25fNZIuasFJRyZ4R/Y2BbhasbmZXF9QQ7T3kePo=
github.com/jackc/chunkreader/v2 v2.0.0/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/chunkreader/v2 v2.0.1 h1:i+RDz65UE+mmpjTfyz0MoVTnzeYxroil2G82ki7MGG8=
github.com/jackc/chunkreader/v2 v2.0.1/go.mod h1:odVSm741yZoC3dpHEUXIqA9tQRhFrgOHwnPIn9lDKlk=
github.com/jackc/pgconn v0.0.0-20190420214824-7e0022ef6ba3/go.mod h1:jkELnwuX+w9qN5YIfX0fl88Ehu4XC3keFuOJJk9pcnA=
github.com/jackc/pgconn v0.0.0-20190824142844-760dd75542eb/go.mod h1:lLjNuW/+OfW9/pnVKPazfWOgNfH2aPem8YQ7ilXGvJE=
github.com/jackc/pgconn v0.0.0-20190831204454-2fabfa3c18b7/go.mod h1:ZJKsE/KZfsUgOEh9hBm+xYTstcNHg7UPMVJqRfQxq4s=
github.com/jackc/pgconn v1.8.0/go.mod h1:1C2Pb36bGIP9QHGBYCjnyhqu7Rv3sGshaQUvmfGIB/o=
github.com/jackc/pgconn v1.9.0/go.mod h1:YctiPyvzfU11JFxoXokUOOKQXQmDMoJL9vJzHH8/2JY=
github.com/jackc/pgconn v1.9.1-0.20210724152538-d89c8390a530/go.mod h1:4z2w8XhRbP1hYxkpTuBjTS3ne3J48K83+u0zoyvg2pI=
github.com/jackc/pgconn v1.14.0/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E=
github.com/jackc/pgconn v1.14.1 h1:smbxIaZA08n6YuxEX1sDyjV/qkbtUtkH20qLkR9MUR4=
github.com/jackc/pgconn v1.14.1/go.mod h1:9mBNlny0UvkgJdCDvdVHYSjI+8tD2rnKK69Wz8ti++E=
github.com/jackc/pgio v1.0.0 h1:g12B9UwVnzGhueNavwioyEEpAmqMe1E/BN9ES+8ovkE=
github.com/jackc/pgio v1.0.0/go.mod h1:oP+2QK2wFfUWgr+gxjoBH9KGBb31Eio69xUb0w5bYf8=
github.com/jackc/pgmock v0.0.0-20190831213851-13a1b77aafa2/go.mod h1:fGZlG77KXmcq05nJLRkk0+p82V8B8Dw8KN2/V9c/OAE=
github.com/jackc/pgmock v0.0.0-20201204152224-4fe30f7445fd/go.mod h1:hrBW0Enj2AZTNpt/7Y5rr2xe/9Mn757Wtb2xeBzPv2c=
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65 h1:DadwsjnMwFjfWc9y5Wi/+Zz7xoE5ALHsRQlOctkOiHc=
github.com/jackc/pgmock v0.0.0-20210724152146-4ad1a8207f65/go.mod h1:5R2h2EEX+qri8jOWMbJCtaPWkrrNc7OHwsp2TCqp7ak=
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
github.com/jackc/pgproto3 v1.1.0/go.mod h1:eR5FA3leWg7p9aeAqi37XOTgTIbkABlvcPB3E5rlc78=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190420180111-c116219b62db/go.mod h1:bhq50y+xrl9n5mRYyCBFKkpRVTLYJVWeCc+mEAI3yXA=
github.com/jackc/pgproto3/v2 v2.0.0-alpha1.0.20190609003834-432c2951c711/go.mod h1:uH0AWtUmuShn0bcesswc4aBTWGvw0cAxIJp+6OB//Wg=
github.com/jackc/pgproto3/v2 v2.0.0-rc3/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgproto3/v2 v2.0.0-rc3.0.20190831210041-4c03ce451f29/go.mod h1:ryONWYqW6dqSg1Lw6vXNMXoBJhpzvWKnT95C46ckYeM=
github.com/jackc/pgproto3/v2 v2.0.6/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.1.1/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgproto3/v2 v2.3.2 h1:7eY55bdBeCz1F2fTzSz69QC+pG46jYq9/jtSPiJ5nn0=
github.com/jackc/pgproto3/v2 v2.3.2/go.mod h1:WfJCnwN3HIg9Ish/j3sgWXnAfK8A9Y0bwXYU5xKaEdA=
github.com/jackc/pgservicefile v0.0.0-20200714003250-2b9c44734f2b/go.mod h1:vsD4gTJCa9TptPL8sPkXrLZ+hDuNrZCnj29CQpr4X1E=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
github.com/jackc/pgtype v0.0.0-20190421001408-4ed0de4755e0/go.mod h1:hdSHsc1V01CGwFsrv11mJRHWJ6aifDLfdV3aVjFF0zg=
github.com/jackc/pgtype v0.0.0-20190824184912-ab885b375b90/go.mod h1:KcahbBH1nCMSo2DXpzsoWOAfFkdEtEJpPbVLq8eE+mc=
github.com/jackc/pgtype v0.0.0-20190828014616-a8802b16cc59/go.mod h1:MWlu30kVJrUS8lot6TQqcg7mtthZ9T0EoIBFiJcmcyw=
github.com/jackc/pgtype v1.8.1-0.20210724151600-32e20a603178/go.mod h1:C516IlIV9NKqfsMCXTdChteoXmwgUceqaLfjg2e3NlM=
github.com/jackc/pgtype v1.14.0 h1:y+xUdabmyMkJLyApYuPj38mW+aAIqCe5uuBB51rH3Vw=
github.com/jackc/pgtype v1.14.0/go.mod h1:LUMuVrfsFfdKGLw+AFFVv6KtHOFMwRgDDzBt76IqCA4=
github.com/jackc/pgx/v4 v4.0.0-20190420224344-cc3461e65d96/go.mod h1:mdxmSJJuR08CZQyj1PVQBHy9XOp5p8/SHH6a0psbY9Y=
github.com/jackc/pgx/v4 v4.0.0-20190421002000-1b8f0016e912/go.mod h1:no/Y67Jkk/9WuGR0JG/JseM9irFbnEPbuWV2EELPNuM=
github.com/jackc/pgx/v4 v4.0.0-pre1.0.20190824185557-6972a5742186/go.mod h1:X+GQnOEnf1dqHGpw7JmHqHc1NxDoalibchSk9/RWuDc=
github.com/jackc/pgx/v4 v4.12.1-0.20210724153913-640aa07df17c/go.mod h1:1QD0+tgSXP7iUjYm9C1NxKhny7lq6ee99u/z+IHFcgs=
github.com/jackc/pgx/v4 v4.18.1 h1:YP7G1KABtKpB5IHrO9vYwSrCOhs7p3uqhvhhQBptya0=
github.com/jackc/pgx/v4 v4.18.1/go.mod h1:FydWkUyadDmdNH/mHnGob881GawxeEm7TcMCzkb+qQE=
github.com/jackc/puddle v0.0.0-20190413234325-e4ced69a3a2b/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v0.0.0-20190608224051-11cab39313c9/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.1.3/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/jackc/puddle v1.3.0/go.mod h1:m4B5Dj62Y0fbyuIc15OsIqK0+JU8nkqQjsgx7dvjSWk=
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/konsorten/go-windows-terminal-sequences v1.0.2/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
github.com/kr/pty v1.1.8/go.mod h1:O1sed60cT9XZ5uDucP5qwvh+TE3NnUj51EiZO/lmSfw=
github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
github.com/lib/pq v1.0.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.1.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo=
github.com/lib/pq v1.10.2/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw=
github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o=
github.com/mattn/go-colorable v0.1.1/go.mod h1:FuOcm+DKB9mbwrcAfNl7/TZVBZ6rcnceauSikq3lYCQ=
github.com/mattn/go-colorable v0.1.6/go.mod h1:u6P/XSegPjTcexA+o6vUJrdnUu04hMope9wVRipJSqc=
github.com/mattn/go-isatty v0.0.5/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.7/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-isatty v0.0.12/go.mod h1:cbi8OIDigv2wuxKPP5vlRcQ1OAZbq2CE4Kysco4FUpU=
github.com/mattn/go-sqlite3 v1.14.17 h1:mCRHCLDUBXgpKAqIKsaAaAsrAlbkeomtRFKXh2L6YIM=
github.com/mattn/go-sqlite3 v1.14.17/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
github.com/rs/xid v1.2.1/go.mod h1:+uKXf+4Djp6Md1KODXJxgGQPKngRmWyn10oCKFzNHOQ=
github.com/rs/zerolog v1.13.0/go.mod h1:YbFCdg8HfsridGWAh22vktObvhZbQsZXe4/zB0OKkWU=
github.com/rs/zerolog v1.15.0/go.mod h1:xYTKnLHcpfU2225ny5qZjxnj9NvkumZYjJHlAThCjNc=
github.com/satori/go.uuid v1.2.0/go.mod h1:dA0hQrYB0VpLJoorglMZABFdXlWrHn1NEOzdhQKdks0=
github.com/segmentio/fasthash v1.0.3 h1:EI9+KE1EwvMLBWwjpRDc+fEM+prwxDYbslddQGtrmhM=
github.com/segmentio/fasthash v1.0.3/go.mod h1:waKX8l2N8yckOgmSsXJi7x1ZfdKZ4x7KRMzBtS3oedY=
github.com/shopspring/decimal v0.0.0-20180709203117-cd690d0c9e24/go.mod h1:M+9NzErvs504Cn4c5DxATwIqPbtswREoFCre64PpcG4=
github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ=
github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o=
github.com/sirupsen/logrus v1.4.1/go.mod h1:ni0Sbl8bgC9z8RoU9G6nDWqqs/fq4eDPysMBDgk/93Q=
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/objx v0.2.0/go.mod h1:qt09Ya8vawLte6SNmTgCsAVtYtaKzEcn8ATUoHMkEqE=
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA=
github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU=
github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/zenazn/goji v0.9.0/go.mod h1:7S9M489iMyHBNxwZnk9/EHS098H4/F6TATF2mIxtB1Q=
go.uber.org/atomic v1.3.2/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.4.0/go.mod h1:gD2HeocX3+yG+ygLZcrzQJaqmWj9AIm7n08wl/qW/PE=
go.uber.org/atomic v1.5.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/atomic v1.6.0/go.mod h1:sABNBOSYdrvTF6hTgEIbc7YasKWGhgEQZyfxyTvoXHQ=
go.uber.org/multierr v1.1.0/go.mod h1:wR5kodmAFQ0UK8QlbwjlSNy0Z68gJhDJUG5sjR94q/0=
go.uber.org/multierr v1.3.0/go.mod h1:VgVr7evmIr6uPjLBxg28wmKNXyqE9akIJ5XnfpiKl+4=
go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU=
go.uber.org/tools v0.0.0-20190618225709-2cfd321de3ee/go.mod h1:vJERXedbb3MVM5f9Ejo0C68/HhF8uaILCdgjnY+goOA=
go.uber.org/zap v1.9.1/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.10.0/go.mod h1:vwi/ZaCAaUcBkycHslxD9B2zi4UTXhF60s6SWpuDF0Q=
go.uber.org/zap v1.13.0/go.mod h1:zwrFLgMcdUuIBviXEYEH1YKNaOBnKXsx2IPda5bBwHM=
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/crypto v0.0.0-20190411191339-88737f569e3a/go.mod h1:WFFai1msRO1wXaEeE5yQxYXgSfI8pQAWXbQop6sCtWE=
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20190820162420-60c769a6c586/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto=
golang.org/x/crypto v0.0.0-20201203163018-be400aefbc4c/go.mod h1:jdWPYTVW3xRLrWPugEBEK3UY2ZEsg3UU495nc5E+M+I=
golang.org/x/crypto v0.0.0-20210616213533-5ff15b29337e/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210711020723-a769d52b0f97/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc=
golang.org/x/crypto v0.6.0/go.mod h1:OFC/31mSvZgRz0V1QTNCzfAI1aIRzbiufJtkMIlEp58=
golang.org/x/crypto v0.12.0 h1:tFM/ta59kqch6LlvYnPa0yx5a83cL2nHflFhYKvv9Yk=
golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw=
golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc=
golang.org/x/mod v0.0.0-20190513183733-4bf6d317e70e/go.mod h1:mXi4GBBbnImb6dmsKGUJ2LatrhH/nqhxcFungHvyanc=
golang.org/x/mod v0.1.1-0.20191105210325-c90efee705ee/go.mod h1:QqPTAvyqsEbceGzBzNggFXnrqF1CaUcvgkdR5Ot7KZg=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20190813141303-74dc4d7220e7/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg=
golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
golang.org/x/sys v0.0.0-20180905080454-ebe1bf3edb33/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
golang.org/x/sys v0.0.0-20190403152447-81d4e9dc473e/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20190813064441-fde4db37ae7a/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200116001909-b77594299b42/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20200223170610-d5e6a3e2c0ae/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0 h1:CM0HF96J0hcLAwsHPJZjfdNzs0gftsLfgKt57wWHJ0o=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201117132131-f5c789dd3221/go.mod h1:Nr5EML6q2oocZ2LXRh80K7BxOlk5/8JxuGnuhpl+muw=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
golang.org/x/text v0.13.0 h1:ablQoSUd0tRdKxZewP80B+BaqeKJuVhuRxj/dkrun3k=
golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs=
golang.org/x/tools v0.0.0-20190425163242-31fd60d6bfdc/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q=
golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc=
golang.org/x/tools v0.0.0-20190823170909-c4a336ef6a2f/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo=
golang.org/x/tools v0.0.0-20200103221440-774c71fcf114/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/xerrors v0.0.0-20190410155217-1f06c39b4373/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190513163551-3ee3066db522/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI=
gopkg.in/inconshreveable/log15.v2 v2.0.0-20180818164646-67afb5ed74ec/go.mod h1:aPpfJ7XW+gOuirDoZ8gHhLh3kZ1B08FtV2bbmy7Jv3s=
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22 h1:VpOs+IwYnYBaFnrNAeB8UUWtL3vEUnzSCL1nVjPhqrw=
gopkg.in/mgo.v2 v2.0.0-20190816093944-a6b53ec6cb22/go.mod h1:yeKp02qBN3iKW1OzL3MGk2IdtZzaj7SFntXj72NppTA=
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY=
gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ=
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg=

View File

@ -0,0 +1,60 @@
package adapter
// ComparisonOperator is the base type for comparison operators.
type ComparisonOperator uint8
// Comparison operators
const (
ComparisonOperatorNone ComparisonOperator = iota
ComparisonOperatorCustom
ComparisonOperatorEqual
ComparisonOperatorNotEqual
ComparisonOperatorLessThan
ComparisonOperatorGreaterThan
ComparisonOperatorLessThanOrEqualTo
ComparisonOperatorGreaterThanOrEqualTo
ComparisonOperatorBetween
ComparisonOperatorNotBetween
ComparisonOperatorIn
ComparisonOperatorNotIn
ComparisonOperatorIs
ComparisonOperatorIsNot
ComparisonOperatorLike
ComparisonOperatorNotLike
ComparisonOperatorRegExp
ComparisonOperatorNotRegExp
)
type Comparison struct {
t ComparisonOperator
op string
v interface{}
}
func (c *Comparison) CustomOperator() string {
return c.op
}
func (c *Comparison) Operator() ComparisonOperator {
return c.t
}
func (c *Comparison) Value() interface{} {
return c.v
}
func NewComparisonOperator(t ComparisonOperator, v interface{}) *Comparison {
return &Comparison{t: t, v: v}
}
func NewCustomComparisonOperator(op string, v interface{}) *Comparison {
return &Comparison{t: ComparisonOperatorCustom, op: op, v: v}
}

View File

@ -0,0 +1,51 @@
package adapter
// ConstraintValuer allows constraints to use specific values of their own.
type ConstraintValuer interface {
ConstraintValue() interface{}
}
// Constraint interface represents a single condition, like "a = 1". where `a`
// is the key and `1` is the value. This is an exported interface but it's
// rarely used directly, you may want to use the `db.Cond{}` map instead.
type Constraint interface {
// Key is the leftmost part of the constraint and usually contains a column
// name.
Key() interface{}
// Value if the rightmost part of the constraint and usually contains a
// column value.
Value() interface{}
}
// Constraints interface represents an array of constraints, like "a = 1, b =
// 2, c = 3".
type Constraints interface {
// Constraints returns an array of constraints.
Constraints() []Constraint
}
type constraint struct {
k interface{}
v interface{}
}
func (c constraint) Key() interface{} {
return c.k
}
func (c constraint) Value() interface{} {
if constraintValuer, ok := c.v.(ConstraintValuer); ok {
return constraintValuer.ConstraintValue()
}
return c.v
}
// NewConstraint creates a constraint.
func NewConstraint(key interface{}, value interface{}) Constraint {
return &constraint{k: key, v: value}
}
var (
_ = Constraint(&constraint{})
)

18
internal/adapter/func.go Normal file
View File

@ -0,0 +1,18 @@
package adapter
type FuncExpr struct {
name string
args []interface{}
}
func (f *FuncExpr) Arguments() []interface{} {
return f.args
}
func (f *FuncExpr) Name() string {
return f.name
}
func NewFuncExpr(name string, args []interface{}) *FuncExpr {
return &FuncExpr{name: name, args: args}
}

View File

@ -0,0 +1,100 @@
package adapter
import "git.hexq.cn/tiglog/mydb/internal/immutable"
// LogicalExpr represents a group formed by one or more sentences joined by
// an Operator like "AND" or "OR".
type LogicalExpr interface {
// Expressions returns child sentences.
Expressions() []LogicalExpr
// Operator returns the Operator that joins all the sentences in the group.
Operator() LogicalOperator
// Empty returns true if the compound has zero children, false otherwise.
Empty() bool
}
// LogicalOperator represents the operation on a compound statement.
type LogicalOperator uint
// LogicalExpr Operators.
const (
LogicalOperatorNone LogicalOperator = iota
LogicalOperatorAnd
LogicalOperatorOr
)
const DefaultLogicalOperator = LogicalOperatorAnd
type LogicalExprGroup struct {
op LogicalOperator
prev *LogicalExprGroup
fn func(*[]LogicalExpr) error
}
func NewLogicalExprGroup(op LogicalOperator, conds ...LogicalExpr) *LogicalExprGroup {
group := &LogicalExprGroup{op: op}
if len(conds) == 0 {
return group
}
return group.Frame(func(in *[]LogicalExpr) error {
*in = append(*in, conds...)
return nil
})
}
// Expressions returns each one of the conditions as a compound.
func (g *LogicalExprGroup) Expressions() []LogicalExpr {
conds, err := immutable.FastForward(g)
if err == nil {
return *(conds.(*[]LogicalExpr))
}
return nil
}
// Operator is undefined for a logical group.
func (g *LogicalExprGroup) Operator() LogicalOperator {
if g.op == LogicalOperatorNone {
panic("operator is not defined")
}
return g.op
}
// Empty returns true if this condition has no elements. False otherwise.
func (g *LogicalExprGroup) Empty() bool {
if g.fn != nil {
return false
}
if g.prev != nil {
return g.prev.Empty()
}
return true
}
func (g *LogicalExprGroup) Frame(fn func(*[]LogicalExpr) error) *LogicalExprGroup {
return &LogicalExprGroup{prev: g, op: g.op, fn: fn}
}
func (g *LogicalExprGroup) Prev() immutable.Immutable {
if g == nil {
return nil
}
return g.prev
}
func (g *LogicalExprGroup) Fn(in interface{}) error {
if g.fn == nil {
return nil
}
return g.fn(in.(*[]LogicalExpr))
}
func (g *LogicalExprGroup) Base() interface{} {
return &[]LogicalExpr{}
}
var (
_ = immutable.Immutable(&LogicalExprGroup{})
)

49
internal/adapter/raw.go Normal file
View File

@ -0,0 +1,49 @@
package adapter
// RawExpr interface represents values that can bypass SQL filters. This is an
// exported interface but it's rarely used directly, you may want to use the
// `db.Raw()` function instead.
type RawExpr struct {
value string
args *[]interface{}
}
func (r *RawExpr) Arguments() []interface{} {
if r.args != nil {
return *r.args
}
return nil
}
func (r RawExpr) Raw() string {
return r.value
}
func (r RawExpr) String() string {
return r.Raw()
}
// Expressions returns a logical expressio.n
func (r *RawExpr) Expressions() []LogicalExpr {
return []LogicalExpr{r}
}
// Operator returns the default compound operator.
func (r RawExpr) Operator() LogicalOperator {
return LogicalOperatorNone
}
// Empty return false if this struct has no value.
func (r *RawExpr) Empty() bool {
return r.value == ""
}
func NewRawExpr(value string, args []interface{}) *RawExpr {
r := &RawExpr{value: value, args: nil}
if len(args) > 0 {
r.args = &args
}
return r
}
var _ = LogicalExpr(&RawExpr{})

113
internal/cache/cache.go vendored Normal file
View File

@ -0,0 +1,113 @@
package cache
import (
"container/list"
"errors"
"sync"
)
const defaultCapacity = 128
// Cache holds a map of volatile key -> values.
type Cache struct {
keys *list.List
items map[uint64]*list.Element
mu sync.RWMutex
capacity int
}
type cacheItem struct {
key uint64
value interface{}
}
// NewCacheWithCapacity initializes a new caching space with the given
// capacity.
func NewCacheWithCapacity(capacity int) (*Cache, error) {
if capacity < 1 {
return nil, errors.New("Capacity must be greater than zero.")
}
c := &Cache{
capacity: capacity,
}
c.init()
return c, nil
}
// NewCache initializes a new caching space with default settings.
func NewCache() *Cache {
c, err := NewCacheWithCapacity(defaultCapacity)
if err != nil {
panic(err.Error()) // Should never happen as we're not providing a negative defaultCapacity.
}
return c
}
func (c *Cache) init() {
c.items = make(map[uint64]*list.Element)
c.keys = list.New()
}
// Read attempts to retrieve a cached value as a string, if the value does not
// exists returns an empty string and false.
func (c *Cache) Read(h Hashable) (string, bool) {
if v, ok := c.ReadRaw(h); ok {
if s, ok := v.(string); ok {
return s, true
}
}
return "", false
}
// ReadRaw attempts to retrieve a cached value as an interface{}, if the value
// does not exists returns nil and false.
func (c *Cache) ReadRaw(h Hashable) (interface{}, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
item, ok := c.items[h.Hash()]
if ok {
return item.Value.(*cacheItem).value, true
}
return nil, false
}
// Write stores a value in memory. If the value already exists its overwritten.
func (c *Cache) Write(h Hashable, value interface{}) {
c.mu.Lock()
defer c.mu.Unlock()
key := h.Hash()
if item, ok := c.items[key]; ok {
item.Value.(*cacheItem).value = value
c.keys.MoveToFront(item)
return
}
c.items[key] = c.keys.PushFront(&cacheItem{key, value})
for c.keys.Len() > c.capacity {
item := c.keys.Remove(c.keys.Back()).(*cacheItem)
delete(c.items, item.key)
if p, ok := item.value.(HasOnEvict); ok {
p.OnEvict()
}
}
}
// Clear generates a new memory space, leaving the old memory unreferenced, so
// it can be claimed by the garbage collector.
func (c *Cache) Clear() {
c.mu.Lock()
defer c.mu.Unlock()
for _, item := range c.items {
if p, ok := item.Value.(*cacheItem).value.(HasOnEvict); ok {
p.OnEvict()
}
}
c.init()
}

97
internal/cache/cache_test.go vendored Normal file
View File

@ -0,0 +1,97 @@
package cache
import (
"fmt"
"hash/fnv"
"testing"
)
var c *Cache
type cacheableT struct {
Name string
}
func (ct *cacheableT) Hash() uint64 {
s := fnv.New64()
s.Sum([]byte(ct.Name))
return s.Sum64()
}
var (
key = cacheableT{"foo"}
value = "bar"
)
func TestNewCache(t *testing.T) {
c = NewCache()
if c == nil {
t.Fatal("Expecting a new cache object.")
}
}
func TestCacheReadNonExistentValue(t *testing.T) {
if _, ok := c.Read(&key); ok {
t.Fatal("Expecting false.")
}
}
func TestCacheWritingValue(t *testing.T) {
c.Write(&key, value)
c.Write(&key, value)
}
func TestCacheReadExistentValue(t *testing.T) {
s, ok := c.Read(&key)
if !ok {
t.Fatal("Expecting true.")
}
if s != value {
t.Fatal("Expecting value.")
}
}
func BenchmarkNewCache(b *testing.B) {
for i := 0; i < b.N; i++ {
NewCache()
}
}
func BenchmarkNewCacheAndClear(b *testing.B) {
for i := 0; i < b.N; i++ {
c := NewCache()
c.Clear()
}
}
func BenchmarkReadNonExistentValue(b *testing.B) {
z := NewCache()
for i := 0; i < b.N; i++ {
z.Read(&key)
}
}
func BenchmarkWriteSameValue(b *testing.B) {
z := NewCache()
for i := 0; i < b.N; i++ {
z.Write(&key, value)
}
}
func BenchmarkWriteNewValue(b *testing.B) {
z := NewCache()
for i := 0; i < b.N; i++ {
key := cacheableT{fmt.Sprintf("item-%d", i)}
z.Write(&key, value)
}
}
func BenchmarkReadExistentValue(b *testing.B) {
z := NewCache()
z.Write(&key, value)
for i := 0; i < b.N; i++ {
z.Read(&key)
}
}

109
internal/cache/hash.go vendored Normal file
View File

@ -0,0 +1,109 @@
package cache
import (
"fmt"
"github.com/segmentio/fasthash/fnv1a"
)
const (
hashTypeInt uint64 = 1 << iota
hashTypeSignedInt
hashTypeBool
hashTypeString
hashTypeHashable
hashTypeNil
)
type hasher struct {
t uint64
v interface{}
}
func (h *hasher) Hash() uint64 {
return NewHash(h.t, h.v)
}
func NewHashable(t uint64, v interface{}) Hashable {
return &hasher{t: t, v: v}
}
func InitHash(t uint64) uint64 {
return fnv1a.AddUint64(fnv1a.Init64, t)
}
func NewHash(t uint64, in ...interface{}) uint64 {
return AddToHash(InitHash(t), in...)
}
func AddToHash(h uint64, in ...interface{}) uint64 {
for i := range in {
if in[i] == nil {
continue
}
h = addToHash(h, in[i])
}
return h
}
func addToHash(h uint64, in interface{}) uint64 {
switch v := in.(type) {
case uint64:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), v)
case uint32:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case uint16:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case uint8:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case uint:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
case int64:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int32:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int16:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int8:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case int:
if v < 0 {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeSignedInt), uint64(-v))
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeInt), uint64(v))
}
case bool:
if v {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 1)
} else {
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeBool), 2)
}
case string:
return fnv1a.AddString64(fnv1a.AddUint64(h, hashTypeString), v)
case Hashable:
if in == nil {
panic(fmt.Sprintf("could not hash nil element %T", in))
}
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeHashable), v.Hash())
case nil:
return fnv1a.AddUint64(fnv1a.AddUint64(h, hashTypeNil), 0)
default:
panic(fmt.Sprintf("unsupported value type %T", in))
}
}

13
internal/cache/interface.go vendored Normal file
View File

@ -0,0 +1,13 @@
package cache
// Hashable types must implement a method that returns a key. This key will be
// associated with a cached value.
type Hashable interface {
Hash() uint64
}
// HasOnEvict type is (optionally) implemented by cache objects to clean after
// themselves.
type HasOnEvict interface {
OnEvict()
}

View File

@ -0,0 +1,28 @@
package immutable
// Immutable represents an immutable chain that, if passed to FastForward,
// applies Fn() to every element of a chain, the first element of this chain is
// represented by Base().
type Immutable interface {
// Prev is the previous element on a chain.
Prev() Immutable
// Fn a function that is able to modify the passed element.
Fn(interface{}) error
// Base is the first element on a chain, there's no previous element before
// the Base element.
Base() interface{}
}
// FastForward applies all Fn methods in order on the given new Base.
func FastForward(curr Immutable) (interface{}, error) {
prev := curr.Prev()
if prev == nil {
return curr.Base(), nil
}
in, err := FastForward(prev)
if err != nil {
return nil, err
}
err = curr.Fn(in)
return in, err
}

23
internal/reflectx/LICENSE Normal file
View File

@ -0,0 +1,23 @@
Copyright (c) 2013, Jason Moiron
Permission is hereby granted, free of charge, to any person
obtaining a copy of this software and associated documentation
files (the "Software"), to deal in the Software without
restriction, including without limitation the rights to use,
copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the
Software is furnished to do so, subject to the following
conditions:
The above copyright notice and this permission notice shall be
included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
OTHER DEALINGS IN THE SOFTWARE.

View File

@ -0,0 +1,17 @@
# reflectx
The sqlx package has special reflect needs. In particular, it needs to:
* be able to map a name to a field
* understand embedded structs
* understand mapping names to fields by a particular tag
* user specified name -> field mapping functions
These behaviors mimic the behaviors by the standard library marshallers and also the
behavior of standard Go accessors.
The first two are amply taken care of by `Reflect.Value.FieldByName`, and the third is
addressed by `Reflect.Value.FieldByNameFunc`, but these don't quite understand struct
tags in the ways that are vital to most marshalers, and they are slow.
This reflectx package extends reflect to achieve these goals.

View File

@ -0,0 +1,404 @@
// Package reflectx implements extensions to the standard reflect lib suitable
// for implementing marshaling and unmarshaling packages. The main Mapper type
// allows for Go-compatible named attribute access, including accessing embedded
// struct attributes and the ability to use functions and struct tags to
// customize field names.
package reflectx
import (
"fmt"
"reflect"
"runtime"
"strings"
"sync"
)
// A FieldInfo is a collection of metadata about a struct field.
type FieldInfo struct {
Index []int
Path string
Field reflect.StructField
Zero reflect.Value
Name string
Options map[string]string
Embedded bool
Children []*FieldInfo
Parent *FieldInfo
}
// A StructMap is an index of field metadata for a struct.
type StructMap struct {
Tree *FieldInfo
Index []*FieldInfo
Paths map[string]*FieldInfo
Names map[string]*FieldInfo
}
// GetByPath returns a *FieldInfo for a given string path.
func (f StructMap) GetByPath(path string) *FieldInfo {
return f.Paths[path]
}
// GetByTraversal returns a *FieldInfo for a given integer path. It is
// analogous to reflect.FieldByIndex.
func (f StructMap) GetByTraversal(index []int) *FieldInfo {
if len(index) == 0 {
return nil
}
tree := f.Tree
for _, i := range index {
if i >= len(tree.Children) || tree.Children[i] == nil {
return nil
}
tree = tree.Children[i]
}
return tree
}
// Mapper is a general purpose mapper of names to struct fields. A Mapper
// behaves like most marshallers, optionally obeying a field tag for name
// mapping and a function to provide a basic mapping of fields to names.
type Mapper struct {
cache map[reflect.Type]*StructMap
tagName string
tagMapFunc func(string) string
mapFunc func(string) string
mutex sync.Mutex
}
// NewMapper returns a new mapper which optionally obeys the field tag given
// by tagName. If tagName is the empty string, it is ignored.
func NewMapper(tagName string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
}
}
// NewMapperTagFunc returns a new mapper which contains a mapper for field names
// AND a mapper for tag values. This is useful for tags like json which can
// have values like "name,omitempty".
func NewMapperTagFunc(tagName string, mapFunc, tagMapFunc func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: mapFunc,
tagMapFunc: tagMapFunc,
}
}
// NewMapperFunc returns a new mapper which optionally obeys a field tag and
// a struct field name mapper func given by f. Tags will take precedence, but
// for any other field, the mapped name will be f(field.Name)
func NewMapperFunc(tagName string, f func(string) string) *Mapper {
return &Mapper{
cache: make(map[reflect.Type]*StructMap),
tagName: tagName,
mapFunc: f,
}
}
// TypeMap returns a mapping of field strings to int slices representing
// the traversal down the struct to reach the field.
func (m *Mapper) TypeMap(t reflect.Type) *StructMap {
m.mutex.Lock()
mapping, ok := m.cache[t]
if !ok {
mapping = getMapping(t, m.tagName, m.mapFunc, m.tagMapFunc)
m.cache[t] = mapping
}
m.mutex.Unlock()
return mapping
}
// FieldMap returns the mapper's mapping of field names to reflect values. Panics
// if v's Kind is not Struct, or v is not Indirectable to a struct kind.
func (m *Mapper) FieldMap(v reflect.Value) map[string]reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
r := map[string]reflect.Value{}
tm := m.TypeMap(v.Type())
for tagName, fi := range tm.Names {
r[tagName] = FieldByIndexes(v, fi.Index)
}
return r
}
// ValidFieldMap returns the mapper's mapping of field names to reflect valid
// field values. Panics if v's Kind is not Struct, or v is not Indirectable to
// a struct kind.
func (m *Mapper) ValidFieldMap(v reflect.Value) map[string]reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
r := map[string]reflect.Value{}
tm := m.TypeMap(v.Type())
for tagName, fi := range tm.Names {
v := ValidFieldByIndexes(v, fi.Index)
if v.IsValid() {
r[tagName] = v
}
}
return r
}
// FieldByName returns a field by the its mapped name as a reflect.Value.
// Panics if v's Kind is not Struct or v is not Indirectable to a struct Kind.
// Returns zero Value if the name is not found.
func (m *Mapper) FieldByName(v reflect.Value, name string) reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
fi, ok := tm.Names[name]
if !ok {
return v
}
return FieldByIndexes(v, fi.Index)
}
// FieldsByName returns a slice of values corresponding to the slice of names
// for the value. Panics if v's Kind is not Struct or v is not Indirectable
// to a struct Kind. Returns zero Value for each name not found.
func (m *Mapper) FieldsByName(v reflect.Value, names []string) []reflect.Value {
v = reflect.Indirect(v)
mustBe(v, reflect.Struct)
tm := m.TypeMap(v.Type())
vals := make([]reflect.Value, 0, len(names))
for _, name := range names {
fi, ok := tm.Names[name]
if !ok {
vals = append(vals, *new(reflect.Value))
} else {
vals = append(vals, FieldByIndexes(v, fi.Index))
}
}
return vals
}
// TraversalsByName returns a slice of int slices which represent the struct
// traversals for each mapped name. Panics if t is not a struct or Indirectable
// to a struct. Returns empty int slice for each name not found.
func (m *Mapper) TraversalsByName(t reflect.Type, names []string) [][]int {
t = Deref(t)
mustBe(t, reflect.Struct)
tm := m.TypeMap(t)
r := make([][]int, 0, len(names))
for _, name := range names {
fi, ok := tm.Names[name]
if !ok {
r = append(r, []int{})
} else {
r = append(r, fi.Index)
}
}
return r
}
// FieldByIndexes returns a value for a particular struct traversal.
func FieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
// if this is a pointer, it's possible it is nil
if v.Kind() == reflect.Ptr && v.IsNil() {
alloc := reflect.New(Deref(v.Type()))
v.Set(alloc)
}
if v.Kind() == reflect.Map && v.IsNil() {
v.Set(reflect.MakeMap(v.Type()))
}
}
return v
}
// ValidFieldByIndexes returns a value for a particular struct traversal.
func ValidFieldByIndexes(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v)
if !v.IsValid() {
return reflect.Value{}
}
v = v.Field(i)
// if this is a pointer, it's possible it is nil
if (v.Kind() == reflect.Ptr || v.Kind() == reflect.Map) && v.IsNil() {
return reflect.Value{}
}
}
return v
}
// FieldByIndexesReadOnly returns a value for a particular struct traversal,
// but is not concerned with allocating nil pointers because the value is
// going to be used for reading and not setting.
func FieldByIndexesReadOnly(v reflect.Value, indexes []int) reflect.Value {
for _, i := range indexes {
v = reflect.Indirect(v).Field(i)
}
return v
}
// Deref is Indirect for reflect.Types
func Deref(t reflect.Type) reflect.Type {
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
return t
}
// -- helpers & utilities --
type kinder interface {
Kind() reflect.Kind
}
// mustBe checks a value against a kind, panicing with a reflect.ValueError
// if the kind isn't that which is required.
func mustBe(v kinder, expected reflect.Kind) {
k := v.Kind()
if k != expected {
panic(&reflect.ValueError{Method: methodName(), Kind: k})
}
}
// methodName is returns the caller of the function calling methodName
func methodName() string {
pc, _, _, _ := runtime.Caller(2)
f := runtime.FuncForPC(pc)
if f == nil {
return "unknown method"
}
return f.Name()
}
type typeQueue struct {
t reflect.Type
fi *FieldInfo
pp string // Parent path
}
// A copying append that creates a new slice each time.
func apnd(is []int, i int) []int {
x := make([]int, len(is)+1)
copy(x, is)
x[len(x)-1] = i
return x
}
// getMapping returns a mapping for the t type, using the tagName, mapFunc and
// tagMapFunc to determine the canonical names of fields.
func getMapping(t reflect.Type, tagName string, mapFunc, tagMapFunc func(string) string) *StructMap {
m := []*FieldInfo{}
root := &FieldInfo{}
queue := []typeQueue{}
queue = append(queue, typeQueue{Deref(t), root, ""})
for len(queue) != 0 {
// pop the first item off of the queue
tq := queue[0]
queue = queue[1:]
nChildren := 0
if tq.t.Kind() == reflect.Struct {
nChildren = tq.t.NumField()
}
tq.fi.Children = make([]*FieldInfo, nChildren)
// iterate through all of its fields
for fieldPos := 0; fieldPos < nChildren; fieldPos++ {
f := tq.t.Field(fieldPos)
fi := FieldInfo{}
fi.Field = f
fi.Zero = reflect.New(f.Type).Elem()
fi.Options = map[string]string{}
var tag, name string
if tagName != "" && strings.Contains(string(f.Tag), tagName+":") {
tag = f.Tag.Get(tagName)
name = tag
} else {
if mapFunc != nil {
name = mapFunc(f.Name)
}
}
parts := strings.Split(name, ",")
if len(parts) > 1 {
name = parts[0]
for _, opt := range parts[1:] {
kv := strings.Split(opt, "=")
if len(kv) > 1 {
fi.Options[kv[0]] = kv[1]
} else {
fi.Options[kv[0]] = ""
}
}
}
if tagMapFunc != nil {
tag = tagMapFunc(tag)
}
fi.Name = name
if tq.pp == "" || (tq.pp == "" && tag == "") {
fi.Path = fi.Name
} else {
fi.Path = fmt.Sprintf("%s.%s", tq.pp, fi.Name)
}
// if the name is "-", disabled via a tag, skip it
if name == "-" {
continue
}
// skip unexported fields
if len(f.PkgPath) != 0 && !f.Anonymous {
continue
}
// bfs search of anonymous embedded structs
if f.Anonymous {
pp := tq.pp
if tag != "" {
pp = fi.Path
}
fi.Embedded = true
fi.Index = apnd(tq.fi.Index, fieldPos)
nChildren := 0
ft := Deref(f.Type)
if ft.Kind() == reflect.Struct {
nChildren = ft.NumField()
}
fi.Children = make([]*FieldInfo, nChildren)
queue = append(queue, typeQueue{Deref(f.Type), &fi, pp})
} else if fi.Zero.Kind() == reflect.Struct || (fi.Zero.Kind() == reflect.Ptr && fi.Zero.Type().Elem().Kind() == reflect.Struct) {
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Children = make([]*FieldInfo, Deref(f.Type).NumField())
queue = append(queue, typeQueue{Deref(f.Type), &fi, fi.Path})
}
fi.Index = apnd(tq.fi.Index, fieldPos)
fi.Parent = tq.fi
tq.fi.Children[fieldPos] = &fi
m = append(m, &fi)
}
}
flds := &StructMap{Index: m, Tree: root, Paths: map[string]*FieldInfo{}, Names: map[string]*FieldInfo{}}
for _, fi := range flds.Index {
flds.Paths[fi.Path] = fi
if fi.Name != "" && !fi.Embedded {
flds.Names[fi.Path] = fi
}
}
return flds
}

View File

@ -0,0 +1,587 @@
package reflectx
import (
"reflect"
"strings"
"testing"
)
func ival(v reflect.Value) int {
return v.Interface().(int)
}
func TestBasic(t *testing.T) {
type Foo struct {
A int
B int
C int
}
f := Foo{1, 2, 3}
fv := reflect.ValueOf(f)
m := NewMapperFunc("", func(s string) string { return s })
v := m.FieldByName(fv, "A")
if ival(v) != f.A {
t.Errorf("Expecting %d, got %d", ival(v), f.A)
}
v = m.FieldByName(fv, "B")
if ival(v) != f.B {
t.Errorf("Expecting %d, got %d", f.B, ival(v))
}
v = m.FieldByName(fv, "C")
if ival(v) != f.C {
t.Errorf("Expecting %d, got %d", f.C, ival(v))
}
}
func TestBasicEmbedded(t *testing.T) {
type Foo struct {
A int
}
type Bar struct {
Foo // `db:""` is implied for an embedded struct
B int
C int `db:"-"`
}
type Baz struct {
A int
Bar `db:"Bar"`
}
m := NewMapperFunc("db", func(s string) string { return s })
z := Baz{}
z.A = 1
z.B = 2
z.C = 4
z.Bar.Foo.A = 3
zv := reflect.ValueOf(z)
fields := m.TypeMap(reflect.TypeOf(z))
if len(fields.Index) != 5 {
t.Errorf("Expecting 5 fields")
}
// for _, fi := range fields.Index {
// log.Println(fi)
// }
v := m.FieldByName(zv, "A")
if ival(v) != z.A {
t.Errorf("Expecting %d, got %d", z.A, ival(v))
}
v = m.FieldByName(zv, "Bar.B")
if ival(v) != z.Bar.B {
t.Errorf("Expecting %d, got %d", z.Bar.B, ival(v))
}
v = m.FieldByName(zv, "Bar.A")
if ival(v) != z.Bar.Foo.A {
t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v))
}
v = m.FieldByName(zv, "Bar.C")
if _, ok := v.Interface().(int); ok {
t.Errorf("Expecting Bar.C to not exist")
}
fi := fields.GetByPath("Bar.C")
if fi != nil {
t.Errorf("Bar.C should not exist")
}
}
func TestEmbeddedSimple(t *testing.T) {
type UUID [16]byte
type MyID struct {
UUID
}
type Item struct {
ID MyID
}
z := Item{}
m := NewMapper("db")
m.TypeMap(reflect.TypeOf(z))
}
func TestBasicEmbeddedWithTags(t *testing.T) {
type Foo struct {
A int `db:"a"`
}
type Bar struct {
Foo // `db:""` is implied for an embedded struct
B int `db:"b"`
}
type Baz struct {
A int `db:"a"`
Bar // `db:""` is implied for an embedded struct
}
m := NewMapper("db")
z := Baz{}
z.A = 1
z.B = 2
z.Bar.Foo.A = 3
zv := reflect.ValueOf(z)
fields := m.TypeMap(reflect.TypeOf(z))
if len(fields.Index) != 5 {
t.Errorf("Expecting 5 fields")
}
// for _, fi := range fields.index {
// log.Println(fi)
// }
v := m.FieldByName(zv, "a")
if ival(v) != z.Bar.Foo.A { // the dominant field
t.Errorf("Expecting %d, got %d", z.Bar.Foo.A, ival(v))
}
v = m.FieldByName(zv, "b")
if ival(v) != z.B {
t.Errorf("Expecting %d, got %d", z.B, ival(v))
}
}
func TestFlatTags(t *testing.T) {
m := NewMapper("db")
type Asset struct {
Title string `db:"title"`
}
type Post struct {
Author string `db:"author,required"`
Asset Asset `db:""`
}
// Post columns: (author title)
post := Post{Author: "Joe", Asset: Asset{Title: "Hello"}}
pv := reflect.ValueOf(post)
v := m.FieldByName(pv, "author")
if v.Interface().(string) != post.Author {
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
}
v = m.FieldByName(pv, "title")
if v.Interface().(string) != post.Asset.Title {
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
}
}
func TestNestedStruct(t *testing.T) {
m := NewMapper("db")
type Details struct {
Active bool `db:"active"`
}
type Asset struct {
Title string `db:"title"`
Details Details `db:"details"`
}
type Post struct {
Author string `db:"author,required"`
Asset `db:"asset"`
}
// Post columns: (author asset.title asset.details.active)
post := Post{
Author: "Joe",
Asset: Asset{Title: "Hello", Details: Details{Active: true}},
}
pv := reflect.ValueOf(post)
v := m.FieldByName(pv, "author")
if v.Interface().(string) != post.Author {
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
}
v = m.FieldByName(pv, "title")
if _, ok := v.Interface().(string); ok {
t.Errorf("Expecting field to not exist")
}
v = m.FieldByName(pv, "asset.title")
if v.Interface().(string) != post.Asset.Title {
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
}
v = m.FieldByName(pv, "asset.details.active")
if v.Interface().(bool) != post.Asset.Details.Active {
t.Errorf("Expecting %v, got %v", post.Asset.Details.Active, v.Interface().(bool))
}
}
func TestInlineStruct(t *testing.T) {
m := NewMapperTagFunc("db", strings.ToLower, nil)
type Employee struct {
Name string
ID int
}
type Boss Employee
type person struct {
Employee `db:"employee"`
Boss `db:"boss"`
}
// employees columns: (employee.name employee.id boss.name boss.id)
em := person{Employee: Employee{Name: "Joe", ID: 2}, Boss: Boss{Name: "Dick", ID: 1}}
ev := reflect.ValueOf(em)
fields := m.TypeMap(reflect.TypeOf(em))
if len(fields.Index) != 6 {
t.Errorf("Expecting 6 fields")
}
v := m.FieldByName(ev, "employee.name")
if v.Interface().(string) != em.Employee.Name {
t.Errorf("Expecting %s, got %s", em.Employee.Name, v.Interface().(string))
}
v = m.FieldByName(ev, "boss.id")
if ival(v) != em.Boss.ID {
t.Errorf("Expecting %v, got %v", em.Boss.ID, ival(v))
}
}
func TestFieldsEmbedded(t *testing.T) {
m := NewMapper("db")
type Person struct {
Name string `db:"name"`
}
type Place struct {
Name string `db:"name"`
}
type Article struct {
Title string `db:"title"`
}
type PP struct {
Person `db:"person,required"`
Place `db:",someflag"`
Article `db:",required"`
}
// PP columns: (person.name name title)
pp := PP{}
pp.Person.Name = "Peter"
pp.Place.Name = "Toronto"
pp.Article.Title = "Best city ever"
fields := m.TypeMap(reflect.TypeOf(pp))
// for i, f := range fields {
// log.Println(i, f)
// }
ppv := reflect.ValueOf(pp)
v := m.FieldByName(ppv, "person.name")
if v.Interface().(string) != pp.Person.Name {
t.Errorf("Expecting %s, got %s", pp.Person.Name, v.Interface().(string))
}
v = m.FieldByName(ppv, "name")
if v.Interface().(string) != pp.Place.Name {
t.Errorf("Expecting %s, got %s", pp.Place.Name, v.Interface().(string))
}
v = m.FieldByName(ppv, "title")
if v.Interface().(string) != pp.Article.Title {
t.Errorf("Expecting %s, got %s", pp.Article.Title, v.Interface().(string))
}
fi := fields.GetByPath("person")
if _, ok := fi.Options["required"]; !ok {
t.Errorf("Expecting required option to be set")
}
if !fi.Embedded {
t.Errorf("Expecting field to be embedded")
}
if len(fi.Index) != 1 || fi.Index[0] != 0 {
t.Errorf("Expecting index to be [0]")
}
fi = fields.GetByPath("person.name")
if fi == nil {
t.Errorf("Expecting person.name to exist")
}
if fi.Path != "person.name" {
t.Errorf("Expecting %s, got %s", "person.name", fi.Path)
}
fi = fields.GetByTraversal([]int{1, 0})
if fi == nil {
t.Errorf("Expecting traveral to exist")
}
if fi.Path != "name" {
t.Errorf("Expecting %s, got %s", "name", fi.Path)
}
fi = fields.GetByTraversal([]int{2})
if fi == nil {
t.Errorf("Expecting traversal to exist")
}
if _, ok := fi.Options["required"]; !ok {
t.Errorf("Expecting required option to be set")
}
trs := m.TraversalsByName(reflect.TypeOf(pp), []string{"person.name", "name", "title"})
if !reflect.DeepEqual(trs, [][]int{{0, 0}, {1, 0}, {2, 0}}) {
t.Errorf("Expecting traversal: %v", trs)
}
}
func TestPtrFields(t *testing.T) {
m := NewMapperTagFunc("db", strings.ToLower, nil)
type Asset struct {
Title string
}
type Post struct {
*Asset `db:"asset"`
Author string
}
post := &Post{Author: "Joe", Asset: &Asset{Title: "Hiyo"}}
pv := reflect.ValueOf(post)
fields := m.TypeMap(reflect.TypeOf(post))
if len(fields.Index) != 3 {
t.Errorf("Expecting 3 fields")
}
v := m.FieldByName(pv, "asset.title")
if v.Interface().(string) != post.Asset.Title {
t.Errorf("Expecting %s, got %s", post.Asset.Title, v.Interface().(string))
}
v = m.FieldByName(pv, "author")
if v.Interface().(string) != post.Author {
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
}
}
func TestNamedPtrFields(t *testing.T) {
m := NewMapperTagFunc("db", strings.ToLower, nil)
type User struct {
Name string
}
type Asset struct {
Title string
Owner *User `db:"owner"`
}
type Post struct {
Author string
Asset1 *Asset `db:"asset1"`
Asset2 *Asset `db:"asset2"`
}
post := &Post{Author: "Joe", Asset1: &Asset{Title: "Hiyo", Owner: &User{"Username"}}} // Let Asset2 be nil
pv := reflect.ValueOf(post)
fields := m.TypeMap(reflect.TypeOf(post))
if len(fields.Index) != 9 {
t.Errorf("Expecting 9 fields")
}
v := m.FieldByName(pv, "asset1.title")
if v.Interface().(string) != post.Asset1.Title {
t.Errorf("Expecting %s, got %s", post.Asset1.Title, v.Interface().(string))
}
v = m.FieldByName(pv, "asset1.owner.name")
if v.Interface().(string) != post.Asset1.Owner.Name {
t.Errorf("Expecting %s, got %s", post.Asset1.Owner.Name, v.Interface().(string))
}
v = m.FieldByName(pv, "asset2.title")
if v.Interface().(string) != post.Asset2.Title {
t.Errorf("Expecting %s, got %s", post.Asset2.Title, v.Interface().(string))
}
v = m.FieldByName(pv, "asset2.owner.name")
if v.Interface().(string) != post.Asset2.Owner.Name {
t.Errorf("Expecting %s, got %s", post.Asset2.Owner.Name, v.Interface().(string))
}
v = m.FieldByName(pv, "author")
if v.Interface().(string) != post.Author {
t.Errorf("Expecting %s, got %s", post.Author, v.Interface().(string))
}
}
func TestFieldMap(t *testing.T) {
type Foo struct {
A int
B int
C int
}
f := Foo{1, 2, 3}
m := NewMapperFunc("db", strings.ToLower)
fm := m.FieldMap(reflect.ValueOf(f))
if len(fm) != 3 {
t.Errorf("Expecting %d keys, got %d", 3, len(fm))
}
if fm["a"].Interface().(int) != 1 {
t.Errorf("Expecting %d, got %d", 1, ival(fm["a"]))
}
if fm["b"].Interface().(int) != 2 {
t.Errorf("Expecting %d, got %d", 2, ival(fm["b"]))
}
if fm["c"].Interface().(int) != 3 {
t.Errorf("Expecting %d, got %d", 3, ival(fm["c"]))
}
}
func TestTagNameMapping(t *testing.T) {
type Strategy struct {
StrategyID string `protobuf:"bytes,1,opt,name=strategy_id" json:"strategy_id,omitempty"`
StrategyName string
}
m := NewMapperTagFunc("json", strings.ToUpper, func(value string) string {
if strings.Contains(value, ",") {
return strings.Split(value, ",")[0]
}
return value
})
strategy := Strategy{"1", "Alpah"}
mapping := m.TypeMap(reflect.TypeOf(strategy))
for _, key := range []string{"strategy_id", "STRATEGYNAME"} {
if fi := mapping.GetByPath(key); fi == nil {
t.Errorf("Expecting to find key %s in mapping but did not.", key)
}
}
}
func TestMapping(t *testing.T) {
type Person struct {
ID int
Name string
WearsGlasses bool `db:"wears_glasses"`
}
m := NewMapperFunc("db", strings.ToLower)
p := Person{1, "Jason", true}
mapping := m.TypeMap(reflect.TypeOf(p))
for _, key := range []string{"id", "name", "wears_glasses"} {
if fi := mapping.GetByPath(key); fi == nil {
t.Errorf("Expecting to find key %s in mapping but did not.", key)
}
}
type SportsPerson struct {
Weight int
Age int
Person
}
s := SportsPerson{Weight: 100, Age: 30, Person: p}
mapping = m.TypeMap(reflect.TypeOf(s))
for _, key := range []string{"id", "name", "wears_glasses", "weight", "age"} {
if fi := mapping.GetByPath(key); fi == nil {
t.Errorf("Expecting to find key %s in mapping but did not.", key)
}
}
type RugbyPlayer struct {
Position int
IsIntense bool `db:"is_intense"`
IsAllBlack bool `db:"-"`
SportsPerson
}
r := RugbyPlayer{12, true, false, s}
mapping = m.TypeMap(reflect.TypeOf(r))
for _, key := range []string{"id", "name", "wears_glasses", "weight", "age", "position", "is_intense"} {
if fi := mapping.GetByPath(key); fi == nil {
t.Errorf("Expecting to find key %s in mapping but did not.", key)
}
}
if fi := mapping.GetByPath("isallblack"); fi != nil {
t.Errorf("Expecting to ignore `IsAllBlack` field")
}
}
type E1 struct {
A int
}
type E2 struct {
E1
B int
}
type E3 struct {
E2
C int
}
type E4 struct {
E3
D int
}
func BenchmarkFieldNameL1(b *testing.B) {
e4 := E4{D: 1}
for i := 0; i < b.N; i++ {
v := reflect.ValueOf(e4)
f := v.FieldByName("D")
if f.Interface().(int) != 1 {
b.Fatal("Wrong value.")
}
}
}
func BenchmarkFieldNameL4(b *testing.B) {
e4 := E4{}
e4.A = 1
for i := 0; i < b.N; i++ {
v := reflect.ValueOf(e4)
f := v.FieldByName("A")
if f.Interface().(int) != 1 {
b.Fatal("Wrong value.")
}
}
}
func BenchmarkFieldPosL1(b *testing.B) {
e4 := E4{D: 1}
for i := 0; i < b.N; i++ {
v := reflect.ValueOf(e4)
f := v.Field(1)
if f.Interface().(int) != 1 {
b.Fatal("Wrong value.")
}
}
}
func BenchmarkFieldPosL4(b *testing.B) {
e4 := E4{}
e4.A = 1
for i := 0; i < b.N; i++ {
v := reflect.ValueOf(e4)
f := v.Field(0)
f = f.Field(0)
f = f.Field(0)
f = f.Field(0)
if f.Interface().(int) != 1 {
b.Fatal("Wrong value.")
}
}
}
func BenchmarkFieldByIndexL4(b *testing.B) {
e4 := E4{}
e4.A = 1
idx := []int{0, 0, 0, 0}
for i := 0; i < b.N; i++ {
v := reflect.ValueOf(e4)
f := FieldByIndexes(v, idx)
if f.Interface().(int) != 1 {
b.Fatal("Wrong value.")
}
}
}

View File

@ -0,0 +1,369 @@
package sqladapter
import (
"fmt"
"reflect"
"git.hexq.cn/tiglog/mydb"
"git.hexq.cn/tiglog/mydb/internal/sqladapter/exql"
"git.hexq.cn/tiglog/mydb/internal/sqlbuilder"
)
// CollectionAdapter defines methods to be implemented by SQL adapters.
type CollectionAdapter interface {
// Insert prepares and executes an INSERT statament. When the item is
// succefully added, Insert returns a unique identifier of the newly added
// element (or nil if the unique identifier couldn't be determined).
Insert(Collection, interface{}) (interface{}, error)
}
// Collection satisfies mydb.Collection.
type Collection interface {
// Insert inserts a new item into the collection.
Insert(interface{}) (mydb.InsertResult, error)
// Name returns the name of the collection.
Name() string
// Session returns the mydb.Session the collection belongs to.
Session() mydb.Session
// Exists returns true if the collection exists, false otherwise.
Exists() (bool, error)
// Find defined a new result set.
Find(conds ...interface{}) mydb.Result
Count() (uint64, error)
// Truncate removes all elements on the collection and resets the
// collection's IDs.
Truncate() error
// InsertReturning inserts a new item into the collection and refreshes the
// item with actual data from the database. This is useful to get automatic
// values, such as timestamps, or IDs.
InsertReturning(item interface{}) error
// UpdateReturning updates a record from the collection and refreshes the item
// with actual data from the database. This is useful to get automatic
// values, such as timestamps, or IDs.
UpdateReturning(item interface{}) error
// PrimaryKeys returns the names of all primary keys in the table.
PrimaryKeys() ([]string, error)
// SQLBuilder returns a mydb.SQL instance.
SQL() mydb.SQL
}
type finder interface {
Find(Collection, *Result, ...interface{}) mydb.Result
}
type condsFilter interface {
FilterConds(...interface{}) []interface{}
}
// collection is the implementation of Collection.
type collection struct {
name string
adapter CollectionAdapter
}
type collectionWithSession struct {
*collection
session Session
}
func newCollection(name string, adapter CollectionAdapter) *collection {
if adapter == nil {
panic("mydb: nil adapter")
}
return &collection{
name: name,
adapter: adapter,
}
}
func (c *collectionWithSession) SQL() mydb.SQL {
return c.session.SQL()
}
func (c *collectionWithSession) Session() mydb.Session {
return c.session
}
func (c *collectionWithSession) Name() string {
return c.name
}
func (c *collectionWithSession) Count() (uint64, error) {
return c.Find().Count()
}
func (c *collectionWithSession) Insert(item interface{}) (mydb.InsertResult, error) {
id, err := c.adapter.Insert(c, item)
if err != nil {
return nil, err
}
return mydb.NewInsertResult(id), nil
}
func (c *collectionWithSession) PrimaryKeys() ([]string, error) {
return c.session.PrimaryKeys(c.Name())
}
func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) {
pk, err := c.PrimaryKeys()
if err != nil {
return nil, err
}
if len(conds) == 1 && len(pk) == 1 {
if id := conds[0]; IsKeyValue(id) {
conds[0] = mydb.Cond{pk[0]: mydb.Eq(id)}
}
}
if tr, ok := c.adapter.(condsFilter); ok {
return tr.FilterConds(conds...), nil
}
return conds, nil
}
func (c *collectionWithSession) Find(conds ...interface{}) mydb.Result {
filteredConds, err := c.filterConds(conds...)
if err != nil {
res := &Result{}
res.setErr(err)
return res
}
res := NewResult(
c.session.SQL(),
c.Name(),
filteredConds,
)
if f, ok := c.adapter.(finder); ok {
return f.Find(c, res, conds...)
}
return res
}
func (c *collectionWithSession) Exists() (bool, error) {
if err := c.session.TableExists(c.Name()); err != nil {
return false, err
}
return true, nil
}
func (c *collectionWithSession) InsertReturning(item interface{}) error {
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
return fmt.Errorf("Expecting a pointer but got %T", item)
}
// Grab primary keys
pks, err := c.PrimaryKeys()
if err != nil {
return err
}
if len(pks) == 0 {
if ok, err := c.Exists(); !ok {
return err
}
return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name())
}
var tx Session
isTransaction := c.session.IsTransaction()
if isTransaction {
tx = c.session
} else {
var err error
tx, err = c.session.NewTransaction(c.session.Context(), nil)
if err != nil {
return err
}
defer tx.Close()
}
// Allocate a clone of item.
newItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface()
var newItemFieldMap map[string]reflect.Value
itemValue := reflect.ValueOf(item)
col := tx.Collection(c.Name())
// Insert item as is and grab the returning ID.
var newItemRes mydb.Result
id, err := col.Insert(item)
if err != nil {
goto cancel
}
if id == nil {
err = fmt.Errorf("InsertReturning: Could not get a valid ID after inserting. Does the %q table have a primary key?", c.Name())
goto cancel
}
if len(pks) > 1 {
newItemRes = col.Find(id)
} else {
// We have one primary key, build a explicit mydb.Cond with it to prevent
// string keys to be considered as raw conditions.
newItemRes = col.Find(mydb.Cond{pks[0]: id}) // We already checked that pks is not empty, so pks[0] is defined.
}
// Fetch the row that was just interted into newItem
err = newItemRes.One(newItem)
if err != nil {
goto cancel
}
switch reflect.ValueOf(newItem).Elem().Kind() {
case reflect.Struct:
// Get valid fields from newItem to overwrite those that are on item.
newItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(newItem))
for fieldName := range newItemFieldMap {
sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(newItemFieldMap[fieldName])
}
case reflect.Map:
newItemV := reflect.ValueOf(newItem).Elem()
itemV := reflect.ValueOf(item)
if itemV.Kind() == reflect.Ptr {
itemV = itemV.Elem()
}
for _, keyV := range newItemV.MapKeys() {
itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV))
}
default:
err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem)
goto cancel
}
if !isTransaction {
// This is only executed if t.Session() was **not** a transaction and if
// sess was created with sess.NewTransaction().
return tx.Commit()
}
return err
cancel:
// This goto label should only be used when we got an error within a
// transaction and we don't want to continue.
if !isTransaction {
// This is only executed if t.Session() was **not** a transaction and if
// sess was created with sess.NewTransaction().
_ = tx.Rollback()
}
return err
}
func (c *collectionWithSession) UpdateReturning(item interface{}) error {
if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr {
return fmt.Errorf("Expecting a pointer but got %T", item)
}
// Grab primary keys
pks, err := c.PrimaryKeys()
if err != nil {
return err
}
if len(pks) == 0 {
if ok, err := c.Exists(); !ok {
return err
}
return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name())
}
var tx Session
isTransaction := c.session.IsTransaction()
if isTransaction {
tx = c.session
} else {
// Not within a transaction, let's create one.
var err error
tx, err = c.session.NewTransaction(c.session.Context(), nil)
if err != nil {
return err
}
defer tx.Close()
}
// Allocate a clone of item.
defaultItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface()
var defaultItemFieldMap map[string]reflect.Value
itemValue := reflect.ValueOf(item)
conds := mydb.Cond{}
for _, pk := range pks {
conds[pk] = mydb.Eq(sqlbuilder.Mapper.FieldByName(itemValue, pk).Interface())
}
col := tx.(Session).Collection(c.Name())
err = col.Find(conds).Update(item)
if err != nil {
goto cancel
}
if err = col.Find(conds).One(defaultItem); err != nil {
goto cancel
}
switch reflect.ValueOf(defaultItem).Elem().Kind() {
case reflect.Struct:
// Get valid fields from defaultItem to overwrite those that are on item.
defaultItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(defaultItem))
for fieldName := range defaultItemFieldMap {
sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(defaultItemFieldMap[fieldName])
}
case reflect.Map:
defaultItemV := reflect.ValueOf(defaultItem).Elem()
itemV := reflect.ValueOf(item)
if itemV.Kind() == reflect.Ptr {
itemV = itemV.Elem()
}
for _, keyV := range defaultItemV.MapKeys() {
itemV.SetMapIndex(keyV, defaultItemV.MapIndex(keyV))
}
default:
panic("default")
}
if !isTransaction {
// This is only executed if t.Session() was **not** a transaction and if
// sess was created with sess.NewTransaction().
return tx.Commit()
}
return err
cancel:
// This goto label should only be used when we got an error within a
// transaction and we don't want to continue.
if !isTransaction {
// This is only executed if t.Session() was **not** a transaction and if
// sess was created with sess.NewTransaction().
_ = tx.Rollback()
}
return err
}
func (c *collectionWithSession) Truncate() error {
stmt := exql.Statement{
Type: exql.Truncate,
Table: exql.TableWithName(c.Name()),
}
if _, err := c.session.SQL().Exec(&stmt); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,72 @@
// +build !go1.8
package compat
import (
"context"
"database/sql"
)
type PreparedExecer interface {
Exec(...interface{}) (sql.Result, error)
}
func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) {
return p.Exec(args...)
}
type Execer interface {
Exec(string, ...interface{}) (sql.Result, error)
}
func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) {
return p.Exec(query, args...)
}
type PreparedQueryer interface {
Query(...interface{}) (*sql.Rows, error)
}
func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) {
return p.Query(args...)
}
type Queryer interface {
Query(string, ...interface{}) (*sql.Rows, error)
}
func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) {
return p.Query(query, args...)
}
type PreparedRowQueryer interface {
QueryRow(...interface{}) *sql.Row
}
func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row {
return p.QueryRow(args...)
}
type RowQueryer interface {
QueryRow(string, ...interface{}) *sql.Row
}
func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row {
return p.QueryRow(query, args...)
}
type Preparer interface {
Prepare(string) (*sql.Stmt, error)
}
func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) {
return p.Prepare(query)
}
type TxStarter interface {
Begin() (*sql.Tx, error)
}
func BeginTx(p TxStarter, ctx context.Context, opts interface{}) (*sql.Tx, error) {
return p.Begin()
}

View File

@ -0,0 +1,72 @@
// +build go1.8
package compat
import (
"context"
"database/sql"
)
type PreparedExecer interface {
ExecContext(context.Context, ...interface{}) (sql.Result, error)
}
func PreparedExecContext(p PreparedExecer, ctx context.Context, args []interface{}) (sql.Result, error) {
return p.ExecContext(ctx, args...)
}
type Execer interface {
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
}
func ExecContext(p Execer, ctx context.Context, query string, args []interface{}) (sql.Result, error) {
return p.ExecContext(ctx, query, args...)
}
type PreparedQueryer interface {
QueryContext(context.Context, ...interface{}) (*sql.Rows, error)
}
func PreparedQueryContext(p PreparedQueryer, ctx context.Context, args []interface{}) (*sql.Rows, error) {
return p.QueryContext(ctx, args...)
}
type Queryer interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error)
}
func QueryContext(p Queryer, ctx context.Context, query string, args []interface{}) (*sql.Rows, error) {
return p.QueryContext(ctx, query, args...)
}
type PreparedRowQueryer interface {
QueryRowContext(context.Context, ...interface{}) *sql.Row
}
func PreparedQueryRowContext(p PreparedRowQueryer, ctx context.Context, args []interface{}) *sql.Row {
return p.QueryRowContext(ctx, args...)
}
type RowQueryer interface {
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
}
func QueryRowContext(p RowQueryer, ctx context.Context, query string, args []interface{}) *sql.Row {
return p.QueryRowContext(ctx, query, args...)
}
type Preparer interface {
PrepareContext(context.Context, string) (*sql.Stmt, error)
}
func PrepareContext(p Preparer, ctx context.Context, query string) (*sql.Stmt, error) {
return p.PrepareContext(ctx, query)
}
type TxStarter interface {
BeginTx(context.Context, *sql.TxOptions) (*sql.Tx, error)
}
func BeginTx(p TxStarter, ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) {
return p.BeginTx(ctx, opts)
}

View File

@ -0,0 +1,83 @@
package exql
import (
"fmt"
"strings"
"git.hexq.cn/tiglog/mydb/internal/cache"
)
type columnWithAlias struct {
Name string
Alias string
}
// Column represents a SQL column.
type Column struct {
Name interface{}
}
var _ = Fragment(&Column{})
// ColumnWithName creates and returns a Column with the given name.
func ColumnWithName(name string) *Column {
return &Column{Name: name}
}
// Hash returns a unique identifier for the struct.
func (c *Column) Hash() uint64 {
if c == nil {
return cache.NewHash(FragmentType_Column, nil)
}
return cache.NewHash(FragmentType_Column, c.Name)
}
// Compile transforms the ColumnValue into an equivalent SQL representation.
func (c *Column) Compile(layout *Template) (compiled string, err error) {
if z, ok := layout.Read(c); ok {
return z, nil
}
var alias string
switch value := c.Name.(type) {
case string:
value = trimString(value)
chunks := separateByAS(value)
if len(chunks) == 1 {
chunks = separateBySpace(value)
}
name := chunks[0]
nameChunks := strings.SplitN(name, layout.ColumnSeparator, 2)
for i := range nameChunks {
nameChunks[i] = trimString(nameChunks[i])
if nameChunks[i] == "*" {
continue
}
nameChunks[i] = layout.MustCompile(layout.IdentifierQuote, Raw{Value: nameChunks[i]})
}
compiled = strings.Join(nameChunks, layout.ColumnSeparator)
if len(chunks) > 1 {
alias = trimString(chunks[1])
alias = layout.MustCompile(layout.IdentifierQuote, Raw{Value: alias})
}
case compilable:
compiled, err = value.Compile(layout)
if err != nil {
return "", err
}
default:
return "", fmt.Errorf(errExpectingHashableFmt, c.Name)
}
if alias != "" {
compiled = layout.MustCompile(layout.ColumnAliasLayout, columnWithAlias{compiled, alias})
}
layout.Write(c, compiled)
return
}

View File

@ -0,0 +1,88 @@
package exql
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestColumnString(t *testing.T) {
column := Column{Name: "role.name"}
s, err := column.Compile(defaultTemplate)
assert.NoError(t, err)
assert.Equal(t, `"role"."name"`, s)
}
func TestColumnAs(t *testing.T) {
column := Column{Name: "role.name as foo"}
s, err := column.Compile(defaultTemplate)
assert.NoError(t, err)
assert.Equal(t, `"role"."name" AS "foo"`, s)
}
func TestColumnImplicitAs(t *testing.T) {
column := Column{Name: "role.name foo"}
s, err := column.Compile(defaultTemplate)
assert.NoError(t, err)
assert.Equal(t, `"role"."name" AS "foo"`, s)
}
func TestColumnRaw(t *testing.T) {
column := Column{Name: &Raw{Value: "role.name As foo"}}
s, err := column.Compile(defaultTemplate)
assert.NoError(t, err)
assert.Equal(t, `role.name As foo`, s)
}
func BenchmarkColumnWithName(b *testing.B) {
for i := 0; i < b.N; i++ {
_ = ColumnWithName("a")
}
}
func BenchmarkColumnHash(b *testing.B) {
c := Column{Name: "name"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.Hash()
}
}
func BenchmarkColumnCompile(b *testing.B) {
c := Column{Name: "name"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = c.Compile(defaultTemplate)
}
}
func BenchmarkColumnCompileNoCache(b *testing.B) {
for i := 0; i < b.N; i++ {
c := Column{Name: "name"}
_, _ = c.Compile(defaultTemplate)
}
}
func BenchmarkColumnWithDotCompile(b *testing.B) {
c := Column{Name: "role.name"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = c.Compile(defaultTemplate)
}
}
func BenchmarkColumnWithImplicitAsKeywordCompile(b *testing.B) {
c := Column{Name: "role.name foo"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = c.Compile(defaultTemplate)
}
}
func BenchmarkColumnWithAsKeywordCompile(b *testing.B) {
c := Column{Name: "role.name AS foo"}
b.ResetTimer()
for i := 0; i < b.N; i++ {
_, _ = c.Compile(defaultTemplate)
}
}

Some files were not shown because too many files have changed in this diff Show More