package main

import (
    "bufio"
    "bytes"
    "database/sql"
    "encoding/base64"
    "fmt"
    "io"
    "sync"
    "testing"

    "github.com/gokyle/hotp"
)

var db *sql.DB
var stdin_writer *bufio.Writer


var (
    username = "foobar"
    secret = []byte("foobar")
    secret_b64 = base64.StdEncoding.EncodeToString(secret)
)

func t_err_fatal(t *testing.T, err error) {
    if err != nil {
        t.Fatal(err)
    }
}

// always uses global username and secret for comparison, so we can test "-"
func create_and_check(user string, sec string) func(*testing.T) {
    return func(t *testing.T) {
        create_user(db, user, sec)

        var result struct {
            secret []byte
            count uint64
        }

        t.Run("exists", func(t *testing.T) {
            var err error
            result.secret, result.count, err = get_user(db, username)
            t_err_fatal(t, err)
        })

        t.Run("correct_secret", func(t *testing.T) {
            if bytes.Compare(secret, result.secret) != 0 {
                t.Errorf("secret: %v; result: %v", string(secret), string(result.secret))
            }
        })

        t.Run("correct_count", func(t *testing.T) {
            if result.count != 1 {
                t.Errorf("initial count: %v", result.count)
            }
        })
    }
}

func create_user_t(t *testing.T) {
    t.Run("args", create_and_check(username, secret_b64))

    go func() {
        fmt.Fprintln(stdin_writer, username)
        fmt.Fprintln(stdin_writer, secret_b64)
        stdin_writer.Flush()
    }()
    t.Run("stdin", create_and_check("-", "-"))
}

// runs two updating transactions concurrently, fails if none of them aborts
func transaction_conflict_t(t *testing.T) {
    var err error
    var txs [2]*sql.Tx
    for i, _ := range txs {
        txs[i], err = db.Begin()
        t_err_fatal(t, err)
        defer func(i int) { t.Log("rollback", i); txs[i].Rollback() }(i)
    }

    for i, tx := range txs {
        t.Log("update", i)
        _, err = tx.Exec("UPDATE users SET count = ? WHERE name == ?", uint64(i), username)
        if transaction_failed(err) {
            return
        } else { t_err_fatal(t, err) }
    }

    for i, tx := range txs {
        t.Log("commit", i)
        err = tx.Commit()
        if transaction_failed(err) {
            return
        } else { t_err_fatal(t, err) }
    }

    t.Error("No transaction failure despite concurrent updates!")
}


func check_offer_t(t *testing.T) {
    t.Run("no_such_user", func(t *testing.T) {
        _, err := check_offer(db, "mock", "no such user", "dummy offer")
        if err != sql.ErrNoRows { t.Error("err:", nil) }
    })

    secret, count, err := get_user(db, username)
    t_err_fatal(t, err)

    t.Run("fail", func(t *testing.T) {
        ok, err := check_offer(db, "mock", username, "dummy offer")
        t_err_fatal(t, err)
        if ok { t.Fail() }

        t.Run("too_far_out", func(t *testing.T) {
            ahead := hotp.NewHOTP(secret, count + cfg.Lookahead + 1, 6)
            ok, err = check_offer(db, "mock", username, ahead.OTP())
            if ok { t.Fail() }
        })
    })

    t.Run("ok", func(t *testing.T) {
        cur := hotp.NewHOTP(secret, count, 6)
        ok, err := check_offer(db, "mock", username, cur.OTP())
        t_err_fatal(t, err)
        if !ok { t.Fail() }

        cur.Increment()
        t.Run("incremented", func(t *testing.T) {
            ok, err := check_offer(db, "mock", username, cur.OTP())
            t_err_fatal(t, err)
            if !ok { t.Fail() }
        })

        t.Run("lookahead", func(t *testing.T) {
            ok, err := check_offer(db, "mock", username,
                hotp.NewHOTP(secret, cur.Counter() + cfg.Lookahead, 6).OTP())
            t_err_fatal(t, err)
            if !ok { t.Fail() }
        })
    })

}

type client_t func(r *io.PipeReader, w *io.PipeWriter)

func mock_conn(client client_t) (delay *sync.Mutex) {
    server_r, client_w := io.Pipe()
    server_r_buf := bufio.NewReader(server_r)
    defer server_r.Close()
    client_r, server_w := io.Pipe()
    server_w_buf := bufio.NewWriter(server_w)
    defer client_r.Close()

    go client(client_r, client_w)

    delay = handle_conn(db, "handle_conn_t", server_r_buf, server_w_buf)

    return
}

func interact(name string, offer string, recv chan string) client_t {
    return func(r *io.PipeReader, w *io.PipeWriter) {
        w.Write(append([]byte(name), '\n'))
        w.Write(append([]byte(offer), '\n'))
        answer, _ := get_line(bufio.NewScanner(r))
        recv <- answer
    }
}


func handle_conn_t(t *testing.T) {
    // zero value Mutex is unlocked, can be compared to a given Mutex to
    // (unreliably?) check if it is locked
    unlocked := sync.Mutex{}

    t.Run("dummy_user", func(t *testing.T) {
        recv := make(chan string)
        delay, answer := mock_conn(interact("dummy user", "dummy offer", recv)), <-recv
        if delay != nil {
            t.Fail()
        }
        if answer != "" {
            t.Error("answer:", answer)
        }
        if faildelay.userlocks["dummy user"] != nil {
            t.Error("userlock created for dummy user")
        }
    })

    t.Run("dummy_offer", func(t *testing.T) {
        recv := make(chan string)
        delay, answer := mock_conn(interact(username, "dummy offer", recv)), <-recv
        if delay == nil {
            t.Fail()
        } else if *delay == unlocked {
            t.Error("delay mutex is unlocked")
        } else {
            delay.Unlock()
        }
        if answer != "FAIL" {
            t.Error("answer:", answer)
        }
        if faildelay.userlocks[username] == nil {
            t.Error("userlock not in map")
        }
    })

    otp, err := get_otp(db, username)
    t_err_fatal(t, err)

    t.Run("ok", func(t *testing.T) {
        recv := make(chan string)
        delay := mock_conn(interact(username, otp.OTP(), recv))
        answer := <-recv
        if delay != nil {
            t.Fail()
        }
        if *faildelay.userlocks[username] != unlocked {
            t.Error("userlock not unlocked")
        }
        if answer != "OK" {
            t.Error("answer:", answer)
        }
    })
    otp.Increment()

    t.Run("read_error", func(t *testing.T) {
        t.Run("close_immediately",  func(t *testing.T) {
            delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                w.Close()
            })
            if delay != nil {
                t.Fail()
            }
            if *faildelay.userlocks[username] != unlocked {
                t.Error("userlock not unlocked")
            }
        })

        t.Run("close_after_username",  func(t *testing.T) {
            delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                w.Write(append([]byte(username), '\n'))
                w.Close()
            })
            if delay != nil {
                t.Fail()
            }
            if *faildelay.userlocks[username] != unlocked {
                t.Error("userlock not unlocked")
            }
        })
    })

    t.Run("write_error", func(t *testing.T) {
        t.Run("FAIL", func(t *testing.T) {
            delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                r.Close()
                w.Write(append([]byte(username), '\n'))
                w.Write(append([]byte("dummy offer"), '\n'))
            })
            if delay == nil {
                t.Fail()
            } else if *delay == unlocked {
                t.Error("delay mutex is unlocked")
            } else {
                delay.Unlock()
            }
        })

        t.Run("OK", func(t *testing.T) {
            delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) {
                r.Close()
                w.Write(append([]byte(username), '\n'))
                w.Write(append([]byte(otp.OTP()), '\n'))
            })
            if delay != nil {
                t.Fail()
            }
            if *faildelay.userlocks[username] != unlocked {
                t.Error("userlock not unlocked")
            }
        })
        otp.Increment()
    })
}


func TestMain(t *testing.T) {
    // multiple connections only talk to the same in-memory database if a shared
    // cache is used, otherwise a new database is created
    cfg.Db_url = "file::memory:?cache=shared"
    cfg.Lookahead = 3
    cfg.Debug = false
    faildelay.userlocks = make(map[string]*sync.Mutex)
    stdin_r, stdin_w := io.Pipe()
    stdin_scanner, stdin_writer = bufio.NewScanner(stdin_r), bufio.NewWriter(stdin_w)

    var err error
    db, err = sql.Open("sqlite3", cfg.Db_url)
    t_err_fatal(t, err)
    create_table(db)

    t.Run("create_user", create_user_t)
    t.Run("transaction_conflict_handling", transaction_conflict_t)
    t.Run("check_offer", check_offer_t)
    t.Run("handle_conn", handle_conn_t)
}