diff --git a/goatherd.go b/goatherd.go index c1ba52d9c8a1766bb08121c2f0194407a15ce1da..809396908f8a0954d77ac095effc4b372bbc84a7 100644 --- a/goatherd.go +++ b/goatherd.go @@ -225,21 +225,12 @@ retry: } } -func close(conn net.Conn, remote net.Addr) { - debugf("[%v] closing", remote) - conn.Close() // XXX: check err? -} - -func handle_conn(db *sql.DB, conn net.Conn) { - remote := conn.RemoteAddr() - reader := bufio.NewReader(conn) - writer := bufio.NewWriter(conn) - +func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, + writer *bufio.Writer) (delay *sync.Mutex) { debugf("[%v] reading name", remote) name, err := get_line(reader) if err != nil { log.Printf("[%v] %v", remote, err) - close(conn, remote) return } debugf("[%v] name: %v", remote, name) @@ -248,7 +239,6 @@ func handle_conn(db *sql.DB, conn net.Conn) { offer, err := get_line(reader) if err != nil { log.Printf("[%v] %v", remote, err) - close(conn, remote) return } @@ -256,7 +246,6 @@ func handle_conn(db *sql.DB, conn net.Conn) { result := "FAIL" match, err := check_offer(db, remote, name, offer) if err == sql.ErrNoRows { - close(conn, remote) return } else if err != nil { log.Panic(err) @@ -271,43 +260,41 @@ func handle_conn(db *sql.DB, conn net.Conn) { // zeroed mutex if we used a concurrent map without a temporary variable // and further memory barriers) faildelay.RLock() - delay, exists := faildelay.userlocks[name] + userlock, exists := faildelay.userlocks[name] faildelay.RUnlock() if !exists { debugf("[%v] not yet in faildelay.userlocks", remote) faildelay.Lock() - delay, exists = faildelay.userlocks[name] + userlock, exists = faildelay.userlocks[name] if !exists { - delay = new(sync.Mutex) - faildelay.userlocks[name] = delay + userlock = new(sync.Mutex) + faildelay.userlocks[name] = userlock } faildelay.Unlock() } - delay.Lock() + userlock.Lock() + debugf("[%v] have lock", remote) + if !match { + // invalid offer, unlocked in serve after delaying further attempts + delay = userlock + } else { + defer userlock.Unlock() + } _, err = fmt.Fprintln(writer, result) if err != nil { log.Printf("[%v] %v", remote, err) - goto out + return } err = writer.Flush() if err != nil { log.Printf("[%v] %v", remote, err) - goto out - } - -out: - close(conn, remote) - - if !match { - debugf("[%v] delaying for %v", remote, cfg.Faildelay.Duration) - time.Sleep(cfg.Faildelay.Duration) + return } - debugf("[%v] unlock", remote) - delay.Unlock() + return } @@ -339,7 +326,22 @@ func serve(db *sql.DB) { log.Printf("new connection: %v\n", conn.RemoteAddr()) // XXX: recover from database failure - go handle_conn(db, conn) + go func(conn net.Conn) { + remote := conn.RemoteAddr().String() + reader := bufio.NewReader(conn) + writer := bufio.NewWriter(conn) + + delay := handle_conn(db, remote, reader, writer) + debugf("[%v] closing", remote) + conn.Close() // XXX: check err? + + if delay != nil { + debugf("[%v] delaying for %v", remote, cfg.Faildelay.Duration) + time.Sleep(cfg.Faildelay.Duration) + debugf("[%v] unlock", remote) + delay.Unlock() + } + }(conn) } }