From 019fe686f0051913db0ecca05df0f2f0fb540073 Mon Sep 17 00:00:00 2001
From: Lukas Braun <lukas.braun@fau.de>
Date: Thu, 2 Feb 2017 15:39:48 +0100
Subject: [PATCH] test reading username and secret from stdin

---
 goatherd.go      |  8 +++---
 goatherd_test.go | 65 +++++++++++++++++++++++++++++++-----------------
 2 files changed, 47 insertions(+), 26 deletions(-)

diff --git a/goatherd.go b/goatherd.go
index 033ab19..d50a6cd 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -104,19 +104,20 @@ func create_table(db *sql.DB) {
     err_fatalf("Failed to create table: %v\n", err)
 }
 
+// global var for testing, not modified during normal execution
+var stdin_reader = bufio.NewReader(os.Stdin)
 func create_user(db *sql.DB, name string, secret_b64 string) {
     debug("Creating user")
 
     var err error
-    reader := bufio.NewReader(os.Stdin)
     if name == "-" {
         fmt.Printf("Enter username: ")
-        name, err = get_line(reader)
+        name, err = get_line(stdin_reader)
         err_fatalf("Can't read username: %v\n", err)
     }
     if secret_b64 == "-" {
         fmt.Printf("Enter secret: ")
-        secret_b64, err = get_line(reader)
+        secret_b64, err = get_line(stdin_reader)
         err_fatalf("Can't read secret: %v\n", err)
     }
 
@@ -218,6 +219,7 @@ commit:
         return ok, nil
 
 retry:
+        debugf("[%v] retry", remote)
         ok = false
         err = nil
         tx.Rollback()
diff --git a/goatherd_test.go b/goatherd_test.go
index 526517b..c2142ae 100644
--- a/goatherd_test.go
+++ b/goatherd_test.go
@@ -5,6 +5,7 @@ import (
     "bytes"
     "database/sql"
     "encoding/base64"
+    "fmt"
     "io"
     "sync"
     "testing"
@@ -13,6 +14,8 @@ import (
 )
 
 var db *sql.DB
+var stdin_writer *bufio.Writer
+
 
 var (
     username = "foobar"
@@ -26,33 +29,47 @@ func t_err_fatal(t *testing.T, err error) {
     }
 }
 
-func create_user_t(t *testing.T) {
-    create_user(db, username, secret_b64)
+// always uses global username and secret for comparison, so we can test "-"
+func create_and_check(user string, sec string) func(*testing.T) {
+    return func(t *testing.T) {
+        create_user(db, user, sec)
 
-    var result struct {
-        secret []byte
-        count uint64
-    }
+        var result struct {
+            secret []byte
+            count uint64
+        }
 
-    t.Run("exists", func(t *testing.T) {
-        tx, err := db.Begin()
-        t_err_fatal(t, err)
-        result.secret, result.count, err = get_user(tx, username)
-        t_err_fatal(t, err)
-        tx.Commit()
-    })
+        t.Run("exists", func(t *testing.T) {
+            tx, err := db.Begin()
+            t_err_fatal(t, err)
+            result.secret, result.count, err = get_user(tx, username)
+            t_err_fatal(t, err)
+            tx.Commit()
+        })
 
-    t.Run("correct_secret", func(t *testing.T) {
-        if bytes.Compare(secret, result.secret) != 0 {
-            t.Errorf("secret: %v; result: %v", secret, result.secret)
-        }
-    })
+        t.Run("correct_secret", func(t *testing.T) {
+            if bytes.Compare(secret, result.secret) != 0 {
+                t.Errorf("secret: %v; result: %v", string(secret), string(result.secret))
+            }
+        })
 
-    t.Run("correct_count", func(t *testing.T) {
-        if result.count != 1 {
-            t.Errorf("initial count: %v", result.count)
-        }
-    })
+        t.Run("correct_count", func(t *testing.T) {
+            if result.count != 1 {
+                t.Errorf("initial count: %v", result.count)
+            }
+        })
+    }
+}
+
+func create_user_t(t *testing.T) {
+    t.Run("args", create_and_check(username, secret_b64))
+
+    go func() {
+        fmt.Fprintln(stdin_writer, username)
+        fmt.Fprintln(stdin_writer, secret_b64)
+        stdin_writer.Flush()
+    }()
+    t.Run("stdin", create_and_check("-", "-"))
 }
 
 // runs two updating transactions concurrently, fails if none of them aborts
@@ -286,6 +303,8 @@ func TestMain(t *testing.T) {
     cfg.Lookahead = 3
     cfg.Debug = false
     faildelay.userlocks = make(map[string]*sync.Mutex)
+    stdin_r, stdin_w := io.Pipe()
+    stdin_reader, stdin_writer = bufio.NewReader(stdin_r), bufio.NewWriter(stdin_w)
 
     var err error
     db, err = sql.Open("sqlite3", cfg.Db_url)
-- 
GitLab