golib/gcasbin/adapter_sqlx.go
2023-06-15 21:22:51 +08:00

750 lines
19 KiB
Go

//
// adapter_sqlx.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package gcasbin
import (
"bytes"
"context"
"errors"
"fmt"
"strconv"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/jmoiron/sqlx"
)
// defaultTableName if tableName == "", the Adapter will use this default table name.
const defaultTableName = "casbin_rule"
// maxParamLength .
const maxParamLength = 7
// general sql
const (
sqlCreateTable = `
CREATE TABLE %[1]s(
p_type VARCHAR(32),
v0 VARCHAR(255),
v1 VARCHAR(255),
v2 VARCHAR(255),
v3 VARCHAR(255),
v4 VARCHAR(255),
v5 VARCHAR(255)
);
CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlTruncateTable = "TRUNCATE TABLE %s"
sqlIsTableExist = "SELECT 1 FROM %s"
sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)"
sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
sqlDeleteAll = "DELETE FROM %s"
sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?"
sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s"
sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE "
)
// for Sqlite3
const (
sqlCreateTableSqlite3 = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL,
CHECK (TYPEOF("p_type") = "text" AND
LENGTH("p_type") <= 32),
CHECK (TYPEOF("v0") = "text" AND
LENGTH("v0") <= 255),
CHECK (TYPEOF("v1") = "text" AND
LENGTH("v1") <= 255),
CHECK (TYPEOF("v2") = "text" AND
LENGTH("v2") <= 255),
CHECK (TYPEOF("v3") = "text" AND
LENGTH("v3") <= 255),
CHECK (TYPEOF("v4") = "text" AND
LENGTH("v4") <= 255),
CHECK (TYPEOF("v5") = "text" AND
LENGTH("v5") <= 255)
);
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3
)
// for Mysql
const (
sqlCreateTableMysql = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL,
INDEX idx_%[1]s (p_type,v0,v1)
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;`
)
// for Postgres
const (
sqlCreateTablePostgres = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)"
sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14"
sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7"
)
// for Sqlserver
const (
sqlCreateTableSqlserver = `
CREATE TABLE %[1]s(
p_type NVARCHAR(32) DEFAULT '' NOT NULL,
v0 NVARCHAR(255) DEFAULT '' NOT NULL,
v1 NVARCHAR(255) DEFAULT '' NOT NULL,
v2 NVARCHAR(255) DEFAULT '' NOT NULL,
v3 NVARCHAR(255) DEFAULT '' NOT NULL,
v4 NVARCHAR(255) DEFAULT '' NOT NULL,
v5 NVARCHAR(255) DEFAULT '' NOT NULL
);
CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);`
sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)"
sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14"
sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7"
)
// CasbinRule defines the casbin rule model.
// It used for save or load policy lines from sqlx connected database.
type SqlCasbinRule struct {
PType string `db:"p_type"`
V0 string `db:"v0"`
V1 string `db:"v1"`
V2 string `db:"v2"`
V3 string `db:"v3"`
V4 string `db:"v4"`
V5 string `db:"v5"`
}
// Adapter define the sqlx adapter for Casbin.
// It can load policy lines or save policy lines from sqlx connected database.
type SqlAdapter struct {
db *sqlx.DB
ctx context.Context
tableName string
isFiltered bool
SqlCreateTable string
SqlTruncateTable string
SqlIsTableExist string
SqlInsertRow string
SqlUpdateRow string
SqlDeleteAll string
SqlDeleteRow string
SqlDeleteByArgs string
SqlSelectAll string
SqlSelectWhere string
}
// Filter defines the filtering rules for a FilteredAdapter's policy.
// Empty values are ignored, but all others must match the filter.
type SqlFilter struct {
PType []string
V0 []string
V1 []string
V2 []string
V3 []string
V4 []string
V5 []string
}
// NewAdapter the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) {
return NewSqlAdapterContext(context.Background(), db, tableName)
}
// NewAdapterContext the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) {
if db == nil {
return nil, errors.New("db is nil")
}
// check db connecting
err := db.PingContext(ctx)
if err != nil {
return nil, err
}
switch db.DriverName() {
case "oci8", "ora", "goracle":
return nil, errors.New("sqlxadapter: please checkout 'oracle' branch")
}
if tableName == "" {
tableName = defaultTableName
}
adapter := SqlAdapter{
db: db,
ctx: ctx,
tableName: tableName,
}
// generate different databases sql
adapter.genSQL()
if !adapter.IsTableExist() {
if err = adapter.CreateTable(); err != nil {
return nil, err
}
}
return &adapter, nil
}
// genSQL generate sql based on db driver name.
func (p *SqlAdapter) genSQL() {
p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName)
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName)
p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName)
p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName)
p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName)
p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName)
p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName)
switch p.db.DriverName() {
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName)
case "mysql":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName)
case "sqlite3":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName)
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName)
case "sqlserver":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName)
}
}
// createTable create a not exists table.
func (p *SqlAdapter) CreateTable() error {
_, err := p.db.ExecContext(p.ctx, p.SqlCreateTable)
return err
}
// truncateTable clear the table.
func (p *SqlAdapter) TruncateTable() error {
_, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable)
return err
}
// deleteAll clear the table.
func (p *SqlAdapter) DeleteAll() error {
_, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll)
return err
}
// isTableExist check the table exists.
func (p *SqlAdapter) IsTableExist() bool {
_, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist)
return err == nil
}
// deleteRows delete eligible data.
func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error {
query = p.db.Rebind(query)
_, err := p.db.ExecContext(p.ctx, query, args...)
return err
}
// truncateAndInsertRows clear table and insert new rows.
func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error {
if err := p.TruncateTable(); err != nil {
return err
}
return p.execTxSqlRows(p.SqlInsertRow, rules)
}
// deleteAllAndInsertRows clear table and insert new rows.
func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error {
if err := p.DeleteAll(); err != nil {
return err
}
return p.execTxSqlRows(p.SqlInsertRow, rules)
}
// execTxSqlRows exec sql rows.
func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) {
tx, err := p.db.BeginTx(p.ctx, nil)
if err != nil {
return
}
var action string
stmt, err := tx.PrepareContext(p.ctx, query)
if err != nil {
action = "prepare context"
goto ROLLBACK
}
for _, rule := range rules {
if _, err = stmt.ExecContext(p.ctx, rule...); err != nil {
action = "stmt exec"
goto ROLLBACK
}
}
if err = stmt.Close(); err != nil {
action = "stmt close"
goto ROLLBACK
}
if err = tx.Commit(); err != nil {
action = "commit"
goto ROLLBACK
}
return
ROLLBACK:
if err1 := tx.Rollback(); err1 != nil {
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
}
return
}
// selectRows select eligible data by args from the table.
func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) {
// make a slice with capacity
lines := make([]*SqlCasbinRule, 0, 64)
if len(args) == 0 {
return lines, p.db.SelectContext(p.ctx, &lines, query)
}
query = p.db.Rebind(query)
return lines, p.db.SelectContext(p.ctx, &lines, query, args...)
}
// selectWhereIn select eligible data by filter from the table.
func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlSelectWhere)
args := make([]interface{}, 0, 4)
hasInCond := false
for _, col := range [maxParamLength]struct {
name string
arg []string
}{
{"p_type", filter.PType},
{"v0", filter.V0},
{"v1", filter.V1},
{"v2", filter.V2},
{"v3", filter.V3},
{"v4", filter.V4},
{"v5", filter.V5},
} {
l := len(col.arg)
if l == 0 {
continue
}
switch sqlBuf.Bytes()[sqlBuf.Len()-1] {
case '?', ')':
sqlBuf.WriteString(" AND ")
}
sqlBuf.WriteString(col.name)
if l == 1 {
sqlBuf.WriteString("=?")
args = append(args, col.arg[0])
} else {
sqlBuf.WriteString(" IN (?)")
args = append(args, col.arg)
hasInCond = true
}
}
var query string
if hasInCond {
if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil {
return
}
} else {
query = sqlBuf.String()
}
return p.SelectRows(query, args...)
}
// LoadPolicy load all policy rules from the storage.
func (p *SqlAdapter) LoadPolicy(model model.Model) error {
lines, err := p.SelectRows(p.SqlSelectAll)
if err != nil {
return err
}
for _, line := range lines {
p.loadPolicyLine(line, model)
}
return nil
}
// SavePolicy save policy rules to the storage.
func (p *SqlAdapter) SavePolicy(model model.Model) error {
args := make([][]interface{}, 0, 64)
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
}
return p.DeleteAllAndInsertRows(args)
}
// AddPolicy add one policy rule to the storage.
func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error {
args := p.GenArgs(ptype, rule)
_, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...)
return err
}
// AddPolicies add multiple policy rules to the storage.
func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
args := make([][]interface{}, 0, 8)
for _, rule := range rules {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
return p.execTxSqlRows(p.SqlInsertRow, args)
}
// RemovePolicy remove policy rules from the storage.
func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlDeleteByArgs)
args := make([]interface{}, 0, 4)
args = append(args, ptype)
for idx, arg := range rule {
if arg != "" {
sqlBuf.WriteString(" AND v")
sqlBuf.WriteString(strconv.Itoa(idx))
sqlBuf.WriteString("=?")
args = append(args, arg)
}
}
return p.DeleteRows(sqlBuf.String(), args...)
}
// RemoveFilteredPolicy remove policy rules that match the filter from the storage.
func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlDeleteByArgs)
args := make([]interface{}, 0, 4)
args = append(args, ptype)
var value string
l := fieldIndex + len(fieldValues)
for idx := 0; idx < 6; idx++ {
if fieldIndex <= idx && idx < l {
value = fieldValues[idx-fieldIndex]
if value != "" {
sqlBuf.WriteString(" AND v")
sqlBuf.WriteString(strconv.Itoa(idx))
sqlBuf.WriteString("=?")
args = append(args, value)
}
}
}
return p.DeleteRows(sqlBuf.String(), args...)
}
// RemovePolicies remove policy rules.
func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
args := make([][]interface{}, 0, 8)
for _, rule := range rules {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
return p.execTxSqlRows(p.SqlDeleteRow, args)
}
// LoadFilteredPolicy load policy rules that match the filter.
// filterPtr must be a pointer.
func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error {
if filterPtr == nil {
return p.LoadPolicy(model)
}
filter, ok := filterPtr.(*SqlFilter)
if !ok {
return errors.New("invalid filter type")
}
lines, err := p.SelectWhereIn(filter)
if err != nil {
return err
}
for _, line := range lines {
p.loadPolicyLine(line, model)
}
p.isFiltered = true
return nil
}
// IsFiltered returns true if the loaded policy rules has been filtered.
func (p *SqlAdapter) IsFiltered() bool {
return p.isFiltered
}
// UpdatePolicy update a policy rule from storage.
// This is part of the Auto-Save feature.
func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error {
oldArg := p.GenArgs(ptype, oldRule)
newArg := p.GenArgs(ptype, newPolicy)
_, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...)
return err
}
// UpdatePolicies updates policy rules to storage.
func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) {
if len(oldRules) != len(newRules) {
return errors.New("old rules size not equal to new rules size")
}
args := make([][]interface{}, 0, 16)
for idx := range oldRules {
oldArg := p.GenArgs(ptype, oldRules[idx])
newArg := p.GenArgs(ptype, newRules[idx])
args = append(args, append(newArg, oldArg...))
}
return p.execTxSqlRows(p.SqlUpdateRow, args)
}
// UpdateFilteredPolicies deletes old rules and adds new rules.
func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) {
var value string
var whereBuf bytes.Buffer
whereBuf.Grow(32)
l := fieldIndex + len(fieldValues)
whereArgs := make([]interface{}, 0, 4)
whereArgs = append(whereArgs, ptype)
for idx := 0; idx < 6; idx++ {
if fieldIndex <= idx && idx < l {
value = fieldValues[idx-fieldIndex]
if value != "" {
whereBuf.WriteString(" AND v")
whereBuf.WriteString(strconv.Itoa(idx))
whereBuf.WriteString("=?")
whereArgs = append(whereArgs, value)
}
}
}
var selectBuf bytes.Buffer
selectBuf.Grow(64)
selectBuf.WriteString(p.SqlSelectWhere)
selectBuf.WriteString("p_type=?")
selectBuf.Write(whereBuf.Bytes())
var oldRows []*SqlCasbinRule
value = p.db.Rebind(selectBuf.String())
oldRows, err = p.SelectRows(value, whereArgs...)
if err != nil {
return
}
var deleteBuf bytes.Buffer
deleteBuf.Grow(64)
deleteBuf.WriteString(p.SqlDeleteByArgs)
deleteBuf.Write(whereBuf.Bytes())
var tx *sqlx.Tx
tx, err = p.db.BeginTxx(p.ctx, nil)
if err != nil {
return
}
var (
stmt *sqlx.Stmt
action string
)
value = p.db.Rebind(deleteBuf.String())
if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil {
action = "delete old policies"
goto ROLLBACK
}
stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow)
if err != nil {
action = "preparex context"
goto ROLLBACK
}
for _, policy := range newPolicies {
arg := p.GenArgs(ptype, policy)
if _, err = stmt.ExecContext(p.ctx, arg...); err != nil {
action = "stmt exec context"
goto ROLLBACK
}
}
if err = stmt.Close(); err != nil {
action = "stmt close"
goto ROLLBACK
}
if err = tx.Commit(); err != nil {
action = "commit"
goto ROLLBACK
}
oldPolicies = make([][]string, 0, len(oldRows))
for _, rule := range oldRows {
oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
}
return
ROLLBACK:
if err1 := tx.Rollback(); err1 != nil {
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
}
return
}
// loadPolicyLine load a policy line to model.
func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) {
if line == nil {
return
}
var lineBuf bytes.Buffer
lineBuf.Grow(64)
lineBuf.WriteString(line.PType)
args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
for _, arg := range args {
if arg != "" {
lineBuf.WriteByte(',')
lineBuf.WriteString(arg)
}
}
persist.LoadPolicyLine(lineBuf.String(), model)
}
// genArgs generate args from ptype and rule.
func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} {
l := len(rule)
args := make([]interface{}, maxParamLength)
args[0] = ptype
for idx := 0; idx < l; idx++ {
args[idx+1] = rule[idx]
}
for idx := l + 1; idx < maxParamLength; idx++ {
args[idx] = ""
}
return args
}