1  
//
1  
//
2  
// Copyright (c) 2025 Vinnie Falco (vinnie dot falco at gmail dot com)
2  
// Copyright (c) 2025 Vinnie Falco (vinnie dot falco at gmail dot com)
3  
//
3  
//
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
4  
// Distributed under the Boost Software License, Version 1.0. (See accompanying
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
5  
// file LICENSE_1_0.txt or copy at http://www.boost.org/LICENSE_1_0.txt)
6  
//
6  
//
7  
// Official repository: https://github.com/cppalliance/capy
7  
// Official repository: https://github.com/cppalliance/capy
8  
//
8  
//
9  

9  

10  
#include "src/ex/detail/strand_queue.hpp"
10  
#include "src/ex/detail/strand_queue.hpp"
11  
#include <boost/capy/ex/detail/strand_service.hpp>
11  
#include <boost/capy/ex/detail/strand_service.hpp>
12  
#include <boost/capy/coro.hpp>
12  
#include <boost/capy/coro.hpp>
13  

13  

14  
#include <atomic>
14  
#include <atomic>
15  
#include <coroutine>
15  
#include <coroutine>
16  
#include <mutex>
16  
#include <mutex>
17  
#include <thread>
17  
#include <thread>
18  
#include <utility>
18  
#include <utility>
19  

19  

20  
namespace boost {
20  
namespace boost {
21  
namespace capy {
21  
namespace capy {
22  
namespace detail {
22  
namespace detail {
23  

23  

24  
//----------------------------------------------------------
24  
//----------------------------------------------------------
25  

25  

26  
/** Implementation state for a strand.
26  
/** Implementation state for a strand.
27  

27  

28  
    Each strand_impl provides serialization for coroutines
28  
    Each strand_impl provides serialization for coroutines
29  
    dispatched through strands that share it.
29  
    dispatched through strands that share it.
30  
*/
30  
*/
31  
struct strand_impl
31  
struct strand_impl
32  
{
32  
{
33  
    std::mutex mutex_;
33  
    std::mutex mutex_;
34  
    strand_queue pending_;
34  
    strand_queue pending_;
35  
    bool locked_ = false;
35  
    bool locked_ = false;
36  
    std::atomic<std::thread::id> dispatch_thread_{};
36  
    std::atomic<std::thread::id> dispatch_thread_{};
37  
    void* cached_frame_ = nullptr;
37  
    void* cached_frame_ = nullptr;
38  
};
38  
};
39  

39  

40  
//----------------------------------------------------------
40  
//----------------------------------------------------------
41  

41  

42  
/** Invoker coroutine for strand dispatch.
42  
/** Invoker coroutine for strand dispatch.
43  

43  

44  
    Uses custom allocator to recycle frame - one allocation
44  
    Uses custom allocator to recycle frame - one allocation
45  
    per strand_impl lifetime, stored in trailer for recovery.
45  
    per strand_impl lifetime, stored in trailer for recovery.
46  
*/
46  
*/
47  
struct strand_invoker
47  
struct strand_invoker
48  
{
48  
{
49  
    struct promise_type
49  
    struct promise_type
50  
    {
50  
    {
51  
        void* operator new(std::size_t n, strand_impl& impl)
51  
        void* operator new(std::size_t n, strand_impl& impl)
52  
        {
52  
        {
53  
            constexpr auto A = alignof(strand_impl*);
53  
            constexpr auto A = alignof(strand_impl*);
54  
            std::size_t padded = (n + A - 1) & ~(A - 1);
54  
            std::size_t padded = (n + A - 1) & ~(A - 1);
55  
            std::size_t total = padded + sizeof(strand_impl*);
55  
            std::size_t total = padded + sizeof(strand_impl*);
56  

56  

57  
            void* p = impl.cached_frame_
57  
            void* p = impl.cached_frame_
58  
                ? std::exchange(impl.cached_frame_, nullptr)
58  
                ? std::exchange(impl.cached_frame_, nullptr)
59  
                : ::operator new(total);
59  
                : ::operator new(total);
60  

60  

61  
            // Trailer lets delete recover impl
61  
            // Trailer lets delete recover impl
62  
            *reinterpret_cast<strand_impl**>(
62  
            *reinterpret_cast<strand_impl**>(
63  
                static_cast<char*>(p) + padded) = &impl;
63  
                static_cast<char*>(p) + padded) = &impl;
64  
            return p;
64  
            return p;
65  
        }
65  
        }
66  

66  

67  
        void operator delete(void* p, std::size_t n) noexcept
67  
        void operator delete(void* p, std::size_t n) noexcept
68  
        {
68  
        {
69  
            constexpr auto A = alignof(strand_impl*);
69  
            constexpr auto A = alignof(strand_impl*);
70  
            std::size_t padded = (n + A - 1) & ~(A - 1);
70  
            std::size_t padded = (n + A - 1) & ~(A - 1);
71  

71  

72  
            auto* impl = *reinterpret_cast<strand_impl**>(
72  
            auto* impl = *reinterpret_cast<strand_impl**>(
73  
                static_cast<char*>(p) + padded);
73  
                static_cast<char*>(p) + padded);
74  

74  

75  
            if (!impl->cached_frame_)
75  
            if (!impl->cached_frame_)
76  
                impl->cached_frame_ = p;
76  
                impl->cached_frame_ = p;
77  
            else
77  
            else
78  
                ::operator delete(p);
78  
                ::operator delete(p);
79  
        }
79  
        }
80  

80  

81  
        strand_invoker get_return_object() noexcept
81  
        strand_invoker get_return_object() noexcept
82  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
82  
        { return {std::coroutine_handle<promise_type>::from_promise(*this)}; }
83  

83  

84  
        std::suspend_always initial_suspend() noexcept { return {}; }
84  
        std::suspend_always initial_suspend() noexcept { return {}; }
85  
        std::suspend_never final_suspend() noexcept { return {}; }
85  
        std::suspend_never final_suspend() noexcept { return {}; }
86  
        void return_void() noexcept {}
86  
        void return_void() noexcept {}
87  
        void unhandled_exception() { std::terminate(); }
87  
        void unhandled_exception() { std::terminate(); }
88  
    };
88  
    };
89  

89  

90  
    std::coroutine_handle<promise_type> h_;
90  
    std::coroutine_handle<promise_type> h_;
91  
};
91  
};
92  

92  

93  
//----------------------------------------------------------
93  
//----------------------------------------------------------
94  

94  

95  
/** Concrete implementation of strand_service.
95  
/** Concrete implementation of strand_service.
96  

96  

97  
    Holds the fixed pool of strand_impl objects.
97  
    Holds the fixed pool of strand_impl objects.
98  
*/
98  
*/
99  
class strand_service_impl : public strand_service
99  
class strand_service_impl : public strand_service
100  
{
100  
{
101  
    static constexpr std::size_t num_impls = 211;
101  
    static constexpr std::size_t num_impls = 211;
102  

102  

103  
    strand_impl impls_[num_impls];
103  
    strand_impl impls_[num_impls];
104  
    std::size_t salt_ = 0;
104  
    std::size_t salt_ = 0;
105  
    std::mutex mutex_;
105  
    std::mutex mutex_;
106  

106  

107  
public:
107  
public:
108  
    explicit
108  
    explicit
109  
    strand_service_impl(execution_context&)
109  
    strand_service_impl(execution_context&)
110  
    {
110  
    {
111  
    }
111  
    }
112  

112  

113  
    strand_impl*
113  
    strand_impl*
114  
    get_implementation() override
114  
    get_implementation() override
115  
    {
115  
    {
116  
        std::lock_guard<std::mutex> lock(mutex_);
116  
        std::lock_guard<std::mutex> lock(mutex_);
117  
        std::size_t index = salt_++;
117  
        std::size_t index = salt_++;
118  
        index = index % num_impls;
118  
        index = index % num_impls;
119  
        return &impls_[index];
119  
        return &impls_[index];
120  
    }
120  
    }
121  

121  

122  
protected:
122  
protected:
123  
    void
123  
    void
124  
    shutdown() override
124  
    shutdown() override
125  
    {
125  
    {
126  
        for(std::size_t i = 0; i < num_impls; ++i)
126  
        for(std::size_t i = 0; i < num_impls; ++i)
127  
        {
127  
        {
128  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
128  
            std::lock_guard<std::mutex> lock(impls_[i].mutex_);
129  
            impls_[i].locked_ = true;
129  
            impls_[i].locked_ = true;
130  

130  

131  
            if(impls_[i].cached_frame_)
131  
            if(impls_[i].cached_frame_)
132  
            {
132  
            {
133  
                ::operator delete(impls_[i].cached_frame_);
133  
                ::operator delete(impls_[i].cached_frame_);
134  
                impls_[i].cached_frame_ = nullptr;
134  
                impls_[i].cached_frame_ = nullptr;
135  
            }
135  
            }
136  
        }
136  
        }
137  
    }
137  
    }
138  

138  

139  
private:
139  
private:
140  
    static bool
140  
    static bool
141  
    enqueue(strand_impl& impl, coro h)
141  
    enqueue(strand_impl& impl, coro h)
142  
    {
142  
    {
143  
        std::lock_guard<std::mutex> lock(impl.mutex_);
143  
        std::lock_guard<std::mutex> lock(impl.mutex_);
144  
        impl.pending_.push(h);
144  
        impl.pending_.push(h);
145  
        if(!impl.locked_)
145  
        if(!impl.locked_)
146  
        {
146  
        {
147  
            impl.locked_ = true;
147  
            impl.locked_ = true;
148  
            return true;
148  
            return true;
149  
        }
149  
        }
150  
        return false;
150  
        return false;
151  
    }
151  
    }
152  

152  

153  
    static void
153  
    static void
154  
    dispatch_pending(strand_impl& impl)
154  
    dispatch_pending(strand_impl& impl)
155  
    {
155  
    {
156  
        strand_queue::taken_batch batch;
156  
        strand_queue::taken_batch batch;
157  
        {
157  
        {
158  
            std::lock_guard<std::mutex> lock(impl.mutex_);
158  
            std::lock_guard<std::mutex> lock(impl.mutex_);
159  
            batch = impl.pending_.take_all();
159  
            batch = impl.pending_.take_all();
160  
        }
160  
        }
161  
        impl.pending_.dispatch_batch(batch);
161  
        impl.pending_.dispatch_batch(batch);
162  
    }
162  
    }
163  

163  

164  
    static bool
164  
    static bool
165  
    try_unlock(strand_impl& impl)
165  
    try_unlock(strand_impl& impl)
166  
    {
166  
    {
167  
        std::lock_guard<std::mutex> lock(impl.mutex_);
167  
        std::lock_guard<std::mutex> lock(impl.mutex_);
168  
        if(impl.pending_.empty())
168  
        if(impl.pending_.empty())
169  
        {
169  
        {
170  
            impl.locked_ = false;
170  
            impl.locked_ = false;
171  
            return true;
171  
            return true;
172  
        }
172  
        }
173  
        return false;
173  
        return false;
174  
    }
174  
    }
175  

175  

176  
    static void
176  
    static void
177  
    set_dispatch_thread(strand_impl& impl) noexcept
177  
    set_dispatch_thread(strand_impl& impl) noexcept
178  
    {
178  
    {
179  
        impl.dispatch_thread_.store(std::this_thread::get_id());
179  
        impl.dispatch_thread_.store(std::this_thread::get_id());
180  
    }
180  
    }
181  

181  

182  
    static void
182  
    static void
183  
    clear_dispatch_thread(strand_impl& impl) noexcept
183  
    clear_dispatch_thread(strand_impl& impl) noexcept
184  
    {
184  
    {
185  
        impl.dispatch_thread_.store(std::thread::id{});
185  
        impl.dispatch_thread_.store(std::thread::id{});
186  
    }
186  
    }
187  

187  

188  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
188  
    // Loops until queue empty (aggressive). Alternative: per-batch fairness
189  
    // (repost after each batch to let other work run) - explore if starvation observed.
189  
    // (repost after each batch to let other work run) - explore if starvation observed.
190  
    static strand_invoker
190  
    static strand_invoker
191  
    make_invoker(strand_impl& impl)
191  
    make_invoker(strand_impl& impl)
192  
    {
192  
    {
193  
        strand_impl* p = &impl;
193  
        strand_impl* p = &impl;
194  
        for(;;)
194  
        for(;;)
195  
        {
195  
        {
196  
            set_dispatch_thread(*p);
196  
            set_dispatch_thread(*p);
197  
            dispatch_pending(*p);
197  
            dispatch_pending(*p);
198  
            if(try_unlock(*p))
198  
            if(try_unlock(*p))
199  
            {
199  
            {
200  
                clear_dispatch_thread(*p);
200  
                clear_dispatch_thread(*p);
201  
                co_return;
201  
                co_return;
202  
            }
202  
            }
203  
        }
203  
        }
204  
    }
204  
    }
205  

205  

206  
    friend class strand_service;
206  
    friend class strand_service;
207  
};
207  
};
208  

208  

209  
//----------------------------------------------------------
209  
//----------------------------------------------------------
210  

210  

211  
strand_service::
211  
strand_service::
212  
strand_service()
212  
strand_service()
213  
    : service()
213  
    : service()
214  
{
214  
{
215  
}
215  
}
216  

216  

217  
strand_service::
217  
strand_service::
218  
~strand_service() = default;
218  
~strand_service() = default;
219  

219  

220  
bool
220  
bool
221  
strand_service::
221  
strand_service::
222  
running_in_this_thread(strand_impl& impl) noexcept
222  
running_in_this_thread(strand_impl& impl) noexcept
223  
{
223  
{
224  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
224  
    return impl.dispatch_thread_.load() == std::this_thread::get_id();
225  
}
225  
}
226  

226  

227  
coro
227  
coro
228  
strand_service::
228  
strand_service::
229  
dispatch(strand_impl& impl, executor_ref ex, coro h)
229  
dispatch(strand_impl& impl, executor_ref ex, coro h)
230  
{
230  
{
231  
    if(running_in_this_thread(impl))
231  
    if(running_in_this_thread(impl))
232  
        return h;
232  
        return h;
233  

233  

234  
    if(strand_service_impl::enqueue(impl, h))
234  
    if(strand_service_impl::enqueue(impl, h))
235  
        ex.post(strand_service_impl::make_invoker(impl).h_);
235  
        ex.post(strand_service_impl::make_invoker(impl).h_);
236  

236  

237  
    return std::noop_coroutine();
237  
    return std::noop_coroutine();
238  
}
238  
}
239  

239  

240  
void
240  
void
241  
strand_service::
241  
strand_service::
242  
post(strand_impl& impl, executor_ref ex, coro h)
242  
post(strand_impl& impl, executor_ref ex, coro h)
243  
{
243  
{
244  
    if(strand_service_impl::enqueue(impl, h))
244  
    if(strand_service_impl::enqueue(impl, h))
245  
        ex.post(strand_service_impl::make_invoker(impl).h_);
245  
        ex.post(strand_service_impl::make_invoker(impl).h_);
246  
}
246  
}
247  

247  

248  
strand_service&
248  
strand_service&
249  
get_strand_service(execution_context& ctx)
249  
get_strand_service(execution_context& ctx)
250  
{
250  
{
251  
    return ctx.use_service<strand_service_impl>();
251  
    return ctx.use_service<strand_service_impl>();
252  
}
252  
}
253  

253  

254  
} // namespace detail
254  
} // namespace detail
255  
} // namespace capy
255  
} // namespace capy
256  
} // namespace boost
256  
} // namespace boost