From ba908c8853caf6408df31438eaee9dea3e4dc590 Mon Sep 17 00:00:00 2001
From: Lukas Braun <lukas.braun@fau.de>
Date: Mon, 23 Jan 2017 17:52:32 +0100
Subject: [PATCH] ratelimiting 2/2: do it

---
 goatherd.go | 108 ++++++++++++++++++++++++++++++++++++++++++++--------
 1 file changed, 92 insertions(+), 16 deletions(-)

diff --git a/goatherd.go b/goatherd.go
index 07ed3fb..6891a9e 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -12,19 +12,56 @@ import (
     "log"
     "net"
     "os"
+    "sync"
+    "time"
 
     "github.com/gokyle/hotp"
     "github.com/mattn/go-sqlite3"
 )
 
+// wrapper because time.Duration doesn't implement UnmarshalJSON
+type duration struct {
+    time.Duration
+}
+func (d *duration) UnmarshalJSON(b []byte) (err error) {
+    if b[0] == '"' {
+        d.Duration, err = time.ParseDuration(string(b[1 : len(b)-1]))
+        return
+    }
+
+    i, err := json.Number(string(b)).Int64()
+    d.Duration = time.Duration(i) * time.Second
+
+    return
+}
+
+
 var cfg struct {
     Db_url string
     Lookahead uint64
     Debug bool
     Listen string
     Tls struct { Key, Cert string }
+    Faildelay duration
+}
+
+
+// state for per-user ratelimiting
+//
+// Each user has a corresponding Mutex in faildelay.userlocks. These locks
+// protect the _sending of responses_, not database access or anything else.
+// The reason is that we want to avoid keeping state for non-existent users and
+// only after talking to the database do we know if a user exists.
+//
+// Access to the map itself is synchronized by faildelay.lock. It is only
+// neccessary to grab the lock for writing when a user is inserted into the map
+// (i.e first login attempt), otherwise a read lock suffices.
+var faildelay struct {
+    lock sync.RWMutex
+    userlocks map[string]*sync.Mutex
 }
 
+
 func debug(v ...interface{}) {
     if cfg.Debug { log.Println(v...) }
 }
@@ -147,8 +184,8 @@ func check_offer(db *sql.DB, remote net.Addr, name string, offer string) (ok boo
             goto retry
         } else if err == sql.ErrNoRows {
             log.Printf("Unkown user: %v", name)
-            ok = false
-            goto commit
+            tx.Rollback()
+            return false, err
         } else { err_panic(err) }
 
         for i = 0; i < cfg.Lookahead; i++ {
@@ -187,49 +224,86 @@ retry:
     }
 }
 
-func handle(db *sql.DB, remote net.Addr, reader *bufio.Reader, writer *bufio.Writer) {
+func close(conn net.Conn, remote net.Addr) {
+    debugf("[%v] closing", remote)
+    conn.Close() // XXX: check err?
+}
+
+func handle_conn(db *sql.DB, conn net.Conn) {
+    remote := conn.RemoteAddr()
+    reader := bufio.NewReader(conn)
+    writer := bufio.NewWriter(conn)
+
     debugf("[%v] reading name", remote)
     name, err := get_line(reader)
     if err != nil {
         log.Printf("[%v] %v", remote, err)
+        close(conn, remote)
         return
     }
     debugf("[%v] name: %v", remote, name)
 
-    // XXX: ratelimiting: wait x secs after fail per user
-
     debugf("[%v] reading offer", remote)
     offer, err := get_line(reader)
     if err != nil {
         log.Printf("[%v] %v", remote, err)
+        close(conn, remote)
         return
     }
 
     debugf("[%v] checking for match", remote)
     result := "FAIL"
-    if match, err := check_offer(db, remote, name, offer); err != nil {
+    match, err := check_offer(db, remote, name, offer)
+    if err == sql.ErrNoRows {
+        close(conn, remote)
+        return
+    } else if err != nil {
         log.Panic(err)
     } else if match {
         result = "OK"
     }
     log.Printf("%v: %v", name, result)
+
+    // name exists, get or create its lock
+    faildelay.lock.RLock()
+    delay, exists := faildelay.userlocks[name]
+    faildelay.lock.RUnlock()
+    if !exists {
+        debugf("[%v] not yet in faildelay.userlocks", remote)
+
+        // no atomic upgrade with sync.RWMutex, so we have to do the lookup again
+        faildelay.lock.Lock()
+        delay, exists = faildelay.userlocks[name]
+        if !exists {
+            delay = new(sync.Mutex)
+            faildelay.userlocks[name] = delay
+        }
+        faildelay.lock.Unlock()
+    }
+
+    delay.Lock()
+
     _, err = fmt.Fprintln(writer, result)
     if err != nil {
         log.Printf("[%v] %v", remote, err)
-        return
+        goto out
     }
     err = writer.Flush()
     if err != nil {
         log.Printf("[%v] %v", remote, err)
-        return
+        goto out
     }
-}
 
-func handle_conn(db *sql.DB, conn net.Conn) {
-    remote := conn.RemoteAddr()
-    handle(db, remote, bufio.NewReader(conn), bufio.NewWriter(conn))
-    debugf("[%v] closing", remote)
-    conn.Close() // XXX: check err?
+out:
+    close(conn, remote)
+
+    if !match {
+        debugf("[%v] delaying for %v", remote, cfg.Faildelay.Duration)
+        time.Sleep(cfg.Faildelay.Duration)
+    }
+
+    debugf("[%v] unlock", remote)
+    delay.Unlock()
 }
 
 
@@ -247,6 +321,8 @@ func main() {
     flag.StringVar(&cfg.Listen, "listen", "127.0.0.1:9999", "Address to listen on.")
     flag.StringVar(&cfg.Tls.Key, "tls-key", "", "Use TLS.")
     flag.StringVar(&cfg.Tls.Cert, "tls-cert", "", "Use TLS.")
+    flag.DurationVar(&cfg.Faildelay.Duration, "faildelay", 1*time.Second,
+        "Per-user delay after a failed authentication attempt.")
 
     // 1. parse arguments to get config path
     flag.Parse()
@@ -272,6 +348,8 @@ func main() {
     }
 
     if *flag_serve {
+        faildelay.userlocks = make(map[string]*sync.Mutex)
+
         var listener net.Listener
         listen_addr, err := net.ResolveTCPAddr("tcp", cfg.Listen);
         err_fatal(err)
@@ -300,6 +378,4 @@ func main() {
             go handle_conn(db, conn)
         }
     }
-
-    // handle(db, bufio.NewReader(os.Stdin), bufio.NewWriter(os.Stdout))
 }
-- 
GitLab