diff --git a/include/async_mqtt/broker/endpoint_variant.hpp b/include/async_mqtt/broker/endpoint_variant.hpp index 1fc39a71b..f1815759f 100644 --- a/include/async_mqtt/broker/endpoint_variant.hpp +++ b/include/async_mqtt/broker/endpoint_variant.hpp @@ -119,7 +119,6 @@ class epsp_wrap { ); } -#if 0 decltype(auto) strand() const { return visit( [&](auto& ep) -> decltype(auto) { @@ -127,6 +126,7 @@ class epsp_wrap { } ); } + decltype(auto) strand() { return visit( [&](auto& ep) -> decltype(auto) { @@ -134,7 +134,15 @@ class epsp_wrap { } ); } -#endif + + bool in_strand() const { + return visit( + [&](auto& ep) { + return ep.in_strand(); + } + ); + } + // async functions template diff --git a/include/async_mqtt/endpoint.hpp b/include/async_mqtt/endpoint.hpp index 45c66cbe4..1354da4da 100644 --- a/include/async_mqtt/endpoint.hpp +++ b/include/async_mqtt/endpoint.hpp @@ -148,14 +148,23 @@ class basic_endpoint : public std::enable_shared_from_thisstrand(); } + /** * @brief strand getter - * @return eference of the strand + * @return reference of the strand */ strand_type& strand() { return stream_->strand(); } + /** + * @brief strand checker + * @return true if the current context running in the strand, otherwise false + */ + bool in_strand() const { + return stream_->in_strand(); + } + /** * @brief next_layer getter * @return const reference of the next_layer @@ -586,7 +595,7 @@ class basic_endpoint : public std::enable_shared_from_this acquire_unique_packet_id() { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); auto pid = pid_man_.acquire_unique_id(); if (pid) { ASYNC_MQTT_LOG("mqtt_api", info) @@ -608,7 +617,7 @@ class basic_endpoint : public std::enable_shared_from_this get_qos2_publish_handled_pids() const { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); ASYNC_MQTT_LOG("mqtt_api", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "get_qos2_publish_handled_pids"; @@ -650,7 +659,7 @@ class basic_endpoint : public std::enable_shared_from_this pids) { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); ASYNC_MQTT_LOG("mqtt_api", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "restore_qos2_publish_handled_pids"; @@ -666,7 +675,7 @@ class basic_endpoint : public std::enable_shared_from_this> pvs ) { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); ASYNC_MQTT_LOG("mqtt_api", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "restore_packets"; @@ -697,7 +706,7 @@ class basic_endpoint : public std::enable_shared_from_this> get_stored_packets() const { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); ASYNC_MQTT_LOG("mqtt_api", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "get_stored_packets"; @@ -710,7 +719,7 @@ class basic_endpoint : public std::enable_shared_from_this& packet) const { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); ASYNC_MQTT_LOG("mqtt_api", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "regulate_for_store:" << packet; @@ -757,14 +766,14 @@ class basic_endpoint : public std::enable_shared_from_thiscancel(); tim_pingreq_recv_->cancel(); tim_pingresp_recv_->cancel(); } void set_pingreq_send_interval_ms_for_test(std::size_t ms) { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); pingreq_send_interval_ms_ = ms; } @@ -806,12 +815,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case complete: - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); self.complete(ep.pid_man_.acquire_unique_id()); break; } @@ -832,12 +841,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case complete: - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); self.complete(ep.pid_man_.register_id(packet_id)); break; } @@ -858,12 +867,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case complete: - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); ep.pid_man_.release_id(packet_id); self.complete(); break; @@ -898,12 +907,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case write: { - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); state = complete; if constexpr( std::is_same_v, basic_packet_variant> || @@ -950,7 +959,7 @@ class basic_endpoint : public std::enable_shared_from_this optional validate_topic_alias(Self& self, optional ta_opt) { - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); if (!ta_opt) { self.complete( make_error( @@ -1509,7 +1518,7 @@ class basic_endpoint : public std::enable_shared_from_thisread_packet(force_move(self)); } break; case complete: { - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); if (buf.size() > ep.maximum_packet_size_recv_) { // on v3.1.1 maximum_packet_size_recv_ is initialized as packet_size_no_limit BOOST_ASSERT(ep.protocol_version_ == protocol_version::v5); @@ -2093,7 +2102,7 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } @@ -2136,7 +2145,7 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; @@ -2261,7 +2270,7 @@ class basic_endpoint : public std::enable_shared_from_this(ep.status_); @@ -2294,12 +2303,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case complete: - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); ep.restore_packets(force_move(pvs)); self.complete(); break; @@ -2320,12 +2329,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case complete: - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); self.complete(ep.get_stored_packets()); break; } @@ -2346,12 +2355,12 @@ class basic_endpoint : public std::enable_shared_from_thisraw_strand(), force_move(self) ); } break; case complete: - BOOST_ASSERT(ep.strand().running_in_this_thread()); + BOOST_ASSERT(ep.in_strand()); ep.regulate_for_store(packet); self.complete(force_move(packet)); break; @@ -2383,7 +2392,7 @@ class basic_endpoint : public std::enable_shared_from_this& packet) { - BOOST_ASSERT(strand().running_in_this_thread()); + BOOST_ASSERT(in_strand()); if (packet.opts().get_qos() == qos::at_least_once || packet.opts().get_qos() == qos::exactly_once ) { @@ -2403,7 +2412,7 @@ class basic_endpoint : public std::enable_shared_from_this const& pv) { if (pv.size() > maximum_packet_size_send_) { @@ -2447,7 +2456,7 @@ class basic_endpoint : public std::enable_shared_from_thiscancel(); if (status_ == connection_status::disconnecting || @@ -2501,7 +2510,7 @@ class basic_endpoint : public std::enable_shared_from_thiscancel(); if (status_ == connection_status::disconnecting || @@ -2551,7 +2560,7 @@ class basic_endpoint : public std::enable_shared_from_thiscancel(); if (status_ == connection_status::disconnecting || @@ -2618,7 +2627,7 @@ class basic_endpoint : public std::enable_shared_from_this pid_pubcomp_; bool need_store_ = false; - store store_{strand()}; + store> store_{stream_->raw_strand()}; bool auto_pub_response_ = false; bool auto_ping_response_ = false; @@ -2646,9 +2655,9 @@ class basic_endpoint : public std::enable_shared_from_this pingreq_recv_timeout_ms_; optional pingresp_recv_timeout_ms_; - std::shared_ptr tim_pingreq_send_{std::make_shared(strand())}; - std::shared_ptr tim_pingreq_recv_{std::make_shared(strand())}; - std::shared_ptr tim_pingresp_recv_{std::make_shared(strand())}; + std::shared_ptr tim_pingreq_send_{std::make_shared(stream_->raw_strand())}; + std::shared_ptr tim_pingreq_recv_{std::make_shared(stream_->raw_strand())}; + std::shared_ptr tim_pingresp_recv_{std::make_shared(stream_->raw_strand())}; std::set qos2_publish_handled_; diff --git a/include/async_mqtt/stream.hpp b/include/async_mqtt/stream.hpp index bbb39d050..179db7628 100644 --- a/include/async_mqtt/stream.hpp +++ b/include/async_mqtt/stream.hpp @@ -65,7 +65,8 @@ class stream : public std::enable_shared_from_this> { using this_type_sp = std::shared_ptr; using next_layer_type = typename std::remove_reference::type; using executor_type = async_mqtt::executor_type; - using strand_type = as::strand; + using raw_strand_type = as::strand; + using strand_type = as::strand; template friend class make_shared_helper; @@ -150,10 +151,23 @@ class stream : public std::enable_shared_from_this> { strand_type const& strand() const { return strand_; } + strand_type& strand() { return strand_; } + raw_strand_type const& raw_strand() const { + return raw_strand_; + }; + + raw_strand_type& raw_strand() { + return raw_strand_; + }; + + bool in_strand() const { + return raw_strand().running_in_this_thread(); + } + template typename as::async_result, void(error_code)>::return_type close(CompletionToken&& token) { @@ -215,12 +229,12 @@ class stream : public std::enable_shared_from_this> { state = header; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ); } break; case header: { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); // read fixed_header auto address = &strm.header_remaining_length_buf_[received]; auto& a_strm{strm}; @@ -228,7 +242,7 @@ class stream : public std::enable_shared_from_this> { a_strm.nl_, as::buffer(address, 1), as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ) ); @@ -255,7 +269,7 @@ class stream : public std::enable_shared_from_this> { std::size_t bytes_transferred ) { if (ec) { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); auto exe = as::get_associated_executor(self); if constexpr (is_strand>()) { state = complete; @@ -272,7 +286,7 @@ class stream : public std::enable_shared_from_this> { switch (state) { case header: - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); BOOST_ASSERT(bytes_transferred == 1); state = remaining_length; ++received; @@ -284,14 +298,14 @@ class stream : public std::enable_shared_from_this> { a_strm.nl_, as::buffer(address, 1), as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ) ); } break; case remaining_length: - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); BOOST_ASSERT(bytes_transferred == 1); ++received; if (strm.header_remaining_length_buf_[received - 1] & 0b10000000) { @@ -314,7 +328,7 @@ class stream : public std::enable_shared_from_this> { a_strm.nl_, as::buffer(address, 1), as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ) ); @@ -350,7 +364,7 @@ class stream : public std::enable_shared_from_this> { a_strm.nl_, as::buffer(address, rl), as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ) ); @@ -358,7 +372,7 @@ class stream : public std::enable_shared_from_this> { } break; case bind: { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); auto exe = as::get_associated_executor(self); if constexpr (is_strand>()) { state = complete; @@ -396,23 +410,23 @@ class stream : public std::enable_shared_from_this> { state = post; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ); } break; case post: { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); state = write; auto& a_strm{strm}; a_strm.queue_.post( as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ) ); } break; case write: { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); strm.queue_.start_work(); if (strm.lowest_layer().is_open()) { state = bind; @@ -422,7 +436,7 @@ class stream : public std::enable_shared_from_this> { a_strm.nl_, cbs, as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, force_move(self) ) ); @@ -431,7 +445,7 @@ class stream : public std::enable_shared_from_this> { state = bind; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, as::append( force_move(self), errc::make_error_code(errc::connection_reset), @@ -453,11 +467,11 @@ class stream : public std::enable_shared_from_this> { std::size_t bytes_transferred ) { if (ec) { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); strm.queue_.stop_work(); auto& a_strm{strm}; as::post( - a_strm.strand_, + a_strm.raw_strand_, [&a_strm,wp = a_strm.weak_from_this()] { if (auto sp = wp.lock()) { a_strm.queue_.poll_one(); @@ -479,11 +493,11 @@ class stream : public std::enable_shared_from_this> { } switch (state) { case bind: { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); strm.queue_.stop_work(); auto& a_strm{strm}; as::post( - a_strm.strand_, + a_strm.raw_strand_, [&a_strm, wp = a_strm.weak_from_this()] { if (auto sp = wp.lock()) { a_strm.queue_.poll_one(); @@ -503,7 +517,7 @@ class stream : public std::enable_shared_from_this> { self.complete(ec, bytes_transferred); } break; case complete: - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); if (last_ec) { self.complete(last_ec, 0); } @@ -538,7 +552,7 @@ class stream : public std::enable_shared_from_this> { state = close; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, as::append( force_move(self), error_code{}, @@ -571,7 +585,7 @@ class stream : public std::enable_shared_from_this> { state = close; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, as::append( force_move(self), error_code{}, @@ -585,7 +599,7 @@ class stream : public std::enable_shared_from_this> { stream.get().async_read( *buffer, as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, as::append( as::consign( force_move(self), @@ -608,7 +622,7 @@ class stream : public std::enable_shared_from_this> { ) { switch (state) { case close: { - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); if constexpr(is_ws::value) { if (stream.get().is_open()) { state = drop1; @@ -616,7 +630,7 @@ class stream : public std::enable_shared_from_this> { stream.get().async_close( bs::websocket::close_code::none, as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, as::append( force_move(self), force_move(stream) @@ -628,7 +642,7 @@ class stream : public std::enable_shared_from_this> { state = close; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, as::append( force_move(self), error_code{}, @@ -642,10 +656,10 @@ class stream : public std::enable_shared_from_this> { ASYNC_MQTT_LOG("mqtt_impl", info) << ASYNC_MQTT_ADD_VALUE(address, this) << "TLS async_shutdown start with timeout"; - auto tim = std::make_shared(a_strm.strand_, shutdown_timeout); + auto tim = std::make_shared(a_strm.raw_strand_, shutdown_timeout); tim->async_wait( as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, [this, &next_layer = stream.get().next_layer()] (error_code const& ec) { if (!ec) { ASYNC_MQTT_LOG("mqtt_impl", info) @@ -659,7 +673,7 @@ class stream : public std::enable_shared_from_this> { ); stream.get().async_shutdown( as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, as::append( as::consign( force_move(self), @@ -706,7 +720,7 @@ class stream : public std::enable_shared_from_this> { state = close; auto& a_strm{strm}; as::dispatch( - a_strm.strand_, + a_strm.raw_strand_, as::append( force_move(self), error_code{}, @@ -721,7 +735,7 @@ class stream : public std::enable_shared_from_this> { stream.get().async_read( *buffer, as::bind_executor( - a_strm.strand_, + a_strm.raw_strand_, as::append( as::consign( force_move(self), @@ -734,7 +748,7 @@ class stream : public std::enable_shared_from_this> { } } break; case complete: - BOOST_ASSERT(strm.strand_.running_in_this_thread()); + BOOST_ASSERT(strm.in_strand()); self.complete(last_ec); break; default: @@ -746,7 +760,8 @@ class stream : public std::enable_shared_from_this> { private: next_layer_type nl_; - strand_type strand_{nl_.get_executor()}; + raw_strand_type raw_strand_{nl_.get_executor()}; + strand_type strand_{as::any_io_executor{raw_strand_}}; ioc_queue queue_; static_vector header_remaining_length_buf_ = static_vector(5); };