Skip to content

Commit

Permalink
wrapcompressor added
Browse files Browse the repository at this point in the history
  • Loading branch information
misvivek committed Nov 5, 2024
1 parent 954f515 commit 49fa96c
Showing 1 changed file with 63 additions and 117 deletions.
180 changes: 63 additions & 117 deletions rpc_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ package grpc
import (
"bytes"
"compress/gzip"
"errors"
"io"
"math"
"reflect"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/encoding"
protoenc "google.golang.org/grpc/encoding/proto"
Expand Down Expand Up @@ -294,145 +294,91 @@ func BenchmarkGZIPCompressor1MiB(b *testing.B) {
bmCompressor(b, 1024*1024, NewGZIPCompressor())
}

// wrapCompressor wraps the gzip compressor and counts decompression invocations.
type wrapCompressor struct {
encoding.Compressor
compressInvokes int32
decompressedData []byte
errDecompress error
customReader io.Reader
}

func (wc *wrapCompressor) Compress(w io.Writer) (io.WriteCloser, error) {
atomic.AddInt32(&wc.compressInvokes, 1)
return wc.Compressor.Compress(w)
decompressInvokes int32
}

func (wc *wrapCompressor) Decompress(r io.Reader) (io.Reader, error) {
if decompressor, ok := wc.Compressor.(interface {
Decompress(io.Reader) (io.Reader, error)
}); ok {
return decompressor.Decompress(r)
}
return nil, errors.New("Decompress not supported")
atomic.AddInt32(&wc.decompressInvokes, 1)
return wc.Compressor.Decompress(r)
}

func (wc *wrapCompressor) Name() string {
return wc.Compressor.Name()
}

type ErrorReader struct{}

func (e *ErrorReader) Read(p []byte) (n int, err error) {
return 0, errors.New("simulated io.Copy read error")
}
// TestDecompress tests the decompress function.
func TestDecompress(t *testing.T) {
// Setup the wrapCompressor with gzip.
baseCompressor := encoding.GetCompressor("gzip") // Use the registered gzip compressor
wc := &wrapCompressor{Compressor: baseCompressor}

tests := []struct {
name string
compressor encoding.Compressor
input mem.BufferSlice
compressedData []byte
maxReceiveMessageSize int
want mem.BufferSlice
compressedsize int
error error
wantCompressInvokes int32
expectedDecompressed []byte
expectError bool
wantDecompressInvokes int32
}{
{
name: "Successful decompression",
compressor: &wrapCompressor{
decompressedData: []byte("decompressed data"),
},
input: mem.BufferSlice{},
maxReceiveMessageSize: 100,
want: func() mem.BufferSlice {
decompressed := []byte("decompressed data")
return mem.BufferSlice{mem.NewBuffer(&decompressed, nil)}
}(),
compressedsize: 17,
error: nil,
wantCompressInvokes: 0, // We only check decompression, so no compression invocations
name: "Successful Decompression",
compressedData: compressData([]byte("decompressed data")),
maxReceiveMessageSize: 1024,
expectedDecompressed: []byte("decompressed data"),
expectError: false,
wantDecompressInvokes: 1,
},
{
name: "Error during decompression",
compressor: &wrapCompressor{

errDecompress: errors.New("decompression error"),
},
input: mem.BufferSlice{},
maxReceiveMessageSize: 100,
want: nil,
compressedsize: 0,
error: errors.New("decompression error"),
wantCompressInvokes: 0,
name: "Decompression Failure with Corrupt Data",
compressedData: []byte("invalid compressed data"),
maxReceiveMessageSize: 1024,
expectedDecompressed: nil,
expectError: true,
wantDecompressInvokes: 1,
},
{
name: "Buffer overflow",
compressor: &wrapCompressor{

decompressedData: []byte("overflow data"),
},
input: mem.BufferSlice{},
maxReceiveMessageSize: 5,
want: nil,
compressedsize: 6,
error: errors.New("overflow: received message size is larger than the allowed maxReceiveMessageSize (5 bytes)."),
wantCompressInvokes: 0,
},
{
name: "MaxInt64 receive size with small data",
compressor: &wrapCompressor{

decompressedData: []byte("small data"),
},
input: mem.BufferSlice{},
maxReceiveMessageSize: math.MaxInt64,
want: func() mem.BufferSlice {
smallDecompressed := []byte("small data")
return mem.BufferSlice{mem.NewBuffer(&smallDecompressed, nil)}
}(),
compressedsize: 10,
error: nil,
wantCompressInvokes: 0,
},
{
name: "Error during io.Copy",
compressor: &wrapCompressor{

customReader: &ErrorReader{},
},
input: mem.BufferSlice{},
maxReceiveMessageSize: 100,
want: nil,
compressedsize: 0,
error: errors.New("simulated io.Copy read error"),
wantCompressInvokes: 0,
name: "Overflow Check",
compressedData: compressData([]byte("large decompressed data")),
maxReceiveMessageSize: 5, // Intentionally small to trigger overflow
expectedDecompressed: nil,
expectError: true,
wantDecompressInvokes: 1,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
output, size, err := decompress(tt.compressor, tt.input, tt.maxReceiveMessageSize, nil)

// Check errors
if (err != nil) != (tt.error != nil) {
t.Errorf("decompress() error, got err=%v, want err=%v", err, tt.error)
}

// Check compressed size
if size != tt.compressedsize {
t.Errorf("decompress() size, got = %d, want = %d", size, tt.compressedsize)
// Create a BufferSlice with compressed data
var data mem.BufferSlice
buf := make([]byte, len(tt.compressedData))
copy(buf, tt.compressedData)

// Creating a new buffer slice to hold the compressed data
buffer := mem.NewBuffer(&buf, nil) // Assuming mem.NewBuffer is correct in this context
data = append(data, buffer) // Append the buffer to the BufferSlice

// Call the decompress function
result, length, err := decompress(wc, data, tt.maxReceiveMessageSize, nil)

if tt.expectError {
assert.Error(t, err)
} else {
assert.NoError(t, err)
assert.Equal(t, tt.expectedDecompressed, result[:length]) // Compare the resulting byte slice
assert.Equal(t, len(tt.expectedDecompressed), length)
}

// Check output length
if len(tt.want) != len(output) {
t.Errorf("decompress() output length, got = %d, want = %d", len(output), len(tt.want))
}

// Check if compression was invoked (for tracking purposes)
if wc, ok := tt.compressor.(*wrapCompressor); ok {
invokes := atomic.LoadInt32(&wc.compressInvokes)
if invokes != tt.wantCompressInvokes {
t.Errorf("Unexpected compress invokes, got = %d, want = %d", invokes, tt.wantCompressInvokes)
}
}
// Check the number of decompression invokes
decompressInvokes := atomic.LoadInt32(&wc.decompressInvokes)
assert.Equal(t, tt.wantDecompressInvokes, decompressInvokes)
})
}
}

// Helper function to compress data using gzip
func compressData(data []byte) []byte {
var buf bytes.Buffer
writer := gzip.NewWriter(&buf)
_, _ = writer.Write(data)
writer.Close()
return buf.Bytes()
}

0 comments on commit 49fa96c

Please sign in to comment.