diff --git a/include/async_mqtt/endpoint.hpp b/include/async_mqtt/endpoint.hpp index e9383bcd8..dba6bc4ec 100644 --- a/include/async_mqtt/endpoint.hpp +++ b/include/async_mqtt/endpoint.hpp @@ -2194,24 +2194,51 @@ class basic_endpoint { error_code const& = error_code{} ) { switch (state) { - case initiate: { - ASYNC_MQTT_LOG("mqtt_impl", trace) - << ASYNC_MQTT_ADD_VALUE(address, &ep) - << "close initiate status:" << static_cast(ep.status_); - state = complete; - ep.status_ = connection_status::closing; - auto& a_ep{ep}; - a_ep.stream_->close(force_move(self)); - } break; + case initiate: + switch (ep.status_) { + case connection_status::connecting: + case connection_status::connected: + case connection_status::disconnecting: { + ASYNC_MQTT_LOG("mqtt_impl", trace) + << ASYNC_MQTT_ADD_VALUE(address, &ep) + << "close initiate status:" << static_cast(ep.status_); + state = complete; + ep.status_ = connection_status::closing; + auto& a_ep{ep}; + a_ep.stream_->close(force_move(self)); + } break; + case connection_status::closing: { + ASYNC_MQTT_LOG("mqtt_impl", trace) + << ASYNC_MQTT_ADD_VALUE(address, &ep) + << "already close requested"; + state = complete; + auto& a_ep{ep}; + a_ep.close_queue_.post(force_move(self)); + } break; + case connection_status::closed: + ASYNC_MQTT_LOG("mqtt_impl", trace) + << ASYNC_MQTT_ADD_VALUE(address, &ep) + << "already closed"; + state = complete; + self.complete(); + break; + } + break; case complete: BOOST_ASSERT(ep.strand().running_in_this_thread()); ASYNC_MQTT_LOG("mqtt_impl", trace) << ASYNC_MQTT_ADD_VALUE(address, &ep) - << "close complete status:" << static_cast(ep.status_); + << "close complete status:" << static_cast(ep.status_); ep.tim_pingreq_send_->cancel(); ep.tim_pingreq_recv_->cancel(); ep.tim_pingresp_recv_->cancel(); ep.status_ = connection_status::closed; + while (!ep.close_queue_.stopped()) { + ASYNC_MQTT_LOG("mqtt_impl", trace) + << ASYNC_MQTT_ADD_VALUE(address, &ep) + << "process enqueued close"; + ep.close_queue_.poll_one(); + } self.complete(); break; } @@ -2573,6 +2600,8 @@ class basic_endpoint { std::set publish_recv_; std::deque> publish_queue_; + ioc_queue close_queue_; + std::uint32_t maximum_packet_size_send_{packet_size_no_limit}; std::uint32_t maximum_packet_size_recv_{packet_size_no_limit}; diff --git a/include/async_mqtt/stream.hpp b/include/async_mqtt/stream.hpp index b511238d3..8ca934055 100644 --- a/include/async_mqtt/stream.hpp +++ b/include/async_mqtt/stream.hpp @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -84,8 +85,6 @@ class stream : public std::enable_shared_from_this> { ) ); } - queue_.emplace(); - queue_->stop(); } ~stream() { @@ -372,7 +371,6 @@ class stream : public std::enable_shared_from_this> { std::shared_ptr packet; error_code last_ec = error_code{}; this_type_sp life_keeper = strm.shared_from_this(); - optional> queue_work_guard = nullopt; enum { dispatch, post, write, bind, complete } state = dispatch; template @@ -392,25 +390,18 @@ class stream : public std::enable_shared_from_this> { BOOST_ASSERT(strm.strand_.running_in_this_thread()); state = write; auto& a_strm{strm}; - as::post( - *a_strm.queue_, + a_strm.queue_.post( as::bind_executor( a_strm.strand_, force_move(self) ) ); - if (!a_strm.writing_ && a_strm.queue_->stopped()) { - a_strm.queue_->restart(); - a_strm.queue_->poll_one(); - } } break; case write: { BOOST_ASSERT(strm.strand_.running_in_this_thread()); + strm.queue_.start_work(); if (strm.lowest_layer().is_open()) { state = bind; - BOOST_ASSERT(!strm.writing_); - strm.writing_ = true; - queue_work_guard.emplace(strm.queue_->get_executor()); auto& a_strm{strm}; auto cbs = packet->const_buffer_sequence(); async_write( @@ -449,14 +440,13 @@ class stream : public std::enable_shared_from_this> { ) { if (ec) { BOOST_ASSERT(strm.strand_.running_in_this_thread()); + strm.queue_.stop_work(); auto& a_strm{strm}; as::post( a_strm.strand_, - [&a_strm, &queue = a_strm.queue_, wp = a_strm.weak_from_this()] { + [&a_strm,wp = a_strm.weak_from_this()] { if (auto sp = wp.lock()) { - a_strm.writing_ = false; - if (a_strm.queue_->stopped()) a_strm.queue_->restart(); - queue->poll_one(); + a_strm.queue_.poll_one(); } } ); @@ -476,14 +466,13 @@ class stream : public std::enable_shared_from_this> { switch (state) { case bind: { BOOST_ASSERT(strm.strand_.running_in_this_thread()); + strm.queue_.stop_work(); auto& a_strm{strm}; as::post( a_strm.strand_, - [&a_strm, &queue = a_strm.queue_, wp = a_strm.weak_from_this()] { + [&a_strm, wp = a_strm.weak_from_this()] { if (auto sp = wp.lock()) { - a_strm.writing_ = false; - if (a_strm.queue_->stopped()) a_strm.queue_->restart(); - queue->poll_one(); + a_strm.queue_.poll_one(); } } ); @@ -525,6 +514,7 @@ class stream : public std::enable_shared_from_this> { complete } state = dispatch; error_code last_ec = error_code{}; + this_type_sp life_keeper = strm.shared_from_this(); template void operator()( @@ -743,9 +733,8 @@ class stream : public std::enable_shared_from_this> { private: next_layer_type nl_; strand_type strand_{nl_.get_executor()}; - optional queue_; + ioc_queue queue_; static_vector header_remaining_length_buf_ = static_vector(5); - bool writing_ = false; }; } // namespace async_mqtt diff --git a/include/async_mqtt/util/ioc_queue.hpp b/include/async_mqtt/util/ioc_queue.hpp new file mode 100644 index 000000000..415f9d634 --- /dev/null +++ b/include/async_mqtt/util/ioc_queue.hpp @@ -0,0 +1,63 @@ +// Copyright Takatoshi Kondo 2023 +// +// Distributed under the Boost Software License, Version 1.0. +// (See accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) + +#if !defined(ASYNC_MQTT_UTIL_IOC_QUEUE_HPP) +#define ASYNC_MQTT_UTIL_IOC_QUEUE_HPP + +#include + +#include + +namespace async_mqtt { + +namespace as = boost::asio; + +class ioc_queue { +public: + ioc_queue() { + queue_.stop(); + } + + void start_work() { + working_ = true; + guard_.emplace(queue_.get_executor()); + } + + void stop_work() { + guard_.reset(); + } + + template + void post(CompletionToken&& token) { + as::post( + queue_, + std::forward(token) + ); + if (!working_ && queue_.stopped()) { + queue_.restart(); + queue_.poll_one(); + } + } + + bool stopped() const { + return queue_.stopped(); + } + + void poll_one() { + working_ = false; + if (queue_.stopped()) queue_.restart(); + queue_.poll_one(); + } + +private: + as::io_context queue_; + bool working_ = false; + optional> guard_; +}; + +} // namespace async_mqtt + +#endif // ASYNC_MQTT_UTIL_IOC_QUEUE_HPP