diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml new file mode 100644 index 0000000..44fca7a --- /dev/null +++ b/.github/workflows/test.yaml @@ -0,0 +1,41 @@ +name: CI Tests +on: + pull_request: + paths-ignore: + - 'README.md' + push: + branches: + - 'master' + paths-ignore: + - 'README.md' + +permissions: + contents: read + +jobs: + go-fmt-and-vet: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3.3.0 + - uses: actions/setup-go@6edd4406fa81c3da01a34fa6f6343087c207a568 # v3.5.0 + with: + go-version: '1.20' + cache: true + - run: | + files=$(go fmt ./...) + if [ -n "$files" ]; then + echo "The following file(s) do not conform to go fmt:" + echo "$files" + exit 1 + fi + go-test: + needs: go-fmt-and-vet + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@ac593985615ec2ede58e132d2e21d2b1cbd6127c # v3.3.0 + - uses: actions/setup-go@6edd4406fa81c3da01a34fa6f6343087c207a568 # v3.5.0 + with: + go-version: '1.20' + cache: true + - run: | + go test -race ./... diff --git a/bench_test.go b/bench_test.go index d6b7348..2f890ce 100644 --- a/bench_test.go +++ b/bench_test.go @@ -1,17 +1,13 @@ package yamux import ( + "fmt" "io" - "io/ioutil" "testing" ) func BenchmarkPing(b *testing.B) { - client, server := testClientServer() - defer func() { - client.Close() - server.Close() - }() + client, _ := testClientServer(b) b.ReportAllocs() b.ResetTimer() @@ -28,11 +24,7 @@ func BenchmarkPing(b *testing.B) { } func BenchmarkAccept(b *testing.B) { - client, server := testClientServer() - defer func() { - client.Close() - server.Close() - }() + client, server := testClientServer(b) doneCh := make(chan struct{}) b.ReportAllocs() @@ -107,25 +99,20 @@ func BenchmarkSendRecvLarge(b *testing.B) { } func benchmarkSendRecv(b *testing.B, sendSize, recvSize int) { - client, server := testClientServer() - defer func() { - client.Close() - server.Close() - }() + client, server := testClientServer(b) sendBuf := make([]byte, sendSize) recvBuf := make([]byte, recvSize) - doneCh := make(chan struct{}) + errCh := make(chan error, 1) b.SetBytes(int64(sendSize)) b.ReportAllocs() b.ResetTimer() go func() { - defer close(doneCh) - stream, err := server.AcceptStream() if err != nil { + errCh <- err return } defer stream.Close() @@ -134,23 +121,27 @@ func benchmarkSendRecv(b *testing.B, sendSize, recvSize int) { case sendSize == recvSize: for i := 0; i < b.N; i++ { if _, err := stream.Read(recvBuf); err != nil { - b.Fatalf("err: %v", err) + errCh <- err + return } } case recvSize > sendSize: - b.Fatalf("bad test case; recvSize was: %d and sendSize was: %d, but recvSize must be <= sendSize!", recvSize, sendSize) + errCh <- fmt.Errorf("bad test case; recvSize was: %d and sendSize was: %d, but recvSize must be <= sendSize!", recvSize, sendSize) + return default: chunks := sendSize / recvSize for i := 0; i < b.N; i++ { for j := 0; j < chunks; j++ { if _, err := stream.Read(recvBuf); err != nil { - b.Fatalf("err: %v", err) + errCh <- err + return } } } } + errCh <- nil }() stream, err := client.Open() @@ -164,7 +155,8 @@ func benchmarkSendRecv(b *testing.B, sendSize, recvSize int) { b.Fatalf("err: %v", err) } } - <-doneCh + + drainErrorsUntil(b, errCh, 1, 0, "") } func BenchmarkSendRecvParallel32(b *testing.B) { @@ -208,33 +200,29 @@ func BenchmarkSendRecvParallel4096(b *testing.B) { } func benchmarkSendRecvParallel(b *testing.B, sendSize int) { - client, server := testClientServer() - defer func() { - client.Close() - server.Close() - }() + client, server := testClientServer(b) sendBuf := make([]byte, sendSize) - discarder := ioutil.Discard.(io.ReaderFrom) + discarder := io.Discard.(io.ReaderFrom) b.SetBytes(int64(sendSize)) b.ReportAllocs() b.ResetTimer() b.RunParallel(func(pb *testing.PB) { - doneCh := make(chan struct{}) - + errCh := make(chan error, 1) go func() { - defer close(doneCh) - stream, err := server.AcceptStream() if err != nil { + errCh <- err return } defer stream.Close() if _, err := discarder.ReadFrom(stream); err != nil { - b.Fatalf("err: %v", err) + errCh <- err + return } + errCh <- nil }() stream, err := client.Open() @@ -249,6 +237,7 @@ func benchmarkSendRecvParallel(b *testing.B, sendSize int) { } stream.Close() - <-doneCh + + drainErrorsUntil(b, errCh, 1, 0, "") }) } diff --git a/go.mod b/go.mod index dd8974d..57c4453 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,3 @@ module github.com/hashicorp/yamux -go 1.15 +go 1.20 diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..e69de29 diff --git a/mux.go b/mux.go index 28c7313..d3d5b3f 100644 --- a/mux.go +++ b/mux.go @@ -53,6 +53,11 @@ type Config struct { Logger Logger } +func (c *Config) Clone() *Config { + c2 := *c + return &c2 +} + // DefaultConfig is used to return a default configuration func DefaultConfig() *Config { return &Config{ diff --git a/session.go b/session.go index ba09f7e..c08c4da 100644 --- a/session.go +++ b/session.go @@ -356,7 +356,7 @@ func (s *Session) Ping() (time.Duration, error) { } // Compute the RTT - return time.Now().Sub(start), nil + return time.Since(start), nil } // keepalive is a long running goroutine that periodically does diff --git a/session_test.go b/session_test.go index 124d0af..46e9233 100644 --- a/session_test.go +++ b/session_test.go @@ -3,20 +3,40 @@ package yamux import ( "bytes" "context" + "errors" "fmt" "io" - "io/ioutil" "log" "net" "reflect" "runtime" "strings" "sync" + "sync/atomic" "testing" "time" ) -type logCapture struct{ bytes.Buffer } +type logCapture struct { + mu sync.Mutex + buf *bytes.Buffer +} + +var _ io.Writer = (*logCapture)(nil) + +func (l *logCapture) Write(p []byte) (n int, err error) { + l.mu.Lock() + defer l.mu.Unlock() + if l.buf == nil { + l.buf = &bytes.Buffer{} + } + return l.buf.Write(p) +} +func (l *logCapture) String() string { + l.mu.Lock() + defer l.mu.Unlock() + return l.buf.String() +} func (l *logCapture) logs() []string { return strings.Split(strings.TrimSpace(l.String()), "\n") @@ -26,12 +46,6 @@ func (l *logCapture) match(expect []string) bool { return reflect.DeepEqual(l.logs(), expect) } -func captureLogs(s *Session) *logCapture { - buf := new(logCapture) - s.logger = log.New(buf, "", 0) - return buf -} - type pipeConn struct { reader *io.PipeReader writer *io.PipeWriter @@ -69,27 +83,42 @@ func testConf() *Config { return conf } +func captureLogs(conf *Config) *logCapture { + buf := new(logCapture) + conf.Logger = log.New(buf, "", 0) + conf.LogOutput = nil + return buf +} + func testConfNoKeepAlive() *Config { conf := testConf() conf.EnableKeepAlive = false return conf } -func testClientServer() (*Session, *Session) { - return testClientServerConfig(testConf()) +func testClientServer(t testing.TB) (*Session, *Session) { + return testClientServerConfig(t, testConf(), testConf()) } -func testClientServerConfig(conf *Config) (*Session, *Session) { +func testClientServerConfig(t testing.TB, serverConf, clientConf *Config) (*Session, *Session) { conn1, conn2 := testConn() - client, _ := Client(conn1, conf) - server, _ := Server(conn2, conf) + + client, err := Client(conn1, clientConf) + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { _ = client.Close() }) + + server, err := Server(conn2, serverConf) + if err != nil { + t.Fatalf("err: %v", err) + } + t.Cleanup(func() { _ = server.Close() }) return client, server } func TestPing(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) rtt, err := client.Ping() if err != nil { @@ -109,9 +138,8 @@ func TestPing(t *testing.T) { } func TestPing_Timeout(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() + conf := testConfNoKeepAlive() + client, server := testClientServerConfig(t, conf.Clone(), conf.Clone()) // Prevent the client from responding clientConn := client.conn.(*pipeConn) @@ -153,10 +181,7 @@ func TestPing_Timeout(t *testing.T) { func TestCloseBeforeAck(t *testing.T) { cfg := testConf() cfg.AcceptBacklog = 8 - client, server := testClientServerConfig(cfg) - - defer client.Close() - defer server.Close() + client, server := testClientServerConfig(t, cfg, cfg.Clone()) for i := 0; i < 8; i++ { s, err := client.OpenStream() @@ -174,27 +199,22 @@ func TestCloseBeforeAck(t *testing.T) { s.Close() } - done := make(chan struct{}) + errCh := make(chan error, 1) go func() { - defer close(done) s, err := client.OpenStream() if err != nil { - t.Fatal(err) + errCh <- err + return } s.Close() + errCh <- nil }() - select { - case <-done: - case <-time.After(time.Second * 5): - t.Fatal("timed out trying to open stream") - } + drainErrorsUntil(t, errCh, 1, time.Second*5, "timed out trying to open stream") } func TestAccept(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) if client.NumStreams() != 0 { t.Fatalf("bad") @@ -203,89 +223,42 @@ func TestAccept(t *testing.T) { t.Fatalf("bad") } - wg := &sync.WaitGroup{} - wg.Add(4) - - go func() { - defer wg.Done() - stream, err := server.AcceptStream() - if err != nil { - t.Fatalf("err: %v", err) - } - if id := stream.StreamID(); id != 1 { - t.Fatalf("bad: %v", id) - } - if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) - } - }() - - go func() { - defer wg.Done() - stream, err := client.AcceptStream() - if err != nil { - t.Fatalf("err: %v", err) - } - if id := stream.StreamID(); id != 2 { - t.Fatalf("bad: %v", id) - } - if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) - } - }() - - go func() { - defer wg.Done() - stream, err := server.OpenStream() - if err != nil { - t.Fatalf("err: %v", err) - } - if id := stream.StreamID(); id != 2 { - t.Fatalf("bad: %v", id) - } - if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) - } - }() - - go func() { - defer wg.Done() - stream, err := client.OpenStream() + errCh := make(chan error, 4) + acceptOne := func(streamFunc func() (*Stream, error), expectID uint32) { + stream, err := streamFunc() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } - if id := stream.StreamID(); id != 1 { - t.Fatalf("bad: %v", id) + if id := stream.StreamID(); id != expectID { + errCh <- fmt.Errorf("bad: %v", id) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } - }() + errCh <- nil + } - doneCh := make(chan struct{}) - go func() { - wg.Wait() - close(doneCh) - }() + go acceptOne(server.AcceptStream, 1) + go acceptOne(client.AcceptStream, 2) + go acceptOne(server.OpenStream, 2) + go acceptOne(client.OpenStream, 1) - select { - case <-doneCh: - case <-time.After(time.Second): - panic("timeout") - } + drainErrorsUntil(t, errCh, 4, time.Second, "timeout") } func TestOpenStreamTimeout(t *testing.T) { const timeout = 25 * time.Millisecond - cfg := testConf() - cfg.StreamOpenTimeout = timeout + serverConf := testConf() + serverConf.StreamOpenTimeout = timeout - client, server := testClientServerConfig(cfg) - defer client.Close() - defer server.Close() + clientConf := serverConf.Clone() + clientLogs := captureLogs(clientConf) - clientLogs := captureLogs(client) + client, _ := testClientServerConfig(t, serverConf, clientConf) // Open a single stream without a server to acknowledge it. s, err := client.OpenStream() @@ -300,7 +273,12 @@ func TestOpenStreamTimeout(t *testing.T) { if !clientLogs.match([]string{"[ERR] yamux: aborted stream open (destination=yamux:remote): i/o deadline reached"}) { t.Fatalf("server log incorect: %v", clientLogs.logs()) } - if s.state != streamClosed { + + s.stateLock.Lock() + state := s.state + s.stateLock.Unlock() + + if state != streamClosed { t.Fatalf("stream should have been closed") } if !client.IsClosed() { @@ -311,9 +289,7 @@ func TestOpenStreamTimeout(t *testing.T) { func TestClose_closeTimeout(t *testing.T) { conf := testConf() conf.StreamCloseTimeout = 10 * time.Millisecond - client, server := testClientServerConfig(conf) - defer client.Close() - defer server.Close() + client, server := testClientServerConfig(t, conf, conf.Clone()) if client.NumStreams() != 0 { t.Fatalf("bad") @@ -322,44 +298,32 @@ func TestClose_closeTimeout(t *testing.T) { t.Fatalf("bad") } - wg := &sync.WaitGroup{} - wg.Add(2) + errCh := make(chan error, 2) // Open a stream on the client but only close it on the server. // We want to see if the stream ever gets cleaned up on the client. var clientStream *Stream go func() { - defer wg.Done() var err error clientStream, err = client.OpenStream() - if err != nil { - t.Fatalf("err: %v", err) - } + errCh <- err }() go func() { - defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } + errCh <- nil }() - doneCh := make(chan struct{}) - go func() { - wg.Wait() - close(doneCh) - }() - - select { - case <-doneCh: - case <-time.After(time.Second): - panic("timeout") - } + drainErrorsUntil(t, errCh, 2, time.Second, "timeout") // We should have zero streams after our timeout period time.Sleep(100 * time.Millisecond) @@ -379,107 +343,110 @@ func TestClose_closeTimeout(t *testing.T) { } func TestNonNilInterface(t *testing.T) { - _, server := testClientServer() + _, server := testClientServer(t) server.Close() conn, err := server.Accept() + if err == nil || !errors.Is(err, ErrSessionShutdown) || conn != nil { + t.Fatal("bad: accept should return a shutdown error and a connection of nil value") + } if err != nil && conn != nil { t.Error("bad: accept should return a connection of nil value") } conn, err = server.Open() - if err != nil && conn != nil { - t.Error("bad: open should return a connection of nil value") + if err == nil || !errors.Is(err, ErrSessionShutdown) || conn != nil { + t.Fatal("bad: open should return a shutdown error and a connection of nil value") } } func TestSendData_Small(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) - wg := &sync.WaitGroup{} - wg.Add(2) + errCh := make(chan error, 2) go func() { - defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if server.NumStreams() != 1 { - t.Fatalf("bad") + errCh <- fmt.Errorf("bad") + return } buf := make([]byte, 4) for i := 0; i < 1000; i++ { n, err := stream.Read(buf) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 4 { - t.Fatalf("short read: %d", n) + errCh <- fmt.Errorf("short read: %d", n) + return } if string(buf) != "test" { - t.Fatalf("bad: %s", buf) + errCh <- fmt.Errorf("bad: %s", buf) + return } } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } + errCh <- nil }() go func() { - defer wg.Done() stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if client.NumStreams() != 1 { - t.Fatalf("bad") + errCh <- fmt.Errorf("bad") + return } for i := 0; i < 1000; i++ { n, err := stream.Write([]byte("test")) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 4 { - t.Fatalf("short write %d", n) + errCh <- fmt.Errorf("short write %d", n) + return } } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } + errCh <- nil }() - doneCh := make(chan struct{}) - go func() { - wg.Wait() - close(doneCh) - }() - select { - case <-doneCh: - if client.NumStreams() != 0 { - t.Fatalf("bad") - } - if server.NumStreams() != 0 { - t.Fatalf("bad") - } - return - case <-time.After(time.Second): - panic("timeout") + drainErrorsUntil(t, errCh, 2, time.Second, "timeout") + + if client.NumStreams() != 0 { + t.Fatalf("bad") + } + if server.NumStreams() != 0 { + t.Fatalf("bad") } } func TestSendData_Large(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + if testing.Short() { + t.Skip("skipping slow test that may time out on the race detector") + } + client, server := testClientServer(t) const ( sendSize = 250 * 1024 * 1024 @@ -491,81 +458,81 @@ func TestSendData_Large(t *testing.T) { data[idx] = byte(idx % 256) } - wg := &sync.WaitGroup{} - wg.Add(2) + errCh := make(chan error, 2) go func() { - defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } var sz int buf := make([]byte, recvSize) for i := 0; i < sendSize/recvSize; i++ { n, err := stream.Read(buf) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != recvSize { - t.Fatalf("short read: %d", n) + errCh <- fmt.Errorf("short read: %d", n) + return } sz += n for idx := range buf { if buf[idx] != byte(idx%256) { - t.Fatalf("bad: %v %v %v", i, idx, buf[idx]) + errCh <- fmt.Errorf("bad: %v %v %v", i, idx, buf[idx]) + return } } } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz) + errCh <- nil }() go func() { - defer wg.Done() stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } n, err := stream.Write(data) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != len(data) { - t.Fatalf("short write %d", n) + errCh <- fmt.Errorf("short write %d", n) + return } if err := stream.Close(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } + errCh <- nil }() - doneCh := make(chan struct{}) - go func() { - wg.Wait() - close(doneCh) - }() - select { - case <-doneCh: - return - case <-time.After(5 * time.Second): - panic("timeout") - } + + drainErrorsUntil(t, errCh, 2, 10*time.Second, "timeout") } func TestGoAway(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) if err := server.GoAway(); err != nil { t.Fatalf("err: %v", err) } + // Give the other side time to process the goaway after receiving it. + time.Sleep(100 * time.Millisecond) + _, err := client.Open() if err != ErrRemoteGoAway { t.Fatalf("err: %v", err) @@ -573,17 +540,17 @@ func TestGoAway(t *testing.T) { } func TestManyStreams(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) + + const streams = 50 - wg := &sync.WaitGroup{} + errCh := make(chan error, 2*streams) - acceptor := func(i int) { - defer wg.Done() + acceptor := func() { stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -591,60 +558,65 @@ func TestManyStreams(t *testing.T) { for { n, err := stream.Read(buf) if err == io.EOF { + errCh <- nil return } if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n == 0 { - t.Fatalf("err: %v", err) + errCh <- fmt.Errorf("no bytes read") + return } } } - sender := func(i int) { - defer wg.Done() + sender := func(id int) { stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() - msg := fmt.Sprintf("%08d", i) + msg := fmt.Sprintf("%08d", id) for i := 0; i < 1000; i++ { n, err := stream.Write([]byte(msg)) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != len(msg) { - t.Fatalf("short write %d", n) + errCh <- fmt.Errorf("short write %d", n) + return } } + errCh <- nil } - for i := 0; i < 50; i++ { - wg.Add(2) - go acceptor(i) + for i := 0; i < streams; i++ { + go acceptor() go sender(i) } - wg.Wait() + drainErrorsUntil(t, errCh, 2*streams, 0, "") } func TestManyStreams_PingPong(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) + + const streams = 50 - wg := &sync.WaitGroup{} + errCh := make(chan error, 2*streams) ping := []byte("ping") pong := []byte("pong") - acceptor := func(i int) { - defer wg.Done() + acceptor := func() { stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -653,16 +625,20 @@ func TestManyStreams_PingPong(t *testing.T) { // Read the 'ping' n, err := stream.Read(buf) if err == io.EOF { + errCh <- nil return } if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 4 { - t.Fatalf("err: %v", err) + errCh <- fmt.Errorf("short read %d", n) + return } if !bytes.Equal(buf, ping) { - t.Fatalf("bad: %s", buf) + errCh <- fmt.Errorf("bad: %s", buf) + return } // Shrink the internal buffer! @@ -671,18 +647,20 @@ func TestManyStreams_PingPong(t *testing.T) { // Write out the 'pong' n, err = stream.Write(pong) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 4 { - t.Fatalf("err: %v", err) + errCh <- fmt.Errorf("short write %d", n) + return } } } - sender := func(i int) { - defer wg.Done() + sender := func() { stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -691,42 +669,45 @@ func TestManyStreams_PingPong(t *testing.T) { // Send the 'ping' n, err := stream.Write(ping) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 4 { - t.Fatalf("short write %d", n) + errCh <- fmt.Errorf("short write %d", n) + return } // Read the 'pong' n, err = stream.Read(buf) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 4 { - t.Fatalf("err: %v", err) + errCh <- fmt.Errorf("short read %d", n) + return } if !bytes.Equal(buf, pong) { - t.Fatalf("bad: %s", buf) + errCh <- fmt.Errorf("bad: %s", buf) + return } // Shrink the buffer stream.Shrink() } + errCh <- nil } - for i := 0; i < 50; i++ { - wg.Add(2) - go acceptor(i) - go sender(i) + for i := 0; i < streams; i++ { + go acceptor() + go sender() } - wg.Wait() + drainErrorsUntil(t, errCh, 2*streams, 0, "") } func TestHalfClose(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) stream, err := client.Open() if err != nil { @@ -777,9 +758,7 @@ func TestHalfClose(t *testing.T) { } func TestHalfCloseSessionShutdown(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) // dataSize must be large enough to ensure the server will send a window // update @@ -833,9 +812,7 @@ func TestHalfCloseSessionShutdown(t *testing.T) { } func TestReadDeadline(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) stream, err := client.Open() if err != nil { @@ -878,9 +855,7 @@ func TestReadDeadline(t *testing.T) { } func TestReadDeadline_BlockedRead(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) stream, err := client.Open() if err != nil { @@ -922,9 +897,7 @@ func TestReadDeadline_BlockedRead(t *testing.T) { } func TestWriteDeadline(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) stream, err := client.Open() if err != nil { @@ -955,9 +928,7 @@ func TestWriteDeadline(t *testing.T) { } func TestWriteDeadline_BlockedWrite(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) stream, err := client.Open() if err != nil { @@ -1008,9 +979,7 @@ func TestWriteDeadline_BlockedWrite(t *testing.T) { } func TestBacklogExceeded(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) // Fill the backlog max := client.config.AcceptBacklog @@ -1050,9 +1019,7 @@ func TestBacklogExceeded(t *testing.T) { } func TestKeepAlive(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) time.Sleep(200 * time.Millisecond) @@ -1076,15 +1043,21 @@ func TestKeepAlive_Timeout(t *testing.T) { clientConf := testConf() clientConf.ConnectionWriteTimeout = time.Hour // We're testing keep alives, not connection writes clientConf.EnableKeepAlive = false // Just test one direction, so it's deterministic who hangs up on whom - client, _ := Client(conn1, clientConf) + _ = captureLogs(clientConf) // Client logs aren't part of the test + client, err := Client(conn1, clientConf) + if err != nil { + t.Fatalf("err: %v", err) + } defer client.Close() - server, _ := Server(conn2, testConf()) + serverConf := testConf() + serverLogs := captureLogs(serverConf) + server, err := Server(conn2, serverConf) + if err != nil { + t.Fatalf("err: %v", err) + } defer server.Close() - _ = captureLogs(client) // Client logs aren't part of the test - serverLogs := captureLogs(server) - errCh := make(chan error, 1) go func() { _, err := server.Accept() // Wait until server closes @@ -1119,9 +1092,7 @@ func TestLargeWindow(t *testing.T) { conf := DefaultConfig() conf.MaxStreamWindowSize *= 2 - client, server := testClientServerConfig(conf) - defer client.Close() - defer server.Close() + client, server := testClientServerConfig(t, conf, conf.Clone()) stream, err := client.Open() if err != nil { @@ -1135,7 +1106,10 @@ func TestLargeWindow(t *testing.T) { } defer stream2.Close() - stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + err = stream.SetWriteDeadline(time.Now().Add(10 * time.Millisecond)) + if err != nil { + t.Fatalf("err: %v", err) + } buf := make([]byte, conf.MaxStreamWindowSize) n, err := stream.Write(buf) if err != nil { @@ -1154,93 +1128,97 @@ func (u *UnlimitedReader) Read(p []byte) (int, error) { } func TestSendData_VeryLarge(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + if testing.Short() { + t.Skip("skipping slow test that may time out on the race detector") + } + client, server := testClientServer(t) var n int64 = 1 * 1024 * 1024 * 1024 var workers int = 16 - wg := &sync.WaitGroup{} - wg.Add(workers * 2) + errCh := make(chan error, workers*2) for i := 0; i < workers; i++ { go func() { - defer wg.Done() stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() buf := make([]byte, 4) _, err = stream.Read(buf) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if !bytes.Equal(buf, []byte{0, 1, 2, 3}) { - t.Fatalf("bad header") + errCh <- errors.New("bad header") + return } - recv, err := io.Copy(ioutil.Discard, stream) + recv, err := io.Copy(io.Discard, stream) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if recv != n { - t.Fatalf("bad: %v", recv) + errCh <- fmt.Errorf("bad: %v", recv) + return } + + errCh <- nil }() } for i := 0; i < workers; i++ { go func() { - defer wg.Done() stream, err := client.Open() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() _, err = stream.Write([]byte{0, 1, 2, 3}) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } unlimited := &UnlimitedReader{} sent, err := io.Copy(stream, io.LimitReader(unlimited, n)) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if sent != n { - t.Fatalf("bad: %v", sent) + errCh <- fmt.Errorf("bad: %v", sent) + return } + + errCh <- nil }() } - doneCh := make(chan struct{}) - go func() { - wg.Wait() - close(doneCh) - }() - select { - case <-doneCh: - case <-time.After(20 * time.Second): - panic("timeout") - } + drainErrorsUntil(t, errCh, workers*2, 120*time.Second, "timeout") } func TestBacklogExceeded_Accept(t *testing.T) { - client, server := testClientServer() - defer client.Close() - defer server.Close() + client, server := testClientServer(t) max := 5 * client.config.AcceptBacklog + + errCh := make(chan error, max) go func() { for i := 0; i < max; i++ { stream, err := server.Accept() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() + errCh <- nil } }() @@ -1256,47 +1234,49 @@ func TestBacklogExceeded_Accept(t *testing.T) { t.Fatalf("err: %v", err) } } + + drainErrorsUntil(t, errCh, max, 0, "") } func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() + conf := testConfNoKeepAlive() - var wg sync.WaitGroup - wg.Add(2) + client, server := testClientServerConfig(t, conf, conf.Clone()) // Choose a huge flood size that we know will result in a window update. flood := int64(client.config.MaxStreamWindowSize) - 1 + errCh := make(chan error, 2) + // The server will accept a new stream and then flood data to it. go func() { - defer wg.Done() - stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() n, err := stream.Write(make([]byte, flood)) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if int64(n) != flood { - t.Fatalf("short write: %d", n) + errCh <- fmt.Errorf("short write: %d", n) } + + errCh <- nil }() // The client will open a stream, block outbound writes, and then // listen to the flood from the server, which should time out since // it won't be able to send the window update. go func() { - defer wg.Done() - stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -1306,20 +1286,22 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { _, err = stream.Read(make([]byte, flood)) if err != ErrConnectionWriteTimeout { - t.Fatalf("err: %v", err) + errCh <- err + return } + + errCh <- nil }() - wg.Wait() + drainErrorsUntil(t, errCh, 2, 0, "") } func TestSession_PartialReadWindowUpdate(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() + conf := testConfNoKeepAlive() + + client, server := testClientServerConfig(t, conf, conf.Clone()) - var wg sync.WaitGroup - wg.Add(1) + errCh := make(chan error, 1) // Choose a huge flood size that we know will result in a window update. flood := int64(client.config.MaxStreamWindowSize) @@ -1327,29 +1309,35 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { // The server will accept a new stream and then flood data to it. go func() { - defer wg.Done() - var err error wr, err = server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer wr.Close() - if wr.sendWindow != client.config.MaxStreamWindowSize { - t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow) + window := atomic.LoadUint32(&wr.sendWindow) + if window != client.config.MaxStreamWindowSize { + errCh <- fmt.Errorf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, window) + return } n, err := wr.Write(make([]byte, flood)) if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } if int64(n) != flood { - t.Fatalf("short write: %d", n) + errCh <- fmt.Errorf("short write: %d", n) + return } - if wr.sendWindow != 0 { - t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow) + window = atomic.LoadUint32(&wr.sendWindow) + if window != 0 { + errCh <- fmt.Errorf("sendWindow: exp=%d, got=%d", 0, window) + return } + errCh <- err }() stream, err := client.OpenStream() @@ -1358,41 +1346,43 @@ func TestSession_PartialReadWindowUpdate(t *testing.T) { } defer stream.Close() - wg.Wait() + drainErrorsUntil(t, errCh, 1, 0, "") _, err = stream.Read(make([]byte, flood/2+1)) + if err != nil { + t.Fatalf("err: %v", err) + } - if exp := uint32(flood/2 + 1); wr.sendWindow != exp { - t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow) + window := atomic.LoadUint32(&wr.sendWindow) + if exp := uint32(flood/2 + 1); window != exp { + t.Fatalf("sendWindow: exp=%d, got=%d", exp, window) } } func TestSession_sendNoWait_Timeout(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() + conf := testConfNoKeepAlive() - var wg sync.WaitGroup - wg.Add(2) + client, server := testClientServerConfig(t, conf, conf.Clone()) - go func() { - defer wg.Done() + errCh := make(chan error, 2) + go func() { stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() + errCh <- nil }() // The client will open the stream and then block outbound writes, we'll // probe sendNoWait once it gets into that state. go func() { - defer wg.Done() - stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -1409,21 +1399,22 @@ func TestSession_sendNoWait_Timeout(t *testing.T) { } else if err == ErrConnectionWriteTimeout { break } else { - t.Fatalf("err: %v", err) + errCh <- err + return } } + errCh <- nil }() - wg.Wait() + drainErrorsUntil(t, errCh, 2, 0, "") } func TestSession_PingOfDeath(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() + conf := testConfNoKeepAlive() + + client, server := testClientServerConfig(t, conf, conf.Clone()) - var wg sync.WaitGroup - wg.Add(2) + errCh := make(chan error, 2) var doPingOfDeath sync.Mutex doPingOfDeath.Lock() @@ -1434,11 +1425,10 @@ func TestSession_PingOfDeath(t *testing.T) { // The server will accept a stream, block outbound writes, and then // flood its send channel so that no more headers can be queued. go func() { - defer wg.Done() - stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -1452,68 +1442,70 @@ func TestSession_PingOfDeath(t *testing.T) { } else if err == ErrConnectionWriteTimeout { break } else { - t.Fatalf("err: %v", err) + errCh <- err + return } } doPingOfDeath.Unlock() + errCh <- nil }() // The client will open a stream and then send the server a ping once it // can no longer write. This makes sure the server doesn't deadlock reads // while trying to reply to the ping with no ability to write. go func() { - defer wg.Done() - stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() // This ping will never unblock because the ping id will never // show up in a response. doPingOfDeath.Lock() - go func() { client.Ping() }() + go func() { _, _ = client.Ping() }() // Wait for a while to make sure the previous ping times out, // then turn writes back on and make sure a ping works again. time.Sleep(2 * server.config.ConnectionWriteTimeout) conn.writeBlocker.Unlock() if _, err = client.Ping(); err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } + + errCh <- nil }() - wg.Wait() + drainErrorsUntil(t, errCh, 2, 0, "") } func TestSession_ConnectionWriteTimeout(t *testing.T) { - client, server := testClientServerConfig(testConfNoKeepAlive()) - defer client.Close() - defer server.Close() + conf := testConfNoKeepAlive() - var wg sync.WaitGroup - wg.Add(2) + client, server := testClientServerConfig(t, conf, conf.Clone()) - go func() { - defer wg.Done() + errCh := make(chan error, 2) + go func() { stream, err := server.AcceptStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() + errCh <- nil }() // The client will open the stream and then block outbound writes, we'll // tee up a write and make sure it eventually times out. go func() { - defer wg.Done() - stream, err := client.OpenStream() if err != nil { - t.Fatalf("err: %v", err) + errCh <- err + return } defer stream.Close() @@ -1526,39 +1518,63 @@ func TestSession_ConnectionWriteTimeout(t *testing.T) { // worked. n, err := stream.Write([]byte("hello")) if err != ErrConnectionWriteTimeout { - t.Fatalf("err: %v", err) + errCh <- err + return } if n != 0 { - t.Fatalf("lied about writes: %d", n) + errCh <- fmt.Errorf("lied about writes: %d", n) } + errCh <- nil }() - wg.Wait() + drainErrorsUntil(t, errCh, 2, 0, "") } func TestCancelAccept(t *testing.T) { - _, server := testClientServer() - defer server.Close() + _, server := testClientServer(t) ctx, cancel := context.WithCancel(context.Background()) + t.Cleanup(cancel) - var wg sync.WaitGroup + errCh := make(chan error, 1) - wg.Add(1) go func() { - defer wg.Done() - stream, err := server.AcceptStreamWithContext(ctx) if err != context.Canceled { - t.Fatalf("err: %v", err) + errCh <- err + return } if stream != nil { defer stream.Close() } + errCh <- nil }() cancel() - wg.Wait() + drainErrorsUntil(t, errCh, 1, 0, "") +} + +func drainErrorsUntil(t testing.TB, errCh chan error, expect int, timeout time.Duration, msg string) { + t.Helper() + start := time.Now() + var timerC <-chan time.Time + if timeout > 0 { + timerC = time.After(timeout) + } + + for found := 0; found < expect; { + select { + case <-timerC: + t.Fatalf(msg+" (timeout was %v)", timeout) + case err := <-errCh: + if err != nil { + t.Fatalf("err: %v", err) + } else { + found++ + } + } + } + t.Logf("drain took %v (timeout was %v)", time.Since(start), timeout) } diff --git a/stream.go b/stream.go index 23d08fc..365b718 100644 --- a/stream.go +++ b/stream.go @@ -138,7 +138,7 @@ WAIT: var timer *time.Timer readDeadline := s.readDeadline.Load().(time.Time) if !readDeadline.IsZero() { - delay := readDeadline.Sub(time.Now()) + delay := time.Until(readDeadline) timer = time.NewTimer(delay) timeout = timer.C } @@ -221,7 +221,7 @@ WAIT: var timeout <-chan time.Time writeDeadline := s.writeDeadline.Load().(time.Time) if !writeDeadline.IsZero() { - delay := writeDeadline.Sub(time.Now()) + delay := time.Until(writeDeadline) timeout = time.After(delay) } select { @@ -230,7 +230,6 @@ WAIT: case <-timeout: return 0, ErrTimeout } - return 0, nil } // sendFlags determines any flags that are appropriate @@ -380,7 +379,7 @@ func (s *Stream) closeTimeout() { defer s.sendLock.Unlock() hdr := header(make([]byte, headerSize)) hdr.encode(typeWindowUpdate, flagRST, s.id, 0) - s.session.sendNoWait(hdr) + _ = s.session.sendNoWait(hdr) } // forceClose is used for when the session is exiting