diff --git a/goatherd.go b/goatherd.go index 8c605467702a9b4a2cd6091a72b0f6dfdf201af4..ac4bb6248c5f6f0baff0e5899980cc9d2ca48e2a 100644 --- a/goatherd.go +++ b/goatherd.go @@ -44,15 +44,15 @@ func (d duration) MarshalJSON() ([]byte, error) { return []byte(fmt.Sprintf(`"%s"`, d.String())), nil } -type tls_cfg struct { +type tlsCfg struct { Key, Cert string } var cfg struct { - Db_url string + DbURL string Lookahead uint64 Debug bool - Listen map[string]tls_cfg + Listen map[string]tlsCfg Faildelay duration } @@ -81,23 +81,23 @@ func debugf(fmt string, v ...interface{}) { } } -func err_panic(err error) { +func errPanic(err error) { if err != nil { log.Panic(err) } } -func err_fatalf(fmt string, err error) { +func errFatalf(fmt string, err error) { if err != nil { log.Fatalf(fmt, err) } } -func err_fatal(err error) { - err_fatalf("%v", err) +func errFatal(err error) { + errFatalf("%v", err) } -func get_line(s *bufio.Scanner) (string, error) { +func getLine(s *bufio.Scanner) (string, error) { if !s.Scan() { if err := s.Err(); err != nil { return "", err @@ -111,13 +111,13 @@ func get_line(s *bufio.Scanner) (string, error) { return line, nil } -type db_conn interface { +type dbConn interface { Exec(query string, args ...interface{}) (sql.Result, error) QueryRow(query string, args ...interface{}) *sql.Row Prepare(query string) (*sql.Stmt, error) } -func create_table() { +func createTable() { debug("Creating table 'users' in DB") _, err := db.Exec(` @@ -126,44 +126,44 @@ func create_table() { secret BYTEA NOT NULL, count INTEGER) `) - err_fatalf("Failed to create table: %v\n", err) + errFatalf("Failed to create table: %v\n", err) } -func pg_set_default_isolation() { +func pgSetDefaultIsolation() { var dbname string err := db.QueryRow("SELECT current_database()").Scan(&dbname) - err_fatalf("SELECT current_database(): %v\n", err) + errFatalf("SELECT current_database(): %v\n", err) _, err = db.Exec(` ALTER DATABASE ` + dbname + ` SET default_transaction_isolation TO "serializable" `) - err_fatalf("Failed to set default_transaction_isolation for DB: %v\n", err) + errFatalf("Failed to set default_transaction_isolation for DB: %v\n", err) // database settings only take effect on future sessions -> reconnect - err_fatal(db.Close()) - db, err = sql.Open("postgres", cfg.Db_url) - err_fatal(err) + errFatal(db.Close()) + db, err = sql.Open("postgres", cfg.DbURL) + errFatal(err) } -func init_db() { - pg_set_default_isolation() +func initDB() { + pgSetDefaultIsolation() - create_table() + createTable() } -func read_users_from_file(path string) { +func readUsersFromFile(path string) { fi, err := os.Stat(path) - err_fatal(err) + errFatal(err) if fi.Mode().Perm()&0077 != 0 { log.Fatalf("%v: group/other accessible, refusing operation\n", path) } f, err := os.Open(path) - err_fatal(err) + errFatal(err) lines := bufio.NewScanner(f) - users_want := make(map[string]string) + usersWant := make(map[string]string) for lines.Scan() { l := strings.TrimSpace(lines.Text()) if len(l) == 0 || l[0] == '#' { @@ -173,12 +173,12 @@ func read_users_from_file(path string) { if len(tokens) != 2 { log.Fatalf("unexpected input in userlist: %v\n", l) } - users_want[strings.TrimSpace(tokens[0])] = strings.TrimSpace(tokens[1]) + usersWant[strings.TrimSpace(tokens[0])] = strings.TrimSpace(tokens[1]) } // update and delete existing users rows, err := db.Query("SELECT name, secret FROM users") - err_fatal(err) + errFatal(err) defer rows.Close() for rows.Next() { var have struct { @@ -188,61 +188,61 @@ func read_users_from_file(path string) { err = rows.Scan(&have.name, &have.secret) debug("have", have.name) - if s_b64, want := users_want[have.name]; want { - secret, err := base64.StdEncoding.DecodeString(s_b64) + if sB64, want := usersWant[have.name]; want { + secret, err := base64.StdEncoding.DecodeString(sB64) if err != nil { log.Fatalf("%v: %v: %v\n", path, have.name, err) } if !bytes.Equal(secret, have.secret) { log.Printf("update %v\n", have.name) _, err = db.Exec("UPDATE users SET count = 1, secret = $1 WHERE name = $2", secret, have.name) - err_fatal(err) + errFatal(err) } - delete(users_want, have.name) + delete(usersWant, have.name) } else { log.Printf("remove %v\n", have.name) _, err = db.Exec("DELETE FROM users WHERE name = $1", have.name) - err_fatal(err) + errFatal(err) } } // add new users - for user, s_b64 := range users_want { + for user, sB64 := range usersWant { log.Printf("create %v\n", user) - create_user(user, s_b64) + createUser(user, sB64) } } // global var for testing, not modified during normal execution -var stdin_scanner = bufio.NewScanner(os.Stdin) +var stdinScanner = bufio.NewScanner(os.Stdin) -func create_user(name string, secret_b64 string) { +func createUser(name string, secretB64 string) { debug("Creating user") var err error if name == "-" { fmt.Printf("Enter username: ") - name, err = get_line(stdin_scanner) - err_fatalf("Can't read username: %v\n", err) + name, err = getLine(stdinScanner) + errFatalf("Can't read username: %v\n", err) } - if secret_b64 == "-" { + if secretB64 == "-" { fmt.Printf("Enter secret: ") - secret_b64, err = get_line(stdin_scanner) - err_fatalf("Can't read secret: %v\n", err) + secretB64, err = getLine(stdinScanner) + errFatalf("Can't read secret: %v\n", err) } - secret, err := base64.StdEncoding.DecodeString(secret_b64) - err_fatalf("Can't decode secret: %v\n", err) + secret, err := base64.StdEncoding.DecodeString(secretB64) + errFatalf("Can't decode secret: %v\n", err) debug("Adding user with name", name) _, err = db.Exec(` INSERT INTO users(name, secret, count) values($1, $2, $3) ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3 `, name, secret, 1) - err_fatalf("Failed to create user: %v\n", err) + errFatalf("Failed to create user: %v\n", err) } -func get_user(db db_conn, name string) (secret []byte, count uint64, err error) { +func getUser(db dbConn, name string) (secret []byte, count uint64, err error) { stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1") if err != nil { return @@ -254,7 +254,7 @@ func get_user(db db_conn, name string) (secret []byte, count uint64, err error) return } -func set_count(db db_conn, name string, count uint64) error { +func setCount(db dbConn, name string, count uint64) error { inc, err := db.Prepare("UPDATE users SET count = $1 WHERE name = $2") if err != nil { return err @@ -265,8 +265,8 @@ func set_count(db db_conn, name string, count uint64) error { return err } -func get_otp(db db_conn, name string) (*hotp.HOTP, error) { - secret, count, err := get_user(db, name) +func getOTP(db dbConn, name string) (*hotp.HOTP, error) { + secret, count, err := getUser(db, name) if err != nil { return nil, err } @@ -274,7 +274,7 @@ func get_otp(db db_conn, name string) (*hotp.HOTP, error) { return hotp.NewHOTP(secret, count, otpLen), nil } -func transaction_failed(err error) bool { +func transactionFailed(err error) bool { if err == nil { return false } @@ -290,23 +290,23 @@ func transaction_failed(err error) bool { // Retrieve secret and count for given username and try to find a match within // the lookahead range. Update count in DB if match is found. All within a // transaction, retrying if it fails. -func check_offer(remote string, name string, offer string) (bool, error) { +func checkOffer(remote string, name string, offer string) (bool, error) { for { debugf("[%v] begin transaction", remote) tx, err := db.Begin() - err_panic(err) + errPanic(err) ok := false debugf("[%v] looking up data for %v", remote, name) - hotp, err := get_otp(tx, name) - if transaction_failed(err) { + hotp, err := getOTP(tx, name) + if transactionFailed(err) { goto retry } else if err == sql.ErrNoRows { tx.Rollback() return false, err } else { - err_panic(err) + errPanic(err) } for i := uint64(0); i <= cfg.Lookahead; i++ { @@ -315,11 +315,11 @@ func check_offer(remote string, name string, offer string) (bool, error) { // otherwise do it explicitly if hotp.Check(offer) { debugf("[%v] ok, set new count", remote) - err = set_count(tx, name, hotp.Counter()) - if transaction_failed(err) { + err = setCount(tx, name, hotp.Counter()) + if transactionFailed(err) { goto retry } else { - err_panic(err) + errPanic(err) } ok = true @@ -332,10 +332,10 @@ func check_offer(remote string, name string, offer string) (bool, error) { commit: debugf("[%v] commiting", remote) err = tx.Commit() - if transaction_failed(err) { + if transactionFailed(err) { goto retry } else { - err_panic(err) + errPanic(err) } return ok, nil @@ -348,21 +348,21 @@ func check_offer(remote string, name string, offer string) (bool, error) { } } -func handle_conn(remote string, reader *bufio.Reader, +func handleConn(remote string, reader *bufio.Reader, writer *bufio.Writer) (delay *sync.Mutex) { s := bufio.NewScanner(reader) b := make([]byte, 80) s.Buffer(b, len(b)) debugf("[%v] reading name", remote) - name, err := get_line(s) + name, err := getLine(s) if err != nil { log.Printf("[%v] %v", remote, err) return } debugf("[%v] name: %v", remote, name) - _, _, err = get_user(db, name) + _, _, err = getUser(db, name) if err == sql.ErrNoRows { log.Printf("[%v] Unkown user: %v", remote, name) return @@ -371,7 +371,7 @@ func handle_conn(remote string, reader *bufio.Reader, } debugf("[%v] reading offer", remote) - offer, err := get_line(s) + offer, err := getLine(s) if err != nil { log.Printf("[%v] %v", remote, err) return @@ -402,7 +402,7 @@ func handle_conn(remote string, reader *bufio.Reader, debugf("[%v] checking for match", remote) result := "FAIL" - match, err := check_offer(remote, name, offer) + match, err := checkOffer(remote, name, offer) if err != nil { log.Panic(err) } else if match { @@ -437,7 +437,7 @@ func listen(wg *sync.WaitGroup, listener net.Listener) { for { debugf("Accepting on %v", listener.Addr()) conn, err := listener.Accept() - err_fatal(err) + errFatal(err) log.Printf("new connection: %v\n", conn.RemoteAddr()) // XXX: recover from database failure @@ -446,7 +446,7 @@ func listen(wg *sync.WaitGroup, listener net.Listener) { reader := bufio.NewReader(conn) writer := bufio.NewWriter(conn) - delay := handle_conn(remote, reader, writer) + delay := handleConn(remote, reader, writer) debugf("[%v] closing", remote) conn.Close() // XXX: check err? @@ -465,24 +465,24 @@ func serve() { var wg sync.WaitGroup - for addr, tls_cfg := range cfg.Listen { + for addr, tlsCfg := range cfg.Listen { var listener net.Listener - listen_addr, err := net.ResolveTCPAddr("tcp", addr) - err_fatal(err) - listener, err = net.ListenTCP("tcp", listen_addr) - err_fatal(err) + listenAddr, err := net.ResolveTCPAddr("tcp", addr) + errFatal(err) + listener, err = net.ListenTCP("tcp", listenAddr) + errFatal(err) - if tls_cfg.Key != "" { - log.Printf("Using TLS: cert %v, key %v\n", tls_cfg.Cert, tls_cfg.Key) - cert, err := tls.LoadX509KeyPair(tls_cfg.Cert, tls_cfg.Key) - err_fatalf("Error loading key pair: %v\n", err) + if tlsCfg.Key != "" { + log.Printf("Using TLS: cert %v, key %v\n", tlsCfg.Cert, tlsCfg.Key) + cert, err := tls.LoadX509KeyPair(tlsCfg.Cert, tlsCfg.Key) + errFatalf("Error loading key pair: %v\n", err) listener = tls.NewListener(listener, &tls.Config{ Certificates: []tls.Certificate{cert}, }) } - log.Println("Listening on", listen_addr) + log.Println("Listening on", listenAddr) wg.Add(1) go listen(&wg, listener) } @@ -491,8 +491,8 @@ func serve() { } func main() { - flag_config := flag.String("config", "/etc/goatherd.conf", "Path to config file") - flag.StringVar(&cfg.Db_url, "db-url", "", "URL used to connect to the database.") + flagConfig := flag.String("config", "/etc/goatherd.conf", "Path to config file") + flag.StringVar(&cfg.DbURL, "db-url", "", "URL used to connect to the database.") flag.Uint64Var(&cfg.Lookahead, "lookahead", 10, "Counter range to check for matching OTPs.") flag.BoolVar(&cfg.Debug, "debug", false, "Enable debug output.") flag.DurationVar(&cfg.Faildelay.Duration, "faildelay", 1*time.Second, @@ -501,17 +501,17 @@ func main() { // 1. parse arguments to get config path flag.Parse() // 2. parse config - debug("Using config file", *flag_config) - cfg_file, err := os.Open(*flag_config) - err_fatalf("Can't read config: %v\n", err) - err = json.NewDecoder(cfg_file).Decode(&cfg) - err_fatalf("Error while parsing config file: %v\n", err) + debug("Using config file", *flagConfig) + cfgFile, err := os.Open(*flagConfig) + errFatalf("Can't read config: %v\n", err) + err = json.NewDecoder(cfgFile).Decode(&cfg) + errFatalf("Error while parsing config file: %v\n", err) // 3. parse arguments again to override values from config file, defaults // are not set again flag.Parse() - db, err = sql.Open("postgres", cfg.Db_url) - err_fatal(err) + db, err = sql.Open("postgres", cfg.DbURL) + errFatal(err) if len(flag.Args()) < 1 { serve() @@ -523,45 +523,45 @@ func main() { case "dump-config": enc := json.NewEncoder(os.Stdout) enc.SetIndent("", "\t") - err_fatal(enc.Encode(cfg)) + errFatal(enc.Encode(cfg)) case "init-db": - init_db() + initDB() case "sync-users": if len(args) != 1 { log.Fatalf("sync-users: Invalid number of arguments: %v (expecting <path>)\n", len(args)) } - read_users_from_file(args[0]) + readUsersFromFile(args[0]) case "add-user": if len(args) != 2 { log.Fatalf("add-user: Invalid number of arguments: %v (expecting <username> <secret>)\n", len(args)) } user := args[0] secret := args[1] - create_user(user, secret) + createUser(user, secret) case "set-counter": if len(args) != 2 { log.Fatalf("set-counter: Invalid number of arguments: %v (expecting <username> <counter>)\n", len(args)) } user := args[0] counter, err := strconv.ParseUint(args[1], 0, 64) - err_fatal(err) - if err = set_count(db, user, counter); err != nil { - err_fatal(err) + errFatal(err) + if err = setCount(db, user, counter); err != nil { + errFatal(err) } case "serve": fs := flag.NewFlagSet("serve", flag.ExitOnError) - flag_addr := fs.String("addr", "", "Address to listen on. (default 127.0.0.1:9999)") - flag_tls_key := fs.String("tls-key", "", "Use TLS.") - flag_tls_cert := fs.String("tls-cert", "", "Use TLS.") + flagAddr := fs.String("addr", "", "Address to listen on. (default 127.0.0.1:9999)") + flagTLSKey := fs.String("tls-key", "", "Use TLS.") + flagTLSCert := fs.String("tls-cert", "", "Use TLS.") fs.Parse(args) // copy -addr option to config - if *flag_addr != "" { - cfg.Listen = make(map[string]tls_cfg) - cfg.Listen[*flag_addr] = tls_cfg{Key: *flag_tls_key, Cert: *flag_tls_cert} + if *flagAddr != "" { + cfg.Listen = make(map[string]tlsCfg) + cfg.Listen[*flagAddr] = tlsCfg{Key: *flagTLSKey, Cert: *flagTLSCert} } else if cfg.Listen == nil { - cfg.Listen = make(map[string]tls_cfg) - cfg.Listen["127.0.0.1:9999"] = tls_cfg{} + cfg.Listen = make(map[string]tlsCfg) + cfg.Listen["127.0.0.1:9999"] = tlsCfg{} } serve() diff --git a/goatherd_test.go b/goatherd_test.go index 19d7d07d851b5c26dd508b76bcfe36ce4df44e36..99bf277b8fd1770e6e188357efff3f611991a6c6 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -14,24 +14,24 @@ import ( "github.com/gokyle/hotp" ) -var stdin_writer *bufio.Writer +var stdinWriter *bufio.Writer var ( - username = "foobar" - secret = []byte("foobar") - secret_b64 = base64.StdEncoding.EncodeToString(secret) + username = "foobar" + secret = []byte("foobar") + secretB64 = base64.StdEncoding.EncodeToString(secret) ) -func t_err_fatal(t *testing.T, err error) { +func tErrFatal(t *testing.T, err error) { if err != nil { t.Fatal(err) } } // always uses global username and secret for comparison, so we can test "-" -func create_and_check(user string, sec string) func(*testing.T) { +func createAndCheck(user string, sec string) func(*testing.T) { return func(t *testing.T) { - create_user(user, sec) + createUser(user, sec) var result struct { secret []byte @@ -40,17 +40,17 @@ func create_and_check(user string, sec string) func(*testing.T) { t.Run("exists", func(t *testing.T) { var err error - result.secret, result.count, err = get_user(db, username) - t_err_fatal(t, err) + result.secret, result.count, err = getUser(db, username) + tErrFatal(t, err) }) - t.Run("correct_secret", func(t *testing.T) { + t.Run("correctSecret", func(t *testing.T) { if bytes.Compare(secret, result.secret) != 0 { t.Errorf("secret: %v; result: %v", string(secret), string(result.secret)) } }) - t.Run("correct_count", func(t *testing.T) { + t.Run("correctCount", func(t *testing.T) { if result.count != 1 { t.Errorf("initial count: %v", result.count) } @@ -58,34 +58,34 @@ func create_and_check(user string, sec string) func(*testing.T) { } } -func create_user_t(t *testing.T) { - t.Run("args", create_and_check(username, secret_b64)) +func createUserT(t *testing.T) { + t.Run("args", createAndCheck(username, secretB64)) go func() { - fmt.Fprintln(stdin_writer, username) - fmt.Fprintln(stdin_writer, secret_b64) - stdin_writer.Flush() + fmt.Fprintln(stdinWriter, username) + fmt.Fprintln(stdinWriter, secretB64) + stdinWriter.Flush() }() - t.Run("stdin", create_and_check("-", "-")) + t.Run("stdin", createAndCheck("-", "-")) } func aborted(t *testing.T, err error, aborted ...interface{}) bool { - if transaction_failed(err) { + if transactionFailed(err) { t.Log(aborted...) return true } - t_err_fatal(t, err) + tErrFatal(t, err) return false } -func interleaved_transactions_t(t *testing.T) { +func interleavedTransactionsT(t *testing.T) { var err error var txs [2]*sql.Tx for i := range txs { txs[i], err = db.Begin() - t_err_fatal(t, err) + tErrFatal(t, err) defer func(i int) { txs[i].Rollback() }(i) } @@ -114,16 +114,16 @@ func interleaved_transactions_t(t *testing.T) { t.Error("No transaction failure despite interleaved transactions!") } -func nested_transactions_t(t *testing.T) { +func nestedTransactionsT(t *testing.T) { outer, err := db.Begin() - t_err_fatal(t, err) + tErrFatal(t, err) defer outer.Rollback() inner, err := db.Begin() - t_err_fatal(t, err) + tErrFatal(t, err) defer inner.Rollback() var c uint64 - t_err_fatal(t, outer.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c)) + tErrFatal(t, outer.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c)) err = inner.QueryRow("SELECT count FROM users WHERE name = $1", username).Scan(&c) if aborted(t, err, "rollback after inner.Query") { @@ -155,32 +155,32 @@ func nested_transactions_t(t *testing.T) { t.Error("No transaction failure despite nested transactions") } -func transaction_conflict_t(t *testing.T) { - t.Run("interleaved_transactions_t", interleaved_transactions_t) - t.Run("nested_transactions_t", nested_transactions_t) +func transactionConflictT(t *testing.T) { + t.Run("interleavedTransactionsT", interleavedTransactionsT) + t.Run("nestedTransactionsT", nestedTransactionsT) } -func check_offer_t(t *testing.T) { - t.Run("no_such_user", func(t *testing.T) { - _, err := check_offer("mock", "no such user", "dummy offer") +func checkOfferT(t *testing.T) { + t.Run("noSuchUser", func(t *testing.T) { + _, err := checkOffer("mock", "no such user", "dummy offer") if err != sql.ErrNoRows { t.Error("err:", nil) } }) - secret, count, err := get_user(db, username) - t_err_fatal(t, err) + secret, count, err := getUser(db, username) + tErrFatal(t, err) t.Run("fail", func(t *testing.T) { - ok, err := check_offer("mock", username, "dummy offer") - t_err_fatal(t, err) + ok, err := checkOffer("mock", username, "dummy offer") + tErrFatal(t, err) if ok { t.Fail() } - t.Run("too_far_out", func(t *testing.T) { + t.Run("tooFarOut", func(t *testing.T) { ahead := hotp.NewHOTP(secret, count+cfg.Lookahead+1, otpLen) - ok, err = check_offer("mock", username, ahead.OTP()) + ok, err = checkOffer("mock", username, ahead.OTP()) if ok { t.Fail() } @@ -189,25 +189,25 @@ func check_offer_t(t *testing.T) { t.Run("ok", func(t *testing.T) { cur := hotp.NewHOTP(secret, count, otpLen) - ok, err := check_offer("mock", username, cur.OTP()) - t_err_fatal(t, err) + ok, err := checkOffer("mock", username, cur.OTP()) + tErrFatal(t, err) if !ok { t.Fail() } cur.Increment() t.Run("incremented", func(t *testing.T) { - ok, err := check_offer("mock", username, cur.OTP()) - t_err_fatal(t, err) + ok, err := checkOffer("mock", username, cur.OTP()) + tErrFatal(t, err) if !ok { t.Fail() } }) t.Run("lookahead", func(t *testing.T) { - ok, err := check_offer("mock", username, + ok, err := checkOffer("mock", username, hotp.NewHOTP(secret, cur.Counter()+cfg.Lookahead, otpLen).OTP()) - t_err_fatal(t, err) + tErrFatal(t, err) if !ok { t.Fail() } @@ -216,40 +216,40 @@ func check_offer_t(t *testing.T) { } -type client_t func(r *io.PipeReader, w *io.PipeWriter) +type mockClient func(r *io.PipeReader, w *io.PipeWriter) -func mock_conn(client client_t) (delay *sync.Mutex) { - server_r, client_w := io.Pipe() - server_r_buf := bufio.NewReader(server_r) - defer server_r.Close() - client_r, server_w := io.Pipe() - server_w_buf := bufio.NewWriter(server_w) - defer client_r.Close() +func mockConn(client mockClient) (delay *sync.Mutex) { + serverR, clientW := io.Pipe() + serverRBuf := bufio.NewReader(serverR) + defer serverR.Close() + clientR, serverW := io.Pipe() + serverWBuf := bufio.NewWriter(serverW) + defer clientR.Close() - go client(client_r, client_w) + go client(clientR, clientW) - delay = handle_conn("handle_conn_t", server_r_buf, server_w_buf) + delay = handleConn("handleConnT", serverRBuf, serverWBuf) return } -func interact(name string, offer string, recv chan string) client_t { +func interact(name string, offer string, recv chan string) mockClient { return func(r *io.PipeReader, w *io.PipeWriter) { w.Write(append([]byte(name), '\n')) w.Write(append([]byte(offer), '\n')) - answer, _ := get_line(bufio.NewScanner(r)) + answer, _ := getLine(bufio.NewScanner(r)) recv <- answer } } -func handle_conn_t(t *testing.T) { +func handleConnT(t *testing.T) { // zero value Mutex is unlocked, can be compared to a given Mutex to // (unreliably?) check if it is locked unlocked := sync.Mutex{} - t.Run("dummy_user", func(t *testing.T) { + t.Run("dummyUser", func(t *testing.T) { recv := make(chan string) - delay, answer := mock_conn(interact("dummy user", "dummy offer", recv)), <-recv + delay, answer := mockConn(interact("dummy user", "dummy offer", recv)), <-recv if delay != nil { t.Fail() } @@ -261,9 +261,9 @@ func handle_conn_t(t *testing.T) { } }) - t.Run("dummy_offer", func(t *testing.T) { + t.Run("dummyOffer", func(t *testing.T) { recv := make(chan string) - delay, answer := mock_conn(interact(username, "dummy offer", recv)), <-recv + delay, answer := mockConn(interact(username, "dummy offer", recv)), <-recv if delay == nil { t.Fail() } else if *delay == unlocked { @@ -279,12 +279,12 @@ func handle_conn_t(t *testing.T) { } }) - otp, err := get_otp(db, username) - t_err_fatal(t, err) + otp, err := getOTP(db, username) + tErrFatal(t, err) t.Run("ok", func(t *testing.T) { recv := make(chan string) - delay := mock_conn(interact(username, otp.OTP(), recv)) + delay := mockConn(interact(username, otp.OTP(), recv)) answer := <-recv if delay != nil { t.Fail() @@ -298,9 +298,9 @@ func handle_conn_t(t *testing.T) { }) otp.Increment() - t.Run("read_error", func(t *testing.T) { - t.Run("close_immediately", func(t *testing.T) { - delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) { + t.Run("readError", func(t *testing.T) { + t.Run("closeImmediately", func(t *testing.T) { + delay := mockConn(func(r *io.PipeReader, w *io.PipeWriter) { w.Close() }) if delay != nil { @@ -311,8 +311,8 @@ func handle_conn_t(t *testing.T) { } }) - t.Run("close_after_username", func(t *testing.T) { - delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) { + t.Run("closeAfterUsername", func(t *testing.T) { + delay := mockConn(func(r *io.PipeReader, w *io.PipeWriter) { w.Write(append([]byte(username), '\n')) w.Close() }) @@ -325,9 +325,9 @@ func handle_conn_t(t *testing.T) { }) }) - t.Run("write_error", func(t *testing.T) { + t.Run("writeError", func(t *testing.T) { t.Run("FAIL", func(t *testing.T) { - delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) { + delay := mockConn(func(r *io.PipeReader, w *io.PipeWriter) { r.Close() w.Write(append([]byte(username), '\n')) w.Write(append([]byte("dummy offer"), '\n')) @@ -342,7 +342,7 @@ func handle_conn_t(t *testing.T) { }) t.Run("OK", func(t *testing.T) { - delay := mock_conn(func(r *io.PipeReader, w *io.PipeWriter) { + delay := mockConn(func(r *io.PipeReader, w *io.PipeWriter) { r.Close() w.Write(append([]byte(username), '\n')) w.Write(append([]byte(otp.OTP()), '\n')) @@ -359,24 +359,24 @@ func handle_conn_t(t *testing.T) { } func TestMain(t *testing.T) { - if db_url, ok := os.LookupEnv("DB_URL"); ok { - cfg.Db_url = db_url + if dbURL, ok := os.LookupEnv("DB_URL"); ok { + cfg.DbURL = dbURL } else { t.Error("No DB_URL specified") } cfg.Lookahead = 3 cfg.Debug = false faildelay.userlocks = make(map[string]*sync.Mutex) - stdin_r, stdin_w := io.Pipe() - stdin_scanner, stdin_writer = bufio.NewScanner(stdin_r), bufio.NewWriter(stdin_w) + stdinR, stdinW := io.Pipe() + stdinScanner, stdinWriter = bufio.NewScanner(stdinR), bufio.NewWriter(stdinW) var err error - db, err = sql.Open("postgres", cfg.Db_url) - t_err_fatal(t, err) - init_db() - - t.Run("create_user", create_user_t) - t.Run("transaction_conflict_handling", transaction_conflict_t) - t.Run("check_offer", check_offer_t) - t.Run("handle_conn", handle_conn_t) + db, err = sql.Open("postgres", cfg.DbURL) + tErrFatal(t, err) + initDB() + + t.Run("createUser", createUserT) + t.Run("transactionConflict", transactionConflictT) + t.Run("checkOffer", checkOfferT) + t.Run("handleConn", handleConnT) }