8#include <batmat/assume.hpp>
17#ifndef CYQLONE_SANITY_CHECKS_BARRIER
19#define CYQLONE_SANITY_CHECKS_BARRIER 1
21#define CYQLONE_SANITY_CHECKS_BARRIER 0
45template <
typename CompletionFn = EmptyCompletion,
class PhaseType = u
int32_t>
72 static_assert(
sizeof(T) <=
sizeof(
payload));
73 static_assert(std::is_trivially_copyable_v<T>);
74 std::memcpy(
payload.data(), &t,
sizeof(T));
78 static_assert(
sizeof(T) <=
sizeof(
payload));
79 static_assert(std::is_trivially_copyable_v<T>);
81 std::memcpy(&t,
payload.data(),
sizeof(T));
86 sizeof(T) <=
sizeof(
payload) && std::is_trivially_copyable_v<T>;
91 struct alignas(cache_line_size)
State {
95 atomic_word::is_always_lock_free && !atomic_byte::is_always_lock_free;
98 using ticket_t = std::conditional_t<only_word_lock_free, atomic_word, atomic_byte>;
111#if CYQLONE_SANITY_CHECKS_BARRIER
119 if (
get_local_phase(thread_id).fetch_add(1, std::memory_order_relaxed) !=
121 BATMAT_ASSERT(!
"This thread has already arrived in this phase");
129 static constexpr auto acq_rel = std::memory_order_acq_rel;
134 for (
size_t level = 0;; ++level) {
139 auto &ticket =
state[thread_id].tickets[level];
140 const uint32_t end_node = (level_size + 1) >> 1;
141 const bool last_odd = thread_id + 1 == end_node && (level_size & 1) == 1;
142 const auto target = last_odd ? first_of_one : second_of_two;
143 const auto old_value = ticket.fetch_add(1, acq_rel);
144 if (old_value != target)
146 level_size = end_node;
158 template <
class T,
class F>
160 static constexpr auto acq_rel = std::memory_order_acq_rel;
172 for (
size_t level = 0;; ++level) {
173 if (level_size <= 1) {
178 auto offset =
size_t{1} << level;
179 auto write = thread_id << level;
182 auto &ticket =
state[thread_id].tickets[level];
183 const uint32_t end_node = (level_size + 1) >> 1;
184 const bool last_odd = thread_id + 1 == end_node && (level_size & 1) == 1;
185 const auto target = last_odd ? first_of_one : second_of_two;
186 const auto old_value = ticket.fetch_add(1, acq_rel);
187 if (old_value != target)
191 storage[write | +offset].
template load<T>());
192 level_size = end_node;
201 template <
class A,
class C>
203 C &&custom_completion) {
205 const auto cur_phase =
phase.load(std::memory_order_relaxed);
206#if CYQLONE_SANITY_CHECKS_BARRIER
209 if (arrival(cur_phase, thread_id)) {
210 std::invoke(std::forward<C>(custom_completion));
211 auto next_phase =
static_cast<BarrierPhase>(
static_cast<PhaseType
>(cur_phase) + 1);
212 phase.store(next_phase, std::memory_order_release);
215 return arrival_token{cur_phase};
220 static constexpr uint32_t
max() {
221#if CYQLONE_SANITY_CHECKS_BARRIER
226 return num_levels > 31 ? 0xFFFFFFFF : uint32_t{1} << num_levels;
235 const size_t leaf_count = (
expected + 1) >> 1;
236 state = std::make_unique<State[]>(leaf_count);
250 auto arrival = [
this](
BarrierPhase cur_phase, uint32_t thread_id) {
259 [[nodiscard]] arrival_token
arrive(uint32_t thread_id) {
269 [[nodiscard]] arrival_token
arrive(uint32_t thread_id, [[maybe_unused]]
int line) {
270#if CYQLONE_SANITY_CHECKS_BARRIER
272 std::memory_order_relaxed);
274 for (uint32_t i = 0; i <
expected; ++i)
287 return phase.load(std::memory_order_relaxed);
301 return phase.load(std::memory_order_relaxed) == token.get();
308 void wait(arrival_token &&token)
const {
309 const auto old_phase = token.get();
311 if (
phase.load(std::memory_order_acquire) != old_phase) [[likely]]
314 for (
auto spin = this->spin_count; spin-- > 0;)
315 if (
phase.load(std::memory_order_acquire) != old_phase) [[unlikely]]
317 phase.wait(old_phase, std::memory_order_acquire);
326 requires std::is_void_v<std::invoke_result_t<C &&>>
333 requires(!std::is_void_v<std::invoke_result_t<C &&>> &&
334 !std::is_reference_v<std::invoke_result_t<C &&>> &&
335 Storage::template is_compatible<std::invoke_result_t<C &&>>)
337 using ret_t = std::invoke_result_t<C &&>;
339 [
this, c{std::forward<C>(custom_completion)}]
mutable {
347 template <
class T,
class F>
358 wait(std::move(token));
364 template <
class T,
class F>
372 [[nodiscard]] T
broadcast(uint32_t thread_id, T &&x, uint32_t src = 0) {
373 if (thread_id == src)
374 storage[thread_id].store(std::forward<T>(x));
377 auto custom_completion = [
this, src] {
BarrierPhase get() const noexcept
arrival_token & operator=(arrival_token &&phase)=default
arrival_token(BarrierPhase phase)
arrival_token & operator=(const arrival_token &phase)=delete
arrival_token(arrival_token &&phase)=default
arrival_token(const arrival_token &phase)=delete
static constexpr uint32_t max()
Maximum number of threads supported by this barrier implementation.
TreeBarrier(uint32_t expected, CompletionFn completion)
Create a barrier with expected participating threads and a completion function that is called by the ...
void arrive_and_wait_with_completion(uint32_t thread_id, C &&custom_completion)
Convenience function to arrive and wait in a single call (with custom completion).
bool wait_may_block(const arrival_token &token) const noexcept
Check if wait() may block.
std::unique_ptr< State[]> state
State::ticket_t & get_local_line(uint32_t thread_id) noexcept
bool arrive_impl(BarrierPhase old_phase, uint32_t thread_id, T value, F reduce)
Fused implementation of the combining tree arrival and a reduction operation.
static constexpr size_t cache_line_size
T wait_reduce(arrival_token_typed< T > &&token)
Wait for the result of an arrive_reduce call and obtain the reduced value.
BarrierPhase current_phase() const
Query the current barrier phase.
void arrive_and_wait(uint32_t thread_id, int line)
Convenience function to arrive and wait in a single call (with optional sanity check).
TreeBarrier(const TreeBarrier &)=delete
arrival_token_typed< T > arrive_reduce(uint32_t thread_id, T x, F reduce)
Combining tree reduction across all threads.
arrival_token arrive(uint32_t thread_id)
Arrive at the barrier.
completion_type completion
TreeBarrier(TreeBarrier &&)=default
arrival_token arrive_with_completion(uint32_t thread_id, C &&custom_completion)
Arrive at the barrier with a custom completion function that is called by the last thread that arrive...
void wait(arrival_token &&token) const
Wait for the barrier to complete after an arrival, using the given token.
auto arrive_and_wait_with_completion(uint32_t thread_id, C &&custom_completion)
Convenience function to arrive and wait in a single call (with custom completion).
T reduce(uint32_t thread_id, T x, F reduce)
arrival_token arrive_with_completion(uint32_t thread_id, A arrival, C &&custom_completion)
Generic implementation of arrive with custom completion function.
void sanity_check_arrival(uint32_t thread_id, BarrierPhase cur_phase) noexcept
arrival_token arrive(uint32_t thread_id, int line)
Arrive at the barrier, recording the given line number for sanity checking to make sure that all thre...
TreeBarrier & operator=(TreeBarrier &&)=default
void arrive_and_wait(uint32_t thread_id)
Convenience function to arrive and wait in a single call.
TreeBarrier & operator=(const TreeBarrier &)=delete
std::atomic< BarrierPhase > phase
State::ticket_t & get_local_phase(uint32_t thread_id) noexcept
Storage broadcast_storage
typename State::ticket_t::value_type ticket_value_type
T broadcast(uint32_t thread_id, T &&x, uint32_t src=0)
Broadcast a value from the source thread to all other threads.
bool arrive_impl(BarrierPhase old_phase, uint32_t thread_id)
Combining tree arrival.
std::unique_ptr< Storage[]> storage
No-op completion function for the TreeBarrier.
void operator()() const noexcept
Does nothing.
Atomic counters for each level of the combining tree.
std::atomic< uint32_t > atomic_word
std::array< ticket_t, num_levels > tickets
std::atomic< unsigned char > atomic_byte
static constexpr size_t num_levels
std::conditional_t< only_word_lock_free, atomic_word, atomic_byte > ticket_t
static constexpr bool only_word_lock_free
Storage for small values used in reductions and broadcasts.
static constexpr bool is_compatible
std::array< std::byte, cache_line_size > payload