diff --git a/goatherd.go b/goatherd.go index da00c55f9c7608b3353f796f6a8aaa3929107788..e0d20bbf70d7224cf36c6055096c6c77a78f96b6 100644 --- a/goatherd.go +++ b/goatherd.go @@ -16,6 +16,7 @@ import ( "time" "github.com/gokyle/hotp" + "github.com/lib/pq" "github.com/mattn/go-sqlite3" ) @@ -45,6 +46,7 @@ type tls_cfg struct { Key, Cert string } var cfg struct { + Db_driver string Db_url string Lookahead uint64 Debug bool @@ -112,10 +114,14 @@ type db_conn interface { func create_table(db *sql.DB) { debug("Creating table 'users' in DB") + secret_type := "BLOB" + if cfg.Db_driver == "postgres" { + secret_type = "BYTEA" + } _, err := db.Exec(` CREATE TABLE users( name TEXT NOT NULL PRIMARY KEY, - secret BLOB NOT NULL, + secret ` + secret_type + ` NOT NULL, count INTEGER) `) err_fatalf("Failed to create table: %v\n", err) @@ -141,18 +147,25 @@ func create_user(db *sql.DB, name string, secret_b64 string) { secret, err := base64.StdEncoding.DecodeString(secret_b64) err_fatalf("Can't decode secret: %v\n", err) + var q string + if cfg.Db_driver == "sqlite3" { + q = "REPLACE INTO users(name, secret, count) values($1, $2, $3)" + } else if cfg.Db_driver == "postgres" { + q = `INSERT INTO users(name, secret, count) values($1, $2, $3) + ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3` + } else { + log.Panic("Unkown driver: ", cfg.Db_driver) + } + debug("Adding user with name", name) - _, err = db.Exec(` - REPLACE INTO users(name, secret, count) values($1, $2, $3)`, - // ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3`, - name, secret, 1) + _, err = db.Exec(q, name, secret, 1) err_fatalf("Failed to create user: %v\n", err) } func get_user(db db_conn, name string) (secret []byte, count uint64, err error) { - stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = ?") + stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1") if err != nil { return } @@ -164,7 +177,7 @@ func get_user(db db_conn, name string) (secret []byte, count uint64, err error) } func set_count(db db_conn, name string, count uint64) error { - inc, err := db.Prepare("UPDATE users SET count = ? WHERE name = ?") + inc, err := db.Prepare("UPDATE users SET count = $1 WHERE name = $2") if err != nil { return err } @@ -182,10 +195,18 @@ func get_otp(db db_conn, name string) (*hotp.HOTP, error) { } func transaction_failed(err error) bool { - if err, ok := err.(sqlite3.Error); ok { + if err == nil { + return false + } + + switch err := err.(type) { + case sqlite3.Error: return err.Code == sqlite3.ErrLocked || err.Code == sqlite3.ErrBusy + case *pq.Error: + return err.Code == "40001" // serialization_failure + default: + return false } - return false } // Retrieve secret and count for given username and try to find a match within @@ -399,6 +420,7 @@ func main() { // also settable in config file flag.StringVar(&cfg.Db_url, "db-url", ":memory:", "URL used to connect to the database.") + flag.StringVar(&cfg.Db_driver, "db-driver", "sqlite3", "Name of the database driver.") flag.Uint64Var(&cfg.Lookahead, "lookahead", 10, "Counter range to check for matching OTPs.") flag.BoolVar(&cfg.Debug, "debug", false, "Enable debug output.") flag.DurationVar(&cfg.Faildelay.Duration, "faildelay", 1*time.Second, @@ -416,7 +438,7 @@ func main() { // are not set again flag.Parse() - db, err := sql.Open("sqlite3", cfg.Db_url) + db, err := sql.Open(cfg.Db_driver, cfg.Db_url) err_fatal(err) // default action is to serve, but not if one of the other actions is given diff --git a/goatherd_test.go b/goatherd_test.go index c1ddc74d0a4d8fdc05981f0b5f70f36d5287e101..5ce6225c8cef407065b137ae6fee08f3bc623741 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -7,6 +7,7 @@ import ( "encoding/base64" "fmt" "io" + "os" "sync" "testing" @@ -80,15 +81,21 @@ func transaction_conflict_t(t *testing.T) { defer func(i int) { t.Log("rollback", i); txs[i].Rollback() }(i) } - for i, tx := range txs { - t.Log("update", i) - _, err = tx.Exec("UPDATE users SET count = ? WHERE name == ?", uint64(i), username) + var c uint64 + for _, tx := range txs { + err = tx.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c) if transaction_failed(err) { return } else { t_err_fatal(t, err) } } for i, tx := range txs { + t.Log("update", i) + _, err := tx.Exec("UPDATE users SET count = $1 WHERE name = $2", c + uint64(i), username) + if transaction_failed(err) { + return + } else { t_err_fatal(t, err) } + t.Log("commit", i) err = tx.Commit() if transaction_failed(err) { @@ -292,6 +299,13 @@ func TestMain(t *testing.T) { // multiple connections only talk to the same in-memory database if a shared // cache is used, otherwise a new database is created cfg.Db_url = "file::memory:?cache=shared" + if db_url, ok := os.LookupEnv("DB_URL"); ok { + cfg.Db_url = db_url + } + cfg.Db_driver = "sqlite3" + if db_driver, ok := os.LookupEnv("DB_DRIVER"); ok { + cfg.Db_driver = db_driver + } cfg.Lookahead = 3 cfg.Debug = false faildelay.userlocks = make(map[string]*sync.Mutex) @@ -299,7 +313,7 @@ func TestMain(t *testing.T) { stdin_scanner, stdin_writer = bufio.NewScanner(stdin_r), bufio.NewWriter(stdin_w) var err error - db, err = sql.Open("sqlite3", cfg.Db_url) + db, err = sql.Open(cfg.Db_driver, cfg.Db_url) t_err_fatal(t, err) create_table(db)