diff --git a/include/stdexec/__detail/__when_all.hpp b/include/stdexec/__detail/__when_all.hpp index ea6844419..6f5be1edf 100644 --- a/include/stdexec/__detail/__when_all.hpp +++ b/include/stdexec/__detail/__when_all.hpp @@ -51,17 +51,15 @@ namespace STDEXEC //! their value datums. //! //! @c when_all is the canonical *parallel composition* primitive in the - //! sender model. You give it one or more senders; it returns a single + //! sender model. You give it zero or more senders; it returns a single //! sender that, when connected and started, starts *all* of the input //! senders concurrently. When every input has completed, @c when_all's //! sender completes with a value tuple that is the concatenation of every //! input's value datums. //! - //! If any one input fails or is stopped, @c when_all requests stop on the - //! others (via an internal @c inplace_stop_source) and completes with - //! that error (or with @c set_stopped). This makes @c when_all naturally - //! fail-fast: as soon as one branch has gone bad, the rest are asked to - //! wind down. + //! If any one input can fail or be stopped, @c when_all uses an internal + //! @c inplace_stop_source. An unhappy completion requests stop on the + //! other inputs before the combined operation completes. //! //! @code{.cpp} //! auto s = stdexec::when_all( @@ -101,8 +99,7 @@ namespace STDEXEC //! set_value_t(V1..., V2..., ..., Vn...) // concatenation of every input //! set_error_t(Eij)... // union across all inputs //! set_error_t(std::exception_ptr) // added if any decay-copy may throw - //! set_stopped_t() // added if any input has it, - //! // or if cancellation may happen + //! set_stopped_t() // added if any input has it //! @endcode //! //! The value datums of each input are decay-copied into the resulting @@ -163,19 +160,34 @@ namespace STDEXEC //! input has completed. //! //! @tparam _Senders A pack of types each satisfying @c stdexec::sender. - //! Must be non-empty. Each must have exactly one - //! @c set_value_t completion signature in the - //! ambient environment. + //! Each must have exactly one @c set_value_t completion + //! signature in the ambient environment. //! //! @param __sndrs The senders to compose. Forwarded into the result. //! - //! @returns A sender that, when connected and started, concurrently - //! starts every input and value-completes with the - //! concatenation of the input's value datums. - template - constexpr auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto + //! @returns @c just() for no inputs, the input sender for one input, or a + //! sender that concurrently starts every input and concatenates + //! their value datums for two or more inputs. + constexpr auto operator()() const noexcept + { + return just(); + } + + template + constexpr auto operator()(_Sender&& __sndr) const noexcept(__nothrow_decay_copyable<_Sender>) { - return __make_sexpr(__(), static_cast<_Senders&&>(__sndrs)...); + return static_cast<_Sender&&>(__sndr); + } + + template + constexpr auto operator()(_Sender0&& __sndr0, _Sender1&& __sndr1, _Senders&&... __sndrs) const + noexcept(__nothrow_decay_copyable<_Sender0, _Sender1, _Senders...>) -> __well_formed_sender + auto + { + return __make_sexpr(__(), + static_cast<_Sender0&&>(__sndr0), + static_cast<_Sender1&&>(__sndr1), + static_cast<_Senders&&>(__sndrs)...); } }; @@ -394,8 +406,8 @@ namespace STDEXEC } template - using __env_t = decltype(__when_all::__mk_env(__declval<_Env>(), - __declval())); + using __stoppable_env_t = decltype(__when_all::__mk_env(__declval<_Env>(), + __declval())); template concept __max1_sender = @@ -423,6 +435,17 @@ namespace STDEXEC using __nothrow_decay_copyable_results_t = STDEXEC::__nothrow_decay_copyable_results_t<__completion_signatures_of_t<_Sender, _Env...>>; + template + inline constexpr bool __can_fail = !__never_sends + || sends_stopped<_Sender, _Env> + || !__nothrow_decay_copyable_results_t<_Sender, _Env>::value; + + template + inline constexpr bool __uses_stop_source = (__can_fail<_Senders, _Env> || ...); + + template + using __env_t = __if_c<__uses_stop_source<_Env, _Senders...>, __stoppable_env_t<_Env>, _Env>; + template struct __completions { @@ -460,6 +483,23 @@ namespace STDEXEC __concat_completion_signatures_t>...>; }; + template + struct __completions_for; + + template <> + struct __completions_for<> + { + template + using __f = __completions<>::template __f<_Senders...>; + }; + + template + struct __completions_for<_Env> + { + template + using __f = __completions<__env_t<_Env, _Senders...>>::template __f<_Senders...>; + }; + template constexpr void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept { @@ -472,29 +512,33 @@ namespace STDEXEC static_cast<_ValuesTuple&&>(__values)); } - template - using __values_opt_tuple_t = - value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>; + template + using __values_opt_tuple_t = value_types_of_t<_Sender, _ChildEnv, __decayed_tuple, __optional>; - template >... _Senders> + template + requires(__max1_sender<_Senders, __env_t<_Env, _Senders...>> && ...) struct __traits { + using __child_env = __env_t<_Env, _Senders...>; + // tuple>, optional>, ...> - using __values_tuple = __minvoke< - __mwith_default<__mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>, __q<__tuple>>, - __ignore>, - _Senders...>; + using __values_tuple = + __minvoke<__mwith_default< + __mtransform<__mbind_front_q<__values_opt_tuple_t, __child_env>, __q<__tuple>>, + __ignore>, + _Senders...>; using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>; using __errors_list = __minvoke<__mconcat<>, - __if<__mand<__nothrow_decay_copyable_results_t<_Senders, _Env>...>, + __if<__mand<__nothrow_decay_copyable_results_t<_Senders, __child_env>...>, __mlist<>, __mlist>, - __error_types_of_t<_Senders, __env_t<_Env>, __q<__mlist>>...>; + __error_types_of_t<_Senders, __child_env, __q<__mlist>>...>; - using __errors_variant = __mapply<__q<__uniqued_variant>, __errors_list>; + using __errors_variant = __mapply<__q<__uniqued_variant>, __errors_list>; + static constexpr bool __uses_stop_source = __when_all::__uses_stop_source<_Env, _Senders...>; }; struct _INVALID_ARGUMENTS_TO_WHEN_ALL_ @@ -515,7 +559,10 @@ namespace STDEXEC // error state, which trumps cancellation.) if (__state_->__state_.compare_exchange_strong(__expected, __stopped)) { - __state_->__stop_source_.request_stop(); + if constexpr (_State::__uses_stop_source) + { + __state_->__stop_source_.request_stop(); + } } // Arrive in order to decrement the count again and complete if needed. @@ -525,12 +572,19 @@ namespace STDEXEC _State* __state_; }; - template + template struct __state { - using __receiver_t = _Receiver; + using __receiver_t = _Receiver; + static constexpr bool __uses_stop_source = _UsesStopSource; using __stop_callback_t = stop_callback_for_t>, __forward_stop_request<__state>>; + using __stop_source_t = __if_c<_UsesStopSource, inplace_stop_source, __empty>; + using __on_stop_t = __if_c<_UsesStopSource, __optional<__stop_callback_t>, __empty>; constexpr void __arrive() noexcept { @@ -543,7 +597,10 @@ namespace STDEXEC constexpr void __complete() noexcept { // Stop callback is no longer needed. Destroy it. - __on_stop_.reset(); + if constexpr (_UsesStopSource) + { + __on_stop_.reset(); + } // All child operations have completed and arrived at the barrier. switch (__state_.load(__std::memory_order_relaxed)) { @@ -579,13 +636,15 @@ namespace STDEXEC STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS _Receiver __rcvr_; __std::atomic __count_; - inplace_stop_source __stop_source_{}; + STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS + __stop_source_t __stop_source_{}; // Could be non-atomic here and atomic_ref everywhere except __completion_fn - __std::atomic<__state_t> __state_{__started}; - _ErrorsVariant __errors_{__no_init}; + __std::atomic<__state_t> __state_{__started}; + _ErrorsVariant __errors_{__no_init}; + STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS + _ValuesTuple __values_{}; STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS - _ValuesTuple __values_{}; - __optional<__stop_callback_t> __on_stop_{}; + __on_stop_t __on_stop_{}; }; template @@ -622,18 +681,12 @@ namespace STDEXEC } }; - // A when_all with no senders completes inline with no values. template <> struct __attrs<> { + template [[nodiscard]] - constexpr auto query(__get_completion_behavior_t) const noexcept - { - return __completion_behavior::__inline_completion; - } - - [[nodiscard]] - constexpr auto query(__get_completion_behavior_t) const noexcept + constexpr auto query(__get_completion_behavior_t<_Tag>) const noexcept { return __completion_behavior::__inline_completion; } @@ -642,17 +695,18 @@ namespace STDEXEC template static constexpr auto __mk_state_fn(_Receiver&& __rcvr) noexcept { - return [&]<__max1_sender<__env_t>>... _Child>(__ignore, - __ignore, - _Child&&...) noexcept + return [&](__ignore, __ignore, _Child&&...) noexcept + requires(__max1_sender<_Child, __env_t, _Child...>> && ...) { using _Traits = __traits, _Child...>; using _ErrorsVariant = _Traits::__errors_variant; using _ValuesTuple = _Traits::__values_tuple; + using _ChildEnv = _Traits::__child_env; using _State = __state<_ErrorsVariant, _ValuesTuple, _Receiver, - (sends_stopped<_Child, env_of_t<_Receiver>> || ...)>; + (sends_stopped<_Child, _ChildEnv> || ...), + _Traits::__uses_stop_source>; return _State{static_cast<_Receiver&&>(__rcvr), sizeof...(_Child)}; }; } @@ -663,7 +717,7 @@ namespace STDEXEC struct __when_all_impl : __sexpr_defaults { template - using __completions_t = __children_of<_Self, __when_all::__completions<__env_t<_Env>...>>; + using __completions_t = __children_of<_Self, __when_all::__completions_for<_Env...>>; static constexpr auto __get_attrs = [](__ignore, __ignore, _Child const &...) noexcept @@ -694,9 +748,15 @@ namespace STDEXEC } static constexpr auto __get_env = [](__ignore, _State const & __state) noexcept - -> __env_t> { - return __when_all::__mk_env(STDEXEC::get_env(__state.__rcvr_), __state.__stop_source_); + if constexpr (_State::__uses_stop_source) + { + return __when_all::__mk_env(STDEXEC::get_env(__state.__rcvr_), __state.__stop_source_); + } + else + { + return STDEXEC::get_env(__state.__rcvr_); + } }; static constexpr auto __get_state = @@ -711,19 +771,18 @@ namespace STDEXEC [](_State& __state, _Operations&... __child_ops) noexcept -> void { - // register stop callback: - __state.__on_stop_.emplace(get_stop_token(STDEXEC::get_env(__state.__rcvr_)), - __forward_stop_request<_State>{&__state}); - (STDEXEC::start(__child_ops), ...); - if constexpr (sizeof...(__child_ops) == 0) + if constexpr (_State::__uses_stop_source) { - __state.__complete(); + __state.__on_stop_.emplace(get_stop_token(STDEXEC::get_env(__state.__rcvr_)), + __forward_stop_request<_State>{&__state}); } + (STDEXEC::start(__child_ops), ...); }; template static constexpr void __set_error(_State& __state, _Error&& __err) noexcept { + static_assert(_State::__uses_stop_source); // Transition to the "error" state and switch on the prior state. // TODO: What memory orderings are actually needed here? switch (__state.__state_.exchange(__error)) @@ -769,6 +828,7 @@ namespace STDEXEC } else if constexpr (__same_as<_Set, set_stopped_t>) { + static_assert(_State::__uses_stop_source); __state_t __expected = __started; // Transition to the "stopped" state if and only if we're in the // "started" state. (If this fails, it's because we're in an diff --git a/test/stdexec/algos/adaptors/test_when_all.cpp b/test/stdexec/algos/adaptors/test_when_all.cpp index a5f55df97..bd64442d3 100644 --- a/test/stdexec/algos/adaptors/test_when_all.cpp +++ b/test/stdexec/algos/adaptors/test_when_all.cpp @@ -33,6 +33,40 @@ namespace ex = STDEXEC; namespace { + struct stop_sensitive_sender + { + using sender_concept = ex::sender_tag; + + template + static consteval auto get_completion_signatures() + { + if constexpr (ex::unstoppable_token>) + { + return ex::completion_signatures{}; + } + else + { + return ex::completion_signatures{}; + } + } + + template + struct operation + { + Receiver receiver_; + + void start() & noexcept + { + ex::set_value(std::move(receiver_)); + } + }; + + template + auto connect(Receiver receiver) const noexcept -> operation + { + return {std::move(receiver)}; + } + }; TEST_CASE("when_all returns a sender", "[adaptors][when_all]") { @@ -42,6 +76,34 @@ namespace (void) snd; } + TEST_CASE("when_all coalesces empty and unary calls", "[adaptors][when_all]") + { + using empty_t = decltype(ex::when_all()); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(!exec::sender_for); + STATIC_REQUIRE(noexcept(ex::when_all())); + + auto child = ex::just(42); + using unary_t = decltype(ex::when_all(child)); + STATIC_REQUIRE(std::same_as); + STATIC_REQUIRE(!exec::sender_for); + STATIC_REQUIRE(noexcept(ex::when_all(child))); + + auto multi = ex::when_all(ex::just(1), ex::just(2)); + STATIC_REQUIRE(exec::sender_for); + } + + TEST_CASE("unary when_all preserves every completion channel", "[adaptors][when_all]") + { + wait_for_value(ex::when_all(ex::just(42)), 42); + + auto error_op = ex::connect(ex::when_all(ex::just_error(42)), expect_error_receiver{42}); + ex::start(error_op); + + auto stopped_op = ex::connect(ex::when_all(ex::just_stopped()), expect_stopped_receiver{}); + ex::start(stopped_op); + } + TEST_CASE("when_all with environment returns a sender", "[adaptors][when_all]") { auto snd = ex::when_all(ex::just(3), ex::just(0.1415)); @@ -453,6 +515,80 @@ namespace ex::start(op); } + TEST_CASE("infallible when_all children retain the receiver stop token", "[adaptors][when_all]") + { + auto observes_stop_possible = ex::read_env(ex::get_stop_token) + | ex::then([](auto token) noexcept + { return token.stop_possible(); }); + + auto unstoppable = ex::when_all(observes_stop_possible, observes_stop_possible); + static_assert(set_equivalent>, + ex::completion_signatures>); + wait_for_value(std::move(unstoppable), false, false); + + ex::inplace_stop_source source; + auto stoppable = ex::when_all(observes_stop_possible, observes_stop_possible); + auto env = ex::prop(ex::get_stop_token, source.get_token()); + static_assert(set_equivalent, + ex::completion_signatures>); + auto op = ex::connect(std::move(stoppable), expect_value_receiver{env_tag{}, env, true, true}); + ex::start(op); + } + + TEST_CASE("when_all publishes environment-sensitive completion signatures", + "[adaptors][when_all]") + { + auto snd = ex::when_all(stop_sensitive_sender{}, stop_sensitive_sender{}); + static_assert(set_equivalent>, + ex::completion_signatures>); + + ex::inplace_stop_source source; + auto env = ex::prop(ex::get_stop_token, source.get_token()); + static_assert( + set_equivalent, + ex::completion_signatures>); + } + + TEST_CASE("when_all publishes storage failure as exception_ptr", "[adaptors][when_all]") + { + auto snd = ex::when_all(ex::just(), ex::just(potentially_throwing{})); + static_assert(set_equivalent>, + ex::completion_signatures>); + + ex::inplace_stop_source source; + auto env = ex::prop(ex::get_stop_token, source.get_token()); + static_assert(set_equivalent, + ex::completion_signatures>); + } + + TEST_CASE("fallible when_all children receive an internal stop token", "[adaptors][when_all]") + { + bool observed_stop_possible = false; + auto observer = ex::read_env(ex::get_stop_token) + | ex::then([&](auto token) noexcept + { observed_stop_possible = token.stop_possible(); }); + auto snd = ex::when_all(std::move(observer), ex::just_error(42)); + auto op = ex::connect(std::move(snd), expect_error_receiver{42}); + ex::start(op); + CHECK(observed_stop_possible); + } + + TEST_CASE("potentially throwing when_all result storage uses an internal stop token", + "[adaptors][when_all]") + { + bool observed_stop_possible = false; + auto observer = ex::read_env(ex::get_stop_token) + | ex::then([&](auto token) noexcept + { observed_stop_possible = token.stop_possible(); }); + auto snd = ex::when_all(std::move(observer), ex::just(potentially_throwing{})) + | ex::then([](potentially_throwing) noexcept {}); + auto op = ex::connect(std::move(snd), expect_void_receiver{}); + ex::start(op); + CHECK(observed_stop_possible); + } + TEST_CASE("when_all handles stop requests from the environment correctly", "[adaptors][when_all]") { auto snd = ex::when_all(completes_if(false), completes_if(false));