From c40a7ea0d35a2f692feeffc33049cec12fd75172 Mon Sep 17 00:00:00 2001 From: Michael Eischer <michael.eischer@fau.de> Date: Sat, 25 Jun 2022 14:56:07 +0200 Subject: [PATCH] add isTOTP data base fields --- goatherd.go | 61 ++++++++++++++++++++++++++++++++---------------- goatherd_test.go | 2 +- 2 files changed, 42 insertions(+), 21 deletions(-) diff --git a/goatherd.go b/goatherd.go index f9a0933..9269c68 100644 --- a/goatherd.go +++ b/goatherd.go @@ -130,6 +130,7 @@ func createTable() { CREATE TABLE users( name TEXT NOT NULL PRIMARY KEY, secret BYTEA NOT NULL, + isTOTP BOOLEAN, count INTEGER) `) errFatalf("Failed to create table: %v\n", err) @@ -169,37 +170,49 @@ func readUsersFromFile(path string) { errFatal(err) lines := bufio.NewScanner(f) - usersWant := make(map[string]string) + type userInfo struct { + sB64 string + isTOTP bool + } + usersWant := make(map[string]userInfo) for lines.Scan() { l := strings.TrimSpace(lines.Text()) if len(l) == 0 || l[0] == '#' { continue } - tokens := strings.SplitN(l, " ", 2) - if len(tokens) != 2 { + tokens := strings.SplitN(l, " ", 3) + if len(tokens) != 3 { log.Fatalf("unexpected input in userlist: %v\n", l) } - usersWant[strings.TrimSpace(tokens[0])] = strings.TrimSpace(tokens[1]) + t2 := strings.TrimSpace(tokens[2]) + if t2 != "0" && t2 != "1" { + log.Fatalf("unexpected isTOTP flag in userlist: %v\n", l) + } + usersWant[strings.TrimSpace(tokens[0])] = userInfo{ + strings.TrimSpace(tokens[1]), + t2 == "1", + } } // update and delete existing users - rows, err := db.Query("SELECT name, secret FROM users") + rows, err := db.Query("SELECT name, secret, isTOTP FROM users") errFatal(err) defer rows.Close() for rows.Next() { var have struct { name string secret []byte + isTOTP bool } - _ = rows.Scan(&have.name, &have.secret) + _ = rows.Scan(&have.name, &have.secret, &have.isTOTP) debug("have", have.name) - if sB64, want := usersWant[have.name]; want { - secret, err := base64.StdEncoding.DecodeString(sB64) + if s, want := usersWant[have.name]; want { + secret, err := base64.StdEncoding.DecodeString(s.sB64) if err != nil { log.Fatalf("%v: %v: %v\n", path, have.name, err) } - if !bytes.Equal(secret, have.secret) { + if !bytes.Equal(secret, have.secret) || have.isTOTP != s.isTOTP { log.Printf("update %v\n", have.name) _, err = db.Exec("UPDATE users SET count = 1, secret = $1 WHERE name = $2", secret, have.name) errFatal(err) @@ -213,16 +226,16 @@ func readUsersFromFile(path string) { } // add new users - for user, sB64 := range usersWant { + for user, s := range usersWant { log.Printf("create %v\n", user) - createUser(user, sB64) + createUser(user, s.sB64, s.isTOTP) } } // global var for testing, not modified during normal execution var stdinScanner = bufio.NewScanner(os.Stdin) -func createUser(name string, secretB64 string) { +func createUser(name string, secretB64 string, isTOTP bool) { debug("Creating user") var err error @@ -242,19 +255,19 @@ func createUser(name string, secretB64 string) { debug("Adding user with name", name) _, err = db.Exec(` - INSERT INTO users(name, secret, count) values($1, $2, $3) - ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3 - `, name, secret, 1) + INSERT INTO users(name, secret, isTOTP, count) values($1, $2, $3, $4) + ON CONFLICT (name) DO UPDATE SET secret = $2, isTOTP = $3, count = $4 + `, name, secret, isTOTP, 1) errFatalf("Failed to create user: %v\n", err) } func getUser(db dbConn, name string) (secret []byte, count uint64, isTOTP bool, err error) { - stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1") + stmt, err := db.Prepare("SELECT secret, isTOTP, count FROM users WHERE name = $1") if err != nil { return } - err = stmt.QueryRow(name).Scan(&secret, &count) + err = stmt.QueryRow(name).Scan(&secret, &isTOTP, &count) stmt.Close() return @@ -674,12 +687,20 @@ func main() { } readUsersFromFile(args[0]) case "add-user": - if len(args) != 2 { - log.Fatalf("add-user: Invalid number of arguments: %v (expecting <username> <secret>)\n", len(args)) + if len(args) != 3 { + log.Fatalf("add-user: Invalid number of arguments: %v (expecting <username> <secret> <isTOTP:0|1>)\n", len(args)) } user := args[0] secret := args[1] - createUser(user, secret) + isTOTP := false + if args[2] == "1" { + isTOTP = true + } else if args[2] == "0" { + isTOTP = false + } else { + log.Fatalf("parameter isTOTP must be 0 or 1") + } + createUser(user, secret, isTOTP) case "get-counter": if len(args) != 1 { log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args)) diff --git a/goatherd_test.go b/goatherd_test.go index a0914c5..642e75a 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -34,7 +34,7 @@ func tErrFatal(t *testing.T, err error) { // always uses global username and secret for comparison, so we can test "-" func createAndCheck(user string, sec string, isTOTP bool) func(*testing.T) { return func(t *testing.T) { - createUser(user, sec) + createUser(user, sec, isTOTP) var result struct { secret []byte -- GitLab