Skip to content
Snippets Groups Projects
Commit 026c1cee authored by Lukas Braun's avatar Lukas Braun
Browse files

abstract from sql.DB vs sql.Tx

Now we can use get_user/get_otp without having to open a transaction
each time (which also fixes a leak of transactions if an error occurs).
parent 019fe686
No related branches found
No related tags found
No related merge requests found
...@@ -92,6 +92,13 @@ func get_line(reader *bufio.Reader) (string, error) { ...@@ -92,6 +92,13 @@ func get_line(reader *bufio.Reader) (string, error) {
} }
type db_conn interface {
Exec(query string, args ...interface{}) (sql.Result, error)
QueryRow(query string, args ...interface{}) *sql.Row
Prepare(query string) (*sql.Stmt, error)
}
func create_table(db *sql.DB) { func create_table(db *sql.DB) {
debug("Creating table 'users' in DB") debug("Creating table 'users' in DB")
...@@ -134,11 +141,11 @@ func create_user(db *sql.DB, name string, secret_b64 string) { ...@@ -134,11 +141,11 @@ func create_user(db *sql.DB, name string, secret_b64 string) {
func get_user(tx *sql.Tx, name string) (secret []byte, count uint64, err error) { func get_user(db db_conn, name string) (secret []byte, count uint64, err error) {
stmt, err := tx.Prepare("SELECT secret, count FROM users WHERE name = ?") stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = ?")
if transaction_failed(err) { if err != nil {
return return
} else { err_panic(err) } }
err = stmt.QueryRow(name).Scan(&secret, &count) err = stmt.QueryRow(name).Scan(&secret, &count)
stmt.Close() stmt.Close()
...@@ -146,19 +153,19 @@ func get_user(tx *sql.Tx, name string) (secret []byte, count uint64, err error) ...@@ -146,19 +153,19 @@ func get_user(tx *sql.Tx, name string) (secret []byte, count uint64, err error)
return return
} }
func set_count(tx *sql.Tx, name string, count uint64) error { func set_count(db db_conn, name string, count uint64) error {
inc, err := tx.Prepare("UPDATE users SET count = ? WHERE name = ?") inc, err := db.Prepare("UPDATE users SET count = ? WHERE name = ?")
if transaction_failed(err) { if err != nil {
return err return err
} else { err_panic(err) } }
_, err = inc.Exec(count, name) _, err = inc.Exec(count, name)
inc.Close() inc.Close()
return err return err
} }
func get_otp(tx *sql.Tx, name string) (*hotp.HOTP, error) { func get_otp(db db_conn, name string) (*hotp.HOTP, error) {
secret, count, err := get_user(tx, name) secret, count, err := get_user(db, name)
if err != nil { return nil, err } if err != nil { return nil, err }
return hotp.NewHOTP(secret, count, 6), nil return hotp.NewHOTP(secret, count, 6), nil
......
...@@ -40,11 +40,9 @@ func create_and_check(user string, sec string) func(*testing.T) { ...@@ -40,11 +40,9 @@ func create_and_check(user string, sec string) func(*testing.T) {
} }
t.Run("exists", func(t *testing.T) { t.Run("exists", func(t *testing.T) {
tx, err := db.Begin() var err error
result.secret, result.count, err = get_user(db, username)
t_err_fatal(t, err) t_err_fatal(t, err)
result.secret, result.count, err = get_user(tx, username)
t_err_fatal(t, err)
tx.Commit()
}) })
t.Run("correct_secret", func(t *testing.T) { t.Run("correct_secret", func(t *testing.T) {
...@@ -108,11 +106,8 @@ func check_offer_t(t *testing.T) { ...@@ -108,11 +106,8 @@ func check_offer_t(t *testing.T) {
if err != sql.ErrNoRows { t.Error("err:", nil) } if err != sql.ErrNoRows { t.Error("err:", nil) }
}) })
tx, err := db.Begin() secret, count, err := get_user(db, username)
t_err_fatal(t, err)
secret, count, err := get_user(tx, username)
t_err_fatal(t, err) t_err_fatal(t, err)
tx.Commit()
t.Run("fail", func(t *testing.T) { t.Run("fail", func(t *testing.T) {
ok, err := check_offer(db, "mock", username, "dummy offer") ok, err := check_offer(db, "mock", username, "dummy offer")
...@@ -213,11 +208,8 @@ func handle_conn_t(t *testing.T) { ...@@ -213,11 +208,8 @@ func handle_conn_t(t *testing.T) {
} }
}) })
tx, err := db.Begin() otp, err := get_otp(db, username)
t_err_fatal(t, err)
otp, err := get_otp(tx, username)
t_err_fatal(t, err) t_err_fatal(t, err)
tx.Commit()
t.Run("ok", func(t *testing.T) { t.Run("ok", func(t *testing.T) {
recv := make(chan string) recv := make(chan string)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment