diff --git a/goatherd.go b/goatherd.go index d50a6cd364c6525b3363bf5eb339069f953094c3..074e80f4d40223ac693915cf8653645ecf2bfc5b 100644 --- a/goatherd.go +++ b/goatherd.go @@ -92,6 +92,13 @@ func get_line(reader *bufio.Reader) (string, error) { } +type db_conn interface { + Exec(query string, args ...interface{}) (sql.Result, error) + QueryRow(query string, args ...interface{}) *sql.Row + Prepare(query string) (*sql.Stmt, error) +} + + func create_table(db *sql.DB) { debug("Creating table 'users' in DB") @@ -134,11 +141,11 @@ func create_user(db *sql.DB, name string, secret_b64 string) { -func get_user(tx *sql.Tx, name string) (secret []byte, count uint64, err error) { - stmt, err := tx.Prepare("SELECT secret, count FROM users WHERE name = ?") - if transaction_failed(err) { +func get_user(db db_conn, name string) (secret []byte, count uint64, err error) { + stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = ?") + if err != nil { return - } else { err_panic(err) } + } err = stmt.QueryRow(name).Scan(&secret, &count) stmt.Close() @@ -146,19 +153,19 @@ func get_user(tx *sql.Tx, name string) (secret []byte, count uint64, err error) return } -func set_count(tx *sql.Tx, name string, count uint64) error { - inc, err := tx.Prepare("UPDATE users SET count = ? WHERE name = ?") - if transaction_failed(err) { +func set_count(db db_conn, name string, count uint64) error { + inc, err := db.Prepare("UPDATE users SET count = ? WHERE name = ?") + if err != nil { return err - } else { err_panic(err) } + } _, err = inc.Exec(count, name) inc.Close() return err } -func get_otp(tx *sql.Tx, name string) (*hotp.HOTP, error) { - secret, count, err := get_user(tx, name) +func get_otp(db db_conn, name string) (*hotp.HOTP, error) { + secret, count, err := get_user(db, name) if err != nil { return nil, err } return hotp.NewHOTP(secret, count, 6), nil diff --git a/goatherd_test.go b/goatherd_test.go index c2142aea4a35ac6e450c51ce8653f85d8de19a1d..310ad1193ce26942221e447526ce1afdc7aedffe 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -40,11 +40,9 @@ func create_and_check(user string, sec string) func(*testing.T) { } t.Run("exists", func(t *testing.T) { - tx, err := db.Begin() + var err error + result.secret, result.count, err = get_user(db, username) t_err_fatal(t, err) - result.secret, result.count, err = get_user(tx, username) - t_err_fatal(t, err) - tx.Commit() }) t.Run("correct_secret", func(t *testing.T) { @@ -108,11 +106,8 @@ func check_offer_t(t *testing.T) { if err != sql.ErrNoRows { t.Error("err:", nil) } }) - tx, err := db.Begin() - t_err_fatal(t, err) - secret, count, err := get_user(tx, username) + secret, count, err := get_user(db, username) t_err_fatal(t, err) - tx.Commit() t.Run("fail", func(t *testing.T) { ok, err := check_offer(db, "mock", username, "dummy offer") @@ -213,11 +208,8 @@ func handle_conn_t(t *testing.T) { } }) - tx, err := db.Begin() - t_err_fatal(t, err) - otp, err := get_otp(tx, username) + otp, err := get_otp(db, username) t_err_fatal(t, err) - tx.Commit() t.Run("ok", func(t *testing.T) { recv := make(chan string)