diff --git a/goatherd.go b/goatherd.go index cc22b233248ff17813991feb13330a4346d34ed7..da00c55f9c7608b3353f796f6a8aaa3929107788 100644 --- a/goatherd.go +++ b/goatherd.go @@ -41,12 +41,14 @@ func (d duration) MarshalJSON() ([]byte, error) { } +type tls_cfg struct { + Key, Cert string +} var cfg struct { Db_url string Lookahead uint64 Debug bool - Listen string - Tls struct { Key, Cert string } + Listen map[string]tls_cfg Faildelay duration } @@ -324,30 +326,11 @@ func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, return } - -func serve(db *sql.DB) { - faildelay.userlocks = make(map[string]*sync.Mutex) - - var listener net.Listener - listen_addr, err := net.ResolveTCPAddr("tcp", cfg.Listen); - err_fatal(err) - listener, err = net.ListenTCP("tcp", listen_addr) - err_fatal(err) - - if (cfg.Tls.Key != "") { - log.Printf("Using TLS: cert %v, key %v\n", cfg.Tls.Cert, cfg.Tls.Key) - cert, err := tls.LoadX509KeyPair(cfg.Tls.Cert, cfg.Tls.Key) - err_fatalf("Error loading key pair: %v\n", err) - - listener = tls.NewListener(listener, &tls.Config{ - Certificates: []tls.Certificate{ cert }, - }) - } - - log.Println("Listening on", listen_addr) +func listen(db *sql.DB, wg sync.WaitGroup, listener net.Listener) { + defer wg.Done() for { - debug("Accept...") + debugf("Accepting on %v", listener.Addr()) conn, err := listener.Accept() err_fatal(err) log.Printf("new connection: %v\n", conn.RemoteAddr()) @@ -372,6 +355,36 @@ func serve(db *sql.DB) { } } +func serve(db *sql.DB) { + faildelay.userlocks = make(map[string]*sync.Mutex) + + var wg sync.WaitGroup + + for addr, tls_cfg := range cfg.Listen { + var listener net.Listener + listen_addr, err := net.ResolveTCPAddr("tcp", addr); + err_fatal(err) + listener, err = net.ListenTCP("tcp", listen_addr) + err_fatal(err) + + if (tls_cfg.Key != "") { + log.Printf("Using TLS: cert %v, key %v\n", tls_cfg.Cert, tls_cfg.Key) + cert, err := tls.LoadX509KeyPair(tls_cfg.Cert, tls_cfg.Key) + err_fatalf("Error loading key pair: %v\n", err) + + listener = tls.NewListener(listener, &tls.Config{ + Certificates: []tls.Certificate{ cert }, + }) + } + + log.Println("Listening on", listen_addr) + wg.Add(1) + go listen(db, wg, listener) + } + + wg.Wait() +} + func main() { flag_config := flag.String("config", "/etc/goatherd.conf", "Path to config file") @@ -380,14 +393,14 @@ func main() { flag_secret := flag.String("secret", "-", "Secret for the new user. If '-' read from stdin.") flag_serve := flag.Bool("serve", false, "Start daemon.") flag_dump_config := flag.Bool("dump-config", false, "Dump the effective config to stdout.") + flag_listen := flag.String("listen", "", "Address to listen on. (default 127.0.0.1:9999)") + flag_tls_key := flag.String("tls-key", "", "Use TLS.") + flag_tls_cert := flag.String("tls-cert", "", "Use TLS.") // also settable in config file flag.StringVar(&cfg.Db_url, "db-url", ":memory:", "URL used to connect to the database.") flag.Uint64Var(&cfg.Lookahead, "lookahead", 10, "Counter range to check for matching OTPs.") flag.BoolVar(&cfg.Debug, "debug", false, "Enable debug output.") - flag.StringVar(&cfg.Listen, "listen", "127.0.0.1:9999", "Address to listen on.") - flag.StringVar(&cfg.Tls.Key, "tls-key", "", "Use TLS.") - flag.StringVar(&cfg.Tls.Cert, "tls-cert", "", "Use TLS.") flag.DurationVar(&cfg.Faildelay.Duration, "faildelay", 1*time.Second, "Per-user delay after a failed authentication attempt.") @@ -409,6 +422,15 @@ func main() { // default action is to serve, but not if one of the other actions is given serve_default := true + // copy -listen option to config + if *flag_listen != "" { + cfg.Listen = make(map[string]tls_cfg) + cfg.Listen[*flag_listen] = tls_cfg{ Key: *flag_tls_key, Cert: *flag_tls_cert } + } else if cfg.Listen == nil { + cfg.Listen = make(map[string]tls_cfg) + cfg.Listen["127.0.0.1:9999"] = tls_cfg{} + } + if *flag_dump_config { serve_default = false enc := json.NewEncoder(os.Stdout)