From 72cd4a92a8b13e722763e6b6a3467163c2028d3d Mon Sep 17 00:00:00 2001 From: Klaus Post Date: Tue, 8 Oct 2024 02:20:28 -0700 Subject: [PATCH] zstd: Fix extra CRC written with multiple Close calls (#1017) * zstd: Fix extra CRC written with multiple Close calls * Also check write/flush after close. Fixes #1016 --- zstd/encoder.go | 26 +++++++++++++++++++++++--- zstd/encoder_test.go | 11 ++++++++++- zstd/zstd.go | 4 ++++ 3 files changed, 37 insertions(+), 4 deletions(-) diff --git a/zstd/encoder.go b/zstd/encoder.go index a79c4a527c..8f8223cd3a 100644 --- a/zstd/encoder.go +++ b/zstd/encoder.go @@ -6,6 +6,7 @@ package zstd import ( "crypto/rand" + "errors" "fmt" "io" "math" @@ -149,6 +150,9 @@ func (e *Encoder) ResetContentSize(w io.Writer, size int64) { // and write CRC if requested. func (e *Encoder) Write(p []byte) (n int, err error) { s := &e.state + if s.eofWritten { + return 0, ErrEncoderClosed + } for len(p) > 0 { if len(p)+len(s.filling) < e.o.blockSize { if e.o.crc { @@ -288,6 +292,9 @@ func (e *Encoder) nextBlock(final bool) error { s.filling, s.current, s.previous = s.previous[:0], s.filling, s.current s.nInput += int64(len(s.current)) s.wg.Add(1) + if final { + s.eofWritten = true + } go func(src []byte) { if debugEncoder { println("Adding block,", len(src), "bytes, final:", final) @@ -303,9 +310,6 @@ func (e *Encoder) nextBlock(final bool) error { blk := enc.Block() enc.Encode(blk, src) blk.last = final - if final { - s.eofWritten = true - } // Wait for pending writes. s.wWg.Wait() if s.writeErr != nil { @@ -401,12 +405,20 @@ func (e *Encoder) Flush() error { if len(s.filling) > 0 { err := e.nextBlock(false) if err != nil { + // Ignore Flush after Close. + if errors.Is(s.err, ErrEncoderClosed) { + return nil + } return err } } s.wg.Wait() s.wWg.Wait() if s.err != nil { + // Ignore Flush after Close. + if errors.Is(s.err, ErrEncoderClosed) { + return nil + } return s.err } return s.writeErr @@ -422,6 +434,9 @@ func (e *Encoder) Close() error { } err := e.nextBlock(true) if err != nil { + if errors.Is(s.err, ErrEncoderClosed) { + return nil + } return err } if s.frameContentSize > 0 { @@ -459,6 +474,11 @@ func (e *Encoder) Close() error { } _, s.err = s.w.Write(frame) } + if s.err == nil { + s.err = ErrEncoderClosed + return nil + } + return s.err } diff --git a/zstd/encoder_test.go b/zstd/encoder_test.go index 4a39474448..9b107d970e 100644 --- a/zstd/encoder_test.go +++ b/zstd/encoder_test.go @@ -6,6 +6,7 @@ package zstd import ( "bytes" + "errors" "fmt" "io" "math/rand" @@ -278,13 +279,21 @@ func TestEncoderRegression(t *testing.T) { if err != nil { t.Error(err) } + err = enc.Close() + if err != nil { + t.Error(err) + } + _, err = enc.Write([]byte{1, 2, 3, 4}) + if !errors.Is(err, ErrEncoderClosed) { + t.Errorf("unexpected error: %v", err) + } encoded = dst.Bytes() if len(encoded) > enc.MaxEncodedSize(len(in)) { t.Errorf("max encoded size for %v: got: %d, want max: %d", len(in), len(encoded), enc.MaxEncodedSize(len(in))) } got, err = dec.DecodeAll(encoded, make([]byte, 0, len(in)/2)) if err != nil { - t.Logf("error: %v\nwant: %v\ngot: %v", err, in, got) + t.Logf("error: %v\nwant: %v\ngot: %v", err, len(in), len(got)) t.Error(err) } }) diff --git a/zstd/zstd.go b/zstd/zstd.go index 4be7cc7367..066bef2a4f 100644 --- a/zstd/zstd.go +++ b/zstd/zstd.go @@ -88,6 +88,10 @@ var ( // Close has been called. ErrDecoderClosed = errors.New("decoder used after Close") + // ErrEncoderClosed will be returned if the Encoder was used after + // Close has been called. + ErrEncoderClosed = errors.New("encoder used after Close") + // ErrDecoderNilInput is returned when a nil Reader was provided // and an operation other than Reset/DecodeAll/Close was attempted. ErrDecoderNilInput = errors.New("nil input provided as reader")