From b276bd542997b5351a17a35954e1e1a8252872d3 Mon Sep 17 00:00:00 2001
From: Lukas Braun <lukas.braun@fau.de>
Date: Fri, 6 Oct 2017 17:00:42 +0200
Subject: [PATCH] Add get-counter and inc-counter subcommands

Also add some cli tests, to be extended for existing subcommands.
---
 Makefile         |  2 +-
 goatherd.go      | 48 ++++++++++++++++++++++++++++++++++++++++++++++++
 goatherd_test.go | 23 +++++++++++++++++++++++
 test_pg.sh       | 37 ++++++++++++++++++++++++++++++++++++-
 4 files changed, 108 insertions(+), 2 deletions(-)

diff --git a/Makefile b/Makefile
index c179bc0..3a1ecb6 100644
--- a/Makefile
+++ b/Makefile
@@ -7,7 +7,7 @@ all: goatherd pam_goatherd.so
 goatherd: goatherd.go
 	go build goatherd.go
 
-test_pg: goatherd.go goatherd_test.go test_pg.sh
+test_pg: goatherd.go goatherd_test.go test_pg.sh goatherd
 	sh test_pg.sh
 
 test_pam_goatherd: goatherd pam_goatherd.so test_pam_goatherd.sh
diff --git a/goatherd.go b/goatherd.go
index 68811ad..c17819c 100644
--- a/goatherd.go
+++ b/goatherd.go
@@ -510,6 +510,38 @@ func serve() {
 	wg.Wait()
 }
 
+// incCount increments a users counter atomically.
+func incCount(user string, diff uint64) error {
+	for {
+		tx, err := db.Begin()
+		if err != nil {
+			return err
+		}
+
+		_, counter, err := getUser(tx, user)
+		if transactionFailed(err) {
+			goto retry
+		} else if err != nil {
+			tx.Rollback()
+			return err
+		}
+
+		err = setCount(db, user, counter+diff)
+		if transactionFailed(err) {
+			goto retry
+		} else if err != nil {
+			tx.Rollback()
+			return err
+		}
+
+		tx.Commit()
+		return nil
+
+	retry:
+		tx.Rollback()
+	}
+}
+
 func main() {
 	flagConfig := flag.String("config", "/etc/goatherd.conf", "Path to config file")
 	flag.StringVar(&cfg.DbURL, "db-url", "", "URL used to connect to the database.")
@@ -560,6 +592,14 @@ func main() {
 			user := args[0]
 			secret := args[1]
 			createUser(user, secret)
+		case "get-counter":
+			if len(args) != 1 {
+				log.Fatalf("get-counter: Invalid number of arguments: %v (expecting <user>)\n", len(args))
+			}
+			user := args[0]
+			_, counter, err := getUser(db, user)
+			errFatal(err)
+			fmt.Printf("%v: %v\n", user, counter)
 		case "set-counter":
 			if len(args) != 2 {
 				log.Fatalf("set-counter: Invalid number of arguments: %v (expecting <username> <counter>)\n", len(args))
@@ -568,6 +608,14 @@ func main() {
 			counter, err := strconv.ParseUint(args[1], 0, 64)
 			errFatal(err)
 			errFatal(setCount(db, user, counter))
+		case "inc-counter":
+			if len(args) != 2 {
+				log.Fatalf("set-counter: Invalid number of arguments: %v (expecting <username> <counter>)\n", len(args))
+			}
+			user := args[0]
+			diff, err := strconv.ParseUint(args[1], 0, 64)
+			errFatal(err)
+			errFatal(incCount(user, diff))
 		case "serve":
 			fs := flag.NewFlagSet("serve", flag.ExitOnError)
 			flagAddr := fs.String("addr", "", "Address to listen on. (default 127.0.0.1:9999)")
diff --git a/goatherd_test.go b/goatherd_test.go
index 99bf277..74d1dd2 100644
--- a/goatherd_test.go
+++ b/goatherd_test.go
@@ -216,6 +216,28 @@ func checkOfferT(t *testing.T) {
 
 }
 
+func incCountT(t *testing.T) {
+    t.Run("noSuchUser", func(t *testing.T) {
+        err := incCount("noSuchUser", 1)
+		if err != sql.ErrNoRows {
+			t.Error("err:", nil)
+		}
+	})
+
+    t.Run("by2", func(t *testing.T) {
+        _, counter, err := getUser(db, username)
+        tErrFatal(t, err)
+        err = incCount(username, 2)
+        tErrFatal(t, err)
+        _, newCounter, err := getUser(db, username)
+        tErrFatal(t, err)
+
+        if newCounter != counter + 2 {
+            t.Errorf("Expected %v, is %v\n", counter + 2, newCounter)
+        }
+    })
+}
+
 type mockClient func(r *io.PipeReader, w *io.PipeWriter)
 
 func mockConn(client mockClient) (delay *sync.Mutex) {
@@ -378,5 +400,6 @@ func TestMain(t *testing.T) {
 	t.Run("createUser", createUserT)
 	t.Run("transactionConflict", transactionConflictT)
 	t.Run("checkOffer", checkOfferT)
+	t.Run("incCount", incCountT)
 	t.Run("handleConn", handleConnT)
 }
diff --git a/test_pg.sh b/test_pg.sh
index b575631..e8836e2 100644
--- a/test_pg.sh
+++ b/test_pg.sh
@@ -20,6 +20,41 @@ echo Waiting for compile to finish
 wait %1
 
 echo
-echo ===== Tests =====
+echo ===== Go Tests =====
 
 DB_URL="host=$PGHOST dbname=goatherd_test" ./goatherd.test "$@"
+
+
+echo ===== Command-Line Tests =====
+
+gh() {
+    ./goatherd -config "$PWD/pam_goatherd_test.conf" -db-url "host=$PGHOST dbname=goatherd_test" "$@"
+}
+
+expecting() {
+    expected="$1"
+    got="$2"
+    msg="$3"
+    if [ "$expected" != "$got" ]; then
+        echo "$msg: expected '$expected', got '$got'" >&2
+        exit 1
+    fi
+}
+
+expecting_count() {
+    cnt="$1"
+    msg="$2"
+    expecting "cmdtest: $cnt" "$(gh get-counter cmdtest)" "$msg"
+}
+
+gh add-user cmdtest 'Hq05cjBHVVAN/sCAcN81p5uwUzI='
+expecting_count 1 "Unexpected counter after initialization"
+
+gh set-counter cmdtest 10
+expecting_count 10 "Unexpected counter after setting to 10"
+
+gh inc-counter cmdtest 10
+expecting_count 20 "Unexpected counter after incrementing by 10"
+
+
+echo PASS
-- 
GitLab