diff --git a/include/ucoro/awaitable.hpp b/include/ucoro/awaitable.hpp index 415663d..91bbcc2 100644 --- a/include/ucoro/awaitable.hpp +++ b/include/ucoro/awaitable.hpp @@ -72,39 +72,112 @@ namespace ucoro inline constexpr local_storage_t local_storage; ////////////////////////////////////////////////////////////////////////// - namespace detail + namespace concepts { + ////////////////////////////////////////////////////////////////////////// // 用于判定 T 是否是一个 U 的类型 // 比如 // is_instance_of_v,std::vector>; // true // is_instance_of_v,std::list>; // false - template class U> + // + // 首先定义一个接受 is_instance_of_v 这样的一个默认模板萃取 + template typename U> inline constexpr bool is_instance_of_v = std::false_type{}; - template class U, class... Vs> - inline constexpr bool is_instance_of_v,U> = std::true_type{}; - template - struct local_storage_value_type; + // 接着为 is_instance_of_v, class_type> + // 这种定义一个偏特化,于是把符合这个模式的特殊参数给匹配到这个偏特化来了 + template typename class_type, typename... parameters> + inline constexpr bool is_instance_of_v, class_type> = std::true_type{}; - template - struct local_storage_value_type> - { - typedef ValueType value_type; - }; + // 然后把模板偏特化的萃取重新定义为一个 concept + template typename U> + concept is_instance_of = is_instance_of_v::type, U>; + // 再定义一个直接用来测试 local_storage_t<> 的辅助 template - concept is_valid_await_suspend_return_value = - std::convertible_to> || std::is_void_v || std::is_same_v; + concept LocalStorage = is_instance_of; + + // 再定义一个直接用来测试 awaitable<> 的辅助 + template + concept awaitable_type = is_instance_of; + + // 再定义一个直接用来测试 awaitable_promise<> 的辅助 + template + concept awaitable_promise_type = is_instance_of; + + // await_suspend 有三种返回值 + template + concept is_valid_await_suspend_return_value = std::convertible_to> || + std::is_void_v || + std::is_same_v; // 用于判定 T 是否是一个 awaiter 的类型, 即: 拥有 await_ready,await_suspend,await_resume 成员函数的结构或类. template concept is_awaiter_v = requires (T a) { - { a.await_ready() } -> std::convertible_to; + { a.await_ready() } -> std::same_as; { a.await_suspend(std::coroutine_handle<>{}) } -> is_valid_await_suspend_return_value; { a.await_resume() }; }; - } // namespace detail + template + concept has_operator_co_await = requires (T a) + { + { a.operator co_await() } -> is_awaiter_v; + }; + + // 用于判定 T 是可以用在 co_await 后面 + template + concept is_awaitable_v = is_awaiter_v> || + awaitable_type || + has_operator_co_await>; + + + template + concept has_user_defined_await_transformer = requires (T&& a) + { + await_transformer::await_transform(std::move(a)); + }; + + + } // namespace concepts + + namespace traits + { + ////////////////////////////////////////////////////////////////////////// + // 用于从 A = U 类型里提取 T 参数 + // 比如 + // template_parameter_of, local_storage_t>; // int + // template_parameter_of; // void + // + // 首先定义一个接受 template_parameter_of 这样的一个默认模板萃取 + template typename FromTemplate> + struct template_parameter_traits; + + // 接着定义一个偏特化,匹配 template_parameter_traits<模板名<参数>, 模板名> + // 这样,这个偏特化的 template_parameter_traits 就有了一个 + // 名为 template_parameter 的成员类型,其定义的类型就是 _template_parameter + // 于是就把 _template_parameter 这个类型给萃取出来了 + template typename class_template, typename _template_parameter> + struct template_parameter_traits, class_template> + { + using template_parameter = _template_parameter ; + }; + + // 最后,定义一个简化用法的 using 让用户的地方代码变短点 + template typename FromTemplate> + using template_parameter_of = typename template_parameter_traits< + std::decay_t, FromTemplate>::template_parameter; + + // 利用 通用工具 template_parameter_of 萃取 local_storage_t 里的 T + template + using local_storage_value_type = template_parameter_of; + + + // 利用 通用工具 template_parameter_of 萃取 awaitable 里的 T + template + using awaitable_return_type = template_parameter_of; + + } // namespace traits struct debug_coro_promise { @@ -259,19 +332,27 @@ namespace ucoro template auto await_transform(A&& awaiter) const { - if constexpr ( detail::is_instance_of_v, local_storage_t> ) + if constexpr (concepts::is_instance_of) { - return local_storage_awaiter>::value_type>{this}; + // 调用 co_await local_storage_t + return local_storage_awaiter>{this}; } - else if constexpr ( detail::is_awaiter_v> ) + else if constexpr (concepts::is_awaitable_v) { - static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); + // 调用 co_await awaitable; 或者其他有三件套的类型 + static_assert(std::is_rvalue_reference_v, "co_await must be used on rvalue"); return std::forward(awaiter); } - else + else if constexpr ( concepts::has_user_defined_await_transformer ) { + // 调用 co_await 其他写了 await_transformer 的自定义类型. + // 例如包含了 asio_glue.hpp 后,就可以 co_await asio::awaitable; return await_transformer::await_transform(std::move(awaiter)); } + else + { + static_assert(0, "co_await must been used on an awaitable"); + } } std::coroutine_handle<> continuation_; @@ -344,7 +425,7 @@ namespace ucoro template auto await_suspend(std::coroutine_handle continuation) { - if constexpr (detail::is_instance_of_v) + if constexpr (concepts::is_instance_of) { current_coro_handle_.promise().local_ = continuation.promise().local_; } diff --git a/tests/test3/test.cpp b/tests/test3/test.cpp index d151040..3d035c3 100644 --- a/tests/test3/test.cpp +++ b/tests/test3/test.cpp @@ -7,15 +7,19 @@ int main(int argc, char **argv) using CallbackAwaiterType0 = ucoro::CallbackAwaiter; using CallbackAwaiterType1 = ucoro::CallbackAwaiter ; - static_assert(ucoro::detail::is_instance_of_v, ucoro::local_storage_t>, "not a local_storage_t"); + static_assert(ucoro::concepts::is_instance_of_v, ucoro::local_storage_t>, "not a local_storage_t"); - static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine"); - static_assert(ucoro::detail::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine"); + using local_storage_template_parameter = ucoro::traits::template_parameter_of; - static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); - static_assert(ucoro::detail::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); + static_assert(std::is_void_v, "local_storage is not local_storage_t"); - static_assert(!ucoro::detail::is_awaiter_v < int >, "not a coroutine"); + static_assert(ucoro::concepts::is_awaiter_v < CallbackAwaiterType0 >, "not a coroutine"); + static_assert(ucoro::concepts::is_awaiter_v < CallbackAwaiterType1 >, "not a coroutine"); + + static_assert(ucoro::concepts::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); + static_assert(ucoro::concepts::is_awaiter_v < ucoro::awaitable >, "not a coroutine"); + + static_assert(!ucoro::concepts::is_awaiter_v < int >, "not a coroutine"); return 0; }