Skip to content

Commit

Permalink
added error_code support and forwarding of stop tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
dietmarkuehl committed Jan 18, 2025
1 parent d584e16 commit 63ed1b6
Show file tree
Hide file tree
Showing 3 changed files with 155 additions and 20 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ default: compile

compile: config
#cmake --build $(BUILDDIR) -j
cmake --workflow --preset=gcc-debug --fresh
cmake --workflow --preset=appleclang-debug

format:
git clang-format main
Expand Down
50 changes: 40 additions & 10 deletions include/beman/lazy/detail/any_scheduler.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,37 +7,41 @@
#include <beman/execution26/execution.hpp>
#include <beman/lazy/detail/poly.hpp>
#include <new>
#include <optional>
#include <utility>

// ----------------------------------------------------------------------------

namespace beman::lazy::detail {

class any_scheduler {
// TODO: add support for forwarding stop_tokens to the type-erased sender
// TODO: other errors than std::exception_ptr should be supported
struct state_base {
virtual ~state_base() = default;
virtual void complete_value() = 0;
virtual void complete_error(std::exception_ptr) = 0;
virtual void complete_error(::std::error_code) = 0;
virtual void complete_error(::std::exception_ptr) = 0;
virtual void complete_stopped() = 0;
virtual ::beman::execution26::inplace_stop_token get_stop_token() = 0;
};

struct inner_state {
struct receiver;
struct env {
state_base* state;
auto query(::beman::execution26::get_stop_token_t) const noexcept { return this->state->get_stop_token(); }
};
struct receiver {
using receiver_concept = ::beman::execution26::receiver_t;
state_base* state;
void set_value() && noexcept { this->state->complete_value(); }
void set_error(std::error_code err) && noexcept { this->state->complete_error(err); }
void set_error(std::exception_ptr ptr) && noexcept { this->state->complete_error(std::move(ptr)); }
template <typename E>
void set_error(E e) {
try {
throw std::move(e);
} catch (...) {
this->state->complete_error(std::current_exception());
}
void set_error(E e) && noexcept {
this->state->complete_error(std::make_exception_ptr(std::move(e)));
}
void set_stopped() && noexcept { this->state->complete_stopped(); }
env get_env() const noexcept { return {this->state}; }
};
static_assert(::beman::execution26::receiver<receiver>);

Expand All @@ -53,7 +57,7 @@ class any_scheduler {
concrete(S&& s, state_base* b) : state(::beman::execution26::connect(std::forward<S>(s), receiver{b})) {}
void start() override { ::beman::execution26::start(state); }
};
::beman::lazy::detail::poly<base, 8u * sizeof(void*)> state;
::beman::lazy::detail::poly<base, 16u * sizeof(void*)> state;
template <::beman::execution26::sender S>
inner_state(S&& s, state_base* b) : state(static_cast<concrete<S>*>(nullptr), std::forward<S>(s), b) {}
void start() { this->state->start(); }
Expand All @@ -62,17 +66,42 @@ class any_scheduler {
template <::beman::execution26::receiver Receiver>
struct state : state_base {
using operation_state_concept = ::beman::execution26::operation_state_t;
struct stopper {
state* st;
void operator()() noexcept {
state* self = this->st;
self->callback.reset();
self->source.request_stop();
}
};
using token_t =
decltype(::beman::execution26::get_stop_token(::beman::execution26::get_env(std::declval<Receiver>())));
using callback_t = ::beman::execution26::stop_callback_for_t<token_t, stopper>;

std::remove_cvref_t<Receiver> receiver;
inner_state s;
::beman::execution26::inplace_stop_source source;
::std::optional<callback_t> callback;

template <::beman::execution26::receiver R, typename PS>
state(R&& r, PS& ps) : receiver(std::forward<R>(r)), s(ps->connect(this)) {}
void start() & noexcept { this->s.start(); }
void complete_value() override { ::beman::execution26::set_value(std::move(this->receiver)); }
void complete_error(std::error_code err) override {
::beman::execution26::set_error(std::move(receiver), err);
}
void complete_error(std::exception_ptr ptr) override {
::beman::execution26::set_error(std::move(receiver), std::move(ptr));
}
void complete_stopped() override { ::beman::execution26::set_stopped(std::move(this->receiver)); }
::beman::execution26::inplace_stop_token get_stop_token() override {
if (not this->callback) {
this->callback.emplace(
::beman::execution26::get_stop_token(::beman::execution26::get_env(this->receiver)),
stopper{this});
}
return this->source.get_token();
}
};

class sender;
Expand Down Expand Up @@ -123,6 +152,7 @@ class any_scheduler {
using sender_concept = ::beman::execution26::sender_t;
using completion_signatures =
::beman::execution26::completion_signatures<::beman::execution26::set_value_t(),
::beman::execution26::set_error_t(std::error_code),
::beman::execution26::set_error_t(std::exception_ptr),
::beman::execution26::set_stopped_t()>;

Expand Down
123 changes: 114 additions & 9 deletions tests/beman/lazy/any_scheduler.test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
#include <beman/lazy/detail/inline_scheduler.hpp>
#include <beman/execution26/execution.hpp>
#include <atomic>
#include <exception>
#include <system_error>
#include <thread>
#include <condition_variable>
#include <mutex>
Expand All @@ -20,6 +22,7 @@ namespace ly = beman::lazy;

namespace {
struct thread_context {
enum class complete { success, failure, exception, never };
struct base {
base* next{};
virtual ~base() = default;
Expand Down Expand Up @@ -63,18 +66,46 @@ struct thread_context {
struct scheduler {
using scheduler_concept = ex::scheduler_t;
thread_context* context;
complete cmpl{complete::success};
bool operator==(const scheduler&) const = default;

template <ex::receiver Receiver>
struct state : base {
struct stopper {
state* st;
void operator()() noexcept {
auto self{this->st};
self->callback.reset();
ex::set_stopped(std::move(self->receiver));
}
};
using operation_state_concept = ex::operation_state_t;
using token_t = decltype(ex::get_stop_token(ex::get_env(std::declval<Receiver>())));
using callback_t = ex::stop_callback_for_t<token_t, stopper>;

thread_context* ctxt;
std::remove_cvref_t<Receiver> receiver;
thread_context::complete cmpl;
std::optional<callback_t> callback;

template <typename R>
state(auto c, R&& r) : ctxt(c), receiver(std::forward<R>(r)) {}
void start() & noexcept { this->ctxt->enqueue(this); }
void complete() override { ex::set_value(std::move(this->receiver)); }
state(auto c, R&& r, thread_context::complete cm) : ctxt(c), receiver(std::forward<R>(r)), cmpl(cm) {}
void start() & noexcept {
callback.emplace(ex::get_stop_token(ex::get_env(this->receiver)), stopper{this});
if (cmpl != thread_context::complete::never)
this->ctxt->enqueue(this);
}
void complete() override {
this->callback.reset();
if (this->cmpl == thread_context::complete::success)
ex::set_value(std::move(this->receiver));
else if (this->cmpl == thread_context::complete::failure)
ex::set_error(std::move(this->receiver), std::make_error_code(std::errc::address_in_use));
else
ex::set_error(
std::move(this->receiver),
std::make_exception_ptr(std::system_error(std::make_error_code(std::errc::address_in_use))));
}
};
struct env {
thread_context* ctxt;
Expand All @@ -83,25 +114,27 @@ struct thread_context {
}
};
struct sender {
using sender_concept = ex::sender_t;
using completion_signatures = ex::completion_signatures<ex::set_value_t()>;
using sender_concept = ex::sender_t;
using completion_signatures =
ex::completion_signatures<ex::set_value_t(), ex::set_error_t(std::error_code)>;

thread_context* ctxt;
thread_context* ctxt;
thread_context::complete cmpl;

template <ex::receiver Receiver>
auto connect(Receiver&& receiver) {
static_assert(ex::operation_state<state<Receiver>>);
return state<Receiver>(this->ctxt, std::forward<Receiver>(receiver));
return state<Receiver>(this->ctxt, std::forward<Receiver>(receiver), this->cmpl);
}
env get_env() const noexcept { return {this->ctxt}; }
};
static_assert(ex::sender<sender>);

sender schedule() noexcept { return sender{this->context}; }
sender schedule() noexcept { return sender{this->context, this->cmpl}; }
};
static_assert(ex::scheduler<scheduler>);

scheduler get_scheduler() { return scheduler{this}; }
scheduler get_scheduler(complete cmpl = complete::success) { return scheduler{this, cmpl}; }
void stop() {
{
std::lock_guard cerberus(this->mutex);
Expand All @@ -110,6 +143,24 @@ struct thread_context {
this->condition.notify_one();
}
};

enum class stop_result { none, success, failure, stopped };
struct stop_env {
ex::inplace_stop_token token;
auto query(ex::get_stop_token_t) const noexcept { return this->token; }
};
struct stop_receiver {
using receiver_concept = ex::receiver_t;
ex::inplace_stop_token token;
stop_result& result;
stop_env get_env() const noexcept { return {this->token}; }

void set_value(auto&&...) && noexcept { this->result = stop_result::success; }
void set_error(auto&&) && noexcept { this->result = stop_result::failure; }
void set_stopped() && noexcept { this->result = stop_result::stopped; }
};
static_assert(ex::receiver<stop_receiver>);

} // namespace

// ----------------------------------------------------------------------------
Expand Down Expand Up @@ -154,4 +205,58 @@ int main() {
ex::then([&id1]() { assert(id1 == std::this_thread::get_id()); }));
ex::sync_wait(ex::schedule(ly::detail::any_scheduler(sched2)) |
ex::then([&id2]() { assert(id2 == std::this_thread::get_id()); }));

{
bool success{false};
bool failed{false};
bool exception{false};
ex::sync_wait(ex::schedule(ctxt1.get_scheduler(thread_context::complete::failure)) |
ex::then([&success] { success = true; }) | ex::upon_error([&failed, &exception](auto err) {
if constexpr (std::same_as<decltype(err), std::error_code>)
failed = true;
else if constexpr (std::same_as<decltype(err), std::exception_ptr>)
exception = true;
}));
assert(not success);
assert(failed);
assert(not exception);
}
{
bool success{false};
bool failed{false};
bool exception{false};
ex::sync_wait(ex::schedule(ctxt1.get_scheduler(thread_context::complete::exception)) |
ex::then([&success] { success = true; }) | ex::upon_error([&failed, &exception](auto err) {
if constexpr (std::same_as<decltype(err), std::error_code>)
failed = true;
else if constexpr (std::same_as<decltype(err), std::exception_ptr>)
exception = true;
}));
assert(not success);
assert(not failed);
assert(exception);
}
{
ex::inplace_stop_source source;
stop_result result{stop_result::none};
auto state{ex::connect(ex::schedule(ctxt1.get_scheduler(thread_context::complete::never)),
stop_receiver(source.get_token(), result))};
assert(result == stop_result::none);
ex::start(state);
assert(result == stop_result::none);
source.request_stop();
assert(result == stop_result::stopped);
}
{
ex::inplace_stop_source source;
stop_result result{stop_result::none};
auto state{
ex::connect(ex::schedule(ly::detail::any_scheduler(ctxt1.get_scheduler(thread_context::complete::never))),
stop_receiver(source.get_token(), result))};
assert(result == stop_result::none);
ex::start(state);
assert(result == stop_result::none);
source.request_stop();
assert(result == stop_result::stopped);
}
}

0 comments on commit 63ed1b6

Please sign in to comment.