Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 91 additions & 61 deletions include/stdexec/__detail/__as_awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,38 @@ namespace STDEXEC
return static_cast<__reference_t>(__var::__get<0>(__result_));
}

template <class _Promise>
constexpr auto __get_continuation() noexcept -> __std::coroutine_handle<>
{
if (__result_.__is_valueless())
{
auto __hcoro = STDEXEC::__coroutine_handle_cast<_Promise>(this->__continuation_);
STDEXEC_TRY
{
// The operation completed with set_stopped, so we should not be resuming this
// coroutine at all. Instead, we need to resume the unhandled_stopped
// continuation.
return __hcoro.promise().unhandled_stopped();
}
STDEXEC_CATCH_ALL
{
if constexpr (noexcept(__hcoro.promise().unhandled_stopped()))
{
__std::unreachable();
}
else
{
// If unhandled_stopped() threw an exception, we need to resume the current
// coroutine so that the exception can be observed at the suspension point.
__result_.template emplace<1>(std::current_exception());
}
// fall through to resume the current coroutine
}
}

return this->__continuation_;
}

__std::coroutine_handle<> __continuation_;
__expected_t<_Value> __result_{__no_init};
};
Expand Down Expand Up @@ -184,8 +216,8 @@ namespace STDEXEC

void set_stopped() noexcept
{
// no-op: the __result_ variant will remain engaged with the monostate
// alternative, which signals that the operation was stopped.
// no-op: the __result_ variant will remain empty, which is how we denote
// completion with set_stopped.
}

// Forward get_env query to the coroutine promise
Expand Down Expand Up @@ -223,54 +255,58 @@ namespace STDEXEC

constexpr void set_stopped() noexcept
{
STDEXEC_TRY
{
// Resuming the stopped continuation unwinds the coroutine stack until we reach
// a promise that can handle the stopped signal. The coroutine referred to by
// __continuation_ will never be resumed.
auto __hcoro = STDEXEC::__coroutine_handle_cast<_Promise>(
this->__awaiter_.__continuation_);
__std::coroutine_handle<> __unwind = __hcoro.promise().unhandled_stopped();
__unwind.resume();
}
STDEXEC_CATCH_ALL
{
this->__awaiter_.__result_.template emplace<1>(std::current_exception());
this->__awaiter_.__continuation_.resume();
}
__done();
}

private:
void __done() noexcept
{
auto& __awaiter = static_cast<__awaiter_t&>(this->__awaiter_);

// If __ready_ is still false when executing the CAS it means the started
// operation completed before await_suspend checked whether the operation
// completed. In this case resuming execution is handled by await_suspend.
// Otherwise, the execution needs to be resumed from here.
auto& __awaiter = static_cast<__awaiter_t&>(this->__awaiter_);

if (std::this_thread::get_id() != __awaiter.__starting_thread_)
{
// If we're completing on a different thread than the one that started the
// operation, we know we are completing asynchronously, so we need to resume
// the continuation from here.
__awaiter.__continuation_.resume();
return;
}

bool __expected = false;
bool const __was_ready =
!__awaiter.__ready_.compare_exchange_strong(__expected,
true,
__std::memory_order_release,
__std::memory_order_acquire);
if (__was_ready)

// If __ready_ was already true when the CAS was executed, or if the operation is
// completing on a different thread than the one that started the operation, then
// the operation is completing asynchronously. An asynchronous completion means
// that we must resume the continuation from here (since it didn't happen in
// await_suspend()).
//
// Extra context: If __ready_ was alread true, it got set to true in
// await_suspend() immediately after the operation was started, which implies that
// this completion is happening asynchronously. But __ready_ could be false due to
// a race between the CAS in await_suspend and the CAS here, so we also need to
// check if we're completing on a different thread
bool const __async_completion = __was_ready
|| std::this_thread::get_id() != __awaiter.__starting_thread_;

// We also want to resume the continuation from here if the operation completed
// with set_stopped. Resuming the continuation in this case means resuming the
// unhandled_stopped continuation, which immediately tears down the current
// coroutine. There's no point waiting to tear it down later. Just do it.
bool const __is_stopped = __awaiter.__result_.__is_valueless();

if (__async_completion || __is_stopped)
{
// We get here if __ready_ was true when the CAS was executed. It got set to
// true in await_suspend() immediately after the operation was started, which
// implies that this completion is happening asynchronously, so we need to
// resume the continuation from here.
__awaiter.__continuation_.resume();
STDEXEC_TRY
{
__awaiter.template __get_continuation<_Promise>().resume();
}
STDEXEC_CATCH_ALL
{
// In no sane world can resuming a coroutine throw an exception, but resume()
// is not marked noexcept. This call to __std::unreachable() tells the
// compiler to optimize as if it were noexcept.
__std::unreachable();
}
}
}
};
Expand Down Expand Up @@ -306,7 +342,8 @@ namespace STDEXEC
}

constexpr auto
await_suspend([[maybe_unused]] __std::coroutine_handle<_Promise> __hcoro) noexcept -> bool
await_suspend([[maybe_unused]] __std::coroutine_handle<_Promise> __hcoro) noexcept
-> STDEXEC_PP_IF(STDEXEC_GCC(), bool, __std::coroutine_handle<>)
{
STDEXEC_ASSERT(this->__continuation_ == __hcoro);

Expand All @@ -323,19 +360,24 @@ namespace STDEXEC
__std::memory_order_acquire);
this->__ready_.notify_one();

if (__was_ready)
if (!__was_ready)
{
// The operation completed inline with set_value or set_error, so we can just
// resume the current coroutine. await_resume will either return the value or
// throw as appropriate.
return false;
// If __ready_ was still false when executing the CAS, then the operation did
// not complete inline. The continuation will be resumed when the operation
// completes, so we return a noop_coroutine to suspend the current coroutine.
return STDEXEC_PP_IF(STDEXEC_GCC(), true, __std::noop_coroutine());
}
else
{
// If __ready_ was still false when executing the CAS, then the operation did not
// complete inline. The continuation will be resumed when the operation
// completes, so we return a noop_coroutine to suspend the current coroutine.
return true;
// The operation completed inline, so return the continuation to resume the
// current coroutine (if the operation completed with set_value or set_error) or
// the unhandled_stopped continuation (if the operation completed with
// set_stopped).
auto const __continuation = this->template __get_continuation<_Promise>();
return STDEXEC_PP_IF(STDEXEC_GCC(),
this->__result_.__is_valueless() ? (__continuation.resume(), true)
: false,
__continuation);
}
}

Expand Down Expand Up @@ -371,23 +413,11 @@ namespace STDEXEC
STDEXEC::start(__opstate);
}

if (this->__result_.__is_valueless())
{
// The operation completed with set_stopped, so we need to call
// unhandled_stopped() on the promise to propagate the stop signal. That will
// result in the coroutine being torn down, so beware. We then resume the
// returned coroutine handle (which may be a noop_coroutine).
return STDEXEC_PP_IF(STDEXEC_GCC(),
(__hcoro.promise().unhandled_stopped().resume(), true),
__hcoro.promise().unhandled_stopped());
}
else
{
// The operation completed with set_value or set_error, so we can just resume
// the current coroutine. await_resume will either return the value or throw as
// appropriate.
return STDEXEC_PP_IF(STDEXEC_GCC(), false, __hcoro);
}
auto const __continuation = this->template __get_continuation<_Promise>();
return STDEXEC_PP_IF(STDEXEC_GCC(),
this->__result_.__is_valueless() ? (__continuation.resume(), true)
: false,
__continuation);
}

private:
Expand Down
7 changes: 0 additions & 7 deletions include/stdexec/__detail/__task.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -308,8 +308,6 @@ namespace STDEXEC
completion_signatures<__single_value_sig_t<_Ty>, set_stopped_t()>,
error_types>;

static constexpr void __sink(task) noexcept {}

template <class _Env>
[[nodiscard]]
static auto __mk_alloc(_Env const & __env) noexcept -> allocator_type
Expand Down Expand Up @@ -609,19 +607,16 @@ namespace STDEXEC
{
// Move the errors out of the promise before destroying the coroutine.
auto __errors = std::move(this->__errors_);
__sink(static_cast<task&&>(this->__task_));
__visit(STDEXEC::set_error, std::move(__errors), static_cast<_Rcvr&&>(this->__rcvr_));
}
else if constexpr (__same_as<_Ty, void>)
{
__sink(static_cast<task&&>(this->__task_));
STDEXEC::set_value(static_cast<_Rcvr&&>(this->__rcvr_));
}
else
{
// Move the result out of the promise before destroying the coroutine.
_Ty __result = static_cast<_Ty&&>(*this->__handle().promise().__result_);
__sink(static_cast<task&&>(this->__task_));
STDEXEC::set_value(static_cast<_Rcvr&&>(this->__rcvr_), static_cast<_Ty&&>(__result));
}
}
Expand All @@ -630,7 +625,6 @@ namespace STDEXEC
if constexpr (!__nothrow_move_constructible<_Ty>
|| !__nothrow_move_constructible<__error_variant_t>)
{
__sink(static_cast<task&&>(this->__task_));
STDEXEC::set_error(static_cast<_Rcvr&&>(this->__rcvr_), std::current_exception());
}
}
Expand All @@ -640,7 +634,6 @@ namespace STDEXEC
auto __canceled() noexcept -> __std::coroutine_handle<> final
{
this->__reset_callback();
__sink(static_cast<task&&>(this->__task_));
STDEXEC::set_stopped(static_cast<_Rcvr&&>(this->__rcvr_));
return std::noop_coroutine();
}
Expand Down
53 changes: 52 additions & 1 deletion test/stdexec/types/test_task.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,8 +334,59 @@ namespace
// ex::sync_wait(std::move(t));
// }

// TODO: add tests for stop token support in task
struct inline_affine_stopped_sender
{
using sender_concept = ex::sender_tag;
using completion_signatures = ex::completion_signatures<ex::set_stopped_t()>;

template <class Receiver>
struct operation
{
Receiver rcvr_;

void start() & noexcept
{
ex::set_stopped(std::move(rcvr_));
}
};

template <class Receiver>
auto connect(Receiver rcvr) && -> operation<Receiver>
{
return {std::move(rcvr)};
}

struct attrs
{
[[nodiscard]]
static constexpr auto query(ex::__get_completion_behavior_t<ex::set_stopped_t>) noexcept
{
return ex::__completion_behavior::__inline_completion
| ex::__completion_behavior::__asynchronous_affine;
}
};

[[nodiscard]]
auto get_env() const noexcept -> attrs
{
return {};
}
};

TEST_CASE("task co_awaiting inline|async_affine stopped sender does not deadlock",
"[types][task]")
{
auto res = ex::sync_wait(
[]() -> ex::task<int>
{
co_await inline_affine_stopped_sender{};
FAIL("Expected co_awaiting inline_affine_stopped_sender to stop the task");
co_return 42;
}());
CHECK(!res.has_value());
}

// TODO: add tests for stop token support in task
} // anonymous namespace

#endif // !STDEXEC_NO_STDCPP_COROUTINES()
Loading