290 lines
6.0 KiB
Go
290 lines
6.0 KiB
Go
package postgresql
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"sort"
|
|
"strings"
|
|
"time"
|
|
"unicode"
|
|
)
|
|
|
|
// scanner implements a tokenizer for libpq-style option strings.
|
|
type scanner struct {
|
|
s []rune
|
|
i int
|
|
}
|
|
|
|
// Next returns the next rune. It returns 0, false if the end of the text has
|
|
// been reached.
|
|
func (s *scanner) Next() (rune, bool) {
|
|
if s.i >= len(s.s) {
|
|
return 0, false
|
|
}
|
|
r := s.s[s.i]
|
|
s.i++
|
|
return r, true
|
|
}
|
|
|
|
// SkipSpaces returns the next non-whitespace rune. It returns 0, false if the
|
|
// end of the text has been reached.
|
|
func (s *scanner) SkipSpaces() (rune, bool) {
|
|
r, ok := s.Next()
|
|
for unicode.IsSpace(r) && ok {
|
|
r, ok = s.Next()
|
|
}
|
|
return r, ok
|
|
}
|
|
|
|
type values map[string]string
|
|
|
|
func (vs values) Set(k, v string) {
|
|
vs[k] = v
|
|
}
|
|
|
|
func (vs values) Get(k string) (v string) {
|
|
return vs[k]
|
|
}
|
|
|
|
func (vs values) Isset(k string) bool {
|
|
_, ok := vs[k]
|
|
return ok
|
|
}
|
|
|
|
// ConnectionURL represents a parsed PostgreSQL connection URL.
|
|
//
|
|
// You can use a ConnectionURL struct as an argument for Open:
|
|
//
|
|
// var settings = postgresql.ConnectionURL{
|
|
// Host: "localhost", // PostgreSQL server IP or name.
|
|
// Database: "peanuts", // Database name.
|
|
// User: "cbrown", // Optional user name.
|
|
// Password: "snoopy", // Optional user password.
|
|
// }
|
|
//
|
|
// sess, err = postgresql.Open(settings)
|
|
//
|
|
// If you already have a valid DSN, you can use ParseURL to convert it into
|
|
// a ConnectionURL before passing it to Open.
|
|
type ConnectionURL struct {
|
|
User string
|
|
Password string
|
|
Host string
|
|
Socket string
|
|
Database string
|
|
Options map[string]string
|
|
|
|
timezone *time.Location
|
|
}
|
|
|
|
var escaper = strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
|
|
|
|
// ParseURL parses the given DSN into a ConnectionURL struct.
|
|
// A typical PostgreSQL connection URL looks like:
|
|
//
|
|
// postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full
|
|
func ParseURL(s string) (u *ConnectionURL, err error) {
|
|
o := make(values)
|
|
|
|
if strings.HasPrefix(s, "postgres://") || strings.HasPrefix(s, "postgresql://") {
|
|
s, err = parseURL(s)
|
|
if err != nil {
|
|
return u, err
|
|
}
|
|
}
|
|
|
|
if err := parseOpts(s, o); err != nil {
|
|
return u, err
|
|
}
|
|
u = &ConnectionURL{}
|
|
|
|
u.User = o.Get("user")
|
|
u.Password = o.Get("password")
|
|
|
|
h := o.Get("host")
|
|
p := o.Get("port")
|
|
|
|
if strings.HasPrefix(h, "/") {
|
|
u.Socket = h
|
|
} else {
|
|
if p == "" {
|
|
u.Host = h
|
|
} else {
|
|
u.Host = fmt.Sprintf("%s:%s", h, p)
|
|
}
|
|
}
|
|
|
|
u.Database = o.Get("dbname")
|
|
|
|
u.Options = make(map[string]string)
|
|
|
|
for k := range o {
|
|
switch k {
|
|
case "user", "password", "host", "port", "dbname":
|
|
// Skip
|
|
default:
|
|
u.Options[k] = o[k]
|
|
}
|
|
}
|
|
|
|
if timezone, ok := u.Options["timezone"]; ok {
|
|
u.timezone, _ = time.LoadLocation(timezone)
|
|
}
|
|
|
|
return u, err
|
|
}
|
|
|
|
// parseOpts parses the options from name and adds them to the values.
|
|
//
|
|
// The parsing code is based on conninfo_parse from libpq's fe-connect.c
|
|
func parseOpts(name string, o values) error {
|
|
s := newScanner(name)
|
|
|
|
for {
|
|
var (
|
|
keyRunes, valRunes []rune
|
|
r rune
|
|
ok bool
|
|
)
|
|
|
|
if r, ok = s.SkipSpaces(); !ok {
|
|
break
|
|
}
|
|
|
|
// Scan the key
|
|
for !unicode.IsSpace(r) && r != '=' {
|
|
keyRunes = append(keyRunes, r)
|
|
if r, ok = s.Next(); !ok {
|
|
break
|
|
}
|
|
}
|
|
|
|
// Skip any whitespace if we're not at the = yet
|
|
if r != '=' {
|
|
r, ok = s.SkipSpaces()
|
|
}
|
|
|
|
// The current character should be =
|
|
if r != '=' || !ok {
|
|
return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
|
|
}
|
|
|
|
// Skip any whitespace after the =
|
|
if r, ok = s.SkipSpaces(); !ok {
|
|
// If we reach the end here, the last value is just an empty string as per libpq.
|
|
o.Set(string(keyRunes), "")
|
|
break
|
|
}
|
|
|
|
if r != '\'' {
|
|
for !unicode.IsSpace(r) {
|
|
if r == '\\' {
|
|
if r, ok = s.Next(); !ok {
|
|
return fmt.Errorf(`missing character after backslash`)
|
|
}
|
|
}
|
|
valRunes = append(valRunes, r)
|
|
|
|
if r, ok = s.Next(); !ok {
|
|
break
|
|
}
|
|
}
|
|
} else {
|
|
quote:
|
|
for {
|
|
if r, ok = s.Next(); !ok {
|
|
return fmt.Errorf(`unterminated quoted string literal in connection string`)
|
|
}
|
|
switch r {
|
|
case '\'':
|
|
break quote
|
|
case '\\':
|
|
r, _ = s.Next()
|
|
fallthrough
|
|
default:
|
|
valRunes = append(valRunes, r)
|
|
}
|
|
}
|
|
}
|
|
|
|
o.Set(string(keyRunes), string(valRunes))
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// newScanner returns a new scanner initialized with the option string s.
|
|
func newScanner(s string) *scanner {
|
|
return &scanner{[]rune(s), 0}
|
|
}
|
|
|
|
// ParseURL no longer needs to be used by clients of this library since supplying a URL as a
|
|
// connection string to sql.Open() is now supported:
|
|
//
|
|
// sql.Open("postgres", "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full")
|
|
//
|
|
// It remains exported here for backwards-compatibility.
|
|
//
|
|
// ParseURL converts a url to a connection string for driver.Open.
|
|
// Example:
|
|
//
|
|
// "postgres://bob:secret@1.2.3.4:5432/mydb?sslmode=verify-full"
|
|
//
|
|
// converts to:
|
|
//
|
|
// "user=bob password=secret host=1.2.3.4 port=5432 dbname=mydb sslmode=verify-full"
|
|
//
|
|
// A minimal example:
|
|
//
|
|
// "postgres://"
|
|
//
|
|
// # This will be blank, causing driver.Open to use all of the defaults
|
|
//
|
|
// NOTE: vendored/copied from github.com/lib/pq
|
|
func parseURL(uri string) (string, error) {
|
|
u, err := url.Parse(uri)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
if u.Scheme != "postgres" && u.Scheme != "postgresql" {
|
|
return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
|
|
}
|
|
|
|
var kvs []string
|
|
escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
|
|
accrue := func(k, v string) {
|
|
if v != "" {
|
|
kvs = append(kvs, k+"="+escaper.Replace(v))
|
|
}
|
|
}
|
|
|
|
if u.User != nil {
|
|
v := u.User.Username()
|
|
accrue("user", v)
|
|
|
|
v, _ = u.User.Password()
|
|
accrue("password", v)
|
|
}
|
|
|
|
if host, port, err := net.SplitHostPort(u.Host); err != nil {
|
|
accrue("host", u.Host)
|
|
} else {
|
|
accrue("host", host)
|
|
accrue("port", port)
|
|
}
|
|
|
|
if u.Path != "" {
|
|
accrue("dbname", u.Path[1:])
|
|
}
|
|
|
|
q := u.Query()
|
|
for k := range q {
|
|
accrue(k, q.Get(k))
|
|
}
|
|
|
|
sort.Strings(kvs) // Makes testing easier (not a performance concern)
|
|
return strings.Join(kvs, " "), nil
|
|
}
|