diff --git a/mempool/bufpool.go b/mempool/bufpool.go new file mode 100644 index 00000000..b2a2eabb --- /dev/null +++ b/mempool/bufpool.go @@ -0,0 +1,81 @@ +package mempool + +import ( + "bytes" + "sync" +) + +var bufPool = NewBuffer(0) + +// GetBuffer takes a Buffer from the default buffer pool +func GetBuffer() *bytes.Buffer { return bufPool.Get() } + +// PutBuffer returns Buffer to the default buffer pool +func PutBuffer(x *bytes.Buffer) { bufPool.Put(x) } + +type BufferPool interface { + Get() *bytes.Buffer + Put(x *bytes.Buffer) +} + +// NewBuffer returns a buffer pool. The max specify the max capacity of the Buffer the pool will +// return. If the Buffer becoomes large than max, it will no longer be returned to the pool. If +// max <= 0, no limit will be enforced. +func NewBuffer(max int) BufferPool { + if max > 0 { + return newBufferWithCap(max) + } + + return newBuffer() +} + +// Buffer is a Buffer pool. +type Buffer struct { + pool *sync.Pool +} + +func newBuffer() *Buffer { + return &Buffer{ + pool: &sync.Pool{ + New: func() any { return new(bytes.Buffer) }, + }, + } +} + +// Get a Buffer from the pool. +func (b *Buffer) Get() *bytes.Buffer { + return b.pool.Get().(*bytes.Buffer) +} + +// Put the Buffer back into pool. It resets the Buffer for reuse. +func (b *Buffer) Put(x *bytes.Buffer) { + x.Reset() + b.pool.Put(x) +} + +// BufferWithCap is a Buffer pool that +type BufferWithCap struct { + bp *Buffer + max int +} + +func newBufferWithCap(max int) *BufferWithCap { + return &BufferWithCap{ + bp: newBuffer(), + max: max, + } +} + +// Get a Buffer from the pool. +func (b *BufferWithCap) Get() *bytes.Buffer { + return b.bp.Get() +} + +// Put the Buffer back into the pool if the capacity doesn't exceed the limit. It resets the Buffer +// for reuse. +func (b *BufferWithCap) Put(x *bytes.Buffer) { + if x.Cap() > b.max { + return + } + b.bp.Put(x) +} diff --git a/mempool/bufpool_test.go b/mempool/bufpool_test.go new file mode 100644 index 00000000..560f2e93 --- /dev/null +++ b/mempool/bufpool_test.go @@ -0,0 +1,96 @@ +package mempool + +import ( + "bytes" + "reflect" + "runtime/debug" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewBuffer(t *testing.T) { + defer debug.SetGCPercent(debug.SetGCPercent(-1)) + bp := NewBuffer(1000) + require.Equal(t, "*mempool.BufferWithCap", reflect.TypeOf(bp).String()) + + bp = NewBuffer(0) + require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String()) + + bp = NewBuffer(-1) + require.Equal(t, "*mempool.Buffer", reflect.TypeOf(bp).String()) +} + +func TestBuffer(t *testing.T) { + defer debug.SetGCPercent(debug.SetGCPercent(-1)) + Size := 101 + + bp := NewBuffer(0) + buf := bp.Get() + + for i := 0; i < Size; i++ { + buf.WriteByte('a') + } + + bp.Put(buf) + buf = bp.Get() + require.Equal(t, 0, buf.Len()) +} + +func TestBufferWithCap(t *testing.T) { + defer debug.SetGCPercent(debug.SetGCPercent(-1)) + Size := 101 + bp := NewBuffer(100) + buf := bp.Get() + + for i := 0; i < Size; i++ { + buf.WriteByte('a') + } + + bp.Put(buf) + buf = bp.Get() + require.Equal(t, 0, buf.Len()) + require.Equal(t, 0, buf.Cap()) +} + +func BenchmarkBufferPool(b *testing.B) { + bp := NewBuffer(0) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := bp.Get() + b.WriteString("this is a test") + bp.Put(b) + } +} + +func BenchmarkBufferPoolWithCapLarger(b *testing.B) { + bp := NewBuffer(64 * 1024) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := bp.Get() + b.WriteString("this is a test") + bp.Put(b) + } +} + +func BenchmarkBufferPoolWithCapLesser(b *testing.B) { + bp := NewBuffer(10) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := bp.Get() + b.WriteString("this is a test") + bp.Put(b) + } +} + +func BenchmarkBufferWithoutPool(b *testing.B) { + b.ResetTimer() + for i := 0; i < b.N; i++ { + b := new(bytes.Buffer) + b.WriteString("this is a test") + _ = b + } +} diff --git a/packets/packets.go b/packets/packets.go index ff5930b2..6833473a 100644 --- a/packets/packets.go +++ b/packets/packets.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "sync" + + "github.com/mochi-mqtt/server/v2/mempool" ) // All valid packet types and their packet identifiers. @@ -298,7 +300,8 @@ func (s *Subscription) decode(b byte) { // ConnectEncode encodes a connect packet. func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeBytes(pk.Connect.ProtocolName)) nb.WriteByte(pk.ProtocolVersion) @@ -315,7 +318,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { nb.Write(encodeUint16(pk.Connect.Keepalive)) if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) (&pk.Properties).Encode(pk.FixedHeader.Type, pk.Mods, pb, 0) nb.Write(pb.Bytes()) } @@ -324,7 +328,8 @@ func (pk *Packet) ConnectEncode(buf *bytes.Buffer) error { if pk.Connect.WillFlag { if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) (&pk.Connect).WillProperties.Encode(WillProperties, pk.Mods, pb, 0) nb.Write(pb.Bytes()) } @@ -493,12 +498,14 @@ func (pk *Packet) ConnectValidate() Code { // ConnackEncode encodes a Connack packet. func (pk *Packet) ConnackEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.WriteByte(encodeBool(pk.SessionPresent)) nb.WriteByte(pk.ReasonCode) if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+2) // +SessionPresent +ReasonCode nb.Write(pb.Bytes()) } @@ -536,12 +543,14 @@ func (pk *Packet) ConnackDecode(buf []byte) error { // DisconnectEncode encodes a Disconnect packet. func (pk *Packet) DisconnectEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) if pk.ProtocolVersion == 5 { nb.WriteByte(pk.ReasonCode) - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) nb.Write(pb.Bytes()) } @@ -598,7 +607,8 @@ func (pk *Packet) PingrespDecode(buf []byte) error { // PublishEncode encodes a Publish packet. func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeString(pk.TopicName)) // [MQTT-3.3.2-1] @@ -610,16 +620,16 @@ func (pk *Packet) PublishEncode(buf *bytes.Buffer) error { } if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.Payload)) nb.Write(pb.Bytes()) } - nb.Write(pk.Payload) - - pk.FixedHeader.Remaining = nb.Len() + pk.FixedHeader.Remaining = nb.Len() + len(pk.Payload) pk.FixedHeader.Encode(buf) _, _ = nb.WriteTo(buf) + buf.Write(pk.Payload) return nil } @@ -690,11 +700,13 @@ func (pk *Packet) PublishValidate(topicAliasMaximum uint16) Code { // encodePubAckRelRecComp encodes a Puback, Pubrel, Pubrec, or Pubcomp packet. func (pk *Packet) encodePubAckRelRecComp(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeUint16(pk.PacketID)) if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) if pk.ReasonCode >= ErrUnspecifiedError.Code || pb.Len() > 1 { nb.WriteByte(pk.ReasonCode) @@ -831,11 +843,13 @@ func (pk *Packet) ReasonCodeValid() bool { // SubackEncode encodes a Suback packet. func (pk *Packet) SubackEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeUint16(pk.PacketID)) if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+len(pk.ReasonCodes)) nb.Write(pb.Bytes()) } @@ -878,10 +892,12 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { return ErrProtocolViolationNoPacketID } - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeUint16(pk.PacketID)) - xb := bytes.NewBuffer([]byte{}) // capture and write filters after length checks + xb := mempool.GetBuffer() // capture and write filters after length checks + defer mempool.PutBuffer(xb) for _, opts := range pk.Filters { xb.Write(encodeString(opts.Filter)) // [MQTT-3.8.3-1] if pk.ProtocolVersion == 5 { @@ -892,7 +908,8 @@ func (pk *Packet) SubscribeEncode(buf *bytes.Buffer) error { } if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len()) nb.Write(pb.Bytes()) } @@ -983,11 +1000,13 @@ func (pk *Packet) SubscribeValidate() Code { // UnsubackEncode encodes an Unsuback packet. func (pk *Packet) UnsubackEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeUint16(pk.PacketID)) if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) nb.Write(pb.Bytes()) } @@ -1031,16 +1050,19 @@ func (pk *Packet) UnsubscribeEncode(buf *bytes.Buffer) error { return ErrProtocolViolationNoPacketID } - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.Write(encodeUint16(pk.PacketID)) - xb := bytes.NewBuffer([]byte{}) // capture filters and write after length checks + xb := mempool.GetBuffer() // capture filters and write after length checks + defer mempool.PutBuffer(xb) for _, sub := range pk.Filters { xb.Write(encodeString(sub.Filter)) // [MQTT-3.10.3-1] } if pk.ProtocolVersion == 5 { - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()+xb.Len()) nb.Write(pb.Bytes()) } @@ -1100,10 +1122,12 @@ func (pk *Packet) UnsubscribeValidate() Code { // AuthEncode encodes an Auth packet. func (pk *Packet) AuthEncode(buf *bytes.Buffer) error { - nb := bytes.NewBuffer([]byte{}) + nb := mempool.GetBuffer() + defer mempool.PutBuffer(nb) nb.WriteByte(pk.ReasonCode) - pb := bytes.NewBuffer([]byte{}) + pb := mempool.GetBuffer() + defer mempool.PutBuffer(pb) pk.Properties.Encode(pk.FixedHeader.Type, pk.Mods, pb, nb.Len()) nb.Write(pb.Bytes())