diff --git a/goatherd.go b/goatherd.go index 1cee3ca80f931e4dfe3ad5ed4afdcb154caf964f..94840cb4e02cf563da9589c0d0c998e97bfbdfdd 100644 --- a/goatherd.go +++ b/goatherd.go @@ -261,7 +261,7 @@ func createUser(name string, secretB64 string, isTOTP bool) { errFatalf("Failed to create user: %v\n", err) } -func getUser(db dbConn, name string) (secret []byte, count uint64, isTOTP bool, err error) { +func getEntity(db dbConn, name string) (secret []byte, count uint64, isTOTP bool, err error) { stmt, err := db.Prepare("SELECT secret, isTOTP, count FROM users WHERE name = $1") if err != nil { return @@ -273,6 +273,48 @@ func getUser(db dbConn, name string) (secret []byte, count uint64, isTOTP bool, return } +type otpRecord struct { + name string + secret []byte + isTOTP bool + count uint64 + otp twofactor.OTP +} + +func getOTPWithAlts(db dbConn, name string) (recs map[string]otpRecord, err error) { + stmt, err := db.Prepare("SELECT name, secret, isTOTP, count FROM users WHERE name = $1 or name like CONCAT($1, '/%')") + if err != nil { + return + } + + rows, err := stmt.Query(name) + if err != nil { + return nil, err + } + defer rows.Close() + + recs = make(map[string]otpRecord) + + for rows.Next() { + var rec otpRecord + if err := rows.Scan(&rec.name, &rec.secret, &rec.isTOTP, &rec.count); err != nil { + return nil, err + } + if rec.isTOTP { + rec.otp = twofactor.NewTOTPSHA1(rec.secret, 0, 30, totpLen) + } else { + rec.otp = twofactor.NewHOTP(rec.secret, rec.count, hotpLen) + } + recs[rec.name] = rec + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return +} + func setCount(db dbConn, name string, count uint64) error { inc, err := db.Prepare("UPDATE users SET count = $1 WHERE name = $2") if err != nil { @@ -284,18 +326,6 @@ func setCount(db dbConn, name string, count uint64) error { return err } -func getOTP(db dbConn, name string) (twofactor.OTP, uint64, error) { - secret, count, isTOTP, err := getUser(db, name) - if err != nil { - return nil, 0, err - } - - if isTOTP { - return twofactor.NewTOTPSHA1(secret, 0, 30, totpLen), count, nil - } - return twofactor.NewHOTP(secret, count, hotpLen), count, nil -} - func transactionFailed(err error) bool { if err == nil { return false @@ -399,31 +429,36 @@ func checkTOTP(totp *twofactor.TOTP, minCount uint64, remote string, name string // Retrieve secret and count for given username and try to find a match within // the lookahead range. Update count in DB if match is found. All within a // transaction, retrying if it fails. -func checkOffer(remote string, name string, offer string) (bool, error) { +func checkOffer(remote string, user string, offer string) (bool, error) { for { inner := func(tx *sql.Tx) (bool, error) { - debugf("[%v] looking up data for %v", remote, name) - otp, count, err := getOTP(tx, name) + debugf("[%v] looking up data for %v", remote, user) + userrecords, err := getOTPWithAlts(tx, user) if err != nil { return false, err + } else if len(userrecords) == 0 { + return false, errors.New("no such user") } var ok bool - var nextCtr uint64 - switch otp.Type() { - case twofactor.OATH_HOTP: - ok, nextCtr = checkHOTP(otp.(*twofactor.HOTP), remote, name, offer) - case twofactor.OATH_TOTP: - ok, nextCtr = checkTOTP(otp.(*twofactor.TOTP), count, remote, name, offer) - default: - log.Panicf("unsupported otp type %v", otp.Type()) - } + for name, r := range userrecords { + var nextCtr uint64 + switch r.otp.Type() { + case twofactor.OATH_HOTP: + ok, nextCtr = checkHOTP(r.otp.(*twofactor.HOTP), remote, name, offer) + case twofactor.OATH_TOTP: + ok, nextCtr = checkTOTP(r.otp.(*twofactor.TOTP), r.count, remote, name, offer) + default: + log.Panicf("unsupported otp type %v", r.otp.Type()) + } - if ok { - debugf("[%v] ok, set new count", remote) - err := setCount(tx, name, nextCtr) - if err != nil { - return false, err + if ok { + debugf("[%v] ok, set new count", remote) + err := setCount(tx, user, nextCtr) + if err != nil { + return false, err + } + break } } @@ -468,8 +503,8 @@ func handleConn(remote string, reader *bufio.Reader, } debugf("[%v] name: %v", remote, name) - _, _, _, err = getUser(db, name) - if err == sql.ErrNoRows { + records, err := getOTPWithAlts(db, name) + if len(records) == 0 { log.Printf("[%v] Unknown user: %v", remote, name) return } else if err != nil { @@ -613,7 +648,7 @@ func serve() { func incCount(user string, diff uint64) error { for { inner := func(tx *sql.Tx) error { - _, counter, _, err := getUser(tx, user) + _, counter, _, err := getEntity(tx, user) if err != nil { return err } @@ -707,7 +742,7 @@ func main() { log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args)) } user := args[0] - _, counter, _, err := getUser(db, user) + _, counter, _, err := getEntity(db, user) errFatal(err) fmt.Printf("%v: %v\n", user, counter) case "set-counter": diff --git a/goatherd_test.go b/goatherd_test.go index ea8daa322c3687559975beb7ceefe074e6debe20..3887c21d39d6efc5e00f0cfe03127fcbaaa5caf5 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -44,7 +44,7 @@ func createAndCheck(user string, sec string, isTOTP bool) func(*testing.T) { t.Run("exists", func(t *testing.T) { var err error - result.secret, result.count, result.isTOTP, err = getUser(db, username) + result.secret, result.count, result.isTOTP, err = getEntity(db, username) tErrFatal(t, err) }) @@ -173,12 +173,12 @@ func transactionConflictT(t *testing.T) { func checkOfferT(t *testing.T) { t.Run("noSuchUser", func(t *testing.T) { _, err := checkOffer("mock", "no such user", "dummy offer") - if err != sql.ErrNoRows { + if err.Error() != "no such user" { t.Error("err:", nil) } }) - secret, count, isTOTP, err := getUser(db, username) + secret, count, isTOTP, err := getEntity(db, username) tErrFatal(t, err) if isTOTP { t.Fatalf("broken test setup") @@ -289,11 +289,11 @@ func incCountT(t *testing.T) { }) t.Run("by2", func(t *testing.T) { - _, counter, _, err := getUser(db, username) + _, counter, _, err := getEntity(db, username) tErrFatal(t, err) err = incCount(username, 2) tErrFatal(t, err) - _, newCounter, _, err := getUser(db, username) + _, newCounter, _, err := getEntity(db, username) tErrFatal(t, err) if newCounter != counter+2 { @@ -307,7 +307,7 @@ func createUserTOTPT(t *testing.T) { } func checkOfferTOTPT(t *testing.T) { - secret, _, isTOTP, err := getUser(db, username) + secret, _, isTOTP, err := getEntity(db, username) tErrFatal(t, err) if !isTOTP { t.Fatalf("broken test setup") @@ -454,8 +454,9 @@ func handleConnT(t *testing.T) { } }) - otp, _, err := getOTP(db, username) + records, err := getOTPWithAlts(db, username) tErrFatal(t, err) + otp := records[username].otp t.Run("ok", func(t *testing.T) { recv := make(chan string)