Skip to content

Commit

Permalink
Added mutex guard to timers.
Browse files Browse the repository at this point in the history
  • Loading branch information
redboltz committed Apr 7, 2023
1 parent beb1323 commit 2c99522
Show file tree
Hide file tree
Showing 5 changed files with 138 additions and 85 deletions.
33 changes: 21 additions & 12 deletions include/mqtt/broker/session_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ struct session_state {
<< MQTT_ADD_VALUE(address, this)
<< "session expiry interval timer set";

std::lock_guard<mutex> g(mtx_tim_session_expiry_);
tim_session_expiry_ = std::make_shared<as::steady_timer>(timer_ioc_, session_expiry_interval_.value());
tim_session_expiry_->async_wait(
[this, wp = std::weak_ptr<as::steady_timer>(tim_session_expiry_), h = std::forward<SessionExpireHandler>(h)]
Expand All @@ -188,6 +189,7 @@ struct session_state {
<< MQTT_ADD_VALUE(address, this)
<< "renew_session expiry";
session_expiry_interval_ = force_move(v);
std::lock_guard<mutex> g(mtx_tim_session_expiry_);
tim_session_expiry_.reset();
}

Expand Down Expand Up @@ -383,6 +385,7 @@ struct session_state {
as::io_context& timer_ioc,
optional<MQTT_NS::will> will,
optional<std::chrono::steady_clock::duration> will_expiry_interval) {
std::lock_guard<mutex> g(mtx_tim_will_expiry_);
tim_will_expiry_.reset();
will_value_ = force_move(will);

Expand All @@ -405,6 +408,7 @@ struct session_state {
MQTT_LOG("mqtt_broker", trace)
<< MQTT_ADD_VALUE(address, this)
<< "clear will. cid:" << client_id_;
std::lock_guard<mutex> g(mtx_tim_will_expiry_);
tim_will_expiry_.reset();
will_value_ = nullopt;
}
Expand Down Expand Up @@ -543,18 +547,21 @@ struct session_state {
auto opts = will_value_.value().get_qos() | will_value_.value().get_retain();
auto props = force_move(will_value_.value().props());
will_value_ = nullopt;
if (tim_will_expiry_) {
auto d =
std::chrono::duration_cast<std::chrono::seconds>(
tim_will_expiry_->expiry() - std::chrono::steady_clock::now()
).count();
if (d < 0) d = 0;
set_property<v5::property::message_expiry_interval>(
props,
v5::property::message_expiry_interval(
static_cast<uint32_t>(d)
)
);
{
std::shared_lock<mutex> g(mtx_tim_will_expiry_);
if (tim_will_expiry_) {
auto d =
std::chrono::duration_cast<std::chrono::seconds>(
tim_will_expiry_->expiry() - std::chrono::steady_clock::now()
).count();
if (d < 0) d = 0;
set_property<v5::property::message_expiry_interval>(
props,
v5::property::message_expiry_interval(
static_cast<uint32_t>(d)
)
);
}
}
if (will_sender_) {
will_sender_(
Expand All @@ -571,6 +578,7 @@ struct session_state {
friend class session_states;

as::io_context& timer_ioc_;
mutex mtx_tim_will_expiry_;
std::shared_ptr<as::steady_timer> tim_will_expiry_;
optional<MQTT_NS::will> will_value_;

Expand All @@ -584,6 +592,7 @@ struct session_state {
std::string username_;

optional<std::chrono::steady_clock::duration> session_expiry_interval_;
mutex mtx_tim_session_expiry_;
std::shared_ptr<as::steady_timer> tim_session_expiry_;

mutable mutex mtx_inflight_messages_;
Expand Down
147 changes: 90 additions & 57 deletions include/mqtt/client.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1020,7 +1020,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
*/
void set_keep_alive_sec(std::uint16_t keep_alive_sec, std::chrono::steady_clock::duration ping) {
if ((ping_duration_ != std::chrono::steady_clock::duration::zero()) && base::connected() && (ping == std::chrono::steady_clock::duration::zero())) {
tim_ping_.cancel();
cancel_ping_timer();
}
keep_alive_sec_ = keep_alive_sec;
ping_duration_ = force_move(ping);
Expand Down Expand Up @@ -1442,23 +1442,26 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
v5::disconnect_reason_code reason_code = v5::disconnect_reason_code::normal_disconnection,
v5::properties props = {}
) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
if (base::connected()) {
std::weak_ptr<this_type> wp(std::static_pointer_cast<this_type>(this->shared_from_this()));
tim_close_.expires_after(force_move(timeout));
tim_close_.async_wait(
[wp = force_move(wp)](error_code ec) {
if (auto sp = wp.lock()) {
if (!ec) {
sp->socket()->post(
[sp] {
sp->force_disconnect();
}
);
{
std::lock_guard<std::mutex> g(mtx_tim_close_);
tim_close_.expires_after(force_move(timeout));
tim_close_.async_wait(
[wp = force_move(wp)](error_code ec) {
if (auto sp = wp.lock()) {
if (!ec) {
sp->socket()->post(
[sp] {
sp->force_disconnect();
}
);
}
}
}
}
);
);
}
base::disconnect(reason_code, force_move(props));
}
}
Expand All @@ -1482,7 +1485,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
v5::disconnect_reason_code reason_code = v5::disconnect_reason_code::normal_disconnection,
v5::properties props = {}
) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
if (base::connected()) {
base::disconnect(reason_code, force_move(props));
}
Expand All @@ -1500,23 +1503,26 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
void async_disconnect(
std::chrono::steady_clock::duration timeout,
async_handler_t func = async_handler_t()) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
if (base::connected()) {
std::weak_ptr<this_type> wp(std::static_pointer_cast<this_type>(this->shared_from_this()));
tim_close_.expires_after(force_move(timeout));
tim_close_.async_wait(
[wp = force_move(wp)](error_code ec) {
if (auto sp = wp.lock()) {
if (!ec) {
sp->socket()->post(
[sp] {
sp->force_disconnect();
}
);
{
std::lock_guard<std::mutex> g(mtx_tim_close_);
tim_close_.expires_after(force_move(timeout));
tim_close_.async_wait(
[wp = force_move(wp)](error_code ec) {
if (auto sp = wp.lock()) {
if (!ec) {
sp->socket()->post(
[sp] {
sp->force_disconnect();
}
);
}
}
}
}
);
);
}
base::async_disconnect(force_move(func));
}
}
Expand All @@ -1543,23 +1549,26 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
v5::disconnect_reason_code reason_code,
v5::properties props,
async_handler_t func = async_handler_t()) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
if (base::connected()) {
std::weak_ptr<this_type> wp(std::static_pointer_cast<this_type>(this->shared_from_this()));
tim_close_.expires_after(force_move(timeout));
tim_close_.async_wait(
[wp = force_move(wp)](error_code ec) {
if (auto sp = wp.lock()) {
if (!ec) {
sp->socket()->post(
[sp] {
sp->force_disconnect();
}
);
{
std::lock_guard<std::mutex> g(mtx_tim_close_);
tim_close_.expires_after(force_move(timeout));
tim_close_.async_wait(
[wp = force_move(wp)](error_code ec) {
if (auto sp = wp.lock()) {
if (!ec) {
sp->socket()->post(
[sp] {
sp->force_disconnect();
}
);
}
}
}
}
);
);
}
base::async_disconnect(reason_code, force_move(props), force_move(func));
}
}
Expand All @@ -1574,7 +1583,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
*/
void async_disconnect(
async_handler_t func = async_handler_t()) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
if (base::connected()) {
base::async_disconnect(force_move(func));
}
Expand All @@ -1600,7 +1609,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
v5::disconnect_reason_code reason_code,
v5::properties props,
async_handler_t func = async_handler_t()) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
if (base::connected()) {
base::async_disconnect(reason_code, force_move(props), force_move(func));
}
Expand All @@ -1612,8 +1621,11 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
* When the endpoint disconnects using force_disconnect(), a will will send.<BR>
*/
void force_disconnect() {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
tim_close_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
{
std::lock_guard<std::mutex> g(mtx_tim_close_);
tim_close_.cancel();
}
base::force_disconnect();
}

Expand All @@ -1625,8 +1637,11 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
*/
void async_force_disconnect(
async_handler_t func = async_handler_t()) {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
tim_close_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
{
std::lock_guard<std::mutex> g(mtx_tim_close_);
tim_close_.cancel();
}
base::async_force_disconnect(force_move(func));
}

Expand Down Expand Up @@ -1989,7 +2004,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
}
base::set_connect();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) {
set_timer();
set_ping_timer();
}
handshake_socket(*socket_, force_move(props), force_move(session_life_keeper), underlying_connected);
}
Expand All @@ -2008,7 +2023,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
}
base::set_connect();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) {
set_timer();
set_ping_timer();
}
handshake_socket(*socket_, force_move(props), force_move(session_life_keeper), ec, underlying_connected);
}
Expand All @@ -2021,7 +2036,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
if (underlying_connected) {
base::set_connect();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) {
set_timer();
set_ping_timer();
}
async_handshake_socket(*socket_, force_move(props), force_move(session_life_keeper), force_move(func), underlying_connected);
}
Expand Down Expand Up @@ -2065,7 +2080,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
}
base::set_connect();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) {
set_timer();
set_ping_timer();
}
async_handshake_socket(*socket_, force_move(props), force_move(session_life_keeper), force_move(func), underlying_connected);
}
Expand All @@ -2078,7 +2093,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
protected:
void on_pre_send() noexcept override {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) {
reset_timer();
reset_ping_timer();
}
}

Expand All @@ -2094,7 +2109,7 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
}
}

void set_timer() {
void set_ping_timer_no_lock() {
tim_ping_.expires_after(ping_duration_);
std::weak_ptr<this_type> wp(std::static_pointer_cast<this_type>(this->shared_from_this()));
tim_ping_.async_wait(
Expand All @@ -2106,19 +2121,35 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {
);
}

void reset_timer() {
void cancel_timer_no_lock() {
tim_ping_.cancel();
set_timer();
}


void set_ping_timer() {
std::lock_guard<std::mutex> g(mtx_tim_ping_);
set_ping_timer_no_lock();
}

void cancel_ping_timer() {
std::lock_guard<std::mutex> g(mtx_tim_ping_);
cancel_timer_no_lock();
}

void reset_ping_timer() {
std::lock_guard<std::mutex> g(mtx_tim_ping_);
cancel_timer_no_lock();
set_ping_timer_no_lock();
}

protected:
void on_close() noexcept override {
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
}

void on_error(error_code ec) noexcept override {
(void)ec;
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) tim_ping_.cancel();
if (ping_duration_ != std::chrono::steady_clock::duration::zero()) cancel_ping_timer();
}

// Ensure that only code that knows the *exact* type of an object
Expand Down Expand Up @@ -2152,7 +2183,9 @@ class client : public endpoint<std::mutex, std::lock_guard, PacketIdBytes> {

std::shared_ptr<Socket> socket_;
as::io_context& ioc_;
std::mutex mtx_tim_ping_;
as::steady_timer tim_ping_;
std::mutex mtx_tim_close_;
as::steady_timer tim_close_;
std::string host_;
std::string port_;
Expand Down
Loading

0 comments on commit 2c99522

Please sign in to comment.