From c40a7ea0d35a2f692feeffc33049cec12fd75172 Mon Sep 17 00:00:00 2001
From: Michael Eischer <michael.eischer@fau.de>
Date: Sat, 25 Jun 2022 14:56:07 +0200
Subject: [PATCH] add isTOTP data base fields

---
 goatherd.go      | 61 ++++++++++++++++++++++++++++++++----------------
 goatherd_test.go |  2 +-
 2 files changed, 42 insertions(+), 21 deletions(-)

diff --git a/goatherd.go b/goatherd.go
index f9a0933..9269c68 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -130,6 +130,7 @@ func createTable() {
         CREATE TABLE users(
             name TEXT NOT NULL PRIMARY KEY,
             secret BYTEA NOT NULL,
+            isTOTP BOOLEAN,
             count INTEGER)
         `)
 	errFatalf("Failed to create table: %v\n", err)
@@ -169,37 +170,49 @@ func readUsersFromFile(path string) {
 	errFatal(err)
 	lines := bufio.NewScanner(f)
 
-	usersWant := make(map[string]string)
+	type userInfo struct {
+		sB64   string
+		isTOTP bool
+	}
+	usersWant := make(map[string]userInfo)
 	for lines.Scan() {
 		l := strings.TrimSpace(lines.Text())
 		if len(l) == 0 || l[0] == '#' {
 			continue
 		}
-		tokens := strings.SplitN(l, " ", 2)
-		if len(tokens) != 2 {
+		tokens := strings.SplitN(l, " ", 3)
+		if len(tokens) != 3 {
 			log.Fatalf("unexpected input in userlist: %v\n", l)
 		}
-		usersWant[strings.TrimSpace(tokens[0])] = strings.TrimSpace(tokens[1])
+		t2 := strings.TrimSpace(tokens[2])
+		if t2 != "0" && t2 != "1" {
+			log.Fatalf("unexpected isTOTP flag in userlist: %v\n", l)
+		}
+		usersWant[strings.TrimSpace(tokens[0])] = userInfo{
+			strings.TrimSpace(tokens[1]),
+			t2 == "1",
+		}
 	}
 
 	// update and delete existing users
-	rows, err := db.Query("SELECT name, secret FROM users")
+	rows, err := db.Query("SELECT name, secret, isTOTP FROM users")
 	errFatal(err)
 	defer rows.Close()
 	for rows.Next() {
 		var have struct {
 			name   string
 			secret []byte
+			isTOTP bool
 		}
-		_ = rows.Scan(&have.name, &have.secret)
+		_ = rows.Scan(&have.name, &have.secret, &have.isTOTP)
 		debug("have", have.name)
 
-		if sB64, want := usersWant[have.name]; want {
-			secret, err := base64.StdEncoding.DecodeString(sB64)
+		if s, want := usersWant[have.name]; want {
+			secret, err := base64.StdEncoding.DecodeString(s.sB64)
 			if err != nil {
 				log.Fatalf("%v: %v: %v\n", path, have.name, err)
 			}
-			if !bytes.Equal(secret, have.secret) {
+			if !bytes.Equal(secret, have.secret) || have.isTOTP != s.isTOTP {
 				log.Printf("update %v\n", have.name)
 				_, err = db.Exec("UPDATE users SET count = 1, secret = $1 WHERE name = $2", secret, have.name)
 				errFatal(err)
@@ -213,16 +226,16 @@ func readUsersFromFile(path string) {
 	}
 
 	// add new users
-	for user, sB64 := range usersWant {
+	for user, s := range usersWant {
 		log.Printf("create %v\n", user)
-		createUser(user, sB64)
+		createUser(user, s.sB64, s.isTOTP)
 	}
 }
 
 // global var for testing, not modified during normal execution
 var stdinScanner = bufio.NewScanner(os.Stdin)
 
-func createUser(name string, secretB64 string) {
+func createUser(name string, secretB64 string, isTOTP bool) {
 	debug("Creating user")
 
 	var err error
@@ -242,19 +255,19 @@ func createUser(name string, secretB64 string) {
 
 	debug("Adding user with name", name)
 	_, err = db.Exec(`
-        INSERT INTO users(name, secret, count) values($1, $2, $3)
-        ON CONFLICT (name) DO UPDATE SET secret = $2, count = $3
-        `, name, secret, 1)
+        INSERT INTO users(name, secret, isTOTP, count) values($1, $2, $3, $4)
+        ON CONFLICT (name) DO UPDATE SET secret = $2, isTOTP = $3, count = $4
+        `, name, secret, isTOTP, 1)
 	errFatalf("Failed to create user: %v\n", err)
 }
 
 func getUser(db dbConn, name string) (secret []byte, count uint64, isTOTP bool, err error) {
-	stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1")
+	stmt, err := db.Prepare("SELECT secret, isTOTP, count FROM users WHERE name = $1")
 	if err != nil {
 		return
 	}
 
-	err = stmt.QueryRow(name).Scan(&secret, &count)
+	err = stmt.QueryRow(name).Scan(&secret, &isTOTP, &count)
 	stmt.Close()
 
 	return
@@ -674,12 +687,20 @@ func main() {
 			}
 			readUsersFromFile(args[0])
 		case "add-user":
-			if len(args) != 2 {
-				log.Fatalf("add-user: Invalid number of arguments: %v (expecting <username> <secret>)\n", len(args))
+			if len(args) != 3 {
+				log.Fatalf("add-user: Invalid number of arguments: %v (expecting <username> <secret> <isTOTP:0|1>)\n", len(args))
 			}
 			user := args[0]
 			secret := args[1]
-			createUser(user, secret)
+			isTOTP := false
+			if args[2] == "1" {
+				isTOTP = true
+			} else if args[2] == "0" {
+				isTOTP = false
+			} else {
+				log.Fatalf("parameter isTOTP must be 0 or 1")
+			}
+			createUser(user, secret, isTOTP)
 		case "get-counter":
 			if len(args) != 1 {
 				log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args))
diff --git a/goatherd_test.go b/goatherd_test.go
index a0914c5..642e75a 100644
--- a/goatherd_test.go
+++ b/goatherd_test.go
@@ -34,7 +34,7 @@ func tErrFatal(t *testing.T, err error) {
 // always uses global username and secret for comparison, so we can test "-"
 func createAndCheck(user string, sec string, isTOTP bool) func(*testing.T) {
 	return func(t *testing.T) {
-		createUser(user, sec)
+		createUser(user, sec, isTOTP)
 
 		var result struct {
 			secret []byte
-- 
GitLab