diff --git a/goatherd.go b/goatherd.go index e0d20bbf70d7224cf36c6055096c6c77a78f96b6..eb05e9da669155a84ae85ebdb07a7bbae65f054b 100644 --- a/goatherd.go +++ b/goatherd.go @@ -66,6 +66,8 @@ var faildelay struct { userlocks map[string]*sync.Mutex } +var db *sql.DB + func debug(v ...interface{}) { if cfg.Debug { log.Println(v...) } @@ -111,7 +113,7 @@ type db_conn interface { } -func create_table(db *sql.DB) { +func create_table() { debug("Creating table 'users' in DB") secret_type := "BLOB" @@ -127,9 +129,34 @@ func create_table(db *sql.DB) { err_fatalf("Failed to create table: %v\n", err) } +func pg_set_default_isolation() { + var dbname string + err := db.QueryRow("SELECT current_database()").Scan(&dbname) + err_fatalf("SELECT current_database(): %v\n", err) + + _, err = db.Exec(` + ALTER DATABASE ` + dbname + ` + SET default_transaction_isolation TO "serializable" + `) + err_fatalf("Failed to set default_transaction_isolation for DB: %v\n", err) + + // database settings only take effect on future sessions -> reconnect + err_fatal(db.Close()) + db, err = sql.Open(cfg.Db_driver, cfg.Db_url) + err_fatal(err) +} + +func init_db() { + if cfg.Db_driver == "postgres" { + pg_set_default_isolation() + } + + create_table() +} + // global var for testing, not modified during normal execution var stdin_scanner = bufio.NewScanner(os.Stdin) -func create_user(db *sql.DB, name string, secret_b64 string) { +func create_user(name string, secret_b64 string) { debug("Creating user") var err error @@ -212,7 +239,7 @@ func transaction_failed(err error) bool { // Retrieve secret and count for given username and try to find a match within // the lookahead range. Update count in DB if match is found. All within a // transaction, retrying if it fails. -func check_offer(db *sql.DB, remote string, name string, offer string) (bool, error) { +func check_offer(remote string, name string, offer string) (bool, error) { for { debugf("[%v] begin transaction", remote) tx, err := db.Begin() @@ -264,7 +291,7 @@ retry: } } -func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, +func handle_conn(remote string, reader *bufio.Reader, writer *bufio.Writer) (delay *sync.Mutex) { s := bufio.NewScanner(reader) b := make([]byte, 80) @@ -318,7 +345,7 @@ func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, debugf("[%v] checking for match", remote) result := "FAIL" - match, err := check_offer(db, remote, name, offer) + match, err := check_offer(remote, name, offer) if err != nil { log.Panic(err) } else if match { @@ -347,7 +374,7 @@ func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, return } -func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { +func listen(wg sync.WaitGroup, listener net.Listener) { defer wg.Done() for { @@ -362,7 +389,7 @@ func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { reader := bufio.NewReader(conn) writer := bufio.NewWriter(conn) - delay := handle_conn(db, remote, reader, writer) + delay := handle_conn(remote, reader, writer) debugf("[%v] closing", remote) conn.Close() // XXX: check err? @@ -376,7 +403,7 @@ func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { } } -func serve(db *sql.DB) { +func serve() { faildelay.userlocks = make(map[string]*sync.Mutex) var wg sync.WaitGroup @@ -400,7 +427,7 @@ func serve(db *sql.DB) { log.Println("Listening on", listen_addr) wg.Add(1) - go listen(db, wg, listener) + go listen(wg, listener) } wg.Wait() @@ -438,7 +465,7 @@ func main() { // are not set again flag.Parse() - db, err := sql.Open(cfg.Db_driver, cfg.Db_url) + db, err = sql.Open(cfg.Db_driver, cfg.Db_url) err_fatal(err) // default action is to serve, but not if one of the other actions is given @@ -462,15 +489,15 @@ func main() { if *flag_init_db { serve_default = false - create_table(db) + init_db() } if *flag_add_user != "" { serve_default = false - create_user(db, *flag_add_user, *flag_secret) + create_user(*flag_add_user, *flag_secret) } if serve_default || *flag_serve { - serve(db) + serve() } } diff --git a/goatherd_test.go b/goatherd_test.go index 5ce6225c8cef407065b137ae6fee08f3bc623741..c954559830e2e6db493035421e21d9e812b1c9ae 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -14,7 +14,6 @@ import ( "github.com/gokyle/hotp" ) -var db *sql.DB var stdin_writer *bufio.Writer @@ -33,7 +32,7 @@ func t_err_fatal(t *testing.T, err error) { // always uses global username and secret for comparison, so we can test "-" func create_and_check(user string, sec string) func(*testing.T) { return func(t *testing.T) { - create_user(db, user, sec) + create_user(user, sec) var result struct { secret []byte @@ -109,7 +108,7 @@ func transaction_conflict_t(t *testing.T) { func check_offer_t(t *testing.T) { t.Run("no_such_user", func(t *testing.T) { - _, err := check_offer(db, "mock", "no such user", "dummy offer") + _, err := check_offer("mock", "no such user", "dummy offer") if err != sql.ErrNoRows { t.Error("err:", nil) } }) @@ -117,32 +116,32 @@ func check_offer_t(t *testing.T) { t_err_fatal(t, err) t.Run("fail", func(t *testing.T) { - ok, err := check_offer(db, "mock", username, "dummy offer") + ok, err := check_offer("mock", username, "dummy offer") t_err_fatal(t, err) if ok { t.Fail() } t.Run("too_far_out", func(t *testing.T) { ahead := hotp.NewHOTP(secret, count + cfg.Lookahead + 1, otpLen) - ok, err = check_offer(db, "mock", username, ahead.OTP()) + ok, err = check_offer("mock", username, ahead.OTP()) if ok { t.Fail() } }) }) t.Run("ok", func(t *testing.T) { cur := hotp.NewHOTP(secret, count, otpLen) - ok, err := check_offer(db, "mock", username, cur.OTP()) + ok, err := check_offer("mock", username, cur.OTP()) t_err_fatal(t, err) if !ok { t.Fail() } cur.Increment() t.Run("incremented", func(t *testing.T) { - ok, err := check_offer(db, "mock", username, cur.OTP()) + ok, err := check_offer("mock", username, cur.OTP()) t_err_fatal(t, err) if !ok { t.Fail() } }) t.Run("lookahead", func(t *testing.T) { - ok, err := check_offer(db, "mock", username, + ok, err := check_offer("mock", username, hotp.NewHOTP(secret, cur.Counter() + cfg.Lookahead, otpLen).OTP()) t_err_fatal(t, err) if !ok { t.Fail() } @@ -163,7 +162,7 @@ func mock_conn(client client_t) (delay *sync.Mutex) { go client(client_r, client_w) - delay = handle_conn(db, "handle_conn_t", server_r_buf, server_w_buf) + delay = handle_conn("handle_conn_t", server_r_buf, server_w_buf) return } @@ -315,7 +314,7 @@ func TestMain(t *testing.T) { var err error db, err = sql.Open(cfg.Db_driver, cfg.Db_url) t_err_fatal(t, err) - create_table(db) + init_db() t.Run("create_user", create_user_t) t.Run("transaction_conflict_handling", transaction_conflict_t)