From 71ab0f62a77d59157bec0b7215e17b8bafed9095 Mon Sep 17 00:00:00 2001
From: Lukas Braun <lukas.braun@fau.de>
Date: Mon, 6 Mar 2017 22:43:07 +0100
Subject: [PATCH] Add support for PostgreSQL

The transaction_conflict test in its previous form would deadlock
because PostgreSQL locks tables/rows instead of aborting on every
conflict. The fix is to force a rollback by first reading and then
updating in both transactions, which would result in a lost update
otherwise.
---
 goatherd.go      | 42 ++++++++++++++++++++++++++++++++----------
 goatherd_test.go | 22 ++++++++++++++++++----
 2 files changed, 50 insertions(+), 14 deletions(-)

diff --git a/goatherd.go b/goatherd.go
index da00c55..e0d20bb 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -16,6 +16,7 @@ import (
     "time"
 
     "github.com/gokyle/hotp"
+    "github.com/lib/pq"
     "github.com/mattn/go-sqlite3"
 )
 
@@ -45,6 +46,7 @@ type tls_cfg struct {
     Key, Cert string
 }
 var cfg struct {
+    Db_driver string
     Db_url string
     Lookahead uint64
     Debug bool
@@ -112,10 +114,14 @@ type db_conn interface {
 func create_table(db *sql.DB) {
     debug("Creating table 'users' in DB")
 
+    secret_type := "BLOB"
+    if cfg.Db_driver == "postgres" {
+        secret_type = "BYTEA"
+    }
     _, err := db.Exec(`
         CREATE TABLE users(
             name TEXT NOT NULL PRIMARY KEY,
-            secret BLOB NOT NULL,
+            secret ` + secret_type + ` NOT NULL,
             count INTEGER)
         `)
     err_fatalf("Failed to create table: %v\n", err)
@@ -141,18 +147,25 @@ func create_user(db *sql.DB, name string, secret_b64 string) {
     secret, err := base64.StdEncoding.DecodeString(secret_b64)
     err_fatalf("Can't decode secret: %v\n", err)
 
+    var q string
+    if cfg.Db_driver == "sqlite3" {
+        q = "REPLACE INTO users(name, secret, count) values($1, $2, $3)"
+    } else if cfg.Db_driver == "postgres" {
+        q = `INSERT INTO users(name, secret, count) values($1, $2, $3)
+            ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3`
+    } else {
+        log.Panic("Unkown driver: ", cfg.Db_driver)
+    }
+
     debug("Adding user with name", name)
-    _, err = db.Exec(`
-        REPLACE INTO users(name, secret, count) values($1, $2, $3)`,
-        // ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3`,
-        name, secret, 1)
+    _, err = db.Exec(q, name, secret, 1)
     err_fatalf("Failed to create user: %v\n", err)
 }
 
 
 
 func get_user(db db_conn, name string) (secret []byte, count uint64, err error) {
-    stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = ?")
+    stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1")
     if err != nil {
         return
     }
@@ -164,7 +177,7 @@ func get_user(db db_conn, name string) (secret []byte, count uint64, err error)
 }
 
 func set_count(db db_conn, name string, count uint64) error {
-    inc, err := db.Prepare("UPDATE users SET count = ? WHERE name = ?")
+    inc, err := db.Prepare("UPDATE users SET count = $1 WHERE name = $2")
     if err != nil {
         return err
     }
@@ -182,10 +195,18 @@ func get_otp(db db_conn, name string) (*hotp.HOTP, error) {
 }
 
 func transaction_failed(err error) bool {
-    if err, ok := err.(sqlite3.Error); ok {
+    if err == nil {
+        return false
+    }
+
+    switch err := err.(type) {
+    case sqlite3.Error:
         return err.Code == sqlite3.ErrLocked || err.Code == sqlite3.ErrBusy
+    case *pq.Error:
+        return err.Code == "40001" // serialization_failure
+    default:
+        return false
     }
-    return false
 }
 
 // Retrieve secret and count for given username and try to find a match within
@@ -399,6 +420,7 @@ func main() {
 
     // also settable in config file
     flag.StringVar(&cfg.Db_url, "db-url", ":memory:", "URL used to connect to the database.")
+    flag.StringVar(&cfg.Db_driver, "db-driver", "sqlite3", "Name of the database driver.")
     flag.Uint64Var(&cfg.Lookahead, "lookahead", 10, "Counter range to check for matching OTPs.")
     flag.BoolVar(&cfg.Debug, "debug", false, "Enable debug output.")
     flag.DurationVar(&cfg.Faildelay.Duration, "faildelay", 1*time.Second,
@@ -416,7 +438,7 @@ func main() {
     // are not set again
     flag.Parse()
 
-    db, err := sql.Open("sqlite3", cfg.Db_url)
+    db, err := sql.Open(cfg.Db_driver, cfg.Db_url)
     err_fatal(err)
 
     // default action is to serve, but not if one of the other actions is given
diff --git a/goatherd_test.go b/goatherd_test.go
index c1ddc74..5ce6225 100644
--- a/goatherd_test.go
+++ b/goatherd_test.go
@@ -7,6 +7,7 @@ import (
     "encoding/base64"
     "fmt"
     "io"
+    "os"
     "sync"
     "testing"
 
@@ -80,15 +81,21 @@ func transaction_conflict_t(t *testing.T) {
         defer func(i int) { t.Log("rollback", i); txs[i].Rollback() }(i)
     }
 
-    for i, tx := range txs {
-        t.Log("update", i)
-        _, err = tx.Exec("UPDATE users SET count = ? WHERE name == ?", uint64(i), username)
+    var c uint64
+    for _, tx := range txs {
+        err = tx.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c)
         if transaction_failed(err) {
             return
         } else { t_err_fatal(t, err) }
     }
 
     for i, tx := range txs {
+        t.Log("update", i)
+        _, err := tx.Exec("UPDATE users SET count = $1 WHERE name = $2", c + uint64(i), username)
+        if transaction_failed(err) {
+            return
+        } else { t_err_fatal(t, err) }
+
         t.Log("commit", i)
         err = tx.Commit()
         if transaction_failed(err) {
@@ -292,6 +299,13 @@ func TestMain(t *testing.T) {
     // multiple connections only talk to the same in-memory database if a shared
     // cache is used, otherwise a new database is created
     cfg.Db_url = "file::memory:?cache=shared"
+    if db_url, ok := os.LookupEnv("DB_URL"); ok {
+        cfg.Db_url = db_url
+    }
+    cfg.Db_driver = "sqlite3"
+    if db_driver, ok := os.LookupEnv("DB_DRIVER"); ok {
+        cfg.Db_driver = db_driver
+    }
     cfg.Lookahead = 3
     cfg.Debug = false
     faildelay.userlocks = make(map[string]*sync.Mutex)
@@ -299,7 +313,7 @@ func TestMain(t *testing.T) {
     stdin_scanner, stdin_writer = bufio.NewScanner(stdin_r), bufio.NewWriter(stdin_w)
 
     var err error
-    db, err = sql.Open("sqlite3", cfg.Db_url)
+    db, err = sql.Open(cfg.Db_driver, cfg.Db_url)
     t_err_fatal(t, err)
     create_table(db)
 
-- 
GitLab