Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 120 additions & 60 deletions include/stdexec/__detail/__when_all.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,17 +51,15 @@ namespace STDEXEC
//! their value datums.
//!
//! @c when_all is the canonical *parallel composition* primitive in the
//! sender model. You give it one or more senders; it returns a single
//! sender model. You give it zero or more senders; it returns a single
//! sender that, when connected and started, starts *all* of the input
//! senders concurrently. When every input has completed, @c when_all's
//! sender completes with a value tuple that is the concatenation of every
//! input's value datums.
//!
//! If any one input fails or is stopped, @c when_all requests stop on the
//! others (via an internal @c inplace_stop_source) and completes with
//! that error (or with @c set_stopped). This makes @c when_all naturally
//! fail-fast: as soon as one branch has gone bad, the rest are asked to
//! wind down.
//! If any one input can fail or be stopped, @c when_all uses an internal
//! @c inplace_stop_source. An unhappy completion requests stop on the
//! other inputs before the combined operation completes.
//!
//! @code{.cpp}
//! auto s = stdexec::when_all(
Expand Down Expand Up @@ -101,8 +99,7 @@ namespace STDEXEC
//! set_value_t(V1..., V2..., ..., Vn...) // concatenation of every input
//! set_error_t(Eij)... // union across all inputs
//! set_error_t(std::exception_ptr) // added if any decay-copy may throw
//! set_stopped_t() // added if any input has it,
//! // or if cancellation may happen
//! set_stopped_t() // added if any input has it
//! @endcode
//!
//! The value datums of each input are decay-copied into the resulting
Expand Down Expand Up @@ -163,19 +160,34 @@ namespace STDEXEC
//! input has completed.
//!
//! @tparam _Senders A pack of types each satisfying @c stdexec::sender.
//! Must be non-empty. Each must have exactly one
//! @c set_value_t completion signature in the
//! ambient environment.
//! Each must have exactly one @c set_value_t completion
//! signature in the ambient environment.
//!
//! @param __sndrs The senders to compose. Forwarded into the result.
//!
//! @returns A sender that, when connected and started, concurrently
//! starts every input and value-completes with the
//! concatenation of the input's value datums.
template <sender... _Senders>
constexpr auto operator()(_Senders&&... __sndrs) const -> __well_formed_sender auto
//! @returns @c just() for no inputs, the input sender for one input, or a
//! sender that concurrently starts every input and concatenates
//! their value datums for two or more inputs.
constexpr auto operator()() const noexcept
{
return just();
}

template <sender _Sender>
constexpr auto operator()(_Sender&& __sndr) const noexcept(__nothrow_decay_copyable<_Sender>)
{
return __make_sexpr<when_all_t>(__(), static_cast<_Senders&&>(__sndrs)...);
return static_cast<_Sender&&>(__sndr);
}

template <sender _Sender0, sender _Sender1, sender... _Senders>
constexpr auto operator()(_Sender0&& __sndr0, _Sender1&& __sndr1, _Senders&&... __sndrs) const
noexcept(__nothrow_decay_copyable<_Sender0, _Sender1, _Senders...>) -> __well_formed_sender
auto
{
return __make_sexpr<when_all_t>(__(),
static_cast<_Sender0&&>(__sndr0),
static_cast<_Sender1&&>(__sndr1),
static_cast<_Senders&&>(__sndrs)...);
}
};

Expand Down Expand Up @@ -394,8 +406,8 @@ namespace STDEXEC
}

template <class _Env>
using __env_t = decltype(__when_all::__mk_env(__declval<_Env>(),
__declval<inplace_stop_source&>()));
using __stoppable_env_t = decltype(__when_all::__mk_env(__declval<_Env>(),
__declval<inplace_stop_source&>()));

template <class _Sender, class _Env>
concept __max1_sender =
Expand Down Expand Up @@ -423,6 +435,17 @@ namespace STDEXEC
using __nothrow_decay_copyable_results_t =
STDEXEC::__nothrow_decay_copyable_results_t<__completion_signatures_of_t<_Sender, _Env...>>;

template <class _Sender, class _Env>
inline constexpr bool __can_fail = !__never_sends<set_error_t, _Sender, _Env>
|| sends_stopped<_Sender, _Env>
|| !__nothrow_decay_copyable_results_t<_Sender, _Env>::value;

template <class _Env, class... _Senders>
inline constexpr bool __uses_stop_source = (__can_fail<_Senders, _Env> || ...);

template <class _Env, class... _Senders>
using __env_t = __if_c<__uses_stop_source<_Env, _Senders...>, __stoppable_env_t<_Env>, _Env>;

template <class... _Env>
struct __completions
{
Expand Down Expand Up @@ -460,6 +483,23 @@ namespace STDEXEC
__concat_completion_signatures_t>...>;
};

template <class... _Env>
struct __completions_for;

template <>
struct __completions_for<>
{
template <class... _Senders>
using __f = __completions<>::template __f<_Senders...>;
};

template <class _Env>
struct __completions_for<_Env>
{
template <class... _Senders>
using __f = __completions<__env_t<_Env, _Senders...>>::template __f<_Senders...>;
};

template <class _Receiver, class _ValuesTuple>
constexpr void __set_values(_Receiver& __rcvr, _ValuesTuple& __values) noexcept
{
Expand All @@ -472,29 +512,33 @@ namespace STDEXEC
static_cast<_ValuesTuple&&>(__values));
}

template <class _Env, class _Sender>
using __values_opt_tuple_t =
value_types_of_t<_Sender, __env_t<_Env>, __decayed_tuple, __optional>;
template <class _ChildEnv, class _Sender>
using __values_opt_tuple_t = value_types_of_t<_Sender, _ChildEnv, __decayed_tuple, __optional>;

template <class _Env, __max1_sender<__env_t<_Env>>... _Senders>
template <class _Env, class... _Senders>
requires(__max1_sender<_Senders, __env_t<_Env, _Senders...>> && ...)
struct __traits
{
using __child_env = __env_t<_Env, _Senders...>;

// tuple<optional<tuple<Vs1...>>, optional<tuple<Vs2...>>, ...>
using __values_tuple = __minvoke<
__mwith_default<__mtransform<__mbind_front_q<__values_opt_tuple_t, _Env>, __q<__tuple>>,
__ignore>,
_Senders...>;
using __values_tuple =
__minvoke<__mwith_default<
__mtransform<__mbind_front_q<__values_opt_tuple_t, __child_env>, __q<__tuple>>,
__ignore>,
_Senders...>;

using __collect_errors = __mbind_front_q<__mset_insert, __mset<>>;

using __errors_list =
__minvoke<__mconcat<>,
__if<__mand<__nothrow_decay_copyable_results_t<_Senders, _Env>...>,
__if<__mand<__nothrow_decay_copyable_results_t<_Senders, __child_env>...>,
__mlist<>,
__mlist<std::exception_ptr>>,
__error_types_of_t<_Senders, __env_t<_Env>, __q<__mlist>>...>;
__error_types_of_t<_Senders, __child_env, __q<__mlist>>...>;

using __errors_variant = __mapply<__q<__uniqued_variant>, __errors_list>;
using __errors_variant = __mapply<__q<__uniqued_variant>, __errors_list>;
static constexpr bool __uses_stop_source = __when_all::__uses_stop_source<_Env, _Senders...>;
};

struct _INVALID_ARGUMENTS_TO_WHEN_ALL_
Expand All @@ -515,7 +559,10 @@ namespace STDEXEC
// error state, which trumps cancellation.)
if (__state_->__state_.compare_exchange_strong(__expected, __stopped))
{
__state_->__stop_source_.request_stop();
if constexpr (_State::__uses_stop_source)
{
__state_->__stop_source_.request_stop();
}
}

// Arrive in order to decrement the count again and complete if needed.
Expand All @@ -525,12 +572,19 @@ namespace STDEXEC
_State* __state_;
};

template <class _ErrorsVariant, class _ValuesTuple, class _Receiver, bool _SendsStopped>
template <class _ErrorsVariant,
class _ValuesTuple,
class _Receiver,
bool _SendsStopped,
bool _UsesStopSource>
struct __state
{
using __receiver_t = _Receiver;
using __receiver_t = _Receiver;
static constexpr bool __uses_stop_source = _UsesStopSource;
using __stop_callback_t =
stop_callback_for_t<stop_token_of_t<env_of_t<_Receiver>>, __forward_stop_request<__state>>;
using __stop_source_t = __if_c<_UsesStopSource, inplace_stop_source, __empty>;
using __on_stop_t = __if_c<_UsesStopSource, __optional<__stop_callback_t>, __empty>;

constexpr void __arrive() noexcept
{
Expand All @@ -543,7 +597,10 @@ namespace STDEXEC
constexpr void __complete() noexcept
{
// Stop callback is no longer needed. Destroy it.
__on_stop_.reset();
if constexpr (_UsesStopSource)
{
__on_stop_.reset();
}
// All child operations have completed and arrived at the barrier.
switch (__state_.load(__std::memory_order_relaxed))
{
Expand Down Expand Up @@ -579,13 +636,15 @@ namespace STDEXEC
STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS
_Receiver __rcvr_;
__std::atomic<std::size_t> __count_;
inplace_stop_source __stop_source_{};
STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS
__stop_source_t __stop_source_{};
// Could be non-atomic here and atomic_ref everywhere except __completion_fn
__std::atomic<__state_t> __state_{__started};
_ErrorsVariant __errors_{__no_init};
__std::atomic<__state_t> __state_{__started};
_ErrorsVariant __errors_{__no_init};
STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS
_ValuesTuple __values_{};
STDEXEC_IMMOVABLE_NO_UNIQUE_ADDRESS
_ValuesTuple __values_{};
__optional<__stop_callback_t> __on_stop_{};
__on_stop_t __on_stop_{};
};

template <class... _Senders>
Expand Down Expand Up @@ -622,18 +681,12 @@ namespace STDEXEC
}
};

// A when_all with no senders completes inline with no values.
template <>
struct __attrs<>
{
template <class _Tag>
[[nodiscard]]
constexpr auto query(__get_completion_behavior_t<set_value_t>) const noexcept
{
return __completion_behavior::__inline_completion;
}

[[nodiscard]]
constexpr auto query(__get_completion_behavior_t<set_stopped_t>) const noexcept
constexpr auto query(__get_completion_behavior_t<_Tag>) const noexcept
{
return __completion_behavior::__inline_completion;
}
Expand All @@ -642,17 +695,18 @@ namespace STDEXEC
template <class _Receiver>
static constexpr auto __mk_state_fn(_Receiver&& __rcvr) noexcept
{
return [&]<__max1_sender<__env_t<env_of_t<_Receiver>>>... _Child>(__ignore,
__ignore,
_Child&&...) noexcept
return [&]<class... _Child>(__ignore, __ignore, _Child&&...) noexcept
requires(__max1_sender<_Child, __env_t<env_of_t<_Receiver>, _Child...>> && ...)
{
using _Traits = __traits<env_of_t<_Receiver>, _Child...>;
using _ErrorsVariant = _Traits::__errors_variant;
using _ValuesTuple = _Traits::__values_tuple;
using _ChildEnv = _Traits::__child_env;
using _State = __state<_ErrorsVariant,
_ValuesTuple,
_Receiver,
(sends_stopped<_Child, env_of_t<_Receiver>> || ...)>;
(sends_stopped<_Child, _ChildEnv> || ...),
_Traits::__uses_stop_source>;
return _State{static_cast<_Receiver&&>(__rcvr), sizeof...(_Child)};
};
}
Expand All @@ -663,7 +717,7 @@ namespace STDEXEC
struct __when_all_impl : __sexpr_defaults
{
template <class _Self, class... _Env>
using __completions_t = __children_of<_Self, __when_all::__completions<__env_t<_Env>...>>;
using __completions_t = __children_of<_Self, __when_all::__completions_for<_Env...>>;

static constexpr auto __get_attrs =
[]<class... _Child>(__ignore, __ignore, _Child const &...) noexcept
Expand Down Expand Up @@ -694,9 +748,15 @@ namespace STDEXEC
}

static constexpr auto __get_env = []<class _State>(__ignore, _State const & __state) noexcept
-> __env_t<env_of_t<typename _State::__receiver_t const &>>
{
return __when_all::__mk_env(STDEXEC::get_env(__state.__rcvr_), __state.__stop_source_);
if constexpr (_State::__uses_stop_source)
{
return __when_all::__mk_env(STDEXEC::get_env(__state.__rcvr_), __state.__stop_source_);
}
else
{
return STDEXEC::get_env(__state.__rcvr_);
}
};

static constexpr auto __get_state =
Expand All @@ -711,19 +771,18 @@ namespace STDEXEC
[]<class _State, class... _Operations>(_State& __state,
_Operations&... __child_ops) noexcept -> void
{
// register stop callback:
__state.__on_stop_.emplace(get_stop_token(STDEXEC::get_env(__state.__rcvr_)),
__forward_stop_request<_State>{&__state});
(STDEXEC::start(__child_ops), ...);
if constexpr (sizeof...(__child_ops) == 0)
if constexpr (_State::__uses_stop_source)
{
__state.__complete();
__state.__on_stop_.emplace(get_stop_token(STDEXEC::get_env(__state.__rcvr_)),
__forward_stop_request<_State>{&__state});
}
(STDEXEC::start(__child_ops), ...);
};

template <class _State, class _Error>
static constexpr void __set_error(_State& __state, _Error&& __err) noexcept
{
static_assert(_State::__uses_stop_source);
// Transition to the "error" state and switch on the prior state.
// TODO: What memory orderings are actually needed here?
switch (__state.__state_.exchange(__error))
Expand Down Expand Up @@ -769,6 +828,7 @@ namespace STDEXEC
}
else if constexpr (__same_as<_Set, set_stopped_t>)
{
static_assert(_State::__uses_stop_source);
__state_t __expected = __started;
// Transition to the "stopped" state if and only if we're in the
// "started" state. (If this fails, it's because we're in an
Expand Down
Loading