From 0bbf0bf45b62c2c5074a8d7118f7e2fdc9492653 Mon Sep 17 00:00:00 2001
From: Lukas Braun <lukas.braun@fau.de>
Date: Wed, 1 Feb 2017 02:23:47 +0100
Subject: [PATCH] add tests for the server

---
 goatherd_test.go | 217 +++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 217 insertions(+)
 create mode 100644 goatherd_test.go

diff --git a/goatherd_test.go b/goatherd_test.go
new file mode 100644
index 0000000..1692204
--- /dev/null
+++ b/goatherd_test.go
@@ -0,0 +1,217 @@
+package main
+
+import (
+    "bufio"
+    "bytes"
+    "database/sql"
+    "encoding/base64"
+    "io"
+    "sync"
+    "testing"
+
+    "github.com/gokyle/hotp"
+)
+
+var db *sql.DB
+
+var (
+    username = "foobar"
+    secret = []byte("foobar")
+    secret_b64 = base64.StdEncoding.EncodeToString(secret)
+)
+
+func t_err_fatal(t *testing.T, err error) {
+    if err != nil {
+        t.Fatal(err)
+    }
+}
+
+func create_user_t(t *testing.T) {
+    create_user(db, username, secret_b64)
+
+    var result struct {
+        secret []byte
+        count uint64
+    }
+
+    t.Run("exists", func(t *testing.T) {
+        tx, err := db.Begin()
+        t_err_fatal(t, err)
+        result.secret, result.count, err = get_user(tx, username)
+        t_err_fatal(t, err)
+        tx.Commit()
+    })
+
+    t.Run("correct_secret", func(t *testing.T) {
+        if bytes.Compare(secret, result.secret) != 0 {
+            t.Errorf("secret: %v; result: %v", secret, result.secret)
+        }
+    })
+
+    t.Run("correct_count", func(t *testing.T) {
+        if result.count != 1 {
+            t.Errorf("initial count: %v", result.count)
+        }
+    })
+}
+
+// runs two updating transactions concurrently, fails if none of them aborts
+func transaction_conflict_t(t *testing.T) {
+    var err error
+    var txs [2]*sql.Tx
+    for i, _ := range txs {
+        txs[i], err = db.Begin()
+        t_err_fatal(t, err)
+        defer func(i int) { t.Log("rollback", i); txs[i].Rollback() }(i)
+    }
+
+    for i, tx := range txs {
+        t.Log("update", i)
+        _, err = tx.Exec("UPDATE users SET count = ? WHERE name == ?", uint64(i), username)
+        if transaction_failed(err) {
+            return
+        } else { t_err_fatal(t, err) }
+    }
+
+    for i, tx := range txs {
+        t.Log("commit", i)
+        err = tx.Commit()
+        if transaction_failed(err) {
+            return
+        } else { t_err_fatal(t, err) }
+    }
+
+    t.Error("No transaction failure despite concurrent updates!")
+}
+
+
+func check_offer_t(t *testing.T) {
+    t.Run("no_such_user", func(t *testing.T) {
+        _, err := check_offer(db, "mock", "no such user", "dummy offer")
+        if err != sql.ErrNoRows { t.Error("err:", nil) }
+    })
+
+    tx, err := db.Begin()
+    t_err_fatal(t, err)
+    secret, count, err := get_user(tx, username)
+    t_err_fatal(t, err)
+    tx.Commit()
+
+    t.Run("fail", func(t *testing.T) {
+        ok, err := check_offer(db, "mock", username, "dummy offer")
+        t_err_fatal(t, err)
+        if ok { t.Fail() }
+
+        t.Run("too_far_out", func(t *testing.T) {
+            ahead := hotp.NewHOTP(secret, count + cfg.Lookahead + 1, 6)
+            ok, err = check_offer(db, "mock", username, ahead.OTP())
+            if ok { t.Fail() }
+        })
+    })
+
+    t.Run("ok", func(t *testing.T) {
+        cur := hotp.NewHOTP(secret, count, 6)
+        ok, err := check_offer(db, "mock", username, cur.OTP())
+        t_err_fatal(t, err)
+        if !ok { t.Fail() }
+
+        cur.Increment()
+        t.Run("incremented", func(t *testing.T) {
+            ok, err := check_offer(db, "mock", username, cur.OTP())
+            t_err_fatal(t, err)
+            if !ok { t.Fail() }
+        })
+
+        t.Run("lookahead", func(t *testing.T) {
+            ok, err := check_offer(db, "mock", username,
+                hotp.NewHOTP(secret, cur.Counter() + cfg.Lookahead, 6).OTP())
+            t_err_fatal(t, err)
+            if !ok { t.Fail() }
+        })
+    })
+
+}
+
+
+func mock_conn(name string, offer string) (delay *sync.Mutex, answer string) {
+    in_r, in_w := io.Pipe()
+    in_r_buf := bufio.NewReader(in_r)
+    defer in_r.Close()
+    out_r, out_w := io.Pipe()
+    out_r_buf, out_w_buf := bufio.NewReader(out_r), bufio.NewWriter(out_w)
+    defer out_r.Close()
+
+    answer_c := make(chan string)
+
+    go func() {
+        in_w.Write(append([]byte(username), '\n'))
+        in_w.Write(append([]byte(offer), '\n'))
+        answer, _ := get_line(out_r_buf)
+        answer_c <- answer
+    }()
+
+    delay = handle_conn(db, "handle_conn_t", in_r_buf, out_w_buf)
+    answer = <-answer_c
+
+    return
+}
+
+
+func handle_conn_t(t *testing.T) {
+    // zero value Mutex is unlocked, can be compared to a given Mutex to
+    // (unreliably?) check if it is locked
+    unlocked := sync.Mutex{}
+
+    t.Run("dummy_offer", func(t *testing.T) {
+        delay, answer := mock_conn(username, "dummy offer")
+        if delay == nil {
+            t.Fail()
+        } else if *delay == unlocked {
+            t.Error("delay mutex is unlocked")
+        } else {
+            delay.Unlock()
+        }
+        if answer != "FAIL" {
+            t.Error("answer:", answer)
+        }
+    })
+
+    t.Run("ok", func(t *testing.T) {
+        tx, err := db.Begin()
+        t_err_fatal(t, err)
+        otp, err := get_otp(tx, username)
+        t_err_fatal(t, err)
+        tx.Commit()
+
+        delay, answer := mock_conn(username, otp.OTP())
+        if delay != nil {
+            t.Fail()
+        }
+        if *faildelay.userlocks[username] != unlocked {
+            t.Error("userlock not unlocked")
+        }
+        if answer != "OK" {
+            t.Error("answer:", answer)
+        }
+    })
+}
+
+
+func TestMain(t *testing.T) {
+    // multiple connections only talk to the same in-memory database if a shared
+    // cache is used, otherwise a new database is created
+    cfg.Db_url = "file::memory:?cache=shared"
+    cfg.Lookahead = 3
+    cfg.Debug = false
+    faildelay.userlocks = make(map[string]*sync.Mutex)
+
+    var err error
+    db, err = sql.Open("sqlite3", cfg.Db_url)
+    t_err_fatal(t, err)
+    create_table(db)
+
+    t.Run("create_user", create_user_t)
+    t.Run("transaction_conflict_handling", transaction_conflict_t)
+    t.Run("check_offer", check_offer_t)
+    t.Run("handle_conn", handle_conn_t)
+}
-- 
GitLab