From cfb6be5bd05e9c377e67c660c4ad45acd869380d Mon Sep 17 00:00:00 2001
From: Simon Schuster <git@rationality.eu>
Date: Mon, 30 May 2022 14:56:54 +0200
Subject: [PATCH] Add ability to store multiple secrets per user

Alternative secrets are stored in the same database table, with the
username schema: "$USERNAME/suffix", where suffix can be any unique
string, but should probably be chosen in a meaningful way to identify
which device the key belongs to.

Please note that this linearly decreases the security of the 2fa
solution, because it accepts twice as many HOTP-keys when a second key
is added.
---
 goatherd.go      | 184 +++++++++++++++++++++++++++--------------------
 goatherd_test.go |  13 ++--
 2 files changed, 115 insertions(+), 82 deletions(-)

diff --git a/goatherd.go b/goatherd.go
index ce3255a..d4260cf 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -60,6 +60,13 @@ var cfg struct {
 	ReadTimeout         duration          `json:"read_timeout"`
 }
 
+type otpRecord struct {
+	name   string
+	secret []byte
+	count  uint64
+	hotp   *hotp.HOTP
+}
+
 // state for per-user ratelimiting
 //
 // Each user has a corresponding Mutex in faildelay.userlocks. A lock is
@@ -245,18 +252,48 @@ func createUser(name string, secretB64 string) {
 	errFatalf("Failed to create user: %v\n", err)
 }
 
-func getUser(db dbConn, name string) (secret []byte, count uint64, err error) {
+func getEntity(db dbConn, name string) (secret []byte, count uint64, err error) {
 	stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1")
 	if err != nil {
 		return
 	}
-
 	err = stmt.QueryRow(name).Scan(&secret, &count)
 	stmt.Close()
 
 	return
 }
 
+
+func getOTPWithAlts(db dbConn, name string) (recs map[string]otpRecord, err error) {
+	stmt, err := db.Prepare("SELECT name, secret, count FROM users WHERE name = $1 or name like CONCAT($1, '/%')")
+	if err != nil {
+		return
+	}
+
+	rows, err := stmt.Query(name)
+	if err != nil {
+		return nil, err
+	}
+	defer rows.Close()
+
+	recs = make(map[string]otpRecord)
+
+	for rows.Next() {
+		var rec otpRecord
+		if err := rows.Scan(&rec.name, &rec.secret, &rec.count); err != nil {
+			return nil, err
+		}
+		rec.hotp = hotp.NewHOTP(rec.secret, rec.count, otpLen)
+		recs[rec.name] = rec
+	}
+
+	if err = rows.Err(); err != nil {
+		return nil, err
+	}
+
+	return
+}
+
 func setCount(db dbConn, name string, count uint64) error {
 	inc, err := db.Prepare("UPDATE users SET count = $1 WHERE name = $2")
 	if err != nil {
@@ -268,15 +305,6 @@ func setCount(db dbConn, name string, count uint64) error {
 	return err
 }
 
-func getOTP(db dbConn, name string) (*hotp.HOTP, error) {
-	secret, count, err := getUser(db, name)
-	if err != nil {
-		return nil, err
-	}
-
-	return hotp.NewHOTP(secret, count, otpLen), nil
-}
-
 func transactionFailed(err error) bool {
 	if err == nil {
 		return false
@@ -302,7 +330,7 @@ type autoresyncEntry struct {
 // Retrieve secret and count for given username and try to find a match within
 // the lookahead range. Update count in DB if match is found. All within a
 // transaction, retrying if it fails.
-func checkOffer(remote string, name string, offer string) (bool, error) {
+func checkOffer(remote string, user string, offer string) (bool, error) {
 	for {
 		var lookahead uint64 = cfg.Lookahead
 		var i uint64
@@ -312,82 +340,86 @@ func checkOffer(remote string, name string, offer string) (bool, error) {
 
 		ok := false
 
-		debugf("[%v] looking up data for %v", remote, name)
-		hotp, err := getOTP(tx, name)
+		debugf("[%v] looking up data for %v", remote, user)
+		userrecords, err := getOTPWithAlts(tx, user)
 		if transactionFailed(err) {
 			goto retry
-		} else if err == sql.ErrNoRows {
+		} else if len(userrecords) == 0 {
 			_ = tx.Rollback()
-			return false, err
+			return false, errors.New("No such user")
 		} else {
 			errPanic(err)
 		}
 
-		autoresyncListLock.RLock()
-		if s, found := autoresyncList[name]; found {
-			autoresyncListLock.RUnlock()
-			debugf("[%v] autoresync in progress: %v", remote, s)
-			if uint64(time.Now().Unix()-s.Time) <= cfg.AutoresyncTime {
-				if s.Num >= cfg.AutoresyncRepeat && s.Counter-hotp.Counter() < cfg.AutoresyncLookahead {
-					// if the user had a sufficient number of successful tries that were not within
-					// standard lookahead range but within cfg.AutoresyncLookahead within cfg.AutoresyncTime seconds,
-					// temporarily increase lookahead to authenticate and resync.
-					debugf("[%v] autoresync conditions: increasing lookahead to ", cfg.AutoresyncLookahead)
-					lookahead = cfg.AutoresyncLookahead
+		for name, record := range userrecords {
+			hotp := record.hotp
+
+			autoresyncListLock.RLock()
+			if s, found := autoresyncList[name]; found {
+				autoresyncListLock.RUnlock()
+				debugf("[%v] autoresync in progress: %v", remote, s)
+				if uint64(time.Now().Unix()-s.Time) <= cfg.AutoresyncTime {
+					if s.Num >= cfg.AutoresyncRepeat && s.Counter-hotp.Counter() < cfg.AutoresyncLookahead {
+						// if the user had a sufficient number of successful tries that were not within
+						// standard lookahead range but within cfg.AutoresyncLookahead within cfg.AutoresyncTime seconds,
+						// temporarily increase lookahead to authenticate and resync.
+						debugf("[%v] autoresync conditions: increasing lookahead to ", cfg.AutoresyncLookahead)
+						lookahead = cfg.AutoresyncLookahead
+					}
+				} else {
+					// timeout
+					debugf("[%v] autoresync timeout: %v", remote, name)
+					autoresyncListLock.Lock()
+					delete(autoresyncList, name)
+					autoresyncListLock.Unlock()
 				}
 			} else {
-				// timeout
-				debugf("[%v] autoresync timeout: %v", remote, name)
-				autoresyncListLock.Lock()
-				delete(autoresyncList, name)
-				autoresyncListLock.Unlock()
+				autoresyncListLock.RUnlock()
 			}
-		} else {
-			autoresyncListLock.RUnlock()
-		}
 
-		for i = uint64(0); i <= lookahead; i++ {
-			debugf("[%v] checking for match (offset %v)", remote, i)
-			// .Check increments .Counter if successfull
-			// otherwise do it explicitly
-			if hotp.Check(offer) {
-				debugf("[%v] ok, set new count", remote)
-				err = setCount(tx, name, hotp.Counter())
-				if transactionFailed(err) {
-					goto retry
+			for i = uint64(0); i <= lookahead; i++ {
+				debugf("[%v] checking for match (offset %v)", remote, i)
+				// .Check increments .Counter if successfull
+				// otherwise do it explicitly
+				if hotp.Check(offer) {
+					debugf("[%v] ok, set new count", remote)
+					err = setCount(tx, name, hotp.Counter())
+					if transactionFailed(err) {
+						goto retry
+					} else {
+						errPanic(err)
+					}
+
+					ok = true
+					goto commit
 				} else {
-					errPanic(err)
+					hotp.Increment()
 				}
-
-				ok = true
-				goto commit
-			} else {
-				hotp.Increment()
 			}
-		}
-		// check failed, try extended range for autoresync
-		for ; i <= cfg.AutoresyncLookahead; i++ {
-			debugf("[%v] autoresync checking for match (offset %v counter %v)", remote, i, hotp.Counter())
-			if hotp.Check(offer) {
-				autoresyncListLock.Lock()
-				debugf("[%v] autoresync repeat count increase hotp.Counter %v, %v", remote, hotp.Counter(), autoresyncList[name])
-				if s, found := autoresyncList[name]; found && hotp.Counter()-1 == s.Counter {
-					s.Time = time.Now().Unix()
-					s.Num++
-					s.Counter++
-					debugf("[%v] autoresync repeat count increase for %v", remote, name)
-				} else {
-					autoresyncList[name] = &autoresyncEntry{
-						Time:    time.Now().Unix(),
-						Counter: hotp.Counter(),
-						Num:     1,
+			// check failed, try extended range for autoresync
+			for ; i <= cfg.AutoresyncLookahead; i++ {
+				debugf("[%v] autoresync checking for match (offset %v counter %v)", remote, i, hotp.Counter())
+				if hotp.Check(offer) {
+					autoresyncListLock.Lock()
+					debugf("[%v] autoresync repeat count increase hotp.Counter %v, %v", remote, hotp.Counter(), autoresyncList[name])
+					if s, found := autoresyncList[name]; found && hotp.Counter()-1 == s.Counter {
+						s.Time = time.Now().Unix()
+						s.Num++
+						s.Counter++
+						debugf("[%v] autoresync repeat count increase for %v", remote, name)
+					} else {
+						autoresyncList[name] = &autoresyncEntry{
+							Time:    time.Now().Unix(),
+							Counter: hotp.Counter(),
+							Num:     1,
+						}
+						debugf("[%v] autoresync repeat count init for %v", remote, name)
 					}
-					debugf("[%v] autoresync repeat count init for %v", remote, name)
+					autoresyncListLock.Unlock()
+					break
+				} else {
+					hotp.Increment()
 				}
-				autoresyncListLock.Unlock()
-				break
-			} else {
-				hotp.Increment()
 			}
 		}
 
@@ -424,8 +456,8 @@ func handleConn(remote string, reader *bufio.Reader,
 	}
 	debugf("[%v] name: %v", remote, name)
 
-	_, _, err = getUser(db, name)
-	if err == sql.ErrNoRows {
+	records, err := getOTPWithAlts(db, name)
+	if len(records) == 0 {
 		log.Printf("[%v] Unknown user: %v", remote, name)
 		return
 	} else if err != nil {
@@ -573,7 +605,7 @@ func incCount(user string, diff uint64) error {
 			return err
 		}
 
-		_, counter, err := getUser(tx, user)
+		_, counter, err := getEntity(tx, user)
 		if transactionFailed(err) {
 			goto retry
 		} else if err != nil {
@@ -651,7 +683,7 @@ func main() {
 				log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args))
 			}
 			user := args[0]
-			_, counter, err := getUser(db, user)
+			_, counter, err := getEntity(db, user)
 			errFatal(err)
 			fmt.Printf("%v: %v\n", user, counter)
 		case "set-counter":
diff --git a/goatherd_test.go b/goatherd_test.go
index 3bdf1e7..0f2b780 100644
--- a/goatherd_test.go
+++ b/goatherd_test.go
@@ -42,7 +42,7 @@ func createAndCheck(user string, sec string) func(*testing.T) {
 
 		t.Run("exists", func(t *testing.T) {
 			var err error
-			result.secret, result.count, err = getUser(db, username)
+			result.secret, result.count, err = getEntity(db, username)
 			tErrFatal(t, err)
 		})
 
@@ -165,12 +165,12 @@ func transactionConflictT(t *testing.T) {
 func checkOfferT(t *testing.T) {
 	t.Run("noSuchUser", func(t *testing.T) {
 		_, err := checkOffer("mock", "no such user", "dummy offer")
-		if err != sql.ErrNoRows {
+		if err.Error() != "No such user" {
 			t.Error("err:", nil)
 		}
 	})
 
-	secret, count, err := getUser(db, username)
+	secret, count, err := getEntity(db, username)
 	tErrFatal(t, err)
 
 	t.Run("fail", func(t *testing.T) {
@@ -278,11 +278,11 @@ func incCountT(t *testing.T) {
 	})
 
 	t.Run("by2", func(t *testing.T) {
-		_, counter, err := getUser(db, username)
+		_, counter, err := getEntity(db, username)
 		tErrFatal(t, err)
 		err = incCount(username, 2)
 		tErrFatal(t, err)
-		_, newCounter, err := getUser(db, username)
+		_, newCounter, err := getEntity(db, username)
 		tErrFatal(t, err)
 
 		if newCounter != counter+2 {
@@ -356,8 +356,9 @@ func handleConnT(t *testing.T) {
 		}
 	})
 
-	otp, err := getOTP(db, username)
+	records, err := getOTPWithAlts(db, username)
 	tErrFatal(t, err)
+	otp := records[username].hotp
 
 	t.Run("ok", func(t *testing.T) {
 		recv := make(chan string)
-- 
GitLab