From bc7f25aab05c132baffe0f9f33e4621bea40712b Mon Sep 17 00:00:00 2001 From: Lukas Braun <lukas.braun@fau.de> Date: Tue, 7 Mar 2017 01:14:11 +0100 Subject: [PATCH] Postgres: set default isolation level serializable In the process, make the database connection object global because we have to reconnect in order for the database settings to take effect and thus have to reassign the variable. --- goatherd.go | 53 ++++++++++++++++++++++++++++++++++++------------ goatherd_test.go | 19 ++++++++--------- 2 files changed, 49 insertions(+), 23 deletions(-) diff --git a/goatherd.go b/goatherd.go index e0d20bb..eb05e9d 100644 --- a/goatherd.go +++ b/goatherd.go @@ -66,6 +66,8 @@ var faildelay struct { userlocks map[string]*sync.Mutex } +var db *sql.DB + func debug(v ...interface{}) { if cfg.Debug { log.Println(v...) } @@ -111,7 +113,7 @@ type db_conn interface { } -func create_table(db *sql.DB) { +func create_table() { debug("Creating table 'users' in DB") secret_type := "BLOB" @@ -127,9 +129,34 @@ func create_table(db *sql.DB) { err_fatalf("Failed to create table: %v\n", err) } +func pg_set_default_isolation() { + var dbname string + err := db.QueryRow("SELECT current_database()").Scan(&dbname) + err_fatalf("SELECT current_database(): %v\n", err) + + _, err = db.Exec(` + ALTER DATABASE ` + dbname + ` + SET default_transaction_isolation TO "serializable" + `) + err_fatalf("Failed to set default_transaction_isolation for DB: %v\n", err) + + // database settings only take effect on future sessions -> reconnect + err_fatal(db.Close()) + db, err = sql.Open(cfg.Db_driver, cfg.Db_url) + err_fatal(err) +} + +func init_db() { + if cfg.Db_driver == "postgres" { + pg_set_default_isolation() + } + + create_table() +} + // global var for testing, not modified during normal execution var stdin_scanner = bufio.NewScanner(os.Stdin) -func create_user(db *sql.DB, name string, secret_b64 string) { +func create_user(name string, secret_b64 string) { debug("Creating user") var err error @@ -212,7 +239,7 @@ func transaction_failed(err error) bool { // Retrieve secret and count for given username and try to find a match within // the lookahead range. Update count in DB if match is found. All within a // transaction, retrying if it fails. -func check_offer(db *sql.DB, remote string, name string, offer string) (bool, error) { +func check_offer(remote string, name string, offer string) (bool, error) { for { debugf("[%v] begin transaction", remote) tx, err := db.Begin() @@ -264,7 +291,7 @@ retry: } } -func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, +func handle_conn(remote string, reader *bufio.Reader, writer *bufio.Writer) (delay *sync.Mutex) { s := bufio.NewScanner(reader) b := make([]byte, 80) @@ -318,7 +345,7 @@ func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, debugf("[%v] checking for match", remote) result := "FAIL" - match, err := check_offer(db, remote, name, offer) + match, err := check_offer(remote, name, offer) if err != nil { log.Panic(err) } else if match { @@ -347,7 +374,7 @@ func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, return } -func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { +func listen(wg sync.WaitGroup, listener net.Listener) { defer wg.Done() for { @@ -362,7 +389,7 @@ func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { reader := bufio.NewReader(conn) writer := bufio.NewWriter(conn) - delay := handle_conn(db, remote, reader, writer) + delay := handle_conn(remote, reader, writer) debugf("[%v] closing", remote) conn.Close() // XXX: check err? @@ -376,7 +403,7 @@ func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { } } -func serve(db *sql.DB) { +func serve() { faildelay.userlocks = make(map[string]*sync.Mutex) var wg sync.WaitGroup @@ -400,7 +427,7 @@ func serve(db *sql.DB) { log.Println("Listening on", listen_addr) wg.Add(1) - go listen(db, wg, listener) + go listen(wg, listener) } wg.Wait() @@ -438,7 +465,7 @@ func main() { // are not set again flag.Parse() - db, err := sql.Open(cfg.Db_driver, cfg.Db_url) + db, err = sql.Open(cfg.Db_driver, cfg.Db_url) err_fatal(err) // default action is to serve, but not if one of the other actions is given @@ -462,15 +489,15 @@ func main() { if *flag_init_db { serve_default = false - create_table(db) + init_db() } if *flag_add_user != "" { serve_default = false - create_user(db, *flag_add_user, *flag_secret) + create_user(*flag_add_user, *flag_secret) } if serve_default || *flag_serve { - serve(db) + serve() } } diff --git a/goatherd_test.go b/goatherd_test.go index 5ce6225..c954559 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -14,7 +14,6 @@ import ( "github.com/gokyle/hotp" ) -var db *sql.DB var stdin_writer *bufio.Writer @@ -33,7 +32,7 @@ func t_err_fatal(t *testing.T, err error) { // always uses global username and secret for comparison, so we can test "-" func create_and_check(user string, sec string) func(*testing.T) { return func(t *testing.T) { - create_user(db, user, sec) + create_user(user, sec) var result struct { secret []byte @@ -109,7 +108,7 @@ func transaction_conflict_t(t *testing.T) { func check_offer_t(t *testing.T) { t.Run("no_such_user", func(t *testing.T) { - _, err := check_offer(db, "mock", "no such user", "dummy offer") + _, err := check_offer("mock", "no such user", "dummy offer") if err != sql.ErrNoRows { t.Error("err:", nil) } }) @@ -117,32 +116,32 @@ func check_offer_t(t *testing.T) { t_err_fatal(t, err) t.Run("fail", func(t *testing.T) { - ok, err := check_offer(db, "mock", username, "dummy offer") + ok, err := check_offer("mock", username, "dummy offer") t_err_fatal(t, err) if ok { t.Fail() } t.Run("too_far_out", func(t *testing.T) { ahead := hotp.NewHOTP(secret, count + cfg.Lookahead + 1, otpLen) - ok, err = check_offer(db, "mock", username, ahead.OTP()) + ok, err = check_offer("mock", username, ahead.OTP()) if ok { t.Fail() } }) }) t.Run("ok", func(t *testing.T) { cur := hotp.NewHOTP(secret, count, otpLen) - ok, err := check_offer(db, "mock", username, cur.OTP()) + ok, err := check_offer("mock", username, cur.OTP()) t_err_fatal(t, err) if !ok { t.Fail() } cur.Increment() t.Run("incremented", func(t *testing.T) { - ok, err := check_offer(db, "mock", username, cur.OTP()) + ok, err := check_offer("mock", username, cur.OTP()) t_err_fatal(t, err) if !ok { t.Fail() } }) t.Run("lookahead", func(t *testing.T) { - ok, err := check_offer(db, "mock", username, + ok, err := check_offer("mock", username, hotp.NewHOTP(secret, cur.Counter() + cfg.Lookahead, otpLen).OTP()) t_err_fatal(t, err) if !ok { t.Fail() } @@ -163,7 +162,7 @@ func mock_conn(client client_t) (delay *sync.Mutex) { go client(client_r, client_w) - delay = handle_conn(db, "handle_conn_t", server_r_buf, server_w_buf) + delay = handle_conn("handle_conn_t", server_r_buf, server_w_buf) return } @@ -315,7 +314,7 @@ func TestMain(t *testing.T) { var err error db, err = sql.Open(cfg.Db_driver, cfg.Db_url) t_err_fatal(t, err) - create_table(db) + init_db() t.Run("create_user", create_user_t) t.Run("transaction_conflict_handling", transaction_conflict_t) -- GitLab