diff --git a/goatherd.go b/goatherd.go index ce3255a6ed28168077df03cc464f9b8e760c686a..d4260cfdfd570cd704fa3513834b183f9621411b 100644 --- a/goatherd.go +++ b/goatherd.go @@ -60,6 +60,13 @@ var cfg struct { ReadTimeout duration `json:"read_timeout"` } +type otpRecord struct { + name string + secret []byte + count uint64 + hotp *hotp.HOTP +} + // state for per-user ratelimiting // // Each user has a corresponding Mutex in faildelay.userlocks. A lock is @@ -245,18 +252,48 @@ func createUser(name string, secretB64 string) { errFatalf("Failed to create user: %v\n", err) } -func getUser(db dbConn, name string) (secret []byte, count uint64, err error) { +func getEntity(db dbConn, name string) (secret []byte, count uint64, err error) { stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1") if err != nil { return } - err = stmt.QueryRow(name).Scan(&secret, &count) stmt.Close() return } + +func getOTPWithAlts(db dbConn, name string) (recs map[string]otpRecord, err error) { + stmt, err := db.Prepare("SELECT name, secret, count FROM users WHERE name = $1 or name like CONCAT($1, '/%')") + if err != nil { + return + } + + rows, err := stmt.Query(name) + if err != nil { + return nil, err + } + defer rows.Close() + + recs = make(map[string]otpRecord) + + for rows.Next() { + var rec otpRecord + if err := rows.Scan(&rec.name, &rec.secret, &rec.count); err != nil { + return nil, err + } + rec.hotp = hotp.NewHOTP(rec.secret, rec.count, otpLen) + recs[rec.name] = rec + } + + if err = rows.Err(); err != nil { + return nil, err + } + + return +} + func setCount(db dbConn, name string, count uint64) error { inc, err := db.Prepare("UPDATE users SET count = $1 WHERE name = $2") if err != nil { @@ -268,15 +305,6 @@ func setCount(db dbConn, name string, count uint64) error { return err } -func getOTP(db dbConn, name string) (*hotp.HOTP, error) { - secret, count, err := getUser(db, name) - if err != nil { - return nil, err - } - - return hotp.NewHOTP(secret, count, otpLen), nil -} - func transactionFailed(err error) bool { if err == nil { return false @@ -302,7 +330,7 @@ type autoresyncEntry struct { // Retrieve secret and count for given username and try to find a match within // the lookahead range. Update count in DB if match is found. All within a // transaction, retrying if it fails. -func checkOffer(remote string, name string, offer string) (bool, error) { +func checkOffer(remote string, user string, offer string) (bool, error) { for { var lookahead uint64 = cfg.Lookahead var i uint64 @@ -312,82 +340,86 @@ func checkOffer(remote string, name string, offer string) (bool, error) { ok := false - debugf("[%v] looking up data for %v", remote, name) - hotp, err := getOTP(tx, name) + debugf("[%v] looking up data for %v", remote, user) + userrecords, err := getOTPWithAlts(tx, user) if transactionFailed(err) { goto retry - } else if err == sql.ErrNoRows { + } else if len(userrecords) == 0 { _ = tx.Rollback() - return false, err + return false, errors.New("No such user") } else { errPanic(err) } - autoresyncListLock.RLock() - if s, found := autoresyncList[name]; found { - autoresyncListLock.RUnlock() - debugf("[%v] autoresync in progress: %v", remote, s) - if uint64(time.Now().Unix()-s.Time) <= cfg.AutoresyncTime { - if s.Num >= cfg.AutoresyncRepeat && s.Counter-hotp.Counter() < cfg.AutoresyncLookahead { - // if the user had a sufficient number of successful tries that were not within - // standard lookahead range but within cfg.AutoresyncLookahead within cfg.AutoresyncTime seconds, - // temporarily increase lookahead to authenticate and resync. - debugf("[%v] autoresync conditions: increasing lookahead to ", cfg.AutoresyncLookahead) - lookahead = cfg.AutoresyncLookahead + for name, record := range userrecords { + hotp := record.hotp + + autoresyncListLock.RLock() + if s, found := autoresyncList[name]; found { + autoresyncListLock.RUnlock() + debugf("[%v] autoresync in progress: %v", remote, s) + if uint64(time.Now().Unix()-s.Time) <= cfg.AutoresyncTime { + if s.Num >= cfg.AutoresyncRepeat && s.Counter-hotp.Counter() < cfg.AutoresyncLookahead { + // if the user had a sufficient number of successful tries that were not within + // standard lookahead range but within cfg.AutoresyncLookahead within cfg.AutoresyncTime seconds, + // temporarily increase lookahead to authenticate and resync. + debugf("[%v] autoresync conditions: increasing lookahead to ", cfg.AutoresyncLookahead) + lookahead = cfg.AutoresyncLookahead + } + } else { + // timeout + debugf("[%v] autoresync timeout: %v", remote, name) + autoresyncListLock.Lock() + delete(autoresyncList, name) + autoresyncListLock.Unlock() } } else { - // timeout - debugf("[%v] autoresync timeout: %v", remote, name) - autoresyncListLock.Lock() - delete(autoresyncList, name) - autoresyncListLock.Unlock() + autoresyncListLock.RUnlock() } - } else { - autoresyncListLock.RUnlock() - } - for i = uint64(0); i <= lookahead; i++ { - debugf("[%v] checking for match (offset %v)", remote, i) - // .Check increments .Counter if successfull - // otherwise do it explicitly - if hotp.Check(offer) { - debugf("[%v] ok, set new count", remote) - err = setCount(tx, name, hotp.Counter()) - if transactionFailed(err) { - goto retry + for i = uint64(0); i <= lookahead; i++ { + debugf("[%v] checking for match (offset %v)", remote, i) + // .Check increments .Counter if successfull + // otherwise do it explicitly + if hotp.Check(offer) { + debugf("[%v] ok, set new count", remote) + err = setCount(tx, name, hotp.Counter()) + if transactionFailed(err) { + goto retry + } else { + errPanic(err) + } + + ok = true + goto commit } else { - errPanic(err) + hotp.Increment() } - - ok = true - goto commit - } else { - hotp.Increment() } - } - // check failed, try extended range for autoresync - for ; i <= cfg.AutoresyncLookahead; i++ { - debugf("[%v] autoresync checking for match (offset %v counter %v)", remote, i, hotp.Counter()) - if hotp.Check(offer) { - autoresyncListLock.Lock() - debugf("[%v] autoresync repeat count increase hotp.Counter %v, %v", remote, hotp.Counter(), autoresyncList[name]) - if s, found := autoresyncList[name]; found && hotp.Counter()-1 == s.Counter { - s.Time = time.Now().Unix() - s.Num++ - s.Counter++ - debugf("[%v] autoresync repeat count increase for %v", remote, name) - } else { - autoresyncList[name] = &autoresyncEntry{ - Time: time.Now().Unix(), - Counter: hotp.Counter(), - Num: 1, + // check failed, try extended range for autoresync + for ; i <= cfg.AutoresyncLookahead; i++ { + debugf("[%v] autoresync checking for match (offset %v counter %v)", remote, i, hotp.Counter()) + if hotp.Check(offer) { + autoresyncListLock.Lock() + debugf("[%v] autoresync repeat count increase hotp.Counter %v, %v", remote, hotp.Counter(), autoresyncList[name]) + if s, found := autoresyncList[name]; found && hotp.Counter()-1 == s.Counter { + s.Time = time.Now().Unix() + s.Num++ + s.Counter++ + debugf("[%v] autoresync repeat count increase for %v", remote, name) + } else { + autoresyncList[name] = &autoresyncEntry{ + Time: time.Now().Unix(), + Counter: hotp.Counter(), + Num: 1, + } + debugf("[%v] autoresync repeat count init for %v", remote, name) } - debugf("[%v] autoresync repeat count init for %v", remote, name) + autoresyncListLock.Unlock() + break + } else { + hotp.Increment() } - autoresyncListLock.Unlock() - break - } else { - hotp.Increment() } } @@ -424,8 +456,8 @@ func handleConn(remote string, reader *bufio.Reader, } debugf("[%v] name: %v", remote, name) - _, _, err = getUser(db, name) - if err == sql.ErrNoRows { + records, err := getOTPWithAlts(db, name) + if len(records) == 0 { log.Printf("[%v] Unknown user: %v", remote, name) return } else if err != nil { @@ -573,7 +605,7 @@ func incCount(user string, diff uint64) error { return err } - _, counter, err := getUser(tx, user) + _, counter, err := getEntity(tx, user) if transactionFailed(err) { goto retry } else if err != nil { @@ -651,7 +683,7 @@ func main() { log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args)) } user := args[0] - _, counter, err := getUser(db, user) + _, counter, err := getEntity(db, user) errFatal(err) fmt.Printf("%v: %v\n", user, counter) case "set-counter": diff --git a/goatherd_test.go b/goatherd_test.go index 3bdf1e72f10c41e3806082da147dc08cfa1cf70b..0f2b78020f1bd7d3e5fa7ce35b32e4f42b643a36 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -42,7 +42,7 @@ func createAndCheck(user string, sec string) func(*testing.T) { t.Run("exists", func(t *testing.T) { var err error - result.secret, result.count, err = getUser(db, username) + result.secret, result.count, err = getEntity(db, username) tErrFatal(t, err) }) @@ -165,12 +165,12 @@ func transactionConflictT(t *testing.T) { func checkOfferT(t *testing.T) { t.Run("noSuchUser", func(t *testing.T) { _, err := checkOffer("mock", "no such user", "dummy offer") - if err != sql.ErrNoRows { + if err.Error() != "No such user" { t.Error("err:", nil) } }) - secret, count, err := getUser(db, username) + secret, count, err := getEntity(db, username) tErrFatal(t, err) t.Run("fail", func(t *testing.T) { @@ -278,11 +278,11 @@ func incCountT(t *testing.T) { }) t.Run("by2", func(t *testing.T) { - _, counter, err := getUser(db, username) + _, counter, err := getEntity(db, username) tErrFatal(t, err) err = incCount(username, 2) tErrFatal(t, err) - _, newCounter, err := getUser(db, username) + _, newCounter, err := getEntity(db, username) tErrFatal(t, err) if newCounter != counter+2 { @@ -356,8 +356,9 @@ func handleConnT(t *testing.T) { } }) - otp, err := getOTP(db, username) + records, err := getOTPWithAlts(db, username) tErrFatal(t, err) + otp := records[username].hotp t.Run("ok", func(t *testing.T) { recv := make(chan string)