Skip to content

Commit

Permalink
WIP: invoke in parallel
Browse files Browse the repository at this point in the history
  • Loading branch information
coyorkdow committed Nov 3, 2024
1 parent a68cdad commit bfbb8ee
Show file tree
Hide file tree
Showing 2 changed files with 159 additions and 65 deletions.
204 changes: 139 additions & 65 deletions promise.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <coroutine>
#include <cstddef>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
Expand All @@ -14,9 +15,12 @@
namespace coro {

template <class Tp>
struct async_awaiter;
class async_awaiter;

struct suspend_always : std::suspend_always {};
template <class Tp>
class parallel_awaiter;

struct always_awaiter : std::suspend_always {};

class static_thread_pool;

Expand All @@ -30,52 +34,64 @@ enum class task_type {
};

namespace this_scheduler {
inline suspend_always yield;
inline always_awaiter yield;

template <class Tp>
parallel_awaiter<Tp> parallel(task<Tp>) noexcept;
} // namespace this_scheduler

namespace details_ {
struct coro_context {
std::queue<std::coroutine_handle<>> ctx_wait_que_;
std::mutex ctx_mu_;
};
template <class Tp>
struct final_awaiter;

template <class Tp>
class task_base;

template <class Tp>
struct final_suspender;
class promise_base : public std::promise<Tp> {
public:
struct shared_ctx_t {
std::mutex mu;
std::coroutine_handle<> wait_coro;
bool done{false};
// Mark this coroutine has been enqueued to the scheduler. Only awaiter will
// use it, no need be protected by mutex.
bool has_scheduled{false};
bool suspend_on_final{true};
static_thread_pool* scheduler{nullptr};
};

template <class Tp>
struct promise_base : std::promise<Tp>, protected coro_context {
task<Tp> get_return_object() noexcept;

std::suspend_always initial_suspend() const noexcept { return {}; }

inline final_suspender<Tp> final_suspend() noexcept;
inline final_awaiter<Tp> final_suspend() noexcept;

template <class Awaiter>
void await_transform(Awaiter) = delete;

suspend_always await_transform(suspend_always) { return suspend_always{}; }
always_awaiter await_transform(always_awaiter) { return always_awaiter{}; }

template <class Up>
inline async_awaiter<Up> await_transform(task<Up>) noexcept;

template <class Up>
inline parallel_awaiter<Up> await_transform(parallel_awaiter<Up>) noexcept;

void unhandled_exception() noexcept {
this->set_exception(std::current_exception());
}

protected:
template <class Up>
friend struct promise_base;
friend class promise_base;

promise_base() : scheduler_(nullptr), suspend_on_final_(true) {}
promise_base() : shared_ctx_(std::make_shared<shared_ctx_t>()) {}

static_thread_pool* scheduler_;
bool suspend_on_final_;
std::shared_ptr<shared_ctx_t> shared_ctx_;

template <class Up>
friend struct final_suspender;
friend struct final_awaiter;
};
} // namespace details_

Expand All @@ -95,7 +111,7 @@ class promise : public details_::promise_base<Tp> {
friend class details_::task_base;

template <class Up>
friend struct async_awaiter;
friend class async_awaiter;

template <class Up>
requires(!std::is_reference_v<Up>)
Expand All @@ -114,7 +130,7 @@ class promise<void> : public details_::promise_base<void> {
friend class details_::task_base;

template <class Up>
friend struct async_awaiter;
friend class async_awaiter;

friend class static_thread_pool;
};
Expand Down Expand Up @@ -147,8 +163,8 @@ class task_base {

void wait() const {
if (typ_ == task_type::deferred) {
while (!done()) {
resume();
while (!handle_.done()) {
handle_.resume();
}
} else {
fu_.wait();
Expand Down Expand Up @@ -181,15 +197,11 @@ class task_base {
}
}

void resume() const { handle_.resume(); }

bool done() const noexcept { return handle_.done(); }

void destroy() const { handle_.destroy(); }

std::coroutine_handle<promise<Tp>> release_coroutine_handle() noexcept {
if (typ_ == task_type::async) {
return nullptr;
}
auto h = handle_;
handle_ = nullptr;
typ_ = task_type::async;
return h;
}
Expand All @@ -200,6 +212,7 @@ class task_base {
: handle_(h), typ_(t) {
if (h) {
fu_ = h.promise().get_future();
shared_ctx_ = h.promise().shared_ctx_;
}
}

Expand All @@ -208,6 +221,7 @@ class task_base {
task_base& operator=(task_base&& rhs) noexcept {
handle_ = rhs.handle_;
fu_ = std::move(rhs.fu_);
shared_ctx_ = std::move(rhs.shared_ctx_);
typ_ = rhs.typ_;
rhs.handle_ = nullptr;
return *this;
Expand All @@ -217,6 +231,7 @@ class task_base {

std::coroutine_handle<promise<Tp>> handle_;
std::future<Tp> fu_;
std::shared_ptr<typename promise<Tp>::shared_ctx_t> shared_ctx_;
task_type typ_;
};
} // namespace details_
Expand All @@ -236,7 +251,7 @@ class task : public details_::task_base<Tp> {
: details_::task_base<Tp>(h, t) {}

template <class Rp>
friend struct details_::promise_base;
friend class details_::promise_base;

friend class static_thread_pool;
};
Expand Down Expand Up @@ -281,13 +296,16 @@ class static_thread_pool {

private:
template <class Tp>
friend struct details_::final_suspender;
friend class async_awaiter;

template <class Tp>
friend struct details_::final_awaiter;

template <class Tp>
void schedule(std::coroutine_handle<Tp> handle) {
if constexpr (is_promise_v<Tp>) {
handle.promise().suspend_on_final_ = false;
handle.promise().scheduler_ = this;
handle.promise().shared_ctx_->suspend_on_final = false;
handle.promise().shared_ctx_->scheduler = this;
}
{
std::unique_lock l(mu_);
Expand Down Expand Up @@ -320,54 +338,92 @@ class static_thread_pool {
};

template <class Tp>
struct async_awaiter : public coro::task<Tp> {
class async_awaiter : protected coro::task<Tp> {
public:
bool await_ready() const noexcept {
if (this->typ_ == task_type::deferred) {
return true; // jump to call this->get() directly.
}
using namespace std::chrono_literals;
return this->wait_for(0s) == std::future_status::ready;
}

void await_suspend(std::coroutine_handle<> h) {
assert(this->typ_ == task_type::async);
promise<Tp>& promise = this->get_promise();
bool await_suspend(std::coroutine_handle<> h) {
// The callee might resume caller in the future, and result in destruction
// of the caller frame. Which means the awaiter will be destructed too.
// Therefore, we cannot use `this` next.
return maybe_suspend(suspend_, this->shared_ctx_, this->handle_, h);
}

Tp await_resume() { return this->get(); }

protected:
template <class Up>
friend class details_::promise_base;

explicit async_awaiter(task<Tp> t, bool suspend = true) noexcept
: task<Tp>(std::move(t)), suspend_(suspend) {}

static bool maybe_suspend(
bool need_suspend,
std::shared_ptr<typename promise<Tp>::shared_ctx_t> callee_ctx,
std::coroutine_handle<promise<Tp>> callee,
std::coroutine_handle<> caller) {
bool has_scheduled = false;
bool done = false;
{
std::unique_lock l(promise.ctx_mu_);
promise.ctx_wait_que_.push(h);
}
if (promise.scheduler_) {
promise.scheduler_->schedule(static_cast<task<Tp>&>(*this));
} else {
while (!this->done()) {
this->resume();
std::unique_lock l(callee_ctx->mu);
if (done = callee_ctx->done; !done && need_suspend) {
callee_ctx->wait_coro = caller;
}
has_scheduled = callee_ctx->has_scheduled;
if (!has_scheduled && callee_ctx->scheduler) {
callee_ctx->has_scheduled = true;
}
this->destroy();
}
if (!has_scheduled && callee_ctx->scheduler) {
callee_ctx->scheduler->schedule(callee);
}
return !done && need_suspend;
}

Tp await_resume() { return this->get(); }
bool suspend_;
};

template <class Tp>
class parallel_awaiter : public async_awaiter<Tp> {
public:
task<Tp> await_resume() noexcept {
return std::move(static_cast<task<Tp>&>(*this));
}

private:
template <class Up>
friend class details_::promise_base;

template <class Up>
friend parallel_awaiter<Up> this_scheduler::parallel(task<Up>) noexcept;

explicit parallel_awaiter(task<Tp> t) noexcept
: async_awaiter<Tp>(std::move(t), false /*not suspend*/) {}
};

template <class Tp>
struct details_::final_suspender {
struct details_::final_awaiter {
constexpr bool await_ready() const noexcept { return false; }

bool await_suspend(std::coroutine_handle<>) const noexcept {
std::unique_lock l(self->ctx_mu_);
auto wait_que = std::move(self->ctx_wait_que_);
auto sch = self->scheduler_;
bool suspend_on_final = self->suspend_on_final_;
// Current coroutine might be destroyed by the coroutines resumed in this
// function. Promise object (which is stored in coroutine frame) is not
// allowed to use.
while (!wait_que.empty()) {
auto handle = wait_que.front();
wait_que.pop();
if (sch) {
sch->schedule(handle);
auto shared_ctx = std::move(self->shared_ctx_);
std::unique_lock l(shared_ctx->mu);
shared_ctx->done = true;
if (shared_ctx->wait_coro) {
if (shared_ctx->scheduler) {
shared_ctx->scheduler->schedule(shared_ctx->wait_coro);
} else {
handle.resume();
shared_ctx->wait_coro.resume();
}
}
return suspend_on_final;
return shared_ctx->suspend_on_final;
}

constexpr void await_resume() const noexcept {}
Expand All @@ -379,15 +435,33 @@ template <class Tp>
template <class Up>
inline async_awaiter<Up> details_::promise_base<Tp>::await_transform(
task<Up> t) noexcept {
t.get_promise().scheduler_ = scheduler_;
t.typ_ = task_type::async;
return async_awaiter<Up>{std::move(t)};
t.shared_ctx_->scheduler = shared_ctx_->scheduler;
if (shared_ctx_->scheduler) {
t.typ_ = task_type::async;
}
return async_awaiter<Up>(std::move(t));
}

template <class Tp>
template <class Up>
inline parallel_awaiter<Up> details_::promise_base<Tp>::await_transform(
parallel_awaiter<Up> awaiter) noexcept {
awaiter.shared_ctx_->scheduler = shared_ctx_->scheduler;
if (shared_ctx_->scheduler) {
awaiter.typ_ = task_type::async;
}
return std::move(awaiter);
}

template <class Tp>
inline details_::final_suspender<Tp>
inline details_::final_awaiter<Tp>
details_::promise_base<Tp>::final_suspend() noexcept {
return {this};
}

template <class Tp>
parallel_awaiter<Tp> this_scheduler::parallel(task<Tp> t) noexcept {
return parallel_awaiter<Tp>(std::move(t));
}

} // namespace coro
Loading

0 comments on commit bfbb8ee

Please sign in to comment.