From ad5010e2190dcc6cef9fdf51a346df970e5fddd0 Mon Sep 17 00:00:00 2001 From: Lukas Braun <lukas.braun@fau.de> Date: Mon, 27 Feb 2017 18:52:27 +0100 Subject: [PATCH] use Scanner for reading lines Scanner aborts when lines are too long and is more flexible wrt. line separators. --- goatherd.go | 29 ++++++++++++++++++----------- goatherd_test.go | 4 ++-- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/goatherd.go b/goatherd.go index c7e0dbe..01f94a8 100644 --- a/goatherd.go +++ b/goatherd.go @@ -85,13 +85,16 @@ func err_fatal(err error) { err_fatalf("%v", err) } -func get_line(reader *bufio.Reader) (string, error) { - ln, err := reader.ReadString('\n') - if err != nil { return "", err } - if len(ln) <= 1 { return "", errors.New("Empty line") } +func get_line(s *bufio.Scanner) (string, error) { + if !s.Scan() { + if err := s.Err(); err != nil { + return "", err + } + } + line := s.Text() + if len(line) <= 0 { return "", errors.New("Empty line") } - // strip \n - return ln[:len(ln)-1], nil + return line, nil } @@ -115,19 +118,19 @@ func create_table(db *sql.DB) { } // global var for testing, not modified during normal execution -var stdin_reader = bufio.NewReader(os.Stdin) +var stdin_scanner = bufio.NewScanner(os.Stdin) func create_user(db *sql.DB, name string, secret_b64 string) { debug("Creating user") var err error if name == "-" { fmt.Printf("Enter username: ") - name, err = get_line(stdin_reader) + name, err = get_line(stdin_scanner) err_fatalf("Can't read username: %v\n", err) } if secret_b64 == "-" { fmt.Printf("Enter secret: ") - secret_b64, err = get_line(stdin_reader) + secret_b64, err = get_line(stdin_scanner) err_fatalf("Can't read secret: %v\n", err) } @@ -238,8 +241,12 @@ retry: func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, writer *bufio.Writer) (delay *sync.Mutex) { + s := bufio.NewScanner(reader) + b := make([]byte, 80) + s.Buffer(b, len(b)) + debugf("[%v] reading name", remote) - name, err := get_line(reader) + name, err := get_line(s) if err != nil { log.Printf("[%v] %v", remote, err) return @@ -247,7 +254,7 @@ func handle_conn(db *sql.DB, remote string, reader *bufio.Reader, debugf("[%v] name: %v", remote, name) debugf("[%v] reading offer", remote) - offer, err := get_line(reader) + offer, err := get_line(s) if err != nil { log.Printf("[%v] %v", remote, err) return diff --git a/goatherd_test.go b/goatherd_test.go index 310ad11..60627c2 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -165,7 +165,7 @@ func interact(name string, offer string, recv chan string) client_t { return func(r *io.PipeReader, w *io.PipeWriter) { w.Write(append([]byte(name), '\n')) w.Write(append([]byte(offer), '\n')) - answer, _ := get_line(bufio.NewReader(r)) + answer, _ := get_line(bufio.NewScanner(r)) recv <- answer } } @@ -296,7 +296,7 @@ func TestMain(t *testing.T) { cfg.Debug = false faildelay.userlocks = make(map[string]*sync.Mutex) stdin_r, stdin_w := io.Pipe() - stdin_reader, stdin_writer = bufio.NewReader(stdin_r), bufio.NewWriter(stdin_w) + stdin_scanner, stdin_writer = bufio.NewScanner(stdin_r), bufio.NewWriter(stdin_w) var err error db, err = sql.Open("sqlite3", cfg.Db_url) -- GitLab