// // adapter_sqlx.go // Copyright (C) 2022 tiglog // // Distributed under terms of the MIT license. // package gcasbin import ( "bytes" "context" "errors" "fmt" "strconv" "github.com/casbin/casbin/v2/model" "github.com/casbin/casbin/v2/persist" "github.com/jmoiron/sqlx" ) // defaultTableName if tableName == "", the Adapter will use this default table name. const defaultTableName = "casbin_rule" // maxParamLength . const maxParamLength = 7 // general sql const ( sqlCreateTable = ` CREATE TABLE %[1]s( p_type VARCHAR(32), v0 VARCHAR(255), v1 VARCHAR(255), v2 VARCHAR(255), v3 VARCHAR(255), v4 VARCHAR(255), v5 VARCHAR(255) ); CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);` sqlTruncateTable = "TRUNCATE TABLE %s" sqlIsTableExist = "SELECT 1 FROM %s" sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)" sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?" sqlDeleteAll = "DELETE FROM %s" sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?" sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?" sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s" sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE " ) // for Sqlite3 const ( sqlCreateTableSqlite3 = ` CREATE TABLE IF NOT EXISTS %[1]s( p_type VARCHAR(32) DEFAULT '' NOT NULL, v0 VARCHAR(255) DEFAULT '' NOT NULL, v1 VARCHAR(255) DEFAULT '' NOT NULL, v2 VARCHAR(255) DEFAULT '' NOT NULL, v3 VARCHAR(255) DEFAULT '' NOT NULL, v4 VARCHAR(255) DEFAULT '' NOT NULL, v5 VARCHAR(255) DEFAULT '' NOT NULL, CHECK (TYPEOF("p_type") = "text" AND LENGTH("p_type") <= 32), CHECK (TYPEOF("v0") = "text" AND LENGTH("v0") <= 255), CHECK (TYPEOF("v1") = "text" AND LENGTH("v1") <= 255), CHECK (TYPEOF("v2") = "text" AND LENGTH("v2") <= 255), CHECK (TYPEOF("v3") = "text" AND LENGTH("v3") <= 255), CHECK (TYPEOF("v4") = "text" AND LENGTH("v4") <= 255), CHECK (TYPEOF("v5") = "text" AND LENGTH("v5") <= 255) ); CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);` sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3 ) // for Mysql const ( sqlCreateTableMysql = ` CREATE TABLE IF NOT EXISTS %[1]s( p_type VARCHAR(32) DEFAULT '' NOT NULL, v0 VARCHAR(255) DEFAULT '' NOT NULL, v1 VARCHAR(255) DEFAULT '' NOT NULL, v2 VARCHAR(255) DEFAULT '' NOT NULL, v3 VARCHAR(255) DEFAULT '' NOT NULL, v4 VARCHAR(255) DEFAULT '' NOT NULL, v5 VARCHAR(255) DEFAULT '' NOT NULL, INDEX idx_%[1]s (p_type,v0,v1) ) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;` ) // for Postgres const ( sqlCreateTablePostgres = ` CREATE TABLE IF NOT EXISTS %[1]s( p_type VARCHAR(32) DEFAULT '' NOT NULL, v0 VARCHAR(255) DEFAULT '' NOT NULL, v1 VARCHAR(255) DEFAULT '' NOT NULL, v2 VARCHAR(255) DEFAULT '' NOT NULL, v3 VARCHAR(255) DEFAULT '' NOT NULL, v4 VARCHAR(255) DEFAULT '' NOT NULL, v5 VARCHAR(255) DEFAULT '' NOT NULL ); CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);` sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)" sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14" sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7" ) // for Sqlserver const ( sqlCreateTableSqlserver = ` CREATE TABLE %[1]s( p_type NVARCHAR(32) DEFAULT '' NOT NULL, v0 NVARCHAR(255) DEFAULT '' NOT NULL, v1 NVARCHAR(255) DEFAULT '' NOT NULL, v2 NVARCHAR(255) DEFAULT '' NOT NULL, v3 NVARCHAR(255) DEFAULT '' NOT NULL, v4 NVARCHAR(255) DEFAULT '' NOT NULL, v5 NVARCHAR(255) DEFAULT '' NOT NULL ); CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);` sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)" sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14" sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7" ) // CasbinRule defines the casbin rule model. // It used for save or load policy lines from sqlx connected database. type SqlCasbinRule struct { PType string `db:"p_type"` V0 string `db:"v0"` V1 string `db:"v1"` V2 string `db:"v2"` V3 string `db:"v3"` V4 string `db:"v4"` V5 string `db:"v5"` } // Adapter define the sqlx adapter for Casbin. // It can load policy lines or save policy lines from sqlx connected database. type SqlAdapter struct { db *sqlx.DB ctx context.Context tableName string isFiltered bool SqlCreateTable string SqlTruncateTable string SqlIsTableExist string SqlInsertRow string SqlUpdateRow string SqlDeleteAll string SqlDeleteRow string SqlDeleteByArgs string SqlSelectAll string SqlSelectWhere string } // Filter defines the filtering rules for a FilteredAdapter's policy. // Empty values are ignored, but all others must match the filter. type SqlFilter struct { PType []string V0 []string V1 []string V2 []string V3 []string V4 []string V5 []string } // NewAdapter the constructor for Adapter. // db should connected to database and controlled by user. // If tableName == "", the Adapter will automatically create a table named "casbin_rule". func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) { return NewSqlAdapterContext(context.Background(), db, tableName) } // NewAdapterContext the constructor for Adapter. // db should connected to database and controlled by user. // If tableName == "", the Adapter will automatically create a table named "casbin_rule". func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) { if db == nil { return nil, errors.New("db is nil") } // check db connecting err := db.PingContext(ctx) if err != nil { return nil, err } switch db.DriverName() { case "oci8", "ora", "goracle": return nil, errors.New("sqlxadapter: please checkout 'oracle' branch") } if tableName == "" { tableName = defaultTableName } adapter := SqlAdapter{ db: db, ctx: ctx, tableName: tableName, } // generate different databases sql adapter.genSQL() if !adapter.IsTableExist() { if err = adapter.CreateTable(); err != nil { return nil, err } } return &adapter, nil } // genSQL generate sql based on db driver name. func (p *SqlAdapter) genSQL() { p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName) p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName) p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName) p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName) p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName) p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName) p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName) p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName) p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName) p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName) switch p.db.DriverName() { case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres": p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName) p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName) p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName) p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName) case "mysql": p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName) case "sqlite3": p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName) p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName) case "sqlserver": p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName) p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName) p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName) p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName) } } // createTable create a not exists table. func (p *SqlAdapter) CreateTable() error { _, err := p.db.ExecContext(p.ctx, p.SqlCreateTable) return err } // truncateTable clear the table. func (p *SqlAdapter) TruncateTable() error { _, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable) return err } // deleteAll clear the table. func (p *SqlAdapter) DeleteAll() error { _, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll) return err } // isTableExist check the table exists. func (p *SqlAdapter) IsTableExist() bool { _, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist) return err == nil } // deleteRows delete eligible data. func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error { query = p.db.Rebind(query) _, err := p.db.ExecContext(p.ctx, query, args...) return err } // truncateAndInsertRows clear table and insert new rows. func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error { if err := p.TruncateTable(); err != nil { return err } return p.execTxSqlRows(p.SqlInsertRow, rules) } // deleteAllAndInsertRows clear table and insert new rows. func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error { if err := p.DeleteAll(); err != nil { return err } return p.execTxSqlRows(p.SqlInsertRow, rules) } // execTxSqlRows exec sql rows. func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) { tx, err := p.db.BeginTx(p.ctx, nil) if err != nil { return } var action string stmt, err := tx.PrepareContext(p.ctx, query) if err != nil { action = "prepare context" goto ROLLBACK } for _, rule := range rules { if _, err = stmt.ExecContext(p.ctx, rule...); err != nil { action = "stmt exec" goto ROLLBACK } } if err = stmt.Close(); err != nil { action = "stmt close" goto ROLLBACK } if err = tx.Commit(); err != nil { action = "commit" goto ROLLBACK } return ROLLBACK: if err1 := tx.Rollback(); err1 != nil { err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1) } return } // selectRows select eligible data by args from the table. func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) { // make a slice with capacity lines := make([]*SqlCasbinRule, 0, 64) if len(args) == 0 { return lines, p.db.SelectContext(p.ctx, &lines, query) } query = p.db.Rebind(query) return lines, p.db.SelectContext(p.ctx, &lines, query, args...) } // selectWhereIn select eligible data by filter from the table. func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) { var sqlBuf bytes.Buffer sqlBuf.Grow(64) sqlBuf.WriteString(p.SqlSelectWhere) args := make([]interface{}, 0, 4) hasInCond := false for _, col := range [maxParamLength]struct { name string arg []string }{ {"p_type", filter.PType}, {"v0", filter.V0}, {"v1", filter.V1}, {"v2", filter.V2}, {"v3", filter.V3}, {"v4", filter.V4}, {"v5", filter.V5}, } { l := len(col.arg) if l == 0 { continue } switch sqlBuf.Bytes()[sqlBuf.Len()-1] { case '?', ')': sqlBuf.WriteString(" AND ") } sqlBuf.WriteString(col.name) if l == 1 { sqlBuf.WriteString("=?") args = append(args, col.arg[0]) } else { sqlBuf.WriteString(" IN (?)") args = append(args, col.arg) hasInCond = true } } var query string if hasInCond { if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil { return } } else { query = sqlBuf.String() } return p.SelectRows(query, args...) } // LoadPolicy load all policy rules from the storage. func (p *SqlAdapter) LoadPolicy(model model.Model) error { lines, err := p.SelectRows(p.SqlSelectAll) if err != nil { return err } for _, line := range lines { p.loadPolicyLine(line, model) } return nil } // SavePolicy save policy rules to the storage. func (p *SqlAdapter) SavePolicy(model model.Model) error { args := make([][]interface{}, 0, 64) for ptype, ast := range model["p"] { for _, rule := range ast.Policy { arg := p.GenArgs(ptype, rule) args = append(args, arg) } } for ptype, ast := range model["g"] { for _, rule := range ast.Policy { arg := p.GenArgs(ptype, rule) args = append(args, arg) } } return p.DeleteAllAndInsertRows(args) } // AddPolicy add one policy rule to the storage. func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error { args := p.GenArgs(ptype, rule) _, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...) return err } // AddPolicies add multiple policy rules to the storage. func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error { args := make([][]interface{}, 0, 8) for _, rule := range rules { arg := p.GenArgs(ptype, rule) args = append(args, arg) } return p.execTxSqlRows(p.SqlInsertRow, args) } // RemovePolicy remove policy rules from the storage. func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error { var sqlBuf bytes.Buffer sqlBuf.Grow(64) sqlBuf.WriteString(p.SqlDeleteByArgs) args := make([]interface{}, 0, 4) args = append(args, ptype) for idx, arg := range rule { if arg != "" { sqlBuf.WriteString(" AND v") sqlBuf.WriteString(strconv.Itoa(idx)) sqlBuf.WriteString("=?") args = append(args, arg) } } return p.DeleteRows(sqlBuf.String(), args...) } // RemoveFilteredPolicy remove policy rules that match the filter from the storage. func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { var sqlBuf bytes.Buffer sqlBuf.Grow(64) sqlBuf.WriteString(p.SqlDeleteByArgs) args := make([]interface{}, 0, 4) args = append(args, ptype) var value string l := fieldIndex + len(fieldValues) for idx := 0; idx < 6; idx++ { if fieldIndex <= idx && idx < l { value = fieldValues[idx-fieldIndex] if value != "" { sqlBuf.WriteString(" AND v") sqlBuf.WriteString(strconv.Itoa(idx)) sqlBuf.WriteString("=?") args = append(args, value) } } } return p.DeleteRows(sqlBuf.String(), args...) } // RemovePolicies remove policy rules. func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) { args := make([][]interface{}, 0, 8) for _, rule := range rules { arg := p.GenArgs(ptype, rule) args = append(args, arg) } return p.execTxSqlRows(p.SqlDeleteRow, args) } // LoadFilteredPolicy load policy rules that match the filter. // filterPtr must be a pointer. func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error { if filterPtr == nil { return p.LoadPolicy(model) } filter, ok := filterPtr.(*SqlFilter) if !ok { return errors.New("invalid filter type") } lines, err := p.SelectWhereIn(filter) if err != nil { return err } for _, line := range lines { p.loadPolicyLine(line, model) } p.isFiltered = true return nil } // IsFiltered returns true if the loaded policy rules has been filtered. func (p *SqlAdapter) IsFiltered() bool { return p.isFiltered } // UpdatePolicy update a policy rule from storage. // This is part of the Auto-Save feature. func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error { oldArg := p.GenArgs(ptype, oldRule) newArg := p.GenArgs(ptype, newPolicy) _, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...) return err } // UpdatePolicies updates policy rules to storage. func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) { if len(oldRules) != len(newRules) { return errors.New("old rules size not equal to new rules size") } args := make([][]interface{}, 0, 16) for idx := range oldRules { oldArg := p.GenArgs(ptype, oldRules[idx]) newArg := p.GenArgs(ptype, newRules[idx]) args = append(args, append(newArg, oldArg...)) } return p.execTxSqlRows(p.SqlUpdateRow, args) } // UpdateFilteredPolicies deletes old rules and adds new rules. func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) { var value string var whereBuf bytes.Buffer whereBuf.Grow(32) l := fieldIndex + len(fieldValues) whereArgs := make([]interface{}, 0, 4) whereArgs = append(whereArgs, ptype) for idx := 0; idx < 6; idx++ { if fieldIndex <= idx && idx < l { value = fieldValues[idx-fieldIndex] if value != "" { whereBuf.WriteString(" AND v") whereBuf.WriteString(strconv.Itoa(idx)) whereBuf.WriteString("=?") whereArgs = append(whereArgs, value) } } } var selectBuf bytes.Buffer selectBuf.Grow(64) selectBuf.WriteString(p.SqlSelectWhere) selectBuf.WriteString("p_type=?") selectBuf.Write(whereBuf.Bytes()) var oldRows []*SqlCasbinRule value = p.db.Rebind(selectBuf.String()) oldRows, err = p.SelectRows(value, whereArgs...) if err != nil { return } var deleteBuf bytes.Buffer deleteBuf.Grow(64) deleteBuf.WriteString(p.SqlDeleteByArgs) deleteBuf.Write(whereBuf.Bytes()) var tx *sqlx.Tx tx, err = p.db.BeginTxx(p.ctx, nil) if err != nil { return } var ( stmt *sqlx.Stmt action string ) value = p.db.Rebind(deleteBuf.String()) if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil { action = "delete old policies" goto ROLLBACK } stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow) if err != nil { action = "preparex context" goto ROLLBACK } for _, policy := range newPolicies { arg := p.GenArgs(ptype, policy) if _, err = stmt.ExecContext(p.ctx, arg...); err != nil { action = "stmt exec context" goto ROLLBACK } } if err = stmt.Close(); err != nil { action = "stmt close" goto ROLLBACK } if err = tx.Commit(); err != nil { action = "commit" goto ROLLBACK } oldPolicies = make([][]string, 0, len(oldRows)) for _, rule := range oldRows { oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5}) } return ROLLBACK: if err1 := tx.Rollback(); err1 != nil { err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1) } return } // loadPolicyLine load a policy line to model. func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) { if line == nil { return } var lineBuf bytes.Buffer lineBuf.Grow(64) lineBuf.WriteString(line.PType) args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5} for _, arg := range args { if arg != "" { lineBuf.WriteByte(',') lineBuf.WriteString(arg) } } persist.LoadPolicyLine(lineBuf.String(), model) } // genArgs generate args from ptype and rule. func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} { l := len(rule) args := make([]interface{}, maxParamLength) args[0] = ptype for idx := 0; idx < l; idx++ { args[idx+1] = rule[idx] } for idx := l + 1; idx < maxParamLength; idx++ { args[idx] = "" } return args }