From 827b79aec1066528f35ab403e72830baf0fea7ab Mon Sep 17 00:00:00 2001
From: Florian Fischer <florian.fl.fischer@fau.de>
Date: Thu, 4 Feb 2021 16:03:22 +0100
Subject: [PATCH] [IO] add callback support

Futures can have a registered Callback of type
std::function<void(const uint32_t&)> which gets called in a new
Fiber with the result of the IO Request.
Note the first completion will cause the execution of a callback and
partial completion support must be implemented manually in the callback.

Callbacks are stored in a heap allocated std::function on registration
and are freed by the new Fiber after the callback returned;

The Future with a registered Callback is not referenced in any way in
the IO subsystem and therefore can be dropped after being submitted.
This also means that a Future with a registered callback will not be
signaled by the IO subsystem on completion.
If signaling is desired one must implement it manually in the registered
callback.
---
 emper/io/Future.hpp          | 38 ++++++++++++--
 emper/io/IoContext.cpp       | 82 +++++++++++++++++++++---------
 emper/lib/TaggedPtr.hpp      | 98 ++++++++++++++++++++++++++++++++++++
 tests/FutureCallbackTest.cpp | 30 +++++++++++
 tests/meson.build            |  7 +++
 5 files changed, 227 insertions(+), 28 deletions(-)
 create mode 100644 emper/lib/TaggedPtr.hpp
 create mode 100644 tests/FutureCallbackTest.cpp

diff --git a/emper/io/Future.hpp b/emper/io/Future.hpp
index bbf29d5d..f94e1b8b 100644
--- a/emper/io/Future.hpp
+++ b/emper/io/Future.hpp
@@ -9,13 +9,16 @@
 #include <cstddef>	// for size_t
 #include <cstdint>	// for int32_t, uint8_t
 #include <cstdlib>	// for abort
+#include <functional>
 #include <ostream>	// for operator<<, ostream, basic_ost...
+#include <utility>
 
 #include "BinaryPrivateSemaphore.hpp"	 // for BPS
 #include "CallerEnvironment.hpp"			 // for CallerEnvironment, ANYWHERE
-#include "Debug.hpp"									 // for LOGD, LogSubsystem, LogSubsyst...
-#include "Emper.hpp"									 // for DEBUG
-#include "io/Operation.hpp"						 // for Operation, operator<<, Operati...
+#include "Common.hpp"
+#include "Debug.hpp"				 // for LOGD, LogSubsystem, LogSubsyst...
+#include "Emper.hpp"				 // for DEBUG
+#include "io/Operation.hpp"	 // for Operation, operator<<, Operati...
 
 struct __kernel_timespec;
 struct io_uring_sqe;
@@ -30,6 +33,9 @@ class Future : public Logger<LogSubsystem::IO> {
 	friend class IoContext;
 	friend class Stats;
 
+ public:
+	using Callback = std::function<void(const int32_t&)>;
+
  protected:
 	using State = struct State {
 		uint8_t submitted : 1 = 0;	/*!< An sqe for this Future was prepared */
@@ -66,6 +72,17 @@ class Future : public Logger<LogSubsystem::IO> {
 	// this and will be linked to this IO request using IOSQE_IO_LINK
 	Future* dependency = nullptr;
 
+	/**
+	 * Callback function which is called by a new scheduled Fiber on completion
+	 * A new Callback object is heap allocated in setCallback and freed by the
+	 * executing Fiber after callback returned;
+	 * The heap allocation and thus the Callback pointer is required to decouple the lifetimes
+	 * of the Future used for submission and the Callback which must life for the
+	 * time the request is in the io_uring.
+	 * The callback is called with the value causing the completion
+	 */
+	Callback* callback = nullptr;
+
 	virtual void prepareSqe(io_uring_sqe* sqe) = 0;
 
 	void setCompletion(int32_t res) {
@@ -100,7 +117,7 @@ class Future : public Logger<LogSubsystem::IO> {
 
  public:
 	virtual ~Future() {
-		if (isForgotten()) {
+		if (isForgotten() || callback) {
 			return;
 		}
 
@@ -181,6 +198,19 @@ class Future : public Logger<LogSubsystem::IO> {
 		this->dependency = &dependency;
 	}
 
+	/*
+	 * @brief register a callback which is executed in a new Fiber on completion
+	 *
+	 * @param callback Function which is called with the value causing the completion
+	 */
+	inline void setCallback(Callback callback) {
+		if (unlikely(this->callback)) {
+			delete this->callback;
+		}
+
+		this->callback = new Callback(std::move(callback));
+	}
+
 	/*
 	 * @brief submit Future for asynchronous completion to the workers IoContext
 	 */
diff --git a/emper/io/IoContext.cpp b/emper/io/IoContext.cpp
index e28262b8..b7c1c599 100644
--- a/emper/io/IoContext.cpp
+++ b/emper/io/IoContext.cpp
@@ -18,26 +18,38 @@
 #include "Common.hpp"							// for unlikely, DIE_MSG_ERRNO, DIE_MSG
 #include "Debug.hpp"							// for LOGD
 #include "Emper.hpp"							// for DEBUG, IO_URING_SQPOLL
-#include "Runtime.hpp"						// for Runtime
-#include "io/Future.hpp"					// for Future, operator<<, Future::State
-#include "io/Stats.hpp"						// for Stats, nanoseconds
+#include "Fiber.hpp"
+#include "Runtime.hpp"		// for Runtime
+#include "io/Future.hpp"	// for Future, operator<<, Future::State
+#include "io/Stats.hpp"		// for Stats, nanoseconds
+#include "lib/TaggedPtr.hpp"
 
 #ifndef EMPER_LOG_OFF
 #include <ostream>	// for basic_osteram::operator<<, operator<<
 #endif
 
-// use the most significant bit of a pointer to differ between an IoContext and a Future
-// in the global IoContext's CQ
-static const uintptr_t IOCONTEXT_TAG = 1L << (sizeof(size_t) * 8 - 1);
-static const uintptr_t IOCONTEXT_TAG_MASK = IOCONTEXT_TAG - 1;
+using emper::lib::TaggedPtr;
 
-static inline auto isIoContext(uintptr_t ptr) -> bool { return (ptr & IOCONTEXT_TAG) != 0; }
+namespace emper::io {
+
+enum class PointerTags : uint16_t { Future, IoContext, Callback };
 
-static inline auto stripIoContextTag(uintptr_t ptr) -> IoContext * {
-	return reinterpret_cast<IoContext *>(ptr & IOCONTEXT_TAG_MASK);
+static inline auto castIfFuture(TaggedPtr ptr) -> Future * {
+	if (ptr.getTag() == static_cast<uint16_t>(PointerTags::Future)) {
+		return ptr.getPtr<Future>();
+	}
+
+	return nullptr;
+}
+
+static inline auto castIfCallback(TaggedPtr ptr) -> Future::Callback * {
+	if (ptr.getTag() == static_cast<uint16_t>(PointerTags::Callback)) {
+		return ptr.getPtr<Future::Callback>();
+	}
+
+	return nullptr;
 }
 
-namespace emper::io {
 thread_local IoContext *IoContext::workerIo = nullptr;
 
 pthread_t IoContext::globalCompleter;
@@ -65,9 +77,15 @@ auto IoContext::prepareFutureChain(Future &future, unsigned chain_length) -> uns
 
 	future.prepareSqe(sqe);
 
-	// Someone wants to be notified about the completion of this Future
-	if (!future.isForgotten()) {
-		io_uring_sqe_set_data(sqe, &future);
+	// we should start a new Fiber executing callback on completion
+	if (future.callback) {
+		LOGD("prepare " << future << " Callback " << future.callback);
+		io_uring_sqe_set_data(sqe,
+													TaggedPtr(future.callback, static_cast<uint16_t>(PointerTags::Callback)));
+
+		// Someone wants to be notified about the completion of this Future
+	} else if (!future.isForgotten()) {
+		io_uring_sqe_set_data(sqe, TaggedPtr(&future, static_cast<uint16_t>(PointerTags::Future)));
 	}
 
 	future.state.submitted = true;
@@ -185,13 +203,31 @@ void IoContext::reapCompletions() {
 	io_uring_for_each_cqe(&ring, head, cqe) {
 		count++;
 
-		auto *future = reinterpret_cast<Future *>(io_uring_cqe_get_data(cqe));
+		TaggedPtr tptr(io_uring_cqe_get_data(cqe));
 
 		// Got a CQE for a forgotten Future
-		if (!future) {
+		if (!tptr) {
 			continue;
 		}
 
+		auto *callback = castIfCallback(tptr);
+		if (callback) {
+			LOGD("Schedule new callback fiber for " << callback);
+			auto *callbackFiber = Fiber::from([&c = *callback, res = cqe->res] {
+				c(res);
+				delete &c;
+			});
+			Runtime *runtime = Runtime::getRuntime();
+			if constexpr (callerEnvironment == CallerEnvironment::EMPER) {
+				runtime->schedule(*callbackFiber);
+			} else {
+				runtime->scheduleFromAnywhere(*callbackFiber);
+			}
+			continue;
+		}
+
+		auto *future = tptr.getPtr<Future>();
+
 		// assert that the future was previously in the uringFutureSet
 		assert(uringFutureSet.erase(future) > 0);
 
@@ -255,12 +291,11 @@ auto IoContext::globalCompleterFunc(void *arg) -> void * {
 			perror("io_uring_wait_cqe");
 		}
 
-		auto data = (uintptr_t)io_uring_cqe_get_data(cqe);
+		TaggedPtr tptr(io_uring_cqe_get_data(cqe));
+		auto *future = castIfFuture(tptr);
 
 		// The cqe is for a completed Future
-		if (unlikely(!isIoContext(data))) {
-			auto *future = reinterpret_cast<Future *>(data);
-
+		if (unlikely(future)) {
 			uint32_t res = cqe->res;
 			io_uring_cqe_seen(&io.ring, cqe);
 
@@ -271,7 +306,7 @@ auto IoContext::globalCompleterFunc(void *arg) -> void * {
 
 		// The cqe is for a IoContext.eventfd read
 		//  -> there are completions on this worker IoContext
-		auto *worker_io = stripIoContextTag(data);
+		auto *worker_io = tptr.getPtr<IoContext>();
 		assert(worker_io);
 
 		io_uring_cqe_seen(&io.ring, cqe);
@@ -283,7 +318,7 @@ auto IoContext::globalCompleterFunc(void *arg) -> void * {
 
 		io_uring_prep_read(sqe, worker_io->ring_eventfd, &worker_io->ring_eventfd_readbuf,
 											 sizeof(worker_io->ring_eventfd_readbuf), 0);
-		io_uring_sqe_set_data(sqe, reinterpret_cast<void *>(data));
+		io_uring_sqe_set_data(sqe, tptr);
 
 		submitted = io_uring_submit(&io.ring);
 
@@ -359,8 +394,7 @@ void IoContext::submit_efd() {
 	assert(sqe);
 
 	io_uring_prep_read(sqe, ring_eventfd, &ring_eventfd_readbuf, sizeof(ring_eventfd_readbuf), 0);
-	auto *tagged_io_ptr = reinterpret_cast<void *>((uintptr_t)this | IOCONTEXT_TAG);
-	io_uring_sqe_set_data(sqe, tagged_io_ptr);
+	io_uring_sqe_set_data(sqe, TaggedPtr(this, static_cast<uint16_t>(PointerTags::IoContext)));
 
 	// The sqe we prepared will be submitted to io_uring when the globalCompleter starts.
 }
diff --git a/emper/lib/TaggedPtr.hpp b/emper/lib/TaggedPtr.hpp
new file mode 100644
index 00000000..fcbeb1d3
--- /dev/null
+++ b/emper/lib/TaggedPtr.hpp
@@ -0,0 +1,98 @@
+// SPDX-License-Identifier: LGPL-3.0-or-later
+// Copyright © 2021 Florian Fischer
+#pragma once
+
+#include <climits>
+#include <cstdint>
+
+#define TPTR_POINTER_BITS 48
+#define TPTR_POINTER_MASK ((1UL << TPTR_POINTER_BITS) - 1)
+
+static_assert(sizeof(uintptr_t) * CHAR_BIT - TPTR_POINTER_BITS == sizeof(uint16_t) * CHAR_BIT,
+							"Tagged pointer assumptions are broken");
+
+namespace emper::lib {
+/**
+ * @brief pointer which uses the 16 unused bits in x86-64 address space to store an uint16_t
+ *
+ * Additionally the least significant bit, which will likely be 0 because of alignment can be
+ * used as a mark.
+ */
+class TaggedPtr {
+ private:
+	uintptr_t tptr = 0;
+
+ public:
+	template <typename T>
+	TaggedPtr(T* ptr, uint16_t tag = 0, bool marked = false)
+			: tptr(reinterpret_cast<uintptr_t>(ptr)) {
+		setTag(tag);
+		setMark(marked);
+	}
+
+	TaggedPtr(void* ptr) : tptr(reinterpret_cast<uintptr_t>(ptr)) {}
+	TaggedPtr(uintptr_t ptr) : tptr(ptr) {}
+
+	/**
+	 * @brief extract the 48-bit the pointer part
+	 *
+	 * @return ptr The actuall pointer part of tptr
+	 */
+	template <typename T>
+	[[nodiscard]] inline auto getPtr() const -> T* {
+		// ignore the least significant bit of the tagged pointer
+		return reinterpret_cast<T*>(tptr & (TPTR_POINTER_MASK - 1));
+	}
+
+	/**
+	 * @brief extract the 16-bit tag part
+	 *
+	 * @return tag The 16-bit tag stored in tptr
+	 */
+	[[nodiscard]] inline auto getTag() const -> uint16_t {
+		return static_cast<uint16_t>(tptr >> TPTR_POINTER_BITS);
+	}
+	/**
+	 * @brief update the 16-bit counter part
+	 *
+	 * @param tag The tag which should be stored in the tagged ptr
+	 */
+	inline void setTag(uint16_t tag) {
+		uintptr_t stripped_ptr = tptr & TPTR_POINTER_MASK;
+		tptr = stripped_ptr | ((uintptr_t)tag << TPTR_POINTER_BITS);
+	}
+
+	/**
+	 * @brief check if the least significant bit is set
+	 *
+	 * @return marked True if the least significant bit is set, False otherwise
+	 */
+	[[nodiscard]] inline auto isMarked() const -> bool { return (tptr & 1) == 1; }
+
+	/**
+	 * @brief set or unset the least significant bit
+	 *
+	 * @param mark True if the least significant bit should be set, False if it should be unset
+	 */
+	inline void setMark(bool mark) { tptr = mark ? (tptr | 1) : tptr & (~1); }
+
+	/**
+	 * @brief unfold both parts of a tagged pointer
+	 *
+	 * @param ptr A pointer where the pointer part should be stored
+	 * @return tag The 16-bit tag stored in tptr
+	 *
+	 */
+	template <typename T>
+	inline auto unfold(T** ptr) const -> uint16_t {
+		*ptr = getPtr<T>();
+		return getTag();
+	}
+
+	inline operator uintptr_t() const { return tptr; }
+
+	inline operator void*() const { return reinterpret_cast<void*>(tptr); }
+
+	inline operator bool() const { return tptr != 0; }
+};
+}	 // namespace emper::lib
diff --git a/tests/FutureCallbackTest.cpp b/tests/FutureCallbackTest.cpp
new file mode 100644
index 00000000..066efe56
--- /dev/null
+++ b/tests/FutureCallbackTest.cpp
@@ -0,0 +1,30 @@
+// SPDX-License-Identifier: LGPL-3.0-or-later
+// Copyright © 2020-2021 Florian Fischer
+#include <liburing.h>
+
+#include <cassert>	// for assert
+#include <cerrno>		// for ETIME
+#include <cstdint>	// for int32_t
+
+#include "BinaryPrivateSemaphore.hpp"
+#include "io/Future.hpp"	// for AlarmFuture
+
+using emper::io::AlarmFuture;
+using emper::io::Future;
+
+void callback(int32_t res, BPS& bps) {
+	assert(res == -ETIME);
+	bps.signal();
+}
+
+void emperTest() {
+	struct __kernel_timespec ts = {.tv_sec = 1, .tv_nsec = 0};
+	AlarmFuture alarm(ts);
+	BPS bps;
+
+	alarm.setCallback(Future::Callback([&bps](int32_t res) { callback(res, bps); }));
+	alarm.submit();
+
+	// wait till the callback was executed
+	bps.wait();
+}
diff --git a/tests/meson.build b/tests/meson.build
index a08c6b54..b845e886 100644
--- a/tests/meson.build
+++ b/tests/meson.build
@@ -110,6 +110,13 @@ tests = {
 			'test_runner': 'io',
 		  },
 
+		  'FutureCallbackTest.cpp':
+		  {
+			'description': 'Test Future callback',
+			'test_suite': 'io',
+			'test_runner': 'io',
+		  },
+
 		  'TimeoutWrapperTest.cpp':
 		  {
 			'description': 'Test TimeoutWrapper object based IO request timeouts',
-- 
GitLab