diff --git a/goatherd.go b/goatherd.go index 29117d731380b51a8393a6f197156787fccb194c..f9a0933c7ab55c75e2f07988621e6f5bf8d6dfef 100644 --- a/goatherd.go +++ b/goatherd.go @@ -23,7 +23,8 @@ import ( "github.com/lib/pq" ) -const otpLen = 8 +const hotpLen = 8 +const totpLen = 6 // wrapper because time.Duration doesn't implement UnmarshalJSON type duration struct { @@ -52,6 +53,7 @@ type tlsCfg struct { var cfg struct { DbURL string `json:"db_url"` Lookahead uint64 `json:"lookahead"` + Lookaround uint64 `json:"lookaround"` AutoresyncRepeat uint64 `json:"autoresync_repeat"` AutoresyncLookahead uint64 `json:"autoresync_lookahead"` AutoresyncTime uint64 `json:"autoresync_time"` @@ -246,7 +248,7 @@ func createUser(name string, secretB64 string) { errFatalf("Failed to create user: %v\n", err) } -func getUser(db dbConn, name string) (secret []byte, count uint64, err error) { +func getUser(db dbConn, name string) (secret []byte, count uint64, isTOTP bool, err error) { stmt, err := db.Prepare("SELECT secret, count FROM users WHERE name = $1") if err != nil { return @@ -269,13 +271,16 @@ func setCount(db dbConn, name string, count uint64) error { return err } -func getOTP(db dbConn, name string) (*twofactor.HOTP, error) { - secret, count, err := getUser(db, name) +func getOTP(db dbConn, name string) (twofactor.OTP, uint64, error) { + secret, count, isTOTP, err := getUser(db, name) if err != nil { - return nil, err + return nil, 0, err } - return twofactor.NewHOTP(secret, count, otpLen), nil + if isTOTP { + return twofactor.NewTOTPSHA1(secret, 0, 30, totpLen), count, nil + } + return twofactor.NewHOTP(secret, count, hotpLen), count, nil } func transactionFailed(err error) bool { @@ -304,7 +309,7 @@ func matchingOTP(expected string, offer string) bool { return subtle.ConstantTimeCompare([]byte(offer), []byte(expected)) == 1 } -func checkHOTP(hotp *twofactor.HOTP, remote string, name string, offer string) bool { +func checkHOTP(hotp *twofactor.HOTP, remote string, name string, offer string) (bool, uint64) { // garbage collect old autoresync entries autoresyncListLock.Lock() if s, found := autoresyncList[name]; found && uint64(time.Now().Unix()-s.Time) > cfg.AutoresyncTime { @@ -318,7 +323,7 @@ func checkHOTP(hotp *twofactor.HOTP, remote string, name string, offer string) b debugf("[%v] checking for match (offset %v)", remote, i) // OTP always increments counter if matchingOTP(hotp.OTP(), offer) { - return true + return true, hotp.Counter() } } @@ -346,12 +351,36 @@ func checkHOTP(hotp *twofactor.HOTP, remote string, name string, offer string) b if entry.Num >= cfg.AutoresyncRepeat { // resync if the user had a sufficient number of consecutive tries that were not within // standard lookahead range but within cfg.AutoresyncLookahead within cfg.AutoresyncTime seconds - return true + return true, hotp.Counter() } break } } - return false + return false, 0 +} + +func checkTOTP(totp *twofactor.TOTP, lastCount uint64, remote string, name string, offer string) (bool, uint64) { + base := totp.OTPCounter() + + for i := uint64(0); i <= 2*cfg.Lookaround; i++ { + // "zig-zag" offset which increases distance from zero and alternates between positive an negative + // e.g. 0, -1, 1, -2, 2, ... + offset := (i + 1) / 2 + ctr := base + offset + if (i & 1) != 0 { + ctr = base - offset + } + if ctr <= lastCount { + // skip already used tokens + continue + } + + debugf("[%v] checking for match (ctr %v)", remote, ctr) + if matchingOTP(totp.OATH.OTP(ctr), offer) { + return true, ctr + 1 + } + } + return false, 0 } // Retrieve secret and count for given username and try to find a match within @@ -361,15 +390,25 @@ func checkOffer(remote string, name string, offer string) (bool, error) { for { inner := func(tx *sql.Tx) (bool, error) { debugf("[%v] looking up data for %v", remote, name) - hotp, err := getOTP(tx, name) + otp, count, err := getOTP(tx, name) if err != nil { return false, err } - ok := checkHOTP(hotp, remote, name, offer) + var ok bool + var nextCtr uint64 + switch otp.Type() { + case twofactor.OATH_HOTP: + ok, nextCtr = checkHOTP(otp.(*twofactor.HOTP), remote, name, offer) + case twofactor.OATH_TOTP: + ok, nextCtr = checkTOTP(otp.(*twofactor.TOTP), count, remote, name, offer) + default: + log.Panicf("unsupported otp type %v", otp.Type()) + } + if ok { debugf("[%v] ok, set new count", remote) - err := setCount(tx, name, hotp.Counter()) + err := setCount(tx, name, nextCtr) if err != nil { return false, err } @@ -416,7 +455,7 @@ func handleConn(remote string, reader *bufio.Reader, } debugf("[%v] name: %v", remote, name) - _, _, err = getUser(db, name) + _, _, _, err = getUser(db, name) if err == sql.ErrNoRows { log.Printf("[%v] Unknown user: %v", remote, name) return @@ -561,7 +600,7 @@ func serve() { func incCount(user string, diff uint64) error { for { inner := func(tx *sql.Tx) error { - _, counter, err := getUser(tx, user) + _, counter, _, err := getUser(tx, user) if err != nil { return err } @@ -646,7 +685,7 @@ func main() { log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args)) } user := args[0] - _, counter, err := getUser(db, user) + _, counter, _, err := getUser(db, user) errFatal(err) fmt.Printf("%v: %v\n", user, counter) case "set-counter": diff --git a/goatherd_test.go b/goatherd_test.go index 1d6c9e113f96316a9f97d3c0ba6548c9543feec6..a0914c520d9537c641063ebb98e8b7a08b45d66a 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/benbjohnson/clock" "github.com/gokyle/twofactor" ) @@ -31,18 +32,19 @@ func tErrFatal(t *testing.T, err error) { } // always uses global username and secret for comparison, so we can test "-" -func createAndCheck(user string, sec string) func(*testing.T) { +func createAndCheck(user string, sec string, isTOTP bool) func(*testing.T) { return func(t *testing.T) { createUser(user, sec) var result struct { secret []byte count uint64 + isTOTP bool } t.Run("exists", func(t *testing.T) { var err error - result.secret, result.count, err = getUser(db, username) + result.secret, result.count, result.isTOTP, err = getUser(db, username) tErrFatal(t, err) }) @@ -57,18 +59,24 @@ func createAndCheck(user string, sec string) func(*testing.T) { t.Errorf("initial count: %v", result.count) } }) + + t.Run("correctType", func(t *testing.T) { + if result.isTOTP != isTOTP { + t.Errorf("expIsTOTP: %v, isTOTP: %v", isTOTP, result.isTOTP) + } + }) } } func createUserT(t *testing.T) { - t.Run("args", createAndCheck(username, secretB64)) + t.Run("args", createAndCheck(username, secretB64, false)) go func() { fmt.Fprintln(stdinWriter, username) fmt.Fprintln(stdinWriter, secretB64) stdinWriter.Flush() }() - t.Run("stdin", createAndCheck("-", "-")) + t.Run("stdin", createAndCheck("-", "-", false)) } func aborted(t *testing.T, err error, aborted ...interface{}) bool { @@ -170,8 +178,11 @@ func checkOfferT(t *testing.T) { } }) - secret, count, err := getUser(db, username) + secret, count, isTOTP, err := getUser(db, username) tErrFatal(t, err) + if isTOTP { + t.Fatalf("broken test setup") + } t.Run("fail", func(t *testing.T) { ok, err := checkOffer("mock", username, "dummy offer") @@ -181,7 +192,7 @@ func checkOfferT(t *testing.T) { } t.Run("tooFarOut", func(t *testing.T) { - ahead := twofactor.NewHOTP(secret, count+cfg.Lookahead+1, otpLen) + ahead := twofactor.NewHOTP(secret, count+cfg.Lookahead+1, hotpLen) ok, err = checkOffer("mock", username, ahead.OTP()) if ok { t.Fail() @@ -190,7 +201,7 @@ func checkOfferT(t *testing.T) { }) t.Run("ok", func(t *testing.T) { - cur := twofactor.NewHOTP(secret, count, otpLen) + cur := twofactor.NewHOTP(secret, count, hotpLen) ok, err := checkOffer("mock", username, cur.OTP()) tErrFatal(t, err) if !ok { @@ -208,7 +219,7 @@ func checkOfferT(t *testing.T) { t.Run("lookahead", func(t *testing.T) { ok, err := checkOffer("mock", username, - twofactor.NewHOTP(secret, cur.Counter()+cfg.Lookahead, otpLen).OTP()) + twofactor.NewHOTP(secret, cur.Counter()+cfg.Lookahead, hotpLen).OTP()) tErrFatal(t, err) if !ok { t.Fail() @@ -218,7 +229,7 @@ func checkOfferT(t *testing.T) { t.Run("autoresync", func(t *testing.T) { t.Run("ok", func(t *testing.T) { - cur := twofactor.NewHOTP(secret, count+cfg.Lookahead+27, otpLen) + cur := twofactor.NewHOTP(secret, count+cfg.Lookahead+27, hotpLen) // temporal assumption: loop iteration takes less than cfg.AutoresyncTime seconds for i := uint64(0); i < cfg.AutoresyncRepeat; i++ { ok, err := checkOffer("autoresync-good", username, cur.OTP()) @@ -234,7 +245,7 @@ func checkOfferT(t *testing.T) { } }) t.Run("fail-timeout", func(t *testing.T) { - cur := twofactor.NewHOTP(secret, count+cfg.Lookahead+57, otpLen) + cur := twofactor.NewHOTP(secret, count+cfg.Lookahead+57, hotpLen) // temporal assumption: loop iteration takes less than cfg.AutoresyncTime seconds for i := uint64(0); i < cfg.AutoresyncRepeat; i++ { ok, err := checkOffer("autoresync-bad-timeout", username, cur.OTP()) @@ -251,7 +262,7 @@ func checkOfferT(t *testing.T) { } }) t.Run("fail-range", func(t *testing.T) { - cur := twofactor.NewHOTP(secret, count+cfg.Lookahead+317, otpLen) + cur := twofactor.NewHOTP(secret, count+cfg.Lookahead+317, hotpLen) // temporal assumption: loop iteration takes less than cfg.AutoresyncTime seconds for i := uint64(0); i < cfg.AutoresyncRepeat; i++ { ok, err := checkOffer("autoresync-bad-range", username, cur.OTP()) @@ -278,11 +289,11 @@ func incCountT(t *testing.T) { }) t.Run("by2", func(t *testing.T) { - _, counter, err := getUser(db, username) + _, counter, _, err := getUser(db, username) tErrFatal(t, err) err = incCount(username, 2) tErrFatal(t, err) - _, newCounter, err := getUser(db, username) + _, newCounter, _, err := getUser(db, username) tErrFatal(t, err) if newCounter != counter+2 { @@ -291,6 +302,93 @@ func incCountT(t *testing.T) { }) } +func createUserTOTPT(t *testing.T) { + t.Run("args", createAndCheck(username, secretB64, true)) +} + +func checkOfferTOTPT(t *testing.T) { + secret, _, isTOTP, err := getUser(db, username) + tErrFatal(t, err) + if !isTOTP { + t.Fatalf("broken test setup") + } + + // setup mock clock which does not advance on its own + mc := clock.NewMock() + mc.Set(time.Now()) + twofactor.SetClock(mc) + + defer func() { + // restore clock + twofactor.SetClock(clock.New()) + }() + + t.Run("fail", func(t *testing.T) { + ok, err := checkOffer("mock", username, "dummy offer") + tErrFatal(t, err) + if ok { + t.Fail() + } + + t.Run("tooFarOut", func(t *testing.T) { + ahead := twofactor.NewTOTPSHA1(secret, 0, 30, totpLen) + ok, err = checkOffer("mock", username, ahead.OATH.OTP(ahead.OTPCounter()+cfg.Lookaround+1)) + if ok { + t.Fail() + } + }) + }) + + t.Run("ok", func(t *testing.T) { + cur := twofactor.NewTOTPSHA1(secret, 0, 30, totpLen) + + t.Run("lookbehind", func(t *testing.T) { + ok, err := checkOffer("mock", username, cur.OATH.OTP(cur.OTPCounter()-1)) + tErrFatal(t, err) + if !ok { + t.Fail() + } + }) + + otp := cur.OTP() + t.Run("normal", func(t *testing.T) { + ok, err := checkOffer("mock", username, otp) + tErrFatal(t, err) + if !ok { + t.Fail() + } + }) + + t.Run("failOnRetry", func(t *testing.T) { + ok, err := checkOffer("mock", username, otp) + tErrFatal(t, err) + if ok { + t.Fail() + } + }) + + t.Run("lookahead", func(t *testing.T) { + ok, err := checkOffer("mock", username, cur.OATH.OTP(cur.OTPCounter()+1)) + tErrFatal(t, err) + if !ok { + t.Fail() + } + }) + }) + + t.Run("otpexpiry", func(t *testing.T) { + mc.Add(5 * time.Minute) + cur := twofactor.NewTOTPSHA1(secret, 0, 30, totpLen) + otp := cur.OTP() + mc.Add(5 * time.Minute) + ok, err := checkOffer("mock", username, otp) + tErrFatal(t, err) + if ok { + t.Fail() + } + }) +} + type mockClient func(r *io.PipeReader, w *io.PipeWriter) func mockConn(client mockClient) (delay *sync.Mutex) { @@ -356,7 +454,7 @@ func handleConnT(t *testing.T) { } }) - otp, err := getOTP(db, username) + otp, _, err := getOTP(db, username) tErrFatal(t, err) t.Run("ok", func(t *testing.T) { @@ -461,6 +559,7 @@ func TestMain(t *testing.T) { t.Error("No DB_URL specified") } cfg.Lookahead = 3 + cfg.Lookaround = 2 cfg.Debug = false cfg.ReadTimeout.Duration = 1 * time.Second cfg.AutoresyncRepeat = 5 @@ -480,6 +579,8 @@ func TestMain(t *testing.T) { t.Run("checkOffer", checkOfferT) t.Run("incCount", incCountT) t.Run("handleConn", handleConnT) + t.Run("createUserTOTP", createUserTOTPT) + t.Run("checkOfferTOTP", checkOfferTOTPT) listener, err := net.Listen("unix", "@goatherd_test") tErrFatal(t, err)