commit f21ddce0c8e54c2bd4e14dcef16f63d2ed7f1081 Author: tiglog Date: Thu Jun 15 21:22:51 2023 +0800 init repo diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0232850 --- /dev/null +++ b/.gitignore @@ -0,0 +1,18 @@ +# ---> Go +# # Binaries for programs and plugins +# *.exe +# *.exe~ +# *.dll +# *.so +# *.dylib +# +# # Test binary, built with `go test -c` +# *.test +# +# # Output of the go coverage tool, specifically when used with LiteIDE +# *.out +# +# # Dependency directories (remove the comment below to include it) +# +.env_test.sh +*.log diff --git a/console/cli_about.go b/console/cli_about.go new file mode 100644 index 0000000..9034529 --- /dev/null +++ b/console/cli_about.go @@ -0,0 +1,59 @@ +// +// cli_about.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +import ( + "fmt" + "path/filepath" + "runtime" +) + +type cAbout struct { + BaseCmd + cli IConsole +} + +func NewAboutCmd(cli IConsole) *cAbout { + c := new(cAbout) + c.Name = "about" + c.Desc = "应用基本环境" + c.cli = cli + return c +} + +func (c *cAbout) Init(args []string) { + + c.Action = func() error { + fmt.Printf("About:\n") + fmt.Printf("==================================================\n") + fmt.Printf(" Version: %s\n", c.cli.GetVersion()) + wd, _ := filepath.Abs("./") + fmt.Printf(" BaseDir: %s\n", wd) + fmt.Printf(" Env: %s\n", "dev") + fmt.Printf(" Debug: %v\n", true) + fmt.Printf(" GOOS: %s %s\n", runtime.GOOS, runtime.GOARCH) + + hd := c.cli.GetExtraAbout() + if hd != nil { + return hd() + } + // if sqldb.Db != nil { + // fmt.Printf(" Db Type: %s\n", config.Conf.Db.Type) + // fmt.Printf(" Db Host: %s\n", config.Conf.Db.Host) + // fmt.Printf(" Db Port: %d\n", config.Conf.Db.Port) + // fmt.Printf(" Db Name: %s\n", config.Conf.Db.Name) + // fmt.Printf(" Db User: %s\n", config.Conf.Db.Username) + // } + return nil + } + +} + +func (c *cAbout) GetHelp() string { + return "" +} diff --git a/console/cli_air.go b/console/cli_air.go new file mode 100644 index 0000000..d5a35a0 --- /dev/null +++ b/console/cli_air.go @@ -0,0 +1,142 @@ +// +// cli_air.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +import ( + "flag" + "fmt" + "io/ioutil" + "os" + "os/exec" + + "hexq.cn/tiglog/golib/gfile" +) + +type cAir struct { + BaseCmd +} + +func NewAirCmd(cli IConsole) *cAir { + return &cAir{ + BaseCmd{ + Name: "serve", + Desc: "代码变动时自动重启应用", + }, + } +} + +func (c *cAir) Init(args []string) { + var cmd string + if len(args) == 0 { + cmd = "dev" + } else { + cmd = args[0] + args = args[1:] + } + if cmd == "conf" { + c.initConfCmd(args) + } else { + c.initDevCmd(args) + } + +} +func (c *cAir) initDevCmd(args []string) { + cmd := flag.NewFlagSet("dev", flag.ExitOnError) + cmd.Parse(args) + c.Action = func() error { + fmt.Println("run dev cmd") + cm := exec.Command("air") + cm.Stderr = os.Stderr + cm.Stdout = os.Stdout + err := cm.Start() + if err != nil { + return err + } + err = cm.Wait() + if err != nil { + return err + } + return nil + } + +} + +func (c *cAir) initConfCmd(args []string) { + cmd := flag.NewFlagSet("conf", flag.ExitOnError) + output := cmd.String("output", "./.air.toml", "指定文件路径") + show := cmd.Bool("show", false, "显示配置文件内容") + cmd.Parse(args) + c.Action = func() error { + if *show { + fmt.Println(conf) + } else { + if gfile.Exists(*output) { + fmt.Printf("Conf file %s exists, SKIP\n\n", *output) + } else { + fmt.Printf("Writing conf to %s ...\n", *output) + ioutil.WriteFile(*output, []byte(conf), 0644) + fmt.Println("Done") + } + } + + return nil + } + +} + +func (c *cAir) GetHelp() string { + p1 := fmt.Sprintf("%s \n%s\n\nSub Commands:\n", c.Name, c.Desc) + p2 := fmt.Sprintf("%10s: %s\n", "dev", "启动开发服务") + p3 := fmt.Sprintf("%10s: %s\n", "conf", "生成配置信息") + return p1 + p2 + p3 +} + +var conf = `# [Air](https://github.com/cosmtrek/air) TOML 格式的配置文件 + +# 工作目录 +# 使用 . 或绝对路径,请注意 tmp_dir 目录必须在 root 目录下 +root = "." +tmp_dir = "var/tmp" + +[build] +# 只需要写你平常编译使用的shell命令。你也可以使用 make +cmd = "go build -o ./var/tmp/main entry/web/main.go" +# 由 cmd 命令得到的二进制文件名 +bin = "var/tmp/main" +# 自定义的二进制,可以添加额外的编译标识例如添加 GIN_MODE=release +# full_bin = "APP_ENV=dev APP_USER=air ./var/tmp/main" +# 监听以下文件扩展名的文件. +include_ext = ["go", "tpl", "tmpl", "html"] +# 忽略这些文件扩展名或目录 +exclude_dir = ["assets", "var", "vendor", "frontend/node_modules"] +# 监听以下指定目录的文件 +include_dir = [] +# 排除以下文件 +exclude_file = [] +# 如果文件更改过于频繁,则没有必要在每次更改时都触发构建。可以设置触发构建的延迟时间 +delay = 1000 # ms +# 发生构建错误时,停止运行旧的二进制文件。 +stop_on_error = true +# air的日志文件名,该日志文件放置在你的 tmp_dir 中 +log = "./var/log/air_errors.log" + +[log] +# 显示日志时间 +time = true + +[color] +# 自定义每个部分显示的颜色。如果找不到颜色,使用原始的应用程序日志。 +main = "magenta" +watcher = "cyan" +build = "yellow" +runner = "green" + +[misc] +# 退出时删除tmp目录 +clean_on_exit = true +` diff --git a/console/cli_base.go b/console/cli_base.go new file mode 100644 index 0000000..78e329e --- /dev/null +++ b/console/cli_base.go @@ -0,0 +1,28 @@ +// +// cli_base.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +import "flag" + +type BaseCmd struct { + Name string + Desc string + FlagSet *flag.FlagSet + Action ActionHandler +} + +func (c *BaseCmd) GetName() string { + return c.Name +} +func (c *BaseCmd) GetDesc() string { + return c.Desc +} + +func (c *BaseCmd) GetAction() ActionHandler { + return c.Action +} diff --git a/console/cli_contract.go b/console/cli_contract.go new file mode 100644 index 0000000..07255cb --- /dev/null +++ b/console/cli_contract.go @@ -0,0 +1,25 @@ +// +// cli_contact.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +type ICommand interface { + Init(args []string) + GetName() string + GetDesc() string + GetAction() ActionHandler + GetHelp() string +} + +type IConsole interface { + GetCmds() map[string]ICommand + GetName() string + GetDesc() string + GetVersion() string + GetExtraAbout() ActionHandler +} +type ActionHandler func() error diff --git a/console/cli_help.go b/console/cli_help.go new file mode 100644 index 0000000..bacd481 --- /dev/null +++ b/console/cli_help.go @@ -0,0 +1,47 @@ +// +// cli_help.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +import ( + "errors" + "fmt" +) + +type cHelp struct { + BaseCmd + cli IConsole +} + +func NewHelpCmd(cli IConsole) *cHelp { + return &cHelp{ + BaseCmd{ + Name: "help", + Desc: "查看命令的使用方法", + }, + cli, + } +} + +func (c *cHelp) Init(args []string) { + if len(args) == 0 { + return + } + c.Action = func() error { + cmds := c.cli.GetCmds() + cmd, ok := cmds[args[0]] + if !ok { + return errors.New("指定的命令不存在") + } + fmt.Println(cmd.GetHelp()) + return nil + } +} + +func (c *cHelp) GetHelp() string { + return c.Desc +} diff --git a/console/cli_list.go b/console/cli_list.go new file mode 100644 index 0000000..c19c033 --- /dev/null +++ b/console/cli_list.go @@ -0,0 +1,42 @@ +// +// cli_list.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +import "fmt" + +type cList struct { + BaseCmd + cli IConsole +} + +func NewListCmd(cli IConsole) *cList { + return &cList{ + BaseCmd{ + Name: "list", + Desc: "列出支持的命令", + }, + cli, + } +} + +func (c *cList) Init(args []string) { + c.Action = func() error { + fmt.Printf("Usage: %s []\n\n%s\n\nCommands:\n", c.cli.GetName(), c.cli.GetDesc()) + cmds := c.cli.GetCmds() + for _, cmd := range cmds { + fmt.Printf("%8s: %s\n", cmd.GetName(), cmd.GetDesc()) + } + fmt.Println() + return nil + } + +} + +func (c *cList) GetHelp() string { + return c.Desc +} diff --git a/console/console.go b/console/console.go new file mode 100644 index 0000000..edaee70 --- /dev/null +++ b/console/console.go @@ -0,0 +1,96 @@ +// +// cli.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package console + +import ( + "fmt" + "os" +) + +type sApp struct { + name string + version string + desc string + cmds map[string]ICommand + about ActionHandler +} + +func New(name, desc string) *sApp { + app := &sApp{ + name: name, + version: "v0.1.0", + desc: desc, + cmds: make(map[string]ICommand), + } + app.AddCmd(NewAboutCmd(app)) + app.AddCmd(NewAirCmd(app)) + app.AddCmd(NewListCmd(app)) + app.AddCmd(NewHelpCmd(app)) + return app +} + +func (s *sApp) GetCmds() map[string]ICommand { + return s.cmds +} + +func (s *sApp) GetName() string { + return s.name +} + +func (s *sApp) GetDesc() string { + return s.desc +} + +func (s *sApp) GetVersion() string { + return s.version +} + +func (s *sApp) AddCmd(cmd ICommand) { + s.cmds[cmd.GetName()] = cmd +} + +func (s *sApp) HasCmd(cmd string) bool { + _, ok := s.cmds[cmd] + return ok +} + +func (s *sApp) SetExtraAbout(about ActionHandler) { + s.about = about +} + +func (s *sApp) GetExtraAbout() ActionHandler { + return s.about +} + +func (s *sApp) Run(args []string) { + cmd := "list" + if len(args) == 1 { + args = []string{cmd} + } else { + cmd = args[1] + args = args[2:] + } + if !s.HasCmd(cmd) { + fmt.Printf("%q is not valid command.\n", cmd) + os.Exit(2) + } + + scmd := s.cmds[cmd] + scmd.Init(args) + act := scmd.GetAction() + if act == nil { + // fmt.Println("未定义 Action,无法执行 Action.") + fmt.Println(scmd.GetHelp()) + os.Exit(1) + } + err := act() + if err != nil { + fmt.Printf("执行异常:%v\n", err) + os.Exit(1) + } +} diff --git a/console/tablewriter/csv.go b/console/tablewriter/csv.go new file mode 100644 index 0000000..53578fd --- /dev/null +++ b/console/tablewriter/csv.go @@ -0,0 +1,53 @@ +// +// csv.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package tablewriter + +import ( + "encoding/csv" + "io" + "os" +) + +// Start A new table by importing from a CSV file +// Takes io.Writer and csv File name +func NewCSV(writer io.Writer, fileName string, hasHeader bool) (*Table, error) { + file, err := os.Open(fileName) + if err != nil { + return &Table{}, err + } + defer file.Close() + csvReader := csv.NewReader(file) + t, err := NewCSVReader(writer, csvReader, hasHeader) + return t, err +} + +// Start a New Table Writer with csv.Reader +// +// This enables customisation such as reader.Comma = ';' +// See http://golang.org/src/pkg/encoding/csv/reader.go?s=3213:3671#L94 +func NewCSVReader(writer io.Writer, csvReader *csv.Reader, hasHeader bool) (*Table, error) { + t := NewWriter(writer) + if hasHeader { + // Read the first row + headers, err := csvReader.Read() + if err != nil { + return &Table{}, err + } + t.SetHeader(headers) + } + for { + record, err := csvReader.Read() + if err == io.EOF { + break + } else if err != nil { + return &Table{}, err + } + t.Append(record) + } + return t, nil +} diff --git a/console/tablewriter/table.go b/console/tablewriter/table.go new file mode 100644 index 0000000..923338e --- /dev/null +++ b/console/tablewriter/table.go @@ -0,0 +1,1057 @@ +// +// table.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package tablewriter + +import ( + "bytes" + "errors" + "fmt" + "io" + "reflect" + "regexp" + "strings" +) + +const ( + MAX_ROW_WIDTH = 30 +) + +const ( + CENTER = "+" + ROW = "-" + COLUMN = "|" + SPACE = " " + NEWLINE = "\n" +) + +const ( + ALIGN_DEFAULT = iota + ALIGN_CENTER + ALIGN_RIGHT + ALIGN_LEFT +) + +var ( + decimal = regexp.MustCompile(`^-?(?:\d{1,3}(?:,\d{3})*|\d+)(?:\.\d+)?$`) + percent = regexp.MustCompile(`^-?\d+\.?\d*$%$`) +) + +type Border struct { + Left bool + Right bool + Top bool + Bottom bool +} + +type Table struct { + out io.Writer + rows [][]string + lines [][][]string + cs map[int]int + rs map[int]int + headers [][]string + footers [][]string + caption bool + captionText string + autoFmt bool + autoWrap bool + reflowText bool + mW int + pCenter string + pRow string + pColumn string + tColumn int + tRow int + hAlign int + fAlign int + align int + newLine string + rowLine bool + autoMergeCells bool + columnsToAutoMergeCells map[int]bool + noWhiteSpace bool + tablePadding string + hdrLine bool + borders Border + colSize int + headerParams []string + columnsParams []string + footerParams []string + columnsAlign []int +} + +// Start New Table +// Take io.Writer Directly +func NewWriter(writer io.Writer) *Table { + t := &Table{ + out: writer, + rows: [][]string{}, + lines: [][][]string{}, + cs: make(map[int]int), + rs: make(map[int]int), + headers: [][]string{}, + footers: [][]string{}, + caption: false, + captionText: "Table caption.", + autoFmt: true, + autoWrap: true, + reflowText: true, + mW: MAX_ROW_WIDTH, + pCenter: CENTER, + pRow: ROW, + pColumn: COLUMN, + tColumn: -1, + tRow: -1, + hAlign: ALIGN_DEFAULT, + fAlign: ALIGN_DEFAULT, + align: ALIGN_DEFAULT, + newLine: NEWLINE, + rowLine: false, + hdrLine: true, + borders: Border{Left: true, Right: true, Bottom: true, Top: true}, + colSize: -1, + headerParams: []string{}, + columnsParams: []string{}, + footerParams: []string{}, + columnsAlign: []int{}} + return t +} + +// Render table output +func (t *Table) Render() { + if t.borders.Top { + t.printLine(true) + } + t.printHeading() + if t.autoMergeCells { + t.printRowsMergeCells() + } else { + t.printRows() + } + if !t.rowLine && t.borders.Bottom { + t.printLine(true) + } + t.printFooter() + + if t.caption { + t.printCaption() + } +} + +const ( + headerRowIdx = -1 + footerRowIdx = -2 +) + +// Set table header +func (t *Table) SetHeader(keys []string) { + t.colSize = len(keys) + for i, v := range keys { + lines := t.parseDimension(v, i, headerRowIdx) + t.headers = append(t.headers, lines) + } +} + +// Set table Footer +func (t *Table) SetFooter(keys []string) { + //t.colSize = len(keys) + for i, v := range keys { + lines := t.parseDimension(v, i, footerRowIdx) + t.footers = append(t.footers, lines) + } +} + +// Set table Caption +func (t *Table) SetCaption(caption bool, captionText ...string) { + t.caption = caption + if len(captionText) == 1 { + t.captionText = captionText[0] + } +} + +// Turn header autoformatting on/off. Default is on (true). +func (t *Table) SetAutoFormatHeaders(auto bool) { + t.autoFmt = auto +} + +// Turn automatic multiline text adjustment on/off. Default is on (true). +func (t *Table) SetAutoWrapText(auto bool) { + t.autoWrap = auto +} + +// Turn automatic reflowing of multiline text when rewrapping. Default is on (true). +func (t *Table) SetReflowDuringAutoWrap(auto bool) { + t.reflowText = auto +} + +// Set the Default column width +func (t *Table) SetColWidth(width int) { + t.mW = width +} + +// Set the minimal width for a column +func (t *Table) SetColMinWidth(column int, width int) { + t.cs[column] = width +} + +// Set the Column Separator +func (t *Table) SetColumnSeparator(sep string) { + t.pColumn = sep +} + +// Set the Row Separator +func (t *Table) SetRowSeparator(sep string) { + t.pRow = sep +} + +// Set the center Separator +func (t *Table) SetCenterSeparator(sep string) { + t.pCenter = sep +} + +// Set Header Alignment +func (t *Table) SetHeaderAlignment(hAlign int) { + t.hAlign = hAlign +} + +// Set Footer Alignment +func (t *Table) SetFooterAlignment(fAlign int) { + t.fAlign = fAlign +} + +// Set Table Alignment +func (t *Table) SetAlignment(align int) { + t.align = align +} + +// Set No White Space +func (t *Table) SetNoWhiteSpace(allow bool) { + t.noWhiteSpace = allow +} + +// Set Table Padding +func (t *Table) SetTablePadding(padding string) { + t.tablePadding = padding +} + +func (t *Table) SetColumnAlignment(keys []int) { + for _, v := range keys { + switch v { + case ALIGN_CENTER: + break + case ALIGN_LEFT: + break + case ALIGN_RIGHT: + break + default: + v = ALIGN_DEFAULT + } + t.columnsAlign = append(t.columnsAlign, v) + } +} + +// Set New Line +func (t *Table) SetNewLine(nl string) { + t.newLine = nl +} + +// Set Header Line +// This would enable / disable a line after the header +func (t *Table) SetHeaderLine(line bool) { + t.hdrLine = line +} + +// Set Row Line +// This would enable / disable a line on each row of the table +func (t *Table) SetRowLine(line bool) { + t.rowLine = line +} + +// Set Auto Merge Cells +// This would enable / disable the merge of cells with identical values +func (t *Table) SetAutoMergeCells(auto bool) { + t.autoMergeCells = auto +} + +// Set Auto Merge Cells By Column Index +// This would enable / disable the merge of cells with identical values for specific columns +// If cols is empty, it is the same as `SetAutoMergeCells(true)`. +func (t *Table) SetAutoMergeCellsByColumnIndex(cols []int) { + t.autoMergeCells = true + + if len(cols) > 0 { + m := make(map[int]bool) + for _, col := range cols { + m[col] = true + } + t.columnsToAutoMergeCells = m + } +} + +// Set Table Border +// This would enable / disable line around the table +func (t *Table) SetBorder(border bool) { + t.SetBorders(Border{border, border, border, border}) +} + +func (t *Table) SetBorders(border Border) { + t.borders = border +} + +// SetStructs sets header and rows from slice of struct. +// If something that is not a slice is passed, error will be returned. +// The tag specified by "tablewriter" for the struct becomes the header. +// If not specified or empty, the field name will be used. +// The field of the first element of the slice is used as the header. +// If the element implements fmt.Stringer, the result will be used. +// And the slice contains nil, it will be skipped without rendering. +func (t *Table) SetStructs(v interface{}) error { + if v == nil { + return errors.New("nil value") + } + vt := reflect.TypeOf(v) + vv := reflect.ValueOf(v) + switch vt.Kind() { + case reflect.Slice, reflect.Array: + if vv.Len() < 1 { + return errors.New("empty value") + } + + // check first element to set header + first := vv.Index(0) + e := first.Type() + switch e.Kind() { + case reflect.Struct: + // OK + case reflect.Ptr: + if first.IsNil() { + return errors.New("the first element is nil") + } + e = first.Elem().Type() + if e.Kind() != reflect.Struct { + return fmt.Errorf("invalid kind %s", e.Kind()) + } + default: + return fmt.Errorf("invalid kind %s", e.Kind()) + } + n := e.NumField() + headers := make([]string, n) + for i := 0; i < n; i++ { + f := e.Field(i) + header := f.Tag.Get("tablewriter") + if header == "" { + header = f.Name + } + headers[i] = header + } + t.SetHeader(headers) + + for i := 0; i < vv.Len(); i++ { + item := reflect.Indirect(vv.Index(i)) + itemType := reflect.TypeOf(item) + switch itemType.Kind() { + case reflect.Struct: + // OK + default: + return fmt.Errorf("invalid item type %v", itemType.Kind()) + } + if !item.IsValid() { + // skip rendering + continue + } + nf := item.NumField() + if n != nf { + return errors.New("invalid num of field") + } + rows := make([]string, nf) + for j := 0; j < nf; j++ { + f := reflect.Indirect(item.Field(j)) + if f.Kind() == reflect.Ptr { + f = f.Elem() + } + if f.IsValid() { + if s, ok := f.Interface().(fmt.Stringer); ok { + rows[j] = s.String() + continue + } + rows[j] = fmt.Sprint(f) + } else { + rows[j] = "nil" + } + } + t.Append(rows) + } + default: + return fmt.Errorf("invalid type %T", v) + } + return nil +} + +// Append row to table +func (t *Table) Append(row []string) { + rowSize := len(t.headers) + if rowSize > t.colSize { + t.colSize = rowSize + } + + n := len(t.lines) + line := [][]string{} + for i, v := range row { + + // Detect string width + // Detect String height + // Break strings into words + out := t.parseDimension(v, i, n) + + // Append broken words + line = append(line, out) + } + t.lines = append(t.lines, line) +} + +// Append row to table with color attributes +func (t *Table) Rich(row []string, colors []Colors) { + rowSize := len(t.headers) + if rowSize > t.colSize { + t.colSize = rowSize + } + + n := len(t.lines) + line := [][]string{} + for i, v := range row { + + // Detect string width + // Detect String height + // Break strings into words + out := t.parseDimension(v, i, n) + + if len(colors) > i { + color := colors[i] + out[0] = format(out[0], color) + } + + // Append broken words + line = append(line, out) + } + t.lines = append(t.lines, line) +} + +// Allow Support for Bulk Append +// Eliminates repeated for loops +func (t *Table) AppendBulk(rows [][]string) { + for _, row := range rows { + t.Append(row) + } +} + +// NumLines to get the number of lines +func (t *Table) NumLines() int { + return len(t.lines) +} + +// Clear rows +func (t *Table) ClearRows() { + t.lines = [][][]string{} +} + +// Clear footer +func (t *Table) ClearFooter() { + t.footers = [][]string{} +} + +// Center based on position and border. +func (t *Table) center(i int) string { + if i == -1 && !t.borders.Left { + return t.pRow + } + + if i == len(t.cs)-1 && !t.borders.Right { + return t.pRow + } + + return t.pCenter +} + +// Print line based on row width +func (t *Table) printLine(nl bool) { + fmt.Fprint(t.out, t.center(-1)) + for i := 0; i < len(t.cs); i++ { + v := t.cs[i] + fmt.Fprintf(t.out, "%s%s%s%s", + t.pRow, + strings.Repeat(string(t.pRow), v), + t.pRow, + t.center(i)) + } + if nl { + fmt.Fprint(t.out, t.newLine) + } +} + +// Print line based on row width with our without cell separator +func (t *Table) printLineOptionalCellSeparators(nl bool, displayCellSeparator []bool) { + fmt.Fprint(t.out, t.pCenter) + for i := 0; i < len(t.cs); i++ { + v := t.cs[i] + if i > len(displayCellSeparator) || displayCellSeparator[i] { + // Display the cell separator + fmt.Fprintf(t.out, "%s%s%s%s", + t.pRow, + strings.Repeat(string(t.pRow), v), + t.pRow, + t.pCenter) + } else { + // Don't display the cell separator for this cell + fmt.Fprintf(t.out, "%s%s", + strings.Repeat(" ", v+2), + t.pCenter) + } + } + if nl { + fmt.Fprint(t.out, t.newLine) + } +} + +// Return the PadRight function if align is left, PadLeft if align is right, +// and Pad by default +func pad(align int) func(string, string, int) string { + padFunc := Pad + switch align { + case ALIGN_LEFT: + padFunc = PadRight + case ALIGN_RIGHT: + padFunc = PadLeft + } + return padFunc +} + +// Print heading information +func (t *Table) printHeading() { + // Check if headers is available + if len(t.headers) < 1 { + return + } + + // Identify last column + end := len(t.cs) - 1 + + // Get pad function + padFunc := pad(t.hAlign) + + // Checking for ANSI escape sequences for header + is_esc_seq := false + if len(t.headerParams) > 0 { + is_esc_seq = true + } + + // Maximum height. + max := t.rs[headerRowIdx] + + // Print Heading + for x := 0; x < max; x++ { + // Check if border is set + // Replace with space if not set + if !t.noWhiteSpace { + fmt.Fprint(t.out, ConditionString(t.borders.Left, t.pColumn, SPACE)) + } + + for y := 0; y <= end; y++ { + v := t.cs[y] + h := "" + + if y < len(t.headers) && x < len(t.headers[y]) { + h = t.headers[y][x] + } + if t.autoFmt { + h = Title(h) + } + pad := ConditionString((y == end && !t.borders.Left), SPACE, t.pColumn) + if t.noWhiteSpace { + pad = ConditionString((y == end && !t.borders.Left), SPACE, t.tablePadding) + } + if is_esc_seq { + if !t.noWhiteSpace { + fmt.Fprintf(t.out, " %s %s", + format(padFunc(h, SPACE, v), + t.headerParams[y]), pad) + } else { + fmt.Fprintf(t.out, "%s %s", + format(padFunc(h, SPACE, v), + t.headerParams[y]), pad) + } + } else { + if !t.noWhiteSpace { + fmt.Fprintf(t.out, " %s %s", + padFunc(h, SPACE, v), + pad) + } else { + // the spaces between breaks the kube formatting + fmt.Fprintf(t.out, "%s%s", + padFunc(h, SPACE, v), + pad) + } + } + } + // Next line + fmt.Fprint(t.out, t.newLine) + } + if t.hdrLine { + t.printLine(true) + } +} + +// Print heading information +func (t *Table) printFooter() { + // Check if headers is available + if len(t.footers) < 1 { + return + } + + // Only print line if border is not set + if !t.borders.Bottom { + t.printLine(true) + } + + // Identify last column + end := len(t.cs) - 1 + + // Get pad function + padFunc := pad(t.fAlign) + + // Checking for ANSI escape sequences for header + is_esc_seq := false + if len(t.footerParams) > 0 { + is_esc_seq = true + } + + // Maximum height. + max := t.rs[footerRowIdx] + + // Print Footer + erasePad := make([]bool, len(t.footers)) + for x := 0; x < max; x++ { + // Check if border is set + // Replace with space if not set + fmt.Fprint(t.out, ConditionString(t.borders.Bottom, t.pColumn, SPACE)) + + for y := 0; y <= end; y++ { + v := t.cs[y] + f := "" + if y < len(t.footers) && x < len(t.footers[y]) { + f = t.footers[y][x] + } + if t.autoFmt { + f = Title(f) + } + pad := ConditionString((y == end && !t.borders.Top), SPACE, t.pColumn) + + if erasePad[y] || (x == 0 && len(f) == 0) { + pad = SPACE + erasePad[y] = true + } + + if is_esc_seq { + fmt.Fprintf(t.out, " %s %s", + format(padFunc(f, SPACE, v), + t.footerParams[y]), pad) + } else { + fmt.Fprintf(t.out, " %s %s", + padFunc(f, SPACE, v), + pad) + } + + //fmt.Fprintf(t.out, " %s %s", + // padFunc(f, SPACE, v), + // pad) + } + // Next line + fmt.Fprint(t.out, t.newLine) + //t.printLine(true) + } + + hasPrinted := false + + for i := 0; i <= end; i++ { + v := t.cs[i] + pad := t.pRow + center := t.pCenter + length := len(t.footers[i][0]) + + if length > 0 { + hasPrinted = true + } + + // Set center to be space if length is 0 + if length == 0 && !t.borders.Right { + center = SPACE + } + + // Print first junction + if i == 0 { + if length > 0 && !t.borders.Left { + center = t.pRow + } + fmt.Fprint(t.out, center) + } + + // Pad With space of length is 0 + if length == 0 { + pad = SPACE + } + // Ignore left space as it has printed before + if hasPrinted || t.borders.Left { + pad = t.pRow + center = t.pCenter + } + + // Change Center end position + if center != SPACE { + if i == end && !t.borders.Right { + center = t.pRow + } + } + + // Change Center start position + if center == SPACE { + if i < end && len(t.footers[i+1][0]) != 0 { + if !t.borders.Left { + center = t.pRow + } else { + center = t.pCenter + } + } + } + + // Print the footer + fmt.Fprintf(t.out, "%s%s%s%s", + pad, + strings.Repeat(string(pad), v), + pad, + center) + + } + + fmt.Fprint(t.out, t.newLine) +} + +// Print caption text +func (t Table) printCaption() { + width := t.getTableWidth() + paragraph, _ := WrapString(t.captionText, width) + for linecount := 0; linecount < len(paragraph); linecount++ { + fmt.Fprintln(t.out, paragraph[linecount]) + } +} + +// Calculate the total number of characters in a row +func (t Table) getTableWidth() int { + var chars int + for _, v := range t.cs { + chars += v + } + + // Add chars, spaces, seperators to calculate the total width of the table. + // ncols := t.colSize + // spaces := ncols * 2 + // seps := ncols + 1 + + return (chars + (3 * t.colSize) + 2) +} + +func (t Table) printRows() { + for i, lines := range t.lines { + t.printRow(lines, i) + } +} + +func (t *Table) fillAlignment(num int) { + if len(t.columnsAlign) < num { + t.columnsAlign = make([]int, num) + for i := range t.columnsAlign { + t.columnsAlign[i] = t.align + } + } +} + +// Print Row Information +// Adjust column alignment based on type + +func (t *Table) printRow(columns [][]string, rowIdx int) { + // Get Maximum Height + max := t.rs[rowIdx] + total := len(columns) + + // TODO Fix uneven col size + // if total < t.colSize { + // for n := t.colSize - total; n < t.colSize ; n++ { + // columns = append(columns, []string{SPACE}) + // t.cs[n] = t.mW + // } + //} + + // Pad Each Height + pads := []int{} + + // Checking for ANSI escape sequences for columns + is_esc_seq := false + if len(t.columnsParams) > 0 { + is_esc_seq = true + } + t.fillAlignment(total) + + for i, line := range columns { + length := len(line) + pad := max - length + pads = append(pads, pad) + for n := 0; n < pad; n++ { + columns[i] = append(columns[i], " ") + } + } + //fmt.Println(max, "\n") + for x := 0; x < max; x++ { + for y := 0; y < total; y++ { + + // Check if border is set + if !t.noWhiteSpace { + fmt.Fprint(t.out, ConditionString((!t.borders.Left && y == 0), SPACE, t.pColumn)) + fmt.Fprintf(t.out, SPACE) + } + + str := columns[y][x] + + // Embedding escape sequence with column value + if is_esc_seq { + str = format(str, t.columnsParams[y]) + } + + // This would print alignment + // Default alignment would use multiple configuration + switch t.columnsAlign[y] { + case ALIGN_CENTER: // + fmt.Fprintf(t.out, "%s", Pad(str, SPACE, t.cs[y])) + case ALIGN_RIGHT: + fmt.Fprintf(t.out, "%s", PadLeft(str, SPACE, t.cs[y])) + case ALIGN_LEFT: + fmt.Fprintf(t.out, "%s", PadRight(str, SPACE, t.cs[y])) + default: + if decimal.MatchString(strings.TrimSpace(str)) || percent.MatchString(strings.TrimSpace(str)) { + fmt.Fprintf(t.out, "%s", PadLeft(str, SPACE, t.cs[y])) + } else { + fmt.Fprintf(t.out, "%s", PadRight(str, SPACE, t.cs[y])) + + // TODO Custom alignment per column + //if max == 1 || pads[y] > 0 { + // fmt.Fprintf(t.out, "%s", Pad(str, SPACE, t.cs[y])) + //} else { + // fmt.Fprintf(t.out, "%s", PadRight(str, SPACE, t.cs[y])) + //} + + } + } + if !t.noWhiteSpace { + fmt.Fprintf(t.out, SPACE) + } else { + fmt.Fprintf(t.out, t.tablePadding) + } + } + // Check if border is set + // Replace with space if not set + if !t.noWhiteSpace { + fmt.Fprint(t.out, ConditionString(t.borders.Left, t.pColumn, SPACE)) + } + fmt.Fprint(t.out, t.newLine) + } + + if t.rowLine { + t.printLine(true) + } +} + +// Print the rows of the table and merge the cells that are identical +func (t *Table) printRowsMergeCells() { + var previousLine []string + var displayCellBorder []bool + var tmpWriter bytes.Buffer + for i, lines := range t.lines { + // We store the display of the current line in a tmp writer, as we need to know which border needs to be print above + previousLine, displayCellBorder = t.printRowMergeCells(&tmpWriter, lines, i, previousLine) + if i > 0 { //We don't need to print borders above first line + if t.rowLine { + t.printLineOptionalCellSeparators(true, displayCellBorder) + } + } + tmpWriter.WriteTo(t.out) + } + //Print the end of the table + if t.rowLine { + t.printLine(true) + } +} + +// Print Row Information to a writer and merge identical cells. +// Adjust column alignment based on type + +func (t *Table) printRowMergeCells(writer io.Writer, columns [][]string, rowIdx int, previousLine []string) ([]string, []bool) { + // Get Maximum Height + max := t.rs[rowIdx] + total := len(columns) + + // Pad Each Height + pads := []int{} + + // Checking for ANSI escape sequences for columns + is_esc_seq := false + if len(t.columnsParams) > 0 { + is_esc_seq = true + } + for i, line := range columns { + length := len(line) + pad := max - length + pads = append(pads, pad) + for n := 0; n < pad; n++ { + columns[i] = append(columns[i], " ") + } + } + + var displayCellBorder []bool + t.fillAlignment(total) + for x := 0; x < max; x++ { + for y := 0; y < total; y++ { + + // Check if border is set + fmt.Fprint(writer, ConditionString((!t.borders.Left && y == 0), SPACE, t.pColumn)) + + fmt.Fprintf(writer, SPACE) + + str := columns[y][x] + + // Embedding escape sequence with column value + if is_esc_seq { + str = format(str, t.columnsParams[y]) + } + + if t.autoMergeCells { + var mergeCell bool + if t.columnsToAutoMergeCells != nil { + // Check to see if the column index is in columnsToAutoMergeCells. + if t.columnsToAutoMergeCells[y] { + mergeCell = true + } + } else { + // columnsToAutoMergeCells was not set. + mergeCell = true + } + //Store the full line to merge mutli-lines cells + fullLine := strings.TrimRight(strings.Join(columns[y], " "), " ") + if len(previousLine) > y && fullLine == previousLine[y] && fullLine != "" && mergeCell { + // If this cell is identical to the one above but not empty, we don't display the border and keep the cell empty. + displayCellBorder = append(displayCellBorder, false) + str = "" + } else { + // First line or different content, keep the content and print the cell border + displayCellBorder = append(displayCellBorder, true) + } + } + + // This would print alignment + // Default alignment would use multiple configuration + switch t.columnsAlign[y] { + case ALIGN_CENTER: // + fmt.Fprintf(writer, "%s", Pad(str, SPACE, t.cs[y])) + case ALIGN_RIGHT: + fmt.Fprintf(writer, "%s", PadLeft(str, SPACE, t.cs[y])) + case ALIGN_LEFT: + fmt.Fprintf(writer, "%s", PadRight(str, SPACE, t.cs[y])) + default: + if decimal.MatchString(strings.TrimSpace(str)) || percent.MatchString(strings.TrimSpace(str)) { + fmt.Fprintf(writer, "%s", PadLeft(str, SPACE, t.cs[y])) + } else { + fmt.Fprintf(writer, "%s", PadRight(str, SPACE, t.cs[y])) + } + } + fmt.Fprintf(writer, SPACE) + } + // Check if border is set + // Replace with space if not set + fmt.Fprint(writer, ConditionString(t.borders.Left, t.pColumn, SPACE)) + fmt.Fprint(writer, t.newLine) + } + + //The new previous line is the current one + previousLine = make([]string, total) + for y := 0; y < total; y++ { + previousLine[y] = strings.TrimRight(strings.Join(columns[y], " "), " ") //Store the full line for multi-lines cells + } + //Returns the newly added line and wether or not a border should be displayed above. + return previousLine, displayCellBorder +} + +func (t *Table) parseDimension(str string, colKey, rowKey int) []string { + var ( + raw []string + maxWidth int + ) + + raw = getLines(str) + maxWidth = 0 + for _, line := range raw { + if w := DisplayWidth(line); w > maxWidth { + maxWidth = w + } + } + + // If wrapping, ensure that all paragraphs in the cell fit in the + // specified width. + if t.autoWrap { + // If there's a maximum allowed width for wrapping, use that. + if maxWidth > t.mW { + maxWidth = t.mW + } + + // In the process of doing so, we need to recompute maxWidth. This + // is because perhaps a word in the cell is longer than the + // allowed maximum width in t.mW. + newMaxWidth := maxWidth + newRaw := make([]string, 0, len(raw)) + + if t.reflowText { + // Make a single paragraph of everything. + raw = []string{strings.Join(raw, " ")} + } + for i, para := range raw { + paraLines, _ := WrapString(para, maxWidth) + for _, line := range paraLines { + if w := DisplayWidth(line); w > newMaxWidth { + newMaxWidth = w + } + } + if i > 0 { + newRaw = append(newRaw, " ") + } + newRaw = append(newRaw, paraLines...) + } + raw = newRaw + maxWidth = newMaxWidth + } + + // Store the new known maximum width. + v, ok := t.cs[colKey] + if !ok || v < maxWidth || v == 0 { + t.cs[colKey] = maxWidth + } + + // Remember the number of lines for the row printer. + h := len(raw) + v, ok = t.rs[rowKey] + + if !ok || v < h || v == 0 { + t.rs[rowKey] = h + } + //fmt.Printf("Raw %+v %d\n", raw, len(raw)) + return raw +} diff --git a/console/tablewriter/table_with_color.go b/console/tablewriter/table_with_color.go new file mode 100644 index 0000000..90f69b0 --- /dev/null +++ b/console/tablewriter/table_with_color.go @@ -0,0 +1,143 @@ +// +// table_with_color.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package tablewriter + +import ( + "fmt" + "strconv" + "strings" +) + +const ESC = "\033" +const SEP = ";" + +const ( + BgBlackColor int = iota + 40 + BgRedColor + BgGreenColor + BgYellowColor + BgBlueColor + BgMagentaColor + BgCyanColor + BgWhiteColor +) + +const ( + FgBlackColor int = iota + 30 + FgRedColor + FgGreenColor + FgYellowColor + FgBlueColor + FgMagentaColor + FgCyanColor + FgWhiteColor +) + +const ( + BgHiBlackColor int = iota + 100 + BgHiRedColor + BgHiGreenColor + BgHiYellowColor + BgHiBlueColor + BgHiMagentaColor + BgHiCyanColor + BgHiWhiteColor +) + +const ( + FgHiBlackColor int = iota + 90 + FgHiRedColor + FgHiGreenColor + FgHiYellowColor + FgHiBlueColor + FgHiMagentaColor + FgHiCyanColor + FgHiWhiteColor +) + +const ( + Normal = 0 + Bold = 1 + UnderlineSingle = 4 + Italic +) + +type Colors []int + +func startFormat(seq string) string { + return fmt.Sprintf("%s[%sm", ESC, seq) +} + +func stopFormat() string { + return fmt.Sprintf("%s[%dm", ESC, Normal) +} + +// Making the SGR (Select Graphic Rendition) sequence. +func makeSequence(codes []int) string { + codesInString := []string{} + for _, code := range codes { + codesInString = append(codesInString, strconv.Itoa(code)) + } + return strings.Join(codesInString, SEP) +} + +// Adding ANSI escape sequences before and after string +func format(s string, codes interface{}) string { + var seq string + + switch v := codes.(type) { + + case string: + seq = v + case []int: + seq = makeSequence(v) + case Colors: + seq = makeSequence(v) + default: + return s + } + + if len(seq) == 0 { + return s + } + return startFormat(seq) + s + stopFormat() +} + +// Adding header colors (ANSI codes) +func (t *Table) SetHeaderColor(colors ...Colors) { + if t.colSize != len(colors) { + panic("Number of header colors must be equal to number of headers.") + } + for i := 0; i < len(colors); i++ { + t.headerParams = append(t.headerParams, makeSequence(colors[i])) + } +} + +// Adding column colors (ANSI codes) +func (t *Table) SetColumnColor(colors ...Colors) { + if t.colSize != len(colors) { + panic("Number of column colors must be equal to number of headers.") + } + for i := 0; i < len(colors); i++ { + t.columnsParams = append(t.columnsParams, makeSequence(colors[i])) + } +} + +// Adding column colors (ANSI codes) +func (t *Table) SetFooterColor(colors ...Colors) { + if len(t.footers) != len(colors) { + panic("Number of footer colors must be equal to number of footer.") + } + for i := 0; i < len(colors); i++ { + t.footerParams = append(t.footerParams, makeSequence(colors[i])) + } +} + +func Color(colors ...int) []int { + return colors +} diff --git a/console/tablewriter/util.go b/console/tablewriter/util.go new file mode 100644 index 0000000..9e1cc9f --- /dev/null +++ b/console/tablewriter/util.go @@ -0,0 +1,93 @@ +// +// util.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package tablewriter + +import ( + "math" + "regexp" + "strings" + + "github.com/mattn/go-runewidth" +) + +var ansi = regexp.MustCompile("\033\\[(?:[0-9]{1,3}(?:;[0-9]{1,3})*)?[m|K]") + +func DisplayWidth(str string) int { + return runewidth.StringWidth(ansi.ReplaceAllLiteralString(str, "")) +} + +// Simple Condition for string +// Returns value based on condition +func ConditionString(cond bool, valid, inValid string) string { + if cond { + return valid + } + return inValid +} + +func isNumOrSpace(r rune) bool { + return ('0' <= r && r <= '9') || r == ' ' +} + +// Format Table Header +// Replace _ , . and spaces +func Title(name string) string { + origLen := len(name) + rs := []rune(name) + for i, r := range rs { + switch r { + case '_': + rs[i] = ' ' + case '.': + // ignore floating number 0.0 + if (i != 0 && !isNumOrSpace(rs[i-1])) || (i != len(rs)-1 && !isNumOrSpace(rs[i+1])) { + rs[i] = ' ' + } + } + } + name = string(rs) + name = strings.TrimSpace(name) + if len(name) == 0 && origLen > 0 { + // Keep at least one character. This is important to preserve + // empty lines in multi-line headers/footers. + name = " " + } + return strings.ToUpper(name) +} + +// Pad String +// Attempts to place string in the center +func Pad(s, pad string, width int) string { + gap := width - DisplayWidth(s) + if gap > 0 { + gapLeft := int(math.Ceil(float64(gap / 2))) + gapRight := gap - gapLeft + return strings.Repeat(string(pad), gapLeft) + s + strings.Repeat(string(pad), gapRight) + } + return s +} + +// Pad String Right position +// This would place string at the left side of the screen +func PadRight(s, pad string, width int) string { + gap := width - DisplayWidth(s) + if gap > 0 { + return s + strings.Repeat(string(pad), gap) + } + return s +} + +// Pad String Left position +// This would place string at the right side of the screen +func PadLeft(s, pad string, width int) string { + gap := width - DisplayWidth(s) + if gap > 0 { + return strings.Repeat(string(pad), gap) + s + } + return s +} diff --git a/console/tablewriter/wrap.go b/console/tablewriter/wrap.go new file mode 100644 index 0000000..91b91b4 --- /dev/null +++ b/console/tablewriter/wrap.go @@ -0,0 +1,99 @@ +// +// wrap.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package tablewriter + +import ( + "math" + "strings" + + "github.com/mattn/go-runewidth" +) + +var ( + nl = "\n" + sp = " " +) + +const defaultPenalty = 1e5 + +// Wrap wraps s into a paragraph of lines of length lim, with minimal +// raggedness. +func WrapString(s string, lim int) ([]string, int) { + words := strings.Split(strings.Replace(s, nl, sp, -1), sp) + var lines []string + max := 0 + for _, v := range words { + max = runewidth.StringWidth(v) + if max > lim { + lim = max + } + } + for _, line := range WrapWords(words, 1, lim, defaultPenalty) { + lines = append(lines, strings.Join(line, sp)) + } + return lines, lim +} + +// WrapWords is the low-level line-breaking algorithm, useful if you need more +// control over the details of the text wrapping process. For most uses, +// WrapString will be sufficient and more convenient. +// +// WrapWords splits a list of words into lines with minimal "raggedness", +// treating each rune as one unit, accounting for spc units between adjacent +// words on each line, and attempting to limit lines to lim units. Raggedness +// is the total error over all lines, where error is the square of the +// difference of the length of the line and lim. Too-long lines (which only +// happen when a single word is longer than lim units) have pen penalty units +// added to the error. +func WrapWords(words []string, spc, lim, pen int) [][]string { + n := len(words) + + length := make([][]int, n) + for i := 0; i < n; i++ { + length[i] = make([]int, n) + length[i][i] = runewidth.StringWidth(words[i]) + for j := i + 1; j < n; j++ { + length[i][j] = length[i][j-1] + spc + runewidth.StringWidth(words[j]) + } + } + nbrk := make([]int, n) + cost := make([]int, n) + for i := range cost { + cost[i] = math.MaxInt32 + } + for i := n - 1; i >= 0; i-- { + if length[i][n-1] <= lim { + cost[i] = 0 + nbrk[i] = n + } else { + for j := i + 1; j < n; j++ { + d := lim - length[i][j-1] + c := d*d + cost[j] + if length[i][j-1] > lim { + c += pen // too-long lines get a worse penalty + } + if c < cost[i] { + cost[i] = c + nbrk[i] = j + } + } + } + } + var lines [][]string + i := 0 + for i < n { + lines = append(lines, words[i:nbrk[i]]) + i = nbrk[i] + } + return lines +} + +// getLines decomposes a multiline string into a slice of strings. +func getLines(s string) []string { + return strings.Split(s, nl) +} diff --git a/crypto/gaes/aes.go b/crypto/gaes/aes.go new file mode 100644 index 0000000..78878fb --- /dev/null +++ b/crypto/gaes/aes.go @@ -0,0 +1,172 @@ +// +// aes.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gaes + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "errors" +) + +const ( + // IVDefaultValue is the default value for IV. + IVDefaultValue = "I Love Xiao Quan" +) + +// Encrypt is alias of EncryptCBC. +func Encrypt(plainText []byte, key []byte, iv ...[]byte) ([]byte, error) { + return EncryptCBC(plainText, key, iv...) +} + +// Decrypt is alias of DecryptCBC. +func Decrypt(cipherText []byte, key []byte, iv ...[]byte) ([]byte, error) { + return DecryptCBC(cipherText, key, iv...) +} + +// EncryptCBC encrypts `plainText` using CBC mode. +// Note that the key must be 16/24/32 bit length. +// The parameter `iv` initialization vector is unnecessary. +func EncryptCBC(plainText []byte, key []byte, iv ...[]byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + plainText = PKCS5Padding(plainText, blockSize) + ivValue := ([]byte)(nil) + if len(iv) > 0 { + ivValue = iv[0] + } else { + ivValue = []byte(IVDefaultValue) + } + blockMode := cipher.NewCBCEncrypter(block, ivValue) + cipherText := make([]byte, len(plainText)) + blockMode.CryptBlocks(cipherText, plainText) + + return cipherText, nil +} + +// DecryptCBC decrypts `cipherText` using CBC mode. +// Note that the key must be 16/24/32 bit length. +// The parameter `iv` initialization vector is unnecessary. +func DecryptCBC(cipherText []byte, key []byte, iv ...[]byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + if len(cipherText) < blockSize { + return nil, errors.New("cipherText too short") + } + ivValue := ([]byte)(nil) + if len(iv) > 0 { + ivValue = iv[0] + } else { + ivValue = []byte(IVDefaultValue) + } + if len(cipherText)%blockSize != 0 { + return nil, errors.New("cipherText is not a multiple of the block size") + } + blockModel := cipher.NewCBCDecrypter(block, ivValue) + plainText := make([]byte, len(cipherText)) + blockModel.CryptBlocks(plainText, cipherText) + plainText, e := PKCS5UnPadding(plainText, blockSize) + if e != nil { + return nil, e + } + return plainText, nil +} + +func PKCS5Padding(src []byte, blockSize int) []byte { + padding := blockSize - len(src)%blockSize + padtext := bytes.Repeat([]byte{byte(padding)}, padding) + return append(src, padtext...) +} + +func PKCS5UnPadding(src []byte, blockSize int) ([]byte, error) { + length := len(src) + if blockSize <= 0 { + return nil, errors.New("invalid blocklen") + } + + if length%blockSize != 0 || length == 0 { + return nil, errors.New("invalid data len") + } + + unpadding := int(src[length-1]) + if unpadding > blockSize || unpadding == 0 { + return nil, errors.New("invalid padding") + } + + padding := src[length-unpadding:] + for i := 0; i < unpadding; i++ { + if padding[i] != byte(unpadding) { + return nil, errors.New("invalid padding") + } + } + + return src[:(length - unpadding)], nil +} + +// EncryptCFB encrypts `plainText` using CFB mode. +// Note that the key must be 16/24/32 bit length. +// The parameter `iv` initialization vector is unnecessary. +func EncryptCFB(plainText []byte, key []byte, padding *int, iv ...[]byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + blockSize := block.BlockSize() + plainText, *padding = ZeroPadding(plainText, blockSize) + ivValue := ([]byte)(nil) + if len(iv) > 0 { + ivValue = iv[0] + } else { + ivValue = []byte(IVDefaultValue) + } + stream := cipher.NewCFBEncrypter(block, ivValue) + cipherText := make([]byte, len(plainText)) + stream.XORKeyStream(cipherText, plainText) + return cipherText, nil +} + +// DecryptCFB decrypts `plainText` using CFB mode. +// Note that the key must be 16/24/32 bit length. +// The parameter `iv` initialization vector is unnecessary. +func DecryptCFB(cipherText []byte, key []byte, unPadding int, iv ...[]byte) ([]byte, error) { + block, err := aes.NewCipher(key) + if err != nil { + return nil, err + } + if len(cipherText) < aes.BlockSize { + return nil, errors.New("cipherText too short") + } + ivValue := ([]byte)(nil) + if len(iv) > 0 { + ivValue = iv[0] + } else { + ivValue = []byte(IVDefaultValue) + } + stream := cipher.NewCFBDecrypter(block, ivValue) + plainText := make([]byte, len(cipherText)) + stream.XORKeyStream(plainText, cipherText) + plainText = ZeroUnPadding(plainText, unPadding) + return plainText, nil +} + +func ZeroPadding(cipherText []byte, blockSize int) ([]byte, int) { + padding := blockSize - len(cipherText)%blockSize + padText := bytes.Repeat([]byte{byte(0)}, padding) + return append(cipherText, padText...), padding +} + +func ZeroUnPadding(plaintext []byte, unPadding int) []byte { + length := len(plaintext) + return plaintext[:(length - unPadding)] +} diff --git a/crypto/gmd5/md5.go b/crypto/gmd5/md5.go new file mode 100644 index 0000000..f503727 --- /dev/null +++ b/crypto/gmd5/md5.go @@ -0,0 +1,91 @@ +// +// md5.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gmd5 + +import ( + "crypto/md5" + "fmt" + "io" + "os" +) + +// Encrypt encrypts any type of variable using MD5 algorithms. +// It uses gconv package to convert `v` to its bytes type. +func Encrypt(in string) (encrypt string, err error) { + return EncryptBytes([]byte(in)) +} + +// MustEncrypt encrypts any type of variable using MD5 algorithms. +// It uses gconv package to convert `v` to its bytes type. +// It panics if any error occurs. +func MustEncrypt(in string) string { + result, err := Encrypt(in) + if err != nil { + panic(err) + } + return result +} + +// EncryptBytes encrypts `data` using MD5 algorithms. +func EncryptBytes(data []byte) (encrypt string, err error) { + h := md5.New() + if _, err = h.Write(data); err != nil { + return "", err + } + return fmt.Sprintf("%x", h.Sum(nil)), nil +} + +// MustEncryptBytes encrypts `data` using MD5 algorithms. +// It panics if any error occurs. +func MustEncryptBytes(data []byte) string { + result, err := EncryptBytes(data) + if err != nil { + panic(err) + } + return result +} + +// EncryptString encrypts string `data` using MD5 algorithms. +func EncryptString(data string) (encrypt string, err error) { + return EncryptBytes([]byte(data)) +} + +// MustEncryptString encrypts string `data` using MD5 algorithms. +// It panics if any error occurs. +func MustEncryptString(data string) string { + result, err := EncryptString(data) + if err != nil { + panic(err) + } + return result +} + +// EncryptFile encrypts file content of `path` using MD5 algorithms. +func EncryptFile(path string) (encrypt string, err error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := md5.New() + _, err = io.Copy(h, f) + if err != nil { + return "", err + } + return fmt.Sprintf("%x", h.Sum(nil)), nil +} + +// MustEncryptFile encrypts file content of `path` using MD5 algorithms. +// It panics if any error occurs. +func MustEncryptFile(path string) string { + result, err := EncryptFile(path) + if err != nil { + panic(err) + } + return result +} diff --git a/crypto/gsha1/sha1.go b/crypto/gsha1/sha1.go new file mode 100644 index 0000000..3622fdc --- /dev/null +++ b/crypto/gsha1/sha1.go @@ -0,0 +1,47 @@ +// +// sha1.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gsha1 + +import ( + "crypto/sha1" + "encoding/hex" + "io" + "os" +) + +// Encrypt encrypts any type of variable using SHA1 algorithms. +// It uses package gconv to convert `v` to its bytes type. +func Encrypt(in string) string { + r := sha1.Sum([]byte(in)) + return hex.EncodeToString(r[:]) +} + +// EncryptFile encrypts file content of `path` using SHA1 algorithms. +func EncryptFile(path string) (encrypt string, err error) { + f, err := os.Open(path) + if err != nil { + return "", err + } + defer f.Close() + h := sha1.New() + _, err = io.Copy(h, f) + if err != nil { + return "", err + } + return hex.EncodeToString(h.Sum(nil)), nil +} + +// MustEncryptFile encrypts file content of `path` using SHA1 algorithms. +// It panics if any error occurs. +func MustEncryptFile(path string) string { + result, err := EncryptFile(path) + if err != nil { + panic(err) + } + return result +} diff --git a/encoding/gbase64/base64.go b/encoding/gbase64/base64.go new file mode 100644 index 0000000..8ed3ac1 --- /dev/null +++ b/encoding/gbase64/base64.go @@ -0,0 +1,121 @@ +// +// base64.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gbase64 + +import ( + "encoding/base64" + "io/ioutil" +) + +// Encode encodes bytes with BASE64 algorithm. +func Encode(src []byte) []byte { + dst := make([]byte, base64.StdEncoding.EncodedLen(len(src))) + base64.StdEncoding.Encode(dst, src) + return dst +} + +// EncodeString encodes string with BASE64 algorithm. +func EncodeString(src string) string { + return EncodeToString([]byte(src)) +} + +// EncodeToString encodes bytes to string with BASE64 algorithm. +func EncodeToString(src []byte) string { + return string(Encode(src)) +} + +// EncodeFile encodes file content of `path` using BASE64 algorithms. +func EncodeFile(path string) ([]byte, error) { + content, err := ioutil.ReadFile(path) + if err != nil { + return nil, err + } + return Encode(content), nil +} + +// MustEncodeFile encodes file content of `path` using BASE64 algorithms. +// It panics if any error occurs. +func MustEncodeFile(path string) []byte { + result, err := EncodeFile(path) + if err != nil { + panic(err) + } + return result +} + +// EncodeFileToString encodes file content of `path` to string using BASE64 algorithms. +func EncodeFileToString(path string) (string, error) { + content, err := EncodeFile(path) + if err != nil { + return "", err + } + return string(content), nil +} + +// MustEncodeFileToString encodes file content of `path` to string using BASE64 algorithms. +// It panics if any error occurs. +func MustEncodeFileToString(path string) string { + result, err := EncodeFileToString(path) + if err != nil { + panic(err) + } + return result +} + +// Decode decodes bytes with BASE64 algorithm. +func Decode(data []byte) ([]byte, error) { + var ( + src = make([]byte, base64.StdEncoding.DecodedLen(len(data))) + n, err = base64.StdEncoding.Decode(src, data) + ) + if err != nil { + return nil, err + } + return src[:n], nil +} + +// MustDecode decodes bytes with BASE64 algorithm. +// It panics if any error occurs. +func MustDecode(data []byte) []byte { + result, err := Decode(data) + if err != nil { + panic(err) + } + return result +} + +// DecodeString decodes string with BASE64 algorithm. +func DecodeString(data string) ([]byte, error) { + return Decode([]byte(data)) +} + +// MustDecodeString decodes string with BASE64 algorithm. +// It panics if any error occurs. +func MustDecodeString(data string) []byte { + result, err := DecodeString(data) + if err != nil { + panic(err) + } + return result +} + +// DecodeToString decodes string with BASE64 algorithm. +func DecodeToString(data string) (string, error) { + b, err := DecodeString(data) + return string(b), err +} + +// MustDecodeToString decodes string with BASE64 algorithm. +// It panics if any error occurs. +func MustDecodeToString(data string) string { + result, err := DecodeToString(data) + if err != nil { + panic(err) + } + return result +} diff --git a/encoding/gurl/url.go b/encoding/gurl/url.go new file mode 100644 index 0000000..6dbf8be --- /dev/null +++ b/encoding/gurl/url.go @@ -0,0 +1,87 @@ +// +// url.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gurl + +import ( + "net/url" + "strings" +) + +// Encode escapes the string so it can be safely placed +// inside a URL query. +func Encode(str string) string { + return url.QueryEscape(str) +} + +// Decode does the inverse transformation of Encode, +// converting each 3-byte encoded substring of the form "%AB" into the +// hex-decoded byte 0xAB. +// It returns an error if any % is not followed by two hexadecimal +// digits. +func Decode(str string) (string, error) { + return url.QueryUnescape(str) +} + +// RawEncode does encode the given string according +// URL-encode according to RFC 3986. +// See http://php.net/manual/en/function.rawurlencode.php. +func RawEncode(str string) string { + return strings.Replace(url.QueryEscape(str), "+", "%20", -1) +} + +// RawDecode does decode the given string +// Decode URL-encoded strings. +// See http://php.net/manual/en/function.rawurldecode.php. +func RawDecode(str string) (string, error) { + return url.QueryUnescape(strings.Replace(str, "%20", "+", -1)) +} + +// BuildQuery Generate URL-encoded query string. +// See http://php.net/manual/en/function.http-build-query.php. +func BuildQuery(queryData url.Values) string { + return queryData.Encode() +} + +// ParseURL Parse a URL and return its components. +// -1: all; 1: scheme; 2: host; 4: port; 8: user; 16: pass; 32: path; 64: query; 128: fragment. +// See http://php.net/manual/en/function.parse-url.php. +func ParseURL(str string, component int) (map[string]string, error) { + u, err := url.Parse(str) + if err != nil { + return nil, err + } + if component == -1 { + component = 1 | 2 | 4 | 8 | 16 | 32 | 64 | 128 + } + var components = make(map[string]string) + if (component & 1) == 1 { + components["scheme"] = u.Scheme + } + if (component & 2) == 2 { + components["host"] = u.Hostname() + } + if (component & 4) == 4 { + components["port"] = u.Port() + } + if (component & 8) == 8 { + components["user"] = u.User.Username() + } + if (component & 16) == 16 { + components["pass"], _ = u.User.Password() + } + if (component & 32) == 32 { + components["path"] = u.Path + } + if (component & 64) == 64 { + components["query"] = u.RawQuery + } + if (component & 128) == 128 { + components["fragment"] = u.Fragment + } + return components, nil +} diff --git a/gauth/auth.go b/gauth/auth.go new file mode 100644 index 0000000..a22bb52 --- /dev/null +++ b/gauth/auth.go @@ -0,0 +1,20 @@ +// +// auth.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gauth + +type sAuth struct { +} + +func New() *sAuth { + return &sAuth{} +} + +// 是否支持 +func (s *sAuth) Support() error { + return nil +} diff --git a/gauth/helper.go b/gauth/helper.go new file mode 100644 index 0000000..a58acf5 --- /dev/null +++ b/gauth/helper.go @@ -0,0 +1,25 @@ +// +// helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gauth + +import "golang.org/x/crypto/bcrypt" + +// 加密密码 +func EncryptPassword(password string) (string, error) { + bt, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + if err != nil { + return "", err + } + return string(bt), nil +} + +// 检查密码 +func CheckPassword(pwd_plain, pwd_hash string) bool { + err := bcrypt.CompareHashAndPassword([]byte(pwd_hash), []byte(pwd_plain)) + return err == nil +} diff --git a/gauth/middleware.go b/gauth/middleware.go new file mode 100644 index 0000000..cc6df38 --- /dev/null +++ b/gauth/middleware.go @@ -0,0 +1,19 @@ +// +// middleware.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gauth + +import "github.com/gin-gonic/gin" + +func GinAuth() gin.HandlerFunc { + + return func(c *gin.Context) { + + c.Next() + } + +} diff --git a/gauth/readme.adoc b/gauth/readme.adoc new file mode 100644 index 0000000..4dea337 --- /dev/null +++ b/gauth/readme.adoc @@ -0,0 +1,17 @@ += 认证 +:author: tiglog +:experimental: +:toc: left +:toclevels: 3 +:toc-title: 目录 +:sectnums: +:icons: font +:!webfonts: +:autofit-option: +:source-highlighter: rouge +:rouge-style: github +:source-linenums-option: +:revdate: 2022-12-01 +:imagesdir: ./img + + diff --git a/gcache/adapter_file.go b/gcache/adapter_file.go new file mode 100644 index 0000000..3011c71 --- /dev/null +++ b/gcache/adapter_file.go @@ -0,0 +1,10 @@ +// +// adapter_file.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcache + +// 本地文件缓存 diff --git a/gcache/adapter_local.go b/gcache/adapter_local.go new file mode 100644 index 0000000..126b033 --- /dev/null +++ b/gcache/adapter_local.go @@ -0,0 +1,10 @@ +// +// adapter_local.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcache + +// 本地内存缓存 diff --git a/gcache/adapter_redis.go b/gcache/adapter_redis.go new file mode 100644 index 0000000..946996d --- /dev/null +++ b/gcache/adapter_redis.go @@ -0,0 +1,10 @@ +// +// adapter_redis.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcache + +// 使用 redis 服务缓存 diff --git a/gcache/cache.go b/gcache/cache.go new file mode 100644 index 0000000..b1af227 --- /dev/null +++ b/gcache/cache.go @@ -0,0 +1,16 @@ +// +// cache.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcache + +type Engine struct { + ICacheAdapter +} + +func New(adapter ICacheAdapter) *Engine { + return &Engine{adapter} +} diff --git a/gcache/cache_contract.go b/gcache/cache_contract.go new file mode 100644 index 0000000..cebcaeb --- /dev/null +++ b/gcache/cache_contract.go @@ -0,0 +1,18 @@ +// +// cache_contact.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcache + +import "time" + +type ICacheAdapter interface { + Get(key string) (string, error) + Set(key string, val interface{}, exp time.Duration) error + Del(keys ...string) int64 + Has(key string) bool + End() +} diff --git a/gcache/readme.adoc b/gcache/readme.adoc new file mode 100644 index 0000000..812ff77 --- /dev/null +++ b/gcache/readme.adoc @@ -0,0 +1,54 @@ += 缓存设计 +:author: tiglog +:experimental: +:toc: left +:toclevels: 3 +:toc-title: 目录 +:sectnums: +:icons: font +:!webfonts: +:autofit-option: +:source-highlighter: rouge +:rouge-style: github +:source-linenums-option: +:revdate: 2022-11-30 +:imagesdir: ./img + + +**注:** 暂时直接使用 `go-resdis/cache` + + +从使用倒推设计。 + +== 场景1 + +自己管理 `key`: + +[source,golang] +---- +ck := "key_foo" +data := cache.get(ck) +if !data { // <1> + data = FETCH_DATA() + cache.set(ck, data, 7200) // <2> +} +return data +---- + +<1> `get` 值为 `false` 表示没有缓存或缓存已过期 +<2> 7200 为缓存有效期(单位为秒),若指定为 0 表示不过期。 + +== 场景2 + +程序自动管理 `key`: + +[source,golang] +---- +cache.get(func() { + return 'foo' +}, 7200) +---- + +这种方式一般情况下比较方便,要是需要手动使缓存失效,则要麻烦一些。因此,这种方 +式暂时不实现。 + diff --git a/gcasbin/adapter_redis.go b/gcasbin/adapter_redis.go new file mode 100644 index 0000000..e3c87b9 --- /dev/null +++ b/gcasbin/adapter_redis.go @@ -0,0 +1,172 @@ +// +// adapter_redis.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcasbin + +import ( + "context" + "errors" + "fmt" + "strings" + + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" + "github.com/casbin/casbin/v2/util" + "github.com/go-redis/redis/v8" +) + +const ( + // The key under which the policies are stored in redis + PolicyKey = "casbin:policy" +) + +// Adapter is an adapter for policy storage based on Redis +type RedisAdapter struct { + redisCli *redis.Client +} + +// NewFromDSN returns a new Adapter by using the given DSN. +// Format: redis://:{password}@{host}:{port}/{database} +// Example: redis://:123@localhost:6379/0 +func NewRedisAdapterFromURL(url string) (adapter *RedisAdapter, err error) { + opt, err := redis.ParseURL(url) + if err != nil { + return nil, err + } + + redisCli := redis.NewClient(opt) + if err = redisCli.Ping(context.Background()).Err(); err != nil { + return nil, fmt.Errorf("failed to ping redis: %v", err) + } + + return NewRedisAdapterFromClient(redisCli), nil +} + +// NewFromClient returns a new instance of Adapter from an already existing go-redis client. +func NewRedisAdapterFromClient(redisCli *redis.Client) (adapter *RedisAdapter) { + return &RedisAdapter{redisCli: redisCli} +} + +// LoadPolicy loads all policy rules from the storage. +func (a *RedisAdapter) LoadPolicy(model model.Model) (err error) { + ctx := context.Background() + + // Using the LoadPolicyLine handler from the Casbin repo for building rules + return a.loadPolicy(ctx, model, persist.LoadPolicyArray) +} + +func (a *RedisAdapter) loadPolicy(ctx context.Context, model model.Model, handler func([]string, model.Model) error) (err error) { + // 0, -1 fetches all entries from the list + rules, err := a.redisCli.LRange(ctx, PolicyKey, 0, -1).Result() + if err != nil { + return err + } + + // Parse the rules from Redis + for _, rule := range rules { + handler(strings.Split(rule, ", "), model) + } + + return +} + +// SavePolicy saves all policy rules to the storage. +func (a *RedisAdapter) SavePolicy(model model.Model) (err error) { + ctx := context.Background() + var rules []string + + // Serialize the policies into a string slice + for ptype, assertion := range model["p"] { + for _, rule := range assertion.Policy { + rules = append(rules, buildRuleStr(ptype, rule)) + } + } + + // Append the group policies to the slice + for ptype, assertion := range model["g"] { + for _, rule := range assertion.Policy { + rules = append(rules, buildRuleStr(ptype, rule)) + } + } + + // If an empty ruleset is saved, the policy is completely deleted from Redis. + if len(rules) > 0 { + return a.savePolicy(ctx, rules) + } + return a.delPolicy(ctx) +} + +func (a *RedisAdapter) savePolicy(ctx context.Context, rules []string) (err error) { + // Use a transaction for deleting the key & creating a new one. + // This only uses one round trip to Redis and also makes sure nothing bad happens. + cmd, err := a.redisCli.TxPipelined(ctx, func(tx redis.Pipeliner) error { + tx.Del(ctx, PolicyKey) + tx.RPush(ctx, PolicyKey, strToInterfaceSlice(rules)...) + + return nil + }) + if err = cmd[0].Err(); err != nil { + return fmt.Errorf("failed to delete policy key: %v", err) + } + if err = cmd[1].Err(); err != nil { + return fmt.Errorf("failed to save policy: %v", err) + } + + return +} + +func (a *RedisAdapter) delPolicy(ctx context.Context) (err error) { + if err = a.redisCli.Del(ctx, PolicyKey).Err(); err != nil { + return err + } + return +} + +// AddPolicy adds a policy rule to the storage. +func (a *RedisAdapter) AddPolicy(_ string, ptype string, rule []string) (err error) { + ctx := context.Background() + return a.addPolicy(ctx, buildRuleStr(ptype, rule)) +} + +func (a *RedisAdapter) addPolicy(ctx context.Context, rule string) (err error) { + if err = a.redisCli.RPush(ctx, PolicyKey, rule).Err(); err != nil { + return err + } + return +} + +// RemovePolicy removes a policy rule from the storage. +func (a *RedisAdapter) RemovePolicy(_ string, ptype string, rule []string) (err error) { + ctx := context.Background() + + return a.removePolicy(ctx, buildRuleStr(ptype, rule)) +} + +func (a *RedisAdapter) removePolicy(ctx context.Context, rule string) (err error) { + if err = a.redisCli.LRem(ctx, PolicyKey, 1, rule).Err(); err != nil { + return err + } + return +} + +// RemoveFilteredPolicy removes policy rules that match the filter from the storage. +func (a *RedisAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + return errors.New("not implemented") +} + +// Converts a string slice to an interface{} slice. +// Needed for pushing elements to a redis list. +func strToInterfaceSlice(ss []string) (is []interface{}) { + for _, s := range ss { + is = append(is, s) + } + return +} + +func buildRuleStr(ptype string, rule []string) string { + return ptype + ", " + util.ArrayToString(rule) +} diff --git a/gcasbin/adapter_redis_test.go b/gcasbin/adapter_redis_test.go new file mode 100644 index 0000000..c4c68d0 --- /dev/null +++ b/gcasbin/adapter_redis_test.go @@ -0,0 +1,160 @@ +// +// adapter_redis_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcasbin_test + +import ( + "context" + "os" + "testing" + + "github.com/casbin/casbin/v2" + "github.com/go-redis/redis/v8" + "hexq.cn/tiglog/golib/gcasbin" + "hexq.cn/tiglog/golib/gtest" +) + +func getRedis() *redis.Client { + addr := os.Getenv("REDIS_ADDR") + username := os.Getenv("REDIS_USERNAME") + pass := os.Getenv("REDIS_PASSWORD") + + return redis.NewClient(&redis.Options{ + Addr: addr, + Username: username, + Password: pass, + DB: 1, + }) +} + +func TestNewRedisAdapterFromClient(t *testing.T) { + rdb := getRedis() + defer rdb.Close() + a := gcasbin.NewRedisAdapterFromClient(rdb) + gtest.NotEmpty(t, a) +} + +func TestNewRedisAdapterFromURL(t *testing.T) { + url := os.Getenv("REDIS_URL") + a, err := gcasbin.NewRedisAdapterFromURL(url) + gtest.NotEmpty(t, a) + gtest.NoError(t, err) +} + +func TestSavePolicy(t *testing.T) { + e, err := casbin.NewEnforcer("testdata/model.conf", "testdata/policy.csv") + gtest.NoError(t, err) + + fileModel := e.GetModel() + + rdb := getRedis() + defer rdb.Close() + + // Create the adapter + a := gcasbin.NewRedisAdapterFromClient(rdb) + + // Save the file model to redis + err = a.SavePolicy(fileModel) + gtest.NoError(t, err) + + // Create a new Enforcer, this time with the redis adapter + e, err = casbin.NewEnforcer("testdata/model.conf", a) + gtest.NoError(t, err) + + // Load policies from redis + err = e.LoadPolicy() + gtest.NoError(t, err) + // gtest.Equal(t, fileModel, e.GetModel()) + + _ = e.SavePolicy() + polLength, err := rdb.LLen(context.Background(), gcasbin.PolicyKey).Result() + gtest.NoError(t, err) + gtest.Equal(t, int64(3), polLength) + + // Delete current policies + e.ClearPolicy() + + // Save empty model for comparison + // emptyModel := e.GetModel() + + // Save empty model + err = e.SavePolicy() + gtest.NoError(t, err) + + // Load empty model again + err = e.LoadPolicy() + gtest.NoError(t, err) + + // Check if the loaded model equals the empty model from before + // gtest.Equal(t, emptyModel, e.GetModel()) +} + +func TestLoadPolicy(t *testing.T) { + gtest.True(t, true) +} + +func TestAddPolicy(t *testing.T) { + rdb := getRedis() + defer rdb.Close() + // Create the adapter + a := gcasbin.NewRedisAdapterFromClient(rdb) + + // Create a new Enforcer, this time with the redis adapter + e, err := casbin.NewEnforcer("testdata/model.conf", a) + gtest.NoError(t, err) + + // Add policies + _, err = e.AddPolicy("bob", "data1", "read") + gtest.NoError(t, err) + _, err = e.AddPolicy("alice", "data1", "write") + gtest.NoError(t, err) + + // Clear all policies from memory + e.ClearPolicy() + + // Policy is deleted now + hasPol := e.HasPolicy("bob", "data1", "read") + gtest.False(t, hasPol) + + // Load policies from redis + err = e.LoadPolicy() + gtest.NoError(t, err) + + // Policy is there again + hasPol = e.HasPolicy("bob", "data1", "read") + gtest.True(t, hasPol) + hasPol = e.HasPolicy("alice", "data1", "write") + gtest.True(t, hasPol) + +} + +func TestRmovePolicy(t *testing.T) { + rdb := getRedis() + defer rdb.Close() + // Create the adapter + a := gcasbin.NewRedisAdapterFromClient(rdb) + + // Create a new Enforcer, this time with the redis adapter + e, err := casbin.NewEnforcer("testdata/model.conf", a) + gtest.NoError(t, err) + + // Add policy + _, err = e.AddPolicy("bob", "data1", "read") + gtest.NoError(t, err) + + // Policy is available + hasPol := e.HasPolicy("bob", "data1", "read") + gtest.True(t, hasPol) + + // Remove the policy + _, err = e.RemovePolicy("bob", "data1", "read") + gtest.NoError(t, err) + + // Policy is gone + hasPol = e.HasPolicy("bob", "data1", "read") + gtest.False(t, hasPol) +} diff --git a/gcasbin/adapter_sqlx.go b/gcasbin/adapter_sqlx.go new file mode 100644 index 0000000..e12df23 --- /dev/null +++ b/gcasbin/adapter_sqlx.go @@ -0,0 +1,749 @@ +// +// adapter_sqlx.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcasbin + +import ( + "bytes" + "context" + "errors" + "fmt" + "strconv" + + "github.com/casbin/casbin/v2/model" + "github.com/casbin/casbin/v2/persist" + "github.com/jmoiron/sqlx" +) + +// defaultTableName if tableName == "", the Adapter will use this default table name. +const defaultTableName = "casbin_rule" + +// maxParamLength . +const maxParamLength = 7 + +// general sql +const ( + sqlCreateTable = ` +CREATE TABLE %[1]s( + p_type VARCHAR(32), + v0 VARCHAR(255), + v1 VARCHAR(255), + v2 VARCHAR(255), + v3 VARCHAR(255), + v4 VARCHAR(255), + v5 VARCHAR(255) +); +CREATE INDEX idx_%[1]s ON %[1]s (p_type,v0,v1);` + sqlTruncateTable = "TRUNCATE TABLE %s" + sqlIsTableExist = "SELECT 1 FROM %s" + sqlInsertRow = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (?,?,?,?,?,?,?)" + sqlUpdateRow = "UPDATE %s SET p_type=?,v0=?,v1=?,v2=?,v3=?,v4=?,v5=? WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?" + sqlDeleteAll = "DELETE FROM %s" + sqlDeleteRow = "DELETE FROM %s WHERE p_type=? AND v0=? AND v1=? AND v2=? AND v3=? AND v4=? AND v5=?" + sqlDeleteByArgs = "DELETE FROM %s WHERE p_type=?" + sqlSelectAll = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s" + sqlSelectWhere = "SELECT p_type,v0,v1,v2,v3,v4,v5 FROM %s WHERE " +) + +// for Sqlite3 +const ( + sqlCreateTableSqlite3 = ` +CREATE TABLE IF NOT EXISTS %[1]s( + p_type VARCHAR(32) DEFAULT '' NOT NULL, + v0 VARCHAR(255) DEFAULT '' NOT NULL, + v1 VARCHAR(255) DEFAULT '' NOT NULL, + v2 VARCHAR(255) DEFAULT '' NOT NULL, + v3 VARCHAR(255) DEFAULT '' NOT NULL, + v4 VARCHAR(255) DEFAULT '' NOT NULL, + v5 VARCHAR(255) DEFAULT '' NOT NULL, + CHECK (TYPEOF("p_type") = "text" AND + LENGTH("p_type") <= 32), + CHECK (TYPEOF("v0") = "text" AND + LENGTH("v0") <= 255), + CHECK (TYPEOF("v1") = "text" AND + LENGTH("v1") <= 255), + CHECK (TYPEOF("v2") = "text" AND + LENGTH("v2") <= 255), + CHECK (TYPEOF("v3") = "text" AND + LENGTH("v3") <= 255), + CHECK (TYPEOF("v4") = "text" AND + LENGTH("v4") <= 255), + CHECK (TYPEOF("v5") = "text" AND + LENGTH("v5") <= 255) +); +CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);` + sqlTruncateTableSqlite3 = "DROP TABLE IF EXISTS %[1]s;" + sqlCreateTableSqlite3 +) + +// for Mysql +const ( + sqlCreateTableMysql = ` +CREATE TABLE IF NOT EXISTS %[1]s( + p_type VARCHAR(32) DEFAULT '' NOT NULL, + v0 VARCHAR(255) DEFAULT '' NOT NULL, + v1 VARCHAR(255) DEFAULT '' NOT NULL, + v2 VARCHAR(255) DEFAULT '' NOT NULL, + v3 VARCHAR(255) DEFAULT '' NOT NULL, + v4 VARCHAR(255) DEFAULT '' NOT NULL, + v5 VARCHAR(255) DEFAULT '' NOT NULL, + INDEX idx_%[1]s (p_type,v0,v1) +) ENGINE = InnoDB DEFAULT CHARSET = utf8mb4;` +) + +// for Postgres +const ( + sqlCreateTablePostgres = ` +CREATE TABLE IF NOT EXISTS %[1]s( + p_type VARCHAR(32) DEFAULT '' NOT NULL, + v0 VARCHAR(255) DEFAULT '' NOT NULL, + v1 VARCHAR(255) DEFAULT '' NOT NULL, + v2 VARCHAR(255) DEFAULT '' NOT NULL, + v3 VARCHAR(255) DEFAULT '' NOT NULL, + v4 VARCHAR(255) DEFAULT '' NOT NULL, + v5 VARCHAR(255) DEFAULT '' NOT NULL +); +CREATE INDEX IF NOT EXISTS idx_%[1]s ON %[1]s (p_type,v0,v1);` + sqlInsertRowPostgres = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES ($1,$2,$3,$4,$5,$6,$7)" + sqlUpdateRowPostgres = "UPDATE %s SET p_type=$1,v0=$2,v1=$3,v2=$4,v3=$5,v4=$6,v5=$7 WHERE p_type=$8 AND v0=$9 AND v1=$10 AND v2=$11 AND v3=$12 AND v4=$13 AND v5=$14" + sqlDeleteRowPostgres = "DELETE FROM %s WHERE p_type=$1 AND v0=$2 AND v1=$3 AND v2=$4 AND v3=$5 AND v4=$6 AND v5=$7" +) + +// for Sqlserver +const ( + sqlCreateTableSqlserver = ` +CREATE TABLE %[1]s( + p_type NVARCHAR(32) DEFAULT '' NOT NULL, + v0 NVARCHAR(255) DEFAULT '' NOT NULL, + v1 NVARCHAR(255) DEFAULT '' NOT NULL, + v2 NVARCHAR(255) DEFAULT '' NOT NULL, + v3 NVARCHAR(255) DEFAULT '' NOT NULL, + v4 NVARCHAR(255) DEFAULT '' NOT NULL, + v5 NVARCHAR(255) DEFAULT '' NOT NULL +); +CREATE INDEX idx_%[1]s ON %[1]s (p_type, v0, v1);` + sqlInsertRowSqlserver = "INSERT INTO %s (p_type,v0,v1,v2,v3,v4,v5) VALUES (@p1,@p2,@p3,@p4,@p5,@p6,@p7)" + sqlUpdateRowSqlserver = "UPDATE %s SET p_type=@p1,v0=@p2,v1=@p3,v2=@p4,v3=@p5,v4=@p6,v5=@p7 WHERE p_type=@p8 AND v0=@p9 AND v1=@p10 AND v2=@p11 AND v3=@p12 AND v4=@p13 AND v5=@p14" + sqlDeleteRowSqlserver = "DELETE FROM %s WHERE p_type=@p1 AND v0=@p2 AND v1=@p3 AND v2=@p4 AND v3=@p5 AND v4=@p6 AND v5=@p7" +) + +// CasbinRule defines the casbin rule model. +// It used for save or load policy lines from sqlx connected database. +type SqlCasbinRule struct { + PType string `db:"p_type"` + V0 string `db:"v0"` + V1 string `db:"v1"` + V2 string `db:"v2"` + V3 string `db:"v3"` + V4 string `db:"v4"` + V5 string `db:"v5"` +} + +// Adapter define the sqlx adapter for Casbin. +// It can load policy lines or save policy lines from sqlx connected database. +type SqlAdapter struct { + db *sqlx.DB + ctx context.Context + tableName string + + isFiltered bool + + SqlCreateTable string + SqlTruncateTable string + SqlIsTableExist string + SqlInsertRow string + SqlUpdateRow string + SqlDeleteAll string + SqlDeleteRow string + SqlDeleteByArgs string + SqlSelectAll string + SqlSelectWhere string +} + +// Filter defines the filtering rules for a FilteredAdapter's policy. +// Empty values are ignored, but all others must match the filter. +type SqlFilter struct { + PType []string + V0 []string + V1 []string + V2 []string + V3 []string + V4 []string + V5 []string +} + +// NewAdapter the constructor for Adapter. +// db should connected to database and controlled by user. +// If tableName == "", the Adapter will automatically create a table named "casbin_rule". +func NewSqlAdapter(db *sqlx.DB, tableName string) (*SqlAdapter, error) { + return NewSqlAdapterContext(context.Background(), db, tableName) +} + +// NewAdapterContext the constructor for Adapter. +// db should connected to database and controlled by user. +// If tableName == "", the Adapter will automatically create a table named "casbin_rule". +func NewSqlAdapterContext(ctx context.Context, db *sqlx.DB, tableName string) (*SqlAdapter, error) { + if db == nil { + return nil, errors.New("db is nil") + } + + // check db connecting + err := db.PingContext(ctx) + if err != nil { + return nil, err + } + + switch db.DriverName() { + case "oci8", "ora", "goracle": + return nil, errors.New("sqlxadapter: please checkout 'oracle' branch") + } + + if tableName == "" { + tableName = defaultTableName + } + + adapter := SqlAdapter{ + db: db, + ctx: ctx, + tableName: tableName, + } + + // generate different databases sql + adapter.genSQL() + + if !adapter.IsTableExist() { + if err = adapter.CreateTable(); err != nil { + return nil, err + } + } + + return &adapter, nil +} + +// genSQL generate sql based on db driver name. +func (p *SqlAdapter) genSQL() { + p.SqlCreateTable = fmt.Sprintf(sqlCreateTable, p.tableName) + p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTable, p.tableName) + + p.SqlIsTableExist = fmt.Sprintf(sqlIsTableExist, p.tableName) + + p.SqlInsertRow = fmt.Sprintf(sqlInsertRow, p.tableName) + p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRow, p.tableName) + p.SqlDeleteAll = fmt.Sprintf(sqlDeleteAll, p.tableName) + p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRow, p.tableName) + p.SqlDeleteByArgs = fmt.Sprintf(sqlDeleteByArgs, p.tableName) + + p.SqlSelectAll = fmt.Sprintf(sqlSelectAll, p.tableName) + p.SqlSelectWhere = fmt.Sprintf(sqlSelectWhere, p.tableName) + + switch p.db.DriverName() { + case "postgres", "pgx", "pq-timeouts", "cloudsqlpostgres": + p.SqlCreateTable = fmt.Sprintf(sqlCreateTablePostgres, p.tableName) + p.SqlInsertRow = fmt.Sprintf(sqlInsertRowPostgres, p.tableName) + p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowPostgres, p.tableName) + p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowPostgres, p.tableName) + case "mysql": + p.SqlCreateTable = fmt.Sprintf(sqlCreateTableMysql, p.tableName) + case "sqlite3": + p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlite3, p.tableName) + p.SqlTruncateTable = fmt.Sprintf(sqlTruncateTableSqlite3, p.tableName) + case "sqlserver": + p.SqlCreateTable = fmt.Sprintf(sqlCreateTableSqlserver, p.tableName) + p.SqlInsertRow = fmt.Sprintf(sqlInsertRowSqlserver, p.tableName) + p.SqlUpdateRow = fmt.Sprintf(sqlUpdateRowSqlserver, p.tableName) + p.SqlDeleteRow = fmt.Sprintf(sqlDeleteRowSqlserver, p.tableName) + } +} + +// createTable create a not exists table. +func (p *SqlAdapter) CreateTable() error { + _, err := p.db.ExecContext(p.ctx, p.SqlCreateTable) + + return err +} + +// truncateTable clear the table. +func (p *SqlAdapter) TruncateTable() error { + _, err := p.db.ExecContext(p.ctx, p.SqlTruncateTable) + + return err +} + +// deleteAll clear the table. +func (p *SqlAdapter) DeleteAll() error { + _, err := p.db.ExecContext(p.ctx, p.SqlDeleteAll) + + return err +} + +// isTableExist check the table exists. +func (p *SqlAdapter) IsTableExist() bool { + _, err := p.db.ExecContext(p.ctx, p.SqlIsTableExist) + + return err == nil +} + +// deleteRows delete eligible data. +func (p *SqlAdapter) DeleteRows(query string, args ...interface{}) error { + query = p.db.Rebind(query) + + _, err := p.db.ExecContext(p.ctx, query, args...) + + return err +} + +// truncateAndInsertRows clear table and insert new rows. +func (p *SqlAdapter) TruncateAndInsertRows(rules [][]interface{}) error { + if err := p.TruncateTable(); err != nil { + return err + } + return p.execTxSqlRows(p.SqlInsertRow, rules) +} + +// deleteAllAndInsertRows clear table and insert new rows. +func (p *SqlAdapter) DeleteAllAndInsertRows(rules [][]interface{}) error { + if err := p.DeleteAll(); err != nil { + return err + } + return p.execTxSqlRows(p.SqlInsertRow, rules) +} + +// execTxSqlRows exec sql rows. +func (p *SqlAdapter) execTxSqlRows(query string, rules [][]interface{}) (err error) { + tx, err := p.db.BeginTx(p.ctx, nil) + if err != nil { + return + } + + var action string + + stmt, err := tx.PrepareContext(p.ctx, query) + if err != nil { + action = "prepare context" + goto ROLLBACK + } + + for _, rule := range rules { + if _, err = stmt.ExecContext(p.ctx, rule...); err != nil { + action = "stmt exec" + goto ROLLBACK + } + } + + if err = stmt.Close(); err != nil { + action = "stmt close" + goto ROLLBACK + } + + if err = tx.Commit(); err != nil { + action = "commit" + goto ROLLBACK + } + + return + +ROLLBACK: + + if err1 := tx.Rollback(); err1 != nil { + err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1) + } + + return +} + +// selectRows select eligible data by args from the table. +func (p *SqlAdapter) SelectRows(query string, args ...interface{}) ([]*SqlCasbinRule, error) { + // make a slice with capacity + lines := make([]*SqlCasbinRule, 0, 64) + + if len(args) == 0 { + return lines, p.db.SelectContext(p.ctx, &lines, query) + } + + query = p.db.Rebind(query) + + return lines, p.db.SelectContext(p.ctx, &lines, query, args...) +} + +// selectWhereIn select eligible data by filter from the table. +func (p *SqlAdapter) SelectWhereIn(filter *SqlFilter) (lines []*SqlCasbinRule, err error) { + var sqlBuf bytes.Buffer + + sqlBuf.Grow(64) + sqlBuf.WriteString(p.SqlSelectWhere) + + args := make([]interface{}, 0, 4) + + hasInCond := false + + for _, col := range [maxParamLength]struct { + name string + arg []string + }{ + {"p_type", filter.PType}, + {"v0", filter.V0}, + {"v1", filter.V1}, + {"v2", filter.V2}, + {"v3", filter.V3}, + {"v4", filter.V4}, + {"v5", filter.V5}, + } { + l := len(col.arg) + if l == 0 { + continue + } + + switch sqlBuf.Bytes()[sqlBuf.Len()-1] { + case '?', ')': + sqlBuf.WriteString(" AND ") + } + + sqlBuf.WriteString(col.name) + + if l == 1 { + sqlBuf.WriteString("=?") + args = append(args, col.arg[0]) + } else { + sqlBuf.WriteString(" IN (?)") + args = append(args, col.arg) + + hasInCond = true + } + } + + var query string + + if hasInCond { + if query, args, err = sqlx.In(sqlBuf.String(), args...); err != nil { + return + } + } else { + query = sqlBuf.String() + } + + return p.SelectRows(query, args...) +} + +// LoadPolicy load all policy rules from the storage. +func (p *SqlAdapter) LoadPolicy(model model.Model) error { + lines, err := p.SelectRows(p.SqlSelectAll) + if err != nil { + return err + } + + for _, line := range lines { + p.loadPolicyLine(line, model) + } + + return nil +} + +// SavePolicy save policy rules to the storage. +func (p *SqlAdapter) SavePolicy(model model.Model) error { + args := make([][]interface{}, 0, 64) + + for ptype, ast := range model["p"] { + for _, rule := range ast.Policy { + arg := p.GenArgs(ptype, rule) + args = append(args, arg) + } + } + + for ptype, ast := range model["g"] { + for _, rule := range ast.Policy { + arg := p.GenArgs(ptype, rule) + args = append(args, arg) + } + } + + return p.DeleteAllAndInsertRows(args) +} + +// AddPolicy add one policy rule to the storage. +func (p *SqlAdapter) AddPolicy(sec string, ptype string, rule []string) error { + args := p.GenArgs(ptype, rule) + + _, err := p.db.ExecContext(p.ctx, p.SqlInsertRow, args...) + + return err +} + +// AddPolicies add multiple policy rules to the storage. +func (p *SqlAdapter) AddPolicies(sec string, ptype string, rules [][]string) error { + args := make([][]interface{}, 0, 8) + + for _, rule := range rules { + arg := p.GenArgs(ptype, rule) + args = append(args, arg) + } + + return p.execTxSqlRows(p.SqlInsertRow, args) +} + +// RemovePolicy remove policy rules from the storage. +func (p *SqlAdapter) RemovePolicy(sec string, ptype string, rule []string) error { + var sqlBuf bytes.Buffer + + sqlBuf.Grow(64) + sqlBuf.WriteString(p.SqlDeleteByArgs) + + args := make([]interface{}, 0, 4) + args = append(args, ptype) + + for idx, arg := range rule { + if arg != "" { + sqlBuf.WriteString(" AND v") + sqlBuf.WriteString(strconv.Itoa(idx)) + sqlBuf.WriteString("=?") + + args = append(args, arg) + } + } + + return p.DeleteRows(sqlBuf.String(), args...) +} + +// RemoveFilteredPolicy remove policy rules that match the filter from the storage. +func (p *SqlAdapter) RemoveFilteredPolicy(sec string, ptype string, fieldIndex int, fieldValues ...string) error { + var sqlBuf bytes.Buffer + + sqlBuf.Grow(64) + sqlBuf.WriteString(p.SqlDeleteByArgs) + + args := make([]interface{}, 0, 4) + args = append(args, ptype) + + var value string + + l := fieldIndex + len(fieldValues) + + for idx := 0; idx < 6; idx++ { + if fieldIndex <= idx && idx < l { + value = fieldValues[idx-fieldIndex] + + if value != "" { + sqlBuf.WriteString(" AND v") + sqlBuf.WriteString(strconv.Itoa(idx)) + sqlBuf.WriteString("=?") + + args = append(args, value) + } + } + } + + return p.DeleteRows(sqlBuf.String(), args...) +} + +// RemovePolicies remove policy rules. +func (p *SqlAdapter) RemovePolicies(sec string, ptype string, rules [][]string) (err error) { + args := make([][]interface{}, 0, 8) + + for _, rule := range rules { + arg := p.GenArgs(ptype, rule) + args = append(args, arg) + } + + return p.execTxSqlRows(p.SqlDeleteRow, args) +} + +// LoadFilteredPolicy load policy rules that match the filter. +// filterPtr must be a pointer. +func (p *SqlAdapter) LoadFilteredPolicy(model model.Model, filterPtr interface{}) error { + if filterPtr == nil { + return p.LoadPolicy(model) + } + + filter, ok := filterPtr.(*SqlFilter) + if !ok { + return errors.New("invalid filter type") + } + + lines, err := p.SelectWhereIn(filter) + if err != nil { + return err + } + + for _, line := range lines { + p.loadPolicyLine(line, model) + } + + p.isFiltered = true + + return nil +} + +// IsFiltered returns true if the loaded policy rules has been filtered. +func (p *SqlAdapter) IsFiltered() bool { + return p.isFiltered +} + +// UpdatePolicy update a policy rule from storage. +// This is part of the Auto-Save feature. +func (p *SqlAdapter) UpdatePolicy(sec, ptype string, oldRule, newPolicy []string) error { + oldArg := p.GenArgs(ptype, oldRule) + newArg := p.GenArgs(ptype, newPolicy) + + _, err := p.db.ExecContext(p.ctx, p.SqlUpdateRow, append(newArg, oldArg...)...) + + return err +} + +// UpdatePolicies updates policy rules to storage. +func (p *SqlAdapter) UpdatePolicies(sec, ptype string, oldRules, newRules [][]string) (err error) { + if len(oldRules) != len(newRules) { + return errors.New("old rules size not equal to new rules size") + } + + args := make([][]interface{}, 0, 16) + + for idx := range oldRules { + oldArg := p.GenArgs(ptype, oldRules[idx]) + newArg := p.GenArgs(ptype, newRules[idx]) + args = append(args, append(newArg, oldArg...)) + } + + return p.execTxSqlRows(p.SqlUpdateRow, args) +} + +// UpdateFilteredPolicies deletes old rules and adds new rules. +func (p *SqlAdapter) UpdateFilteredPolicies(sec, ptype string, newPolicies [][]string, fieldIndex int, fieldValues ...string) (oldPolicies [][]string, err error) { + var value string + + var whereBuf bytes.Buffer + whereBuf.Grow(32) + + l := fieldIndex + len(fieldValues) + + whereArgs := make([]interface{}, 0, 4) + whereArgs = append(whereArgs, ptype) + + for idx := 0; idx < 6; idx++ { + if fieldIndex <= idx && idx < l { + value = fieldValues[idx-fieldIndex] + + if value != "" { + whereBuf.WriteString(" AND v") + whereBuf.WriteString(strconv.Itoa(idx)) + whereBuf.WriteString("=?") + + whereArgs = append(whereArgs, value) + } + } + } + + var selectBuf bytes.Buffer + selectBuf.Grow(64) + selectBuf.WriteString(p.SqlSelectWhere) + selectBuf.WriteString("p_type=?") + selectBuf.Write(whereBuf.Bytes()) + + var oldRows []*SqlCasbinRule + value = p.db.Rebind(selectBuf.String()) + oldRows, err = p.SelectRows(value, whereArgs...) + if err != nil { + return + } + + var deleteBuf bytes.Buffer + deleteBuf.Grow(64) + deleteBuf.WriteString(p.SqlDeleteByArgs) + deleteBuf.Write(whereBuf.Bytes()) + + var tx *sqlx.Tx + tx, err = p.db.BeginTxx(p.ctx, nil) + if err != nil { + return + } + + var ( + stmt *sqlx.Stmt + action string + ) + value = p.db.Rebind(deleteBuf.String()) + if _, err = tx.ExecContext(p.ctx, value, whereArgs...); err != nil { + action = "delete old policies" + goto ROLLBACK + } + + stmt, err = tx.PreparexContext(p.ctx, p.SqlInsertRow) + if err != nil { + action = "preparex context" + goto ROLLBACK + } + + for _, policy := range newPolicies { + arg := p.GenArgs(ptype, policy) + if _, err = stmt.ExecContext(p.ctx, arg...); err != nil { + action = "stmt exec context" + goto ROLLBACK + } + } + + if err = stmt.Close(); err != nil { + action = "stmt close" + goto ROLLBACK + } + + if err = tx.Commit(); err != nil { + action = "commit" + goto ROLLBACK + } + + oldPolicies = make([][]string, 0, len(oldRows)) + for _, rule := range oldRows { + oldPolicies = append(oldPolicies, []string{rule.PType, rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5}) + } + + return + +ROLLBACK: + + if err1 := tx.Rollback(); err1 != nil { + err = fmt.Errorf("%s err: %v, rollback err: %v", action, err, err1) + } + + return +} + +// loadPolicyLine load a policy line to model. +func (SqlAdapter) loadPolicyLine(line *SqlCasbinRule, model model.Model) { + if line == nil { + return + } + + var lineBuf bytes.Buffer + + lineBuf.Grow(64) + lineBuf.WriteString(line.PType) + + args := [6]string{line.V0, line.V1, line.V2, line.V3, line.V4, line.V5} + for _, arg := range args { + if arg != "" { + lineBuf.WriteByte(',') + lineBuf.WriteString(arg) + } + } + + persist.LoadPolicyLine(lineBuf.String(), model) +} + +// genArgs generate args from ptype and rule. +func (SqlAdapter) GenArgs(ptype string, rule []string) []interface{} { + l := len(rule) + + args := make([]interface{}, maxParamLength) + args[0] = ptype + + for idx := 0; idx < l; idx++ { + args[idx+1] = rule[idx] + } + + for idx := l + 1; idx < maxParamLength; idx++ { + args[idx] = "" + } + + return args +} diff --git a/gcasbin/adapter_sqlx_test.go b/gcasbin/adapter_sqlx_test.go new file mode 100644 index 0000000..c42109e --- /dev/null +++ b/gcasbin/adapter_sqlx_test.go @@ -0,0 +1,466 @@ +// +// adapter_sqlx_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcasbin_test + +import ( + "os" + "strings" + "testing" + + "github.com/casbin/casbin/v2" + "github.com/casbin/casbin/v2/util" + _ "github.com/go-sql-driver/mysql" + "github.com/jmoiron/sqlx" + "hexq.cn/tiglog/golib/gcasbin" +) + +const ( + rbacModelFile = "testdata/rbac_model.conf" + rbacPolicyFile = "testdata/rbac_policy.csv" +) + +var ( + dataSourceNames = map[string]string{ + // "sqlite3": ":memory:", + // "mysql": "root:@tcp(127.0.0.1:3306)/sqlx_adapter_test", + "postgres": os.Getenv("DB_DSN"), + // "sqlserver": "sqlserver://sa:YourPassword@127.0.0.1:1433?database=sqlx_adapter_test&connection+timeout=30", + } + + lines = []gcasbin.SqlCasbinRule{ + {PType: "p", V0: "alice", V1: "data1", V2: "read"}, + {PType: "p", V0: "bob", V1: "data2", V2: "read"}, + {PType: "p", V0: "bob", V1: "data2", V2: "write"}, + {PType: "p", V0: "data2_admin", V1: "data1", V2: "read", V3: "test1", V4: "test2", V5: "test3"}, + {PType: "p", V0: "data2_admin", V1: "data2", V2: "write", V3: "test1", V4: "test2", V5: "test3"}, + {PType: "p", V0: "data1_admin", V1: "data2", V2: "write"}, + {PType: "g", V0: "alice", V1: "data2_admin"}, + {PType: "g", V0: "bob", V1: "data2_admin", V2: "test"}, + {PType: "g", V0: "bob", V1: "data1_admin", V2: "test2", V3: "test3", V4: "test4", V5: "test5"}, + } + + filter = gcasbin.SqlFilter{ + PType: []string{"p"}, + V0: []string{"bob", "data2_admin"}, + V1: []string{"data1", "data2"}, + V2: []string{"read", "write"}, + V3: []string{"test1"}, + V4: []string{"test2"}, + V5: []string{"test3"}, + } +) + +func TestSqlAdapters(t *testing.T) { + for key, value := range dataSourceNames { + t.Logf("-------------------- test [%s] start, dataSourceName: [%s]", key, value) + + db, err := sqlx.Connect(key, value) + if err != nil { + t.Fatalf("sqlx.Connect failed, err: %v", err) + } + + t.Log("---------- testTableName start") + testTableName(t, db) + t.Log("---------- testTableName finished") + + t.Log("---------- testSQL start") + testSQL(t, db, "sqlxadapter_sql") + t.Log("---------- testSQL finished") + + t.Log("---------- testSaveLoad start") + testSaveLoad(t, db, "sqlxadapter_save_load") + t.Log("---------- testSaveLoad finished") + + t.Log("---------- testAutoSave start") + testAutoSave(t, db, "sqlxadapter_auto_save") + t.Log("---------- testAutoSave finished") + + t.Log("---------- testFilteredSqlPolicy start") + testFilteredSqlPolicy(t, db, "sqlxadapter_filtered_policy") + t.Log("---------- testFilteredSqlPolicy finished") + + // t.Log("---------- testUpdateSqlPolicy start") + // testUpdateSqlPolicy(t, db, "sqladapter_filtered_policy") + // t.Log("---------- testUpdateSqlPolicy finished") + + // t.Log("---------- testUpdateSqlPolicies start") + // testUpdateSqlPolicies(t, db, "sqladapter_filtered_policy") + // t.Log("---------- testUpdateSqlPolicies finished") + + // t.Log("---------- testUpdateFilteredSqlPolicies start") + // testUpdateFilteredSqlPolicies(t, db, "sqladapter_filtered_policy") + // t.Log("---------- testUpdateFilteredSqlPolicies finished") + + } +} + +func testTableName(t *testing.T, db *sqlx.DB) { + _, err := gcasbin.NewSqlAdapter(db, "") + if err != nil { + t.Fatalf("NewAdapter failed, err: %v", err) + } +} + +func testSQL(t *testing.T, db *sqlx.DB, tableName string) { + var err error + logErr := func(action string) { + if err != nil { + t.Errorf("%s test failed, err: %v", action, err) + } + } + + equalValue := func(line1, line2 gcasbin.SqlCasbinRule) bool { + if line1.PType != line2.PType || + line1.V0 != line2.V0 || + line1.V1 != line2.V1 || + line1.V2 != line2.V2 || + line1.V3 != line2.V3 || + line1.V4 != line2.V4 || + line1.V5 != line2.V5 { + return false + } + return true + } + + var a *gcasbin.SqlAdapter + a, err = gcasbin.NewSqlAdapter(db, tableName) + logErr("NewSqlAdapter") + + // createTable test has passed when adapter create + // err = a.CreateTable() + // logErr("createTable") + + if b := a.IsTableExist(); b == false { + t.Fatal("isTableExist test failed") + } + + rules := make([][]interface{}, len(lines)) + for idx, rule := range lines { + args := a.GenArgs(rule.PType, []string{rule.V0, rule.V1, rule.V2, rule.V3, rule.V4, rule.V5}) + rules[idx] = args + } + + err = a.TruncateAndInsertRows(rules) + logErr("truncateAndInsertRows") + + err = a.DeleteAllAndInsertRows(rules) + logErr("truncateAndInsertRows") + + err = a.DeleteRows(a.SqlDeleteByArgs, "g") + logErr("deleteRows sqlDeleteByArgs g") + + err = a.DeleteRows(a.SqlDeleteAll) + logErr("deleteRows sqlDeleteAll") + + _ = a.TruncateAndInsertRows(rules) + + records, err := a.SelectRows(a.SqlSelectAll) + logErr("selectRows sqlSelectAll") + for idx, record := range records { + line := lines[idx] + if !equalValue(*record, line) { + t.Fatalf("selectRows records test not equal, query record: %+v, need record: %+v", record, line) + } + } + + records, err = a.SelectWhereIn(&filter) + logErr("selectWhereIn") + i := 3 + for _, record := range records { + line := lines[i] + if !equalValue(*record, line) { + t.Fatalf("selectWhereIn records test not equal, query record: %+v, need record: %+v", record, line) + } + i++ + } + + err = a.TruncateTable() + logErr("truncateTable") +} + +func initSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) { + // Because the DB is empty at first, + // so we need to load the policy from the file adapter (.CSV) first. + e, _ := casbin.NewEnforcer(rbacModelFile, rbacPolicyFile) + + a, err := gcasbin.NewSqlAdapter(db, tableName) + if err != nil { + t.Fatal("NewAdapter test failed, err: ", err) + } + + // This is a trick to save the current policy to the DB. + // We can't call e.SavePolicy() because the adapter in the enforcer is still the file adapter. + // The current policy means the policy in the Casbin enforcer (aka in memory). + err = a.SavePolicy(e.GetModel()) + if err != nil { + t.Fatal("SavePolicy test failed, err: ", err) + } + + // Clear the current policy. + e.ClearPolicy() + testGetSqlPolicy(t, e, [][]string{}) + + // Load the policy from DB. + err = a.LoadPolicy(e.GetModel()) + if err != nil { + t.Fatal("LoadPolicy test failed, err: ", err) + } + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) +} + +func testSaveLoad(t *testing.T, db *sqlx.DB, tableName string) { + // Initialize some policy in DB. + initSqlPolicy(t, db, tableName) + // Note: you don't need to look at the above code + // if you already have a working DB with policy inside. + + // Now the DB has policy, so we can provide a normal use case. + // Create an adapter and an enforcer. + // NewEnforcer() will load the policy automatically. + a, _ := gcasbin.NewSqlAdapter(db, tableName) + e, _ := casbin.NewEnforcer(rbacModelFile, a) + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) +} + +func testAutoSave(t *testing.T, db *sqlx.DB, tableName string) { + // Initialize some policy in DB. + initSqlPolicy(t, db, tableName) + // Note: you don't need to look at the above code + // if you already have a working DB with policy inside. + + // Now the DB has policy, so we can provide a normal use case. + // Create an adapter and an enforcer. + // NewEnforcer() will load the policy automatically. + a, _ := gcasbin.NewSqlAdapter(db, tableName) + e, _ := casbin.NewEnforcer(rbacModelFile, a) + + // AutoSave is enabled by default. + // Now we disable it. + e.EnableAutoSave(false) + + var err error + logErr := func(action string) { + if err != nil { + t.Errorf("%s test failed, err: %v", action, err) + } + } + + // Because AutoSave is disabled, the policy change only affects the policy in Casbin enforcer, + // it doesn't affect the policy in the storage. + _, err = e.AddPolicy("alice", "data1", "write") + logErr("AddPolicy1") + // Reload the policy from the storage to see the effect. + err = e.LoadPolicy() + logErr("LoadPolicy1") + // This is still the original policy. + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + + _, err = e.AddPolicies([][]string{{"alice_1", "data_1", "read_1"}, {"bob_1", "data_1", "write_1"}}) + logErr("AddPolicies1") + // Reload the policy from the storage to see the effect. + err = e.LoadPolicy() + logErr("LoadPolicy2") + // This is still the original policy. + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + + // Now we enable the AutoSave. + e.EnableAutoSave(true) + + // Because AutoSave is enabled, the policy change not only affects the policy in Casbin enforcer, + // but also affects the policy in the storage. + _, err = e.AddPolicy("alice", "data1", "write") + logErr("AddPolicy2") + // Reload the policy from the storage to see the effect. + err = e.LoadPolicy() + logErr("LoadPolicy3") + // The policy has a new rule: {"alice", "data1", "write"}. + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}) + + _, err = e.AddPolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}}) + logErr("AddPolicies2") + // Reload the policy from the storage to see the effect. + err = e.LoadPolicy() + logErr("LoadPolicy4") + // This is still the original policy. + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}, + {"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}}) + + _, err = e.RemovePolicies([][]string{{"alice_2", "data_2", "read_2"}, {"bob_2", "data_2", "write_2"}}) + logErr("RemovePolicies") + err = e.LoadPolicy() + logErr("LoadPolicy5") + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"alice", "data1", "write"}}) + + // Remove the added rule. + _, err = e.RemovePolicy("alice", "data1", "write") + logErr("RemovePolicy") + err = e.LoadPolicy() + logErr("LoadPolicy6") + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + + // Remove "data2_admin" related policy rules via a filter. + // Two rules: {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"} are deleted. + _, err = e.RemoveFilteredPolicy(0, "data2_admin") + logErr("RemoveFilteredPolicy") + err = e.LoadPolicy() + logErr("LoadPolicy7") + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) +} + +func testFilteredSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) { + // Initialize some policy in DB. + initSqlPolicy(t, db, tableName) + // Note: you don't need to look at the above code + // if you already have a working DB with policy inside. + + // Now the DB has policy, so we can provide a normal use case. + // Create an adapter and an enforcer. + // NewEnforcer() will load the policy automatically. + a, _ := gcasbin.NewSqlAdapter(db, tableName) + e, _ := casbin.NewEnforcer(rbacModelFile, a) + // Now set the adapter + e.SetAdapter(a) + + var err error + logErr := func(action string) { + if err != nil { + t.Errorf("%s test failed, err: %v", action, err) + } + } + + // Load only alice's policies + err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice"}}) + logErr("LoadFilteredPolicy alice") + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}}) + + // Load only bob's policies + err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"bob"}}) + logErr("LoadFilteredPolicy bob") + testGetSqlPolicy(t, e, [][]string{{"bob", "data2", "write"}}) + + // Load policies for data2_admin + err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"data2_admin"}}) + logErr("LoadFilteredPolicy data2_admin") + testGetSqlPolicy(t, e, [][]string{{"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) + + // Load policies for alice and bob + err = e.LoadFilteredPolicy(&gcasbin.SqlFilter{V0: []string{"alice", "bob"}}) + logErr("LoadFilteredPolicy alice bob") + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "write"}}) + + _, err = e.AddPolicy("bob", "data1", "write", "test1", "test2", "test3") + logErr("AddPolicy") + + err = e.LoadFilteredPolicy(&filter) + logErr("LoadFilteredPolicy filter") + testGetSqlPolicy(t, e, [][]string{{"bob", "data1", "write", "test1", "test2", "test3"}}) +} + +func testUpdateSqlPolicy(t *testing.T, db *sqlx.DB, tableName string) { + // Initialize some policy in DB. + initSqlPolicy(t, db, tableName) + + a, _ := gcasbin.NewSqlAdapter(db, tableName) + e, _ := casbin.NewEnforcer(rbacModelFile, a) + + e.EnableAutoSave(true) + e.UpdatePolicy([]string{"alice", "data1", "read"}, []string{"alice", "data1", "write"}) + e.LoadPolicy() + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) +} + +func testUpdateSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) { + // Initialize some policy in DB. + initSqlPolicy(t, db, tableName) + + a, _ := gcasbin.NewSqlAdapter(db, tableName) + e, _ := casbin.NewEnforcer(rbacModelFile, a) + + e.EnableAutoSave(true) + e.UpdatePolicies([][]string{{"alice", "data1", "write"}, {"bob", "data2", "write"}}, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}}) + e.LoadPolicy() + testGetSqlPolicy(t, e, [][]string{{"alice", "data1", "read"}, {"bob", "data2", "read"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}}) +} + +func testUpdateFilteredSqlPolicies(t *testing.T, db *sqlx.DB, tableName string) { + // Initialize some policy in DB. + initSqlPolicy(t, db, tableName) + + a, _ := gcasbin.NewSqlAdapter(db, tableName) + e, _ := casbin.NewEnforcer(rbacModelFile, a) + + e.EnableAutoSave(true) + e.UpdateFilteredPolicies([][]string{{"alice", "data1", "write"}}, 0, "alice", "data1", "read") + e.UpdateFilteredPolicies([][]string{{"bob", "data2", "read"}}, 0, "bob", "data2", "write") + e.LoadPolicy() + testGetSqlPolicyWithoutOrder(t, e, [][]string{{"alice", "data1", "write"}, {"data2_admin", "data2", "read"}, {"data2_admin", "data2", "write"}, {"bob", "data2", "read"}}) +} + +func testGetSqlPolicy(t *testing.T, e *casbin.Enforcer, res [][]string) { + t.Helper() + myRes := e.GetPolicy() + t.Log("Policy: ", myRes) + + m := make(map[string]struct{}, len(myRes)) + for _, record := range myRes { + key := strings.Join(record, ",") + m[key] = struct{}{} + } + + for _, record := range res { + key := strings.Join(record, ",") + if _, ok := m[key]; !ok { + t.Error("Policy: \n", myRes, ", supposed to be \n", res) + break + } + } +} + +func testGetSqlPolicyWithoutOrder(t *testing.T, e *casbin.Enforcer, res [][]string) { + myRes := e.GetPolicy() + // log.Print("Policy: \n", myRes) + + if !arraySqlEqualsWithoutOrder(myRes, res) { + t.Error("Policy: \n", myRes, ", supposed to be \n", res) + } +} + +func arraySqlEqualsWithoutOrder(a [][]string, b [][]string) bool { + if len(a) != len(b) { + return false + } + + mapA := make(map[int]string) + mapB := make(map[int]string) + order := make(map[int]struct{}) + l := len(a) + + for i := 0; i < l; i++ { + mapA[i] = util.ArrayToString(a[i]) + mapB[i] = util.ArrayToString(b[i]) + } + + for i := 0; i < l; i++ { + for j := 0; j < l; j++ { + if _, ok := order[j]; ok { + if j == l-1 { + return false + } else { + continue + } + } + if mapA[i] == mapB[j] { + order[j] = struct{}{} + break + } else if j == l-1 { + return false + } + } + } + return true +} diff --git a/gcasbin/casbin.go b/gcasbin/casbin.go new file mode 100644 index 0000000..1dcd272 --- /dev/null +++ b/gcasbin/casbin.go @@ -0,0 +1,8 @@ +// +// casbin.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcasbin diff --git a/gcasbin/middleware.go b/gcasbin/middleware.go new file mode 100644 index 0000000..263e7e4 --- /dev/null +++ b/gcasbin/middleware.go @@ -0,0 +1,17 @@ +// +// middleware.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gcasbin + +import "github.com/gin-gonic/gin" + +func GinCasbin() gin.HandlerFunc { + return func(c *gin.Context) { + c.Next() + } + +} diff --git a/gcasbin/readme.adoc b/gcasbin/readme.adoc new file mode 100644 index 0000000..2217b58 --- /dev/null +++ b/gcasbin/readme.adoc @@ -0,0 +1,18 @@ += 授权 +:author: tiglog +:experimental: +:toc: left +:toclevels: 3 +:toc-title: 目录 +:sectnums: +:icons: font +:!webfonts: +:autofit-option: +:source-highlighter: rouge +:rouge-style: github +:source-linenums-option: +:revdate: 2022-12-01 +:imagesdir: ./img + + + diff --git a/gcasbin/testdata/model.conf b/gcasbin/testdata/model.conf new file mode 100644 index 0000000..286ccf2 --- /dev/null +++ b/gcasbin/testdata/model.conf @@ -0,0 +1,18 @@ +# Request definition +[request_definition] +r = sub, obj, act + +# Policy definition +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +# Policy effect +[policy_effect] +e = some(where (p.eft == allow)) + +# Matchers +[matchers] +m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act diff --git a/gcasbin/testdata/policy.csv b/gcasbin/testdata/policy.csv new file mode 100644 index 0000000..f74dd0f --- /dev/null +++ b/gcasbin/testdata/policy.csv @@ -0,0 +1,3 @@ +p, admin, data, read +p, admin, data, write +g, bob, admin diff --git a/gcasbin/testdata/rbac_model.conf b/gcasbin/testdata/rbac_model.conf new file mode 100644 index 0000000..9ca4b92 --- /dev/null +++ b/gcasbin/testdata/rbac_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[role_definition] +g = _, _ + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act diff --git a/gcasbin/testdata/rbac_policy.csv b/gcasbin/testdata/rbac_policy.csv new file mode 100644 index 0000000..8479f3c --- /dev/null +++ b/gcasbin/testdata/rbac_policy.csv @@ -0,0 +1,5 @@ +p, alice, data1, read +p, bob, data2, write +p, data2_admin, data2, read +p, data2_admin, data2, write +g, alice, data2_admin diff --git a/gcasbin/testdata/rbac_tenant_service.conf b/gcasbin/testdata/rbac_tenant_service.conf new file mode 100644 index 0000000..81486c4 --- /dev/null +++ b/gcasbin/testdata/rbac_tenant_service.conf @@ -0,0 +1,14 @@ +[request_definition] +r = tenant, sub, obj, act, service + +[policy_definition] +p =tenant, sub, obj, act, service, eft + +[role_definition] +g = _, _ + +[policy_effect] +e = priority(p.eft) || deny + +[matchers] +m = r.tenant == p.tenant && g(r.sub, p.sub) && keyMatch(r.obj, p.obj) && (r.act == p.act || p.act == "*") && (r.service == p.service || p.service == "*") diff --git a/gconfig/auth.go b/gconfig/auth.go new file mode 100644 index 0000000..6e2e33a --- /dev/null +++ b/gconfig/auth.go @@ -0,0 +1,13 @@ +// +// auth.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gconfig + +type AuthConfig struct { + TokenTtl int `yaml:"token_ttl"` + RefreshTtl int `yaml:"refresh_ttl"` +} diff --git a/gconfig/config.go b/gconfig/config.go new file mode 100644 index 0000000..5394cf8 --- /dev/null +++ b/gconfig/config.go @@ -0,0 +1,60 @@ +// +// config.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gconfig + +import ( + "bytes" + "io/ioutil" + "text/template" + + "hexq.cn/tiglog/golib/gfile" + + "gopkg.in/yaml.v2" + "hexq.cn/tiglog/golib/helper" +) + +type BaseConfig struct { + BaseDir string + Http HttpConfig `json:"http" yaml:"http"` + Auth AuthConfig `json:"auth" yaml:"auth"` + Param ParamConfig `json:"param" yaml:"param"` + Db DbConfig `json:"db" yaml:"db"` + Mongo MongoConfig `json:"mongo" yaml:"mongo"` + Redis RedisConfig `json:"redis" yaml:"redis"` +} + +func (c *BaseConfig) LoadParams() { + if c.BaseDir == "" { + c.BaseDir = "./etc" + } + c.Param.Load(c.BaseDir + "/params.yaml") +} + +func (c *BaseConfig) ParseAppConfig() { + buf := c.GetData("/app.yaml", true) + err := yaml.Unmarshal(buf.Bytes(), c) + helper.CheckErr(err) +} + +func (c *BaseConfig) GetData(fname string, must bool) *bytes.Buffer { + fp := c.BaseDir + fname + if !gfile.Exists(fp) { + if must { + panic("配置文件" + fp + "不存在") + } else { + return nil + } + } + dat, err := ioutil.ReadFile(fp) + helper.CheckErr(err) + tpl, err := template.New("config").Parse(string(dat)) + helper.CheckErr(err) + buf := new(bytes.Buffer) + tpl.Execute(buf, c.Param.Params) + return buf +} diff --git a/gconfig/db.go b/gconfig/db.go new file mode 100644 index 0000000..bd726d2 --- /dev/null +++ b/gconfig/db.go @@ -0,0 +1,56 @@ +// +// db.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gconfig + +import "fmt" + +type DbConfig struct { + Type string `yaml:"type"` + Host string `yaml:"host"` + Username string `yaml:"user"` + Password string `yaml:"pass"` + Port int `yaml:"port"` + Name string `yaml:"name"` +} + +type MongoConfig struct { + Host string `yaml:"host"` + Port int `yaml:"port"` + Username string `yaml:"user"` + Password string `yaml:"pass"` + Name string `yaml:"name"` + PoolSize int `yaml:"pool_size"` +} + +func (c *DbConfig) GetUri() string { + switch c.Type { + case "postgres": + return fmt.Sprintf("%s://host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", c.Type, c.Host, c.Port, c.Username, c.Password, c.Name) + case "mysql": + return fmt.Sprintf("%s://%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local", c.Type, c.Username, c.Password, c.Host, c.Port, c.Name) + } + return "" +} + +func (c *MongoConfig) GetUri() string { + if c.Host == "" { + return "" + } + if c.Username == "" { + return fmt.Sprintf("mongodb://%s:%d/%s", c.Host, c.Port, c.Name) + } else { + return fmt.Sprintf("mongodb://%s:%s@%s:%d/%s", c.Username, c.Password, c.Host, c.Port, c.Name) + } +} + +type RedisConfig struct { + Addr string `yaml:"addr"` + Username string `yaml:"user"` + Password string `yaml:"pass"` + Database int `yaml:"db"` +} diff --git a/gconfig/http.go b/gconfig/http.go new file mode 100644 index 0000000..c191196 --- /dev/null +++ b/gconfig/http.go @@ -0,0 +1,30 @@ +// +// http.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gconfig + +import "path/filepath" + +type HttpConfig struct { + Addr string `yaml:"addr"` + Env string `yaml:"env"` + Debug bool `yaml:"debug"` + Storage string `yaml:"storage"` // 存储文件的目录。如果不是绝对路径,前面的 "./" 也不需要 +} + +func (c HttpConfig) GetBaseDir() string { + dir, _ := filepath.Abs("./") + return dir +} + +// 全路径的 storage dir +func (c HttpConfig) GetStorageDir() string { + if c.Storage[0] == '/' { + return c.Storage + } + return filepath.Join(c.GetBaseDir(), "/", c.Storage) +} diff --git a/gconfig/param.go b/gconfig/param.go new file mode 100644 index 0000000..499e8b5 --- /dev/null +++ b/gconfig/param.go @@ -0,0 +1,33 @@ +// +// param.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gconfig + +import ( + "io/ioutil" + + "gopkg.in/yaml.v2" + "hexq.cn/tiglog/golib/gfile" +) + +type ParamConfig struct { + Params map[string]any +} + +func (c *ParamConfig) Load(fp string) { + if !gfile.Exists(fp) { + panic("配置文件 " + fp + " 不存在") + } + dat, err := ioutil.ReadFile(fp) + if err != nil { + panic(err) + } + err = yaml.Unmarshal(dat, &c.Params) + if err != nil { + panic(err) + } +} diff --git a/gconfig/testdata/app.yaml b/gconfig/testdata/app.yaml new file mode 100644 index 0000000..e69de29 diff --git a/gconfig/testdata/params.yaml b/gconfig/testdata/params.yaml new file mode 100644 index 0000000..e69de29 diff --git a/gconsts/err_code.go b/gconsts/err_code.go new file mode 100644 index 0000000..3c324f2 --- /dev/null +++ b/gconsts/err_code.go @@ -0,0 +1,24 @@ +// +// err_code.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gconsts + +const ( + ErrCodeNone = 0 // 正确的结果 + ErrCodeBadRequest = 40000 // 无效的请求 + ErrCodeValidateFail = 40001 // 验证失败 + ErrCodeNoLogin = 40100 // 没有认证 + ErrCodeNoToken = 40101 // 没有带 token + ErrCodeExpiredToken = 40102 // token 过期 + ErrCodeInvalidToken = 40103 // 1️无效的 token + ErrCodeNoPermission = 40300 // 没有权限 + ErrCodePageNotFound = 40400 // 页面不存在 + ErrCodeEntryNotFound = 40401 // 对象不存在 + ErrCodeResNotFound = 40402 // 资源不存在 + ErrCodeInternal = 50000 // 服务内部错误 + +) diff --git a/gdb/mgodb/bson.go b/gdb/mgodb/bson.go new file mode 100644 index 0000000..5c9e206 --- /dev/null +++ b/gdb/mgodb/bson.go @@ -0,0 +1,13 @@ +// +// bson.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package mgodb + +import "go.mongodb.org/mongo-driver/bson" + +type M = bson.M +type D = bson.D diff --git a/gdb/mgodb/error.go b/gdb/mgodb/error.go new file mode 100644 index 0000000..761b5b0 --- /dev/null +++ b/gdb/mgodb/error.go @@ -0,0 +1,33 @@ +// +// error.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package mgodb + +import "go.mongodb.org/mongo-driver/mongo" + +var ( + // ErrNilDocument is returned when a nil document is passed to a CRUD method. + ErrNilDocument = mongo.ErrNilDocument + + // ErrNilValue is returned when a nil value is passed to a CRUD method. + ErrNilValue = mongo.ErrNilValue + + // ErrNoDocuments is returned by SingleResult methods when the operation that + // created the SingleResult did not return any documents. + ErrNoDocuments = mongo.ErrNoDocuments +) + +func IsNoDocuments(err error) bool { + return err == mongo.ErrNoDocuments +} + +func IsNilValue(err error) bool { + return err == mongo.ErrNilValue +} +func IsNilDocument(err error) bool { + return err == mongo.ErrNilDocument +} diff --git a/gdb/mgodb/mongo.go b/gdb/mgodb/mongo.go new file mode 100644 index 0000000..1d345b7 --- /dev/null +++ b/gdb/mgodb/mongo.go @@ -0,0 +1,99 @@ +// +// mongo.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package mgodb + +import ( + "context" + "time" + + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" +) + +var Mm *sMongoManager + +type sMongoManager struct { + cli *mongo.Client + name string + poolSize int + uri string +} + +func Init(uri, dbName string, poolSize int) { + Mm = &sMongoManager{ + name: dbName, + uri: uri, + poolSize: poolSize, + } +} + +func (s *sMongoManager) SetDb(name string) { + s.name = name +} + +func (s *sMongoManager) Connect() error { + var err error + clientOptions := options.Client().ApplyURI(s.uri) + clientOptions.SetMaxPoolSize(uint64(s.poolSize)) + + // 连接到MongoDB + s.cli, err = mongo.Connect(context.TODO(), clientOptions) + if err != nil { + return err + } + // 检查连接 + err = s.cli.Ping(context.TODO(), nil) + if err != nil { + return err + } + return nil +} +func (s *sMongoManager) Client() (*mongo.Client, error) { + if s.cli == nil { + err := s.Connect() + if err != nil { + return nil, err + } + } + return s.cli, nil +} + +func (s *sMongoManager) Db() *mongo.Database { + c, err := s.Client() + if err != nil { + panic(err) + } + return c.Database(s.name) +} + +func (s *sMongoManager) Collection(name string) *mongo.Collection { + c, err := s.Client() + if err != nil { + panic(err) + } + return c.Database(s.name).Collection(name) +} + +func NewCtx() context.Context { + return context.Background() +} + +func NewTtlCtx(ttl time.Duration) (context.Context, context.CancelFunc) { + return context.WithTimeout(context.Background(), ttl) +} + +func (s *sMongoManager) Close() error { + if s.cli == nil { + return nil + } + err := s.cli.Disconnect(NewCtx()) + if err != nil { + return err + } + return nil +} diff --git a/gdb/mgodb/mongo_test.go b/gdb/mgodb/mongo_test.go new file mode 100644 index 0000000..e2e85c7 --- /dev/null +++ b/gdb/mgodb/mongo_test.go @@ -0,0 +1,38 @@ +// +// mongo_test.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package mgodb_test + +import ( + "os" + "testing" + + "go.mongodb.org/mongo-driver/bson" + "hexq.cn/tiglog/golib/gdb/mgodb" + "hexq.cn/tiglog/golib/gtest" +) + +func initMM() { + url := os.Getenv("MONGO_URL") + mgodb.Init(url, "test", 8) +} + +func TestConnect(t *testing.T) { + initMM() + err := mgodb.Mm.Connect() + gtest.NoError(t, err) +} + +func TestListDatabases(t *testing.T) { + initMM() + cli, err := mgodb.Mm.Client() + gtest.NoError(t, err) + gtest.NotNil(t, cli) + names, err := cli.ListDatabaseNames(nil, bson.D{}) + gtest.NoError(t, err) + gtest.Greater(t, 0, names) +} diff --git a/gdb/sqldb/base_test.go b/gdb/sqldb/base_test.go new file mode 100644 index 0000000..4f0675b --- /dev/null +++ b/gdb/sqldb/base_test.go @@ -0,0 +1,180 @@ +// +// base_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb_test + +import ( + "database/sql" + "fmt" + "os" + "strings" + "testing" + + _ "github.com/lib/pq" + "hexq.cn/tiglog/golib/gdb/sqldb" + // _ "github.com/go-sql-driver/mysql" +) + +type Schema struct { + create string + drop string +} + +var defaultSchema = Schema{ + create: ` +CREATE TABLE person ( + id serial, + first_name text, + last_name text, + email text, + added_at int default 0, + PRIMARY KEY (id) +); + +CREATE TABLE place ( + country text, + city text NULL, + telcode integer +); + +CREATE TABLE capplace ( + country text, + city text NULL, + telcode integer +); + +CREATE TABLE nullperson ( + first_name text NULL, + last_name text NULL, + email text NULL +); + +CREATE TABLE employees ( + name text, + id integer, + boss_id integer +); + +`, + drop: ` +drop table person; +drop table place; +drop table capplace; +drop table nullperson; +drop table employees; +`, +} + +type Person struct { + Id int64 `db:"id"` + FirstName string `db:"first_name"` + LastName string `db:"last_name"` + Email string `db:"email"` + AddedAt int64 `db:"added_at"` +} + +type Person2 struct { + FirstName sql.NullString `db:"first_name"` + LastName sql.NullString `db:"last_name"` + Email sql.NullString +} + +type Place struct { + Country string + City sql.NullString + TelCode int +} + +type PlacePtr struct { + Country string + City *string + TelCode int +} + +type PersonPlace struct { + Person + Place +} + +type PersonPlacePtr struct { + *Person + *Place +} + +type EmbedConflict struct { + FirstName string `db:"first_name"` + Person +} + +type SliceMember struct { + Country string + City sql.NullString + TelCode int + People []Person `db:"-"` + Addresses []Place `db:"-"` +} + +func loadDefaultFixture(db *sqldb.Engine, t *testing.T) { + tx := db.MustBegin() + + s1 := "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)" + tx.MustExec(db.Rebind(s1), "Jason", "Moiron", "jmoiron@jmoiron.net") + + s1 = "INSERT INTO person (first_name, last_name, email) VALUES (?, ?, ?)" + tx.MustExec(db.Rebind(s1), "John", "Doe", "johndoeDNE@gmail.net") + + s1 = "INSERT INTO place (country, city, telcode) VALUES (?, ?, ?)" + tx.MustExec(db.Rebind(s1), "United States", "New York", "1") + + s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)" + tx.MustExec(db.Rebind(s1), "Hong Kong", "852") + + s1 = "INSERT INTO place (country, telcode) VALUES (?, ?)" + tx.MustExec(db.Rebind(s1), "Singapore", "65") + + s1 = "INSERT INTO capplace (country, telcode) VALUES (?, ?)" + tx.MustExec(db.Rebind(s1), "Sarf Efrica", "27") + + s1 = "INSERT INTO employees (name, id) VALUES (?, ?)" + tx.MustExec(db.Rebind(s1), "Peter", "4444") + + s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)" + tx.MustExec(db.Rebind(s1), "Joe", "1", "4444") + + s1 = "INSERT INTO employees (name, id, boss_id) VALUES (?, ?, ?)" + tx.MustExec(db.Rebind(s1), "Martin", "2", "4444") + tx.Commit() +} + +func MultiExec(e *sqldb.Engine, query string) { + stmts := strings.Split(query, ";\n") + if len(strings.Trim(stmts[len(stmts)-1], " \n\t\r")) == 0 { + stmts = stmts[:len(stmts)-1] + } + for _, s := range stmts { + _, err := e.Exec(s) + if err != nil { + fmt.Println(err, s) + } + } +} + +func RunDbTest(t *testing.T, test func(db *sqldb.Engine, t *testing.T)) { + // 先初始化数据库 + url := os.Getenv("DB_URL") + var db = sqldb.New(url) + + // 再注册清空数据库 + defer func() { + MultiExec(db, defaultSchema.drop) + }() + // 再加入一些数据 + MultiExec(db, defaultSchema.create) + loadDefaultFixture(db, t) + // 最后测试 + test(db, t) +} diff --git a/gdb/sqldb/db.go b/gdb/sqldb/db.go new file mode 100644 index 0000000..cec2764 --- /dev/null +++ b/gdb/sqldb/db.go @@ -0,0 +1,58 @@ +// +// db.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "database/sql" + "errors" + "strings" + + "github.com/jmoiron/sqlx" +) + +var Db *Engine + +type Engine struct { + *sqlx.DB +} + +var ErrNoRows = sql.ErrNoRows + +type DbOption struct { + Url string + MaxOpenConns int + MaxIdleConns int +} + +// mysql://[username[:password]@][protocol[(address)]]/dbname[?param1=value1&...¶mN=valueN] +// pgsql://host=X.X.X.X port=54321 user=postgres password=admin123 dbname=postgres sslmode=disable" +func NewWithOption(opt *DbOption) *Engine { + urls := strings.Split(opt.Url, "://") + if len(urls) != 2 { + panic(errors.New("wrong database url:" + opt.Url)) + } + dbx, err := sqlx.Open(urls[0], urls[1]) + if err != nil { + panic(err) + } + dbx.SetMaxIdleConns(opt.MaxIdleConns) + dbx.SetMaxOpenConns(opt.MaxOpenConns) + err = dbx.Ping() + if err != nil { + panic(err) + } + Db = &Engine{ + dbx, + } + return Db +} + +func New(url string) *Engine { + opt := &DbOption{Url: url, MaxOpenConns: 256, MaxIdleConns: 2} + return NewWithOption(opt) +} diff --git a/gdb/sqldb/db_func.go b/gdb/sqldb/db_func.go new file mode 100644 index 0000000..c77153f --- /dev/null +++ b/gdb/sqldb/db_func.go @@ -0,0 +1,220 @@ +// +// db_func.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "github.com/jmoiron/sqlx" +) + +func (e *Engine) Begin() (*sqlx.Tx, error) { + return e.Beginx() +} + +// 插入一条记录 +func (e *Engine) NamedInsertRecord(opt *QueryOption, arg interface{}) (int64, error) { // {{{ + if len(opt.fields) == 0 { + return 0, errors.New("empty fields") + } + var tmp = make([]string, 0) + for _, field := range opt.fields { + tmp = append(tmp, fmt.Sprintf(":%s", field)) + } + fields_str := strings.Join(opt.fields, ",") + fields_pl := strings.Join(tmp, ",") + sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s)", opt.table, fields_str, fields_pl) + if e.DriverName() == "postgres" { + sql += " returning id" + } + // sql = e.Rebind(sql) + stmt, err := e.PrepareNamed(sql) + if err != nil { + return 0, err + } + var id int64 + err = stmt.Get(&id, arg) + if err != nil { + return 0, err + } + return id, err +} // }}} + +// 插入一条记录 +func (e *Engine) InsertRecord(opt *QueryOption) (int64, error) { // {{{ + if len(opt.fields) == 0 { + return 0, errors.New("empty fields") + } + fields_str := strings.Join(opt.fields, ",") + fields_pl := strings.TrimRight(strings.Repeat("?,", len(opt.fields)), ",") + sql := fmt.Sprintf("INSERT INTO %s (%s) VALUES (%s);", opt.table, fields_str, fields_pl) + if e.DriverName() == "postgres" { + sql += " returning id" + } + sql = e.Rebind(sql) + result, err := e.Exec(sql, opt.args...) + if err != nil { + return 0, err + } + return result.LastInsertId() +} // }}} + +// 查询一条记录 +// dest 目标对象 +// table 查询表 +// query 查询条件 +// args bindvars +func (e *Engine) GetRecord(dest interface{}, opt *QueryOption) error { // {{{ + if opt.query == "" { + return errors.New("empty query") + } + opt.query = "WHERE " + opt.query + sql := fmt.Sprintf("SELECT * FROM %s %s limit 1", opt.table, opt.query) + sql = e.Rebind(sql) + err := e.Get(dest, sql, opt.args...) + if err != nil { + return err + } + return nil +} // }}} + +// 查询多条记录 +// dest 目标变量 +// opt 查询对象 +// args bindvars +func (e *Engine) GetRecords(dest interface{}, opt *QueryOption) error { // {{{ + var tmp = []string{} + if opt.query != "" { + tmp = append(tmp, "where", opt.query) + } + if opt.sort != "" { + tmp = append(tmp, "order by", opt.sort) + } + if opt.offset > 0 { + tmp = append(tmp, "offset", strconv.Itoa(opt.offset)) + } + if opt.limit > 0 { + tmp = append(tmp, "limit", strconv.Itoa(opt.limit)) + } + sql := fmt.Sprintf("select * from %s %s", opt.table, strings.Join(tmp, " ")) + sql = e.Rebind(sql) + return e.Select(dest, sql, opt.args...) +} // }}} + +// 更新一条记录 +// table 待处理的表 +// set 需要设置的语句, eg: age=:age +// query 查询语句,不能为空,确保误更新所有记录 +// arg 值 +func (e *Engine) NamedUpdateRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{ + if opt.set == "" || opt.query == "" { + return 0, errors.New("empty set or query") + } + sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query) + result, err := e.NamedExec(sql, arg) + if err != nil { + return 0, err + } + rows, err := result.RowsAffected() + if err != nil { + return 0, err + } + return rows, nil +} // }}} + +func (e *Engine) UpdateRecords(opt *QueryOption) (int64, error) { // {{{ + if opt.set == "" || opt.query == "" { + return 0, errors.New("empty set or query") + } + sql := fmt.Sprintf("update %s set %s where %s", opt.table, opt.set, opt.query) + sql = e.Rebind(sql) + result, err := e.Exec(sql, opt.args...) + if err != nil { + return 0, err + } + rows, err := result.RowsAffected() + if err != nil { + return 0, err + } + return rows, nil +} // }}} + +// 删除若干条记录 +// opt 的 query 不能为空 +// arg bindvars +func (e *Engine) NamedDeleteRecords(opt *QueryOption, arg interface{}) (int64, error) { // {{{ + if opt.query == "" { + return 0, errors.New("emtpy query") + } + sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query) + result, err := e.NamedExec(sql, arg) + if err != nil { + return 0, err + } + rows, err := result.RowsAffected() + if err != nil { + return 0, err + } + return rows, nil +} // }}} + +func (e *Engine) DeleteRecords(opt *QueryOption) (int64, error) { + if opt.query == "" { + return 0, errors.New("emtpy query") + } + sql := fmt.Sprintf("delete from %s where %s", opt.table, opt.query) + sql = e.Rebind(sql) + result, err := e.Exec(sql, opt.args...) + if err != nil { + return 0, err + } + rows, err := result.RowsAffected() + if err != nil { + return 0, err + } + return rows, nil +} + +func (e *Engine) CountRecords(opt *QueryOption) (int, error) { + sql := fmt.Sprintf("select count(*) from %s where %s", opt.table, opt.query) + sql = e.Rebind(sql) + var num int + err := e.Get(&num, sql, opt.args...) + if err != nil { + return 0, err + } + return num, nil +} + +// var levels = []int{4, 6, 7} +// query, args, err := sqlx.In("SELECT * FROM users WHERE level IN (?);", levels) +// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend +// query = db.Rebind(query) +// rows, err := db.Query(query, args...) +func (e *Engine) In(query string, args ...interface{}) (string, []interface{}, error) { + return sqlx.In(query, args...) +} + +func IsNoRows(err error) bool { + return err == ErrNoRows +} + +// 把 fields 转换为 field1=:field1, field2=:field2, ..., fieldN=:fieldN +func GetSetString(fields []string) string { + items := []string{} + for _, field := range fields { + if field == "id" { + continue + } + items = append(items, fmt.Sprintf("%s=:%s", field, field)) + } + return strings.Join(items, ",") +} diff --git a/gdb/sqldb/db_func_opt.go b/gdb/sqldb/db_func_opt.go new file mode 100644 index 0000000..1d596c5 --- /dev/null +++ b/gdb/sqldb/db_func_opt.go @@ -0,0 +1,75 @@ +// +// db_func_opt.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +type QueryOption struct { + table string + query string + set string + fields []string + sort string + offset int + limit int + args []any + joins []string +} + +func NewQueryOption(table string) *QueryOption { + return &QueryOption{ + table: table, + fields: []string{"*"}, + offset: 0, + limit: 0, + args: make([]any, 0), + joins: make([]string, 0), + } +} +func (opt *QueryOption) Query(query string) *QueryOption { + opt.query = query + return opt +} +func (opt *QueryOption) Fields(args []string) *QueryOption { + opt.fields = args + return opt +} +func (opt *QueryOption) Select(cols ...string) *QueryOption { + opt.fields = cols + return opt +} +func (opt *QueryOption) Offset(offset int) *QueryOption { + opt.offset = offset + return opt +} +func (opt *QueryOption) Limit(limit int) *QueryOption { + opt.limit = limit + return opt +} +func (opt *QueryOption) Sort(sort string) *QueryOption { + opt.sort = sort + return opt +} +func (opt *QueryOption) Set(set string) *QueryOption { + opt.set = set + return opt +} +func (opt *QueryOption) Args(args ...any) *QueryOption { + opt.args = args + return opt +} +func (opt *QueryOption) Join(table string, cond string) *QueryOption { + opt.joins = append(opt.joins, "join "+table+" on "+cond) + return opt +} +func (opt *QueryOption) LeftJoin(table string, cond string) *QueryOption { + opt.joins = append(opt.joins, "left join "+table+" on "+cond) + return opt +} +func (opt *QueryOption) RightJoin(table string, cond string) *QueryOption { + opt.joins = append(opt.joins, "right join "+table+" on "+cond) + return opt +} diff --git a/gdb/sqldb/db_func_test.go b/gdb/sqldb/db_func_test.go new file mode 100644 index 0000000..a2e027f --- /dev/null +++ b/gdb/sqldb/db_func_test.go @@ -0,0 +1,114 @@ +// +// db_func_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb_test + +import ( + "testing" + "time" + + "hexq.cn/tiglog/golib/gdb/sqldb" + "hexq.cn/tiglog/golib/gtest" +) + +// 经过测试,发现数据库里面使用 time 类型容易出现 timezone 不一致的情况 +// 在存入数据库时,可能会导致时区丢失 +// 因此,为了更好的兼容性,使用 int 时间戳会更合适 +func dbFuncTest(db *sqldb.Engine, t *testing.T) { + var err error + fields := []string{"first_name", "last_name", "email"} + p := &Person{ + FirstName: "三", + LastName: "张", + Email: "zs@foo.com", + } + // InsertRecord 的用法 + opt := sqldb.NewQueryOption("person").Fields(fields) + rows, err := db.NamedInsertRecord(opt, p) + gtest.Nil(t, err) + gtest.True(t, rows > 0) + // fmt.Println(rows) + + // GetRecord 的用法 + var p3 Person + opt = sqldb.NewQueryOption("person").Query("email=?").Args("zs@foo.com") + err = db.GetRecord(&p3, opt) + // fmt.Println(p3) + gtest.Equal(t, "张", p3.LastName) + gtest.Equal(t, "三", p3.FirstName) + gtest.Equal(t, int64(0), p3.AddedAt) + gtest.Nil(t, err) + + p2 := &Person{ + FirstName: "四", + LastName: "李", + Email: "ls@foo.com", + AddedAt: time.Now().Unix(), + } + fields2 := append(fields, "added_at") + opt = sqldb.NewQueryOption("person").Fields(fields2) + _, err = db.NamedInsertRecord(opt, p2) + gtest.Nil(t, err) + + var p4 Person + opt = sqldb.NewQueryOption("person") + err = db.GetRecord(&p4, opt) + gtest.NotNil(t, err) + gtest.Equal(t, "", p4.FirstName) + + opt = sqldb.NewQueryOption("person").Query("first_name=?").Args("四") + err = db.GetRecord(&p4, opt) + gtest.Nil(t, err) + gtest.Equal(t, time.Now().Unix(), p4.AddedAt) + gtest.Equal(t, "ls@foo.com", p4.Email) + + // GetRecords + var ps []Person + opt = sqldb.NewQueryOption("person").Query("id > ?").Args(0) + err = db.GetRecords(&ps, opt) + gtest.Nil(t, err) + gtest.Greater(t, int64(1), ps) + + var ps2 []Person + opt = sqldb.NewQueryOption("person").Query("id=?").Args(1) + err = db.GetRecords(&ps2, opt) + gtest.Equal(t, 1, len(ps2)) + if len(ps2) > 1 { + gtest.Equal(t, int64(1), ps2[0].Id) + } + + // DeleteRecords + opt = sqldb.NewQueryOption("person").Query("id=?").Args(2) + n, err := db.DeleteRecords(opt) + gtest.Nil(t, err) + gtest.Greater(t, int64(0), n) + + // UpdateRecords + opt = sqldb.NewQueryOption("person").Set("first_name=?").Query("email=?").Args("哈哈", "zs@foo.com") + n, err = db.UpdateRecords(opt) + gtest.Nil(t, err) + gtest.Greater(t, int64(0), n) + + // NamedUpdateRecords + var p5 = ps[0] + p5.FirstName = "中华人民共和国" + opt = sqldb.NewQueryOption("person").Set("first_name=:first_name").Query("email=:email") + n, err = db.NamedUpdateRecords(opt, p5) + gtest.Nil(t, err) + gtest.Greater(t, int64(0), n) + + var p6 Person + opt = sqldb.NewQueryOption("person").Query("first_name=?").Args(p5.FirstName) + err = db.GetRecord(&p6, opt) + gtest.Nil(t, err) + gtest.Greater(t, int64(0), p6.Id) + gtest.Equal(t, p6.FirstName, p5.FirstName) +} + +func TestFunc(t *testing.T) { + RunDbTest(t, dbFuncTest) +} diff --git a/gdb/sqldb/db_model.go b/gdb/sqldb/db_model.go new file mode 100644 index 0000000..7a68fd3 --- /dev/null +++ b/gdb/sqldb/db_model.go @@ -0,0 +1,20 @@ +// +// db_model.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +// TODO 暂时不好实现,以后再说 + +type Model struct { + db *Engine +} + +func NewModel() *Model { + return &Model{ + db: Db, + } +} diff --git a/gdb/sqldb/db_query.go b/gdb/sqldb/db_query.go new file mode 100644 index 0000000..5acda3c --- /dev/null +++ b/gdb/sqldb/db_query.go @@ -0,0 +1,322 @@ +// +// db_query.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb + +import ( + "errors" + "fmt" + "strconv" + "strings" + + "github.com/jmoiron/sqlx" +) + +type Query struct { + db *Engine + table string + fields []string + wheres []string // 不能太复杂 + joins []string + orderBy string + groupBy string + offset int + limit int +} + +func NewQueryBuild(table string, db *Engine) *Query { + return &Query{ + db: db, + table: table, + fields: []string{}, + wheres: []string{}, + joins: []string{}, + offset: 0, + limit: 0, + } +} + +func (q *Query) Table(table string) *Query { + q.table = table + return q +} + +// 设置 select fields +func (q *Query) Select(fields ...string) *Query { + q.fields = fields + return q +} + +// 增加一个 select field +func (q *Query) AddFields(fields ...string) *Query { + q.fields = append(q.fields, fields...) + return q +} + +func (q *Query) Where(query string) *Query { + q.wheres = []string{query} + return q +} +func (q *Query) AndWhere(query string) *Query { + q.wheres = append(q.wheres, "and "+query) + return q +} + +func (q *Query) OrWhere(query string) *Query { + q.wheres = append(q.wheres, "or "+query) + return q +} + +func (q *Query) Join(table string, on string) *Query { + var join = "join " + table + if on != "" { + join = join + " on " + on + } + q.joins = append(q.joins, join) + return q +} + +func (q *Query) LeftJoin(table string, on string) *Query { + var join = "left join " + table + if on != "" { + join = join + " on " + on + } + q.joins = append(q.joins, join) + return q +} + +func (q *Query) RightJoin(table string, on string) *Query { + var join = "right join " + table + if on != "" { + join = join + " on " + on + } + q.joins = append(q.joins, join) + return q +} + +func (q *Query) InnerJoin(table string, on string) *Query { + var join = "inner join " + table + if on != "" { + join = join + " on " + on + } + q.joins = append(q.joins, join) + return q +} + +func (q *Query) OrderBy(order string) *Query { + q.orderBy = order + return q +} +func (q *Query) GroupBy(group string) *Query { + q.groupBy = group + return q +} + +func (q *Query) Offset(offset int) *Query { + q.offset = offset + return q +} + +func (q *Query) Limit(limit int) *Query { + q.limit = limit + return q +} + +// returningId postgres 数据库返回 LastInsertId 处理 +// TODO returningId 暂时不处理 +func (q *Query) getInsertSql(named, returningId bool) string { + fields_str := strings.Join(q.fields, ",") + var pl string + if named { + var tmp []string + for _, field := range q.fields { + tmp = append(tmp, ":"+field) + } + pl = strings.Join(tmp, ",") + } else { + pl = strings.Repeat("?,", len(q.fields)) + pl = strings.TrimRight(pl, ",") + } + + sql := fmt.Sprintf("insert into %s (%s) values (%s);", q.table, fields_str, pl) + sql = q.db.Rebind(sql) + // fmt.Println(sql) + return sql +} + +// return RowsAffected, error +func (q *Query) Insert(args ...interface{}) (int64, error) { + if len(q.fields) == 0 { + return 0, errors.New("empty fields") + } + sql := q.getInsertSql(false, false) + result, err := q.db.Exec(sql, args...) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +// return RowsAffected, error +func (q *Query) NamedInsert(arg interface{}) (int64, error) { + if len(q.fields) == 0 { + return 0, errors.New("empty fields") + } + sql := q.getInsertSql(true, false) + result, err := q.db.NamedExec(sql, arg) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +func (q *Query) getQuerySql() string { + var ( + fields_str string = "*" + join_str string + where_str string + offlim string + ) + if len(q.fields) > 0 { + fields_str = strings.Join(q.fields, ",") + } + + if len(q.joins) > 0 { + join_str = strings.Join(q.joins, " ") + } + if len(q.wheres) > 0 { + where_str = "where " + strings.Join(q.wheres, " ") + } + + if q.offset > 0 { + offlim = " offset " + strconv.Itoa(q.offset) + } + if q.limit > 0 { + offlim = " limit " + strconv.Itoa(q.limit) + } + // select fields from table t join where groupby orderby offset limit + sql := fmt.Sprintf("select %s from %s t %s %s %s %s%s", fields_str, q.table, join_str, where_str, q.groupBy, q.orderBy, offlim) + return sql +} + +func (q *Query) One(dest interface{}, args ...interface{}) error { + q.Limit(1) + sql := q.getQuerySql() + sql = q.db.Rebind(sql) + return q.db.Get(dest, sql, args...) +} + +func (q *Query) NamedOne(dest interface{}, arg interface{}) error { + q.Limit(1) + sql := q.getQuerySql() + rows, err := q.db.NamedQuery(sql, arg) + if err != nil { + return err + } + if rows.Next() { + return rows.Scan(dest) + } + return errors.New("nr") // no record +} + +func (q *Query) All(dest interface{}, args ...interface{}) error { + sql := q.getQuerySql() + sql = q.db.Rebind(sql) + return q.db.Select(dest, sql, args...) +} + +// 为了省内存,直接返回迭代器 +func (q *Query) NamedAll(dest interface{}, arg interface{}) (*sqlx.Rows, error) { + sql := q.getQuerySql() + return q.db.NamedQuery(sql, arg) +} + +// set age=? / age=:age +func (q *Query) NamedUpdate(set string, arg interface{}) (int64, error) { + var where_str string + if len(q.wheres) > 0 { + where_str = strings.Join(q.wheres, " ") + } + if set == "" || where_str == "" { + return 0, errors.New("empty set or where") + } + + // update table t where + sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str) + sql = q.db.Rebind(sql) + result, err := q.db.NamedExec(sql, arg) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +// 顺序容易弄反,记得先是 set 的参数,再是 where 里面的参数 +func (q *Query) Update(set string, args ...interface{}) (int64, error) { + var where_str string + if len(q.wheres) > 0 { + where_str = strings.Join(q.wheres, " ") + } + if set == "" || where_str == "" { + return 0, errors.New("empty set or where") + } + + // update table t where + sql := fmt.Sprintf("update %s t set %s where %s", q.table, set, where_str) + sql = q.db.Rebind(sql) + result, err := q.db.Exec(sql, args...) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +// 普通的删除 +func (q *Query) Delete(args ...interface{}) (int64, error) { + var where_str string + if len(q.wheres) == 0 { + return 0, errors.New("missing where clause") + } + where_str = strings.Join(q.wheres, " ") + + sql := fmt.Sprintf("delete from %s where %s", q.table, where_str) + sql = q.db.Rebind(sql) + result, err := q.db.Exec(sql, args...) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +func (q *Query) NamedDelete(arg interface{}) (int64, error) { + if len(q.wheres) == 0 { + return 0, errors.New("missing where clause") + } + var where_str string + where_str = strings.Join(q.wheres, " ") + + sql := fmt.Sprintf("delete from %s where %s", q.table, where_str) + sql = q.db.Rebind(sql) + + result, err := q.db.NamedExec(sql, arg) + if err != nil { + return 0, err + } + return result.RowsAffected() +} + +func (q *Query) Count(args ...interface{}) (int64, error) { + var where_str string + if len(q.wheres) > 0 { + where_str = " where " + strings.Join(q.wheres, " ") + } + sql := fmt.Sprintf("select count(1) as num from %s t%s", q.table, where_str) + sql = q.db.Rebind(sql) + var num int64 + err := q.db.Get(&num, sql, args...) + return num, err +} diff --git a/gdb/sqldb/db_query_test.go b/gdb/sqldb/db_query_test.go new file mode 100644 index 0000000..754c066 --- /dev/null +++ b/gdb/sqldb/db_query_test.go @@ -0,0 +1,109 @@ +// +// db_query_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package sqldb_test + +import ( + "testing" + "time" + + "hexq.cn/tiglog/golib/gtest" + + "hexq.cn/tiglog/golib/gdb/sqldb" +) + +func dbQueryTest(db *sqldb.Engine, t *testing.T) { + query := sqldb.NewQueryBuild("person", db) + // query one + var p1 Person + query.Where("id=?") + err := query.One(&p1, 1) + gtest.Nil(t, err) + gtest.Equal(t, int64(1), p1.Id) + + // query all + var ps1 []Person + query = sqldb.NewQueryBuild("person", db) + query.Where("id > ?") + err = query.All(&ps1, 1) + gtest.Nil(t, err) + gtest.True(t, len(ps1) > 0) + // fmt.Println(ps1) + if len(ps1) > 0 { + var val int64 = 2 + gtest.Equal(t, val, ps1[0].Id) + } + + // insert + query = sqldb.NewQueryBuild("person", db) + query.AddFields("first_name", "last_name", "email") + id, err := query.Insert("三", "张", "zs@bar.com") + gtest.Nil(t, err) + gtest.Greater(t, int64(0), id) + // fmt.Println(id) + + // named insert + query = sqldb.NewQueryBuild("person", db) + query.AddFields("first_name", "last_name", "email") + row, err := query.NamedInsert(&Person{ + FirstName: "四", + LastName: "李", + Email: "ls@bar.com", + AddedAt: time.Now().Unix(), + }) + gtest.Nil(t, err) + gtest.Equal(t, int64(1), row) + + // update + query = sqldb.NewQueryBuild("person", db) + query.Where("email=?") + n, err := query.Update("first_name=?", "哈哈", "ls@bar.com") + gtest.Nil(t, err) + gtest.Equal(t, int64(1), n) + + // named update map + query = sqldb.NewQueryBuild("person", db) + query.Where("email=:email") + n, err = query.NamedUpdate("first_name=:first_name", map[string]interface{}{ + "email": "ls@bar.com", + "first_name": "中华人民共和国", + }) + gtest.Nil(t, err) + gtest.Equal(t, int64(1), n) + + // named update struct + query = sqldb.NewQueryBuild("person", db) + query.Where("email=:email") + var p = &Person{ + Email: "ls@bar.com", + LastName: "中华人民共和国,救民于水火", + } + n, err = query.NamedUpdate("last_name=:last_name", p) + gtest.Nil(t, err) + gtest.Equal(t, int64(1), n) + + // count + query = sqldb.NewQueryBuild("person", db) + n, err = query.Count() + gtest.Nil(t, err) + // fmt.Println(n) + gtest.Greater(t, int64(0), n) + + // delete + query = sqldb.NewQueryBuild("person", db) + n, err = query.Delete() + gtest.NotNil(t, err) + gtest.Equal(t, int64(0), n) + + n, err = query.Where("id=?").Delete(2) + gtest.Nil(t, err) + gtest.Equal(t, int64(1), n) +} + +func TestQuery(t *testing.T) { + RunDbTest(t, dbQueryTest) +} diff --git a/gfile/file_copy.go b/gfile/file_copy.go new file mode 100644 index 0000000..8d2fbbb --- /dev/null +++ b/gfile/file_copy.go @@ -0,0 +1,133 @@ +// +// file_copy.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gfile + +import ( + "errors" + "io" + "io/ioutil" + "os" + "path/filepath" +) + +// Copy file/directory from `src` to `dst`. +// +// If `src` is file, it calls CopyFile to implements copy feature, +// or else it calls CopyDir. +func Copy(src string, dst string) error { + if src == "" { + return errors.New("source path cannot be empty") + } + if dst == "" { + return errors.New("destination path cannot be empty") + } + if IsFile(src) { + return CopyFile(src, dst) + } + return CopyDir(src, dst) +} + +// CopyFile copies the contents of the file named `src` to the file named +// by `dst`. The file will be created if it does not exist. If the +// destination file exists, all it's contents will be replaced by the contents +// of the source file. The file mode will be copied from the source and +// the copied data is synced/flushed to stable storage. +// Thanks: https://gist.github.com/r0l1/92462b38df26839a3ca324697c8cba04 +func CopyFile(src, dst string) (err error) { + if src == "" { + return errors.New("source file cannot be empty") + } + if dst == "" { + return errors.New("destination file cannot be empty") + } + // If src and dst are the same path, it does nothing. + if src == dst { + return nil + } + in, err := os.Open(src) + if err != nil { + return + } + defer func() { + if e := in.Close(); e != nil { + err = e + } + }() + out, err := os.Create(dst) + if err != nil { + return + } + defer func() { + if e := out.Close(); e != nil { + err = e + } + }() + if _, err = io.Copy(out, in); err != nil { + return + } + if err = out.Sync(); err != nil { + return + } + err = os.Chmod(dst, DefaultPermCopy) + if err != nil { + return + } + return +} + +// CopyDir recursively copies a directory tree, attempting to preserve permissions. +// +// Note that, the Source directory must exist and symlinks are ignored and skipped. +func CopyDir(src string, dst string) (err error) { + if src == "" { + return errors.New("source directory cannot be empty") + } + if dst == "" { + return errors.New("destination directory cannot be empty") + } + // If src and dst are the same path, it does nothing. + if src == dst { + return nil + } + src = filepath.Clean(src) + dst = filepath.Clean(dst) + si, err := os.Stat(src) + if err != nil { + return err + } + if !si.IsDir() { + return errors.New("source is not a directory") + } + if !Exists(dst) { + if err = os.MkdirAll(dst, DefaultPermCopy); err != nil { + return + } + } + entries, err := ioutil.ReadDir(src) + if err != nil { + return + } + for _, entry := range entries { + srcPath := filepath.Join(src, entry.Name()) + dstPath := filepath.Join(dst, entry.Name()) + if entry.IsDir() { + if err = CopyDir(srcPath, dstPath); err != nil { + return + } + } else { + // Skip symlinks. + if entry.Mode()&os.ModeSymlink != 0 { + continue + } + if err = CopyFile(srcPath, dstPath); err != nil { + return + } + } + } + return +} diff --git a/gfile/file_home.go b/gfile/file_home.go new file mode 100644 index 0000000..6c681eb --- /dev/null +++ b/gfile/file_home.go @@ -0,0 +1,81 @@ +// +// file_home.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gfile + +import ( + "bytes" + "errors" + "os" + "os/exec" + "os/user" + "runtime" + "strings" +) + +// Home returns absolute path of current user's home directory. +// The optional parameter `names` specifies the sub-folders/sub-files, +// which will be joined with current system separator and returned with the path. +func Home(names ...string) (string, error) { + path, err := getHomePath() + if err != nil { + return "", err + } + for _, name := range names { + path += Separator + name + } + return path, nil +} + +// getHomePath returns absolute path of current user's home directory. +func getHomePath() (string, error) { + u, err := user.Current() + if nil == err { + return u.HomeDir, nil + } + if runtime.GOOS == "windows" { + return homeWindows() + } + return homeUnix() +} + +// homeUnix retrieves and returns the home on unix system. +func homeUnix() (string, error) { + if home := os.Getenv("HOME"); home != "" { + return home, nil + } + var stdout bytes.Buffer + cmd := exec.Command("sh", "-c", "eval echo ~$USER") + cmd.Stdout = &stdout + if err := cmd.Run(); err != nil { + return "", err + } + + result := strings.TrimSpace(stdout.String()) + if result == "" { + return "", errors.New("blank output when reading home directory") + } + + return result, nil +} + +// homeWindows retrieves and returns the home on windows system. +func homeWindows() (string, error) { + var ( + drive = os.Getenv("HOMEDRIVE") + path = os.Getenv("HOMEPATH") + home = drive + path + ) + if drive == "" || path == "" { + home = os.Getenv("USERPROFILE") + } + if home == "" { + return "", errors.New("environment keys HOMEDRIVE, HOMEPATH and USERPROFILE are empty") + } + + return home, nil +} diff --git a/gfile/file_path.go b/gfile/file_path.go new file mode 100644 index 0000000..e8fbbf9 --- /dev/null +++ b/gfile/file_path.go @@ -0,0 +1,177 @@ +// +// file_path.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gfile + +import ( + "os" + "path/filepath" + "strings" +) + +const ( + // Separator for file system. + // It here defines the separator as variable + // to allow it modified by developer if necessary. + Separator = string(filepath.Separator) + + // DefaultPermOpen is the default perm for file opening. + DefaultPermOpen = os.FileMode(0655) + + // DefaultPermCopy is the default perm for file/folder copy. + DefaultPermCopy = os.FileMode(0755) +) + +// Exists checks whether given `path` exist. +func Exists(path string) bool { + if stat, err := os.Stat(path); stat != nil && !os.IsNotExist(err) { + return true + } + return false +} + +// IsDir checks whether given `path` a directory. +// Note that it returns false if the `path` does not exist. +func IsDir(path string) bool { + s, err := os.Stat(path) + if err != nil { + return false + } + return s.IsDir() +} + +// Pwd returns absolute path of current working directory. +// Note that it returns an empty string if retrieving current +// working directory failed. +func Pwd() string { + path, err := os.Getwd() + if err != nil { + return "" + } + return path +} + +// Chdir changes the current working directory to the named directory. +// If there is an error, it will be of type *PathError. +func Chdir(dir string) (err error) { + err = os.Chdir(dir) + return +} + +// IsFile checks whether given `path` a file, which means it's not a directory. +// Note that it returns false if the `path` does not exist. +func IsFile(path string) bool { + s, err := os.Stat(path) + if err != nil { + return false + } + return !s.IsDir() +} + +// DirNames returns sub-file names of given directory `path`. +// Note that the returned names are NOT absolute paths. +func DirNames(path string) ([]string, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + list, err := f.Readdirnames(-1) + _ = f.Close() + if err != nil { + return nil, err + } + return list, nil +} + +// IsReadable checks whether given `path` is readable. +func IsReadable(path string) bool { + result := true + file, err := os.OpenFile(path, os.O_RDONLY, DefaultPermOpen) + if err != nil { + result = false + } + file.Close() + return result +} + +// Name returns the last element of path without file extension. +// Example: +// /var/www/file.js -> file +// file.js -> file +func Name(path string) string { + base := filepath.Base(path) + if i := strings.LastIndexByte(base, '.'); i != -1 { + return base[:i] + } + return base +} + +// Dir returns all but the last element of path, typically the path's directory. +// After dropping the final element, Dir calls Clean on the path and trailing +// slashes are removed. +// If the `path` is empty, Dir returns ".". +// If the `path` is ".", Dir treats the path as current working directory. +// If the `path` consists entirely of separators, Dir returns a single separator. +// The returned path does not end in a separator unless it is the root directory. +func Dir(path string) string { + if path == "." { + p, _ := filepath.Abs(path) + return filepath.Dir(p) + } + return filepath.Dir(path) +} + +// IsEmpty checks whether the given `path` is empty. +// If `path` is a folder, it checks if there's any file under it. +// If `path` is a file, it checks if the file size is zero. +// +// Note that it returns true if `path` does not exist. +func IsEmpty(path string) bool { + stat, err := os.Stat(path) + if err != nil { + return true + } + if stat.IsDir() { + file, err := os.Open(path) + if err != nil { + return true + } + defer file.Close() + names, err := file.Readdirnames(-1) + if err != nil { + return true + } + return len(names) == 0 + } else { + return stat.Size() == 0 + } +} + +// Ext returns the file name extension used by path. +// The extension is the suffix beginning at the final dot +// in the final element of path; it is empty if there is +// no dot. +// Note: the result contains symbol '.'. +// Eg: +// main.go => .go +// api.json => .json +func Ext(path string) string { + ext := filepath.Ext(path) + if p := strings.IndexByte(ext, '?'); p != -1 { + ext = ext[0:p] + } + return ext +} + +// ExtName is like function Ext, which returns the file name extension used by path, +// but the result does not contain symbol '.'. +// Eg: +// main.go => go +// api.json => json +func ExtName(path string) string { + return strings.TrimLeft(Ext(path), ".") +} diff --git a/gfile/file_size.go b/gfile/file_size.go new file mode 100644 index 0000000..dc5bd45 --- /dev/null +++ b/gfile/file_size.go @@ -0,0 +1,132 @@ +// +// file_size.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gfile + +import ( + "fmt" + "os" + "strconv" + "strings" +) + +// Size returns the size of file specified by `path` in byte. +func Size(path string) int64 { + s, e := os.Stat(path) + if e != nil { + return 0 + } + return s.Size() +} + +// SizeFormat returns the size of file specified by `path` in format string. +func SizeFormat(path string) string { + return FormatSize(Size(path)) +} + +// ReadableSize formats size of file given by `path`, for more human readable. +func ReadableSize(path string) string { + return FormatSize(Size(path)) +} + +// StrToSize converts formatted size string to its size in bytes. +func StrToSize(sizeStr string) int64 { + i := 0 + for ; i < len(sizeStr); i++ { + if sizeStr[i] == '.' || (sizeStr[i] >= '0' && sizeStr[i] <= '9') { + continue + } else { + break + } + } + var ( + unit = sizeStr[i:] + number, _ = strconv.ParseFloat(sizeStr[:i], 64) + ) + if unit == "" { + return int64(number) + } + switch strings.ToLower(unit) { + case "b", "bytes": + return int64(number) + case "k", "kb", "ki", "kib", "kilobyte": + return int64(number * 1024) + case "m", "mb", "mi", "mib", "mebibyte": + return int64(number * 1024 * 1024) + case "g", "gb", "gi", "gib", "gigabyte": + return int64(number * 1024 * 1024 * 1024) + case "t", "tb", "ti", "tib", "terabyte": + return int64(number * 1024 * 1024 * 1024 * 1024) + case "p", "pb", "pi", "pib", "petabyte": + return int64(number * 1024 * 1024 * 1024 * 1024 * 1024) + case "e", "eb", "ei", "eib", "exabyte": + return int64(number * 1024 * 1024 * 1024 * 1024 * 1024 * 1024) + case "z", "zb", "zi", "zib", "zettabyte": + return int64(number * 1024 * 1024 * 1024 * 1024 * 1024 * 1024 * 1024) + case "y", "yb", "yi", "yib", "yottabyte": + return int64(number * 1024 * 1024 * 1024 * 1024 * 1024 * 1024 * 1024 * 1024) + case "bb", "brontobyte": + return int64(number * 1024 * 1024 * 1024 * 1024 * 1024 * 1024 * 1024 * 1024 * 1024) + } + return -1 +} + +// FormatSize formats size `raw` for more manually readable. +func FormatSize(raw int64) string { + var r float64 = float64(raw) + var t float64 = 1024 + var d float64 = 1 + if r < t { + return fmt.Sprintf("%.2fB", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fK", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fM", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fG", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fT", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fP", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fE", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fZ", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fY", r/d) + } + d *= 1024 + t *= 1024 + if r < t { + return fmt.Sprintf("%.2fBB", r/d) + } + return "TooLarge" +} diff --git a/gfile/file_time.go b/gfile/file_time.go new file mode 100644 index 0000000..6ad59b3 --- /dev/null +++ b/gfile/file_time.go @@ -0,0 +1,40 @@ +// +// file_time.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gfile + +import ( + "os" + "time" +) + +// MTime returns the modification time of file given by `path` in second. +func MTime(path string) time.Time { + s, e := os.Stat(path) + if e != nil { + return time.Time{} + } + return s.ModTime() +} + +// MTimestamp returns the modification time of file given by `path` in second. +func MTimestamp(path string) int64 { + mtime := MTime(path) + if mtime.IsZero() { + return -1 + } + return mtime.Unix() +} + +// MTimestampMilli returns the modification time of file given by `path` in millisecond. +func MTimestampMilli(path string) int64 { + mtime := MTime(path) + if mtime.IsZero() { + return -1 + } + return mtime.UnixNano() / 1000000 +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..99e8caf --- /dev/null +++ b/go.mod @@ -0,0 +1,68 @@ +module hexq.cn/tiglog/golib + +go 1.20 + +require ( + github.com/casbin/casbin/v2 v2.70.0 + github.com/gin-gonic/gin v1.9.1 + github.com/go-redis/redis/v8 v8.11.5 + github.com/go-sql-driver/mysql v1.7.1 + github.com/hibiken/asynq v0.24.1 + github.com/jmoiron/sqlx v1.3.5 + github.com/lib/pq v1.10.9 + github.com/mattn/go-runewidth v0.0.14 + github.com/natefinch/lumberjack v2.0.0+incompatible + github.com/rs/xid v1.5.0 + github.com/rs/zerolog v1.29.1 + go.mongodb.org/mongo-driver v1.11.7 + golang.org/x/crypto v0.10.0 + gopkg.in/yaml.v2 v2.4.0 +) + +require ( + github.com/BurntSushi/toml v1.3.2 // indirect + github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible // indirect + github.com/bytedance/sonic v1.9.1 // indirect + github.com/cespare/xxhash/v2 v2.2.0 // indirect + github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 // indirect + github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f // indirect + github.com/gabriel-vasile/mimetype v1.4.2 // indirect + github.com/gin-contrib/sse v0.1.0 // indirect + github.com/go-playground/locales v0.14.1 // indirect + github.com/go-playground/universal-translator v0.18.1 // indirect + github.com/go-playground/validator/v10 v10.14.0 // indirect + github.com/goccy/go-json v0.10.2 // indirect + github.com/golang/protobuf v1.5.2 // indirect + github.com/golang/snappy v0.0.1 // indirect + github.com/google/uuid v1.2.0 // indirect + github.com/json-iterator/go v1.1.12 // indirect + github.com/klauspost/compress v1.13.6 // indirect + github.com/klauspost/cpuid/v2 v2.2.4 // indirect + github.com/leodido/go-urn v1.2.4 // indirect + github.com/mattn/go-colorable v0.1.12 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect + github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd // indirect + github.com/modern-go/reflect2 v1.0.2 // indirect + github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe // indirect + github.com/pelletier/go-toml/v2 v2.0.8 // indirect + github.com/pkg/errors v0.9.1 // indirect + github.com/redis/go-redis/v9 v9.0.3 // indirect + github.com/rivo/uniseg v0.2.0 // indirect + github.com/robfig/cron/v3 v3.0.1 // indirect + github.com/spf13/cast v1.3.1 // indirect + github.com/twitchyliquid64/golang-asm v0.15.1 // indirect + github.com/ugorji/go/codec v1.2.11 // indirect + github.com/xdg-go/pbkdf2 v1.0.0 // indirect + github.com/xdg-go/scram v1.1.1 // indirect + github.com/xdg-go/stringprep v1.0.3 // indirect + github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d // indirect + golang.org/x/arch v0.3.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c // indirect + golang.org/x/sys v0.9.0 // indirect + golang.org/x/text v0.10.0 // indirect + golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 // indirect + google.golang.org/protobuf v1.30.0 // indirect + gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..f28bcdf --- /dev/null +++ b/go.sum @@ -0,0 +1,223 @@ +github.com/BurntSushi/toml v1.3.2 h1:o7IhLm0Msx3BaB+n3Ag7L8EVlByGnpq14C4YWiu/gL8= +github.com/BurntSushi/toml v1.3.2/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible h1:1G1pk05UrOh0NlF1oeaaix1x8XzrfjIDK47TY0Zehcw= +github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= +github.com/bsm/ginkgo/v2 v2.7.0 h1:ItPMPH90RbmZJt5GtkcNvIRuGEdwlBItdNVoyzaNQao= +github.com/bsm/ginkgo/v2 v2.7.0/go.mod h1:AiKlXPm7ItEHNc/2+OkrNG4E0ITzojb9/xWzvQ9XZ9w= +github.com/bsm/gomega v1.26.0 h1:LhQm+AFcgV2M0WyKroMASzAzCAJVpAxQXv4SaI9a69Y= +github.com/bsm/gomega v1.26.0/go.mod h1:JyEr/xRbxbtgWNi8tIEVPUYZ5Dzef52k01W3YH0H+O0= +github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= +github.com/bytedance/sonic v1.9.1 h1:6iJ6NqdoxCDr6mbY8h18oSO+cShGSMRGCEo7F2h0x8s= +github.com/bytedance/sonic v1.9.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= +github.com/casbin/casbin/v2 v2.70.0 h1:CuoWeWpMj6GsXf5K1npAKHEMb+9k9QE/Mo7cVZmSJ98= +github.com/casbin/casbin/v2 v2.70.0/go.mod h1:vByNa/Fchek0KZUgG5wEsl7iFsiviAYKRtgrQfcJqHg= +github.com/cespare/xxhash/v2 v2.2.0 h1:DC2CZ1Ep5Y4k3ZQ899DldepgrayRUGE6BBZ/cd9Cj44= +github.com/cespare/xxhash/v2 v2.2.0/go.mod h1:VGX0DQ3Q6kWi7AoAeZDth3/j3BFtOZR5XLFGgcrjCOs= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= +github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= +github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f h1:lO4WD4F/rVNCu3HqELle0jiPLLBs70cWOduZpkS1E78= +github.com/dgryski/go-rendezvous v0.0.0-20200823014737-9f7001d12a5f/go.mod h1:cuUVRXasLTGF7a8hSLbxyZXjz+1KgoB3wDUb6vlszIc= +github.com/fsnotify/fsnotify v1.4.9 h1:hsms1Qyu0jgnwNXIxa+/V/PDsU6CfLf6CNO8H7IWoS4= +github.com/gabriel-vasile/mimetype v1.4.2 h1:w5qFW6JKBz9Y393Y4q372O9A7cUSequkh1Q7OhCmWKU= +github.com/gabriel-vasile/mimetype v1.4.2/go.mod h1:zApsH/mKG4w07erKIaJPFiX0Tsq9BFQgN3qGY5GnNgA= +github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE= +github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= +github.com/gin-gonic/gin v1.9.1 h1:4idEAncQnU5cB7BeOkPtxjfCSye0AAm1R0RVIqJ+Jmg= +github.com/gin-gonic/gin v1.9.1/go.mod h1:hPrL7YrpYKXt5YId3A/Tnip5kqbEAP+KLuI3SUcPTeU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.14.0 h1:vgvQWe3XCz3gIeFDm/HnTIbj6UGmg/+t63MyGU2n5js= +github.com/go-playground/validator/v10 v10.14.0/go.mod h1:9iXMNT7sEkjXb0I+enO7QXmzG6QCsPWY4zveKFVRSyU= +github.com/go-redis/redis/v8 v8.11.5 h1:AcZZR7igkdvfVmQTPnu9WE37LRrO/YrBH5zWyjDC0oI= +github.com/go-redis/redis/v8 v8.11.5/go.mod h1:gREzHqY1hg6oD9ngVRbLStwAWKhA0FEgq8Jd4h5lpwo= +github.com/go-sql-driver/mysql v1.6.0/go.mod h1:DCzpHaOWr8IXmIStZouvnhqoel9Qv2LBy8hT2VhHyBg= +github.com/go-sql-driver/mysql v1.7.1 h1:lUIinVbN1DY0xBg0eMOzmmtGoHwWBbvnWubQUrtU8EI= +github.com/go-sql-driver/mysql v1.7.1/go.mod h1:OXbVy3sEdcQ2Doequ6Z5BW6fXNQTmx+9S1MCJN5yJMI= +github.com/goccy/go-json v0.10.2 h1:CrxCmQqYDkv1z7lO7Wbh2HN93uovUHgrECaO5ZrCXAU= +github.com/goccy/go-json v0.10.2/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/godbus/dbus/v5 v5.0.4/go.mod h1:xhWf0FNVPg57R7Z0UbKHbJfkEywrmjJnf7w5xrFpKfA= +github.com/golang/mock v1.4.4 h1:l75CXGRSwbaYNpl/Z2X1XIIAMSCquvXgpVZDhwEIJsc= +github.com/golang/mock v1.4.4/go.mod h1:l3mdAwkq5BuhzHwde/uurv3sEJeZMXNpwsxVWU71h+4= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/golang/protobuf v1.5.2 h1:ROPKBNFfQgOUMifHyP+KYbvpjbdoFNs+aK7DXlji0Tw= +github.com/golang/protobuf v1.5.2/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY= +github.com/golang/snappy v0.0.1 h1:Qgr9rKW7uDUkrbSmQeiDsGa8SjGyCOGtuasMWwvp2P4= +github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= +github.com/google/go-cmp v0.5.2/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ= +github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.2.0 h1:qJYtXnJRWmpe7m/3XlyhrsLrEURqHRM2kxzoxXqyUDs= +github.com/google/uuid v1.2.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/hibiken/asynq v0.24.1 h1:+5iIEAyA9K/lcSPvx3qoPtsKJeKI5u9aOIvUmSsazEw= +github.com/hibiken/asynq v0.24.1/go.mod h1:u5qVeSbrnfT+vtG5Mq8ZPzQu/BmCKMHvTGb91uy9Tts= +github.com/jmoiron/sqlx v1.3.5 h1:vFFPA71p1o5gAeqtEAwLU4dnX2napprKtHr7PYIcN3g= +github.com/jmoiron/sqlx v1.3.5/go.mod h1:nRVWtLre0KfCLJvgxzCsLVMogSvQ1zNJtpYr2Ccp0mQ= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/compress v1.13.6 h1:P76CopJELS0TiO2mebmnzgWaajssP/EszplttgQxcgc= +github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/klauspost/cpuid/v2 v2.2.4 h1:acbojRNwl3o09bUq+yDCtZFc1aiwaAAxtcn8YkZXnvk= +github.com/klauspost/cpuid/v2 v2.2.4/go.mod h1:RVVoqg1df56z8g3pUjL/3lE5UfnlrJX8tyFgg4nqhuY= +github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/leodido/go-urn v1.2.4 h1:XlAE/cm/ms7TE/VMVoduSpNBoyc2dOxHs5MZSwAN63Q= +github.com/leodido/go-urn v1.2.4/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/lib/pq v1.2.0/go.mod h1:5WUZQaWbwv1U+lTReE5YruASi9Al49XbQIvNi/34Woo= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= +github.com/mattn/go-colorable v0.1.12 h1:jF+Du6AlPIjs2BiUiQlKOX0rt3SujHxPnksPKZbaA40= +github.com/mattn/go-colorable v0.1.12/go.mod h1:u5H1YNBxpqRaxsYJYSkiCWKzEfiAb1Gb520KVy5xxl4= +github.com/mattn/go-isatty v0.0.14/go.mod h1:7GGIvUiUoEMVVmxf/4nioHXj79iQHKdU27kJ6hsGG94= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= +github.com/mattn/go-runewidth v0.0.14 h1:+xnbZSEeDbOIg5/mE6JF0w6n9duR1l3/WmbinWVwUuU= +github.com/mattn/go-runewidth v0.0.14/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= +github.com/mattn/go-sqlite3 v1.14.6 h1:dNPt6NO46WmLVt2DLNpwczCmdV5boIZ6g/tlDrlRUbg= +github.com/mattn/go-sqlite3 v1.14.6/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd h1:TRLaZ9cD/w8PVh93nsPXa1VrQ6jlwL5oN8l14QlcNfg= +github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe h1:iruDEfMl2E6fbMZ9s0scYfZQ84/6SPL6zC8ACM2oIL0= +github.com/montanaflynn/stats v0.0.0-20171201202039-1bf9dbcd8cbe/go.mod h1:wL8QJuTMNUDYhXwkmfOly8iTdp5TEcJFWZD2D7SIkUc= +github.com/natefinch/lumberjack v2.0.0+incompatible h1:4QJd3OLAMgj7ph+yZTuX13Ld4UpgHp07nNdFX7mqFfM= +github.com/natefinch/lumberjack v2.0.0+incompatible/go.mod h1:Wi9p2TTF5DG5oU+6YfsmYQpsTIOm0B1VNzQg9Mw6nPk= +github.com/nxadm/tail v1.4.8 h1:nPr65rt6Y5JFSKQO7qToXr7pePgD6Gwiw05lkbyAQTE= +github.com/onsi/ginkgo v1.16.5 h1:8xi0RTUf59SOSfEtZMvwTvXYMzG4gV23XVHOZiXNtnE= +github.com/onsi/gomega v1.18.1 h1:M1GfJqGRrBrrGGsbxzV5dqM2U2ApXefZCQpkukxYRLE= +github.com/pelletier/go-toml/v2 v2.0.8 h1:0ctb6s9mE31h0/lhu+J6OPmVeDxJn+kYnJc2jZR9tGQ= +github.com/pelletier/go-toml/v2 v2.0.8/go.mod h1:vuYfssBdrU2XDZ9bYydBu6t+6a6PYNcZljzZR9VXg+4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/redis/go-redis/v9 v9.0.3 h1:+7mmR26M0IvyLxGZUHxu4GiBkJkVDid0Un+j4ScYu4k= +github.com/redis/go-redis/v9 v9.0.3/go.mod h1:WqMKv5vnQbRuZstUwxQI195wHy+t4PuXDOjzMvcuQHk= +github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= +github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= +github.com/robfig/cron/v3 v3.0.1 h1:WdRxkvbJztn8LMz/QEvLN5sBU+xKpSqwwUO1Pjr4qDs= +github.com/robfig/cron/v3 v3.0.1/go.mod h1:eQICP3HwyT7UooqI/z+Ov+PtYAWygg1TEWWzGIFLtro= +github.com/rs/xid v1.4.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/xid v1.5.0 h1:mKX4bl4iPYJtEIxp6CYiUuLQ/8DYMoz0PUdtGgMFRVc= +github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= +github.com/rs/zerolog v1.29.1 h1:cO+d60CHkknCbvzEWxP0S9K6KqyTjrCNUy1LdQLCGPc= +github.com/rs/zerolog v1.29.1/go.mod h1:Le6ESbR7hc+DP6Lt1THiV8CQSdkkNrd3R0XbEgp3ZBU= +github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= +github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.2/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/stretchr/testify v1.8.3 h1:RP3t2pwF7cMEbC1dqtB6poj3niw/9gnV4Cjg5oW5gtY= +github.com/stretchr/testify v1.8.3/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tidwall/pretty v1.0.0 h1:HsD+QiTn7sK6flMKIvNmpqz1qrpP3Ps6jOKIKMooyg4= +github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/ugorji/go/codec v1.2.11 h1:BMaWp1Bb6fHwEtbplGBGJ498wD+LKlNSl25MjdZY4dU= +github.com/ugorji/go/codec v1.2.11/go.mod h1:UNopzCgEMSXjBc6AOMqYvWC1ktqTAfzJZUZgYf6w6lg= +github.com/xdg-go/pbkdf2 v1.0.0 h1:Su7DPu48wXMwC3bs7MCNG+z4FhcyEuz5dlvchbq0B0c= +github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI= +github.com/xdg-go/scram v1.1.1 h1:VOMT+81stJgXW3CpHyqHN3AXDYIMsx56mEFrB37Mb/E= +github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g= +github.com/xdg-go/stringprep v1.0.3 h1:kdwGpVNwPFtjs98xCGkHjQtGKh86rDcRZN17QEMCOIs= +github.com/xdg-go/stringprep v1.0.3/go.mod h1:W3f5j4i+9rC0kuIEJL0ky1VpHXQU3ocBgklLGvcBnW8= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d h1:splanxYIlg+5LfHAM6xpdFEAYOk8iySO56hMFq6uLyA= +github.com/youmark/pkcs8 v0.0.0-20181117223130-1be2e3e5546d/go.mod h1:rHwXgn7JulP+udvsHwJoVG1YGAP6VLg4y9I5dyZdqmA= +github.com/yuin/goldmark v1.3.5/go.mod h1:mwnBkeHKe2W/ZEtQ+71ViKU8L12m81fl3OWwC1Zlc8k= +go.mongodb.org/mongo-driver v1.11.7 h1:LIwYxASDLGUg/8wOhgOOZhX8tQa/9tgZPgzZoVqJvcs= +go.mongodb.org/mongo-driver v1.11.7/go.mod h1:G9TgswdsWjX4tmDA5zfs2+6AEPpYJwqblyjsfuh8oXY= +go.uber.org/goleak v1.1.12 h1:gZAh5/EyT/HQwlpkCy6wTpqfH9H8Lz8zbm3dZh+OyzA= +go.uber.org/goleak v1.1.12/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/arch v0.3.0 h1:02VY4/ZcO/gBOH6PUaoiptASxtXU10jazRCP865E97k= +golang.org/x/arch v0.3.0/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= +golang.org/x/crypto v0.0.0-20220622213112-05595931fe9d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/crypto v0.10.0 h1:LKqV2xt9+kDzSTfOhx4FrkEBcMrAgHSYgzywV9zcGmM= +golang.org/x/crypto v0.10.0/go.mod h1:o4eNf7Ede1fv+hwOwZsTHl9EsPFO6q6ZvYR8vYfY45I= +golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= +golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210405180319-a5a99cb37ef4/go.mod h1:p54w0d4576C0XHj96bSt6lcn1PtDYWL6XObtHCRCNQM= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.10.0 h1:X2//UzNDwYmtCLn7To6G58Wr6f5ahEAQgKNzv9Y951M= +golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210330210617-4fbd30eecc44/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210510120138-977fb7262007/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210630005230-0f9fa26af87c/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210927094055-39ccf1dd6fa6/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20211216021012-1d35b9e2eb4e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220704084225-05e143d24a9e/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.9.0 h1:KS/R3tvhPqvJvwcKfnBHJwwthS11LRhmM5D59eEXa0s= +golang.org/x/sys v0.9.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.10.0 h1:UpjohKhiEgNc0CSauXmwYftY1+LlaC75SJwh0SgCX58= +golang.org/x/text v0.10.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ= +golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1 h1:go1bK/D/BFZV2I8cIQd1NKEZ+0owSTG1fDTci4IqFcE= +golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +google.golang.org/protobuf v1.30.0 h1:kPPoIgf3TsEvrm0PFe15JQ+570QVxYzEvvHqChK+cng= +google.golang.org/protobuf v1.30.0/go.mod h1:HV8QOd/L58Z+nl8r43ehVNZIU/HEI6OcFqwMG9pJV4I= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= +gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/natefinch/lumberjack.v2 v2.2.1 h1:bBRl1b0OH9s/DuPhuXpNl+VtCaJXFZ5/uEFST95x9zc= +gopkg.in/natefinch/lumberjack.v2 v2.2.1/go.mod h1:YD8tP3GAjkrDg1eZH7EGmyESg/lsYskCTPBJVb9jqSc= +gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 h1:uRGJdciOHaEIrze2W8Q3AKkepLTh2hOroT7a+7czfdQ= +gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= +gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/gqueue/queue.go b/gqueue/queue.go new file mode 100644 index 0000000..06d1599 --- /dev/null +++ b/gqueue/queue.go @@ -0,0 +1,49 @@ +// +// queue.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gqueue + +import ( + "sync" + + "github.com/hibiken/asynq" +) + +var onceCli sync.Once +var onceSvc sync.Once + +var redisOpt asynq.RedisClientOpt + +func Init(addr, username, password string, database int) { + redisOpt = asynq.RedisClientOpt{ + Addr: addr, + Username: username, + Password: password, + DB: database, + } +} + +var cli *asynq.Client + +func Client() *asynq.Client { + onceCli.Do(func() { + cli = asynq.NewClient(redisOpt) + }) + return cli +} + +var svc *asynq.Server + +func Server() *asynq.Server { + onceSvc.Do(func() { + svc = asynq.NewServer( + redisOpt, + asynq.Config{Concurrency: 10}, + ) + }) + return svc +} diff --git a/gqueue/queue_test.go b/gqueue/queue_test.go new file mode 100644 index 0000000..544928b --- /dev/null +++ b/gqueue/queue_test.go @@ -0,0 +1,32 @@ +// +// queue_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gqueue_test + +import ( + "os" + "testing" + + "hexq.cn/tiglog/golib/gqueue" + "hexq.cn/tiglog/golib/gtest" +) + +func TestClient(t *testing.T) { + gqueue.Init(os.Getenv("REDIS_ADDR"), os.Getenv("REDIS_USERNAME"), os.Getenv("REDIS_PASSWORD"), 0) + c1 := gqueue.Client() + c2 := gqueue.Client() + gtest.NotNil(t, c1) + gtest.Equal(t, c1, c2) +} + +func TestServer(t *testing.T) { + gqueue.Init(os.Getenv("REDIS_ADDR"), os.Getenv("REDIS_USERNAME"), os.Getenv("REDIS_PASSWORD"), 0) + s1 := gqueue.Server() + s2 := gqueue.Server() + gtest.NotNil(t, s1) + gtest.Equal(t, s1, s2) +} diff --git a/gtest/test.go b/gtest/test.go new file mode 100644 index 0000000..07cb4f0 --- /dev/null +++ b/gtest/test.go @@ -0,0 +1,179 @@ +// +// test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package gtest + +import ( + "reflect" + "strings" + "testing" +) + +func Equal(t *testing.T, expected, val interface{}) { + if val != expected { + t.Errorf("Expected [%v] (type %v), but got [%v] (type %v)", expected, reflect.TypeOf(expected), val, reflect.TypeOf(val)) + } +} + +func NotEqual(t *testing.T, expected, val interface{}) { + if val == expected { + t.Errorf("Expected not [%v] (type %v), but got [%v] (type %v)", expected, reflect.TypeOf(expected), val, reflect.TypeOf(val)) + } +} + +func True(t *testing.T, val interface{}) { + if val != true { + t.Errorf("Expected true, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +func False(t *testing.T, val interface{}) { + if val != false { + t.Errorf("Expected false, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +func MsgNil(t *testing.T, val interface{}) { + t.Errorf("Expected nil, but got [%v] (type %v)", val, reflect.TypeOf(val)) +} + +func MsgNotNil(t *testing.T, val interface{}) { + t.Errorf("Expected not nil, but got [%v] (type %v)", val, reflect.TypeOf(val)) +} + +func MsgExpect(t *testing.T, expect string, val string) { + t.Errorf("Expected %s, but got [%v] (type %v)", expect, val, reflect.TypeOf(val)) +} + +// 有些情况下不能使用 +func Nil(t *testing.T, val interface{}) { + if val != nil { + t.Errorf("Expected nil, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +// 有些情况下不能使用 +func NotNil(t *testing.T, val interface{}) { + if val == nil { + t.Errorf("Expected not nil, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +func IsString(t *testing.T, val interface{}) { + _, ok := val.(string) + if !ok { + t.Errorf("Expected a string, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +func IsBool(t *testing.T, val interface{}) { + _, ok := val.(bool) + if !ok { + t.Errorf("Expected a bool, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +// val 要大于 base +func Greater(t *testing.T, base int64, val interface{}) { + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + if v.Int() <= base { + t.Errorf("Expected greater than %d, but got %v", base, v) + } + return + case reflect.Chan, reflect.Map, reflect.Slice: + if int64(v.Len()) <= base { + t.Errorf("Expected greater than %d, but got %d (%v)", base, v.Len(), v) + } + return + } + t.Errorf("Expected a num or countable val, but got [%v] (type %v)", val, reflect.TypeOf(val)) +} + +// val 要小于 base +func Less(t *testing.T, base int64, val interface{}) { + v := reflect.ValueOf(val) + switch v.Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + if v.Int() >= base { + t.Errorf("Expected less than %d, but got %v", base, v) + } + return + case reflect.Chan, reflect.Map, reflect.Slice: + if int64(v.Len()) >= base { + t.Errorf("Expected less than %d, but got %d (%v)", base, v.Len(), v) + } + return + } + t.Errorf("Expected a num or countable val, but got [%v] (type %v)", val, reflect.TypeOf(val)) +} + +func StartWith(t *testing.T, e string, val string) { + if !strings.HasPrefix(val, e) { + t.Errorf("Expected a string has prefix %s, but got %s", e, val) + } +} + +func EndWith(t *testing.T, e string, val string) { + if !strings.HasSuffix(val, e) { + t.Errorf("Expected a string has suffix %s, but got %s", e, val) + } +} +func ContainWith(t *testing.T, e string, val string) { + if !strings.Contains(val, e) { + t.Errorf("Expected a string contain %s, but got %s", e, val) + } +} + +func Error(t *testing.T, err error) { + if err == nil { + t.Errorf("Expected an error, but got nil") + } +} + +// 不希望是一个 error +func NoError(t *testing.T, err error) { + if err != nil { + t.Errorf("Received unexpected error:\n%+v", err) + } +} + +func Empty(t *testing.T, val interface{}) { + if !isEmpty(val) { + t.Errorf("Should be empty, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +func NotEmpty(t *testing.T, val interface{}) { + if isEmpty(val) { + t.Errorf("should be not emtpy, but got [%v] (type %v)", val, reflect.TypeOf(val)) + } +} + +func isEmpty(object interface{}) bool { + + if object == nil { + return true + } + + objValue := reflect.ValueOf(object) + + switch objValue.Kind() { + case reflect.Chan, reflect.Map, reflect.Slice: + return objValue.Len() == 0 + case reflect.Ptr: + if objValue.IsNil() { + return true + } + deref := objValue.Elem().Interface() + return isEmpty(deref) + default: + zero := reflect.Zero(objValue.Type()) + return reflect.DeepEqual(object, zero.Interface()) + } +} diff --git a/helper/conv_helper.go b/helper/conv_helper.go new file mode 100644 index 0000000..4b01a44 --- /dev/null +++ b/helper/conv_helper.go @@ -0,0 +1,117 @@ +// +// conv_helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +import ( + "strconv" +) + +func AnyToInt(val any, dv int) (int, error) { + switch val.(type) { + case int: + rv, _ := val.(int) + return rv, nil + case int8: + rv, _ := val.(int8) + return int(rv), nil + case int16: + rv, _ := val.(int16) + return int(rv), nil + case int32: + rv, _ := val.(int32) + return int(rv), nil + case int64: + rv, _ := val.(int64) + return int(rv), nil + case string: + sv, _ := val.(string) + rv, err := strconv.Atoi(sv) + if err != nil { + return dv, err + } + return rv, nil + case float32: + rv, _ := val.(float32) + return int(rv), nil + case float64: + rv, _ := val.(float64) + return int(rv), nil + } + return dv, nil +} + +func AnyToFloat(val any, dv float64) (float64, error) { + + switch val.(type) { + case int: + iv, _ := val.(int) + return float64(iv), nil + case int8: + iv, _ := val.(int8) + return float64(iv), nil + case int16: + iv, _ := val.(int16) + return float64(iv), nil + case int32: + iv, _ := val.(int32) + return float64(iv), nil + case int64: + iv, _ := val.(int64) + return float64(iv), nil + case string: + sv, _ := val.(string) + return strconv.ParseFloat(sv, 64) + case float32: + fv, _ := val.(float32) + return float64(fv), nil + case float64: + fv, _ := val.(float64) + return fv, nil + + } + return dv, nil +} + +// 只处理简单数据类型 +func AnyToString(val any, dv string) (string, error) { + switch val.(type) { + case string: + sv, _ := val.(string) + return sv, nil + case int: + iv, _ := val.(int) + return strconv.Itoa(iv), nil + case int8: + iv, _ := val.(int8) + return strconv.FormatInt(int64(iv), 10), nil + case int16: + iv, _ := val.(int16) + return strconv.FormatInt(int64(iv), 10), nil + case int32: + iv, _ := val.(int32) + return strconv.FormatInt(int64(iv), 10), nil + case int64: + iv, _ := val.(int64) + return strconv.FormatInt(iv, 10), nil + case float32: + fv, _ := val.(float32) + return strconv.FormatFloat(float64(fv), 'f', -1, 64), nil + case float64: + fv, _ := val.(float64) + return strconv.FormatFloat(fv, 'f', -1, 64), nil + case bool: + bv, _ := val.(bool) + if bv { + return "true", nil + } + return "false", nil + + } + + return dv, nil +} diff --git a/helper/conv_helper_test.go b/helper/conv_helper_test.go new file mode 100644 index 0000000..ccc4744 --- /dev/null +++ b/helper/conv_helper_test.go @@ -0,0 +1,156 @@ +// +// conv_helper_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper_test + +import ( + "fmt" + "testing" + + "hexq.cn/tiglog/golib/gtest" + "hexq.cn/tiglog/golib/helper" +) + +func TestAnyToInt(t *testing.T) { + var i1 int8 = 11 + v1, err := helper.AnyToInt(i1, 0) + gtest.NoError(t, err) + gtest.Equal(t, 11, v1) + fmt.Printf("v1 %d type is %T\n", v1, v1) + + var i2 int16 = 22 + v2, err := helper.AnyToInt(i2, 0) + gtest.NoError(t, err) + gtest.Equal(t, 22, v2) + fmt.Printf("v2 %d type is %T\n", v2, v2) + + var i3 int32 = 33 + v3, err := helper.AnyToInt(i3, 0) + gtest.NoError(t, err) + gtest.Equal(t, 33, v3) + fmt.Printf("v3 %d type is %T\n", v3, v3) + + var i4 int64 = 44 + v4, err := helper.AnyToInt(i4, 0) + gtest.NoError(t, err) + gtest.Equal(t, 44, v4) + fmt.Printf("v4 %d type is %T\n", v4, v4) + + var i5 int = 55 + v5, err := helper.AnyToInt(i5, 0) + gtest.NoError(t, err) + gtest.Equal(t, 55, v5) + fmt.Printf("v5 %d type is %T\n", v5, v5) + + var i6 string = "66" + v6, err := helper.AnyToInt(i6, 0) + gtest.NoError(t, err) + gtest.Equal(t, 66, v6) + fmt.Printf("v6 %d type is %T\n", v6, v6) + + var i7 float32 = 77 + v7, err := helper.AnyToInt(i7, 0) + gtest.NoError(t, err) + gtest.Equal(t, 77, v7) + fmt.Printf("v7 %d type is %T\n", v7, v7) + + var i8 float64 = 88.8 + v8, err := helper.AnyToInt(i8, 0) + gtest.NoError(t, err) + gtest.Equal(t, 88, v8) + fmt.Printf("v8 %d type is %T\n", v8, v8) +} + +func TestAnyToFloat(t *testing.T) { + var f1 int8 = 11 + v1, err := helper.AnyToFloat(f1, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(11), v1) + fmt.Printf("v1 %f type is %T\n", v1, v1) + + var f2 int16 = 22 + v2, err := helper.AnyToFloat(f2, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(22), v2) + fmt.Printf("v2 %f type is %T\n", v2, v2) + + var f3 int32 = 33 + v3, err := helper.AnyToFloat(f3, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(33), v3) + fmt.Printf("v3 %f type is %T\n", v3, v3) + + var f4 int64 = 44 + v4, err := helper.AnyToFloat(f4, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(44), v4) + fmt.Printf("v4 %f type is %T\n", v4, v4) + + var f5 int = 55 + v5, err := helper.AnyToFloat(f5, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(55), v5) + fmt.Printf("v5 %f type is %T\n", v5, v5) + + var f6 string = "66" + v6, err := helper.AnyToFloat(f6, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(66), v6) + fmt.Printf("v6 %f type is %T\n", v6, v6) + + var f7 float32 = 77 + v7, err := helper.AnyToFloat(f7, 0) + gtest.NoError(t, err) + gtest.Equal(t, float64(77), v7) + gtest.Equal(t, 77.0, v7) + fmt.Printf("v7 %f type is %T\n", v7, v7) + + var f8 float64 = 88.8 + v8, err := helper.AnyToFloat(f8, 0) + gtest.NoError(t, err) + gtest.Equal(t, 88.8, v8) + fmt.Printf("v8 %f type is %T\n", v8, v8) +} + +func TestAnyToString(t *testing.T) { + + var s1 int8 = 11 + v1, err := helper.AnyToString(s1, "") + gtest.NoError(t, err) + gtest.Equal(t, "11", v1) + fmt.Printf("v1 %s type is %T\n", v1, v1) + + var s2 int16 = 22 + v2, err := helper.AnyToString(s2, "") + gtest.NoError(t, err) + gtest.Equal(t, "22", v2) + fmt.Printf("v2 %s type is %T\n", v2, v2) + + var s5 int = 55 + v5, err := helper.AnyToString(s5, "") + gtest.NoError(t, err) + gtest.Equal(t, "55", v5) + fmt.Printf("v5 %s type is %T\n", v5, v5) + + var s6 string = "66" + v6, err := helper.AnyToString(s6, "") + gtest.NoError(t, err) + gtest.Equal(t, "66", v6) + fmt.Printf("v6 %s type is %T\n", v6, v6) + + var s7 float32 = 77 + v7, err := helper.AnyToString(s7, "") + gtest.NoError(t, err) + gtest.Equal(t, "77", v7) + fmt.Printf("v7 %s type is %T\n", v7, v7) + + var s8 float64 = 88.8 + v8, err := helper.AnyToString(s8, "") + gtest.NoError(t, err) + gtest.Equal(t, "88.8", v8) + fmt.Printf("v8 %s type is %T\n", v8, v8) +} diff --git a/helper/error_helper.go b/helper/error_helper.go new file mode 100644 index 0000000..4b4b004 --- /dev/null +++ b/helper/error_helper.go @@ -0,0 +1,14 @@ +// +// error_helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +func CheckErr(err error) { + if err != nil { + panic(err) + } +} diff --git a/helper/http_helper.go b/helper/http_helper.go new file mode 100644 index 0000000..9b8d6aa --- /dev/null +++ b/helper/http_helper.go @@ -0,0 +1,42 @@ +// +// http_helper.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +import ( + "bytes" + "encoding/json" + "errors" + "net/http" +) + +func RequestJson(url string, data []byte) (*http.Response, error) { + res, err := http.Post(url, "application/json", bytes.NewBuffer(data)) + if err != nil { + return nil, err + } + return res, nil +} + +func NotifyBack(url string, vals map[string]any, errmsg string) (*http.Response, error) { + if url == "" { + return nil, errors.New("no url") + } + vals["msg"] = errmsg + if errmsg == "" { + vals["stat"] = "ok" + } else { + vals["stat"] = "fail" + } + + body, _ := json.Marshal(vals) + res, err := http.Post(url, "application/json", bytes.NewReader(body)) + if err != nil { + return nil, err + } + return res, nil +} diff --git a/helper/resp_helper.go b/helper/resp_helper.go new file mode 100644 index 0000000..c548105 --- /dev/null +++ b/helper/resp_helper.go @@ -0,0 +1,59 @@ +// +// resp_helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +import ( + "net/http" + + "github.com/gin-gonic/gin" +) + +func RenderJson(c *gin.Context, code int, msg string, data interface{}) { + c.JSON(http.StatusOK, gin.H{ + "code": code, + "msg": msg, + "data": data, + }) +} + +// 成功时,返回的 code 为 0 +func RenderOk(c *gin.Context, data interface{}) { + RenderJson(c, 0, "success", data) +} + +// 失败时,返回的 code 为指定的 code +// 一般会比 http status code 要详细一点 +func RenderFail(c *gin.Context, code int, msg string) { + RenderJson(c, code, msg, nil) +} + +// 成功时,返回自定义消息和数据 +func RenderSuccess(c *gin.Context, msg string, data interface{}) { + RenderJson(c, 0, msg, data) +} + +func RenderClientError(c *gin.Context, err error) { + RenderJson(c, http.StatusBadRequest, err.Error(), nil) +} + +func RenderServerError(c *gin.Context, err error) { + RenderJson(c, http.StatusInternalServerError, err.Error(), nil) +} + +func RenderClientFail(c *gin.Context, msg string) { + RenderJson(c, http.StatusBadRequest, msg, nil) +} + +func RenderServerFail(c *gin.Context, msg string) { + RenderJson(c, http.StatusInternalServerError, msg, nil) +} + +// 未发现的各种情况 +func RenderNotFound(c *gin.Context, msg string) { + RenderJson(c, http.StatusNotFound, msg, nil) +} diff --git a/helper/slice_helper.go b/helper/slice_helper.go new file mode 100644 index 0000000..9ce4ca6 --- /dev/null +++ b/helper/slice_helper.go @@ -0,0 +1,41 @@ +// +// slice_helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +func InStringSlice(need string, haystack []string) bool { + for _, e := range haystack { + if e == need { + return true + } + } + return false +} + +func InIntSlice(need int, haystack []int) bool { + for _, e := range haystack { + if e == need { + return true + } + } + return false +} + +func IsAnySlice(v interface{}) bool { + _, ok := v.([]interface{}) + return ok +} + +func IsStringSlice(v interface{}) bool { + _, ok := v.([]string) + return ok +} + +func IsIntSlice(v interface{}) bool { + _, ok := v.([]int) + return ok +} diff --git a/helper/str_helper.go b/helper/str_helper.go new file mode 100644 index 0000000..51e34b3 --- /dev/null +++ b/helper/str_helper.go @@ -0,0 +1,120 @@ +// +// str_helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +import ( + "math/rand" + "time" + "unicode" + + "github.com/rs/xid" + "hexq.cn/tiglog/golib/crypto/gmd5" +) + +// 是否是字符串 +func IsString(v interface{}) bool { + _, ok := v.(string) + return ok +} + +// 随机字符串 +func RandString(n int) string { + letterRunes := []rune("1234567890abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + rand.Seed(time.Now().UnixNano()) + b := make([]rune, n) + for i := range b { + b[i] = letterRunes[rand.Intn(len(letterRunes))] + } + return string(b) +} + +// 首字母大写 +func UcFirst(str string) string { + for _, v := range str { + u := string(unicode.ToUpper(v)) + return u + str[len(u):] + } + return "" +} + +// 首字母小写 +func LcFirst(str string) string { + for _, v := range str { + u := string(unicode.ToLower(v)) + return u + str[len(u):] + } + return "" +} + +// 乱序字符串 +func Shuffle(str string) string { + if str == "" { + return str + } + + runes := []rune(str) + index := 0 + + for i := len(runes) - 1; i > 0; i-- { + index = rand.Intn(i + 1) + + if i != index { + runes[i], runes[index] = runes[index], runes[i] + } + } + return string(runes) +} + +// 反序字符串 +func Reverse(str string) string { + n := len(str) + runes := make([]rune, n) + for _, r := range str { + n-- + runes[n] = r + } + return string(runes[n:]) +} + +// 唯一字符串 +// 返回字符串长度为 20 +func GenId() string { + guid := xid.New() + return guid.String() +} + +// 可变长度的唯一字符串 +// 长度太短,可能就不唯一了 +// 长度大于等于 16 为最佳 +// 长度小于20时,为 GenId 的值 md5 后的前缀,因此,理论上前6位也能在大多数情况 +// 下唯一 +func Uniq(l int) string { + if l <= 0 { + panic("wrong length param") + } + ret := GenId() + hl := len(ret) + if l < hl { + t, err := gmd5.EncryptString(ret) + if err != nil { + return ret[hl-l:] + } + return t[:l] + } + mhac_len := 6 + pl := len(ret) + var hash string + for l > pl { + hash = GenId() + hash = hash[mhac_len:] + ret += hash + pl += len(hash) + } + // log.Println("ret=", ret, ", pl=", pl, ", l=", l) + return ret[0:l] +} diff --git a/helper/str_helper_test.go b/helper/str_helper_test.go new file mode 100644 index 0000000..9bfa1e3 --- /dev/null +++ b/helper/str_helper_test.go @@ -0,0 +1,120 @@ +// +// str_helper_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper_test + +import ( + "fmt" + "testing" + + "hexq.cn/tiglog/golib/gtest" + "hexq.cn/tiglog/golib/helper" +) + +func TestIsString(t *testing.T) { + v1 := 111 + r1 := helper.IsString(v1) + gtest.False(t, r1) + + v2 := "hello" + r2 := helper.IsString(v2) + gtest.True(t, r2) +} + +func TestRandString(t *testing.T) { + r1 := helper.RandString(10) + fmt.Println(r1) + gtest.NotEqual(t, "", r1) + gtest.Equal(t, 10, len(r1)) +} + +func TestUcFirst(t *testing.T) { + v1 := "hello" + r1 := helper.UcFirst(v1) + gtest.Equal(t, "Hello", r1) + + v2 := "hello world" + r2 := helper.UcFirst(v2) + gtest.Equal(t, "Hello world", r2) + + v3 := "helloWorld" + r3 := helper.UcFirst(v3) + gtest.Equal(t, "HelloWorld", r3) +} + +func TestGenId(t *testing.T) { + s1 := helper.GenId() + s2 := helper.GenId() + s3 := helper.GenId() + s4 := helper.GenId() + fmt.Println("gen id: ", s4) + gtest.NotNil(t, s1) + gtest.NotNil(t, s2) + gtest.NotNil(t, s3) + gtest.NotNil(t, s4) + gtest.NotEqual(t, s1, s2) + gtest.NotEqual(t, s1, s3) + gtest.NotEqual(t, s1, s4) + // fmt.Println(s1) + // fmt.Println(s2) + // fmt.Println(s3) + // fmt.Println(s4) + +} + +func TestUniq(t *testing.T) { + s1 := helper.Uniq(1) + fmt.Println("s1=", s1) + gtest.True(t, 1 == len(s1)) + + s12 := helper.Uniq(6) + s13 := helper.Uniq(6) + s14 := helper.Uniq(6) + s15 := helper.Uniq(6) + fmt.Println("s12..15", s12, s13, s14, s15) + gtest.NotNil(t, s12) + gtest.NotNil(t, s13) + gtest.NotNil(t, s14) + gtest.NotNil(t, s15) + gtest.NotEqual(t, s12, s13) + gtest.NotEqual(t, s12, s14) + gtest.NotEqual(t, s12, s15) + + s2 := helper.Uniq(16) + s3 := helper.Uniq(16) + s4 := helper.Uniq(16) + s5 := helper.Uniq(16) + gtest.NotNil(t, s2) + gtest.NotNil(t, s3) + gtest.NotNil(t, s4) + gtest.NotNil(t, s5) + gtest.NotEqual(t, s2, s3) + gtest.NotEqual(t, s2, s4) + gtest.NotEqual(t, s2, s5) + + s6 := helper.Uniq(32) + fmt.Println("s6=", s6) + s7 := helper.Uniq(32) + s8 := helper.Uniq(32) + s9 := helper.Uniq(32) + gtest.NotNil(t, s6) + gtest.NotNil(t, s7) + gtest.NotNil(t, s8) + gtest.NotNil(t, s9) + // fmt.Println("s6789=", s6, s7, s8, s9) + + s60 := helper.Uniq(64) + fmt.Println("s60=", s60) + s70 := helper.Uniq(64) + s80 := helper.Uniq(64) + s90 := helper.Uniq(64) + gtest.NotNil(t, s60) + gtest.NotNil(t, s70) + gtest.NotNil(t, s80) + gtest.NotNil(t, s90) + // fmt.Println(s60, s70, s80, s90) +} diff --git a/helper/time_helper.go b/helper/time_helper.go new file mode 100644 index 0000000..dcfda62 --- /dev/null +++ b/helper/time_helper.go @@ -0,0 +1,26 @@ +// +// time_helper.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package helper + +import "time" + +func Format(ts int64, layout string) string { + return time.Unix(ts, 0).Format(layout) +} + +func FormatDate(ts int64) string { + return Format(ts, "2006-01-02") +} + +func FormatDt(ts int64) string { + return Format(ts, "2006-01-02 15:04") +} + +func FormatDateTime(ts int64) string { + return Format(ts, "2006-01-02 15:04:05") +} diff --git a/logger/access.go b/logger/access.go new file mode 100644 index 0000000..a25203f --- /dev/null +++ b/logger/access.go @@ -0,0 +1,55 @@ +// +// access.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package logger + +import ( + "os" + "sync" + "time" + + "github.com/natefinch/lumberjack" + "github.com/rs/zerolog" + "github.com/rs/zerolog/pkgerrors" +) + +var access_once sync.Once + +var access_log *zerolog.Logger + +var access_log_path = "./var/log/access.log" + +func SetupAccessLogFile(path string) { + access_log_path = path +} + +func Access() *zerolog.Logger { + access_once.Do(func() { + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + zerolog.TimeFieldFormat = time.RFC3339Nano + + fileLogger := &lumberjack.Logger{ + Filename: access_log_path, + MaxSize: 5, // + MaxBackups: 10, + MaxAge: 14, + Compress: true, + } + + output := zerolog.MultiLevelWriter(os.Stderr, fileLogger) + + l := zerolog.New(output). + Level(zerolog.InfoLevel). + With(). + Timestamp(). + Logger() + + access_log = &l + }) + + return access_log +} diff --git a/logger/console.go b/logger/console.go new file mode 100644 index 0000000..b9b520c --- /dev/null +++ b/logger/console.go @@ -0,0 +1,55 @@ +// +// console.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package logger + +import ( + "os" + "sync" + "time" + + "github.com/natefinch/lumberjack" + "github.com/rs/zerolog" + "github.com/rs/zerolog/pkgerrors" +) + +var console_once sync.Once + +var console_log *zerolog.Logger + +var console_log_path = "./var/log/console.log" + +func SetupConsoleLogFile(path string) { + access_log_path = path +} + +func Console() *zerolog.Logger { + console_once.Do(func() { + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + zerolog.TimeFieldFormat = time.RFC3339Nano + + fileLogger := &lumberjack.Logger{ + Filename: console_log_path, + MaxSize: 5, // + MaxBackups: 10, + MaxAge: 14, + Compress: true, + } + + output := zerolog.MultiLevelWriter(os.Stderr, fileLogger) + + l := zerolog.New(output). + Level(zerolog.InfoLevel). + With(). + Timestamp(). + Logger() + + console_log = &l + }) + + return console_log +} diff --git a/logger/log.go b/logger/log.go new file mode 100644 index 0000000..a706038 --- /dev/null +++ b/logger/log.go @@ -0,0 +1,83 @@ +// +// log.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package logger + +import ( + "os" + "sync" + "time" + + "github.com/natefinch/lumberjack" + "github.com/rs/zerolog" + "github.com/rs/zerolog/pkgerrors" +) + +var once sync.Once + +var log_path = "./var/log/app.log" +var log_level = zerolog.InfoLevel + +func SetLogPath(path string) { + log_path = path +} +func SetLogLevel(level zerolog.Level) { + log_level = level +} + +var log *zerolog.Logger + +func Get() *zerolog.Logger { + once.Do(func() { + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + zerolog.TimeFieldFormat = time.RFC3339Nano + + fileLogger := &lumberjack.Logger{ + Filename: log_path, + MaxSize: 5, // + MaxBackups: 10, + MaxAge: 14, + Compress: true, + } + + output := zerolog.MultiLevelWriter(os.Stderr, fileLogger) + + l := zerolog.New(output). + Level(log_level). + With(). + Timestamp(). + Logger() + log = &l + }) + + return log +} + +func Debug(msg string) { + Get().Debug().Msg(msg) +} +func Debugf(format string, v ...interface{}) { + Get().Debug().Msgf(format, v...) +} +func Info(msg string) { + Get().Info().Msg(msg) +} +func Infof(format string, v ...interface{}) { + Get().Info().Msgf(format, v...) +} +func Warn(msg string) { + Get().Warn().Msg(msg) +} +func Warnf(format string, v ...interface{}) { + Get().Warn().Msgf(format, v...) +} +func Error(msg string) { + Get().Error().Msg(msg) +} +func Errorf(format string, v ...interface{}) { + Get().Error().Msgf(format, v...) +} diff --git a/logger/log_test.go b/logger/log_test.go new file mode 100644 index 0000000..794ceaf --- /dev/null +++ b/logger/log_test.go @@ -0,0 +1,22 @@ +// +// log_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package logger_test + +import ( + "testing" + + "hexq.cn/tiglog/golib/gtest" + "hexq.cn/tiglog/golib/logger" +) + +func TestLogToFile(t *testing.T) { + // logger.SetupLog("./var/log/test.log", zerolog.DebugLevel) + var log = logger.Get() + gtest.NotNil(t, log) + log.Log().Msg("hello world") +} diff --git a/logger/recover.go b/logger/recover.go new file mode 100644 index 0000000..a7d8e07 --- /dev/null +++ b/logger/recover.go @@ -0,0 +1,55 @@ +// +// recover.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package logger + +import ( + "os" + "sync" + "time" + + "github.com/natefinch/lumberjack" + "github.com/rs/zerolog" + "github.com/rs/zerolog/pkgerrors" +) + +var recover_once sync.Once + +var recover_log *zerolog.Logger + +var recover_log_path = "./var/log/recover.log" + +func SetupRecoverLogFile(path string) { + recover_log_path = path +} + +func Recover() *zerolog.Logger { + recover_once.Do(func() { + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + zerolog.TimeFieldFormat = time.RFC3339Nano + + fileLogger := &lumberjack.Logger{ + Filename: recover_log_path, + MaxSize: 5, // + MaxBackups: 10, + MaxAge: 14, + Compress: true, + } + + output := zerolog.MultiLevelWriter(os.Stderr, fileLogger) + + l := zerolog.New(output). + // Level(zerolog.InfoLevel). + With(). + Timestamp(). + Logger() + + recover_log = &l + }) + + return recover_log +} diff --git a/logger/work.go b/logger/work.go new file mode 100644 index 0000000..838c3ee --- /dev/null +++ b/logger/work.go @@ -0,0 +1,55 @@ +// +// work.go +// Copyright (C) 2023 tiglog +// +// Distributed under terms of the MIT license. +// + +package logger + +import ( + "os" + "sync" + "time" + + "github.com/natefinch/lumberjack" + "github.com/rs/zerolog" + "github.com/rs/zerolog/pkgerrors" +) + +var work_once sync.Once + +var work_log *zerolog.Logger + +var work_log_path = "./var/log/work.log" + +func SetupWorkLogFile(path string) { + access_log_path = path +} + +func Work() *zerolog.Logger { + work_once.Do(func() { + zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack + zerolog.TimeFieldFormat = time.RFC3339Nano + + fileLogger := &lumberjack.Logger{ + Filename: work_log_path, + MaxSize: 5, // + MaxBackups: 10, + MaxAge: 14, + Compress: true, + } + + output := zerolog.MultiLevelWriter(os.Stderr, fileLogger) + + l := zerolog.New(output). + Level(zerolog.InfoLevel). + With(). + Timestamp(). + Logger() + + work_log = &l + }) + + return work_log +} diff --git a/middleware/access_log.go b/middleware/access_log.go new file mode 100644 index 0000000..2ad2ccf --- /dev/null +++ b/middleware/access_log.go @@ -0,0 +1,66 @@ +// +// access_log.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package middleware + +import ( + "bytes" + "io" + "io/ioutil" + "time" + + "github.com/gin-gonic/gin" + "hexq.cn/tiglog/golib/logger" +) + +func GinLogger() gin.HandlerFunc { + log := logger.Access() + return func(c *gin.Context) { + start := time.Now() + path := c.Request.URL.Path + raw := c.Request.URL.RawQuery + latency := time.Since(start) + + // 处理请求 + c.Next() + + clientIP := c.ClientIP() + method := c.Request.Method + statusCode := c.Writer.Status() + + comment := c.Errors.ByType(gin.ErrorTypePrivate).String() + + if raw != "" { + path = path + "?" + raw + } + + l := log.Log() + if comment != "" { + l = log.Error() + } + var buf bytes.Buffer + tee := io.TeeReader(c.Request.Body, &buf) + requestBody, _ := ioutil.ReadAll(tee) + c.Request.Body = ioutil.NopCloser(&buf) + + l. + Str("proto", c.Request.Proto). + Str("server_name", c.Request.Host). + Str("content_type", c.Request.Header.Get("Content-Type")). + Str("user_agent", c.Request.UserAgent()). + Str("method", method). + Str("path", path). + Int("status_code", statusCode). + Str("client_ip", clientIP). + Dur("latency", latency) + + if gin.IsDebugging() { + l.Str("content", string(requestBody)) + } + l.Msg(comment) + } +} diff --git a/middleware/cors.go b/middleware/cors.go new file mode 100644 index 0000000..7a0c675 --- /dev/null +++ b/middleware/cors.go @@ -0,0 +1,36 @@ +// +// cors.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package middleware + +import ( + "net/http" + + "github.com/gin-gonic/gin" + "hexq.cn/tiglog/golib/helper" +) + +func NewCors(origins []string) gin.HandlerFunc { + return func(c *gin.Context) { + method := c.Request.Method + origin := c.Request.Header.Get("Origin") + if helper.InStringSlice(origin, origins) { + c.Header("Access-Control-Allow-Origin", origin) + c.Header("Access-Control-Allow-Headers", "Content-Type,X-CSRF-Token, Authorization") + c.Header("Access-Control-Allow-Methods", "POST,GET,OPTIONS,DELETE,PUT") + c.Header("Access-Control-Expose-Headers", "Content-Length, Content-Type") + c.Header("Access-Control-Allow-Credentials", "true") + } + + // 放行所有OPTIONS方法 + if method == "OPTIONS" { + c.AbortWithStatus(http.StatusNoContent) + } + // 处理请求 + c.Next() + } +} diff --git a/middleware/cors_test.go b/middleware/cors_test.go new file mode 100644 index 0000000..05ea464 --- /dev/null +++ b/middleware/cors_test.go @@ -0,0 +1,104 @@ +// +// cors_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package middleware_test + +import ( + "context" + "net/http" + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" + "hexq.cn/tiglog/golib/gtest" + "hexq.cn/tiglog/golib/middleware" +) + +func newTestRouter(origins []string) *gin.Engine { + gin.SetMode(gin.TestMode) + router := gin.New() + router.Use(middleware.NewCors(origins)) + router.GET("/", func(c *gin.Context) { + c.String(http.StatusOK, "get") + }) + router.POST("/", func(c *gin.Context) { + c.String(http.StatusOK, "post") + }) + router.PATCH("/", func(c *gin.Context) { + c.String(http.StatusOK, "patch") + }) + return router +} + +func performRequest(r http.Handler, method, origin string) *httptest.ResponseRecorder { + return performRequestWithHeaders(r, method, origin, http.Header{}) +} + +func performRequestWithHeaders(r http.Handler, method, origin string, header http.Header) *httptest.ResponseRecorder { + req, _ := http.NewRequestWithContext(context.Background(), method, "/", nil) + // From go/net/http/request.go: + // For incoming requests, the Host header is promoted to the + // Request.Host field and removed from the Header map. + req.Host = header.Get("Host") + header.Del("Host") + if len(origin) > 0 { + header.Set("Origin", origin) + } + req.Header = header + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + return w +} + +func TestPassesAllowOrigins(t *testing.T) { + router := newTestRouter([]string{"http://google.com"}) + + // no CORS request, origin == "" + w := performRequest(router, "GET", "") + gtest.Equal(t, "get", w.Body.String()) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + gtest.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) + + // no CORS request, origin == host + h := http.Header{} + h.Set("Host", "facebook.com") + w = performRequestWithHeaders(router, "GET", "http://facebook.com", h) + gtest.Equal(t, "get", w.Body.String()) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + gtest.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) + + // allowed CORS request + w = performRequest(router, "GET", "http://google.com") + gtest.Equal(t, "get", w.Body.String()) + gtest.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin")) + gtest.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) + + // deny CORS request + w = performRequest(router, "GET", "https://google.com") + // gtest.Equal(t, http.StatusForbidden, w.Code) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + gtest.Empty(t, w.Header().Get("Access-Control-Expose-Headers")) + + // allowed CORS prefligh request + w = performRequest(router, "OPTIONS", "http://google.com") + gtest.Equal(t, http.StatusNoContent, w.Code) + gtest.Equal(t, "http://google.com", w.Header().Get("Access-Control-Allow-Origin")) + gtest.Equal(t, "true", w.Header().Get("Access-Control-Allow-Credentials")) + // gtest.Equal(t, "GET,POST,PUT,HEAD", w.Header().Get("Access-Control-Allow-Methods")) + + // deny CORS prefligh request + w = performRequest(router, "OPTIONS", "http://example.com") + // gtest.Equal(t, http.StatusForbidden, w.Code) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Origin")) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Credentials")) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Methods")) + gtest.Empty(t, w.Header().Get("Access-Control-Allow-Headers")) + gtest.Empty(t, w.Header().Get("Access-Control-Max-Age")) +} diff --git a/middleware/recover_log.go b/middleware/recover_log.go new file mode 100644 index 0000000..413c639 --- /dev/null +++ b/middleware/recover_log.go @@ -0,0 +1,54 @@ +// +// recover_log.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package middleware + +import ( + "net" + "net/http" + "net/http/httputil" + "os" + "strings" + + "github.com/gin-gonic/gin" + "hexq.cn/tiglog/golib/logger" +) + +func GinRecover() gin.HandlerFunc { + var log = logger.Recover() + return func(c *gin.Context) { + defer func() { + if err := recover(); err != nil { + // Check for a broken connection, as it is not really a + // condition that warrants a panic stack trace. + var brokenPipe bool + if ne, ok := err.(*net.OpError); ok { + if se, ok := ne.Err.(*os.SyscallError); ok { + if strings.Contains(strings.ToLower(se.Error()), "broken pipe") || + strings.Contains(strings.ToLower(se.Error()), "connection reset by peer") { + brokenPipe = true + } + } + } + + httpRequest, _ := httputil.DumpRequest(c.Request, false) + if brokenPipe { + log.Error().Err(err.(error)).Str("request", string(httpRequest)).Msg(c.Request.URL.Path) + // If the connection is dead, we can't write a status to it. + c.Error(err.(error)) // nolint: errcheck + c.Abort() + return + } + var er = err.(error) + log.Error().Stack().Str("request", string(httpRequest)).Err(er).Msg("Recovery from panic") + c.AbortWithError(http.StatusInternalServerError, er) + } + }() + c.Next() + } + +} diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..de523c6 --- /dev/null +++ b/readme.md @@ -0,0 +1,4 @@ +说明 +==== + +lib of go web project. diff --git a/storage/adapter_local.go b/storage/adapter_local.go new file mode 100644 index 0000000..c8a1ce7 --- /dev/null +++ b/storage/adapter_local.go @@ -0,0 +1,61 @@ +// +// adapter_local.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package storage + +import ( + "fmt" + "io" + "mime/multipart" + "os" + "path/filepath" + "time" +) + +type LocalStorage struct { + rootDir string +} + +func NewLocalStorage(dir string) *LocalStorage { + return &LocalStorage{rootDir: dir} +} + +func (s *LocalStorage) GetName() string { + return "local" +} + +func (s *LocalStorage) Upload(upfile *multipart.FileHeader) (string, error) { + now := time.Now() + path := fmt.Sprintf("%d/%d/%d", now.Year(), now.Month(), now.Day()) + fp := filepath.Join(s.rootDir, path) + if _, err := os.Stat(fp); err != nil { + os.MkdirAll(fp, 0755) + } + outfp := filepath.Join(fp, upfile.Filename) + fd, err := os.Create(outfp) + if err != nil { + return "", err + } + defer fd.Close() + + ifd, err := upfile.Open() + if err != nil { + return "", err + } + defer ifd.Close() + + io.Copy(fd, ifd) + return filepath.Join(path, upfile.Filename), nil +} + +func (s *LocalStorage) GetBaseDir() string { + return s.rootDir +} + +func (s *LocalStorage) GetFullPath(path string) string { + return filepath.Join(s.GetBaseDir(), path) +} diff --git a/storage/readme.adoc b/storage/readme.adoc new file mode 100644 index 0000000..8923820 --- /dev/null +++ b/storage/readme.adoc @@ -0,0 +1,17 @@ += 存储 +:author: tiglog +:experimental: +:toc: left +:toclevels: 3 +:toc-title: 目录 +:sectnums: +:icons: font +:!webfonts: +:autofit-option: +:source-highlighter: rouge +:rouge-style: github +:source-linenums-option: +:revdate: 2022-11-27 +:imagesdir: ./img + +存储实现。 diff --git a/storage/storage.go b/storage/storage.go new file mode 100644 index 0000000..5843d18 --- /dev/null +++ b/storage/storage.go @@ -0,0 +1,16 @@ +// +// storage.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package storage + +type Engine struct { + IStorageAdapter +} + +func NewStorage(adapter IStorageAdapter) *Engine { + return &Engine{adapter} +} diff --git a/storage/storage_contract.go b/storage/storage_contract.go new file mode 100644 index 0000000..944b62e --- /dev/null +++ b/storage/storage_contract.go @@ -0,0 +1,17 @@ +// +// storage_contact.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package storage + +import "mime/multipart" + +type IStorageAdapter interface { + GetName() string + Upload(upfile *multipart.FileHeader) (string, error) + GetBaseDir() string + GetFullPath(path string) string +} diff --git a/storage/storage_local_test.go b/storage/storage_local_test.go new file mode 100644 index 0000000..a993c42 --- /dev/null +++ b/storage/storage_local_test.go @@ -0,0 +1,78 @@ +// +// storage_local_test.go +// Copyright (C) 2022 tiglog +// +// Distributed under terms of the MIT license. +// + +package storage_test + +import ( + "bytes" + "fmt" + "io" + "mime/multipart" + "net/http/httptest" + "os" + "path/filepath" + "strings" + "testing" + + "hexq.cn/tiglog/golib/gtest" + "hexq.cn/tiglog/golib/storage" +) + +func getLocalStorage() *storage.Engine { + ls := storage.NewLocalStorage("/tmp") + return storage.NewStorage(ls) +} + +func TestNewStorage(t *testing.T) { + s := getLocalStorage() + name := s.GetName() + gtest.Equal(t, "local", name) +} + +func getFileheader() (*multipart.FileHeader, error) { + path := "testdata/hello.txt" + file, err := os.Open(path) + if err != nil { + return nil, err + } + + defer file.Close() + body := &bytes.Buffer{} + writer := multipart.NewWriter(body) + part, err := writer.CreateFormFile("my_file", filepath.Base(path)) + if err != nil { + writer.Close() + return nil, err + } + io.Copy(part, file) + writer.Close() + + req := httptest.NewRequest("POST", "/upload", body) + req.Header.Set("Content-Type", writer.FormDataContentType()) + _, h, err := req.FormFile("my_file") + if err != nil { + return nil, err + } + return h, nil +} + +func TestUpload(t *testing.T) { + s := getLocalStorage() + upfile, err := getFileheader() + gtest.Nil(t, err) + + path, err2 := s.Upload(upfile) + gtest.Nil(t, err2) + fmt.Println("path is ", path) + gtest.NotEqual(t, "", path) + tmp := strings.Split(path, "/") + fmt.Println(tmp) + if len(tmp) > 1 { + dir := filepath.Join(s.GetBaseDir(), tmp[0]) + os.RemoveAll(dir) + } +} diff --git a/storage/testdata/hello.txt b/storage/testdata/hello.txt new file mode 100644 index 0000000..3b18e51 --- /dev/null +++ b/storage/testdata/hello.txt @@ -0,0 +1 @@ +hello world