diff --git a/gcasbin/adapter_sqlx.go b/gcasbin/adapter_sqlx.go deleted file mode 100644 index e12df23..0000000 --- a/gcasbin/adapter_sqlx.go +++ /dev/null @@ -1,749 +0,0 @@ -// -// adapter_sqlx.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package gcasbin - -import ( - "bytes" - "context" - "errors" - "fmt" - "strconv" - - "github.com/casbin/casbin/v2/model" - "github.com/casbin/casbin/v2/persist" - "github.com/jmoiron/sqlx" -) - -// defaultTableName if tableName == "", the Adapter will use this default table name. -const defaultTableName = "casbin_rule" - -// maxParamLength . -const maxParamLength = 7 - -// general sql -const ( - sqlCreateTable = ` -CREATE TABLE %[1]s( - p_type VARCHAR(32), - v0 VARCHAR(255), - v1 VARCHAR(255), - v2 VARCHAR(255), - v3 VARCHAR(255), - v4 VARCHAR(255), - v5 VARCHAR(255) -); -CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);` - sqlTruncateTable = "TRUNCATE TABLE %s" - sqlIsTableExist = "SELECT 1 FROM %s" - sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)" - sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?" - sqlDeleteAll = "DELETE FROM %s" - sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?" - sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?" - sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s" - sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE " -) - -// for Sqlite3 -const ( - sqlCreateTableSqlite3 = ` -CREATE TABLE IF NOT EXISTS %[1]s( - p_type VARCHAR(32) DEFAULT '' NOT NULL, - v0 VARCHAR(255) DEFAULT '' NOT NULL, - v1 VARCHAR(255) DEFAULT '' NOT NULL, - v2 VARCHAR(255) DEFAULT '' NOT NULL, - v3 VARCHAR(255) DEFAULT '' NOT NULL, - v4 VARCHAR(255) DEFAULT '' NOT NULL, - v5 VARCHAR(255) DEFAULT '' NOT NULL, - CHECK (TYPEOF("p_type") = "text" AND - LENGTH("p_type") <= 32), - CHECK (TYPEOF("v0") = "text" AND - LENGTH("v0") <= 255), - CHECK (TYPEOF("v1") = "text" AND - LENGTH("v1") <= 255), - CHECK (TYPEOF("v2") = "text" AND - LENGTH("v2") <= 255), - CHECK (TYPEOF("v3") = "text" AND - LENGTH("v3") <= 255), - CHECK (TYPEOF("v4") = "text" AND - LENGTH("v4") <= 255), - CHECK (TYPEOF("v5") = "text" AND - LENGTH("v5") <= 255) -); -CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);` - sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3 -) - -// for Mysql -const ( - sqlCreateTableMysql = ` -CREATE TABLE IF NOT EXISTS %[1]s( - p_type VARCHAR(32) DEFAULT '' NOT NULL, - v0 VARCHAR(255) DEFAULT '' NOT NULL, - v1 VARCHAR(255) DEFAULT '' NOT NULL, - v2 VARCHAR(255) DEFAULT '' NOT NULL, - v3 VARCHAR(255) DEFAULT '' NOT NULL, - v4 VARCHAR(255) DEFAULT '' NOT NULL, - v5 VARCHAR(255) DEFAULT '' NOT NULL, - INDEX idx_%[1]s (p_type,v0,v1) -) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;` -) - -// for Postgres -const ( - sqlCreateTablePostgres = ` -CREATE TABLE IF NOT EXISTS %[1]s( - p_type VARCHAR(32) DEFAULT '' NOT NULL, - v0 VARCHAR(255) DEFAULT '' NOT NULL, - v1 VARCHAR(255) DEFAULT '' NOT NULL, - v2 VARCHAR(255) DEFAULT '' NOT NULL, - v3 VARCHAR(255) DEFAULT '' NOT NULL, - v4 VARCHAR(255) DEFAULT '' NOT NULL, - v5 VARCHAR(255) DEFAULT '' NOT NULL -); -CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);` - sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)" - sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14" - sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7" -) - -// for Sqlserver -const ( - sqlCreateTableSqlserver = ` -CREATE TABLE %[1]s( - p_type NVARCHAR(32) DEFAULT '' NOT NULL, - v0 NVARCHAR(255) DEFAULT '' NOT NULL, - v1 NVARCHAR(255) DEFAULT '' NOT NULL, - v2 NVARCHAR(255) DEFAULT '' NOT NULL, - v3 NVARCHAR(255) DEFAULT '' NOT NULL, - v4 NVARCHAR(255) DEFAULT '' NOT NULL, - v5 NVARCHAR(255) DEFAULT '' NOT NULL -); -CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);` - sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)" - sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14" - sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7" -) - -// CasbinRule defines the casbin rule model. -// It used for save or load policy lines from sqlx connected database. -type SqlCasbinRule struct { - PType string `db:"p_type"` - V0 string `db:"v0"` - V1 string `db:"v1"` - V2 string `db:"v2"` - V3 string `db:"v3"` - V4 string `db:"v4"` - V5 string `db:"v5"` -} - -// Adapter define the sqlx adapter for Casbin. -// It can load policy lines or save policy lines from sqlx connected database. -type SqlAdapter struct { - db *sqlx.DB - ctx context.Context - tableName string - - isFiltered bool - - SqlCreateTable string - SqlTruncateTable string - SqlIsTableExist string - SqlInsertRow string - SqlUpdateRow string - SqlDeleteAll string - SqlDeleteRow string - SqlDeleteByArgs string - SqlSelectAll string - SqlSelectWhere string -} - -// Filter defines the filtering rules for a FilteredAdapter's policy. -// Empty values are ignored, but all others must match the filter. -type SqlFilter struct { - PType []string - V0 []string - V1 []string - V2 []string - V3 []string - V4 []string - V5 []string -} - -// NewAdapter the constructor for Adapter. -// db should connected to database and controlled by user. -// If tableName == "", the Adapter will automatically create a table named "casbin_rule". -func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) { - return NewSqlAdapterContext(context.Background(), db, tableName) -} - -// NewAdapterContext the constructor for Adapter. -// db should connected to database and controlled by user. -// If tableName == "", the Adapter will automatically create a table named "casbin_rule". -func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) { - if db == nil { - return nil, errors.New("db is nil") - } - - // check db connecting - err := db.PingContext(ctx) - if err != nil { - return nil, err - } - - switch db.DriverName() { - case "oci8", "ora", "goracle": - return nil, errors.New("sqlxadapter: please checkout 'oracle' branch") - } - - if tableName == "" { - tableName = defaultTableName - } - - adapter := SqlAdapter{ - db: db, - ctx: ctx, - tableName: tableName, - } - - // generate different databases sql - adapter.genSQL() - - if !adapter.IsTableExist() { - if err = adapter.CreateTable(); err != nil { - return nil, err - } - } - - return &adapter, nil -} - -// genSQL generate sql based on db driver name. -func (p *SqlAdapter) genSQL() { - p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName) - p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName) - - p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName) - - p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName) - p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName) - p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName) - p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName) - p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName) - - p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName) - p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName) - - switch p.db.DriverName() { - case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres": - p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName) - p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName) - p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName) - p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName) - case "mysql": - p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName) - case "sqlite3": - p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName) - p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName) - case "sqlserver": - p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName) - p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName) - p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName) - p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName) - } -} - -// createTable create a not exists table. -func (p *SqlAdapter) CreateTable() error { - _, err := p.db.ExecContext(p.ctx, p.SqlCreateTable) - - return err -} - -// truncateTable clear the table. -func (p *SqlAdapter) TruncateTable() error { - _, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable) - - return err -} - -// deleteAll clear the table. -func (p *SqlAdapter) DeleteAll() error { - _, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll) - - return err -} - -// isTableExist check the table exists. -func (p *SqlAdapter) IsTableExist() bool { - _, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist) - - return err == nil -} - -// deleteRows delete eligible data. -func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error { - query = p.db.Rebind(query) - - _, err := p.db.ExecContext(p.ctx, query, args...) - - return err -} - -// truncateAndInsertRows clear table and insert new rows. -func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error { - if err := p.TruncateTable(); err != nil { - return err - } - return p.execTxSqlRows(p.SqlInsertRow, rules) -} - -// deleteAllAndInsertRows clear table and insert new rows. -func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error { - if err := p.DeleteAll(); err != nil { - return err - } - return p.execTxSqlRows(p.SqlInsertRow, rules) -} - -// execTxSqlRows exec sql rows. -func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) { - tx, err := p.db.BeginTx(p.ctx, nil) - if err != nil { - return - } - - var action string - - stmt, err := tx.PrepareContext(p.ctx, query) - if err != nil { - action = "prepare context" - goto ROLLBACK - } - - for _, rule := range rules { - if _, err = stmt.ExecContext(p.ctx, rule...); err != nil { - action = "stmt exec" - goto ROLLBACK - } - } - - if err = stmt.Close(); err != nil { - action = "stmt close" - goto ROLLBACK - } - - if err = tx.Commit(); err != nil { - action = "commit" - goto ROLLBACK - } - - return - -ROLLBACK: - - if err1 := tx.Rollback(); err1 != nil { - err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1) - } - - return -} - -// selectRows select eligible data by args from the table. -func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) { - // make a slice with capacity - lines := make([]*SqlCasbinRule, 0, 64) - - if len(args) == 0 { - return lines, p.db.SelectContext(p.ctx, &lines, query) - } - - query = p.db.Rebind(query) - - return lines, p.db.SelectContext(p.ctx, &lines, query, args...) -} - -// selectWhereIn select eligible data by filter from the table. -func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) { - var sqlBuf bytes.Buffer - - sqlBuf.Grow(64) - sqlBuf.WriteString(p.SqlSelectWhere) - - args := make([]interface{}, 0, 4) - - hasInCond := false - - for _, col := range [maxParamLength]struct { - name string - arg []string - }{ - {"p_type", filter.PType}, - {"v0", filter.V0}, - {"v1", filter.V1}, - {"v2", filter.V2}, - {"v3", filter.V3}, - {"v4", filter.V4}, - {"v5", filter.V5}, - } { - l := len(col.arg) - if l == 0 { - continue - } - - switch sqlBuf.Bytes()[sqlBuf.Len()-1] { - case '?', ')': - sqlBuf.WriteString(" AND ") - } - - sqlBuf.WriteString(col.name) - - if l == 1 { - sqlBuf.WriteString("=?") - args = append(args, col.arg[0]) - } else { - sqlBuf.WriteString(" IN (?)") - args = append(args, col.arg) - - hasInCond = true - } - } - - var query string - - if hasInCond { - if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil { - return - } - } else { - query = sqlBuf.String() - } - - return p.SelectRows(query, args...) -} - -// LoadPolicy load all policy rules from the storage. -func (p *SqlAdapter) LoadPolicy(model model.Model) error { - lines, err := p.SelectRows(p.SqlSelectAll) - if err != nil { - return err - } - - for _, line := range lines { - p.loadPolicyLine(line, model) - } - - return nil -} - -// SavePolicy save policy rules to the storage. -func (p *SqlAdapter) SavePolicy(model model.Model) error { - args := make([][]interface{}, 0, 64) - - for ptype, ast := range model["p"] { - for _, rule := range ast.Policy { - arg := p.GenArgs(ptype, rule) - args = append(args, arg) - } - } - - for ptype, ast := range model["g"] { - for _, rule := range ast.Policy { - arg := p.GenArgs(ptype, rule) - args = append(args, arg) - } - } - - return p.DeleteAllAndInsertRows(args) -} - -// AddPolicy add one policy rule to the storage. -func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error { - args := p.GenArgs(ptype, rule) - - _, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...) - - return err -} - -// AddPolicies add multiple policy rules to the storage. -func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error { - args := make([][]interface{}, 0, 8) - - for _, rule := range rules { - arg := p.GenArgs(ptype, rule) - args = append(args, arg) - } - - return p.execTxSqlRows(p.SqlInsertRow, args) -} - -// RemovePolicy remove policy rules from the storage. -func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error { - var sqlBuf bytes.Buffer - - sqlBuf.Grow(64) - sqlBuf.WriteString(p.SqlDeleteByArgs) - - args := make([]interface{}, 0, 4) - args = append(args, ptype) - - for idx, arg := range rule { - if arg != "" { - sqlBuf.WriteString(" AND v") - sqlBuf.WriteString(strconv.Itoa(idx)) - sqlBuf.WriteString("=?") - - args = append(args, arg) - } - } - - return p.DeleteRows(sqlBuf.String(), args...) -} - -// RemoveFilteredPolicy remove policy rules that match the filter from the storage. -func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { - var sqlBuf bytes.Buffer - - sqlBuf.Grow(64) - sqlBuf.WriteString(p.SqlDeleteByArgs) - - args := make([]interface{}, 0, 4) - args = append(args, ptype) - - var value string - - l := fieldIndex + len(fieldValues) - - for idx := 0; idx < 6; idx++ { - if fieldIndex <= idx && idx < l { - value = fieldValues[idx-fieldIndex] - - if value != "" { - sqlBuf.WriteString(" AND v") - sqlBuf.WriteString(strconv.Itoa(idx)) - sqlBuf.WriteString("=?") - - args = append(args, value) - } - } - } - - return p.DeleteRows(sqlBuf.String(), args...) -} - -// RemovePolicies remove policy rules. -func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) { - args := make([][]interface{}, 0, 8) - - for _, rule := range rules { - arg := p.GenArgs(ptype, rule) - args = append(args, arg) - } - - return p.execTxSqlRows(p.SqlDeleteRow, args) -} - -// LoadFilteredPolicy load policy rules that match the filter. -// filterPtr must be a pointer. -func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error { - if filterPtr == nil { - return p.LoadPolicy(model) - } - - filter, ok := filterPtr.(*SqlFilter) - if !ok { - return errors.New("invalid filter type") - } - - lines, err := p.SelectWhereIn(filter) - if err != nil { - return err - } - - for _, line := range lines { - p.loadPolicyLine(line, model) - } - - p.isFiltered = true - - return nil -} - -// IsFiltered returns true if the loaded policy rules has been filtered. -func (p *SqlAdapter) IsFiltered() bool { - return p.isFiltered -} - -// UpdatePolicy update a policy rule from storage. -// This is part of the Auto-Save feature. -func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error { - oldArg := p.GenArgs(ptype, oldRule) - newArg := p.GenArgs(ptype, newPolicy) - - _, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...) - - return err -} - -// UpdatePolicies updates policy rules to storage. -func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) { - if len(oldRules) != len(newRules) { - return errors.New("old rules size not equal to new rules size") - } - - args := make([][]interface{}, 0, 16) - - for idx := range oldRules { - oldArg := p.GenArgs(ptype, oldRules[idx]) - newArg := p.GenArgs(ptype, newRules[idx]) - args = append(args, append(newArg, oldArg...)) - } - - return p.execTxSqlRows(p.SqlUpdateRow, args) -} - -// UpdateFilteredPolicies deletes old rules and adds new rules. -func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) { - var value string - - var whereBuf bytes.Buffer - whereBuf.Grow(32) - - l := fieldIndex + len(fieldValues) - - whereArgs := make([]interface{}, 0, 4) - whereArgs = append(whereArgs, ptype) - - for idx := 0; idx < 6; idx++ { - if fieldIndex <= idx && idx < l { - value = fieldValues[idx-fieldIndex] - - if value != "" { - whereBuf.WriteString(" AND v") - whereBuf.WriteString(strconv.Itoa(idx)) - whereBuf.WriteString("=?") - - whereArgs = append(whereArgs, value) - } - } - } - - var selectBuf bytes.Buffer - selectBuf.Grow(64) - selectBuf.WriteString(p.SqlSelectWhere) - selectBuf.WriteString("p_type=?") - selectBuf.Write(whereBuf.Bytes()) - - var oldRows []*SqlCasbinRule - value = p.db.Rebind(selectBuf.String()) - oldRows, err = p.SelectRows(value, whereArgs...) - if err != nil { - return - } - - var deleteBuf bytes.Buffer - deleteBuf.Grow(64) - deleteBuf.WriteString(p.SqlDeleteByArgs) - deleteBuf.Write(whereBuf.Bytes()) - - var tx *sqlx.Tx - tx, err = p.db.BeginTxx(p.ctx, nil) - if err != nil { - return - } - - var ( - stmt *sqlx.Stmt - action string - ) - value = p.db.Rebind(deleteBuf.String()) - if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil { - action = "delete old policies" - goto ROLLBACK - } - - stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow) - if err != nil { - action = "preparex context" - goto ROLLBACK - } - - for _, policy := range newPolicies { - arg := p.GenArgs(ptype, policy) - if _, err = stmt.ExecContext(p.ctx, arg...); err != nil { - action = "stmt exec context" - goto ROLLBACK - } - } - - if err = stmt.Close(); err != nil { - action = "stmt close" - goto ROLLBACK - } - - if err = tx.Commit(); err != nil { - action = "commit" - goto ROLLBACK - } - - oldPolicies = make([][]string, 0, len(oldRows)) - for _, rule := range oldRows { - oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5}) - } - - return - -ROLLBACK: - - if err1 := tx.Rollback(); err1 != nil { - err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1) - } - - return -} - -// loadPolicyLine load a policy line to model. -func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) { - if line == nil { - return - } - - var lineBuf bytes.Buffer - - lineBuf.Grow(64) - lineBuf.WriteString(line.PType) - - args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5} - for _, arg := range args { - if arg != "" { - lineBuf.WriteByte(',') - lineBuf.WriteString(arg) - } - } - - persist.LoadPolicyLine(lineBuf.String(), model) -} - -// genArgs generate args from ptype and rule. -func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} { - l := len(rule) - - args := make([]interface{}, maxParamLength) - args[0] = ptype - - for idx := 0; idx < l; idx++ { - args[idx+1] = rule[idx] - } - - for idx := l + 1; idx < maxParamLength; idx++ { - args[idx] = "" - } - - return args -} diff --git a/gcasbin/adapter_sqlx_test.go b/gcasbin/adapter_sqlx_test.go deleted file mode 100644 index 5aa87c9..0000000 --- a/gcasbin/adapter_sqlx_test.go +++ /dev/null @@ -1,466 +0,0 @@ -// -// adapter_sqlx_test.go -// Copyright (C) 2022 tiglog -// -// Distributed under terms of the MIT license. -// - -package gcasbin_test - -import ( - "os" - "strings" - "testing" - - "github.com/casbin/casbin/v2" - "github.com/casbin/casbin/v2/util" - _ "github.com/go-sql-driver/mysql" - "github.com/jmoiron/sqlx" - "git.hexq.cn/tiglog/golib/gcasbin" -) - -const ( - rbacModelFile = "testdata/rbac_model.conf" - rbacPolicyFile = "testdata/rbac_policy.csv" -) - -var ( - dataSourceNames = map[string]string{ - // "sqlite3": ":memory:", - // "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test", - "postgres": os.Getenv("DB_DSN"), - // "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30", - } - - lines = []gcasbin.SqlCasbinRule{ - {PType: "p", V0: "alice", V1: "data1", V2: "read"}, - {PType: "p", V0: "bob", V1: "data2", V2: "read"}, - {PType: "p", V0: "bob", V1: "data2", V2: "write"}, - {PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"}, - {PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"}, - {PType: "p", V0: "data1_admin", V1: "data2", V2: "write"}, - {PType: "g", V0: "alice", V1: "data2_admin"}, - {PType: "g", V0: "bob", V1: "data2_admin", V2: "test"}, - {PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"}, - } - - filter = gcasbin.SqlFilter{ - PType: []string{"p"}, - V0: []string{"bob", "data2_admin"}, - V1: []string{"data1", "data2"}, - V2: []string{"read", "write"}, - V3: []string{"test1"}, - V4: []string{"test2"}, - V5: []string{"test3"}, - } -) - -func TestSqlAdapters(t *testing.T) { - for key, value := range dataSourceNames { - t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value) - - db, err := sqlx.Connect(key, value) - if err != nil { - t.Fatalf("sqlx.Connect failed, err: %v", err) - } - - t.Log("---------- testTableName start") - testTableName(t, db) - t.Log("---------- testTableName finished") - - t.Log("---------- testSQL start") - testSQL(t, db, "sqlxadapter_sql") - t.Log("---------- testSQL finished") - - t.Log("---------- testSaveLoad start") - testSaveLoad(t, db, "sqlxadapter_save_load") - t.Log("---------- testSaveLoad finished") - - t.Log("---------- testAutoSave start") - testAutoSave(t, db, "sqlxadapter_auto_save") - t.Log("---------- testAutoSave finished") - - t.Log("---------- testFilteredSqlPolicy start") - testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy") - t.Log("---------- testFilteredSqlPolicy finished") - - // t.Log("---------- testUpdateSqlPolicy start") - // testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy") - // t.Log("---------- testUpdateSqlPolicy finished") - - // t.Log("---------- testUpdateSqlPolicies start") - // testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy") - // t.Log("---------- testUpdateSqlPolicies finished") - - // t.Log("---------- testUpdateFilteredSqlPolicies start") - // testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy") - // t.Log("---------- testUpdateFilteredSqlPolicies finished") - - } -} - -func testTableName(t *testing.T, db *sqlx.DB) { - _, err := gcasbin.NewSqlAdapter(db, "") - if err != nil { - t.Fatalf("NewAdapter failed, err: %v", err) - } -} - -func testSQL(t *testing.T, db *sqlx.DB, tableName string) { - var err error - logErr := func(action string) { - if err != nil { - t.Errorf("%s test failed, err: %v", action, err) - } - } - - equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool { - if line1.PType != line2.PType || - line1.V0 != line2.V0 || - line1.V1 != line2.V1 || - line1.V2 != line2.V2 || - line1.V3 != line2.V3 || - line1.V4 != line2.V4 || - line1.V5 != line2.V5 { - return false - } - return true - } - - var a *gcasbin.SqlAdapter - a, err = gcasbin.NewSqlAdapter(db, tableName) - logErr("NewSqlAdapter") - - // createTable test has passed when adapter create - // err = a.CreateTable() - // logErr("createTable") - - if b := a.IsTableExist(); b == false { - t.Fatal("isTableExist test failed") - } - - rules := make([][]interface{}, len(lines)) - for idx, rule := range lines { - args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5}) - rules[idx] = args - } - - err = a.TruncateAndInsertRows(rules) - logErr("truncateAndInsertRows") - - err = a.DeleteAllAndInsertRows(rules) - logErr("truncateAndInsertRows") - - err = a.DeleteRows(a.SqlDeleteByArgs, "g") - logErr("deleteRows sqlDeleteByArgs g") - - err = a.DeleteRows(a.SqlDeleteAll) - logErr("deleteRows sqlDeleteAll") - - _ = a.TruncateAndInsertRows(rules) - - records, err := a.SelectRows(a.SqlSelectAll) - logErr("selectRows sqlSelectAll") - for idx, record := range records { - line := lines[idx] - if !equalValue(*record, line) { - t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line) - } - } - - records, err = a.SelectWhereIn(&filter) - logErr("selectWhereIn") - i := 3 - for _, record := range records { - line := lines[i] - if !equalValue(*record, line) { - t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line) - } - i++ - } - - err = a.TruncateTable() - logErr("truncateTable") -} - -func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) { - // Because the DB is empty at first, - // so we need to load the policy from the file adapter (.CSV) first. - e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile) - - a, err := gcasbin.NewSqlAdapter(db, tableName) - if err != nil { - t.Fatal("NewAdapter test failed, err: ", err) - } - - // This is a trick to save the current policy to the DB. - // We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter. - // The current policy means the policy in the Casbin enforcer (aka in memory). - err = a.SavePolicy(e.GetModel()) - if err != nil { - t.Fatal("SavePolicy test failed, err: ", err) - } - - // Clear the current policy. - e.ClearPolicy() - testGetSqlPolicy(t, e, [][]string{}) - - // Load the policy from DB. - err = a.LoadPolicy(e.GetModel()) - if err != nil { - t.Fatal("LoadPolicy test failed, err: ", err) - } - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) -} - -func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) { - // Initialize some policy in DB. - initSqlPolicy(t, db, tableName) - // Note: you don't need to look at the above code - // if you already have a working DB with policy inside. - - // Now the DB has policy, so we can provide a normal use case. - // Create an adapter and an enforcer. - // NewEnforcer() will load the policy automatically. - a, _ := gcasbin.NewSqlAdapter(db, tableName) - e, _ := casbin.NewEnforcer(rbacModelFile, a) - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) -} - -func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) { - // Initialize some policy in DB. - initSqlPolicy(t, db, tableName) - // Note: you don't need to look at the above code - // if you already have a working DB with policy inside. - - // Now the DB has policy, so we can provide a normal use case. - // Create an adapter and an enforcer. - // NewEnforcer() will load the policy automatically. - a, _ := gcasbin.NewSqlAdapter(db, tableName) - e, _ := casbin.NewEnforcer(rbacModelFile, a) - - // AutoSave is enabled by default. - // Now we disable it. - e.EnableAutoSave(false) - - var err error - logErr := func(action string) { - if err != nil { - t.Errorf("%s test failed, err: %v", action, err) - } - } - - // Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer, - // it doesn't affect the policy in the storage. - _, err = e.AddPolicy("alice", "data1", "write") - logErr("AddPolicy1") - // Reload the policy from the storage to see the effect. - err = e.LoadPolicy() - logErr("LoadPolicy1") - // This is still the original policy. - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) - - _, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}}) - logErr("AddPolicies1") - // Reload the policy from the storage to see the effect. - err = e.LoadPolicy() - logErr("LoadPolicy2") - // This is still the original policy. - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) - - // Now we enable the AutoSave. - e.EnableAutoSave(true) - - // Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer, - // but also affects the policy in the storage. - _, err = e.AddPolicy("alice", "data1", "write") - logErr("AddPolicy2") - // Reload the policy from the storage to see the effect. - err = e.LoadPolicy() - logErr("LoadPolicy3") - // The policy has a new rule: {"alice", "data1", "write"}. - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}) - - _, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}}) - logErr("AddPolicies2") - // Reload the policy from the storage to see the effect. - err = e.LoadPolicy() - logErr("LoadPolicy4") - // This is still the original policy. - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}, - {"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}}) - - _, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}}) - logErr("RemovePolicies") - err = e.LoadPolicy() - logErr("LoadPolicy5") - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}) - - // Remove the added rule. - _, err = e.RemovePolicy("alice", "data1", "write") - logErr("RemovePolicy") - err = e.LoadPolicy() - logErr("LoadPolicy6") - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) - - // Remove "data2_admin" related policy rules via a filter. - // Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted. - _, err = e.RemoveFilteredPolicy(0, "data2_admin") - logErr("RemoveFilteredPolicy") - err = e.LoadPolicy() - logErr("LoadPolicy7") - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) -} - -func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) { - // Initialize some policy in DB. - initSqlPolicy(t, db, tableName) - // Note: you don't need to look at the above code - // if you already have a working DB with policy inside. - - // Now the DB has policy, so we can provide a normal use case. - // Create an adapter and an enforcer. - // NewEnforcer() will load the policy automatically. - a, _ := gcasbin.NewSqlAdapter(db, tableName) - e, _ := casbin.NewEnforcer(rbacModelFile, a) - // Now set the adapter - e.SetAdapter(a) - - var err error - logErr := func(action string) { - if err != nil { - t.Errorf("%s test failed, err: %v", action, err) - } - } - - // Load only alice's policies - err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}}) - logErr("LoadFilteredPolicy alice") - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}}) - - // Load only bob's policies - err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}}) - logErr("LoadFilteredPolicy bob") - testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}}) - - // Load policies for data2_admin - err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}}) - logErr("LoadFilteredPolicy data2_admin") - testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) - - // Load policies for alice and bob - err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}}) - logErr("LoadFilteredPolicy alice bob") - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) - - _, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3") - logErr("AddPolicy") - - err = e.LoadFilteredPolicy(&filter) - logErr("LoadFilteredPolicy filter") - testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}}) -} - -func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) { - // Initialize some policy in DB. - initSqlPolicy(t, db, tableName) - - a, _ := gcasbin.NewSqlAdapter(db, tableName) - e, _ := casbin.NewEnforcer(rbacModelFile, a) - - e.EnableAutoSave(true) - e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"}) - e.LoadPolicy() - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) -} - -func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) { - // Initialize some policy in DB. - initSqlPolicy(t, db, tableName) - - a, _ := gcasbin.NewSqlAdapter(db, tableName) - e, _ := casbin.NewEnforcer(rbacModelFile, a) - - e.EnableAutoSave(true) - e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}}) - e.LoadPolicy() - testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) -} - -func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) { - // Initialize some policy in DB. - initSqlPolicy(t, db, tableName) - - a, _ := gcasbin.NewSqlAdapter(db, tableName) - e, _ := casbin.NewEnforcer(rbacModelFile, a) - - e.EnableAutoSave(true) - e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read") - e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write") - e.LoadPolicy() - testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}}) -} - -func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) { - t.Helper() - myRes := e.GetPolicy() - t.Log("Policy: ", myRes) - - m := make(map[string]struct{}, len(myRes)) - for _, record := range myRes { - key := strings.Join(record, ",") - m[key] = struct{}{} - } - - for _, record := range res { - key := strings.Join(record, ",") - if _, ok := m[key]; !ok { - t.Error("Policy: \n", myRes, ", supposed to be \n", res) - break - } - } -} - -func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) { - myRes := e.GetPolicy() - // log.Print("Policy: \n", myRes) - - if !arraySqlEqualsWithoutOrder(myRes, res) { - t.Error("Policy: \n", myRes, ", supposed to be \n", res) - } -} - -func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool { - if len(a) != len(b) { - return false - } - - mapA := make(map[int]string) - mapB := make(map[int]string) - order := make(map[int]struct{}) - l := len(a) - - for i := 0; i < l; i++ { - mapA[i] = util.ArrayToString(a[i]) - mapB[i] = util.ArrayToString(b[i]) - } - - for i := 0; i < l; i++ { - for j := 0; j < l; j++ { - if _, ok := order[j]; ok { - if j == l-1 { - return false - } else { - continue - } - } - if mapA[i] == mapB[j] { - order[j] = struct{}{} - break - } else if j == l-1 { - return false - } - } - } - return true -}