158 lines
4.1 KiB
Go
158 lines
4.1 KiB
Go
|
//
|
||
|
// binder.go
|
||
|
// Copyright (C) 2023 tiglog <me@tiglog.com>
|
||
|
//
|
||
|
// Distributed under terms of the MIT license.
|
||
|
//
|
||
|
|
||
|
package orm
|
||
|
|
||
|
import (
|
||
|
"database/sql"
|
||
|
"database/sql/driver"
|
||
|
"fmt"
|
||
|
"reflect"
|
||
|
"unsafe"
|
||
|
)
|
||
|
|
||
|
// makeNewPointersOf creates a map of [field name] -> pointer to fill it
|
||
|
// recursively. it will go down until reaches a driver.Valuer implementation, it will stop there.
|
||
|
func (b *binder) makeNewPointersOf(v reflect.Value) interface{} {
|
||
|
m := map[string]interface{}{}
|
||
|
actualV := v
|
||
|
for actualV.Type().Kind() == reflect.Ptr {
|
||
|
actualV = actualV.Elem()
|
||
|
}
|
||
|
if actualV.Type().Kind() == reflect.Struct {
|
||
|
for i := 0; i < actualV.NumField(); i++ {
|
||
|
f := actualV.Field(i)
|
||
|
if (f.Type().Kind() == reflect.Struct || f.Type().Kind() == reflect.Ptr) && !f.Type().Implements(reflect.TypeOf((*driver.Valuer)(nil)).Elem()) {
|
||
|
f = reflect.NewAt(actualV.Type().Field(i).Type, unsafe.Pointer(actualV.Field(i).UnsafeAddr()))
|
||
|
fm := b.makeNewPointersOf(f).(map[string]interface{})
|
||
|
for k, p := range fm {
|
||
|
m[k] = p
|
||
|
}
|
||
|
} else {
|
||
|
var fm *field
|
||
|
fm = b.s.getField(actualV.Type().Field(i))
|
||
|
if fm == nil {
|
||
|
fm = fieldMetadata(actualV.Type().Field(i), b.s.columnConstraints)[0]
|
||
|
}
|
||
|
m[fm.Name] = reflect.NewAt(actualV.Field(i).Type(), unsafe.Pointer(actualV.Field(i).UnsafeAddr())).Interface()
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
return v.Addr().Interface()
|
||
|
}
|
||
|
|
||
|
return m
|
||
|
}
|
||
|
|
||
|
// ptrsFor first allocates for all struct fields recursively until reaches a driver.Value impl
|
||
|
// then it will put them in a map with their correct field name as key, then loops over cts
|
||
|
// and for each one gets appropriate one from the map and adds it to pointer list.
|
||
|
func (b *binder) ptrsFor(v reflect.Value, cts []*sql.ColumnType) []interface{} {
|
||
|
ptrs := b.makeNewPointersOf(v)
|
||
|
var scanInto []interface{}
|
||
|
if reflect.TypeOf(ptrs).Kind() == reflect.Map {
|
||
|
nameToPtr := ptrs.(map[string]interface{})
|
||
|
for _, ct := range cts {
|
||
|
if nameToPtr[ct.Name()] != nil {
|
||
|
scanInto = append(scanInto, nameToPtr[ct.Name()])
|
||
|
}
|
||
|
}
|
||
|
} else {
|
||
|
scanInto = append(scanInto, ptrs)
|
||
|
}
|
||
|
|
||
|
return scanInto
|
||
|
}
|
||
|
|
||
|
type binder struct {
|
||
|
s *schema
|
||
|
}
|
||
|
|
||
|
func newBinder(s *schema) *binder {
|
||
|
return &binder{s: s}
|
||
|
}
|
||
|
|
||
|
// bind binds given rows to the given object at obj. obj should be a pointer
|
||
|
func (b *binder) bind(rows *sql.Rows, obj interface{}) error {
|
||
|
cts, err := rows.ColumnTypes()
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
t := reflect.TypeOf(obj)
|
||
|
v := reflect.ValueOf(obj)
|
||
|
if t.Kind() != reflect.Ptr {
|
||
|
return fmt.Errorf("obj should be a ptr")
|
||
|
}
|
||
|
// since passed input is always a pointer one deref is necessary
|
||
|
t = t.Elem()
|
||
|
v = v.Elem()
|
||
|
if t.Kind() == reflect.Slice {
|
||
|
// getting slice elemnt type -> slice[t]
|
||
|
t = t.Elem()
|
||
|
for rows.Next() {
|
||
|
var rowValue reflect.Value
|
||
|
// Since reflect.SetupConnections returns a pointer to the type, we need to unwrap it to get actual
|
||
|
rowValue = reflect.New(t).Elem()
|
||
|
// till we reach a not pointer type continue newing the underlying type.
|
||
|
for rowValue.IsZero() && rowValue.Type().Kind() == reflect.Ptr {
|
||
|
rowValue = reflect.New(rowValue.Type().Elem()).Elem()
|
||
|
}
|
||
|
newCts := make([]*sql.ColumnType, len(cts))
|
||
|
copy(newCts, cts)
|
||
|
ptrs := b.ptrsFor(rowValue, newCts)
|
||
|
err = rows.Scan(ptrs...)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
for rowValue.Type() != t {
|
||
|
tmp := reflect.New(rowValue.Type())
|
||
|
tmp.Elem().Set(rowValue)
|
||
|
rowValue = tmp
|
||
|
}
|
||
|
v = reflect.Append(v, rowValue)
|
||
|
}
|
||
|
} else {
|
||
|
for rows.Next() {
|
||
|
ptrs := b.ptrsFor(v, cts)
|
||
|
err = rows.Scan(ptrs...)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
// v is either struct or slice
|
||
|
reflect.ValueOf(obj).Elem().Set(v)
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func bindToMap(rows *sql.Rows) ([]map[string]interface{}, error) {
|
||
|
cts, err := rows.ColumnTypes()
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
var ms []map[string]interface{}
|
||
|
for rows.Next() {
|
||
|
var ptrs []interface{}
|
||
|
for _, ct := range cts {
|
||
|
ptrs = append(ptrs, reflect.New(ct.ScanType()).Interface())
|
||
|
}
|
||
|
|
||
|
err = rows.Scan(ptrs...)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
m := map[string]interface{}{}
|
||
|
for i, ptr := range ptrs {
|
||
|
m[cts[i].Name()] = reflect.ValueOf(ptr).Elem().Interface()
|
||
|
}
|
||
|
|
||
|
ms = append(ms, m)
|
||
|
}
|
||
|
return ms, nil
|
||
|
}
|