golib/gcasbin/adapter_sqlx_test.go
2023-06-15 21:38:12 +08:00

467 lines
15 KiB
Go

//
// adapter_sqlx_test.go
// Copyright (C) 2022 tiglog <me@tiglog.com>
//
// Distributed under terms of the MIT license.
//
package gcasbin_test
import (
"os"
"strings"
"testing"
"github.com/casbin/casbin/v2"
"github.com/casbin/casbin/v2/util"
_ "github.com/go-sql-driver/mysql"
"github.com/jmoiron/sqlx"
"git.hexq.cn/tiglog/golib/gcasbin"
)
const (
rbacModelFile = "testdata/rbac_model.conf"
rbacPolicyFile = "testdata/rbac_policy.csv"
)
var (
dataSourceNames = map[string]string{
// "sqlite3": ":memory:",
// "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test",
"postgres": os.Getenv("DB_DSN"),
// "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30",
}
lines = []gcasbin.SqlCasbinRule{
{PType: "p", V0: "alice", V1: "data1", V2: "read"},
{PType: "p", V0: "bob", V1: "data2", V2: "read"},
{PType: "p", V0: "bob", V1: "data2", V2: "write"},
{PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"},
{PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"},
{PType: "p", V0: "data1_admin", V1: "data2", V2: "write"},
{PType: "g", V0: "alice", V1: "data2_admin"},
{PType: "g", V0: "bob", V1: "data2_admin", V2: "test"},
{PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"},
}
filter = gcasbin.SqlFilter{
PType: []string{"p"},
V0: []string{"bob", "data2_admin"},
V1: []string{"data1", "data2"},
V2: []string{"read", "write"},
V3: []string{"test1"},
V4: []string{"test2"},
V5: []string{"test3"},
}
)
func TestSqlAdapters(t *testing.T) {
for key, value := range dataSourceNames {
t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value)
db, err := sqlx.Connect(key, value)
if err != nil {
t.Fatalf("sqlx.Connect failed, err: %v", err)
}
t.Log("---------- testTableName start")
testTableName(t, db)
t.Log("---------- testTableName finished")
t.Log("---------- testSQL start")
testSQL(t, db, "sqlxadapter_sql")
t.Log("---------- testSQL finished")
t.Log("---------- testSaveLoad start")
testSaveLoad(t, db, "sqlxadapter_save_load")
t.Log("---------- testSaveLoad finished")
t.Log("---------- testAutoSave start")
testAutoSave(t, db, "sqlxadapter_auto_save")
t.Log("---------- testAutoSave finished")
t.Log("---------- testFilteredSqlPolicy start")
testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy")
t.Log("---------- testFilteredSqlPolicy finished")
// t.Log("---------- testUpdateSqlPolicy start")
// testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateSqlPolicy finished")
// t.Log("---------- testUpdateSqlPolicies start")
// testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateSqlPolicies finished")
// t.Log("---------- testUpdateFilteredSqlPolicies start")
// testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy")
// t.Log("---------- testUpdateFilteredSqlPolicies finished")
}
}
func testTableName(t *testing.T, db *sqlx.DB) {
_, err := gcasbin.NewSqlAdapter(db, "")
if err != nil {
t.Fatalf("NewAdapter failed, err: %v", err)
}
}
func testSQL(t *testing.T, db *sqlx.DB, tableName string) {
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool {
if line1.PType != line2.PType ||
line1.V0 != line2.V0 ||
line1.V1 != line2.V1 ||
line1.V2 != line2.V2 ||
line1.V3 != line2.V3 ||
line1.V4 != line2.V4 ||
line1.V5 != line2.V5 {
return false
}
return true
}
var a *gcasbin.SqlAdapter
a, err = gcasbin.NewSqlAdapter(db, tableName)
logErr("NewSqlAdapter")
// createTable test has passed when adapter create
// err = a.CreateTable()
// logErr("createTable")
if b := a.IsTableExist(); b == false {
t.Fatal("isTableExist test failed")
}
rules := make([][]interface{}, len(lines))
for idx, rule := range lines {
args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5})
rules[idx] = args
}
err = a.TruncateAndInsertRows(rules)
logErr("truncateAndInsertRows")
err = a.DeleteAllAndInsertRows(rules)
logErr("truncateAndInsertRows")
err = a.DeleteRows(a.SqlDeleteByArgs, "g")
logErr("deleteRows sqlDeleteByArgs g")
err = a.DeleteRows(a.SqlDeleteAll)
logErr("deleteRows sqlDeleteAll")
_ = a.TruncateAndInsertRows(rules)
records, err := a.SelectRows(a.SqlSelectAll)
logErr("selectRows sqlSelectAll")
for idx, record := range records {
line := lines[idx]
if !equalValue(*record, line) {
t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line)
}
}
records, err = a.SelectWhereIn(&filter)
logErr("selectWhereIn")
i := 3
for _, record := range records {
line := lines[i]
if !equalValue(*record, line) {
t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line)
}
i++
}
err = a.TruncateTable()
logErr("truncateTable")
}
func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Because the DB is empty at first,
// so we need to load the policy from the file adapter (.CSV) first.
e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile)
a, err := gcasbin.NewSqlAdapter(db, tableName)
if err != nil {
t.Fatal("NewAdapter test failed, err: ", err)
}
// This is a trick to save the current policy to the DB.
// We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter.
// The current policy means the policy in the Casbin enforcer (aka in memory).
err = a.SavePolicy(e.GetModel())
if err != nil {
t.Fatal("SavePolicy test failed, err: ", err)
}
// Clear the current policy.
e.ClearPolicy()
testGetSqlPolicy(t, e, [][]string{})
// Load the policy from DB.
err = a.LoadPolicy(e.GetModel())
if err != nil {
t.Fatal("LoadPolicy test failed, err: ", err)
}
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
// AutoSave is enabled by default.
// Now we disable it.
e.EnableAutoSave(false)
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
// Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer,
// it doesn't affect the policy in the storage.
_, err = e.AddPolicy("alice", "data1", "write")
logErr("AddPolicy1")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy1")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
_, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}})
logErr("AddPolicies1")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy2")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Now we enable the AutoSave.
e.EnableAutoSave(true)
// Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer,
// but also affects the policy in the storage.
_, err = e.AddPolicy("alice", "data1", "write")
logErr("AddPolicy2")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy3")
// The policy has a new rule: {"alice", "data1", "write"}.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
_, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
logErr("AddPolicies2")
// Reload the policy from the storage to see the effect.
err = e.LoadPolicy()
logErr("LoadPolicy4")
// This is still the original policy.
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"},
{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
_, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}})
logErr("RemovePolicies")
err = e.LoadPolicy()
logErr("LoadPolicy5")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}})
// Remove the added rule.
_, err = e.RemovePolicy("alice", "data1", "write")
logErr("RemovePolicy")
err = e.LoadPolicy()
logErr("LoadPolicy6")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Remove "data2_admin" related policy rules via a filter.
// Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted.
_, err = e.RemoveFilteredPolicy(0, "data2_admin")
logErr("RemoveFilteredPolicy")
err = e.LoadPolicy()
logErr("LoadPolicy7")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
}
func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
// Note: you don't need to look at the above code
// if you already have a working DB with policy inside.
// Now the DB has policy, so we can provide a normal use case.
// Create an adapter and an enforcer.
// NewEnforcer() will load the policy automatically.
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
// Now set the adapter
e.SetAdapter(a)
var err error
logErr := func(action string) {
if err != nil {
t.Errorf("%s test failed, err: %v", action, err)
}
}
// Load only alice's policies
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}})
logErr("LoadFilteredPolicy alice")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}})
// Load only bob's policies
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}})
logErr("LoadFilteredPolicy bob")
testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}})
// Load policies for data2_admin
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}})
logErr("LoadFilteredPolicy data2_admin")
testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
// Load policies for alice and bob
err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}})
logErr("LoadFilteredPolicy alice bob")
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}})
_, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3")
logErr("AddPolicy")
err = e.LoadFilteredPolicy(&filter)
logErr("LoadFilteredPolicy filter")
testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}})
}
func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"})
e.LoadPolicy()
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}})
e.LoadPolicy()
testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}})
}
func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) {
// Initialize some policy in DB.
initSqlPolicy(t, db, tableName)
a, _ := gcasbin.NewSqlAdapter(db, tableName)
e, _ := casbin.NewEnforcer(rbacModelFile, a)
e.EnableAutoSave(true)
e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read")
e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write")
e.LoadPolicy()
testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}})
}
func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) {
t.Helper()
myRes := e.GetPolicy()
t.Log("Policy: ", myRes)
m := make(map[string]struct{}, len(myRes))
for _, record := range myRes {
key := strings.Join(record, ",")
m[key] = struct{}{}
}
for _, record := range res {
key := strings.Join(record, ",")
if _, ok := m[key]; !ok {
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
break
}
}
}
func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) {
myRes := e.GetPolicy()
// log.Print("Policy: \n", myRes)
if !arraySqlEqualsWithoutOrder(myRes, res) {
t.Error("Policy: \n", myRes, ", supposed to be \n", res)
}
}
func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool {
if len(a) != len(b) {
return false
}
mapA := make(map[int]string)
mapB := make(map[int]string)
order := make(map[int]struct{})
l := len(a)
for i := 0; i < l; i++ {
mapA[i] = util.ArrayToString(a[i])
mapB[i] = util.ArrayToString(b[i])
}
for i := 0; i < l; i++ {
for j := 0; j < l; j++ {
if _, ok := order[j]; ok {
if j == l-1 {
return false
} else {
continue
}
}
if mapA[i] == mapB[j] {
order[j] = struct{}{}
break
} else if j == l-1 {
return false
}
}
}
return true
}