diff --git a/emper/BinaryPrivateSemaphore.cpp b/emper/BinaryPrivateSemaphore.cpp index 4f57db12c92954ab54a39874182fd09eed4461ff..a90f1c64aa876bdf9c30fc5b0153886188bb1774 100644 --- a/emper/BinaryPrivateSemaphore.cpp +++ b/emper/BinaryPrivateSemaphore.cpp @@ -18,8 +18,9 @@ void BinaryPrivateSemaphore::wait() { // If the binary signal was not yet signaled, then we need to // block here. That is, we perform a full context switch. + blockedContext = ContextManager::getCurrentContext(); block([this] { - // N.B. block() will have set the blockedContext at this point. + // N.B. wait() will have set the blockedContext at this point. BpsState newState(blocked); BpsState previousState = bpsState.exchange(newState); if (previousState.state == signaled) { @@ -30,12 +31,12 @@ void BinaryPrivateSemaphore::wait() { newState.state = signaled; bpsState.store(newState, std::memory_order_relaxed); #endif - unblock(); + unblock(blockedContext); } }); } -bool BinaryPrivateSemaphore::signalInternal() { +Context* BinaryPrivateSemaphore::signalInternal() { BpsState currentState = bpsState.load(); assert(currentState.state != signaled); @@ -51,7 +52,7 @@ bool BinaryPrivateSemaphore::signalInternal() { bpsState.store(newState); #endif LOGDD("unblock in fast path"); - return true; + return blockedContext; } BpsState newState(signaled); @@ -61,9 +62,9 @@ bool BinaryPrivateSemaphore::signalInternal() { // The state is still 'initial' which means the other // fiber hasn't called wait() yet. LOGDD("no unblock in slow path"); - return false; + return nullptr; } LOGDD("unblock in slow path"); - return true; + return blockedContext; } diff --git a/emper/BinaryPrivateSemaphore.hpp b/emper/BinaryPrivateSemaphore.hpp index 9d2785bad5ffb26bdde150dfaae0f569198078ac..2f44cafb35f09e17634cfbc2865a8a4dd6cee49b 100644 --- a/emper/BinaryPrivateSemaphore.hpp +++ b/emper/BinaryPrivateSemaphore.hpp @@ -24,8 +24,10 @@ private: std::atomic<BpsState> bpsState; + Context* blockedContext; + protected: - bool signalInternal() override; + Context* signalInternal() override; public: BinaryPrivateSemaphore() { diff --git a/emper/CountingPrivateSemaphore.cpp b/emper/CountingPrivateSemaphore.cpp index 544522af686e5bcb5d4bc2a6a7d9652cacdfb98a..a4c66a4b89f6b502def9bdba037ab0e3883655b3 100644 --- a/emper/CountingPrivateSemaphore.cpp +++ b/emper/CountingPrivateSemaphore.cpp @@ -12,11 +12,41 @@ void CountingPrivateSemaphore::incrementCounterByOne() { counter++; } -bool CountingPrivateSemaphore::signalInternal() { +void CountingPrivateSemaphore::incrementCounter(unsigned int count) { + counter+=count; +} + +void CountingPrivateSemaphore::wait() { + if (counter > 0) { + Context* blockedContext = ContextManager::getCurrentContext(); + block([this, blockedContext] { + this->blockedContext = blockedContext; + if (this->getCounter() == 0) { + Context* readyContext = this->blockedContext.exchange(nullptr); + if (readyContext != nullptr) { + unblock(readyContext); + } + } + }); + } +} + +Context* CountingPrivateSemaphore::signalInternal() { unsigned int oldCounter = counter.fetch_sub(1); assert(oldCounter >= 1); - if (oldCounter == 1) { - return BinaryPrivateSemaphore::signalInternal(); + + // If the counter is still non-zero after the decrement, somebody + // else is responsible for scheduling the fiber. + if (oldCounter > 1) + return nullptr; + + + if (blockedContext.load() != nullptr && counter == 0) { + // Try to swap out a blocked context, it is fine if this + // returns nullptr. In this case the block() function will + // have won the race. + return blockedContext.exchange(nullptr); } - return false; + + return nullptr; } diff --git a/emper/CountingPrivateSemaphore.hpp b/emper/CountingPrivateSemaphore.hpp index e475cf94ce74a424db41b24b9d9f63e481b84c9e..40a994c3be02c0762d86c546596fe346bb4c7a49 100644 --- a/emper/CountingPrivateSemaphore.hpp +++ b/emper/CountingPrivateSemaphore.hpp @@ -2,17 +2,18 @@ #include <atomic> -#include "BinaryPrivateSemaphore.hpp" +#include "PrivateSemaphore.hpp" /** - * A counting private semaphore. Uses wait() from the binary private semaphore super-class. + * A counting private semaphore. */ -class CountingPrivateSemaphore : public BinaryPrivateSemaphore { +class CountingPrivateSemaphore : public PrivateSemaphore { private: std::atomic_uint counter; + std::atomic<Context*> blockedContext; - inline bool signalInternal() override; + inline Context* signalInternal() override; public: CountingPrivateSemaphore(); @@ -24,6 +25,9 @@ public: void incrementCounterByOne(); + void incrementCounter(unsigned int count); + + virtual void wait() override; }; typedef CountingPrivateSemaphore CPS; diff --git a/emper/PrivateSemaphore.hpp b/emper/PrivateSemaphore.hpp index 7c4adccc520de2b82d0d36fb759ddba4bc0ca9c6..f6260205d4aa0a60bca84a95b448fc77db449829 100644 --- a/emper/PrivateSemaphore.hpp +++ b/emper/PrivateSemaphore.hpp @@ -11,8 +11,6 @@ protected: Runtime& runtime; ContextManager& contextManager; - Context* blockedContext; - // cppcheck-suppress uninitMemberVar PrivateSemaphore() : runtime(*Runtime::getRuntime()) , contextManager(runtime.getContextManager()) @@ -21,39 +19,38 @@ protected: } void block(func_t freshContextHook) { - blockedContext = ContextManager::getCurrentContext(); - LOGD("block() blockedContext is " << blockedContext); + LOGD("block() blockedContext is " << ContextManager::getCurrentContext()); contextManager.saveAndStartNew(freshContextHook); } - void unblock() { + void unblock(Context* context) { + assert(context != nullptr); // cppcheck-suppress unsafeClassCanLeak - Fiber* unblockFiber = Fiber::from([this]() { - assert(blockedContext != nullptr); - contextManager.discardAndResume(blockedContext); + Fiber* unblockFiber = Fiber::from([this, context]() { + contextManager.discardAndResume(context); }); runtime.schedule(*unblockFiber); } - [[noreturn]] void unblockAndExit() { - contextManager.discardAndResume(blockedContext); + [[noreturn]] void unblockAndExit(Context* context) { + contextManager.discardAndResume(context); } - virtual bool signalInternal() = 0; + virtual Context* signalInternal() = 0; public: virtual void wait() = 0; void signal() { - if (signalInternal()) { - unblock(); + if (Context* readyContext = signalInternal()) { + unblock(readyContext); } } void signalAndExit() { - if (signalInternal()) { - unblockAndExit(); + if (Context* readyContext = signalInternal()) { + unblockAndExit(readyContext); } } };