diff --git a/pkg/spooledtempfile/spooled.go b/pkg/spooledtempfile/spooled.go index 6870df9..057a920 100644 --- a/pkg/spooledtempfile/spooled.go +++ b/pkg/spooledtempfile/spooled.go @@ -176,7 +176,7 @@ func (s *spooledTempFile) Write(p []byte) (n int, err error) { // Otherwise, check if system memory usage is above threshold // or if we've exceeded our own in-memory limit, or if user forced on-disk. aboveRAMThreshold := s.isSystemMemoryUsageHigh() - if aboveRAMThreshold || s.fullOnDisk || (s.buf.Len()+len(p) > s.maxInMemorySize) { + if aboveRAMThreshold || s.fullOnDisk || (s.buf.Len()+len(p) > s.maxInMemorySize) || (s.buf.Cap() > s.maxInMemorySize) { // Switch to file if we haven't already s.file, err = os.CreateTemp(s.tempDir, s.filePrefix+"-") if err != nil { @@ -191,10 +191,15 @@ func (s *spooledTempFile) Write(p []byte) (n int, err error) { return 0, err } - // Release the buffer - s.buf.Reset() - spooledPool.Put(s.buf) - s.buf = nil + // If we're above the RAM threshold, we don't want to keep the buffer around. + if s.buf.Cap() > s.maxInMemorySize { + s.buf = nil + } else { + // Release the buffer + s.buf.Reset() + spooledPool.Put(s.buf) + s.buf = nil + } // Write incoming bytes directly to file n, err = s.file.Write(p) @@ -214,7 +219,11 @@ func (s *spooledTempFile) Close() error { s.closed = true s.mem = nil - if s.buf != nil { + // If we're above the RAM threshold, we don't want to keep the buffer around. + if s.buf != nil && s.buf.Cap() > s.maxInMemorySize { + s.buf = nil + } else { + // Release the buffer s.buf.Reset() spooledPool.Put(s.buf) s.buf = nil diff --git a/pkg/spooledtempfile/spooled_test.go b/pkg/spooledtempfile/spooled_test.go index 888ed4f..b61d360 100644 --- a/pkg/spooledtempfile/spooled_test.go +++ b/pkg/spooledtempfile/spooled_test.go @@ -62,11 +62,11 @@ func TestInMemoryBasic(t *testing.T) { // TestThresholdCrossing writes enough data to switch from in-memory to disk. func TestThresholdCrossing(t *testing.T) { - spool := NewSpooledTempFile("test", os.TempDir(), 10, false, -1) + spool := NewSpooledTempFile("test", os.TempDir(), 16, false, -1) defer spool.Close() data1 := []byte("12345") - data2 := []byte("67890ABCD") // total length > 10 + data2 := []byte("67890ABCDEFGHIJKLM") // total length > 16 _, err := spool.Write(data1) if err != nil {