From e1faec7bfbc8d9555a64af7c24417c1d769764dd Mon Sep 17 00:00:00 2001
From: Lukas Braun <koomi@moshbit.net>
Date: Fri, 6 Oct 2017 15:17:59 +0200
Subject: [PATCH] Add a connection timeout

TODO: write tests
---
 goatherd.go | 34 ++++++++++++++++++++++++++++------
 1 file changed, 28 insertions(+), 6 deletions(-)

diff --git a/goatherd.go b/goatherd.go
index ac4bb62..85d2653 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -44,16 +44,36 @@ func (d duration) MarshalJSON() ([]byte, error) {
 	return []byte(fmt.Sprintf(`"%s"`, d.String())), nil
 }
 
+type timeoutConn struct {
+	net.Conn
+	Timeout time.Duration
+}
+
+func (c *timeoutConn) updateDeadline() {
+	c.Conn.SetDeadline(time.Now().Add(c.Timeout))
+}
+
+func (c *timeoutConn) Read(b []byte) (int, error) {
+	c.updateDeadline()
+	return c.Conn.Read(b)
+}
+
+func (c *timeoutConn) Write(b []byte) (int, error) {
+	c.updateDeadline()
+	return c.Conn.Write(b)
+}
+
 type tlsCfg struct {
 	Key, Cert string
 }
 
 var cfg struct {
-	DbURL     string
-	Lookahead uint64
-	Debug     bool
-	Listen    map[string]tlsCfg
-	Faildelay duration
+	DbURL       string
+	Lookahead   uint64
+	Debug       bool
+	Listen      map[string]tlsCfg
+	Faildelay   duration
+	IdleTimeout duration
 }
 
 // state for per-user ratelimiting
@@ -456,7 +476,7 @@ func listen(wg *sync.WaitGroup, listener net.Listener) {
 				debugf("[%v] unlock", remote)
 				delay.Unlock()
 			}
-		}(conn)
+		}(&timeoutConn{conn, cfg.IdleTimeout.Duration})
 	}
 }
 
@@ -497,6 +517,8 @@ func main() {
 	flag.BoolVar(&cfg.Debug, "debug", false, "Enable debug output.")
 	flag.DurationVar(&cfg.Faildelay.Duration, "faildelay", 1*time.Second,
 		"Per-user delay after a failed authentication attempt.")
+	flag.DurationVar(&cfg.IdleTimeout.Duration, "idle-timeout", 5*time.Second,
+		"Time after which an idle connection will be terminated.")
 
 	// 1. parse arguments to get config path
 	flag.Parse()
-- 
GitLab