Skip to content

Commit

Permalink
Add cl.IsTakenOver and switch cl.isTakenOver to atomic.Bool (#446)
Browse files Browse the repository at this point in the history
* OnPublish CodeSuccessIgnore, use debug instead of error log

* Suppress OnPublish CodeSuccessIgnore error log

* Add cl.IsTakenOver and switch to use atomic.Bool

---------

Co-authored-by: JB <[email protected]>
  • Loading branch information
thedevop and mochi-co authored Jan 30, 2025
1 parent dcb814c commit 0439068
Show file tree
Hide file tree
Showing 5 changed files with 26 additions and 6 deletions.
6 changes: 5 additions & 1 deletion clients.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ type ClientState struct {
disconnected int64 // the time the client disconnected in unix time, for calculating expiry
outbound chan *packets.Packet // queue for pending outbound packets
endOnce sync.Once // only end once
isTakenOver uint32 // used to identify orphaned clients
isTakenOver atomic.Bool // used to identify orphaned clients
packetID uint32 // the current highest packetID
open context.Context // indicate that the client is open for packet exchange
cancelOpen context.CancelFunc // cancel function for open context
Expand Down Expand Up @@ -427,6 +427,10 @@ func (cl *Client) Closed() bool {
return cl.State.open == nil || cl.State.open.Err() != nil
}

func (cl *Client) IsTakenOver() bool {
return cl.State.isTakenOver.Load()
}

// ReadFixedHeader reads in the values of the next packet's fixed header.
func (cl *Client) ReadFixedHeader(fh *packets.FixedHeader) error {
if cl.Net.bconn == nil {
Expand Down
7 changes: 7 additions & 0 deletions clients_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -599,6 +599,13 @@ func TestClientClosed(t *testing.T) {
require.True(t, cl.Closed())
}

func TestClientIsTakenOver(t *testing.T) {
cl, _, _ := newTestClient()
require.False(t, cl.IsTakenOver())
cl.State.isTakenOver.Store(true)
require.True(t, cl.IsTakenOver())
}

func TestClientReadFixedHeaderError(t *testing.T) {
cl, r, _ := newTestClient()
defer cl.Stop(errClientStop)
Expand Down
2 changes: 2 additions & 0 deletions hooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,8 @@ func (h *Hooks) OnPublish(cl *Client, pk packets.Packet) (pkx packets.Packet, er
"hook", hook.ID(),
"packet", pkx)
return pk, err
} else if errors.Is(err, packets.CodeSuccessIgnore) {
return pk, err
}
h.Log.Error("publish packet error",
"error", err,
Expand Down
10 changes: 5 additions & 5 deletions server.go
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,7 @@ func (s *Server) attachClient(cl *Client, listener string) error {
expire := (cl.Properties.ProtocolVersion == 5 && cl.Properties.Props.SessionExpiryInterval == 0) || (cl.Properties.ProtocolVersion < 5 && cl.Properties.Clean)
s.hooks.OnDisconnect(cl, err, expire)

if expire && atomic.LoadUint32(&cl.State.isTakenOver) == 0 {
if expire && !cl.IsTakenOver() {
cl.ClearInflights()
s.UnsubscribeClient(cl)
s.Clients.Delete(cl.ID) // [MQTT-4.1.0-2] ![MQTT-3.1.2-23]
Expand Down Expand Up @@ -565,11 +565,11 @@ func (s *Server) inheritClientSession(pk packets.Packet, cl *Client) bool {
if pk.Connect.Clean || (existing.Properties.Clean && existing.Properties.ProtocolVersion < 5) { // [MQTT-3.1.2-4] [MQTT-3.1.4-4]
s.UnsubscribeClient(existing)
existing.ClearInflights()
atomic.StoreUint32(&existing.State.isTakenOver, 1) // only set isTakenOver after unsubscribe has occurred
return false // [MQTT-3.2.2-3]
existing.State.isTakenOver.Store(true) // only set isTakenOver after unsubscribe has occurred
return false // [MQTT-3.2.2-3]
}

atomic.StoreUint32(&existing.State.isTakenOver, 1)
existing.State.isTakenOver.Store(true)
if existing.State.Inflight.Len() > 0 {
cl.State.Inflight = existing.State.Inflight.Clone() // [MQTT-3.1.2-5]
if cl.State.Inflight.maximumReceiveQuota == 0 && cl.ops.options.Capabilities.ReceiveMaximum != 0 {
Expand Down Expand Up @@ -1358,7 +1358,7 @@ func (s *Server) UnsubscribeClient(cl *Client) {
cl.State.Subscriptions.Delete(k)
}

if atomic.LoadUint32(&cl.State.isTakenOver) == 1 {
if cl.IsTakenOver() {
return
}

Expand Down
7 changes: 7 additions & 0 deletions server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -667,6 +667,7 @@ func TestEstablishConnectionInheritExisting(t *testing.T) {
clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.NotEmpty(t, clw.State.Subscriptions)
require.True(t, cl.IsTakenOver())

// Prevent sequential takeover memory-bloom.
require.Empty(t, cl.State.Subscriptions.GetAll())
Expand Down Expand Up @@ -761,6 +762,9 @@ func TestEstablishConnectionInheritExistingTrueTakeover(t *testing.T) {

_, _ = w2.Write(packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes)
require.NoError(t, <-o2)

require.True(t, clp1.IsTakenOver())
require.False(t, clp2.IsTakenOver())
}

func TestEstablishConnectionResentPendingInflightsError(t *testing.T) {
Expand Down Expand Up @@ -848,12 +852,15 @@ func TestEstablishConnectionInheritExistingClean(t *testing.T) {
require.Equal(t, packets.TPacketData[packets.Connack].Get(packets.TConnackAcceptedNoSession).RawBytes, <-recv)
require.Equal(t, packets.TPacketData[packets.Disconnect].Get(packets.TDisconnect).RawBytes, <-takeover)

require.True(t, cl.IsTakenOver())

_ = w.Close()
_ = r.Close()

clw, ok := s.Clients.Get(packets.TPacketData[packets.Connect].Get(packets.TConnectMqtt311).Packet.Connect.ClientIdentifier)
require.True(t, ok)
require.Equal(t, 0, clw.State.Subscriptions.Len())

}

func TestEstablishConnectionBadAuthentication(t *testing.T) {
Expand Down

0 comments on commit 0439068

Please sign in to comment.