chore: 移除 sqlx

This commit is contained in:
tiglog 2023-08-17 17:13:45 +08:00
parent 36f9acb8f7
commit 2822110dec
2 changed files with 0 additions and 1215 deletions

View File

@ -1,749 +0,0 @@
//
// adapter_sqlx.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package gcasbin
import (
"bytes"
"context"
"errors"
"fmt"
"strconv"
"github.com/casbin/casbin/v2/model"
"github.com/casbin/casbin/v2/persist"
"github.com/jmoiron/sqlx"
)
// defaultTableName if tableName == "", the Adapter will use this default table name.
const defaultTableName = "casbin_rule"
// maxParamLength .
const maxParamLength = 7
// general sql
const (
sqlCreateTable = `
CREATE TABLE %[1]s(
p_type VARCHAR(32),
v0 VARCHAR(255),
v1 VARCHAR(255),
v2 VARCHAR(255),
v3 VARCHAR(255),
v4 VARCHAR(255),
v5 VARCHAR(255)
);
CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlTruncateTable = "TRUNCATE TABLE %s"
sqlIsTableExist = "SELECT 1 FROM %s"
sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)"
sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
sqlDeleteAll = "DELETE FROM %s"
sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?"
sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s"
sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE "
)
// for Sqlite3
const (
sqlCreateTableSqlite3 = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL,
CHECK (TYPEOF("p_type") = "text" AND
LENGTH("p_type") <= 32),
CHECK (TYPEOF("v0") = "text" AND
LENGTH("v0") <= 255),
CHECK (TYPEOF("v1") = "text" AND
LENGTH("v1") <= 255),
CHECK (TYPEOF("v2") = "text" AND
LENGTH("v2") <= 255),
CHECK (TYPEOF("v3") = "text" AND
LENGTH("v3") <= 255),
CHECK (TYPEOF("v4") = "text" AND
LENGTH("v4") <= 255),
CHECK (TYPEOF("v5") = "text" AND
LENGTH("v5") <= 255)
);
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3
)
// for Mysql
const (
sqlCreateTableMysql = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL,
INDEX idx_%[1]s (p_type,v0,v1)
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;`
)
// for Postgres
const (
sqlCreateTablePostgres = `
CREATE TABLE IF NOT EXISTS %[1]s(
p_type VARCHAR(32) DEFAULT '' NOT NULL,
v0 VARCHAR(255) DEFAULT '' NOT NULL,
v1 VARCHAR(255) DEFAULT '' NOT NULL,
v2 VARCHAR(255) DEFAULT '' NOT NULL,
v3 VARCHAR(255) DEFAULT '' NOT NULL,
v4 VARCHAR(255) DEFAULT '' NOT NULL,
v5 VARCHAR(255) DEFAULT '' NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)"
sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14"
sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7"
)
// for Sqlserver
const (
sqlCreateTableSqlserver = `
CREATE TABLE %[1]s(
p_type NVARCHAR(32) DEFAULT '' NOT NULL,
v0 NVARCHAR(255) DEFAULT '' NOT NULL,
v1 NVARCHAR(255) DEFAULT '' NOT NULL,
v2 NVARCHAR(255) DEFAULT '' NOT NULL,
v3 NVARCHAR(255) DEFAULT '' NOT NULL,
v4 NVARCHAR(255) DEFAULT '' NOT NULL,
v5 NVARCHAR(255) DEFAULT '' NOT NULL
);
CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);`
sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)"
sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14"
sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7"
)
// CasbinRule defines the casbin rule model.
// It used for save or load policy lines from sqlx connected database.
type SqlCasbinRule struct {
PType string `db:"p_type"`
V0 string `db:"v0"`
V1 string `db:"v1"`
V2 string `db:"v2"`
V3 string `db:"v3"`
V4 string `db:"v4"`
V5 string `db:"v5"`
}
// Adapter define the sqlx adapter for Casbin.
// It can load policy lines or save policy lines from sqlx connected database.
type SqlAdapter struct {
db *sqlx.DB
ctx context.Context
tableName string
isFiltered bool
SqlCreateTable string
SqlTruncateTable string
SqlIsTableExist string
SqlInsertRow string
SqlUpdateRow string
SqlDeleteAll string
SqlDeleteRow string
SqlDeleteByArgs string
SqlSelectAll string
SqlSelectWhere string
}
// Filter defines the filtering rules for a FilteredAdapter's policy.
// Empty values are ignored, but all others must match the filter.
type SqlFilter struct {
PType []string
V0 []string
V1 []string
V2 []string
V3 []string
V4 []string
V5 []string
}
// NewAdapter the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) {
return NewSqlAdapterContext(context.Background(), db, tableName)
}
// NewAdapterContext the constructor for Adapter.
// db should connected to database and controlled by user.
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) {
if db == nil {
return nil, errors.New("db is nil")
}
// check db connecting
err := db.PingContext(ctx)
if err != nil {
return nil, err
}
switch db.DriverName() {
case "oci8", "ora", "goracle":
return nil, errors.New("sqlxadapter: please checkout 'oracle' branch")
}
if tableName == "" {
tableName = defaultTableName
}
adapter := SqlAdapter{
db: db,
ctx: ctx,
tableName: tableName,
}
// generate different databases sql
adapter.genSQL()
if !adapter.IsTableExist() {
if err = adapter.CreateTable(); err != nil {
return nil, err
}
}
return &adapter, nil
}
// genSQL generate sql based on db driver name.
func (p *SqlAdapter) genSQL() {
p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName)
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName)
p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName)
p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName)
p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName)
p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName)
p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName)
switch p.db.DriverName() {
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName)
case "mysql":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName)
case "sqlite3":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName)
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName)
case "sqlserver":
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName)
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName)
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName)
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName)
}
}
// createTable create a not exists table.
func (p *SqlAdapter) CreateTable() error {
_, err := p.db.ExecContext(p.ctx, p.SqlCreateTable)
return err
}
// truncateTable clear the table.
func (p *SqlAdapter) TruncateTable() error {
_, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable)
return err
}
// deleteAll clear the table.
func (p *SqlAdapter) DeleteAll() error {
_, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll)
return err
}
// isTableExist check the table exists.
func (p *SqlAdapter) IsTableExist() bool {
_, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist)
return err == nil
}
// deleteRows delete eligible data.
func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error {
query = p.db.Rebind(query)
_, err := p.db.ExecContext(p.ctx, query, args...)
return err
}
// truncateAndInsertRows clear table and insert new rows.
func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error {
if err := p.TruncateTable(); err != nil {
return err
}
return p.execTxSqlRows(p.SqlInsertRow, rules)
}
// deleteAllAndInsertRows clear table and insert new rows.
func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error {
if err := p.DeleteAll(); err != nil {
return err
}
return p.execTxSqlRows(p.SqlInsertRow, rules)
}
// execTxSqlRows exec sql rows.
func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) {
tx, err := p.db.BeginTx(p.ctx, nil)
if err != nil {
return
}
var action string
stmt, err := tx.PrepareContext(p.ctx, query)
if err != nil {
action = "prepare context"
goto ROLLBACK
}
for _, rule := range rules {
if _, err = stmt.ExecContext(p.ctx, rule...); err != nil {
action = "stmt exec"
goto ROLLBACK
}
}
if err = stmt.Close(); err != nil {
action = "stmt close"
goto ROLLBACK
}
if err = tx.Commit(); err != nil {
action = "commit"
goto ROLLBACK
}
return
ROLLBACK:
if err1 := tx.Rollback(); err1 != nil {
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
}
return
}
// selectRows select eligible data by args from the table.
func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) {
// make a slice with capacity
lines := make([]*SqlCasbinRule, 0, 64)
if len(args) == 0 {
return lines, p.db.SelectContext(p.ctx, &lines, query)
}
query = p.db.Rebind(query)
return lines, p.db.SelectContext(p.ctx, &lines, query, args...)
}
// selectWhereIn select eligible data by filter from the table.
func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlSelectWhere)
args := make([]interface{}, 0, 4)
hasInCond := false
for _, col := range [maxParamLength]struct {
name string
arg []string
}{
{"p_type", filter.PType},
{"v0", filter.V0},
{"v1", filter.V1},
{"v2", filter.V2},
{"v3", filter.V3},
{"v4", filter.V4},
{"v5", filter.V5},
} {
l := len(col.arg)
if l == 0 {
continue
}
switch sqlBuf.Bytes()[sqlBuf.Len()-1] {
case '?', ')':
sqlBuf.WriteString(" AND ")
}
sqlBuf.WriteString(col.name)
if l == 1 {
sqlBuf.WriteString("=?")
args = append(args, col.arg[0])
} else {
sqlBuf.WriteString(" IN (?)")
args = append(args, col.arg)
hasInCond = true
}
}
var query string
if hasInCond {
if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil {
return
}
} else {
query = sqlBuf.String()
}
return p.SelectRows(query, args...)
}
// LoadPolicy load all policy rules from the storage.
func (p *SqlAdapter) LoadPolicy(model model.Model) error {
lines, err := p.SelectRows(p.SqlSelectAll)
if err != nil {
return err
}
for _, line := range lines {
p.loadPolicyLine(line, model)
}
return nil
}
// SavePolicy save policy rules to the storage.
func (p *SqlAdapter) SavePolicy(model model.Model) error {
args := make([][]interface{}, 0, 64)
for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
}
for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
}
return p.DeleteAllAndInsertRows(args)
}
// AddPolicy add one policy rule to the storage.
func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error {
args := p.GenArgs(ptype, rule)
_, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...)
return err
}
// AddPolicies add multiple policy rules to the storage.
func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
args := make([][]interface{}, 0, 8)
for _, rule := range rules {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
return p.execTxSqlRows(p.SqlInsertRow, args)
}
// RemovePolicy remove policy rules from the storage.
func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlDeleteByArgs)
args := make([]interface{}, 0, 4)
args = append(args, ptype)
for idx, arg := range rule {
if arg != "" {
sqlBuf.WriteString(" AND v")
sqlBuf.WriteString(strconv.Itoa(idx))
sqlBuf.WriteString("=?")
args = append(args, arg)
}
}
return p.DeleteRows(sqlBuf.String(), args...)
}
// RemoveFilteredPolicy remove policy rules that match the filter from the storage.
func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
var sqlBuf bytes.Buffer
sqlBuf.Grow(64)
sqlBuf.WriteString(p.SqlDeleteByArgs)
args := make([]interface{}, 0, 4)
args = append(args, ptype)
var value string
l := fieldIndex + len(fieldValues)
for idx := 0; idx < 6; idx++ {
if fieldIndex <= idx && idx < l {
value = fieldValues[idx-fieldIndex]
if value != "" {
sqlBuf.WriteString(" AND v")
sqlBuf.WriteString(strconv.Itoa(idx))
sqlBuf.WriteString("=?")
args = append(args, value)
}
}
}
return p.DeleteRows(sqlBuf.String(), args...)
}
// RemovePolicies remove policy rules.
func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
args := make([][]interface{}, 0, 8)
for _, rule := range rules {
arg := p.GenArgs(ptype, rule)
args = append(args, arg)
}
return p.execTxSqlRows(p.SqlDeleteRow, args)
}
// LoadFilteredPolicy load policy rules that match the filter.
// filterPtr must be a pointer.
func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error {
if filterPtr == nil {
return p.LoadPolicy(model)
}
filter, ok := filterPtr.(*SqlFilter)
if !ok {
return errors.New("invalid filter type")
}
lines, err := p.SelectWhereIn(filter)
if err != nil {
return err
}
for _, line := range lines {
p.loadPolicyLine(line, model)
}
p.isFiltered = true
return nil
}
// IsFiltered returns true if the loaded policy rules has been filtered.
func (p *SqlAdapter) IsFiltered() bool {
return p.isFiltered
}
// UpdatePolicy update a policy rule from storage.
// This is part of the Auto-Save feature.
func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error {
oldArg := p.GenArgs(ptype, oldRule)
newArg := p.GenArgs(ptype, newPolicy)
_, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...)
return err
}
// UpdatePolicies updates policy rules to storage.
func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) {
if len(oldRules) != len(newRules) {
return errors.New("old rules size not equal to new rules size")
}
args := make([][]interface{}, 0, 16)
for idx := range oldRules {
oldArg := p.GenArgs(ptype, oldRules[idx])
newArg := p.GenArgs(ptype, newRules[idx])
args = append(args, append(newArg, oldArg...))
}
return p.execTxSqlRows(p.SqlUpdateRow, args)
}
// UpdateFilteredPolicies deletes old rules and adds new rules.
func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) {
var value string
var whereBuf bytes.Buffer
whereBuf.Grow(32)
l := fieldIndex + len(fieldValues)
whereArgs := make([]interface{}, 0, 4)
whereArgs = append(whereArgs, ptype)
for idx := 0; idx < 6; idx++ {
if fieldIndex <= idx && idx < l {
value = fieldValues[idx-fieldIndex]
if value != "" {
whereBuf.WriteString(" AND v")
whereBuf.WriteString(strconv.Itoa(idx))
whereBuf.WriteString("=?")
whereArgs = append(whereArgs, value)
}
}
}
var selectBuf bytes.Buffer
selectBuf.Grow(64)
selectBuf.WriteString(p.SqlSelectWhere)
selectBuf.WriteString("p_type=?")
selectBuf.Write(whereBuf.Bytes())
var oldRows []*SqlCasbinRule
value = p.db.Rebind(selectBuf.String())
oldRows, err = p.SelectRows(value, whereArgs...)
if err != nil {
return
}
var deleteBuf bytes.Buffer
deleteBuf.Grow(64)
deleteBuf.WriteString(p.SqlDeleteByArgs)
deleteBuf.Write(whereBuf.Bytes())
var tx *sqlx.Tx
tx, err = p.db.BeginTxx(p.ctx, nil)
if err != nil {
return
}
var (
stmt *sqlx.Stmt
action string
)
value = p.db.Rebind(deleteBuf.String())
if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil {
action = "delete old policies"
goto ROLLBACK
}
stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow)
if err != nil {
action = "preparex context"
goto ROLLBACK
}
for _, policy := range newPolicies {
arg := p.GenArgs(ptype, policy)
if _, err = stmt.ExecContext(p.ctx, arg...); err != nil {
action = "stmt exec context"
goto ROLLBACK
}
}
if err = stmt.Close(); err != nil {
action = "stmt close"
goto ROLLBACK
}
if err = tx.Commit(); err != nil {
action = "commit"
goto ROLLBACK
}
oldPolicies = make([][]string, 0, len(oldRows))
for _, rule := range oldRows {
oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
}
return
ROLLBACK:
if err1 := tx.Rollback(); err1 != nil {
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
}
return
}
// loadPolicyLine load a policy line to model.
func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) {
if line == nil {
return
}
var lineBuf bytes.Buffer
lineBuf.Grow(64)
lineBuf.WriteString(line.PType)
args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
for _, arg := range args {
if arg != "" {
lineBuf.WriteByte(',')
lineBuf.WriteString(arg)
}
}
persist.LoadPolicyLine(lineBuf.String(), model)
}
// genArgs generate args from ptype and rule.
func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} {
l := len(rule)
args := make([]interface{}, maxParamLength)
args[0] = ptype
for idx := 0; idx < l; idx++ {
args[idx+1] = rule[idx]
}
for idx := l + 1; idx < maxParamLength; idx++ {
args[idx] = ""
}
return args
}

View File

@ -1,466 +0,0 @@
//
// adapter_sqlx_test.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package gcasbin_test
import (
"os"
"strings"
"testing"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"git.hexq.cn/tiglog/golib/gcasbin"
)
const (
rbacModelFile = "testdata/rbac_model.conf"
rbacPolicyFile = "testdata/rbac_policy.csv"
)
var (
dataSourceNames = map[string]string{
// "sqlite3": ":memory:",
// "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test",
"postgres": os.Getenv("DB_DSN"),
// "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30",
}
lines = []gcasbin.SqlCasbinRule{
{PType: "p", V0: "alice", V1: "data1", V2: "read"},
{PType: "p", V0: "bob", V1: "data2", V2: "read"},
{PType: "p", V0: "bob", V1: "data2", V2: "write"},
{PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"},
{PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"},
{PType: "p", V0: "data1_admin", V1: "data2", V2: "write"},
{PType: "g", V0: "alice", V1: "data2_admin"},
{PType: "g", V0: "bob", V1: "data2_admin", V2: "test"},
{PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"},
}
filter = gcasbin.SqlFilter{
PType: []string{"p"},
V0: []string{"bob", "data2_admin"},
V1: []string{"data1", "data2"},
V2: []string{"read", "write"},
V3: []string{"test1"},
V4: []string{"test2"},
V5: []string{"test3"},
}
)
func TestSqlAdapters(t *testing.T) {
for key, value := range dataSourceNames {
t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value)
db, err := sqlx.Connect(key, value)
if err != nil {
t.Fatalf("sqlx.Connect failed, err: %v", err)
}
t.Log("---------- testTableName start")
testTableName(t, db)
t.Log("---------- testTableName finished")
t.Log("---------- testSQL start")
testSQL(t, db, "sqlxadapter_sql")
t.Log("---------- testSQL finished")
t.Log("---------- testSaveLoad start")
testSaveLoad(t, db, "sqlxadapter_save_load")
t.Log("---------- testSaveLoad finished")
t.Log("---------- testAutoSave start")
testAutoSave(t, db, "sqlxadapter_auto_save")
t.Log("---------- testAutoSave finished")
t.Log("---------- testFilteredSqlPolicy start")
testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy")
t.Log("---------- testFilteredSqlPolicy finished")
// t.Log("---------- testUpdateSqlPolicy start")
// testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateSqlPolicy finished")
// t.Log("---------- testUpdateSqlPolicies start")
// testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateSqlPolicies finished")
// t.Log("---------- testUpdateFilteredSqlPolicies start")
// testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateFilteredSqlPolicies finished")
}
}
func testTableName(t *testing.T, db *sqlx.DB) {
_, err := gcasbin.NewSqlAdapter(db, "")
if err != nil {
t.Fatalf("NewAdapter failed, err: %v", err)
}
}
func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool {
if line1.PType != line2.PType ||
line1.V0 != line2.V0 ||
line1.V1 != line2.V1 ||
line1.V2 != line2.V2 ||
line1.V3 != line2.V3 ||
line1.V4 != line2.V4 ||
line1.V5 != line2.V5 {
return false
}
return true
}
var a *gcasbin.SqlAdapter
a, err = gcasbin.NewSqlAdapter(db, tableName)
logErr("NewSqlAdapter")
// createTable test has passed when adapter create
// err = a.CreateTable()
// logErr("createTable")
if b := a.IsTableExist(); b == false {
t.Fatal("isTableExist test failed")
}
rules := make([][]interface{}, len(lines))
for idx, rule := range lines {
args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
rules[idx] = args
}
err = a.TruncateAndInsertRows(rules)
logErr("truncateAndInsertRows")
err = a.DeleteAllAndInsertRows(rules)
logErr("truncateAndInsertRows")
err = a.DeleteRows(a.SqlDeleteByArgs, "g")
logErr("deleteRows sqlDeleteByArgs g")
err = a.DeleteRows(a.SqlDeleteAll)
logErr("deleteRows sqlDeleteAll")
_ = a.TruncateAndInsertRows(rules)
records, err := a.SelectRows(a.SqlSelectAll)
logErr("selectRows sqlSelectAll")
for idx, record := range records {
line := lines[idx]
if !equalValue(*record, line) {
t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line)
}
}
records, err = a.SelectWhereIn(&filter)
logErr("selectWhereIn")
i := 3
for _, record := range records {
line := lines[i]
if !equalValue(*record, line) {
t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line)
}
i++
}
err = a.TruncateTable()
logErr("truncateTable")
}
func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Because the DB is empty at first,
// so we need to load the policy from the file adapter (.CSV) first.
e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile)
a, err := gcasbin.NewSqlAdapter(db, tableName)
if err != nil {
t.Fatal("NewAdapter test failed, err: ", err)
}
// This is a trick to save the current policy to the DB.
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
// The current policy means the policy in the Casbin enforcer (aka in memory).
err = a.SavePolicy(e.GetModel())
if err != nil {
t.Fatal("SavePolicy test failed, err: ", err)
}
// Clear the current policy.
e.ClearPolicy()
testGetSqlPolicy(t, e, [][]string{})
// Load the policy from DB.
err = a.LoadPolicy(e.GetModel())
if err != nil {
t.Fatal("LoadPolicy test failed, err: ", err)
}
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
// AutoSave is enabled by default.
// Now we disable it.
e.EnableAutoSave(false)
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
// Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer,
// it doesn't affect the policy in the storage.
_, err = e.AddPolicy("alice", "data1", "write")
logErr("AddPolicy1")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy1")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
_, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}})
logErr("AddPolicies1")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy2")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Now we enable the AutoSave.
e.EnableAutoSave(true)
// Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer,
// but also affects the policy in the storage.
_, err = e.AddPolicy("alice", "data1", "write")
logErr("AddPolicy2")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy3")
// The policy has a new rule: {"alice", "data1", "write"}.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
_, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
logErr("AddPolicies2")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy4")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"},
{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
_, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
logErr("RemovePolicies")
err = e.LoadPolicy()
logErr("LoadPolicy5")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
// Remove the added rule.
_, err = e.RemovePolicy("alice", "data1", "write")
logErr("RemovePolicy")
err = e.LoadPolicy()
logErr("LoadPolicy6")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Remove "data2_admin" related policy rules via a filter.
// Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted.
_, err = e.RemoveFilteredPolicy(0, "data2_admin")
logErr("RemoveFilteredPolicy")
err = e.LoadPolicy()
logErr("LoadPolicy7")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
}
func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
// Now set the adapter
e.SetAdapter(a)
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
// Load only alice's policies
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}})
logErr("LoadFilteredPolicy alice")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}})
// Load only bob's policies
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}})
logErr("LoadFilteredPolicy bob")
testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}})
// Load policies for data2_admin
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}})
logErr("LoadFilteredPolicy data2_admin")
testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Load policies for alice and bob
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}})
logErr("LoadFilteredPolicy alice bob")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
_, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3")
logErr("AddPolicy")
err = e.LoadFilteredPolicy(&filter)
logErr("LoadFilteredPolicy filter")
testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}})
}
func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"})
e.LoadPolicy()
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
e.LoadPolicy()
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
e.LoadPolicy()
testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}
func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
t.Helper()
myRes := e.GetPolicy()
t.Log("Policy: ", myRes)
m := make(map[string]struct{}, len(myRes))
for _, record := range myRes {
key := strings.Join(record, ",")
m[key] = struct{}{}
}
for _, record := range res {
key := strings.Join(record, ",")
if _, ok := m[key]; !ok {
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
break
}
}
}
func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
myRes := e.GetPolicy()
// log.Print("Policy: \n", myRes)
if !arraySqlEqualsWithoutOrder(myRes, res) {
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
}
}
func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool {
if len(a) != len(b) {
return false
}
mapA := make(map[int]string)
mapB := make(map[int]string)
order := make(map[int]struct{})
l := len(a)
for i := 0; i < l; i++ {
mapA[i] = util.ArrayToString(a[i])
mapB[i] = util.ArrayToString(b[i])
}
for i := 0; i < l; i++ {
for j := 0; j < l; j++ {
if _, ok := order[j]; ok {
if j == l-1 {
return false
} else {
continue
}
}
if mapA[i] == mapB[j] {
order[j] = struct{}{}
break
} else if j == l-1 {
return false
}
}
}
return true
}