From e0f54d464795beecb0450f7a106cf414ac89f972 Mon Sep 17 00:00:00 2001
From: Florian Fischer <florian.fischer@muhq.space>
Date: Mon, 27 Dec 2021 16:56:18 +0100
Subject: [PATCH] [PipeSleepStrategy] fix specific state and sleeper count race

Introducing a lock for each specific state greatly simplifies the
algorithm, fixes a race and I expect it to be rather cheap.
The fact that we have to check two conditions before sleeping
and prepare resources dependent on those makes the algorithm
complex and racy.
We skip sleeping if we were notified specifically or the global sleeper
count was less than 0.

If we check our local state first and decrement the global sleeper count
later. We could receive a notification after the decrement which causes
the worker to skip sleeping making the decrement wrong and the whole counter
unsound.

Checking the local state first and mark us as sleeping preparing a
read for the specific pipe has the problem that after the decrement if we
should skip sleeping we have prepared sqes which we should submit needlessly
because we are not actually sleeping.

Ans decrementing the global count first has the same problem as the
first one where the decrement is wrong if we skip sleeping after wards
breaking the counter.

All this is prevented by locking the specific state while we check both conditions.
---
 emper/sleep_strategy/PipeSleepStrategy.cpp | 110 +++++++-------
 emper/sleep_strategy/PipeSleepStrategy.hpp | 166 +++++----------------
 2 files changed, 94 insertions(+), 182 deletions(-)

diff --git a/emper/sleep_strategy/PipeSleepStrategy.cpp b/emper/sleep_strategy/PipeSleepStrategy.cpp
index d855504b..3466425d 100644
--- a/emper/sleep_strategy/PipeSleepStrategy.cpp
+++ b/emper/sleep_strategy/PipeSleepStrategy.cpp
@@ -43,69 +43,62 @@ PipeSleepStrategy::~PipeSleepStrategy() {
 void PipeSleepStrategy::sleep() {
 	IoContext& io = *IoContext::getWorkerIo();
 
-	uint8_t observedState = sleepState._loadAcquire();
+	{
+		std::lock_guard<std::mutex> lock(sleepState.lock);
 
-	// Check if were notified specifically and should skip the sleep attempt
-	if (SleepState::isNotified(observedState)) {
-		sleepState.clearNotified();
-		LOGD("Reset notified state -> skip sleeping");
-		return;
-	}
-
-	if (SleepState::shouldPrepareGlobalRead(observedState)) {
-		// increment the sleeper count if it was negative we should skip sleeping
-		int64_t sleeping = sleepers.fetch_add(1, std::memory_order_acquire);
-		if (sleeping < 0) {
-			LOGD("observed sleeper count as: " << sleeping << " -> skip sleeping");
+		// Check if we were notified specifically and should skip the sleep attempt
+		if (sleepState.isNotified()) {
+			sleepState.markRunning();
+			LOGD("Reset notified state to running -> skip sleeping");
 			return;
 		}
 
-		LOGD("mark state as sleeping");
-		observedState = sleepState.markSleeping();
-		if (SleepState::isNotified(observedState)) {
-			// Preserve a already set specific flag
-			if (!SleepState::shouldPrepareSpecificRead(observedState))
-				sleepState.clearGlobalNotified();
-			else {
-				sleepState.clearAll();
+		if (!sleepState.globalReadInflight) {
+			// increment the sleeper count if it was negative we should skip sleeping
+			int64_t sleeping = sleepers.fetch_add(1, std::memory_order_acquire);
+			if (sleeping < 0) {
+				LOGD("observed sleeper count as: " << sleeping << " -> skip sleeping");
+				return;
 			}
-		}
+			assert(sleeping <= workerCount);
 
-		assert(sleeping <= workerCount);
+			sleepState.globalReadInflight = true;
 
-		struct io_uring_sqe* sqe = io.getSqe();
+			struct io_uring_sqe* sqe = io.getSqe();
 
-		// We read directly into the workers dispatchHint
-		io_uring_prep_read(sqe, global.sleepFd, &io.worker->dispatchHint,
-											 sizeof(io.worker->dispatchHint), 0);
+			// We read directly into the workers dispatchHint
+			io_uring_prep_read(sqe, global.sleepFd, &io.worker->dispatchHint,
+												 sizeof(io.worker->dispatchHint), 0);
 
-		// Mark the sqe as a new work notification to reset the Global flag when reaping the
-		// resulting cqe
-		io_uring_sqe_set_data(
-				sqe, TaggedPtr(0, static_cast<uint16_t>(IoContext::PointerTags::NewWorkNotification)));
+			// Mark the sqe as a new work notification to reset the Global flag when reaping the
+			// resulting cqe
+			io_uring_sqe_set_data(
+					sqe, TaggedPtr(0, static_cast<uint16_t>(IoContext::PointerTags::NewWorkNotification)));
 
-		io.trackReqsInUring(1);
+			io.trackReqsInUring(1);
 
-		LOGD("prepared global.sleepFd read and set sleepers count to: " << sleeping + 1);
-	}
+			LOGD("prepared global.sleepFd read and set sleepers count to: " << sleeping + 1);
+		}
 
-	// If we reach this Global is definitly set
-	if (SleepState::shouldPrepareSpecificRead(observedState)) {
-		LOGD("prepared pipe.sleepFd read");
-		struct io_uring_sqe* sqe = io.getSqe();
+		if (!sleepState.isSleeping()) {
+			sleepState.markSleeping();
 
-		// We read directly into the workers dispatchHint
-		// TODO: Think about the race between the two reads
-		io_uring_prep_read(sqe, pipe.sleepFd, &io.worker->dispatchHint, sizeof(io.worker->dispatchHint),
-											 0);
+			struct io_uring_sqe* sqe = io.getSqe();
 
-		// Mark the sqe as a marked new work notification to reset the Specific flag when reaping the
-		// resulting cqe
-		io_uring_sqe_set_data(
-				sqe, TaggedPtr((void*)nullptr,
-											 static_cast<uint16_t>(IoContext::PointerTags::NewWorkNotification), true));
+			// We read directly into the workers dispatchHint
+			// TODO: Think about the race between the two reads
+			io_uring_prep_read(sqe, pipe.sleepFd, &io.worker->dispatchHint,
+												 sizeof(io.worker->dispatchHint), 0);
 
-		io.trackReqsInUring(1);
+			// Tag the sqe with a marked new work notification to reset the specific state
+			// when reaping the resulting cqe.
+			io_uring_sqe_set_data(
+					sqe, TaggedPtr((void*)nullptr,
+												 static_cast<uint16_t>(IoContext::PointerTags::NewWorkNotification), true));
+
+			io.trackReqsInUring(1);
+			LOGD("prepared pipe.sleepFd read");
+		}
 	}
 
 	// Wait for IO completions
@@ -119,19 +112,21 @@ void PipeSleepStrategy::sleep() {
 template <CallerEnvironment callerEnvironment>
 void PipeSleepStrategy::notifySpecific(workerid_t workerId) {
 	auto& specificState = *sleepStates[workerId];
-	uint8_t observedState = specificState._loadAcquire();
-	LOGD("Specifically notify " << workerId << " with state: " << (int)observedState << " from "
-															<< callerEnvironment);
-	if (SleepState::isNotified(observedState)) {
+	std::lock_guard<std::mutex> lock(specificState.lock);
+
+	LOGD("Specifically notify  worker " << workerId << " from " << callerEnvironment);
+
+	if (specificState.isNotified()) {
 		LOGD(workerId << " already marked notified");
 		return;
 	}
 
-	observedState = specificState.markNotified();
+	const bool isSleeping = specificState.markNotified();
 
-	if (SleepState::shouldNotifySpecific(observedState))
+	if (isSleeping) {
 		LOGD(workerId << " has specific read -> write notification");
-	writeSpecificNotification(workerId);
+		writeSpecificNotification(workerId);
+	}
 }
 
 template void PipeSleepStrategy::notifySpecific<CallerEnvironment::EMPER>(workerid_t workerId);
@@ -146,19 +141,20 @@ void PipeSleepStrategy::onNewWorkNotification(IoContext& io, TaggedPtr data) {
 	assert(data.getTag() == static_cast<uint16_t>(IoContext::PointerTags::NewWorkNotification));
 
 	if (data.isMarked()) {
+		std::lock_guard<std::mutex> lock(sleepState.lock);
 		LOGD("Got specific notification");
 		stats.incWakeupDueToNotify();
 
 		// Reset specific and notified flag to indicate that a new specific notification
 		// was consumed, a new specific read must be prepared and other notifySpecific
 		// calls must notify again.
-		sleepState.clearSpecificNotified();
+		sleepState.markRunning();
 	} else {
 		LOGD("Got new work notification");
 		stats.incWakeupDueToNotify();
 		// Reset global flag to indicate that a new sleep cqe must be prepared
 		// and allow the completer to reap completions again
-		sleepState.clearGlobal();
+		sleepState.globalReadInflight = false;
 	}
 }
 
diff --git a/emper/sleep_strategy/PipeSleepStrategy.hpp b/emper/sleep_strategy/PipeSleepStrategy.hpp
index b5a10985..ccce97be 100644
--- a/emper/sleep_strategy/PipeSleepStrategy.hpp
+++ b/emper/sleep_strategy/PipeSleepStrategy.hpp
@@ -8,6 +8,7 @@
 #include <atomic>
 #include <cstdint>
 #include <iostream>
+#include <mutex>
 #include <vector>
 
 #include "CallerEnvironment.hpp"
@@ -17,6 +18,7 @@
 #include "emper-common.h"
 #include "lib/TaggedPtr.hpp"
 #include "sleep_strategy/AbstractWorkerSleepStrategy.hpp"
+#include "sleep_strategy/SleeperState.hpp"
 #include "sleep_strategy/Stats.hpp"
 
 class Runtime;
@@ -50,26 +52,26 @@ namespace emper::sleep_strategy {
  *     Per worker:
  *             dispatch hint buffer
  *             specific pipe
+ *             state lock
  *             sleep state
  *
  * Sleep:
- *     Acquire state
- *     If notified
- *         Clear notified
- *         Skip sleep
+ *     Lock state
+ *     If state == notified
+ *         Set state = running
+ *         return
  *
- *     If we have no sleep request in flight
+ *     If we have no global sleep request in flight
  *             Atomic increment sleep count
  *             Skip sleeping if sleep count was < 0
  *             Mark sleep requests in flight
- *             If notified
- *                  Clear global and notified
- *                  Skip sleep
+ *             Prepare read cqe from the global pipe to dispatch hint buffer
  *
- *             If global read required
- *                  Prepare read cqe from the global pipe to dispatch hint buffer
- *             If Specific read required
- *                  Prepare read cqe from the specific pipe to dispatch hint buffer
+ *     If state == running
+ *             Set state = sleeping
+ *             Prepare marked read cqe from the specific pipe to dispatch hint buffer
+ *
+ *     Unlock state
  *     Wait until IO completions occurred
  *
  * NotifyEmper(n):
@@ -92,9 +94,20 @@ namespace emper::sleep_strategy {
  *    toWakeup = min(observed sleeping, n)
  *    write toWakeup hints to the hint pipe
  *
+ * NotifySpecific(w):
+ *    Get w's state
+ *    Lock state
+ *    Return if already notified
+ *    Mark state notified
+ *    Write hint if was stat was sleeping
+ *
  * onNewWorkCompletion:
- *     reset in flight flag
- *     allow completer to reap completions on this IoContext
+ *     If data is marked
+ *         lock state
+ *         set state = running
+ *         return
+ *
+ *     Reset global read inflight
  *```
  *
  * Notes
@@ -115,9 +128,9 @@ namespace emper::sleep_strategy {
  *   This is a trade-off where we trade slower wakeup - a just awoken worker
  *   will check for local work - against a faster dispatch hot path when
  *   we have work to do in our local WSQ.
- * * The completer tread must not reap completions on the IoContexts of
- *   sleeping workers because this introduces a race for cqes and a possible
- *   lost wakeup if the completer consumes the completions before the worker
+ * * Other thread must not reap new work notifications because this
+ *   would introduces a race for cqes and a possible
+ *   lost wakeup if the other consumes the completions before the worker
  *   is actually waiting for them.
  * * When notifying sleeping workers from anywhere we must ensure that all
  *   notifications take effect. This is needed for example when terminating
@@ -135,118 +148,21 @@ class PipeSleepStrategy : AbstractWorkerSleepStrategy<PipeSleepStrategy>,
 	 */
 	class SleepState {
 		friend class PipeSleepStrategy;
-		friend class emper::io::IoContext;
-		CACHE_LINE_EXCLUSIVE(std::atomic<uint8_t>, s);
-		struct State {
-			enum _State {
-				Notified = 1 << 0, /*!< The worker was already notified specifically */
-				Specific = 1 << 1, /*!< The worker is reading from its pipe */
-				Global = 1 << 2,	 /*!< The worker is reading from the global pipe */
-			};
-		};
-
-		/**
-		 * @brief helper to atomically set a bit in the state
-		 *
-		 * @return the old state
-		 */
-		auto _setState(uint8_t bits) -> uint8_t {
-			uint8_t oldState = s.load(std::memory_order_relaxed);
-			do {
-			} while (!s.compare_exchange_weak(oldState, oldState | bits, std::memory_order_release,
-																				std::memory_order_relaxed));
-			return oldState;
-		}
-
-		/**
-		 * @brief helper to atomically clear a bit from the state
-		 *
-		 * @return the old state
-		 */
-		auto _clearState(uint8_t bits) -> uint8_t {
-			uint8_t oldState = s.load(std::memory_order_relaxed);
-			do {
-			} while (!s.compare_exchange_weak(oldState, oldState & ~bits, std::memory_order_release,
-																				std::memory_order_relaxed));
-			return oldState;
-		}
-
-		/**
-		 * @brief helper to load the state with std::memory_order_relaxed
-		 *
-		 * @return the observed state
-		 */
-		auto _loadAcquire() const -> uint8_t { return s.load(std::memory_order_acquire); }
-
-		/**
-		 * @brief Mark the state as notified
-		 *
-		 * @return if we should write to the worker's specific pipe
-		 */
-		[[nodiscard]] auto markNotified() -> bool {
-			return _setState(State::Notified) & State::Specific;
-		}
 
-		/**
-		 * @brief Mark the state as sleeping
-		 *
-		 * Set Specific and Global if Needed
-		 *
-		 * @return return the old state
-		 */
-		[[nodiscard]] auto markSleeping() -> uint8_t {
-			return _setState(State::Specific | State::Global);
-		}
-
-		/**
-		 * @brief Check if the oldState was not reading from the specific pipe
-		 *
-		 * @param oldState State retrieved from a State modifying method
-		 *
-		 * @return oldState & Specific
-		 */
-		[[nodiscard]] static auto shouldPrepareSpecificRead(uint8_t oldState) -> bool {
-			return !(oldState & State::Specific);
-		}
-
-		/**
-		 * @brief Check if the oldState was not reading from the global pipe
-		 *
-		 * @param oldState State retrieved from a State modifying method
-		 *
-		 * @return oldState & Global
-		 */
-		[[nodiscard]] static auto shouldPrepareGlobalRead(uint8_t oldState) -> bool {
-			return !(oldState & State::Global);
-		}
-
-		/**
-		 * @brief Check if the oldState is notified
-		 *
-		 * @param oldState State retrieved from a State loading method
-		 *
-		 * @return oldState & Notified
-		 */
-		[[nodiscard]] static auto isNotified(uint8_t oldState) -> bool {
-			return oldState & State::Notified;
-		}
+		bool globalReadInflight = false;
+		std::mutex lock;
+		emper::sleep_strategy::SleeperState s = emper::sleep_strategy::SleeperState::Running;
 
-		/**
-		 * @brief Check if the oldState has specific read
-		 *
-		 * @param oldState State retrieved from a State loading method
-		 *
-		 * @return oldState & Specific
-		 */
-		[[nodiscard]] static auto shouldNotifySpecific(uint8_t oldState) -> bool {
-			return !(oldState & State::Specific);
+		auto markNotified() -> bool {
+			auto oldS = s;
+			s = emper::sleep_strategy::SleeperState::Notified;
+			return oldS == emper::sleep_strategy::SleeperState::Sleeping;
 		}
+		void markSleeping() { s = emper::sleep_strategy::SleeperState::Sleeping; }
+		void markRunning() { s = emper::sleep_strategy::SleeperState::Running; }
 
-		void clearAll() { s.store(0, std::memory_order_release); }
-		void clearGlobal() { _clearState(State::Global); }
-		void clearNotified() { _clearState(State::Notified); }
-		void clearGlobalNotified() { _clearState(State::Global | State::Notified); }
-		void clearSpecificNotified() { _clearState(State::Specific | State::Notified); }
+		auto isNotified() const -> bool { return s == emper::sleep_strategy::SleeperState::Notified; }
+		auto isSleeping() const -> bool { return s == emper::sleep_strategy::SleeperState::Sleeping; }
 	};
 
 	SleepState** sleepStates;
-- 
GitLab