diff --git a/include/stdexec/__detail/__as_awaitable.hpp b/include/stdexec/__detail/__as_awaitable.hpp index c209ff8db..d30fae306 100644 --- a/include/stdexec/__detail/__as_awaitable.hpp +++ b/include/stdexec/__detail/__as_awaitable.hpp @@ -126,6 +126,38 @@ namespace STDEXEC return static_cast<__reference_t>(__var::__get<0>(__result_)); } + template + constexpr auto __get_continuation() noexcept -> __std::coroutine_handle<> + { + if (__result_.__is_valueless()) + { + auto __hcoro = STDEXEC::__coroutine_handle_cast<_Promise>(this->__continuation_); + STDEXEC_TRY + { + // The operation completed with set_stopped, so we should not be resuming this + // coroutine at all. Instead, we need to resume the unhandled_stopped + // continuation. + return __hcoro.promise().unhandled_stopped(); + } + STDEXEC_CATCH_ALL + { + if constexpr (noexcept(__hcoro.promise().unhandled_stopped())) + { + __std::unreachable(); + } + else + { + // If unhandled_stopped() threw an exception, we need to resume the current + // coroutine so that the exception can be observed at the suspension point. + __result_.template emplace<1>(std::current_exception()); + } + // fall through to resume the current coroutine + } + } + + return this->__continuation_; + } + __std::coroutine_handle<> __continuation_; __expected_t<_Value> __result_{__no_init}; }; @@ -184,8 +216,8 @@ namespace STDEXEC void set_stopped() noexcept { - // no-op: the __result_ variant will remain engaged with the monostate - // alternative, which signals that the operation was stopped. + // no-op: the __result_ variant will remain empty, which is how we denote + // completion with set_stopped. } // Forward get_env query to the coroutine promise @@ -223,54 +255,58 @@ namespace STDEXEC constexpr void set_stopped() noexcept { - STDEXEC_TRY - { - // Resuming the stopped continuation unwinds the coroutine stack until we reach - // a promise that can handle the stopped signal. The coroutine referred to by - // __continuation_ will never be resumed. - auto __hcoro = STDEXEC::__coroutine_handle_cast<_Promise>( - this->__awaiter_.__continuation_); - __std::coroutine_handle<> __unwind = __hcoro.promise().unhandled_stopped(); - __unwind.resume(); - } - STDEXEC_CATCH_ALL - { - this->__awaiter_.__result_.template emplace<1>(std::current_exception()); - this->__awaiter_.__continuation_.resume(); - } + __done(); } private: void __done() noexcept { + auto& __awaiter = static_cast<__awaiter_t&>(this->__awaiter_); + // If __ready_ is still false when executing the CAS it means the started // operation completed before await_suspend checked whether the operation // completed. In this case resuming execution is handled by await_suspend. // Otherwise, the execution needs to be resumed from here. - auto& __awaiter = static_cast<__awaiter_t&>(this->__awaiter_); - - if (std::this_thread::get_id() != __awaiter.__starting_thread_) - { - // If we're completing on a different thread than the one that started the - // operation, we know we are completing asynchronously, so we need to resume - // the continuation from here. - __awaiter.__continuation_.resume(); - return; - } - bool __expected = false; bool const __was_ready = !__awaiter.__ready_.compare_exchange_strong(__expected, true, __std::memory_order_release, __std::memory_order_acquire); - if (__was_ready) + + // If __ready_ was already true when the CAS was executed, or if the operation is + // completing on a different thread than the one that started the operation, then + // the operation is completing asynchronously. An asynchronous completion means + // that we must resume the continuation from here (since it didn't happen in + // await_suspend()). + // + // Extra context: If __ready_ was alread true, it got set to true in + // await_suspend() immediately after the operation was started, which implies that + // this completion is happening asynchronously. But __ready_ could be false due to + // a race between the CAS in await_suspend and the CAS here, so we also need to + // check if we're completing on a different thread + bool const __async_completion = __was_ready + || std::this_thread::get_id() != __awaiter.__starting_thread_; + + // We also want to resume the continuation from here if the operation completed + // with set_stopped. Resuming the continuation in this case means resuming the + // unhandled_stopped continuation, which immediately tears down the current + // coroutine. There's no point waiting to tear it down later. Just do it. + bool const __is_stopped = __awaiter.__result_.__is_valueless(); + + if (__async_completion || __is_stopped) { - // We get here if __ready_ was true when the CAS was executed. It got set to - // true in await_suspend() immediately after the operation was started, which - // implies that this completion is happening asynchronously, so we need to - // resume the continuation from here. - __awaiter.__continuation_.resume(); + STDEXEC_TRY + { + __awaiter.template __get_continuation<_Promise>().resume(); + } + STDEXEC_CATCH_ALL + { + // In no sane world can resuming a coroutine throw an exception, but resume() + // is not marked noexcept. This call to __std::unreachable() tells the + // compiler to optimize as if it were noexcept. + __std::unreachable(); + } } } }; @@ -306,7 +342,8 @@ namespace STDEXEC } constexpr auto - await_suspend([[maybe_unused]] __std::coroutine_handle<_Promise> __hcoro) noexcept -> bool + await_suspend([[maybe_unused]] __std::coroutine_handle<_Promise> __hcoro) noexcept + -> STDEXEC_PP_IF(STDEXEC_GCC(), bool, __std::coroutine_handle<>) { STDEXEC_ASSERT(this->__continuation_ == __hcoro); @@ -323,19 +360,24 @@ namespace STDEXEC __std::memory_order_acquire); this->__ready_.notify_one(); - if (__was_ready) + if (!__was_ready) { - // The operation completed inline with set_value or set_error, so we can just - // resume the current coroutine. await_resume will either return the value or - // throw as appropriate. - return false; + // If __ready_ was still false when executing the CAS, then the operation did + // not complete inline. The continuation will be resumed when the operation + // completes, so we return a noop_coroutine to suspend the current coroutine. + return STDEXEC_PP_IF(STDEXEC_GCC(), true, __std::noop_coroutine()); } else { - // If __ready_ was still false when executing the CAS, then the operation did not - // complete inline. The continuation will be resumed when the operation - // completes, so we return a noop_coroutine to suspend the current coroutine. - return true; + // The operation completed inline, so return the continuation to resume the + // current coroutine (if the operation completed with set_value or set_error) or + // the unhandled_stopped continuation (if the operation completed with + // set_stopped). + auto const __continuation = this->template __get_continuation<_Promise>(); + return STDEXEC_PP_IF(STDEXEC_GCC(), + this->__result_.__is_valueless() ? (__continuation.resume(), true) + : false, + __continuation); } } @@ -371,23 +413,11 @@ namespace STDEXEC STDEXEC::start(__opstate); } - if (this->__result_.__is_valueless()) - { - // The operation completed with set_stopped, so we need to call - // unhandled_stopped() on the promise to propagate the stop signal. That will - // result in the coroutine being torn down, so beware. We then resume the - // returned coroutine handle (which may be a noop_coroutine). - return STDEXEC_PP_IF(STDEXEC_GCC(), - (__hcoro.promise().unhandled_stopped().resume(), true), - __hcoro.promise().unhandled_stopped()); - } - else - { - // The operation completed with set_value or set_error, so we can just resume - // the current coroutine. await_resume will either return the value or throw as - // appropriate. - return STDEXEC_PP_IF(STDEXEC_GCC(), false, __hcoro); - } + auto const __continuation = this->template __get_continuation<_Promise>(); + return STDEXEC_PP_IF(STDEXEC_GCC(), + this->__result_.__is_valueless() ? (__continuation.resume(), true) + : false, + __continuation); } private: diff --git a/include/stdexec/__detail/__task.hpp b/include/stdexec/__detail/__task.hpp index ac9c1fce8..3b2a91a0d 100644 --- a/include/stdexec/__detail/__task.hpp +++ b/include/stdexec/__detail/__task.hpp @@ -308,8 +308,6 @@ namespace STDEXEC completion_signatures<__single_value_sig_t<_Ty>, set_stopped_t()>, error_types>; - static constexpr void __sink(task) noexcept {} - template [[nodiscard]] static auto __mk_alloc(_Env const & __env) noexcept -> allocator_type @@ -609,19 +607,16 @@ namespace STDEXEC { // Move the errors out of the promise before destroying the coroutine. auto __errors = std::move(this->__errors_); - __sink(static_cast(this->__task_)); __visit(STDEXEC::set_error, std::move(__errors), static_cast<_Rcvr&&>(this->__rcvr_)); } else if constexpr (__same_as<_Ty, void>) { - __sink(static_cast(this->__task_)); STDEXEC::set_value(static_cast<_Rcvr&&>(this->__rcvr_)); } else { // Move the result out of the promise before destroying the coroutine. _Ty __result = static_cast<_Ty&&>(*this->__handle().promise().__result_); - __sink(static_cast(this->__task_)); STDEXEC::set_value(static_cast<_Rcvr&&>(this->__rcvr_), static_cast<_Ty&&>(__result)); } } @@ -630,7 +625,6 @@ namespace STDEXEC if constexpr (!__nothrow_move_constructible<_Ty> || !__nothrow_move_constructible<__error_variant_t>) { - __sink(static_cast(this->__task_)); STDEXEC::set_error(static_cast<_Rcvr&&>(this->__rcvr_), std::current_exception()); } } @@ -640,7 +634,6 @@ namespace STDEXEC auto __canceled() noexcept -> __std::coroutine_handle<> final { this->__reset_callback(); - __sink(static_cast(this->__task_)); STDEXEC::set_stopped(static_cast<_Rcvr&&>(this->__rcvr_)); return std::noop_coroutine(); } diff --git a/test/stdexec/types/test_task.cpp b/test/stdexec/types/test_task.cpp index 0c6fd3ff2..a29e97834 100644 --- a/test/stdexec/types/test_task.cpp +++ b/test/stdexec/types/test_task.cpp @@ -334,8 +334,59 @@ namespace // ex::sync_wait(std::move(t)); // } - // TODO: add tests for stop token support in task + struct inline_affine_stopped_sender + { + using sender_concept = ex::sender_tag; + using completion_signatures = ex::completion_signatures; + + template + struct operation + { + Receiver rcvr_; + + void start() & noexcept + { + ex::set_stopped(std::move(rcvr_)); + } + }; + template + auto connect(Receiver rcvr) && -> operation + { + return {std::move(rcvr)}; + } + + struct attrs + { + [[nodiscard]] + static constexpr auto query(ex::__get_completion_behavior_t) noexcept + { + return ex::__completion_behavior::__inline_completion + | ex::__completion_behavior::__asynchronous_affine; + } + }; + + [[nodiscard]] + auto get_env() const noexcept -> attrs + { + return {}; + } + }; + + TEST_CASE("task co_awaiting inline|async_affine stopped sender does not deadlock", + "[types][task]") + { + auto res = ex::sync_wait( + []() -> ex::task + { + co_await inline_affine_stopped_sender{}; + FAIL("Expected co_awaiting inline_affine_stopped_sender to stop the task"); + co_return 42; + }()); + CHECK(!res.has_value()); + } + + // TODO: add tests for stop token support in task } // anonymous namespace #endif // !STDEXEC_NO_STDCPP_COROUTINES()