Skip to content
Snippets Groups Projects
Commit 0a3530c8 authored by Florian Schmaus's avatar Florian Schmaus
Browse files

Add WaitFreeCountingSemaphore

parent 105e4078
No related branches found
No related tags found
No related merge requests found
Pipeline #95208 passed
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright © 2020-2022 Florian Schmaus
#include "WaitFreeCountingPrivateSemaphore.hpp"
#include <cassert>
#include <limits>
#include <ostream>
#include "Context.hpp"
#include "Debug.hpp"
WaitFreeCountingPrivateSemaphore::WaitFreeCountingPrivateSemaphore()
: WaitFreeCountingPrivateSemaphore(0) {}
WaitFreeCountingPrivateSemaphore::WaitFreeCountingPrivateSemaphore(unsigned int counter)
: requiredSignalCount(counter),
counter(std::numeric_limits<unsigned int>::max()),
blockedContext(nullptr) {}
void WaitFreeCountingPrivateSemaphore::incrementCounterByOne() { requiredSignalCount++; }
void WaitFreeCountingPrivateSemaphore::incrementCounter(unsigned int count) {
requiredSignalCount += count;
}
void WaitFreeCountingPrivateSemaphore::wait() {
unsigned int delta = std::numeric_limits<unsigned int>::max() - requiredSignalCount;
LOGD("wait() counter: " << counter << " delta: " << delta);
if (counter - delta == 0) {
return;
}
Context* blockedContext = Context::getCurrentContext();
block([this, blockedContext, delta] {
assert(blockedContext > (Context*)4096);
// this->blockedContext.store(blockedContext); // MO relaxed possible
this->blockedContext = blockedContext;
unsigned int oldCounter = counter.fetch_sub(delta); // MO release
LOGD("wait() block oldCounter: " << oldCounter << " delta: " << delta);
if (oldCounter == delta) {
unblock(blockedContext);
}
});
}
auto WaitFreeCountingPrivateSemaphore::signalInternal() -> Context* {
unsigned int oldCounter = counter.fetch_sub(1);
assert(oldCounter >= 1);
LOGD("signalInternal(): oldCounter: " << oldCounter);
// If the counter is still non-zero after the decrement, somebody
// else is responsible for scheduling the fiber.
if (oldCounter > 1) {
return nullptr;
}
// Context* context = blockedContext.load();
Context* context = blockedContext;
EMPER_ASSERT_MSG(!context || context > (Context*)4096, "Unexpected context value: " << context);
return context;
}
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright © 2020-2022 Florian Schmaus
#pragma once
#include <atomic>
#include <limits>
#include "Common.hpp"
#include "PrivateSemaphore.hpp"
class Context;
/**
* A wait-free counting private semaphore, based on the Nowa approach by Schmaus et. al.
*/
class WaitFreeCountingPrivateSemaphore : public PrivateSemaphore {
private:
unsigned int requiredSignalCount;
ALIGN_TO_CACHE_LINE std::atomic_uint counter;
Context* blockedContext; // std::atomic?
inline auto signalInternal() -> Context* override;
public:
WaitFreeCountingPrivateSemaphore();
explicit WaitFreeCountingPrivateSemaphore(unsigned int counter);
inline auto getCounter() -> unsigned int {
unsigned int delta = std::numeric_limits<unsigned int>::max() - requiredSignalCount;
return counter - delta;
};
void incrementCounterByOne();
void incrementCounter(unsigned int count);
void wait() override;
};
using WFCPS = WaitFreeCountingPrivateSemaphore;
...@@ -32,6 +32,7 @@ emper_cpp_sources = [ ...@@ -32,6 +32,7 @@ emper_cpp_sources = [
'CountingPrivateSemaphore.cpp', 'CountingPrivateSemaphore.cpp',
'Semaphore.cpp', 'Semaphore.cpp',
'StealingMode.cpp', 'StealingMode.cpp',
'WaitFreeCountingPrivateSemaphore.cpp',
'WakeupStrategy.cpp', 'WakeupStrategy.cpp',
'Worker.cpp', 'Worker.cpp',
] ]
......
// SPDX-License-Identifier: LGPL-3.0-or-later
// Copyright © 2022 Florian Schmaus
#include <array>
#include <cstdint>
#include <iostream>
#include <numeric>
#include "Debug.hpp"
#include "WaitFreeCountingPrivateSemaphore.hpp"
#include "emper-config.h"
#include "emper.hpp"
#include "fixtures/assert.hpp"
static const uint64_t DIVISOR = 10;
static const uint64_t SIZE = EMPER_LOG_LEVEL > Info ? 1000 : 1000000;
static void recurseSum(WFCPS& wfcps, uint64_t* res, uint64_t num, uint64_t size) {
if (size == 1) {
*res = num;
wfcps.signal();
return;
}
WFCPS myWfcps(DIVISOR);
std::array<uint64_t, DIVISOR> myRes;
uint64_t newSize = size / DIVISOR;
for (uint64_t i = 0; i < DIVISOR; ++i) {
uint64_t newNum = num + (i * newSize);
uint64_t* newRes = &myRes[i];
async([&, newRes, newNum, newSize] { recurseSum(myWfcps, newRes, newNum, newSize); });
}
myWfcps.wait();
*res = std::reduce(myRes.begin(), myRes.end());
wfcps.signal();
}
static auto kleinerGauß(uint64_t num) -> uint64_t {
uint64_t num_squared = num * num;
return (num_squared + num) / 2;
}
void emperTest() {
WFCPS wfcps(1);
uint64_t res;
recurseSum(wfcps, &res, 1, SIZE);
wfcps.wait();
uint64_t expected = kleinerGauß(SIZE);
std::cout << "res: " << res << " expected: " << expected << std::endl;
ASSERT(res == expected);
}
...@@ -94,6 +94,12 @@ tests = [ ...@@ -94,6 +94,12 @@ tests = [
'test_runner': 'emper', 'test_runner': 'emper',
}, },
{
'source': files('WaitFreeCountingPrivateSemaphoreTest.cpp'),
'description': 'Concurrent test for WaitFreeCountingPrivateSemaphores',
'test_runner': 'emper',
},
{ {
'source': files('SignalPrivateSemaphoreFromAnywhereTest.cpp'), 'source': files('SignalPrivateSemaphoreFromAnywhereTest.cpp'),
'description': 'Simple test for PrivateSemaphore:signalFromAnywhere()', 'description': 'Simple test for PrivateSemaphore:signalFromAnywhere()',
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment