From ddf346e8e94a6a168fd1a15e9522c8ba62863287 Mon Sep 17 00:00:00 2001 From: microcai Date: Sun, 20 Oct 2024 09:58:57 +0800 Subject: [PATCH] =?UTF-8?q?=E6=8A=8A=20awaitable=20=E5=92=8C=20awaiter=20?= =?UTF-8?q?=E7=9A=84=E6=A6=82=E5=BF=B5=E8=BF=9B=E8=A1=8C=E5=88=86=E7=A6=BB?= =?UTF-8?q?=E3=80=82?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/ucoro/awaitable.hpp | 79 ++++++++++++++++++++++++++----------- tests/test3/test.cpp | 4 +- 2 files changed, 57 insertions(+), 26 deletions(-) diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 415663d..e9bcb64 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -58,6 +58,9 @@ namespace ucoro template struct awaitable; + template + struct awaitable_awaiter; + template struct awaitable_promise; @@ -104,6 +107,14 @@ namespace ucoro { a.await_resume() }; }; + template + concept has_operator_co_await = requires (T a) + { + { a.operator co_await() } -> is_awaiter_v; + }; + // 用于判定 T 是否是一个 awaitable<>::promise_type 的类型, 即: 拥有 local_ 成员。 + template + concept is_awaitable_v = is_awaiter_v || has_operator_co_await; } // namespace detail struct debug_coro_promise @@ -263,15 +274,19 @@ namespace ucoro { return local_storage_awaiter>::value_type>{this}; } - else if constexpr ( detail::is_awaiter_v> ) + else if constexpr ( detail::is_awaitable_v> ) { static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); return std::forward(awaiter); } - else + else if constexpr ( requires (A ) { await_transformer::await_transform; }) { return await_transformer::await_transform(std::move(awaiter)); } + else + { + static_assert(0, "co_await must be called on an awaitable type"); + } } std::coroutine_handle<> continuation_; @@ -331,28 +346,6 @@ namespace ucoro awaitable& operator=(const awaitable&) = delete; awaitable& operator=(awaitable&) = delete; - constexpr bool await_ready() const noexcept - { - return false; - } - - T await_resume() - { - return current_coro_handle_.promise().get_value(); - } - - template - auto await_suspend(std::coroutine_handle continuation) - { - if constexpr (detail::is_instance_of_v) - { - current_coro_handle_.promise().local_ = continuation.promise().local_; - } - - current_coro_handle_.promise().continuation_ = continuation; - return current_coro_handle_; - } - void set_local(std::any local) { assert("local has value" && !current_coro_handle_.promise().local_); @@ -380,9 +373,47 @@ namespace ucoro return launched_coro; } + awaitable_awaiter operator co_await () + { + return awaitable_awaiter{this}; + } + std::coroutine_handle current_coro_handle_; }; + ////////////////////////////////////////////////////////////////////////// + // awaitable 的等待器 + template + struct awaitable_awaiter + { + awaitable* this_; + + constexpr bool await_ready() const noexcept + { + return false; + } + + T await_resume() + { + return this_->current_coro_handle_.promise().get_value(); + } + + template + auto await_suspend(std::coroutine_handle continuation) + { + if constexpr (detail::is_instance_of_v) + { + auto& calee_promise = this_->current_coro_handle_.promise(); + auto& caller_promise = continuation.promise(); + calee_promise.local_ = caller_promise.local_; + } + + this_->current_coro_handle_.promise().continuation_ = continuation; + return this_->current_coro_handle_; + } + + }; + ////////////////////////////////////////////////////////////////////////// template diff --git a/tests/test3/test.cpp b/tests/test3/test.cpp index d151040..65f683f 100644 --- a/tests/test3/test.cpp +++ b/tests/test3/test.cpp @@ -12,8 +12,8 @@ int main(int argc, char **argv) static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine"); static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine"); - static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); - static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); + static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable_awaiter >, "not a coroutine"); + static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable_awaiter >, "not a coroutine"); static_assert(!ucoro::detail::is_awaiter_v < int >, "not a coroutine");