Skip to content
Snippets Groups Projects
Commit 019fe686 authored by Lukas Braun's avatar Lukas Braun
Browse files

test reading username and secret from stdin

parent 6c451e74
No related branches found
No related tags found
No related merge requests found
...@@ -104,19 +104,20 @@ func create_table(db *sql.DB) { ...@@ -104,19 +104,20 @@ func create_table(db *sql.DB) {
err_fatalf("Failed to create table: %v\n", err) err_fatalf("Failed to create table: %v\n", err)
} }
// global var for testing, not modified during normal execution
var stdin_reader = bufio.NewReader(os.Stdin)
func create_user(db *sql.DB, name string, secret_b64 string) { func create_user(db *sql.DB, name string, secret_b64 string) {
debug("Creating user") debug("Creating user")
var err error var err error
reader := bufio.NewReader(os.Stdin)
if name == "-" { if name == "-" {
fmt.Printf("Enter username: ") fmt.Printf("Enter username: ")
name, err = get_line(reader) name, err = get_line(stdin_reader)
err_fatalf("Can't read username: %v\n", err) err_fatalf("Can't read username: %v\n", err)
} }
if secret_b64 == "-" { if secret_b64 == "-" {
fmt.Printf("Enter secret: ") fmt.Printf("Enter secret: ")
secret_b64, err = get_line(reader) secret_b64, err = get_line(stdin_reader)
err_fatalf("Can't read secret: %v\n", err) err_fatalf("Can't read secret: %v\n", err)
} }
...@@ -218,6 +219,7 @@ commit: ...@@ -218,6 +219,7 @@ commit:
return ok, nil return ok, nil
retry: retry:
debugf("[%v] retry", remote)
ok = false ok = false
err = nil err = nil
tx.Rollback() tx.Rollback()
......
...@@ -5,6 +5,7 @@ import ( ...@@ -5,6 +5,7 @@ import (
"bytes" "bytes"
"database/sql" "database/sql"
"encoding/base64" "encoding/base64"
"fmt"
"io" "io"
"sync" "sync"
"testing" "testing"
...@@ -13,6 +14,8 @@ import ( ...@@ -13,6 +14,8 @@ import (
) )
var db *sql.DB var db *sql.DB
var stdin_writer *bufio.Writer
var ( var (
username = "foobar" username = "foobar"
...@@ -26,33 +29,47 @@ func t_err_fatal(t *testing.T, err error) { ...@@ -26,33 +29,47 @@ func t_err_fatal(t *testing.T, err error) {
} }
} }
func create_user_t(t *testing.T) { // always uses global username and secret for comparison, so we can test "-"
create_user(db, username, secret_b64) func create_and_check(user string, sec string) func(*testing.T) {
return func(t *testing.T) {
create_user(db, user, sec)
var result struct { var result struct {
secret []byte secret []byte
count uint64 count uint64
} }
t.Run("exists", func(t *testing.T) { t.Run("exists", func(t *testing.T) {
tx, err := db.Begin() tx, err := db.Begin()
t_err_fatal(t, err) t_err_fatal(t, err)
result.secret, result.count, err = get_user(tx, username) result.secret, result.count, err = get_user(tx, username)
t_err_fatal(t, err) t_err_fatal(t, err)
tx.Commit() tx.Commit()
}) })
t.Run("correct_secret", func(t *testing.T) { t.Run("correct_secret", func(t *testing.T) {
if bytes.Compare(secret, result.secret) != 0 { if bytes.Compare(secret, result.secret) != 0 {
t.Errorf("secret: %v; result: %v", secret, result.secret) t.Errorf("secret: %v; result: %v", string(secret), string(result.secret))
} }
}) })
t.Run("correct_count", func(t *testing.T) { t.Run("correct_count", func(t *testing.T) {
if result.count != 1 { if result.count != 1 {
t.Errorf("initial count: %v", result.count) t.Errorf("initial count: %v", result.count)
} }
}) })
}
}
func create_user_t(t *testing.T) {
t.Run("args", create_and_check(username, secret_b64))
go func() {
fmt.Fprintln(stdin_writer, username)
fmt.Fprintln(stdin_writer, secret_b64)
stdin_writer.Flush()
}()
t.Run("stdin", create_and_check("-", "-"))
} }
// runs two updating transactions concurrently, fails if none of them aborts // runs two updating transactions concurrently, fails if none of them aborts
...@@ -286,6 +303,8 @@ func TestMain(t *testing.T) { ...@@ -286,6 +303,8 @@ func TestMain(t *testing.T) {
cfg.Lookahead = 3 cfg.Lookahead = 3
cfg.Debug = false cfg.Debug = false
faildelay.userlocks = make(map[string]*sync.Mutex) faildelay.userlocks = make(map[string]*sync.Mutex)
stdin_r, stdin_w := io.Pipe()
stdin_reader, stdin_writer = bufio.NewReader(stdin_r), bufio.NewWriter(stdin_w)
var err error var err error
db, err = sql.Open("sqlite3", cfg.Db_url) db, err = sql.Open("sqlite3", cfg.Db_url)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment