Skip to content

Commit

Permalink
把 awaitable 和 awaiter 的概念进行分离。
Browse files Browse the repository at this point in the history
  • Loading branch information
microcai committed Oct 18, 2024
1 parent cf01f4e commit 7c0a5a6
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 26 deletions.
79 changes: 55 additions & 24 deletions include/ucoro/awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ namespace ucoro
template<typename T>
struct awaitable;

template<typename T>
struct awaitable_awaiter;

template<typename T>
struct awaitable_promise;

Expand Down Expand Up @@ -107,6 +110,14 @@ namespace ucoro
{ a.await_resume() };
};

template<typename T>
concept has_operator_co_await = requires (T a)
{
{ a.operator co_await() } -> is_awaiter_v;
};
// 用于判定 T 是否是一个 awaitable<>::promise_type 的类型, 即: 拥有 local_ 成员。
template<typename T>
concept is_awaitable_v = is_awaiter_v<T> || has_operator_co_await<T>;
} // namespace detail

struct debug_coro_promise
Expand Down Expand Up @@ -266,15 +277,19 @@ namespace ucoro
{
return local_storage_awaiter<typename detail::local_storage_value_type<std::decay_t<A>>::value_type>{this};
}
else if constexpr ( detail::is_awaiter_v<std::decay_t<A>> )
else if constexpr ( detail::is_awaitable_v<std::decay_t<A>> )
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
}
else
else if constexpr ( requires (A ) { await_transformer<A>::await_transform; })
{
return await_transformer<A>::await_transform(std::move(awaiter));
}
else
{
static_assert(0, "co_await must be called on an awaitable type");
}
}

std::coroutine_handle<> continuation_;
Expand Down Expand Up @@ -334,28 +349,6 @@ namespace ucoro
awaitable& operator=(const awaitable&) = delete;
awaitable& operator=(awaitable&) = delete;

constexpr bool await_ready() const noexcept
{
return false;
}

T await_resume()
{
return current_coro_handle_.promise().get_value();
}

template<typename PromiseType>
auto await_suspend(std::coroutine_handle<PromiseType> continuation)
{
if constexpr (detail::is_instance_of_v<PromiseType, awaitable_promise>)
{
current_coro_handle_.promise().local_ = continuation.promise().local_;
}

current_coro_handle_.promise().continuation_ = continuation;
return current_coro_handle_;
}

void set_local(std::any local)
{
assert("local has value" && !current_coro_handle_.promise().local_);
Expand Down Expand Up @@ -383,9 +376,47 @@ namespace ucoro
return launched_coro;
}

awaitable_awaiter<T> operator co_await ()
{
return awaitable_awaiter<T>{this};
}

std::coroutine_handle<promise_type> current_coro_handle_;
};

//////////////////////////////////////////////////////////////////////////
// awaitable 的等待器
template<typename T>
struct awaitable_awaiter
{
awaitable<T>* this_;

constexpr bool await_ready() const noexcept
{
return false;
}

T await_resume()
{
return this_->current_coro_handle_.promise().get_value();
}

template<typename PromiseType>
auto await_suspend(std::coroutine_handle<PromiseType> continuation)
{
if constexpr (detail::is_instance_of_v<PromiseType, awaitable_promise>)
{
auto& calee_promise = this_->current_coro_handle_.promise();
auto& caller_promise = continuation.promise();
calee_promise.local_ = caller_promise.local_;
}

this_->current_coro_handle_.promise().continuation_ = continuation;
return this_->current_coro_handle_;
}

};

//////////////////////////////////////////////////////////////////////////

template<typename T>
Expand Down
4 changes: 2 additions & 2 deletions tests/test3/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ int main(int argc, char **argv)
static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine");
static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine");

static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable<void> >, "not a coroutine");
static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable<int> >, "not a coroutine");
static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable_awaiter<void> >, "not a coroutine");
static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable_awaiter<int> >, "not a coroutine");

static_assert(!ucoro::detail::is_awaiter_v < int >, "not a coroutine");

Expand Down

0 comments on commit 7c0a5a6

Please sign in to comment.