diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index d641709..63c55c9 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -58,6 +58,9 @@ namespace ucoro template struct awaitable; + template + struct awaitable_awaiter; + template struct awaitable_promise; @@ -107,6 +110,14 @@ namespace ucoro { a.await_resume() }; }; + template + concept has_operator_co_await = requires (T a) + { + { a.operator co_await() } -> is_awaiter_v; + }; + // 用于判定 T 是否是一个 awaitable<>::promise_type 的类型, 即: 拥有 local_ 成员。 + template + concept is_awaitable_v = is_awaiter_v || has_operator_co_await; } // namespace detail struct debug_coro_promise @@ -266,15 +277,19 @@ namespace ucoro { return local_storage_awaiter>::value_type>{this}; } - else if constexpr ( detail::is_awaiter_v> ) + else if constexpr ( detail::is_awaitable_v> ) { static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); return std::forward(awaiter); } - else + else if constexpr ( requires (A ) { await_transformer::await_transform; }) { return await_transformer::await_transform(std::move(awaiter)); } + else + { + static_assert(0, "co_await must be called on an awaitable type"); + } } std::coroutine_handle<> continuation_; @@ -334,28 +349,6 @@ namespace ucoro awaitable& operator=(const awaitable&) = delete; awaitable& operator=(awaitable&) = delete; - constexpr bool await_ready() const noexcept - { - return false; - } - - T await_resume() - { - return current_coro_handle_.promise().get_value(); - } - - template - auto await_suspend(std::coroutine_handle continuation) - { - if constexpr (detail::is_instance_of_v) - { - current_coro_handle_.promise().local_ = continuation.promise().local_; - } - - current_coro_handle_.promise().continuation_ = continuation; - return current_coro_handle_; - } - void set_local(std::any local) { assert("local has value" && !current_coro_handle_.promise().local_); @@ -383,9 +376,47 @@ namespace ucoro return launched_coro; } + awaitable_awaiter operator co_await () + { + return awaitable_awaiter{this}; + } + std::coroutine_handle current_coro_handle_; }; + ////////////////////////////////////////////////////////////////////////// + // awaitable 的等待器 + template + struct awaitable_awaiter + { + awaitable* this_; + + constexpr bool await_ready() const noexcept + { + return false; + } + + T await_resume() + { + return this_->current_coro_handle_.promise().get_value(); + } + + template + auto await_suspend(std::coroutine_handle continuation) + { + if constexpr (detail::is_instance_of_v) + { + auto& calee_promise = this_->current_coro_handle_.promise(); + auto& caller_promise = continuation.promise(); + calee_promise.local_ = caller_promise.local_; + } + + this_->current_coro_handle_.promise().continuation_ = continuation; + return this_->current_coro_handle_; + } + + }; + ////////////////////////////////////////////////////////////////////////// template diff --git a/tests/test3/test.cpp b/tests/test3/test.cpp index d151040..65f683f 100644 --- a/tests/test3/test.cpp +++ b/tests/test3/test.cpp @@ -12,8 +12,8 @@ int main(int argc, char **argv) static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine"); static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine"); - static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); - static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); + static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable_awaiter >, "not a coroutine"); + static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable_awaiter >, "not a coroutine"); static_assert(!ucoro::detail::is_awaiter_v < int >, "not a coroutine");