Skip to content

Commit

Permalink
删除 awaitable_detached<> 类。减少代码重复
Browse files Browse the repository at this point in the history
awaitable_detached 自从支持自身被 co_await 后,代码变得越来越像 awaitable<T>
相似度极高。

既然如此,不如合并代码。 detach 的时候,创建一个新的 awaitable<>。

有没有 detach 其实根本原因在于有没有被 co_await.

因此,在 析构函数里,判断是不是 done() 即可知道有没有被 co_await 过了。
如果没有被 co_await 过,则是 detach 模式。在析构里把协程给 resume 起来。

由于析构的时候没有进行 destroy 操作,因此需要在 final_awaiter 里进行清理。
现在 final_awaiter<T> 的部分清理逻辑本身就是从
awaitable_detached::promise_type::final_awaiter 里合并进来的。

同时detach 模式使用 awaitable<T> 后,也修正了 detach 模式下吞异常的问题。
  • Loading branch information
microcai committed Oct 18, 2024
1 parent 5053c27 commit a3c3853
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 189 deletions.
265 changes: 76 additions & 189 deletions include/ucoro/awaitable.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,15 @@ namespace ucoro
template<typename T>
struct local_storage_t
{
typedef T value_type;
typedef void local_storage_type_detect_tag;
};

inline constexpr local_storage_t<void> local_storage;

template <typename T>
concept is_a_local_storage_t = std::is_same_v<typename T::local_storage_type_detect_tag, typename T::local_storage_type_detect_tag>;

//////////////////////////////////////////////////////////////////////////
namespace detail
{
Expand Down Expand Up @@ -135,6 +141,16 @@ namespace ucoro
value_.template emplace<std::exception_ptr>(std::current_exception());
}

T get_value() const
{
if (std::holds_alternative<std::exception_ptr>(value_))
{
std::rethrow_exception(std::get<std::exception_ptr>(value_));
}

return std::get<T>(value_);
}

std::variant<std::exception_ptr, T> value_{nullptr};
};

Expand All @@ -153,126 +169,14 @@ namespace ucoro
{
exception_ = std::current_exception();
}
};

//////////////////////////////////////////////////////////////////////////
// 使用 .detach() 后创建的独立的协程的入口点
// 由它开始链式使用 awaitable<>
template<typename T = void>
struct awaitable_detached
{
awaitable_detached(const awaitable_detached&) = delete;

struct promise_type : public awaitable_promise_value<T>, public debug_coro_promise
{
awaitable_detached get_return_object() noexcept
{
return awaitable_detached{std::coroutine_handle<promise_type>::from_promise(*this)};
}

struct final_awaiter : std::suspend_always
{
std::coroutine_handle<> await_suspend(std::coroutine_handle<promise_type> final_coro_handle) const noexcept
{
if (final_coro_handle.promise().continuation_)
{
// continuation_ 不为空,则 说明 .detach() 被 co_await
// 因此,awaitable_detached 析构的时候会顺便撤销自己,所以这里不用 destory
// 返回 continuation_,以便让协程框架调用 continuation_.resume()
// 这样就把等它的协程唤醒了.
return final_coro_handle.promise().continuation_;
}
// 如果 continuation_ 为空,则说明 .detach() 没有被 co_await
// 因此,awaitable_detached 对象其实已经析构
// 所以必须主动调用 destroy() 以免内存泄漏.
final_coro_handle.destroy();
return std::noop_coroutine();
}
};

auto initial_suspend() noexcept
{
return std::suspend_always{};
}

auto final_suspend() noexcept
{
return final_awaiter{};
}

// 对 detached 的 coro 调用 co_await 相当于 thread.join()
// 因此记录这个 continuation 以便在 final awaiter 里唤醒
std::coroutine_handle<> continuation_;
};

explicit awaitable_detached(std::coroutine_handle<promise_type> promise_handle) noexcept
: current_coro_handle_(promise_handle)
{
}

awaitable_detached(awaitable_detached&& other) noexcept
: current_coro_handle_(other.current_coro_handle_)
{
other.current_coro_handle_ = nullptr;
}

~awaitable_detached() noexcept
void get_value()
{
if (current_coro_handle_)
if (exception_)
{
if (current_coro_handle_.done())
{
current_coro_handle_.destroy();
}
else
{
// 由于 initial_supend 为 suspend_always
// 因此 如果不对 .detach() 的返回值调用 co_await
// 此协程将不会运行。
// 因此,在本对象析构时,协程其实完全没运行过。
// 正因为本对象析构的时候,协程都没有运行,就意味着
// 其实用户只是调用了 .detach() 并没有对返回值进行
// co_await 操作。
// 因此为了能把协程运行起来,这里强制调用 resume
current_coro_handle_.resume();
}
std::rethrow_exception(exception_);
}
}

bool await_ready() noexcept
{
return false;
}

auto await_suspend(std::coroutine_handle<> continuation) noexcept
{
current_coro_handle_.promise().continuation_ = continuation;
return current_coro_handle_;
}

T await_resume()
{
if constexpr (std::is_void_v<T>)
{
auto exception = current_coro_handle_.promise().exception_;
if (exception)
{
std::rethrow_exception(exception);
}
}
else
{
auto ret = std::move(current_coro_handle_.promise().value_);
if (std::holds_alternative<std::exception_ptr>(ret))
{
std::rethrow_exception(std::get<std::exception_ptr>(ret));
}

return std::get<T>(ret);
}
}

std::coroutine_handle<promise_type> current_coro_handle_;
};

//////////////////////////////////////////////////////////////////////////
Expand All @@ -284,8 +188,20 @@ namespace ucoro
{
if (h.promise().continuation_)
{
// continuation_ 不为空,则 说明 .detach() 被 co_await
// 因此,awaitable_detached 析构的时候会顺便撤销自己,所以这里不用 destory
// 返回 continuation_,以便让协程框架调用 continuation_.resume()
// 这样就把等它的协程唤醒了.
return h.promise().continuation_;
}
// 并且,如果协程处于 .detach() 而没有被 co_await
// 则异常一直存储在 promise 里,并没有代码会去调用他的 await_resume() 重抛异常
// 所以这里重新抛出来,避免有被静默吞并的异常
h.promise().get_value();
// 如果 continuation_ 为空,则说明 .detach() 没有被 co_await
// 因此,awaitable_detached 对象其实已经析构
// 所以必须主动调用 destroy() 以免内存泄漏.
h.destroy();
return std::noop_coroutine();
}
};
Expand All @@ -309,36 +225,18 @@ namespace ucoro
return std::suspend_always{};
}

template<typename A> requires (detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
}

template<typename A> requires (!detail::is_awaiter_v<std::decay_t<A>>)
auto await_transform(A&& awaiter) const
{
return await_transformer<A>::await_transform(std::move(awaiter));
}

void set_local(std::any local)
{
local_ = std::make_shared<std::any>(local);
local_ = std::make_shared<std::any>(std::move(local));
}

template<typename localtype>
struct local_storage_awaiter
{
awaitable_promise* this_;
const awaitable_promise* this_;

constexpr bool await_ready() const noexcept
{
return true;
}
void await_suspend(std::coroutine_handle<void>) noexcept
{
}
constexpr bool await_ready() const noexcept { return true; }
constexpr void await_suspend(std::coroutine_handle<>) const noexcept {}

auto await_resume() const noexcept
{
Expand All @@ -353,10 +251,22 @@ namespace ucoro
}
};

template<typename localtype>
auto await_transform(local_storage_t<localtype>)
template<typename A>
auto await_transform(A&& awaiter) const
{
return local_storage_awaiter<localtype>{this};
if constexpr ( is_a_local_storage_t<std::decay_t<A>> )
{
return local_storage_awaiter<typename std::decay_t<A>::value_type>{this};
}
else if constexpr ( detail::is_awaiter_v<std::decay_t<A>> )
{
static_assert(std::is_rvalue_reference_v<decltype(awaiter)>, "co_await must be used on rvalue");
return std::forward<A>(awaiter);
}
else
{
return await_transformer<A>::await_transform(std::move(awaiter));
}
}

std::coroutine_handle<> continuation_;
Expand All @@ -371,19 +281,28 @@ namespace ucoro
{
using promise_type = awaitable_promise<T>;

explicit awaitable(std::coroutine_handle<promise_type> h) : current_coro_handle_(h)
explicit awaitable(std::coroutine_handle<promise_type> h)
: current_coro_handle_(h)
{
}

~awaitable()
{
if (current_coro_handle_ && current_coro_handle_.done())
if (current_coro_handle_)
{
current_coro_handle_.destroy();
if (current_coro_handle_.done())
{
current_coro_handle_.destroy();
}
else
{
current_coro_handle_.resume();
}
}
}

awaitable(awaitable&& t) noexcept : current_coro_handle_(t.current_coro_handle_)
awaitable(awaitable&& t) noexcept
: current_coro_handle_(t.current_coro_handle_)
{
t.current_coro_handle_ = nullptr;
}
Expand All @@ -407,60 +326,24 @@ namespace ucoro
awaitable& operator=(const awaitable&) = delete;
awaitable& operator=(awaitable&) = delete;

T operator()()
{
return get();
}

T get()
{
if constexpr (!std::is_void_v<T>)
{
return std::move(current_coro_handle_.promise().value_);
}
}

constexpr bool await_ready() const noexcept
{
return false;
}

T await_resume()
{
if constexpr (std::is_void_v<T>)
{
auto exception = current_coro_handle_.promise().exception_;
if (exception)
{
std::rethrow_exception(exception);
}

current_coro_handle_.destroy();
current_coro_handle_ = nullptr;
}
else
{
auto ret = std::move(current_coro_handle_.promise().value_);
if (std::holds_alternative<std::exception_ptr>(ret))
{
auto exception = std::get<std::exception_ptr>(ret);
assert(exception && "The exception must not be nullptr!");
std::rethrow_exception(exception);
}

current_coro_handle_.destroy();
current_coro_handle_ = nullptr;

return std::get<T>(ret);
}
return current_coro_handle_.promise().get_value();
}

template<typename PromiseType>
auto await_suspend(std::coroutine_handle<PromiseType> continuation)
{
if constexpr (detail::is_awaitable_promise_type_v<PromiseType>)
{
current_coro_handle_.promise().local_ = continuation.promise().local_;
auto& calee_promise = current_coro_handle_.promise();
auto& caller_promise = continuation.promise();
calee_promise.local_ = caller_promise.local_;
}

current_coro_handle_.promise().continuation_ = continuation;
Expand All @@ -470,24 +353,28 @@ namespace ucoro
void set_local(std::any local)
{
assert("local has value" && !current_coro_handle_.promise().local_);
current_coro_handle_.promise().set_local(local);
current_coro_handle_.promise().set_local(std::move(local));
}

auto detach()
{
auto launch_coro = [](awaitable<T> lazy) -> awaitable_detached<T> { co_return co_await lazy; };
auto launch_coro = [](awaitable<T> lazy) -> awaitable<T> { co_return co_await std::move(lazy); };
return launch_coro(std::move(*this));
}

auto detach(std::any local)
{
auto launched_coro = [](awaitable<T> lazy) mutable -> awaitable<T>
{
co_return co_await std::move(lazy);
}(std::move(*this));

if (local.has_value())
{
set_local(local);
launched_coro.set_local(local);
}

auto launch_coro = [](awaitable<T> lazy) -> awaitable_detached<T> { co_return co_await lazy; };
return launch_coro(std::move(*this));
return launched_coro;
}

std::coroutine_handle<promise_type> current_coro_handle_;
Expand Down
2 changes: 2 additions & 0 deletions tests/test3/test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ int main(int argc, char **argv)
using CallbackAwaiterType0 = ucoro::CallbackAwaiter<void, decltype([](auto h) {}) >;
using CallbackAwaiterType1 = ucoro::CallbackAwaiter<int, decltype([](auto h) {}) > ;

static_assert(ucoro::is_a_local_storage_t<ucoro::local_storage_t<void>>, "not a local_storage_t");

static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine");
static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine");

Expand Down

0 comments on commit a3c3853

Please sign in to comment.