diff --git a/promise.hpp b/promise.hpp index 72e6e4a..5c0d9ba 100644 --- a/promise.hpp +++ b/promise.hpp @@ -4,6 +4,7 @@ #include #include #include +#include #include #include #include @@ -14,9 +15,12 @@ namespace coro { template -struct async_awaiter; +class async_awaiter; -struct suspend_always : std::suspend_always {}; +template +class parallel_awaiter; + +struct always_awaiter : std::suspend_always {}; class static_thread_pool; @@ -30,52 +34,64 @@ enum class task_type { }; namespace this_scheduler { -inline suspend_always yield; +inline always_awaiter yield; + +template +parallel_awaiter parallel(task) noexcept; } // namespace this_scheduler namespace details_ { -struct coro_context { - std::queue> ctx_wait_que_; - std::mutex ctx_mu_; -}; +template +struct final_awaiter; template class task_base; template -struct final_suspender; +class promise_base : public std::promise { + public: + struct shared_ctx_t { + std::mutex mu; + std::coroutine_handle<> wait_coro; + bool done{false}; + // Mark this coroutine has been enqueued to the scheduler. Only awaiter will + // use it, no need be protected by mutex. + bool has_scheduled{false}; + bool suspend_on_final{true}; + static_thread_pool* scheduler{nullptr}; + }; -template -struct promise_base : std::promise, protected coro_context { task get_return_object() noexcept; std::suspend_always initial_suspend() const noexcept { return {}; } - inline final_suspender final_suspend() noexcept; + inline final_awaiter final_suspend() noexcept; template void await_transform(Awaiter) = delete; - suspend_always await_transform(suspend_always) { return suspend_always{}; } + always_awaiter await_transform(always_awaiter) { return always_awaiter{}; } template inline async_awaiter await_transform(task) noexcept; + template + inline parallel_awaiter await_transform(parallel_awaiter) noexcept; + void unhandled_exception() noexcept { this->set_exception(std::current_exception()); } protected: template - friend struct promise_base; + friend class promise_base; - promise_base() : scheduler_(nullptr), suspend_on_final_(true) {} + promise_base() : shared_ctx_(std::make_shared()) {} - static_thread_pool* scheduler_; - bool suspend_on_final_; + std::shared_ptr shared_ctx_; template - friend struct final_suspender; + friend struct final_awaiter; }; } // namespace details_ @@ -95,7 +111,7 @@ class promise : public details_::promise_base { friend class details_::task_base; template - friend struct async_awaiter; + friend class async_awaiter; template requires(!std::is_reference_v) @@ -114,7 +130,7 @@ class promise : public details_::promise_base { friend class details_::task_base; template - friend struct async_awaiter; + friend class async_awaiter; friend class static_thread_pool; }; @@ -147,8 +163,8 @@ class task_base { void wait() const { if (typ_ == task_type::deferred) { - while (!done()) { - resume(); + while (!handle_.done()) { + handle_.resume(); } } else { fu_.wait(); @@ -181,15 +197,11 @@ class task_base { } } - void resume() const { handle_.resume(); } - - bool done() const noexcept { return handle_.done(); } - - void destroy() const { handle_.destroy(); } - std::coroutine_handle> release_coroutine_handle() noexcept { + if (typ_ == task_type::async) { + return nullptr; + } auto h = handle_; - handle_ = nullptr; typ_ = task_type::async; return h; } @@ -200,6 +212,7 @@ class task_base { : handle_(h), typ_(t) { if (h) { fu_ = h.promise().get_future(); + shared_ctx_ = h.promise().shared_ctx_; } } @@ -208,6 +221,7 @@ class task_base { task_base& operator=(task_base&& rhs) noexcept { handle_ = rhs.handle_; fu_ = std::move(rhs.fu_); + shared_ctx_ = std::move(rhs.shared_ctx_); typ_ = rhs.typ_; rhs.handle_ = nullptr; return *this; @@ -217,6 +231,7 @@ class task_base { std::coroutine_handle> handle_; std::future fu_; + std::shared_ptr::shared_ctx_t> shared_ctx_; task_type typ_; }; } // namespace details_ @@ -236,7 +251,7 @@ class task : public details_::task_base { : details_::task_base(h, t) {} template - friend struct details_::promise_base; + friend class details_::promise_base; friend class static_thread_pool; }; @@ -281,13 +296,16 @@ class static_thread_pool { private: template - friend struct details_::final_suspender; + friend class async_awaiter; + + template + friend struct details_::final_awaiter; template void schedule(std::coroutine_handle handle) { if constexpr (is_promise_v) { - handle.promise().suspend_on_final_ = false; - handle.promise().scheduler_ = this; + handle.promise().shared_ctx_->suspend_on_final = false; + handle.promise().shared_ctx_->scheduler = this; } { std::unique_lock l(mu_); @@ -320,54 +338,92 @@ class static_thread_pool { }; template -struct async_awaiter : public coro::task { +class async_awaiter : protected coro::task { + public: bool await_ready() const noexcept { + if (this->typ_ == task_type::deferred) { + return true; // jump to call this->get() directly. + } using namespace std::chrono_literals; return this->wait_for(0s) == std::future_status::ready; } - void await_suspend(std::coroutine_handle<> h) { - assert(this->typ_ == task_type::async); - promise& promise = this->get_promise(); + bool await_suspend(std::coroutine_handle<> h) { + // The callee might resume caller in the future, and result in destruction + // of the caller frame. Which means the awaiter will be destructed too. + // Therefore, we cannot use `this` next. + return maybe_suspend(suspend_, this->shared_ctx_, this->handle_, h); + } + + Tp await_resume() { return this->get(); } + + protected: + template + friend class details_::promise_base; + + explicit async_awaiter(task t, bool suspend = true) noexcept + : task(std::move(t)), suspend_(suspend) {} + + static bool maybe_suspend( + bool need_suspend, + std::shared_ptr::shared_ctx_t> callee_ctx, + std::coroutine_handle> callee, + std::coroutine_handle<> caller) { + bool has_scheduled = false; + bool done = false; { - std::unique_lock l(promise.ctx_mu_); - promise.ctx_wait_que_.push(h); - } - if (promise.scheduler_) { - promise.scheduler_->schedule(static_cast&>(*this)); - } else { - while (!this->done()) { - this->resume(); + std::unique_lock l(callee_ctx->mu); + if (done = callee_ctx->done; !done && need_suspend) { + callee_ctx->wait_coro = caller; + } + has_scheduled = callee_ctx->has_scheduled; + if (!has_scheduled && callee_ctx->scheduler) { + callee_ctx->has_scheduled = true; } - this->destroy(); } + if (!has_scheduled && callee_ctx->scheduler) { + callee_ctx->scheduler->schedule(callee); + } + return !done && need_suspend; } - Tp await_resume() { return this->get(); } + bool suspend_; +}; + +template +class parallel_awaiter : public async_awaiter { + public: + task await_resume() noexcept { + return std::move(static_cast&>(*this)); + } + + private: + template + friend class details_::promise_base; + + template + friend parallel_awaiter this_scheduler::parallel(task) noexcept; + + explicit parallel_awaiter(task t) noexcept + : async_awaiter(std::move(t), false /*not suspend*/) {} }; template -struct details_::final_suspender { +struct details_::final_awaiter { constexpr bool await_ready() const noexcept { return false; } bool await_suspend(std::coroutine_handle<>) const noexcept { - std::unique_lock l(self->ctx_mu_); - auto wait_que = std::move(self->ctx_wait_que_); - auto sch = self->scheduler_; - bool suspend_on_final = self->suspend_on_final_; - // Current coroutine might be destroyed by the coroutines resumed in this - // function. Promise object (which is stored in coroutine frame) is not - // allowed to use. - while (!wait_que.empty()) { - auto handle = wait_que.front(); - wait_que.pop(); - if (sch) { - sch->schedule(handle); + auto shared_ctx = std::move(self->shared_ctx_); + std::unique_lock l(shared_ctx->mu); + shared_ctx->done = true; + if (shared_ctx->wait_coro) { + if (shared_ctx->scheduler) { + shared_ctx->scheduler->schedule(shared_ctx->wait_coro); } else { - handle.resume(); + shared_ctx->wait_coro.resume(); } } - return suspend_on_final; + return shared_ctx->suspend_on_final; } constexpr void await_resume() const noexcept {} @@ -379,15 +435,33 @@ template template inline async_awaiter details_::promise_base::await_transform( task t) noexcept { - t.get_promise().scheduler_ = scheduler_; - t.typ_ = task_type::async; - return async_awaiter{std::move(t)}; + t.shared_ctx_->scheduler = shared_ctx_->scheduler; + if (shared_ctx_->scheduler) { + t.typ_ = task_type::async; + } + return async_awaiter(std::move(t)); +} + +template +template +inline parallel_awaiter details_::promise_base::await_transform( + parallel_awaiter awaiter) noexcept { + awaiter.shared_ctx_->scheduler = shared_ctx_->scheduler; + if (shared_ctx_->scheduler) { + awaiter.typ_ = task_type::async; + } + return std::move(awaiter); } template -inline details_::final_suspender +inline details_::final_awaiter details_::promise_base::final_suspend() noexcept { return {this}; } +template +parallel_awaiter this_scheduler::parallel(task t) noexcept { + return parallel_awaiter(std::move(t)); +} + } // namespace coro \ No newline at end of file diff --git a/test.cc b/test.cc index 8202529..129a78e 100644 --- a/test.cc +++ b/test.cc @@ -1,5 +1,7 @@ +#include #include +#include #include #include @@ -23,6 +25,18 @@ coro::task foo(std::string v) { co_return v; } +coro::task slow_response(int a, int b) { + using namespace std::chrono_literals; + auto request = [](int v) -> coro::task { + std::this_thread::sleep_for(1s); + co_return v; + }; + coro::task resp1 = co_await coro::this_scheduler::parallel(request(a)); + coro::task resp2 = co_await coro::this_scheduler::parallel(request(b)); + std::this_thread::sleep_for(1s); + co_return co_await std::move(resp1) + co_await std::move(resp2); +} + int main() { auto fib = fibonacci(1); fib.wait(); @@ -51,4 +65,10 @@ int main() { coro::static_thread_pool pool(3); std::cout << pool.schedule(fibonacci(10)).get() << std::endl; + + auto start = std::chrono::steady_clock::now(); + std::cout << pool.schedule(slow_response(1, 2)).get() << std::endl; + auto elapsed = std::chrono::duration_cast( + std::chrono::steady_clock::now() - start); + std::cout << "elapsed time: " << elapsed << '\n'; } \ No newline at end of file