750 lines
19 KiB
Go
750 lines
19 KiB
Go
|
//
|
||
|
// adapter_sqlx.go
|
||
|
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
||
|
//
|
||
|
// Distributed under terms of the MIT license.
|
||
|
//
|
||
|
|
||
|
package gcasbin
|
||
|
|
||
|
import (
|
||
|
"bytes"
|
||
|
"context"
|
||
|
"errors"
|
||
|
"fmt"
|
||
|
"strconv"
|
||
|
|
||
|
"github.com/casbin/casbin/v2/model"
|
||
|
"github.com/casbin/casbin/v2/persist"
|
||
|
"github.com/jmoiron/sqlx"
|
||
|
)
|
||
|
|
||
|
// defaultTableName if tableName == "", the Adapter will use this default table name.
|
||
|
const defaultTableName = "casbin_rule"
|
||
|
|
||
|
// maxParamLength .
|
||
|
const maxParamLength = 7
|
||
|
|
||
|
// general sql
|
||
|
const (
|
||
|
sqlCreateTable = `
|
||
|
CREATE TABLE %[1]s(
|
||
|
p_type VARCHAR(32),
|
||
|
v0 VARCHAR(255),
|
||
|
v1 VARCHAR(255),
|
||
|
v2 VARCHAR(255),
|
||
|
v3 VARCHAR(255),
|
||
|
v4 VARCHAR(255),
|
||
|
v5 VARCHAR(255)
|
||
|
);
|
||
|
CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
||
|
sqlTruncateTable = "TRUNCATE TABLE %s"
|
||
|
sqlIsTableExist = "SELECT 1 FROM %s"
|
||
|
sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)"
|
||
|
sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
|
||
|
sqlDeleteAll = "DELETE FROM %s"
|
||
|
sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
|
||
|
sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?"
|
||
|
sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s"
|
||
|
sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE "
|
||
|
)
|
||
|
|
||
|
// for Sqlite3
|
||
|
const (
|
||
|
sqlCreateTableSqlite3 = `
|
||
|
CREATE TABLE IF NOT EXISTS %[1]s(
|
||
|
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
||
|
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v5 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
CHECK (TYPEOF("p_type") = "text" AND
|
||
|
LENGTH("p_type") <= 32),
|
||
|
CHECK (TYPEOF("v0") = "text" AND
|
||
|
LENGTH("v0") <= 255),
|
||
|
CHECK (TYPEOF("v1") = "text" AND
|
||
|
LENGTH("v1") <= 255),
|
||
|
CHECK (TYPEOF("v2") = "text" AND
|
||
|
LENGTH("v2") <= 255),
|
||
|
CHECK (TYPEOF("v3") = "text" AND
|
||
|
LENGTH("v3") <= 255),
|
||
|
CHECK (TYPEOF("v4") = "text" AND
|
||
|
LENGTH("v4") <= 255),
|
||
|
CHECK (TYPEOF("v5") = "text" AND
|
||
|
LENGTH("v5") <= 255)
|
||
|
);
|
||
|
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
||
|
sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3
|
||
|
)
|
||
|
|
||
|
// for Mysql
|
||
|
const (
|
||
|
sqlCreateTableMysql = `
|
||
|
CREATE TABLE IF NOT EXISTS %[1]s(
|
||
|
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
||
|
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v5 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
INDEX idx_%[1]s (p_type,v0,v1)
|
||
|
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;`
|
||
|
)
|
||
|
|
||
|
// for Postgres
|
||
|
const (
|
||
|
sqlCreateTablePostgres = `
|
||
|
CREATE TABLE IF NOT EXISTS %[1]s(
|
||
|
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
||
|
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v5 VARCHAR(255) DEFAULT '' NOT NULL
|
||
|
);
|
||
|
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
||
|
sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)"
|
||
|
sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14"
|
||
|
sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7"
|
||
|
)
|
||
|
|
||
|
// for Sqlserver
|
||
|
const (
|
||
|
sqlCreateTableSqlserver = `
|
||
|
CREATE TABLE %[1]s(
|
||
|
p_type NVARCHAR(32) DEFAULT '' NOT NULL,
|
||
|
v0 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v1 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v2 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v3 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v4 NVARCHAR(255) DEFAULT '' NOT NULL,
|
||
|
v5 NVARCHAR(255) DEFAULT '' NOT NULL
|
||
|
);
|
||
|
CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);`
|
||
|
sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)"
|
||
|
sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14"
|
||
|
sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7"
|
||
|
)
|
||
|
|
||
|
// CasbinRule defines the casbin rule model.
|
||
|
// It used for save or load policy lines from sqlx connected database.
|
||
|
type SqlCasbinRule struct {
|
||
|
PType string `db:"p_type"`
|
||
|
V0 string `db:"v0"`
|
||
|
V1 string `db:"v1"`
|
||
|
V2 string `db:"v2"`
|
||
|
V3 string `db:"v3"`
|
||
|
V4 string `db:"v4"`
|
||
|
V5 string `db:"v5"`
|
||
|
}
|
||
|
|
||
|
// Adapter define the sqlx adapter for Casbin.
|
||
|
// It can load policy lines or save policy lines from sqlx connected database.
|
||
|
type SqlAdapter struct {
|
||
|
db *sqlx.DB
|
||
|
ctx context.Context
|
||
|
tableName string
|
||
|
|
||
|
isFiltered bool
|
||
|
|
||
|
SqlCreateTable string
|
||
|
SqlTruncateTable string
|
||
|
SqlIsTableExist string
|
||
|
SqlInsertRow string
|
||
|
SqlUpdateRow string
|
||
|
SqlDeleteAll string
|
||
|
SqlDeleteRow string
|
||
|
SqlDeleteByArgs string
|
||
|
SqlSelectAll string
|
||
|
SqlSelectWhere string
|
||
|
}
|
||
|
|
||
|
// Filter defines the filtering rules for a FilteredAdapter's policy.
|
||
|
// Empty values are ignored, but all others must match the filter.
|
||
|
type SqlFilter struct {
|
||
|
PType []string
|
||
|
V0 []string
|
||
|
V1 []string
|
||
|
V2 []string
|
||
|
V3 []string
|
||
|
V4 []string
|
||
|
V5 []string
|
||
|
}
|
||
|
|
||
|
// NewAdapter the constructor for Adapter.
|
||
|
// db should connected to database and controlled by user.
|
||
|
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
|
||
|
func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) {
|
||
|
return NewSqlAdapterContext(context.Background(), db, tableName)
|
||
|
}
|
||
|
|
||
|
// NewAdapterContext the constructor for Adapter.
|
||
|
// db should connected to database and controlled by user.
|
||
|
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
|
||
|
func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) {
|
||
|
if db == nil {
|
||
|
return nil, errors.New("db is nil")
|
||
|
}
|
||
|
|
||
|
// check db connecting
|
||
|
err := db.PingContext(ctx)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
|
||
|
switch db.DriverName() {
|
||
|
case "oci8", "ora", "goracle":
|
||
|
return nil, errors.New("sqlxadapter: please checkout 'oracle' branch")
|
||
|
}
|
||
|
|
||
|
if tableName == "" {
|
||
|
tableName = defaultTableName
|
||
|
}
|
||
|
|
||
|
adapter := SqlAdapter{
|
||
|
db: db,
|
||
|
ctx: ctx,
|
||
|
tableName: tableName,
|
||
|
}
|
||
|
|
||
|
// generate different databases sql
|
||
|
adapter.genSQL()
|
||
|
|
||
|
if !adapter.IsTableExist() {
|
||
|
if err = adapter.CreateTable(); err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return &adapter, nil
|
||
|
}
|
||
|
|
||
|
// genSQL generate sql based on db driver name.
|
||
|
func (p *SqlAdapter) genSQL() {
|
||
|
p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName)
|
||
|
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName)
|
||
|
|
||
|
p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName)
|
||
|
|
||
|
p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName)
|
||
|
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName)
|
||
|
p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName)
|
||
|
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName)
|
||
|
p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName)
|
||
|
|
||
|
p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName)
|
||
|
p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName)
|
||
|
|
||
|
switch p.db.DriverName() {
|
||
|
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
|
||
|
p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName)
|
||
|
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName)
|
||
|
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName)
|
||
|
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName)
|
||
|
case "mysql":
|
||
|
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName)
|
||
|
case "sqlite3":
|
||
|
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName)
|
||
|
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName)
|
||
|
case "sqlserver":
|
||
|
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName)
|
||
|
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName)
|
||
|
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName)
|
||
|
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
// createTable create a not exists table.
|
||
|
func (p *SqlAdapter) CreateTable() error {
|
||
|
_, err := p.db.ExecContext(p.ctx, p.SqlCreateTable)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// truncateTable clear the table.
|
||
|
func (p *SqlAdapter) TruncateTable() error {
|
||
|
_, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// deleteAll clear the table.
|
||
|
func (p *SqlAdapter) DeleteAll() error {
|
||
|
_, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// isTableExist check the table exists.
|
||
|
func (p *SqlAdapter) IsTableExist() bool {
|
||
|
_, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist)
|
||
|
|
||
|
return err == nil
|
||
|
}
|
||
|
|
||
|
// deleteRows delete eligible data.
|
||
|
func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error {
|
||
|
query = p.db.Rebind(query)
|
||
|
|
||
|
_, err := p.db.ExecContext(p.ctx, query, args...)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// truncateAndInsertRows clear table and insert new rows.
|
||
|
func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error {
|
||
|
if err := p.TruncateTable(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return p.execTxSqlRows(p.SqlInsertRow, rules)
|
||
|
}
|
||
|
|
||
|
// deleteAllAndInsertRows clear table and insert new rows.
|
||
|
func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error {
|
||
|
if err := p.DeleteAll(); err != nil {
|
||
|
return err
|
||
|
}
|
||
|
return p.execTxSqlRows(p.SqlInsertRow, rules)
|
||
|
}
|
||
|
|
||
|
// execTxSqlRows exec sql rows.
|
||
|
func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) {
|
||
|
tx, err := p.db.BeginTx(p.ctx, nil)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var action string
|
||
|
|
||
|
stmt, err := tx.PrepareContext(p.ctx, query)
|
||
|
if err != nil {
|
||
|
action = "prepare context"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
for _, rule := range rules {
|
||
|
if _, err = stmt.ExecContext(p.ctx, rule...); err != nil {
|
||
|
action = "stmt exec"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err = stmt.Close(); err != nil {
|
||
|
action = "stmt close"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
if err = tx.Commit(); err != nil {
|
||
|
action = "commit"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
return
|
||
|
|
||
|
ROLLBACK:
|
||
|
|
||
|
if err1 := tx.Rollback(); err1 != nil {
|
||
|
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// selectRows select eligible data by args from the table.
|
||
|
func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) {
|
||
|
// make a slice with capacity
|
||
|
lines := make([]*SqlCasbinRule, 0, 64)
|
||
|
|
||
|
if len(args) == 0 {
|
||
|
return lines, p.db.SelectContext(p.ctx, &lines, query)
|
||
|
}
|
||
|
|
||
|
query = p.db.Rebind(query)
|
||
|
|
||
|
return lines, p.db.SelectContext(p.ctx, &lines, query, args...)
|
||
|
}
|
||
|
|
||
|
// selectWhereIn select eligible data by filter from the table.
|
||
|
func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) {
|
||
|
var sqlBuf bytes.Buffer
|
||
|
|
||
|
sqlBuf.Grow(64)
|
||
|
sqlBuf.WriteString(p.SqlSelectWhere)
|
||
|
|
||
|
args := make([]interface{}, 0, 4)
|
||
|
|
||
|
hasInCond := false
|
||
|
|
||
|
for _, col := range [maxParamLength]struct {
|
||
|
name string
|
||
|
arg []string
|
||
|
}{
|
||
|
{"p_type", filter.PType},
|
||
|
{"v0", filter.V0},
|
||
|
{"v1", filter.V1},
|
||
|
{"v2", filter.V2},
|
||
|
{"v3", filter.V3},
|
||
|
{"v4", filter.V4},
|
||
|
{"v5", filter.V5},
|
||
|
} {
|
||
|
l := len(col.arg)
|
||
|
if l == 0 {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
switch sqlBuf.Bytes()[sqlBuf.Len()-1] {
|
||
|
case '?', ')':
|
||
|
sqlBuf.WriteString(" AND ")
|
||
|
}
|
||
|
|
||
|
sqlBuf.WriteString(col.name)
|
||
|
|
||
|
if l == 1 {
|
||
|
sqlBuf.WriteString("=?")
|
||
|
args = append(args, col.arg[0])
|
||
|
} else {
|
||
|
sqlBuf.WriteString(" IN (?)")
|
||
|
args = append(args, col.arg)
|
||
|
|
||
|
hasInCond = true
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var query string
|
||
|
|
||
|
if hasInCond {
|
||
|
if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil {
|
||
|
return
|
||
|
}
|
||
|
} else {
|
||
|
query = sqlBuf.String()
|
||
|
}
|
||
|
|
||
|
return p.SelectRows(query, args...)
|
||
|
}
|
||
|
|
||
|
// LoadPolicy load all policy rules from the storage.
|
||
|
func (p *SqlAdapter) LoadPolicy(model model.Model) error {
|
||
|
lines, err := p.SelectRows(p.SqlSelectAll)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for _, line := range lines {
|
||
|
p.loadPolicyLine(line, model)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// SavePolicy save policy rules to the storage.
|
||
|
func (p *SqlAdapter) SavePolicy(model model.Model) error {
|
||
|
args := make([][]interface{}, 0, 64)
|
||
|
|
||
|
for ptype, ast := range model["p"] {
|
||
|
for _, rule := range ast.Policy {
|
||
|
arg := p.GenArgs(ptype, rule)
|
||
|
args = append(args, arg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
for ptype, ast := range model["g"] {
|
||
|
for _, rule := range ast.Policy {
|
||
|
arg := p.GenArgs(ptype, rule)
|
||
|
args = append(args, arg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return p.DeleteAllAndInsertRows(args)
|
||
|
}
|
||
|
|
||
|
// AddPolicy add one policy rule to the storage.
|
||
|
func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error {
|
||
|
args := p.GenArgs(ptype, rule)
|
||
|
|
||
|
_, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// AddPolicies add multiple policy rules to the storage.
|
||
|
func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
|
||
|
args := make([][]interface{}, 0, 8)
|
||
|
|
||
|
for _, rule := range rules {
|
||
|
arg := p.GenArgs(ptype, rule)
|
||
|
args = append(args, arg)
|
||
|
}
|
||
|
|
||
|
return p.execTxSqlRows(p.SqlInsertRow, args)
|
||
|
}
|
||
|
|
||
|
// RemovePolicy remove policy rules from the storage.
|
||
|
func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
|
||
|
var sqlBuf bytes.Buffer
|
||
|
|
||
|
sqlBuf.Grow(64)
|
||
|
sqlBuf.WriteString(p.SqlDeleteByArgs)
|
||
|
|
||
|
args := make([]interface{}, 0, 4)
|
||
|
args = append(args, ptype)
|
||
|
|
||
|
for idx, arg := range rule {
|
||
|
if arg != "" {
|
||
|
sqlBuf.WriteString(" AND v")
|
||
|
sqlBuf.WriteString(strconv.Itoa(idx))
|
||
|
sqlBuf.WriteString("=?")
|
||
|
|
||
|
args = append(args, arg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return p.DeleteRows(sqlBuf.String(), args...)
|
||
|
}
|
||
|
|
||
|
// RemoveFilteredPolicy remove policy rules that match the filter from the storage.
|
||
|
func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
|
||
|
var sqlBuf bytes.Buffer
|
||
|
|
||
|
sqlBuf.Grow(64)
|
||
|
sqlBuf.WriteString(p.SqlDeleteByArgs)
|
||
|
|
||
|
args := make([]interface{}, 0, 4)
|
||
|
args = append(args, ptype)
|
||
|
|
||
|
var value string
|
||
|
|
||
|
l := fieldIndex + len(fieldValues)
|
||
|
|
||
|
for idx := 0; idx < 6; idx++ {
|
||
|
if fieldIndex <= idx && idx < l {
|
||
|
value = fieldValues[idx-fieldIndex]
|
||
|
|
||
|
if value != "" {
|
||
|
sqlBuf.WriteString(" AND v")
|
||
|
sqlBuf.WriteString(strconv.Itoa(idx))
|
||
|
sqlBuf.WriteString("=?")
|
||
|
|
||
|
args = append(args, value)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return p.DeleteRows(sqlBuf.String(), args...)
|
||
|
}
|
||
|
|
||
|
// RemovePolicies remove policy rules.
|
||
|
func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
|
||
|
args := make([][]interface{}, 0, 8)
|
||
|
|
||
|
for _, rule := range rules {
|
||
|
arg := p.GenArgs(ptype, rule)
|
||
|
args = append(args, arg)
|
||
|
}
|
||
|
|
||
|
return p.execTxSqlRows(p.SqlDeleteRow, args)
|
||
|
}
|
||
|
|
||
|
// LoadFilteredPolicy load policy rules that match the filter.
|
||
|
// filterPtr must be a pointer.
|
||
|
func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error {
|
||
|
if filterPtr == nil {
|
||
|
return p.LoadPolicy(model)
|
||
|
}
|
||
|
|
||
|
filter, ok := filterPtr.(*SqlFilter)
|
||
|
if !ok {
|
||
|
return errors.New("invalid filter type")
|
||
|
}
|
||
|
|
||
|
lines, err := p.SelectWhereIn(filter)
|
||
|
if err != nil {
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
for _, line := range lines {
|
||
|
p.loadPolicyLine(line, model)
|
||
|
}
|
||
|
|
||
|
p.isFiltered = true
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
// IsFiltered returns true if the loaded policy rules has been filtered.
|
||
|
func (p *SqlAdapter) IsFiltered() bool {
|
||
|
return p.isFiltered
|
||
|
}
|
||
|
|
||
|
// UpdatePolicy update a policy rule from storage.
|
||
|
// This is part of the Auto-Save feature.
|
||
|
func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error {
|
||
|
oldArg := p.GenArgs(ptype, oldRule)
|
||
|
newArg := p.GenArgs(ptype, newPolicy)
|
||
|
|
||
|
_, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...)
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
// UpdatePolicies updates policy rules to storage.
|
||
|
func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) {
|
||
|
if len(oldRules) != len(newRules) {
|
||
|
return errors.New("old rules size not equal to new rules size")
|
||
|
}
|
||
|
|
||
|
args := make([][]interface{}, 0, 16)
|
||
|
|
||
|
for idx := range oldRules {
|
||
|
oldArg := p.GenArgs(ptype, oldRules[idx])
|
||
|
newArg := p.GenArgs(ptype, newRules[idx])
|
||
|
args = append(args, append(newArg, oldArg...))
|
||
|
}
|
||
|
|
||
|
return p.execTxSqlRows(p.SqlUpdateRow, args)
|
||
|
}
|
||
|
|
||
|
// UpdateFilteredPolicies deletes old rules and adds new rules.
|
||
|
func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) {
|
||
|
var value string
|
||
|
|
||
|
var whereBuf bytes.Buffer
|
||
|
whereBuf.Grow(32)
|
||
|
|
||
|
l := fieldIndex + len(fieldValues)
|
||
|
|
||
|
whereArgs := make([]interface{}, 0, 4)
|
||
|
whereArgs = append(whereArgs, ptype)
|
||
|
|
||
|
for idx := 0; idx < 6; idx++ {
|
||
|
if fieldIndex <= idx && idx < l {
|
||
|
value = fieldValues[idx-fieldIndex]
|
||
|
|
||
|
if value != "" {
|
||
|
whereBuf.WriteString(" AND v")
|
||
|
whereBuf.WriteString(strconv.Itoa(idx))
|
||
|
whereBuf.WriteString("=?")
|
||
|
|
||
|
whereArgs = append(whereArgs, value)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
var selectBuf bytes.Buffer
|
||
|
selectBuf.Grow(64)
|
||
|
selectBuf.WriteString(p.SqlSelectWhere)
|
||
|
selectBuf.WriteString("p_type=?")
|
||
|
selectBuf.Write(whereBuf.Bytes())
|
||
|
|
||
|
var oldRows []*SqlCasbinRule
|
||
|
value = p.db.Rebind(selectBuf.String())
|
||
|
oldRows, err = p.SelectRows(value, whereArgs...)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var deleteBuf bytes.Buffer
|
||
|
deleteBuf.Grow(64)
|
||
|
deleteBuf.WriteString(p.SqlDeleteByArgs)
|
||
|
deleteBuf.Write(whereBuf.Bytes())
|
||
|
|
||
|
var tx *sqlx.Tx
|
||
|
tx, err = p.db.BeginTxx(p.ctx, nil)
|
||
|
if err != nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var (
|
||
|
stmt *sqlx.Stmt
|
||
|
action string
|
||
|
)
|
||
|
value = p.db.Rebind(deleteBuf.String())
|
||
|
if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil {
|
||
|
action = "delete old policies"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow)
|
||
|
if err != nil {
|
||
|
action = "preparex context"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
for _, policy := range newPolicies {
|
||
|
arg := p.GenArgs(ptype, policy)
|
||
|
if _, err = stmt.ExecContext(p.ctx, arg...); err != nil {
|
||
|
action = "stmt exec context"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
}
|
||
|
|
||
|
if err = stmt.Close(); err != nil {
|
||
|
action = "stmt close"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
if err = tx.Commit(); err != nil {
|
||
|
action = "commit"
|
||
|
goto ROLLBACK
|
||
|
}
|
||
|
|
||
|
oldPolicies = make([][]string, 0, len(oldRows))
|
||
|
for _, rule := range oldRows {
|
||
|
oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
|
||
|
}
|
||
|
|
||
|
return
|
||
|
|
||
|
ROLLBACK:
|
||
|
|
||
|
if err1 := tx.Rollback(); err1 != nil {
|
||
|
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
|
||
|
}
|
||
|
|
||
|
return
|
||
|
}
|
||
|
|
||
|
// loadPolicyLine load a policy line to model.
|
||
|
func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) {
|
||
|
if line == nil {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
var lineBuf bytes.Buffer
|
||
|
|
||
|
lineBuf.Grow(64)
|
||
|
lineBuf.WriteString(line.PType)
|
||
|
|
||
|
args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
|
||
|
for _, arg := range args {
|
||
|
if arg != "" {
|
||
|
lineBuf.WriteByte(',')
|
||
|
lineBuf.WriteString(arg)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
persist.LoadPolicyLine(lineBuf.String(), model)
|
||
|
}
|
||
|
|
||
|
// genArgs generate args from ptype and rule.
|
||
|
func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} {
|
||
|
l := len(rule)
|
||
|
|
||
|
args := make([]interface{}, maxParamLength)
|
||
|
args[0] = ptype
|
||
|
|
||
|
for idx := 0; idx < l; idx++ {
|
||
|
args[idx+1] = rule[idx]
|
||
|
}
|
||
|
|
||
|
for idx := l + 1; idx < maxParamLength; idx++ {
|
||
|
args[idx] = ""
|
||
|
}
|
||
|
|
||
|
return args
|
||
|
}
|