467 lines
15 KiB
Go
467 lines
15 KiB
Go
//
|
|
// adapter_sqlx_test.go
|
|
// Copyright (C) 2022 tiglog <me@tiglog.com>
|
|
//
|
|
// Distributed under terms of the MIT license.
|
|
//
|
|
|
|
package gcasbin_test
|
|
|
|
import (
|
|
"os"
|
|
"strings"
|
|
"testing"
|
|
|
|
"github.com/casbin/casbin/v2"
|
|
"github.com/casbin/casbin/v2/util"
|
|
_ "github.com/go-sql-driver/mysql"
|
|
"github.com/jmoiron/sqlx"
|
|
"git.hexq.cn/tiglog/golib/gcasbin"
|
|
)
|
|
|
|
const (
|
|
rbacModelFile = "testdata/rbac_model.conf"
|
|
rbacPolicyFile = "testdata/rbac_policy.csv"
|
|
)
|
|
|
|
var (
|
|
dataSourceNames = map[string]string{
|
|
// "sqlite3": ":memory:",
|
|
// "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test",
|
|
"postgres": os.Getenv("DB_DSN"),
|
|
// "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30",
|
|
}
|
|
|
|
lines = []gcasbin.SqlCasbinRule{
|
|
{PType: "p", V0: "alice", V1: "data1", V2: "read"},
|
|
{PType: "p", V0: "bob", V1: "data2", V2: "read"},
|
|
{PType: "p", V0: "bob", V1: "data2", V2: "write"},
|
|
{PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"},
|
|
{PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"},
|
|
{PType: "p", V0: "data1_admin", V1: "data2", V2: "write"},
|
|
{PType: "g", V0: "alice", V1: "data2_admin"},
|
|
{PType: "g", V0: "bob", V1: "data2_admin", V2: "test"},
|
|
{PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"},
|
|
}
|
|
|
|
filter = gcasbin.SqlFilter{
|
|
PType: []string{"p"},
|
|
V0: []string{"bob", "data2_admin"},
|
|
V1: []string{"data1", "data2"},
|
|
V2: []string{"read", "write"},
|
|
V3: []string{"test1"},
|
|
V4: []string{"test2"},
|
|
V5: []string{"test3"},
|
|
}
|
|
)
|
|
|
|
func TestSqlAdapters(t *testing.T) {
|
|
for key, value := range dataSourceNames {
|
|
t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value)
|
|
|
|
db, err := sqlx.Connect(key, value)
|
|
if err != nil {
|
|
t.Fatalf("sqlx.Connect failed, err: %v", err)
|
|
}
|
|
|
|
t.Log("---------- testTableName start")
|
|
testTableName(t, db)
|
|
t.Log("---------- testTableName finished")
|
|
|
|
t.Log("---------- testSQL start")
|
|
testSQL(t, db, "sqlxadapter_sql")
|
|
t.Log("---------- testSQL finished")
|
|
|
|
t.Log("---------- testSaveLoad start")
|
|
testSaveLoad(t, db, "sqlxadapter_save_load")
|
|
t.Log("---------- testSaveLoad finished")
|
|
|
|
t.Log("---------- testAutoSave start")
|
|
testAutoSave(t, db, "sqlxadapter_auto_save")
|
|
t.Log("---------- testAutoSave finished")
|
|
|
|
t.Log("---------- testFilteredSqlPolicy start")
|
|
testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy")
|
|
t.Log("---------- testFilteredSqlPolicy finished")
|
|
|
|
// t.Log("---------- testUpdateSqlPolicy start")
|
|
// testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy")
|
|
// t.Log("---------- testUpdateSqlPolicy finished")
|
|
|
|
// t.Log("---------- testUpdateSqlPolicies start")
|
|
// testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy")
|
|
// t.Log("---------- testUpdateSqlPolicies finished")
|
|
|
|
// t.Log("---------- testUpdateFilteredSqlPolicies start")
|
|
// testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy")
|
|
// t.Log("---------- testUpdateFilteredSqlPolicies finished")
|
|
|
|
}
|
|
}
|
|
|
|
func testTableName(t *testing.T, db *sqlx.DB) {
|
|
_, err := gcasbin.NewSqlAdapter(db, "")
|
|
if err != nil {
|
|
t.Fatalf("NewAdapter failed, err: %v", err)
|
|
}
|
|
}
|
|
|
|
func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
|
|
var err error
|
|
logErr := func(action string) {
|
|
if err != nil {
|
|
t.Errorf("%s test failed, err: %v", action, err)
|
|
}
|
|
}
|
|
|
|
equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool {
|
|
if line1.PType != line2.PType ||
|
|
line1.V0 != line2.V0 ||
|
|
line1.V1 != line2.V1 ||
|
|
line1.V2 != line2.V2 ||
|
|
line1.V3 != line2.V3 ||
|
|
line1.V4 != line2.V4 ||
|
|
line1.V5 != line2.V5 {
|
|
return false
|
|
}
|
|
return true
|
|
}
|
|
|
|
var a *gcasbin.SqlAdapter
|
|
a, err = gcasbin.NewSqlAdapter(db, tableName)
|
|
logErr("NewSqlAdapter")
|
|
|
|
// createTable test has passed when adapter create
|
|
// err = a.CreateTable()
|
|
// logErr("createTable")
|
|
|
|
if b := a.IsTableExist(); b == false {
|
|
t.Fatal("isTableExist test failed")
|
|
}
|
|
|
|
rules := make([][]interface{}, len(lines))
|
|
for idx, rule := range lines {
|
|
args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
|
|
rules[idx] = args
|
|
}
|
|
|
|
err = a.TruncateAndInsertRows(rules)
|
|
logErr("truncateAndInsertRows")
|
|
|
|
err = a.DeleteAllAndInsertRows(rules)
|
|
logErr("truncateAndInsertRows")
|
|
|
|
err = a.DeleteRows(a.SqlDeleteByArgs, "g")
|
|
logErr("deleteRows sqlDeleteByArgs g")
|
|
|
|
err = a.DeleteRows(a.SqlDeleteAll)
|
|
logErr("deleteRows sqlDeleteAll")
|
|
|
|
_ = a.TruncateAndInsertRows(rules)
|
|
|
|
records, err := a.SelectRows(a.SqlSelectAll)
|
|
logErr("selectRows sqlSelectAll")
|
|
for idx, record := range records {
|
|
line := lines[idx]
|
|
if !equalValue(*record, line) {
|
|
t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line)
|
|
}
|
|
}
|
|
|
|
records, err = a.SelectWhereIn(&filter)
|
|
logErr("selectWhereIn")
|
|
i := 3
|
|
for _, record := range records {
|
|
line := lines[i]
|
|
if !equalValue(*record, line) {
|
|
t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line)
|
|
}
|
|
i++
|
|
}
|
|
|
|
err = a.TruncateTable()
|
|
logErr("truncateTable")
|
|
}
|
|
|
|
func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Because the DB is empty at first,
|
|
// so we need to load the policy from the file adapter (.CSV) first.
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile)
|
|
|
|
a, err := gcasbin.NewSqlAdapter(db, tableName)
|
|
if err != nil {
|
|
t.Fatal("NewAdapter test failed, err: ", err)
|
|
}
|
|
|
|
// This is a trick to save the current policy to the DB.
|
|
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
|
|
// The current policy means the policy in the Casbin enforcer (aka in memory).
|
|
err = a.SavePolicy(e.GetModel())
|
|
if err != nil {
|
|
t.Fatal("SavePolicy test failed, err: ", err)
|
|
}
|
|
|
|
// Clear the current policy.
|
|
e.ClearPolicy()
|
|
testGetSqlPolicy(t, e, [][]string{})
|
|
|
|
// Load the policy from DB.
|
|
err = a.LoadPolicy(e.GetModel())
|
|
if err != nil {
|
|
t.Fatal("LoadPolicy test failed, err: ", err)
|
|
}
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
}
|
|
|
|
func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Initialize some policy in DB.
|
|
initSqlPolicy(t, db, tableName)
|
|
// Note: you don't need to look at the above code
|
|
// if you already have a working DB with policy inside.
|
|
|
|
// Now the DB has policy, so we can provide a normal use case.
|
|
// Create an adapter and an enforcer.
|
|
// NewEnforcer() will load the policy automatically.
|
|
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
}
|
|
|
|
func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Initialize some policy in DB.
|
|
initSqlPolicy(t, db, tableName)
|
|
// Note: you don't need to look at the above code
|
|
// if you already have a working DB with policy inside.
|
|
|
|
// Now the DB has policy, so we can provide a normal use case.
|
|
// Create an adapter and an enforcer.
|
|
// NewEnforcer() will load the policy automatically.
|
|
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
|
|
// AutoSave is enabled by default.
|
|
// Now we disable it.
|
|
e.EnableAutoSave(false)
|
|
|
|
var err error
|
|
logErr := func(action string) {
|
|
if err != nil {
|
|
t.Errorf("%s test failed, err: %v", action, err)
|
|
}
|
|
}
|
|
|
|
// Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer,
|
|
// it doesn't affect the policy in the storage.
|
|
_, err = e.AddPolicy("alice", "data1", "write")
|
|
logErr("AddPolicy1")
|
|
// Reload the policy from the storage to see the effect.
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy1")
|
|
// This is still the original policy.
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
|
|
_, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}})
|
|
logErr("AddPolicies1")
|
|
// Reload the policy from the storage to see the effect.
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy2")
|
|
// This is still the original policy.
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
|
|
// Now we enable the AutoSave.
|
|
e.EnableAutoSave(true)
|
|
|
|
// Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer,
|
|
// but also affects the policy in the storage.
|
|
_, err = e.AddPolicy("alice", "data1", "write")
|
|
logErr("AddPolicy2")
|
|
// Reload the policy from the storage to see the effect.
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy3")
|
|
// The policy has a new rule: {"alice", "data1", "write"}.
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
|
|
|
|
_, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
|
logErr("AddPolicies2")
|
|
// Reload the policy from the storage to see the effect.
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy4")
|
|
// This is still the original policy.
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"},
|
|
{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
|
|
|
_, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
|
|
logErr("RemovePolicies")
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy5")
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
|
|
|
|
// Remove the added rule.
|
|
_, err = e.RemovePolicy("alice", "data1", "write")
|
|
logErr("RemovePolicy")
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy6")
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
|
|
// Remove "data2_admin" related policy rules via a filter.
|
|
// Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted.
|
|
_, err = e.RemoveFilteredPolicy(0, "data2_admin")
|
|
logErr("RemoveFilteredPolicy")
|
|
err = e.LoadPolicy()
|
|
logErr("LoadPolicy7")
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
|
|
}
|
|
|
|
func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Initialize some policy in DB.
|
|
initSqlPolicy(t, db, tableName)
|
|
// Note: you don't need to look at the above code
|
|
// if you already have a working DB with policy inside.
|
|
|
|
// Now the DB has policy, so we can provide a normal use case.
|
|
// Create an adapter and an enforcer.
|
|
// NewEnforcer() will load the policy automatically.
|
|
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
// Now set the adapter
|
|
e.SetAdapter(a)
|
|
|
|
var err error
|
|
logErr := func(action string) {
|
|
if err != nil {
|
|
t.Errorf("%s test failed, err: %v", action, err)
|
|
}
|
|
}
|
|
|
|
// Load only alice's policies
|
|
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}})
|
|
logErr("LoadFilteredPolicy alice")
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}})
|
|
|
|
// Load only bob's policies
|
|
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}})
|
|
logErr("LoadFilteredPolicy bob")
|
|
testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}})
|
|
|
|
// Load policies for data2_admin
|
|
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}})
|
|
logErr("LoadFilteredPolicy data2_admin")
|
|
testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
|
|
// Load policies for alice and bob
|
|
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}})
|
|
logErr("LoadFilteredPolicy alice bob")
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
|
|
|
|
_, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3")
|
|
logErr("AddPolicy")
|
|
|
|
err = e.LoadFilteredPolicy(&filter)
|
|
logErr("LoadFilteredPolicy filter")
|
|
testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}})
|
|
}
|
|
|
|
func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Initialize some policy in DB.
|
|
initSqlPolicy(t, db, tableName)
|
|
|
|
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
|
|
e.EnableAutoSave(true)
|
|
e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"})
|
|
e.LoadPolicy()
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
}
|
|
|
|
func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Initialize some policy in DB.
|
|
initSqlPolicy(t, db, tableName)
|
|
|
|
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
|
|
e.EnableAutoSave(true)
|
|
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
|
|
e.LoadPolicy()
|
|
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
|
|
}
|
|
|
|
func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
|
|
// Initialize some policy in DB.
|
|
initSqlPolicy(t, db, tableName)
|
|
|
|
a, _ := gcasbin.NewSqlAdapter(db, tableName)
|
|
e, _ := casbin.NewEnforcer(rbacModelFile, a)
|
|
|
|
e.EnableAutoSave(true)
|
|
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
|
|
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
|
|
e.LoadPolicy()
|
|
testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
|
|
}
|
|
|
|
func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
|
|
t.Helper()
|
|
myRes := e.GetPolicy()
|
|
t.Log("Policy: ", myRes)
|
|
|
|
m := make(map[string]struct{}, len(myRes))
|
|
for _, record := range myRes {
|
|
key := strings.Join(record, ",")
|
|
m[key] = struct{}{}
|
|
}
|
|
|
|
for _, record := range res {
|
|
key := strings.Join(record, ",")
|
|
if _, ok := m[key]; !ok {
|
|
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
|
|
break
|
|
}
|
|
}
|
|
}
|
|
|
|
func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
|
|
myRes := e.GetPolicy()
|
|
// log.Print("Policy: \n", myRes)
|
|
|
|
if !arraySqlEqualsWithoutOrder(myRes, res) {
|
|
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
|
|
}
|
|
}
|
|
|
|
func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool {
|
|
if len(a) != len(b) {
|
|
return false
|
|
}
|
|
|
|
mapA := make(map[int]string)
|
|
mapB := make(map[int]string)
|
|
order := make(map[int]struct{})
|
|
l := len(a)
|
|
|
|
for i := 0; i < l; i++ {
|
|
mapA[i] = util.ArrayToString(a[i])
|
|
mapB[i] = util.ArrayToString(b[i])
|
|
}
|
|
|
|
for i := 0; i < l; i++ {
|
|
for j := 0; j < l; j++ {
|
|
if _, ok := order[j]; ok {
|
|
if j == l-1 {
|
|
return false
|
|
} else {
|
|
continue
|
|
}
|
|
}
|
|
if mapA[i] == mapB[j] {
|
|
order[j] = struct{}{}
|
|
break
|
|
} else if j == l-1 {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|