Skip to content

Commit

Permalink
add latch
Browse files Browse the repository at this point in the history
  • Loading branch information
coyorkdow committed Nov 4, 2024
1 parent 0e3c000 commit b53e8b2
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 11 deletions.
90 changes: 82 additions & 8 deletions cosched.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
#pragma once
#include <atomic>
#include <cassert>
#include <concepts>
#include <condition_variable>
#include <coroutine>
#include <cstddef>
Expand Down Expand Up @@ -43,6 +45,8 @@ enum class task_type {
async,
};

class latch;

namespace this_scheduler {
inline always_awaiter yield;

Expand All @@ -51,6 +55,30 @@ parallel_awaiter<Tp> parallel(task<Tp>) noexcept;
} // namespace this_scheduler

namespace details_ {

struct cutex_wait_context {
static_thread_pool* scheduler{nullptr};
std::deque<std::coroutine_handle<>> wait_ques;
std::mutex mu;
};

template <class Pred>
requires requires(Pred f, std::coroutine_handle<> h) {
{ f(h) } -> std::same_as<bool>;
}
class condition_awaiter : public std::suspend_always {
public:
explicit condition_awaiter(Pred f) : f_(f) {}

bool await_suspend(std::coroutine_handle<> h) const { return f_(h); }

private:
Pred f_;
};

template <class Tp>
condition_awaiter(Tp) -> condition_awaiter<Tp>;

template <class Tp>
struct final_awaiter;

Expand Down Expand Up @@ -78,7 +106,9 @@ class promise_base : public std::promise<Tp> {
template <class Awaiter>
void await_transform(Awaiter) = delete;

always_awaiter await_transform(always_awaiter) {
inline auto await_transform(latch&) noexcept;

always_awaiter await_transform(always_awaiter) noexcept {
return always_awaiter{shared_ctx_->scheduler};
}

Expand Down Expand Up @@ -253,7 +283,7 @@ class task_base {
};
} // namespace details_

template <class Tp>
template <class Tp = void>
requires(!std::is_reference_v<Tp>)
class task : public details_::task_base<Tp> {
public:
Expand Down Expand Up @@ -320,6 +350,8 @@ class static_thread_pool {
template <class Tp>
friend struct details_::final_awaiter;

friend class latch;

template <class Tp>
void schedule(std::coroutine_handle<Tp> handle) {
if constexpr (is_promise_v<Tp>) {
Expand Down Expand Up @@ -456,6 +488,54 @@ struct details_::final_awaiter {
promise_base<Tp>* self;
};

template <class Tp>
inline details_::final_awaiter<Tp>
details_::promise_base<Tp>::final_suspend() noexcept {
return {this};
}

class latch {
public:
explicit latch(std::ptrdiff_t countdown) : countdown_(countdown) {}
latch(const latch&) = delete;
latch& operator=(const latch&) = delete;

void count_down(std::ptrdiff_t n = 1) {
auto before = countdown_.fetch_sub(n, std::memory_order_acq_rel);
if (before > 0 && before - n <= 0) {
std::unique_lock l(wait_ctx_.mu);
if (!wait_ctx_.scheduler) return;
for (auto h : wait_ctx_.wait_ques) {
wait_ctx_.scheduler->schedule(h);
}
}
}

private:
template <class Tp>
friend class details_::promise_base;

auto wait_this(static_thread_pool* scheduler) {
return [scheduler, this](std::coroutine_handle<> h) -> bool {
std::unique_lock l(wait_ctx_.mu);
if (countdown_.load(std::memory_order_acquire) <= 0) {
return false;
}
wait_ctx_.scheduler = scheduler;
wait_ctx_.wait_ques.push_back(h);
return true;
};
}

std::atomic<ptrdiff_t> countdown_;
details_::cutex_wait_context wait_ctx_;
};

template <class Tp>
inline auto details_::promise_base<Tp>::await_transform(latch& l) noexcept {
return condition_awaiter(l.wait_this(shared_ctx_->scheduler));
}

template <class Tp>
template <class Up>
inline async_awaiter<Up> details_::promise_base<Tp>::await_transform(
Expand All @@ -478,12 +558,6 @@ inline parallel_awaiter<Up> details_::promise_base<Tp>::await_transform(
return std::move(awaiter);
}

template <class Tp>
inline details_::final_awaiter<Tp>
details_::promise_base<Tp>::final_suspend() noexcept {
return {this};
}

template <class Tp>
parallel_awaiter<Tp> this_scheduler::parallel(task<Tp> t) noexcept {
return parallel_awaiter<Tp>(std::move(t));
Expand Down
27 changes: 24 additions & 3 deletions test.cc
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@

#include <asm-generic/errno.h>
#include <gtest/gtest.h>
#include <sys/types.h>

#include <atomic>
#include <chrono>
#include <coroutine>
#include <iostream>
#include <mutex>
#include <ostream>
Expand Down Expand Up @@ -118,3 +117,25 @@ TEST(StaticThreadPoolTest, Yield) {
auto task = pool.schedule(yield_some(10));
EXPECT_EQ("complete", task.get());
}

TEST(StaticThreadPoolTest, Latch) {
using namespace std::chrono_literals;
std::atomic<int> cnt{0};
coro::latch l(2);
auto count_up = [](std::atomic<int>& cnt, coro::latch& l) -> coro::task<> {
co_await l;
cnt.fetch_add(1);
co_return;
};

coro::static_thread_pool pool(1);
pool.schedule(count_up(cnt, l));
pool.schedule(count_up(cnt, l));
std::this_thread::sleep_for(5ms);
EXPECT_EQ(0, cnt.load());
l.count_down();
EXPECT_EQ(0, cnt.load());
l.count_down();
std::this_thread::sleep_for(5ms);
EXPECT_EQ(2, cnt.load());
}

0 comments on commit b53e8b2

Please sign in to comment.