From b276bd542997b5351a17a35954e1e1a8252872d3 Mon Sep 17 00:00:00 2001 From: Lukas Braun <lukas.braun@fau.de> Date: Fri, 6 Oct 2017 17:00:42 +0200 Subject: [PATCH] Add get-counter and inc-counter subcommands Also add some cli tests, to be extended for existing subcommands. --- Makefile | 2 +- goatherd.go | 48 ++++++++++++++++++++++++++++++++++++++++++++++++ goatherd_test.go | 23 +++++++++++++++++++++++ test_pg.sh | 37 ++++++++++++++++++++++++++++++++++++- 4 files changed, 108 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index c179bc0..3a1ecb6 100644 --- a/Makefile +++ b/Makefile @@ -7,7 +7,7 @@ all: goatherd pam_goatherd.so goatherd: goatherd.go go build goatherd.go -test_pg: goatherd.go goatherd_test.go test_pg.sh +test_pg: goatherd.go goatherd_test.go test_pg.sh goatherd sh test_pg.sh test_pam_goatherd: goatherd pam_goatherd.so test_pam_goatherd.sh diff --git a/goatherd.go b/goatherd.go index 68811ad..c17819c 100644 --- a/goatherd.go +++ b/goatherd.go @@ -510,6 +510,38 @@ func serve() { wg.Wait() } +// incCount increments a users counter atomically. +func incCount(user string, diff uint64) error { + for { + tx, err := db.Begin() + if err != nil { + return err + } + + _, counter, err := getUser(tx, user) + if transactionFailed(err) { + goto retry + } else if err != nil { + tx.Rollback() + return err + } + + err = setCount(db, user, counter+diff) + if transactionFailed(err) { + goto retry + } else if err != nil { + tx.Rollback() + return err + } + + tx.Commit() + return nil + + retry: + tx.Rollback() + } +} + func main() { flagConfig := flag.String("config", "/etc/goatherd.conf", "Path to config file") flag.StringVar(&cfg.DbURL, "db-url", "", "URL used to connect to the database.") @@ -560,6 +592,14 @@ func main() { user := args[0] secret := args[1] createUser(user, secret) + case "get-counter": + if len(args) != 1 { + log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args)) + } + user := args[0] + _, counter, err := getUser(db, user) + errFatal(err) + fmt.Printf("%v: %v\n", user, counter) case "set-counter": if len(args) != 2 { log.Fatalf("set-counter: Invalid number of arguments: %v (expecting <username> <counter>)\n", len(args)) @@ -568,6 +608,14 @@ func main() { counter, err := strconv.ParseUint(args[1], 0, 64) errFatal(err) errFatal(setCount(db, user, counter)) + case "inc-counter": + if len(args) != 2 { + log.Fatalf("set-counter: Invalid number of arguments: %v (expecting <username> <counter>)\n", len(args)) + } + user := args[0] + diff, err := strconv.ParseUint(args[1], 0, 64) + errFatal(err) + errFatal(incCount(user, diff)) case "serve": fs := flag.NewFlagSet("serve", flag.ExitOnError) flagAddr := fs.String("addr", "", "Address to listen on. (default 127.0.0.1:9999)") diff --git a/goatherd_test.go b/goatherd_test.go index 99bf277..74d1dd2 100644 --- a/goatherd_test.go +++ b/goatherd_test.go @@ -216,6 +216,28 @@ func checkOfferT(t *testing.T) { } +func incCountT(t *testing.T) { + t.Run("noSuchUser", func(t *testing.T) { + err := incCount("noSuchUser", 1) + if err != sql.ErrNoRows { + t.Error("err:", nil) + } + }) + + t.Run("by2", func(t *testing.T) { + _, counter, err := getUser(db, username) + tErrFatal(t, err) + err = incCount(username, 2) + tErrFatal(t, err) + _, newCounter, err := getUser(db, username) + tErrFatal(t, err) + + if newCounter != counter + 2 { + t.Errorf("Expected %v, is %v\n", counter + 2, newCounter) + } + }) +} + type mockClient func(r *io.PipeReader, w *io.PipeWriter) func mockConn(client mockClient) (delay *sync.Mutex) { @@ -378,5 +400,6 @@ func TestMain(t *testing.T) { t.Run("createUser", createUserT) t.Run("transactionConflict", transactionConflictT) t.Run("checkOffer", checkOfferT) + t.Run("incCount", incCountT) t.Run("handleConn", handleConnT) } diff --git a/test_pg.sh b/test_pg.sh index b575631..e8836e2 100644 --- a/test_pg.sh +++ b/test_pg.sh @@ -20,6 +20,41 @@ echo Waiting for compile to finish wait %1 echo -echo ===== Tests ===== +echo ===== Go Tests ===== DB_URL="host=$PGHOST dbname=goatherd_test" ./goatherd.test "$@" + + +echo ===== Command-Line Tests ===== + +gh() { + ./goatherd -config "$PWD/pam_goatherd_test.conf" -db-url "host=$PGHOST dbname=goatherd_test" "$@" +} + +expecting() { + expected="$1" + got="$2" + msg="$3" + if [ "$expected" != "$got" ]; then + echo "$msg: expected '$expected', got '$got'" >&2 + exit 1 + fi +} + +expecting_count() { + cnt="$1" + msg="$2" + expecting "cmdtest: $cnt" "$(gh get-counter cmdtest)" "$msg" +} + +gh add-user cmdtest 'Hq05cjBHVVAN/sCAcN81p5uwUzI=' +expecting_count 1 "Unexpected counter after initialization" + +gh set-counter cmdtest 10 +expecting_count 10 "Unexpected counter after setting to 10" + +gh inc-counter cmdtest 10 +expecting_count 20 "Unexpected counter after incrementing by 10" + + +echo PASS -- GitLab