chore: 移除 sqlx
This commit is contained in:
parent
36f9acb8f7
commit
2822110dec
@ -1,749 +0,0 @@
|
|||||||
//
|
|
||||||
// adapter_sqlx.go
|
|
||||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package gcasbin
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"fmt"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
"github.com/casbin/casbin/v2/model"
|
|
||||||
"github.com/casbin/casbin/v2/persist"
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
)
|
|
||||||
|
|
||||||
// defaultTableName if tableName == "", the Adapter will use this default table name.
|
|
||||||
const defaultTableName = "casbin_rule"
|
|
||||||
|
|
||||||
// maxParamLength .
|
|
||||||
const maxParamLength = 7
|
|
||||||
|
|
||||||
// general sql
|
|
||||||
const (
|
|
||||||
sqlCreateTable = `
|
|
||||||
CREATE TABLE %[1]s(
|
|
||||||
p_type VARCHAR(32),
|
|
||||||
v0 VARCHAR(255),
|
|
||||||
v1 VARCHAR(255),
|
|
||||||
v2 VARCHAR(255),
|
|
||||||
v3 VARCHAR(255),
|
|
||||||
v4 VARCHAR(255),
|
|
||||||
v5 VARCHAR(255)
|
|
||||||
);
|
|
||||||
CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
|
||||||
sqlTruncateTable = "TRUNCATE TABLE %s"
|
|
||||||
sqlIsTableExist = "SELECT 1 FROM %s"
|
|
||||||
sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)"
|
|
||||||
sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
|
|
||||||
sqlDeleteAll = "DELETE FROM %s"
|
|
||||||
sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?"
|
|
||||||
sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?"
|
|
||||||
sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s"
|
|
||||||
sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE "
|
|
||||||
)
|
|
||||||
|
|
||||||
// for Sqlite3
|
|
||||||
const (
|
|
||||||
sqlCreateTableSqlite3 = `
|
|
||||||
CREATE TABLE IF NOT EXISTS %[1]s(
|
|
||||||
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
|
||||||
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v5 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
CHECK (TYPEOF("p_type") = "text" AND
|
|
||||||
LENGTH("p_type") <= 32),
|
|
||||||
CHECK (TYPEOF("v0") = "text" AND
|
|
||||||
LENGTH("v0") <= 255),
|
|
||||||
CHECK (TYPEOF("v1") = "text" AND
|
|
||||||
LENGTH("v1") <= 255),
|
|
||||||
CHECK (TYPEOF("v2") = "text" AND
|
|
||||||
LENGTH("v2") <= 255),
|
|
||||||
CHECK (TYPEOF("v3") = "text" AND
|
|
||||||
LENGTH("v3") <= 255),
|
|
||||||
CHECK (TYPEOF("v4") = "text" AND
|
|
||||||
LENGTH("v4") <= 255),
|
|
||||||
CHECK (TYPEOF("v5") = "text" AND
|
|
||||||
LENGTH("v5") <= 255)
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
|
||||||
sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3
|
|
||||||
)
|
|
||||||
|
|
||||||
// for Mysql
|
|
||||||
const (
|
|
||||||
sqlCreateTableMysql = `
|
|
||||||
CREATE TABLE IF NOT EXISTS %[1]s(
|
|
||||||
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
|
||||||
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v5 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
INDEX idx_%[1]s (p_type,v0,v1)
|
|
||||||
) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;`
|
|
||||||
)
|
|
||||||
|
|
||||||
// for Postgres
|
|
||||||
const (
|
|
||||||
sqlCreateTablePostgres = `
|
|
||||||
CREATE TABLE IF NOT EXISTS %[1]s(
|
|
||||||
p_type VARCHAR(32) DEFAULT '' NOT NULL,
|
|
||||||
v0 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v1 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v2 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v3 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v4 VARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v5 VARCHAR(255) DEFAULT '' NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);`
|
|
||||||
sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)"
|
|
||||||
sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14"
|
|
||||||
sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7"
|
|
||||||
)
|
|
||||||
|
|
||||||
// for Sqlserver
|
|
||||||
const (
|
|
||||||
sqlCreateTableSqlserver = `
|
|
||||||
CREATE TABLE %[1]s(
|
|
||||||
p_type NVARCHAR(32) DEFAULT '' NOT NULL,
|
|
||||||
v0 NVARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v1 NVARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v2 NVARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v3 NVARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v4 NVARCHAR(255) DEFAULT '' NOT NULL,
|
|
||||||
v5 NVARCHAR(255) DEFAULT '' NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);`
|
|
||||||
sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)"
|
|
||||||
sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14"
|
|
||||||
sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CasbinRule defines the casbin rule model.
|
|
||||||
// It used for save or load policy lines from sqlx connected database.
|
|
||||||
type SqlCasbinRule struct {
|
|
||||||
PType string `db:"p_type"`
|
|
||||||
V0 string `db:"v0"`
|
|
||||||
V1 string `db:"v1"`
|
|
||||||
V2 string `db:"v2"`
|
|
||||||
V3 string `db:"v3"`
|
|
||||||
V4 string `db:"v4"`
|
|
||||||
V5 string `db:"v5"`
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adapter define the sqlx adapter for Casbin.
|
|
||||||
// It can load policy lines or save policy lines from sqlx connected database.
|
|
||||||
type SqlAdapter struct {
|
|
||||||
db *sqlx.DB
|
|
||||||
ctx context.Context
|
|
||||||
tableName string
|
|
||||||
|
|
||||||
isFiltered bool
|
|
||||||
|
|
||||||
SqlCreateTable string
|
|
||||||
SqlTruncateTable string
|
|
||||||
SqlIsTableExist string
|
|
||||||
SqlInsertRow string
|
|
||||||
SqlUpdateRow string
|
|
||||||
SqlDeleteAll string
|
|
||||||
SqlDeleteRow string
|
|
||||||
SqlDeleteByArgs string
|
|
||||||
SqlSelectAll string
|
|
||||||
SqlSelectWhere string
|
|
||||||
}
|
|
||||||
|
|
||||||
// Filter defines the filtering rules for a FilteredAdapter's policy.
|
|
||||||
// Empty values are ignored, but all others must match the filter.
|
|
||||||
type SqlFilter struct {
|
|
||||||
PType []string
|
|
||||||
V0 []string
|
|
||||||
V1 []string
|
|
||||||
V2 []string
|
|
||||||
V3 []string
|
|
||||||
V4 []string
|
|
||||||
V5 []string
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAdapter the constructor for Adapter.
|
|
||||||
// db should connected to database and controlled by user.
|
|
||||||
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
|
|
||||||
func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) {
|
|
||||||
return NewSqlAdapterContext(context.Background(), db, tableName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewAdapterContext the constructor for Adapter.
|
|
||||||
// db should connected to database and controlled by user.
|
|
||||||
// If tableName == "", the Adapter will automatically create a table named "casbin_rule".
|
|
||||||
func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) {
|
|
||||||
if db == nil {
|
|
||||||
return nil, errors.New("db is nil")
|
|
||||||
}
|
|
||||||
|
|
||||||
// check db connecting
|
|
||||||
err := db.PingContext(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
switch db.DriverName() {
|
|
||||||
case "oci8", "ora", "goracle":
|
|
||||||
return nil, errors.New("sqlxadapter: please checkout 'oracle' branch")
|
|
||||||
}
|
|
||||||
|
|
||||||
if tableName == "" {
|
|
||||||
tableName = defaultTableName
|
|
||||||
}
|
|
||||||
|
|
||||||
adapter := SqlAdapter{
|
|
||||||
db: db,
|
|
||||||
ctx: ctx,
|
|
||||||
tableName: tableName,
|
|
||||||
}
|
|
||||||
|
|
||||||
// generate different databases sql
|
|
||||||
adapter.genSQL()
|
|
||||||
|
|
||||||
if !adapter.IsTableExist() {
|
|
||||||
if err = adapter.CreateTable(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &adapter, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// genSQL generate sql based on db driver name.
|
|
||||||
func (p *SqlAdapter) genSQL() {
|
|
||||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName)
|
|
||||||
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName)
|
|
||||||
|
|
||||||
p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName)
|
|
||||||
|
|
||||||
p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName)
|
|
||||||
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName)
|
|
||||||
p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName)
|
|
||||||
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName)
|
|
||||||
p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName)
|
|
||||||
|
|
||||||
p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName)
|
|
||||||
p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName)
|
|
||||||
|
|
||||||
switch p.db.DriverName() {
|
|
||||||
case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres":
|
|
||||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName)
|
|
||||||
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName)
|
|
||||||
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName)
|
|
||||||
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName)
|
|
||||||
case "mysql":
|
|
||||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName)
|
|
||||||
case "sqlite3":
|
|
||||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName)
|
|
||||||
p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName)
|
|
||||||
case "sqlserver":
|
|
||||||
p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName)
|
|
||||||
p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName)
|
|
||||||
p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName)
|
|
||||||
p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createTable create a not exists table.
|
|
||||||
func (p *SqlAdapter) CreateTable() error {
|
|
||||||
_, err := p.db.ExecContext(p.ctx, p.SqlCreateTable)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncateTable clear the table.
|
|
||||||
func (p *SqlAdapter) TruncateTable() error {
|
|
||||||
_, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteAll clear the table.
|
|
||||||
func (p *SqlAdapter) DeleteAll() error {
|
|
||||||
_, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// isTableExist check the table exists.
|
|
||||||
func (p *SqlAdapter) IsTableExist() bool {
|
|
||||||
_, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist)
|
|
||||||
|
|
||||||
return err == nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteRows delete eligible data.
|
|
||||||
func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error {
|
|
||||||
query = p.db.Rebind(query)
|
|
||||||
|
|
||||||
_, err := p.db.ExecContext(p.ctx, query, args...)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// truncateAndInsertRows clear table and insert new rows.
|
|
||||||
func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error {
|
|
||||||
if err := p.TruncateTable(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.execTxSqlRows(p.SqlInsertRow, rules)
|
|
||||||
}
|
|
||||||
|
|
||||||
// deleteAllAndInsertRows clear table and insert new rows.
|
|
||||||
func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error {
|
|
||||||
if err := p.DeleteAll(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return p.execTxSqlRows(p.SqlInsertRow, rules)
|
|
||||||
}
|
|
||||||
|
|
||||||
// execTxSqlRows exec sql rows.
|
|
||||||
func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) {
|
|
||||||
tx, err := p.db.BeginTx(p.ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var action string
|
|
||||||
|
|
||||||
stmt, err := tx.PrepareContext(p.ctx, query)
|
|
||||||
if err != nil {
|
|
||||||
action = "prepare context"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
if _, err = stmt.ExecContext(p.ctx, rule...); err != nil {
|
|
||||||
action = "stmt exec"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = stmt.Close(); err != nil {
|
|
||||||
action = "stmt close"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
action = "commit"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
ROLLBACK:
|
|
||||||
|
|
||||||
if err1 := tx.Rollback(); err1 != nil {
|
|
||||||
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectRows select eligible data by args from the table.
|
|
||||||
func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) {
|
|
||||||
// make a slice with capacity
|
|
||||||
lines := make([]*SqlCasbinRule, 0, 64)
|
|
||||||
|
|
||||||
if len(args) == 0 {
|
|
||||||
return lines, p.db.SelectContext(p.ctx, &lines, query)
|
|
||||||
}
|
|
||||||
|
|
||||||
query = p.db.Rebind(query)
|
|
||||||
|
|
||||||
return lines, p.db.SelectContext(p.ctx, &lines, query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// selectWhereIn select eligible data by filter from the table.
|
|
||||||
func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) {
|
|
||||||
var sqlBuf bytes.Buffer
|
|
||||||
|
|
||||||
sqlBuf.Grow(64)
|
|
||||||
sqlBuf.WriteString(p.SqlSelectWhere)
|
|
||||||
|
|
||||||
args := make([]interface{}, 0, 4)
|
|
||||||
|
|
||||||
hasInCond := false
|
|
||||||
|
|
||||||
for _, col := range [maxParamLength]struct {
|
|
||||||
name string
|
|
||||||
arg []string
|
|
||||||
}{
|
|
||||||
{"p_type", filter.PType},
|
|
||||||
{"v0", filter.V0},
|
|
||||||
{"v1", filter.V1},
|
|
||||||
{"v2", filter.V2},
|
|
||||||
{"v3", filter.V3},
|
|
||||||
{"v4", filter.V4},
|
|
||||||
{"v5", filter.V5},
|
|
||||||
} {
|
|
||||||
l := len(col.arg)
|
|
||||||
if l == 0 {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
switch sqlBuf.Bytes()[sqlBuf.Len()-1] {
|
|
||||||
case '?', ')':
|
|
||||||
sqlBuf.WriteString(" AND ")
|
|
||||||
}
|
|
||||||
|
|
||||||
sqlBuf.WriteString(col.name)
|
|
||||||
|
|
||||||
if l == 1 {
|
|
||||||
sqlBuf.WriteString("=?")
|
|
||||||
args = append(args, col.arg[0])
|
|
||||||
} else {
|
|
||||||
sqlBuf.WriteString(" IN (?)")
|
|
||||||
args = append(args, col.arg)
|
|
||||||
|
|
||||||
hasInCond = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var query string
|
|
||||||
|
|
||||||
if hasInCond {
|
|
||||||
if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
query = sqlBuf.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.SelectRows(query, args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadPolicy load all policy rules from the storage.
|
|
||||||
func (p *SqlAdapter) LoadPolicy(model model.Model) error {
|
|
||||||
lines, err := p.SelectRows(p.SqlSelectAll)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, line := range lines {
|
|
||||||
p.loadPolicyLine(line, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SavePolicy save policy rules to the storage.
|
|
||||||
func (p *SqlAdapter) SavePolicy(model model.Model) error {
|
|
||||||
args := make([][]interface{}, 0, 64)
|
|
||||||
|
|
||||||
for ptype, ast := range model["p"] {
|
|
||||||
for _, rule := range ast.Policy {
|
|
||||||
arg := p.GenArgs(ptype, rule)
|
|
||||||
args = append(args, arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for ptype, ast := range model["g"] {
|
|
||||||
for _, rule := range ast.Policy {
|
|
||||||
arg := p.GenArgs(ptype, rule)
|
|
||||||
args = append(args, arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.DeleteAllAndInsertRows(args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPolicy add one policy rule to the storage.
|
|
||||||
func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error {
|
|
||||||
args := p.GenArgs(ptype, rule)
|
|
||||||
|
|
||||||
_, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddPolicies add multiple policy rules to the storage.
|
|
||||||
func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error {
|
|
||||||
args := make([][]interface{}, 0, 8)
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
arg := p.GenArgs(ptype, rule)
|
|
||||||
args = append(args, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.execTxSqlRows(p.SqlInsertRow, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePolicy remove policy rules from the storage.
|
|
||||||
func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error {
|
|
||||||
var sqlBuf bytes.Buffer
|
|
||||||
|
|
||||||
sqlBuf.Grow(64)
|
|
||||||
sqlBuf.WriteString(p.SqlDeleteByArgs)
|
|
||||||
|
|
||||||
args := make([]interface{}, 0, 4)
|
|
||||||
args = append(args, ptype)
|
|
||||||
|
|
||||||
for idx, arg := range rule {
|
|
||||||
if arg != "" {
|
|
||||||
sqlBuf.WriteString(" AND v")
|
|
||||||
sqlBuf.WriteString(strconv.Itoa(idx))
|
|
||||||
sqlBuf.WriteString("=?")
|
|
||||||
|
|
||||||
args = append(args, arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.DeleteRows(sqlBuf.String(), args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveFilteredPolicy remove policy rules that match the filter from the storage.
|
|
||||||
func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error {
|
|
||||||
var sqlBuf bytes.Buffer
|
|
||||||
|
|
||||||
sqlBuf.Grow(64)
|
|
||||||
sqlBuf.WriteString(p.SqlDeleteByArgs)
|
|
||||||
|
|
||||||
args := make([]interface{}, 0, 4)
|
|
||||||
args = append(args, ptype)
|
|
||||||
|
|
||||||
var value string
|
|
||||||
|
|
||||||
l := fieldIndex + len(fieldValues)
|
|
||||||
|
|
||||||
for idx := 0; idx < 6; idx++ {
|
|
||||||
if fieldIndex <= idx && idx < l {
|
|
||||||
value = fieldValues[idx-fieldIndex]
|
|
||||||
|
|
||||||
if value != "" {
|
|
||||||
sqlBuf.WriteString(" AND v")
|
|
||||||
sqlBuf.WriteString(strconv.Itoa(idx))
|
|
||||||
sqlBuf.WriteString("=?")
|
|
||||||
|
|
||||||
args = append(args, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.DeleteRows(sqlBuf.String(), args...)
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemovePolicies remove policy rules.
|
|
||||||
func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) {
|
|
||||||
args := make([][]interface{}, 0, 8)
|
|
||||||
|
|
||||||
for _, rule := range rules {
|
|
||||||
arg := p.GenArgs(ptype, rule)
|
|
||||||
args = append(args, arg)
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.execTxSqlRows(p.SqlDeleteRow, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// LoadFilteredPolicy load policy rules that match the filter.
|
|
||||||
// filterPtr must be a pointer.
|
|
||||||
func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error {
|
|
||||||
if filterPtr == nil {
|
|
||||||
return p.LoadPolicy(model)
|
|
||||||
}
|
|
||||||
|
|
||||||
filter, ok := filterPtr.(*SqlFilter)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("invalid filter type")
|
|
||||||
}
|
|
||||||
|
|
||||||
lines, err := p.SelectWhereIn(filter)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, line := range lines {
|
|
||||||
p.loadPolicyLine(line, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
p.isFiltered = true
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsFiltered returns true if the loaded policy rules has been filtered.
|
|
||||||
func (p *SqlAdapter) IsFiltered() bool {
|
|
||||||
return p.isFiltered
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePolicy update a policy rule from storage.
|
|
||||||
// This is part of the Auto-Save feature.
|
|
||||||
func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error {
|
|
||||||
oldArg := p.GenArgs(ptype, oldRule)
|
|
||||||
newArg := p.GenArgs(ptype, newPolicy)
|
|
||||||
|
|
||||||
_, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...)
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePolicies updates policy rules to storage.
|
|
||||||
func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) {
|
|
||||||
if len(oldRules) != len(newRules) {
|
|
||||||
return errors.New("old rules size not equal to new rules size")
|
|
||||||
}
|
|
||||||
|
|
||||||
args := make([][]interface{}, 0, 16)
|
|
||||||
|
|
||||||
for idx := range oldRules {
|
|
||||||
oldArg := p.GenArgs(ptype, oldRules[idx])
|
|
||||||
newArg := p.GenArgs(ptype, newRules[idx])
|
|
||||||
args = append(args, append(newArg, oldArg...))
|
|
||||||
}
|
|
||||||
|
|
||||||
return p.execTxSqlRows(p.SqlUpdateRow, args)
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdateFilteredPolicies deletes old rules and adds new rules.
|
|
||||||
func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) {
|
|
||||||
var value string
|
|
||||||
|
|
||||||
var whereBuf bytes.Buffer
|
|
||||||
whereBuf.Grow(32)
|
|
||||||
|
|
||||||
l := fieldIndex + len(fieldValues)
|
|
||||||
|
|
||||||
whereArgs := make([]interface{}, 0, 4)
|
|
||||||
whereArgs = append(whereArgs, ptype)
|
|
||||||
|
|
||||||
for idx := 0; idx < 6; idx++ {
|
|
||||||
if fieldIndex <= idx && idx < l {
|
|
||||||
value = fieldValues[idx-fieldIndex]
|
|
||||||
|
|
||||||
if value != "" {
|
|
||||||
whereBuf.WriteString(" AND v")
|
|
||||||
whereBuf.WriteString(strconv.Itoa(idx))
|
|
||||||
whereBuf.WriteString("=?")
|
|
||||||
|
|
||||||
whereArgs = append(whereArgs, value)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var selectBuf bytes.Buffer
|
|
||||||
selectBuf.Grow(64)
|
|
||||||
selectBuf.WriteString(p.SqlSelectWhere)
|
|
||||||
selectBuf.WriteString("p_type=?")
|
|
||||||
selectBuf.Write(whereBuf.Bytes())
|
|
||||||
|
|
||||||
var oldRows []*SqlCasbinRule
|
|
||||||
value = p.db.Rebind(selectBuf.String())
|
|
||||||
oldRows, err = p.SelectRows(value, whereArgs...)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var deleteBuf bytes.Buffer
|
|
||||||
deleteBuf.Grow(64)
|
|
||||||
deleteBuf.WriteString(p.SqlDeleteByArgs)
|
|
||||||
deleteBuf.Write(whereBuf.Bytes())
|
|
||||||
|
|
||||||
var tx *sqlx.Tx
|
|
||||||
tx, err = p.db.BeginTxx(p.ctx, nil)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
stmt *sqlx.Stmt
|
|
||||||
action string
|
|
||||||
)
|
|
||||||
value = p.db.Rebind(deleteBuf.String())
|
|
||||||
if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil {
|
|
||||||
action = "delete old policies"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow)
|
|
||||||
if err != nil {
|
|
||||||
action = "preparex context"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, policy := range newPolicies {
|
|
||||||
arg := p.GenArgs(ptype, policy)
|
|
||||||
if _, err = stmt.ExecContext(p.ctx, arg...); err != nil {
|
|
||||||
action = "stmt exec context"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = stmt.Close(); err != nil {
|
|
||||||
action = "stmt close"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
if err = tx.Commit(); err != nil {
|
|
||||||
action = "commit"
|
|
||||||
goto ROLLBACK
|
|
||||||
}
|
|
||||||
|
|
||||||
oldPolicies = make([][]string, 0, len(oldRows))
|
|
||||||
for _, rule := range oldRows {
|
|
||||||
oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
|
|
||||||
ROLLBACK:
|
|
||||||
|
|
||||||
if err1 := tx.Rollback(); err1 != nil {
|
|
||||||
err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1)
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// loadPolicyLine load a policy line to model.
|
|
||||||
func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) {
|
|
||||||
if line == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var lineBuf bytes.Buffer
|
|
||||||
|
|
||||||
lineBuf.Grow(64)
|
|
||||||
lineBuf.WriteString(line.PType)
|
|
||||||
|
|
||||||
args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5}
|
|
||||||
for _, arg := range args {
|
|
||||||
if arg != "" {
|
|
||||||
lineBuf.WriteByte(',')
|
|
||||||
lineBuf.WriteString(arg)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
persist.LoadPolicyLine(lineBuf.String(), model)
|
|
||||||
}
|
|
||||||
|
|
||||||
// genArgs generate args from ptype and rule.
|
|
||||||
func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} {
|
|
||||||
l := len(rule)
|
|
||||||
|
|
||||||
args := make([]interface{}, maxParamLength)
|
|
||||||
args[0] = ptype
|
|
||||||
|
|
||||||
for idx := 0; idx < l; idx++ {
|
|
||||||
args[idx+1] = rule[idx]
|
|
||||||
}
|
|
||||||
|
|
||||||
for idx := l + 1; idx < maxParamLength; idx++ {
|
|
||||||
args[idx] = ""
|
|
||||||
}
|
|
||||||
|
|
||||||
return args
|
|
||||||
}
|
|
@ -1,466 +0,0 @@
|
|||||||
//
|
|
||||||
// adapter_sqlx_test.go
|
|
||||||
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
|
||||||
//
|
|
||||||
// Distributed under terms of the MIT license.
|
|
||||||
//
|
|
||||||
|
|
||||||
package gcasbin_test
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"strings"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/casbin/casbin/v2"
|
|
||||||
"github.com/casbin/casbin/v2/util"
|
|
||||||
_ "github.com/go-sql-driver/mysql"
|
|
||||||
"github.com/jmoiron/sqlx"
|
|
||||||
"git.hexq.cn/tiglog/golib/gcasbin"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
rbacModelFile = "testdata/rbac_model.conf"
|
|
||||||
rbacPolicyFile = "testdata/rbac_policy.csv"
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
|
||||||
dataSourceNames = map[string]string{
|
|
||||||
// "sqlite3": ":memory:",
|
|
||||||
// "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test",
|
|
||||||
"postgres": os.Getenv("DB_DSN"),
|
|
||||||
// "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30",
|
|
||||||
}
|
|
||||||
|
|
||||||
lines = []gcasbin.SqlCasbinRule{
|
|
||||||
{PType: "p", V0: "alice", V1: "data1", V2: "read"},
|
|
||||||
{PType: "p", V0: "bob", V1: "data2", V2: "read"},
|
|
||||||
{PType: "p", V0: "bob", V1: "data2", V2: "write"},
|
|
||||||
{PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"},
|
|
||||||
{PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"},
|
|
||||||
{PType: "p", V0: "data1_admin", V1: "data2", V2: "write"},
|
|
||||||
{PType: "g", V0: "alice", V1: "data2_admin"},
|
|
||||||
{PType: "g", V0: "bob", V1: "data2_admin", V2: "test"},
|
|
||||||
{PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"},
|
|
||||||
}
|
|
||||||
|
|
||||||
filter = gcasbin.SqlFilter{
|
|
||||||
PType: []string{"p"},
|
|
||||||
V0: []string{"bob", "data2_admin"},
|
|
||||||
V1: []string{"data1", "data2"},
|
|
||||||
V2: []string{"read", "write"},
|
|
||||||
V3: []string{"test1"},
|
|
||||||
V4: []string{"test2"},
|
|
||||||
V5: []string{"test3"},
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestSqlAdapters(t *testing.T) {
|
|
||||||
for key, value := range dataSourceNames {
|
|
||||||
t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value)
|
|
||||||
|
|
||||||
db, err := sqlx.Connect(key, value)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("sqlx.Connect failed, err: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Log("---------- testTableName start")
|
|
||||||
testTableName(t, db)
|
|
||||||
t.Log("---------- testTableName finished")
|
|
||||||
|
|
||||||
t.Log("---------- testSQL start")
|
|
||||||
testSQL(t, db, "sqlxadapter_sql")
|
|
||||||
t.Log("---------- testSQL finished")
|
|
||||||
|
|
||||||
t.Log("---------- testSaveLoad start")
|
|
||||||
testSaveLoad(t, db, "sqlxadapter_save_load")
|
|
||||||
t.Log("---------- testSaveLoad finished")
|
|
||||||
|
|
||||||
t.Log("---------- testAutoSave start")
|
|
||||||
testAutoSave(t, db, "sqlxadapter_auto_save")
|
|
||||||
t.Log("---------- testAutoSave finished")
|
|
||||||
|
|
||||||
t.Log("---------- testFilteredSqlPolicy start")
|
|
||||||
testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy")
|
|
||||||
t.Log("---------- testFilteredSqlPolicy finished")
|
|
||||||
|
|
||||||
// t.Log("---------- testUpdateSqlPolicy start")
|
|
||||||
// testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy")
|
|
||||||
// t.Log("---------- testUpdateSqlPolicy finished")
|
|
||||||
|
|
||||||
// t.Log("---------- testUpdateSqlPolicies start")
|
|
||||||
// testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy")
|
|
||||||
// t.Log("---------- testUpdateSqlPolicies finished")
|
|
||||||
|
|
||||||
// t.Log("---------- testUpdateFilteredSqlPolicies start")
|
|
||||||
// testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy")
|
|
||||||
// t.Log("---------- testUpdateFilteredSqlPolicies finished")
|
|
||||||
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testTableName(t *testing.T, db *sqlx.DB) {
|
|
||||||
_, err := gcasbin.NewSqlAdapter(db, "")
|
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("NewAdapter failed, err: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
var err error
|
|
||||||
logErr := func(action string) {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s test failed, err: %v", action, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool {
|
|
||||||
if line1.PType != line2.PType ||
|
|
||||||
line1.V0 != line2.V0 ||
|
|
||||||
line1.V1 != line2.V1 ||
|
|
||||||
line1.V2 != line2.V2 ||
|
|
||||||
line1.V3 != line2.V3 ||
|
|
||||||
line1.V4 != line2.V4 ||
|
|
||||||
line1.V5 != line2.V5 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
var a *gcasbin.SqlAdapter
|
|
||||||
a, err = gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
logErr("NewSqlAdapter")
|
|
||||||
|
|
||||||
// createTable test has passed when adapter create
|
|
||||||
// err = a.CreateTable()
|
|
||||||
// logErr("createTable")
|
|
||||||
|
|
||||||
if b := a.IsTableExist(); b == false {
|
|
||||||
t.Fatal("isTableExist test failed")
|
|
||||||
}
|
|
||||||
|
|
||||||
rules := make([][]interface{}, len(lines))
|
|
||||||
for idx, rule := range lines {
|
|
||||||
args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
|
|
||||||
rules[idx] = args
|
|
||||||
}
|
|
||||||
|
|
||||||
err = a.TruncateAndInsertRows(rules)
|
|
||||||
logErr("truncateAndInsertRows")
|
|
||||||
|
|
||||||
err = a.DeleteAllAndInsertRows(rules)
|
|
||||||
logErr("truncateAndInsertRows")
|
|
||||||
|
|
||||||
err = a.DeleteRows(a.SqlDeleteByArgs, "g")
|
|
||||||
logErr("deleteRows sqlDeleteByArgs g")
|
|
||||||
|
|
||||||
err = a.DeleteRows(a.SqlDeleteAll)
|
|
||||||
logErr("deleteRows sqlDeleteAll")
|
|
||||||
|
|
||||||
_ = a.TruncateAndInsertRows(rules)
|
|
||||||
|
|
||||||
records, err := a.SelectRows(a.SqlSelectAll)
|
|
||||||
logErr("selectRows sqlSelectAll")
|
|
||||||
for idx, record := range records {
|
|
||||||
line := lines[idx]
|
|
||||||
if !equalValue(*record, line) {
|
|
||||||
t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
records, err = a.SelectWhereIn(&filter)
|
|
||||||
logErr("selectWhereIn")
|
|
||||||
i := 3
|
|
||||||
for _, record := range records {
|
|
||||||
line := lines[i]
|
|
||||||
if !equalValue(*record, line) {
|
|
||||||
t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
|
|
||||||
err = a.TruncateTable()
|
|
||||||
logErr("truncateTable")
|
|
||||||
}
|
|
||||||
|
|
||||||
func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Because the DB is empty at first,
|
|
||||||
// so we need to load the policy from the file adapter (.CSV) first.
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile)
|
|
||||||
|
|
||||||
a, err := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("NewAdapter test failed, err: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is a trick to save the current policy to the DB.
|
|
||||||
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
|
|
||||||
// The current policy means the policy in the Casbin enforcer (aka in memory).
|
|
||||||
err = a.SavePolicy(e.GetModel())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("SavePolicy test failed, err: ", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear the current policy.
|
|
||||||
e.ClearPolicy()
|
|
||||||
testGetSqlPolicy(t, e, [][]string{})
|
|
||||||
|
|
||||||
// Load the policy from DB.
|
|
||||||
err = a.LoadPolicy(e.GetModel())
|
|
||||||
if err != nil {
|
|
||||||
t.Fatal("LoadPolicy test failed, err: ", err)
|
|
||||||
}
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Initialize some policy in DB.
|
|
||||||
initSqlPolicy(t, db, tableName)
|
|
||||||
// Note: you don't need to look at the above code
|
|
||||||
// if you already have a working DB with policy inside.
|
|
||||||
|
|
||||||
// Now the DB has policy, so we can provide a normal use case.
|
|
||||||
// Create an adapter and an enforcer.
|
|
||||||
// NewEnforcer() will load the policy automatically.
|
|
||||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Initialize some policy in DB.
|
|
||||||
initSqlPolicy(t, db, tableName)
|
|
||||||
// Note: you don't need to look at the above code
|
|
||||||
// if you already have a working DB with policy inside.
|
|
||||||
|
|
||||||
// Now the DB has policy, so we can provide a normal use case.
|
|
||||||
// Create an adapter and an enforcer.
|
|
||||||
// NewEnforcer() will load the policy automatically.
|
|
||||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
||||||
|
|
||||||
// AutoSave is enabled by default.
|
|
||||||
// Now we disable it.
|
|
||||||
e.EnableAutoSave(false)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
logErr := func(action string) {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s test failed, err: %v", action, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer,
|
|
||||||
// it doesn't affect the policy in the storage.
|
|
||||||
_, err = e.AddPolicy("alice", "data1", "write")
|
|
||||||
logErr("AddPolicy1")
|
|
||||||
// Reload the policy from the storage to see the effect.
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy1")
|
|
||||||
// This is still the original policy.
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
|
|
||||||
_, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}})
|
|
||||||
logErr("AddPolicies1")
|
|
||||||
// Reload the policy from the storage to see the effect.
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy2")
|
|
||||||
// This is still the original policy.
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
|
|
||||||
// Now we enable the AutoSave.
|
|
||||||
e.EnableAutoSave(true)
|
|
||||||
|
|
||||||
// Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer,
|
|
||||||
// but also affects the policy in the storage.
|
|
||||||
_, err = e.AddPolicy("alice", "data1", "write")
|
|
||||||
logErr("AddPolicy2")
|
|
||||||
// Reload the policy from the storage to see the effect.
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy3")
|
|
||||||
// The policy has a new rule: {"alice", "data1", "write"}.
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
|
|
||||||
|
|
||||||
_, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
|
||||||
logErr("AddPolicies2")
|
|
||||||
// Reload the policy from the storage to see the effect.
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy4")
|
|
||||||
// This is still the original policy.
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"},
|
|
||||||
{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
|
||||||
|
|
||||||
_, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
|
||||||
logErr("RemovePolicies")
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy5")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
|
|
||||||
|
|
||||||
// Remove the added rule.
|
|
||||||
_, err = e.RemovePolicy("alice", "data1", "write")
|
|
||||||
logErr("RemovePolicy")
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy6")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
|
|
||||||
// Remove "data2_admin" related policy rules via a filter.
|
|
||||||
// Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted.
|
|
||||||
_, err = e.RemoveFilteredPolicy(0, "data2_admin")
|
|
||||||
logErr("RemoveFilteredPolicy")
|
|
||||||
err = e.LoadPolicy()
|
|
||||||
logErr("LoadPolicy7")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Initialize some policy in DB.
|
|
||||||
initSqlPolicy(t, db, tableName)
|
|
||||||
// Note: you don't need to look at the above code
|
|
||||||
// if you already have a working DB with policy inside.
|
|
||||||
|
|
||||||
// Now the DB has policy, so we can provide a normal use case.
|
|
||||||
// Create an adapter and an enforcer.
|
|
||||||
// NewEnforcer() will load the policy automatically.
|
|
||||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
||||||
// Now set the adapter
|
|
||||||
e.SetAdapter(a)
|
|
||||||
|
|
||||||
var err error
|
|
||||||
logErr := func(action string) {
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("%s test failed, err: %v", action, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Load only alice's policies
|
|
||||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}})
|
|
||||||
logErr("LoadFilteredPolicy alice")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}})
|
|
||||||
|
|
||||||
// Load only bob's policies
|
|
||||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}})
|
|
||||||
logErr("LoadFilteredPolicy bob")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}})
|
|
||||||
|
|
||||||
// Load policies for data2_admin
|
|
||||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}})
|
|
||||||
logErr("LoadFilteredPolicy data2_admin")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
|
|
||||||
// Load policies for alice and bob
|
|
||||||
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}})
|
|
||||||
logErr("LoadFilteredPolicy alice bob")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
|
|
||||||
|
|
||||||
_, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3")
|
|
||||||
logErr("AddPolicy")
|
|
||||||
|
|
||||||
err = e.LoadFilteredPolicy(&filter)
|
|
||||||
logErr("LoadFilteredPolicy filter")
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Initialize some policy in DB.
|
|
||||||
initSqlPolicy(t, db, tableName)
|
|
||||||
|
|
||||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
||||||
|
|
||||||
e.EnableAutoSave(true)
|
|
||||||
e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"})
|
|
||||||
e.LoadPolicy()
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Initialize some policy in DB.
|
|
||||||
initSqlPolicy(t, db, tableName)
|
|
||||||
|
|
||||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
||||||
|
|
||||||
e.EnableAutoSave(true)
|
|
||||||
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
|
|
||||||
e.LoadPolicy()
|
|
||||||
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
|
|
||||||
// Initialize some policy in DB.
|
|
||||||
initSqlPolicy(t, db, tableName)
|
|
||||||
|
|
||||||
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
||||||
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
||||||
|
|
||||||
e.EnableAutoSave(true)
|
|
||||||
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
|
|
||||||
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
|
|
||||||
e.LoadPolicy()
|
|
||||||
testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
|
|
||||||
t.Helper()
|
|
||||||
myRes := e.GetPolicy()
|
|
||||||
t.Log("Policy: ", myRes)
|
|
||||||
|
|
||||||
m := make(map[string]struct{}, len(myRes))
|
|
||||||
for _, record := range myRes {
|
|
||||||
key := strings.Join(record, ",")
|
|
||||||
m[key] = struct{}{}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, record := range res {
|
|
||||||
key := strings.Join(record, ",")
|
|
||||||
if _, ok := m[key]; !ok {
|
|
||||||
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
|
|
||||||
myRes := e.GetPolicy()
|
|
||||||
// log.Print("Policy: \n", myRes)
|
|
||||||
|
|
||||||
if !arraySqlEqualsWithoutOrder(myRes, res) {
|
|
||||||
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool {
|
|
||||||
if len(a) != len(b) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
mapA := make(map[int]string)
|
|
||||||
mapB := make(map[int]string)
|
|
||||||
order := make(map[int]struct{})
|
|
||||||
l := len(a)
|
|
||||||
|
|
||||||
for i := 0; i < l; i++ {
|
|
||||||
mapA[i] = util.ArrayToString(a[i])
|
|
||||||
mapB[i] = util.ArrayToString(b[i])
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := 0; i < l; i++ {
|
|
||||||
for j := 0; j < l; j++ {
|
|
||||||
if _, ok := order[j]; ok {
|
|
||||||
if j == l-1 {
|
|
||||||
return false
|
|
||||||
} else {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if mapA[i] == mapB[j] {
|
|
||||||
order[j] = struct{}{}
|
|
||||||
break
|
|
||||||
} else if j == l-1 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
Loading…
Reference in New Issue
Block a user