package sqladapter import ( "fmt" "reflect" "git.hexq.cn/tiglog/mydb" "git.hexq.cn/tiglog/mydb/internal/sqladapter/exql" "git.hexq.cn/tiglog/mydb/internal/sqlbuilder" ) // CollectionAdapter defines methods to be implemented by SQL adapters. type CollectionAdapter interface { // Insert prepares and executes an INSERT statament. When the item is // succefully added, Insert returns a unique identifier of the newly added // element (or nil if the unique identifier couldn't be determined). Insert(Collection, interface{}) (interface{}, error) } // Collection satisfies mydb.Collection. type Collection interface { // Insert inserts a new item into the collection. Insert(interface{}) (mydb.InsertResult, error) // Name returns the name of the collection. Name() string // Session returns the mydb.Session the collection belongs to. Session() mydb.Session // Exists returns true if the collection exists, false otherwise. Exists() (bool, error) // Find defined a new result set. Find(conds ...interface{}) mydb.Result Count() (uint64, error) // Truncate removes all elements on the collection and resets the // collection's IDs. Truncate() error // InsertReturning inserts a new item into the collection and refreshes the // item with actual data from the database. This is useful to get automatic // values, such as timestamps, or IDs. InsertReturning(item interface{}) error // UpdateReturning updates a record from the collection and refreshes the item // with actual data from the database. This is useful to get automatic // values, such as timestamps, or IDs. UpdateReturning(item interface{}) error // PrimaryKeys returns the names of all primary keys in the table. PrimaryKeys() ([]string, error) // SQLBuilder returns a mydb.SQL instance. SQL() mydb.SQL } type finder interface { Find(Collection, *Result, ...interface{}) mydb.Result } type condsFilter interface { FilterConds(...interface{}) []interface{} } // collection is the implementation of Collection. type collection struct { name string adapter CollectionAdapter } type collectionWithSession struct { *collection session Session } func newCollection(name string, adapter CollectionAdapter) *collection { if adapter == nil { panic("mydb: nil adapter") } return &collection{ name: name, adapter: adapter, } } func (c *collectionWithSession) SQL() mydb.SQL { return c.session.SQL() } func (c *collectionWithSession) Session() mydb.Session { return c.session } func (c *collectionWithSession) Name() string { return c.name } func (c *collectionWithSession) Count() (uint64, error) { return c.Find().Count() } func (c *collectionWithSession) Insert(item interface{}) (mydb.InsertResult, error) { id, err := c.adapter.Insert(c, item) if err != nil { return nil, err } return mydb.NewInsertResult(id), nil } func (c *collectionWithSession) PrimaryKeys() ([]string, error) { return c.session.PrimaryKeys(c.Name()) } func (c *collectionWithSession) filterConds(conds ...interface{}) ([]interface{}, error) { pk, err := c.PrimaryKeys() if err != nil { return nil, err } if len(conds) == 1 && len(pk) == 1 { if id := conds[0]; IsKeyValue(id) { conds[0] = mydb.Cond{pk[0]: mydb.Eq(id)} } } if tr, ok := c.adapter.(condsFilter); ok { return tr.FilterConds(conds...), nil } return conds, nil } func (c *collectionWithSession) Find(conds ...interface{}) mydb.Result { filteredConds, err := c.filterConds(conds...) if err != nil { res := &Result{} res.setErr(err) return res } res := NewResult( c.session.SQL(), c.Name(), filteredConds, ) if f, ok := c.adapter.(finder); ok { return f.Find(c, res, conds...) } return res } func (c *collectionWithSession) Exists() (bool, error) { if err := c.session.TableExists(c.Name()); err != nil { return false, err } return true, nil } func (c *collectionWithSession) InsertReturning(item interface{}) error { if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { return fmt.Errorf("Expecting a pointer but got %T", item) } // Grab primary keys pks, err := c.PrimaryKeys() if err != nil { return err } if len(pks) == 0 { if ok, err := c.Exists(); !ok { return err } return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name()) } var tx Session isTransaction := c.session.IsTransaction() if isTransaction { tx = c.session } else { var err error tx, err = c.session.NewTransaction(c.session.Context(), nil) if err != nil { return err } defer tx.Close() } // Allocate a clone of item. newItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface() var newItemFieldMap map[string]reflect.Value itemValue := reflect.ValueOf(item) col := tx.Collection(c.Name()) // Insert item as is and grab the returning ID. var newItemRes mydb.Result id, err := col.Insert(item) if err != nil { goto cancel } if id == nil { err = fmt.Errorf("InsertReturning: Could not get a valid ID after inserting. Does the %q table have a primary key?", c.Name()) goto cancel } if len(pks) > 1 { newItemRes = col.Find(id) } else { // We have one primary key, build a explicit mydb.Cond with it to prevent // string keys to be considered as raw conditions. newItemRes = col.Find(mydb.Cond{pks[0]: id}) // We already checked that pks is not empty, so pks[0] is defined. } // Fetch the row that was just interted into newItem err = newItemRes.One(newItem) if err != nil { goto cancel } switch reflect.ValueOf(newItem).Elem().Kind() { case reflect.Struct: // Get valid fields from newItem to overwrite those that are on item. newItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(newItem)) for fieldName := range newItemFieldMap { sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(newItemFieldMap[fieldName]) } case reflect.Map: newItemV := reflect.ValueOf(newItem).Elem() itemV := reflect.ValueOf(item) if itemV.Kind() == reflect.Ptr { itemV = itemV.Elem() } for _, keyV := range newItemV.MapKeys() { itemV.SetMapIndex(keyV, newItemV.MapIndex(keyV)) } default: err = fmt.Errorf("InsertReturning: expecting a pointer to map or struct, got %T", newItem) goto cancel } if !isTransaction { // This is only executed if t.Session() was **not** a transaction and if // sess was created with sess.NewTransaction(). return tx.Commit() } return err cancel: // This goto label should only be used when we got an error within a // transaction and we don't want to continue. if !isTransaction { // This is only executed if t.Session() was **not** a transaction and if // sess was created with sess.NewTransaction(). _ = tx.Rollback() } return err } func (c *collectionWithSession) UpdateReturning(item interface{}) error { if item == nil || reflect.TypeOf(item).Kind() != reflect.Ptr { return fmt.Errorf("Expecting a pointer but got %T", item) } // Grab primary keys pks, err := c.PrimaryKeys() if err != nil { return err } if len(pks) == 0 { if ok, err := c.Exists(); !ok { return err } return fmt.Errorf(mydb.ErrMissingPrimaryKeys.Error(), c.Name()) } var tx Session isTransaction := c.session.IsTransaction() if isTransaction { tx = c.session } else { // Not within a transaction, let's create one. var err error tx, err = c.session.NewTransaction(c.session.Context(), nil) if err != nil { return err } defer tx.Close() } // Allocate a clone of item. defaultItem := reflect.New(reflect.ValueOf(item).Elem().Type()).Interface() var defaultItemFieldMap map[string]reflect.Value itemValue := reflect.ValueOf(item) conds := mydb.Cond{} for _, pk := range pks { conds[pk] = mydb.Eq(sqlbuilder.Mapper.FieldByName(itemValue, pk).Interface()) } col := tx.(Session).Collection(c.Name()) err = col.Find(conds).Update(item) if err != nil { goto cancel } if err = col.Find(conds).One(defaultItem); err != nil { goto cancel } switch reflect.ValueOf(defaultItem).Elem().Kind() { case reflect.Struct: // Get valid fields from defaultItem to overwrite those that are on item. defaultItemFieldMap = sqlbuilder.Mapper.ValidFieldMap(reflect.ValueOf(defaultItem)) for fieldName := range defaultItemFieldMap { sqlbuilder.Mapper.FieldByName(itemValue, fieldName).Set(defaultItemFieldMap[fieldName]) } case reflect.Map: defaultItemV := reflect.ValueOf(defaultItem).Elem() itemV := reflect.ValueOf(item) if itemV.Kind() == reflect.Ptr { itemV = itemV.Elem() } for _, keyV := range defaultItemV.MapKeys() { itemV.SetMapIndex(keyV, defaultItemV.MapIndex(keyV)) } default: panic("default") } if !isTransaction { // This is only executed if t.Session() was **not** a transaction and if // sess was created with sess.NewTransaction(). return tx.Commit() } return err cancel: // This goto label should only be used when we got an error within a // transaction and we don't want to continue. if !isTransaction { // This is only executed if t.Session() was **not** a transaction and if // sess was created with sess.NewTransaction(). _ = tx.Rollback() } return err } func (c *collectionWithSession) Truncate() error { stmt := exql.Statement{ Type: exql.Truncate, Table: exql.TableWithName(c.Name()), } if _, err := c.session.SQL().Exec(&stmt); err != nil { return err } return nil }