Skip to content
Snippets Groups Projects
goatherd_test.go 10.3 KiB
Newer Older
  • Learn to ignore specific revisions
  • Lukas Braun's avatar
    Lukas Braun committed
    package main
    
    import (
        "bufio"
        "bytes"
        "database/sql"
        "encoding/base64"
    
    Lukas Braun's avatar
    Lukas Braun committed
        "io"
    
    Lukas Braun's avatar
    Lukas Braun committed
        "os"
    
    Lukas Braun's avatar
    Lukas Braun committed
        "sync"
        "testing"
    
        "github.com/gokyle/hotp"
    )
    
    
    var stdin_writer *bufio.Writer
    
    
    Lukas Braun's avatar
    Lukas Braun committed
    
    var (
        username = "foobar"
        secret = []byte("foobar")
        secret_b64 = base64.StdEncoding.EncodeToString(secret)
    )
    
    func t_err_fatal(t *testing.T, err error) {
        if err != nil {
            t.Fatal(err)
        }
    }
    
    
    // always uses global username and secret for comparison, so we can test "-"
    func create_and_check(user string, sec string) func(*testing.T) {
        return func(t *testing.T) {
    
            create_user(user, sec)
    
            var result struct {
                secret []byte
                count uint64
            }
    
            t.Run("exists", func(t *testing.T) {
    
                var err error
                result.secret, result.count, err = get_user(db, username)
    
                t_err_fatal(t, err)
            })
    
            t.Run("correct_secret", func(t *testing.T) {
                if bytes.Compare(secret, result.secret) != 0 {
                    t.Errorf("secret: %v; result: %v", string(secret), string(result.secret))
                }
            })
    
            t.Run("correct_count", func(t *testing.T) {
                if result.count != 1 {
                    t.Errorf("initial count: %v", result.count)
                }
            })
        }
    }
    
    func create_user_t(t *testing.T) {
        t.Run("args", create_and_check(username, secret_b64))
    
        go func() {
            fmt.Fprintln(stdin_writer, username)
            fmt.Fprintln(stdin_writer, secret_b64)
            stdin_writer.Flush()
        }()
        t.Run("stdin", create_and_check("-", "-"))
    
    
    func aborted(t *testing.T, err error, aborted ...interface{}) bool {
        if transaction_failed(err) {
            t.Log(aborted...)
            return true
        } else { t_err_fatal(t, err) }
    
        return false
    }
    
    
    func interleaved_transactions_t(t *testing.T) {
    
    Lukas Braun's avatar
    Lukas Braun committed
        var err error
        var txs [2]*sql.Tx
        for i, _ := range txs {
            txs[i], err = db.Begin()
            t_err_fatal(t, err)
    
            defer func(i int) { txs[i].Rollback() }(i)
    
    Lukas Braun's avatar
    Lukas Braun committed
        var c uint64
    
        for i, tx := range txs {
    
    Lukas Braun's avatar
    Lukas Braun committed
            err = tx.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c)
    
            if aborted(t, err, "rollback after select", i) { return }
    
    Lukas Braun's avatar
    Lukas Braun committed
        }
    
        for i, tx := range txs {
    
    Lukas Braun's avatar
    Lukas Braun committed
            t.Log("update", i)
            _, err := tx.Exec("UPDATE users SET count = $1 WHERE name = $2", c + uint64(i), username)
    
            if aborted(t, err, "rollback after update", i) { return }
    
    Lukas Braun's avatar
    Lukas Braun committed
            t.Log("commit", i)
            err = tx.Commit()
    
            if aborted(t, err, "rollback after commit", i) { return }
    
        t.Error("No transaction failure despite interleaved transactions!")
    }
    
    func nested_transactions_t(t *testing.T) {
        outer, err := db.Begin()
        t_err_fatal(t, err)
        defer outer.Rollback()
        inner, err := db.Begin()
        t_err_fatal(t, err)
        defer inner.Rollback()
    
        var c uint64
        t_err_fatal(t, outer.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c))
    
        err = inner.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c)
    
        if aborted(t, err, "rollback after inner.Query") { return }
    
    
        t.Log("update inner")
        _, err = inner.Exec("UPDATE users SET count = $1 WHERE name = $2", c + 1, username)
    
        if aborted(t, err, "rollback after inner.Exec") { return }
    
    
        err = inner.Commit()
    
        if aborted(t, err, "rollback after inner.Commit") { return }
    
    
        t.Log("update outer")
        _, err = outer.Exec("UPDATE users SET count = $1 WHERE name = $2", c + 1, username)
    
        if aborted(t, err, "rollback after outer.Exec") { return }
    
    
        err = outer.Commit()
    
        if aborted(t, err, "rollback after outer.Commit") { return }
    
    
        t.Error("No transaction failure despite nested transactions")
    
    func transaction_conflict_t(t *testing.T) {
        t.Run("interleaved_transactions_t", interleaved_transactions_t)
        t.Run("nested_transactions_t", nested_transactions_t)
    }
    
    Lukas Braun's avatar
    Lukas Braun committed
    
    func check_offer_t(t *testing.T) {
        t.Run("no_such_user", func(t *testing.T) {
    
            _, err := check_offer("mock", "no such user", "dummy offer")
    
    Lukas Braun's avatar
    Lukas Braun committed
            if err != sql.ErrNoRows { t.Error("err:", nil) }
        })
    
    
        secret, count, err := get_user(db, username)
    
    Lukas Braun's avatar
    Lukas Braun committed
        t_err_fatal(t, err)
    
        t.Run("fail", func(t *testing.T) {
    
            ok, err := check_offer("mock", username, "dummy offer")
    
    Lukas Braun's avatar
    Lukas Braun committed
            t_err_fatal(t, err)
            if ok { t.Fail() }
    
            t.Run("too_far_out", func(t *testing.T) {
    
    Lukas Braun's avatar
    Lukas Braun committed
                ahead := hotp.NewHOTP(secret, count + cfg.Lookahead + 1, otpLen)
    
                ok, err = check_offer("mock", username, ahead.OTP())
    
    Lukas Braun's avatar
    Lukas Braun committed
                if ok { t.Fail() }
            })
        })
    
        t.Run("ok", func(t *testing.T) {
    
    Lukas Braun's avatar
    Lukas Braun committed
            cur := hotp.NewHOTP(secret, count, otpLen)
    
            ok, err := check_offer("mock", username, cur.OTP())
    
    Lukas Braun's avatar
    Lukas Braun committed
            t_err_fatal(t, err)
            if !ok { t.Fail() }
    
            cur.Increment()
            t.Run("incremented", func(t *testing.T) {
    
                ok, err := check_offer("mock", username, cur.OTP())
    
    Lukas Braun's avatar
    Lukas Braun committed
                t_err_fatal(t, err)
                if !ok { t.Fail() }
            })
    
            t.Run("lookahead", func(t *testing.T) {
    
                ok, err := check_offer("mock", username,
    
    Lukas Braun's avatar
    Lukas Braun committed
                    hotp.NewHOTP(secret, cur.Counter() + cfg.Lookahead, otpLen).OTP())
    
    Lukas Braun's avatar
    Lukas Braun committed
                t_err_fatal(t, err)
                if !ok { t.Fail() }
            })
        })
    
    }
    
    
    Lukas Braun's avatar
    Lukas Braun committed
    type client_t func(r *io.PipeReader, w *io.PipeWriter)
    
    Lukas Braun's avatar
    Lukas Braun committed
    func mock_conn(client client_t) (delay *sync.Mutex) {
        server_r, client_w := io.Pipe()
        server_r_buf := bufio.NewReader(server_r)
        defer server_r.Close()
        client_r, server_w := io.Pipe()
        server_w_buf := bufio.NewWriter(server_w)
        defer client_r.Close()
    
    Lukas Braun's avatar
    Lukas Braun committed
        go client(client_r, client_w)
    
        delay = handle_conn("handle_conn_t", server_r_buf, server_w_buf)
    
    Lukas Braun's avatar
    Lukas Braun committed
    func interact(name string, offer string, recv chan string) client_t {
        return func(r *io.PipeReader, w *io.PipeWriter) {
            w.Write(append([]byte(name), '\n'))
            w.Write(append([]byte(offer), '\n'))
    
            answer, _ := get_line(bufio.NewScanner(r))
    
    Lukas Braun's avatar
    Lukas Braun committed
            recv <- answer
        }
    }
    
    
    Lukas Braun's avatar
    Lukas Braun committed
    
    func handle_conn_t(t *testing.T) {
        // zero value Mutex is unlocked, can be compared to a given Mutex to
        // (unreliably?) check if it is locked
        unlocked := sync.Mutex{}
    
    
    Lukas Braun's avatar
    Lukas Braun committed
        t.Run("dummy_user", func(t *testing.T) {
            recv := make(chan string)
            delay, answer := mock_conn(interact("dummy user", "dummy offer", recv)), <-recv
            if delay != nil {
                t.Fail()
            }
            if answer != "" {
                t.Error("answer:", answer)
            }
            if faildelay.userlocks["dummy user"] != nil {
                t.Error("userlock created for dummy user")
            }
        })
    
    
    Lukas Braun's avatar
    Lukas Braun committed
        t.Run("dummy_offer", func(t *testing.T) {
    
    Lukas Braun's avatar
    Lukas Braun committed
            recv := make(chan string)
            delay, answer := mock_conn(interact(username, "dummy offer", recv)), <-recv
    
    Lukas Braun's avatar
    Lukas Braun committed
            if delay == nil {
                t.Fail()
            } else if *delay == unlocked {
                t.Error("delay mutex is unlocked")
            } else {
                delay.Unlock()
            }
            if answer != "FAIL" {
                t.Error("answer:", answer)
            }
    
    Lukas Braun's avatar
    Lukas Braun committed
            if faildelay.userlocks[username] == nil {
                t.Error("userlock not in map")
            }
    
        otp, err := get_otp(db, username)
    
    Lukas Braun's avatar
    Lukas Braun committed
        t_err_fatal(t, err)
    
    Lukas Braun's avatar
    Lukas Braun committed
        t.Run("ok", func(t *testing.T) {
            recv := make(chan string)
            delay := mock_conn(interact(username, otp.OTP(), recv))
            answer := <-recv
    
    Lukas Braun's avatar
    Lukas Braun committed
            if delay != nil {
                t.Fail()
            }
            if *faildelay.userlocks[username] != unlocked {
                t.Error("userlock not unlocked")
            }
            if answer != "OK" {
                t.Error("answer:", answer)
            }
        })
    
    Lukas Braun's avatar
    Lukas Braun committed
        otp.Increment()
    
        t.Run("read_error", func(t *testing.T) {
            t.Run("close_immediately",  func(t *testing.T) {
                delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                    w.Close()
                })
                if delay != nil {
                    t.Fail()
                }
                if *faildelay.userlocks[username] != unlocked {
                    t.Error("userlock not unlocked")
                }
            })
    
            t.Run("close_after_username",  func(t *testing.T) {
                delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                    w.Write(append([]byte(username), '\n'))
                    w.Close()
                })
                if delay != nil {
                    t.Fail()
                }
                if *faildelay.userlocks[username] != unlocked {
                    t.Error("userlock not unlocked")
                }
            })
        })
    
        t.Run("write_error", func(t *testing.T) {
            t.Run("FAIL", func(t *testing.T) {
                delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                    r.Close()
                    w.Write(append([]byte(username), '\n'))
                    w.Write(append([]byte("dummy offer"), '\n'))
                })
                if delay == nil {
                    t.Fail()
                } else if *delay == unlocked {
                    t.Error("delay mutex is unlocked")
                } else {
                    delay.Unlock()
                }
            })
    
            t.Run("OK", func(t *testing.T) {
                delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                    r.Close()
                    w.Write(append([]byte(username), '\n'))
                    w.Write(append([]byte(otp.OTP()), '\n'))
                })
                if delay != nil {
                    t.Fail()
                }
                if *faildelay.userlocks[username] != unlocked {
                    t.Error("userlock not unlocked")
                }
            })
            otp.Increment()
        })
    
    Lukas Braun's avatar
    Lukas Braun committed
    }
    
    
    func TestMain(t *testing.T) {
        // multiple connections only talk to the same in-memory database if a shared
        // cache is used, otherwise a new database is created
        cfg.Db_url = "file::memory:?cache=shared"
    
    Lukas Braun's avatar
    Lukas Braun committed
        if db_url, ok := os.LookupEnv("DB_URL"); ok {
            cfg.Db_url = db_url
        }
        cfg.Db_driver = "sqlite3"
        if db_driver, ok := os.LookupEnv("DB_DRIVER"); ok {
            cfg.Db_driver = db_driver
        }
    
    Lukas Braun's avatar
    Lukas Braun committed
        cfg.Lookahead = 3
        cfg.Debug = false
        faildelay.userlocks = make(map[string]*sync.Mutex)
    
        stdin_r, stdin_w := io.Pipe()
    
        stdin_scanner, stdin_writer = bufio.NewScanner(stdin_r), bufio.NewWriter(stdin_w)
    
    Lukas Braun's avatar
    Lukas Braun committed
    
        var err error
    
    Lukas Braun's avatar
    Lukas Braun committed
        db, err = sql.Open(cfg.Db_driver, cfg.Db_url)
    
    Lukas Braun's avatar
    Lukas Braun committed
        t_err_fatal(t, err)
    
    Lukas Braun's avatar
    Lukas Braun committed
    
        t.Run("create_user", create_user_t)
        t.Run("transaction_conflict_handling", transaction_conflict_t)
        t.Run("check_offer", check_offer_t)
        t.Run("handle_conn", handle_conn_t)
    }