libs/capy/include/boost/capy/when_all.hpp

97.0% Lines (97/100) 91.2% Functions (302/331) 96.3% Branches (26/27)
libs/capy/include/boost/capy/when_all.hpp
Line Branch Hits Source Code
1 //
2 // Copyright (c) 2026 Steve Gerbino
3 //
4 // Distributed under the Boost Software License, Version 1.0. (See accompanying
5 // file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6 //
7 // Official repository: https://github.com/cppalliance/capy
8 //
9
10 #ifndef BOOST_CAPY_WHEN_ALL_HPP
11 #define BOOST_CAPY_WHEN_ALL_HPP
12
13 #include <boost/capy/detail/config.hpp>
14 #include <boost/capy/concept/executor.hpp>
15 #include <boost/capy/concept/io_launchable_task.hpp>
16 #include <boost/capy/coro.hpp>
17 #include <boost/capy/ex/executor_ref.hpp>
18 #include <boost/capy/ex/frame_allocator.hpp>
19 #include <boost/capy/task.hpp>
20
21 #include <array>
22 #include <atomic>
23 #include <exception>
24 #include <optional>
25 #include <stop_token>
26 #include <tuple>
27 #include <type_traits>
28 #include <utility>
29
30 namespace boost {
31 namespace capy {
32
33 namespace detail {
34
35 /** Type trait to filter void types from a tuple.
36
37 Void-returning tasks do not contribute a value to the result tuple.
38 This trait computes the filtered result type.
39
40 Example: filter_void_tuple_t<int, void, string> = tuple<int, string>
41 */
42 template<typename T>
43 using wrap_non_void_t = std::conditional_t<std::is_void_v<T>, std::tuple<>, std::tuple<T>>;
44
45 template<typename... Ts>
46 using filter_void_tuple_t = decltype(std::tuple_cat(std::declval<wrap_non_void_t<Ts>>()...));
47
48 /** Holds the result of a single task within when_all.
49 */
50 template<typename T>
51 struct result_holder
52 {
53 std::optional<T> value_;
54
55 52 void set(T v)
56 {
57 52 value_ = std::move(v);
58 52 }
59
60 45 T get() &&
61 {
62 45 return std::move(*value_);
63 }
64 };
65
66 /** Specialization for void tasks - no value storage needed.
67 */
68 template<>
69 struct result_holder<void>
70 {
71 };
72
73 /** Shared state for when_all operation.
74
75 @tparam Ts The result types of the tasks.
76 */
77 template<typename... Ts>
78 struct when_all_state
79 {
80 static constexpr std::size_t task_count = sizeof...(Ts);
81
82 // Completion tracking - when_all waits for all children
83 std::atomic<std::size_t> remaining_count_;
84
85 // Result storage in input order
86 std::tuple<result_holder<Ts>...> results_;
87
88 // Runner handles - destroyed in await_resume while allocator is valid
89 std::array<coro, task_count> runner_handles_{};
90
91 // Exception storage - first error wins, others discarded
92 std::atomic<bool> has_exception_{false};
93 std::exception_ptr first_exception_;
94
95 // Stop propagation - on error, request stop for siblings
96 std::stop_source stop_source_;
97
98 // Connects parent's stop_token to our stop_source
99 struct stop_callback_fn
100 {
101 std::stop_source* source_;
102 2 void operator()() const { source_->request_stop(); }
103 };
104 using stop_callback_t = std::stop_callback<stop_callback_fn>;
105 std::optional<stop_callback_t> parent_stop_callback_;
106
107 // Parent resumption
108 coro continuation_;
109 executor_ref caller_ex_;
110
111 28 when_all_state()
112
1/1
✓ Branch 5 taken 28 times.
28 : remaining_count_(task_count)
113 {
114 28 }
115
116 28 ~when_all_state()
117 {
118
2/2
✓ Branch 0 taken 68 times.
✓ Branch 1 taken 28 times.
96 for(auto h : runner_handles_)
119
1/2
✓ Branch 1 taken 68 times.
✗ Branch 2 not taken.
68 if(h)
120 68 h.destroy();
121 28 }
122
123 /** Capture an exception (first one wins).
124 */
125 11 void capture_exception(std::exception_ptr ep)
126 {
127 11 bool expected = false;
128
2/2
✓ Branch 1 taken 8 times.
✓ Branch 2 taken 3 times.
11 if(has_exception_.compare_exchange_strong(
129 expected, true, std::memory_order_relaxed))
130 8 first_exception_ = ep;
131 11 }
132
133 /** Signal that a task has completed.
134
135 The last child to complete triggers resumption of the parent.
136 Dispatch handles thread affinity: resumes inline if on same
137 thread, otherwise posts to the caller's executor.
138 */
139 68 coro signal_completion()
140 {
141 68 auto remaining = remaining_count_.fetch_sub(1, std::memory_order_acq_rel);
142
2/2
✓ Branch 0 taken 28 times.
✓ Branch 1 taken 40 times.
68 if(remaining == 1)
143 28 caller_ex_.dispatch(continuation_);
144 68 return std::noop_coroutine();
145 }
146
147 };
148
149 /** Wrapper coroutine that intercepts task completion.
150
151 This runner awaits its assigned task and stores the result in
152 the shared state, or captures the exception and requests stop.
153 */
154 template<typename T, typename... Ts>
155 struct when_all_runner
156 {
157 struct promise_type // : frame_allocating_base // DISABLED FOR TESTING
158 {
159 when_all_state<Ts...>* state_ = nullptr;
160 executor_ref ex_;
161 std::stop_token stop_token_;
162
163 68 when_all_runner get_return_object()
164 {
165 68 return when_all_runner(std::coroutine_handle<promise_type>::from_promise(*this));
166 }
167
168 68 std::suspend_always initial_suspend() noexcept
169 {
170 68 return {};
171 }
172
173 68 auto final_suspend() noexcept
174 {
175 struct awaiter
176 {
177 promise_type* p_;
178
179 8 bool await_ready() const noexcept
180 {
181 8 return false;
182 }
183
184 8 coro await_suspend(coro) noexcept
185 {
186 // Signal completion; last task resumes parent
187 8 return p_->state_->signal_completion();
188 }
189
190 void await_resume() const noexcept
191 {
192 }
193 };
194 68 return awaiter{this};
195 }
196
197 57 void return_void()
198 {
199 57 }
200
201 11 void unhandled_exception()
202 {
203 11 state_->capture_exception(std::current_exception());
204 // Request stop for sibling tasks
205 11 state_->stop_source_.request_stop();
206 11 }
207
208 template<class Awaitable>
209 struct transform_awaiter
210 {
211 std::decay_t<Awaitable> a_;
212 promise_type* p_;
213
214 68 bool await_ready()
215 {
216 68 return a_.await_ready();
217 }
218
219 68 decltype(auto) await_resume()
220 {
221 68 return a_.await_resume();
222 }
223
224 template<class Promise>
225 68 auto await_suspend(std::coroutine_handle<Promise> h)
226 {
227
1/1
✓ Branch 3 taken 54 times.
68 return a_.await_suspend(h, p_->ex_, p_->stop_token_);
228 }
229 };
230
231 template<class Awaitable>
232 68 auto await_transform(Awaitable&& a)
233 {
234 using A = std::decay_t<Awaitable>;
235 if constexpr (IoAwaitable<A>)
236 {
237 return transform_awaiter<Awaitable>{
238 136 std::forward<Awaitable>(a), this};
239 }
240 else
241 {
242 static_assert(sizeof(A) == 0, "requires IoAwaitable");
243 }
244 68 }
245 };
246
247 std::coroutine_handle<promise_type> h_;
248
249 68 explicit when_all_runner(std::coroutine_handle<promise_type> h)
250 68 : h_(h)
251 {
252 68 }
253
254 // Enable move for all clang versions - some versions need it
255 when_all_runner(when_all_runner&& other) noexcept : h_(std::exchange(other.h_, nullptr)) {}
256
257 // Non-copyable
258 when_all_runner(when_all_runner const&) = delete;
259 when_all_runner& operator=(when_all_runner const&) = delete;
260 when_all_runner& operator=(when_all_runner&&) = delete;
261
262 68 auto release() noexcept
263 {
264 68 return std::exchange(h_, nullptr);
265 }
266 };
267
268 /** Create a runner coroutine for a single task.
269
270 Task is passed directly to ensure proper coroutine frame storage.
271 */
272 template<std::size_t Index, typename T, typename... Ts>
273 when_all_runner<T, Ts...>
274
1/1
✓ Branch 1 taken 68 times.
68 make_when_all_runner(task<T> inner, when_all_state<Ts...>* state)
275 {
276 if constexpr (std::is_void_v<T>)
277 {
278 co_await std::move(inner);
279 }
280 else
281 {
282 std::get<Index>(state->results_).set(co_await std::move(inner));
283 }
284 136 }
285
286 /** Internal awaitable that launches all runner coroutines and waits.
287
288 This awaitable is used inside the when_all coroutine to handle
289 the concurrent execution of child tasks.
290 */
291 template<typename... Ts>
292 class when_all_launcher
293 {
294 std::tuple<task<Ts>...>* tasks_;
295 when_all_state<Ts...>* state_;
296
297 public:
298 28 when_all_launcher(
299 std::tuple<task<Ts>...>* tasks,
300 when_all_state<Ts...>* state)
301 28 : tasks_(tasks)
302 28 , state_(state)
303 {
304 28 }
305
306 28 bool await_ready() const noexcept
307 {
308 28 return sizeof...(Ts) == 0;
309 }
310
311 28 coro await_suspend(coro continuation, executor_ref caller_ex, std::stop_token parent_token = {})
312 {
313 28 state_->continuation_ = continuation;
314 28 state_->caller_ex_ = caller_ex;
315
316 // Forward parent's stop requests to children
317
2/2
✓ Branch 1 taken 6 times.
✓ Branch 2 taken 22 times.
28 if(parent_token.stop_possible())
318 {
319 12 state_->parent_stop_callback_.emplace(
320 parent_token,
321 6 typename when_all_state<Ts...>::stop_callback_fn{&state_->stop_source_});
322
323
2/2
✓ Branch 1 taken 2 times.
✓ Branch 2 taken 4 times.
6 if(parent_token.stop_requested())
324 2 state_->stop_source_.request_stop();
325 }
326
327 // CRITICAL: If the last task finishes synchronously then the parent
328 // coroutine resumes, destroying its frame, and destroying this object
329 // prior to the completion of await_suspend. Therefore, await_suspend
330 // must ensure `this` cannot be referenced after calling `launch_one`
331 // for the last time.
332 28 auto token = state_->stop_source_.get_token();
333 [&]<std::size_t... Is>(std::index_sequence<Is...>) {
334
2/2
✓ Branch 2 taken 4 times.
✓ Branch 6 taken 4 times.
4 (..., launch_one<Is>(caller_ex, token));
335
2/2
✓ Branch 1 taken 24 times.
✓ Branch 1 taken 4 times.
28 }(std::index_sequence_for<Ts...>{});
336
337 // Let signal_completion() handle resumption
338 56 return std::noop_coroutine();
339 28 }
340
341 28 void await_resume() const noexcept
342 {
343 // Results are extracted by the when_all coroutine from state
344 28 }
345
346 private:
347 template<std::size_t I>
348 68 void launch_one(executor_ref caller_ex, std::stop_token token)
349 {
350
1/1
✓ Branch 2 taken 68 times.
68 auto runner = make_when_all_runner<I>(
351 68 std::move(std::get<I>(*tasks_)), state_);
352
353 68 auto h = runner.release();
354 68 h.promise().state_ = state_;
355 68 h.promise().ex_ = caller_ex;
356 68 h.promise().stop_token_ = token;
357
358 68 coro ch{h};
359 68 state_->runner_handles_[I] = ch;
360
1/1
✓ Branch 1 taken 68 times.
68 state_->caller_ex_.dispatch(ch);
361 68 }
362 };
363
364 /** Compute the result type for when_all.
365
366 Returns void when all tasks are void (P2300 aligned),
367 otherwise returns a tuple with void types filtered out.
368 */
369 template<typename... Ts>
370 using when_all_result_t = std::conditional_t<
371 std::is_same_v<filter_void_tuple_t<Ts...>, std::tuple<>>,
372 void,
373 filter_void_tuple_t<Ts...>>;
374
375 /** Helper to extract a single result, returning empty tuple for void.
376 This is a separate function to work around a GCC-11 ICE that occurs
377 when using nested immediately-invoked lambdas with pack expansion.
378 */
379 template<std::size_t I, typename... Ts>
380 47 auto extract_single_result(when_all_state<Ts...>& state)
381 {
382 using T = std::tuple_element_t<I, std::tuple<Ts...>>;
383 if constexpr (std::is_void_v<T>)
384 2 return std::tuple<>();
385 else
386
1/1
✓ Branch 4 taken 45 times.
45 return std::make_tuple(std::move(std::get<I>(state.results_)).get());
387 }
388
389 /** Extract results from state, filtering void types.
390 */
391 template<typename... Ts>
392 19 auto extract_results(when_all_state<Ts...>& state)
393 {
394 19 return [&]<std::size_t... Is>(std::index_sequence<Is...>) {
395
3/3
✓ Branch 1 taken 4 times.
✓ Branch 4 taken 4 times.
✓ Branch 7 taken 4 times.
4 return std::tuple_cat(extract_single_result<Is>(state)...);
396
1/1
✓ Branch 1 taken 19 times.
38 }(std::index_sequence_for<Ts...>{});
397 }
398
399 } // namespace detail
400
401 /** Execute multiple tasks concurrently and collect their results.
402
403 Launches all tasks simultaneously and waits for all to complete
404 before returning. Results are collected in input order. If any
405 task throws, cancellation is requested for siblings and the first
406 exception is rethrown after all tasks complete.
407
408 @li All child tasks run concurrently on the caller's executor
409 @li Results are returned as a tuple in input order
410 @li Void-returning tasks do not contribute to the result tuple
411 @li If all tasks return void, `when_all` returns `task<void>`
412 @li First exception wins; subsequent exceptions are discarded
413 @li Stop is requested for siblings on first error
414 @li Completes only after all children have finished
415
416 @par Thread Safety
417 The returned task must be awaited from a single execution context.
418 Child tasks execute concurrently but complete through the caller's
419 executor.
420
421 @param tasks The tasks to execute concurrently. Each task is
422 consumed (moved-from) when `when_all` is awaited.
423
424 @return A task yielding a tuple of non-void results. Returns
425 `task<void>` when all input tasks return void.
426
427 @par Example
428
429 @code
430 task<> example()
431 {
432 // Concurrent fetch, results collected in order
433 auto [user, posts] = co_await when_all(
434 fetch_user( id ), // task<User>
435 fetch_posts( id ) // task<std::vector<Post>>
436 );
437
438 // Void tasks don't contribute to result
439 co_await when_all(
440 log_event( "start" ), // task<void>
441 notify_user( id ) // task<void>
442 );
443 // Returns task<void>, no result tuple
444 }
445 @endcode
446
447 @see task
448 */
449 template<typename... Ts>
450 [[nodiscard]] task<detail::when_all_result_t<Ts...>>
451
1/1
✓ Branch 1 taken 28 times.
28 when_all(task<Ts>... tasks)
452 {
453 using result_type = detail::when_all_result_t<Ts...>;
454
455 // State is stored in the coroutine frame, using the frame allocator
456 detail::when_all_state<Ts...> state;
457
458 // Store tasks in the frame
459 std::tuple<task<Ts>...> task_tuple(std::move(tasks)...);
460
461 // Launch all tasks and wait for completion
462 co_await detail::when_all_launcher<Ts...>(&task_tuple, &state);
463
464 // Propagate first exception if any.
465 // Safe without explicit acquire: capture_exception() is sequenced-before
466 // signal_completion()'s acq_rel fetch_sub, which synchronizes-with the
467 // last task's decrement that resumes this coroutine.
468 if(state.first_exception_)
469 std::rethrow_exception(state.first_exception_);
470
471 // Extract and return results
472 if constexpr (std::is_void_v<result_type>)
473 co_return;
474 else
475 co_return detail::extract_results(state);
476 56 }
477
478 /// Compute the result type of `when_all` for the given task types.
479 template<typename... Ts>
480 using when_all_result_type = detail::when_all_result_t<Ts...>;
481
482 } // namespace capy
483 } // namespace boost
484
485 #endif
486