diff --git a/session.go b/session.go index f914537..ecb6985 100644 --- a/session.go +++ b/session.go @@ -58,6 +58,9 @@ type Session struct { // acceptCh is used to pass ready streams to the client acceptCh chan *Stream + // goAwayCh is used to notify AcceptStream of GoAway requests + goAwayCh chan struct{} + // sendCh is used to mark a stream as ready to send, // or to send a header out directly. sendCh chan sendReady @@ -99,6 +102,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { if client { s.nextStreamID = 1 } else { + s.goAwayCh = make(chan struct{}) s.nextStreamID = 2 } go s.recv() @@ -188,6 +192,9 @@ func (s *Session) Accept() (net.Conn, error) { // AcceptStream is used to block until the next available stream // is ready to be accepted. func (s *Session) AcceptStream() (*Stream, error) { + if atomic.LoadInt32(&s.remoteGoAway) == 1 { + return nil, ErrRemoteGoAway + } select { case stream := <-s.acceptCh: if err := stream.sendWindowUpdate(); err != nil { @@ -196,6 +203,8 @@ func (s *Session) AcceptStream() (*Stream, error) { return stream, nil case <-s.shutdownCh: return nil, s.shutdownErr + case <-s.goAwayCh: + return nil, ErrRemoteGoAway } } @@ -516,6 +525,10 @@ func (s *Session) handleGoAway(hdr header) error { switch code { case goAwayNormal: atomic.SwapInt32(&s.remoteGoAway, 1) + select { + case s.goAwayCh <- struct{}{}: + default: + } case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") return fmt.Errorf("yamux protocol error") diff --git a/session_test.go b/session_test.go index 88d726e..bafe2f4 100644 --- a/session_test.go +++ b/session_test.go @@ -347,6 +347,49 @@ func TestGoAway(t *testing.T) { } } +func TestGoAwayClient(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + done := make(chan struct{}, 1) + go func() { + if err := client.GoAway(); err != nil { + t.Fatalf("err: %v", err) + } + close(done) + }() + <-done + _, err := server.Accept() + if err != ErrRemoteGoAway { + t.Errorf("err: %v", err) + } + // Test GoAway while Accept is running. + client2, server2 := testClientServer() + defer client2.Close() + defer server2.Close() + done = make(chan struct{}, 1) + go func() { + <-done + time.Sleep(500 * time.Millisecond) + if err := client2.GoAway(); err != nil { + t.Fatalf("err: %v", err) + } + }() + errCh := make(chan error, 1) + go func() { + close(done) + _, err := server2.Accept() + errCh <- err + }() + select { + case err = <-errCh: + if err != ErrRemoteGoAway { + t.Errorf("err: %v", err) + } + case <-time.After(2 * time.Second): + t.Errorf("Timeout awaiting ErrRemoteGoAway") + } +} func TestManyStreams(t *testing.T) { client, server := testClientServer() defer client.Close()