From 0a3530c84aa51424a7fa09e830f6e044a55ef272 Mon Sep 17 00:00:00 2001
From: Florian Schmaus <flow@cs.fau.de>
Date: Wed, 14 Dec 2022 19:14:43 +0100
Subject: [PATCH] Add WaitFreeCountingSemaphore

---
 emper/WaitFreeCountingPrivateSemaphore.cpp    | 63 +++++++++++++++++++
 emper/WaitFreeCountingPrivateSemaphore.hpp    | 40 ++++++++++++
 emper/meson.build                             |  1 +
 .../WaitFreeCountingPrivateSemaphoreTest.cpp  | 54 ++++++++++++++++
 tests/meson.build                             |  6 ++
 5 files changed, 164 insertions(+)
 create mode 100644 emper/WaitFreeCountingPrivateSemaphore.cpp
 create mode 100644 emper/WaitFreeCountingPrivateSemaphore.hpp
 create mode 100644 tests/WaitFreeCountingPrivateSemaphoreTest.cpp

diff --git a/emper/WaitFreeCountingPrivateSemaphore.cpp b/emper/WaitFreeCountingPrivateSemaphore.cpp
new file mode 100644
index 00000000..1664acc6
--- /dev/null
+++ b/emper/WaitFreeCountingPrivateSemaphore.cpp
@@ -0,0 +1,63 @@
+// SPDX-License-Identifier: LGPL-3.0-or-later
+// Copyright © 2020-2022 Florian Schmaus
+#include "WaitFreeCountingPrivateSemaphore.hpp"
+
+#include <cassert>
+#include <limits>
+#include <ostream>
+
+#include "Context.hpp"
+#include "Debug.hpp"
+
+WaitFreeCountingPrivateSemaphore::WaitFreeCountingPrivateSemaphore()
+		: WaitFreeCountingPrivateSemaphore(0) {}
+
+WaitFreeCountingPrivateSemaphore::WaitFreeCountingPrivateSemaphore(unsigned int counter)
+		: requiredSignalCount(counter),
+			counter(std::numeric_limits<unsigned int>::max()),
+			blockedContext(nullptr) {}
+
+void WaitFreeCountingPrivateSemaphore::incrementCounterByOne() { requiredSignalCount++; }
+
+void WaitFreeCountingPrivateSemaphore::incrementCounter(unsigned int count) {
+	requiredSignalCount += count;
+}
+
+void WaitFreeCountingPrivateSemaphore::wait() {
+	unsigned int delta = std::numeric_limits<unsigned int>::max() - requiredSignalCount;
+	LOGD("wait() counter: " << counter << " delta: " << delta);
+
+	if (counter - delta == 0) {
+		return;
+	}
+
+	Context* blockedContext = Context::getCurrentContext();
+	block([this, blockedContext, delta] {
+		assert(blockedContext > (Context*)4096);
+
+		// this->blockedContext.store(blockedContext); // MO relaxed possible
+		this->blockedContext = blockedContext;
+		unsigned int oldCounter = counter.fetch_sub(delta);	 // MO release
+		LOGD("wait() block oldCounter: " << oldCounter << " delta: " << delta);
+		if (oldCounter == delta) {
+			unblock(blockedContext);
+		}
+	});
+}
+
+auto WaitFreeCountingPrivateSemaphore::signalInternal() -> Context* {
+	unsigned int oldCounter = counter.fetch_sub(1);
+	assert(oldCounter >= 1);
+	LOGD("signalInternal(): oldCounter: " << oldCounter);
+
+	// If the counter is still non-zero after the decrement, somebody
+	// else is responsible for scheduling the fiber.
+	if (oldCounter > 1) {
+		return nullptr;
+	}
+
+	//	Context* context = blockedContext.load();
+	Context* context = blockedContext;
+	EMPER_ASSERT_MSG(!context || context > (Context*)4096, "Unexpected context value: " << context);
+	return context;
+}
diff --git a/emper/WaitFreeCountingPrivateSemaphore.hpp b/emper/WaitFreeCountingPrivateSemaphore.hpp
new file mode 100644
index 00000000..66d16b43
--- /dev/null
+++ b/emper/WaitFreeCountingPrivateSemaphore.hpp
@@ -0,0 +1,40 @@
+// SPDX-License-Identifier: LGPL-3.0-or-later
+// Copyright © 2020-2022 Florian Schmaus
+#pragma once
+
+#include <atomic>
+#include <limits>
+
+#include "Common.hpp"
+#include "PrivateSemaphore.hpp"
+
+class Context;
+
+/**
+ * A wait-free counting private semaphore, based on the Nowa approach by Schmaus et. al.
+ */
+class WaitFreeCountingPrivateSemaphore : public PrivateSemaphore {
+ private:
+	unsigned int requiredSignalCount;
+	ALIGN_TO_CACHE_LINE std::atomic_uint counter;
+	Context* blockedContext;	// std::atomic?
+
+	inline auto signalInternal() -> Context* override;
+
+ public:
+	WaitFreeCountingPrivateSemaphore();
+	explicit WaitFreeCountingPrivateSemaphore(unsigned int counter);
+
+	inline auto getCounter() -> unsigned int {
+		unsigned int delta = std::numeric_limits<unsigned int>::max() - requiredSignalCount;
+		return counter - delta;
+	};
+
+	void incrementCounterByOne();
+
+	void incrementCounter(unsigned int count);
+
+	void wait() override;
+};
+
+using WFCPS = WaitFreeCountingPrivateSemaphore;
diff --git a/emper/meson.build b/emper/meson.build
index 1f957d07..f5fc87ac 100644
--- a/emper/meson.build
+++ b/emper/meson.build
@@ -32,6 +32,7 @@ emper_cpp_sources = [
   'CountingPrivateSemaphore.cpp',
   'Semaphore.cpp',
   'StealingMode.cpp',
+  'WaitFreeCountingPrivateSemaphore.cpp',
   'WakeupStrategy.cpp',
   'Worker.cpp',
 ]
diff --git a/tests/WaitFreeCountingPrivateSemaphoreTest.cpp b/tests/WaitFreeCountingPrivateSemaphoreTest.cpp
new file mode 100644
index 00000000..672df0f6
--- /dev/null
+++ b/tests/WaitFreeCountingPrivateSemaphoreTest.cpp
@@ -0,0 +1,54 @@
+// SPDX-License-Identifier: LGPL-3.0-or-later
+// Copyright © 2022 Florian Schmaus
+#include <array>
+#include <cstdint>
+#include <iostream>
+#include <numeric>
+
+#include "Debug.hpp"
+#include "WaitFreeCountingPrivateSemaphore.hpp"
+#include "emper-config.h"
+#include "emper.hpp"
+#include "fixtures/assert.hpp"
+
+static const uint64_t DIVISOR = 10;
+static const uint64_t SIZE = EMPER_LOG_LEVEL > Info ? 1000 : 1000000;
+
+static void recurseSum(WFCPS& wfcps, uint64_t* res, uint64_t num, uint64_t size) {
+	if (size == 1) {
+		*res = num;
+		wfcps.signal();
+		return;
+	}
+
+	WFCPS myWfcps(DIVISOR);
+	std::array<uint64_t, DIVISOR> myRes;
+	uint64_t newSize = size / DIVISOR;
+	for (uint64_t i = 0; i < DIVISOR; ++i) {
+		uint64_t newNum = num + (i * newSize);
+		uint64_t* newRes = &myRes[i];
+		async([&, newRes, newNum, newSize] { recurseSum(myWfcps, newRes, newNum, newSize); });
+	}
+
+	myWfcps.wait();
+
+	*res = std::reduce(myRes.begin(), myRes.end());
+	wfcps.signal();
+}
+
+static auto kleinerGauß(uint64_t num) -> uint64_t {
+	uint64_t num_squared = num * num;
+	return (num_squared + num) / 2;
+}
+
+void emperTest() {
+	WFCPS wfcps(1);
+	uint64_t res;
+
+	recurseSum(wfcps, &res, 1, SIZE);
+
+	wfcps.wait();
+	uint64_t expected = kleinerGauß(SIZE);
+	std::cout << "res: " << res << " expected: " << expected << std::endl;
+	ASSERT(res == expected);
+}
diff --git a/tests/meson.build b/tests/meson.build
index b15d76f1..9da80770 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -94,6 +94,12 @@ tests = [
 		'test_runner': 'emper',
 	},
 
+	{
+		'source': files('WaitFreeCountingPrivateSemaphoreTest.cpp'),
+		'description': 'Concurrent test for WaitFreeCountingPrivateSemaphores',
+		'test_runner': 'emper',
+	},
+
 	{
 		'source': files('SignalPrivateSemaphoreFromAnywhereTest.cpp'),
 		'description': 'Simple test for PrivateSemaphore:signalFromAnywhere()',
-- 
GitLab