Skip to content

Commit

Permalink
support yield
Browse files Browse the repository at this point in the history
  • Loading branch information
coyorkdow committed Nov 3, 2024
1 parent 447e549 commit ed41f78
Show file tree
Hide file tree
Showing 2 changed files with 55 additions and 6 deletions.
37 changes: 31 additions & 6 deletions cosched.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,25 @@

namespace coro {

class static_thread_pool;

template <class Tp>
class async_awaiter;

template <class Tp>
class parallel_awaiter;

struct always_awaiter : std::suspend_always {};
class always_awaiter : public std::suspend_always {
public:
always_awaiter() noexcept : scheduler_(nullptr) {}
explicit always_awaiter(static_thread_pool* pool) noexcept
: scheduler_(pool) {}

class static_thread_pool;
inline void await_suspend(std::coroutine_handle<>) const;

private:
static_thread_pool* scheduler_;
};

template <class Tp>
requires(!std::is_reference_v<Tp>)
Expand Down Expand Up @@ -54,8 +64,6 @@ class promise_base : public std::promise<Tp> {
std::mutex mu;
std::coroutine_handle<> wait_coro;
bool done{false};
// Mark this coroutine has been enqueued to the scheduler. Only awaiter will
// use it, no need be protected by mutex.
bool has_scheduled{false};
bool suspend_on_final{true};
static_thread_pool* scheduler{nullptr};
Expand All @@ -70,7 +78,9 @@ class promise_base : public std::promise<Tp> {
template <class Awaiter>
void await_transform(Awaiter) = delete;

always_awaiter await_transform(always_awaiter) { return always_awaiter{}; }
always_awaiter await_transform(always_awaiter) {
return always_awaiter{shared_ctx_->scheduler};
}

template <class Up>
inline async_awaiter<Up> await_transform(task<Up>) noexcept;
Expand Down Expand Up @@ -101,7 +111,8 @@ class promise : public details_::promise_base<Tp> {
public:
template <class Up>
void return_value(Up&& value)
requires(std::is_same_v<Tp, Up> || std::is_same_v<const Tp&, Up>)
requires(std::is_same_v<Tp, Up> || std::is_same_v<const Tp&, Up> ||
std::is_constructible_v<Tp, Up>)
{
this->set_value(std::forward<Up>(value));
}
Expand Down Expand Up @@ -206,6 +217,12 @@ class task_base {
return h;
}

~task_base() {
if (this->typ_ == task_type::deferred && this->fu_.valid()) {
get();
}
}

protected:
task_base(std::coroutine_handle<promise<Tp>> h = nullptr,
task_type t = task_type::deferred) noexcept
Expand Down Expand Up @@ -295,6 +312,8 @@ class static_thread_pool {
}

private:
friend class always_awaiter;

template <class Tp>
friend class async_awaiter;

Expand Down Expand Up @@ -337,6 +356,12 @@ class static_thread_pool {
bool exit_;
};

inline void always_awaiter::await_suspend(std::coroutine_handle<> h) const {
if (scheduler_) {
scheduler_->schedule(h);
}
}

template <class Tp>
class async_awaiter : protected coro::task<Tp> {
public:
Expand Down
24 changes: 24 additions & 0 deletions test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include <mutex>
#include <ostream>
#include <set>
#include <string>
#include <thread>

#include "cosched.hpp"
Expand Down Expand Up @@ -46,6 +47,9 @@ coro::task<int> slow_response(int a, int b) {
coro::task<int> resp1 = co_await coro::this_scheduler::parallel(request(a));
coro::task<int> resp2 = co_await coro::this_scheduler::parallel(request(b));
std::this_thread::sleep_for(1s);
auto immediate = co_await coro::this_scheduler::parallel(
[]() -> coro::task<void> { co_return; }());
co_await std::move(immediate);
co_return co_await std::move(resp1) + co_await std::move(resp2);
}

Expand Down Expand Up @@ -94,3 +98,23 @@ TEST(StaticThreadPoolTest, Parallel) {
std::cout << "elapsed time: " << elapsed.count() << "ms\n";
EXPECT_LE(elapsed.count(), 1100);
}

TEST(StaticThreadPoolTest, Yield) {
auto yield_some = [](int n) -> coro::task<std::string> {
while (n--) {
co_await coro::this_scheduler::yield;
}
co_return "complete";
};
EXPECT_EQ("complete", yield_some(10).get());
int n = 0;
for (auto h = yield_some(10).release_coroutine_handle();
!h.done() || (h.destroy(), false); h.resume()) {
n++;
}
EXPECT_EQ(11, n);

coro::static_thread_pool pool(1);
auto task = pool.schedule(yield_some(10));
EXPECT_EQ("complete", task.get());
}

0 comments on commit ed41f78

Please sign in to comment.