From c4bbd3511be52e0da4f6522a9f33bc4df5705b14 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 2 Sep 2015 23:05:59 +0100 Subject: [PATCH 1/3] session.go: AcceptStream handles GoAway --- session.go | 3 +++ session_test.go | 18 +++++++++++++++++- 2 files changed, 20 insertions(+), 1 deletion(-) diff --git a/session.go b/session.go index f914537..9ecb457 100644 --- a/session.go +++ b/session.go @@ -188,6 +188,9 @@ func (s *Session) Accept() (net.Conn, error) { // AcceptStream is used to block until the next available stream // is ready to be accepted. func (s *Session) AcceptStream() (*Stream, error) { + if atomic.LoadInt32(&s.remoteGoAway) == 1 { + return nil, ErrRemoteGoAway + } select { case stream := <-s.acceptCh: if err := stream.sendWindowUpdate(); err != nil { diff --git a/session_test.go b/session_test.go index 88d726e..bb55b63 100644 --- a/session_test.go +++ b/session_test.go @@ -346,7 +346,23 @@ func TestGoAway(t *testing.T) { t.Fatalf("err: %v", err) } } - +func TestGoAwayClient(t *testing.T) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + done := make(chan struct{}, 1) + go func(){ + if err := client.GoAway(); err != nil { + t.Fatalf("err: %v", err) + } + close(done) + }() + <- done + _, err := server.Accept() + if err != ErrRemoteGoAway { + t.Errorf("err: %v", err) + } +} func TestManyStreams(t *testing.T) { client, server := testClientServer() defer client.Close() From e065b5d0101e8bacc97f6aa0dd4fbfa6010da79e Mon Sep 17 00:00:00 2001 From: root Date: Thu, 3 Sep 2015 01:27:58 +0100 Subject: [PATCH 2/3] session.go: AcceptStream handles GoAway while is running --- session.go | 10 ++++++++++ session_test.go | 27 +++++++++++++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/session.go b/session.go index 9ecb457..30122b2 100644 --- a/session.go +++ b/session.go @@ -58,6 +58,9 @@ type Session struct { // acceptCh is used to pass ready streams to the client acceptCh chan *Stream + // goAwayCh is used to notify AcceptStream of GoAway requests + goAwayCh chan struct{} + // sendCh is used to mark a stream as ready to send, // or to send a header out directly. sendCh chan sendReady @@ -92,6 +95,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { streams: make(map[uint32]*Stream), synCh: make(chan struct{}, config.AcceptBacklog), acceptCh: make(chan *Stream, config.AcceptBacklog), + goAwayCh: make(chan struct{}, 1), sendCh: make(chan sendReady, 64), recvDoneCh: make(chan struct{}), shutdownCh: make(chan struct{}), @@ -199,6 +203,8 @@ func (s *Session) AcceptStream() (*Stream, error) { return stream, nil case <-s.shutdownCh: return nil, s.shutdownErr + case <- s.goAwayCh: + return nil, ErrRemoteGoAway } } @@ -419,6 +425,10 @@ func (s *Session) recvLoop() error { handler = s.handleStreamMessage case typeGoAway: handler = s.handleGoAway + select{ + case s.goAwayCh <- struct{}{}: + default: + } case typePing: handler = s.handlePing default: diff --git a/session_test.go b/session_test.go index bb55b63..d4f73d9 100644 --- a/session_test.go +++ b/session_test.go @@ -346,6 +346,7 @@ func TestGoAway(t *testing.T) { t.Fatalf("err: %v", err) } } + func TestGoAwayClient(t *testing.T) { client, server := testClientServer() defer client.Close() @@ -362,6 +363,32 @@ func TestGoAwayClient(t *testing.T) { if err != ErrRemoteGoAway { t.Errorf("err: %v", err) } + // Test GoAway while Accept is running. + client2, server2 := testClientServer() + defer client2.Close() + defer server2.Close() + done = make(chan struct{}, 1) + go func(){ + <- done + time.Sleep(500 *time.Millisecond) + if err := client2.GoAway(); err != nil { + t.Fatalf("err: %v", err) + } + }() + errCh := make(chan error, 1) + go func(){ + close(done) + _, err := server2.Accept() + errCh <- err + }() + select{ + case err = <- errCh: + if err != ErrRemoteGoAway { + t.Errorf("err: %v", err) + } + case <- time.After(2 *time.Second): + t.Errorf("Timeout awaiting ErrRemoteGoAway") + } } func TestManyStreams(t *testing.T) { client, server := testClientServer() From ef7d65eb4423ba3e90a64d6fa6cee025caa780ce Mon Sep 17 00:00:00 2001 From: root Date: Sat, 5 Sep 2015 18:54:20 +0100 Subject: [PATCH 3/3] session.go: fmt & create goAwayCh only on server --- session.go | 12 ++++++------ session_test.go | 18 +++++++++--------- 2 files changed, 15 insertions(+), 15 deletions(-) diff --git a/session.go b/session.go index 30122b2..ecb6985 100644 --- a/session.go +++ b/session.go @@ -95,7 +95,6 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { streams: make(map[uint32]*Stream), synCh: make(chan struct{}, config.AcceptBacklog), acceptCh: make(chan *Stream, config.AcceptBacklog), - goAwayCh: make(chan struct{}, 1), sendCh: make(chan sendReady, 64), recvDoneCh: make(chan struct{}), shutdownCh: make(chan struct{}), @@ -103,6 +102,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session { if client { s.nextStreamID = 1 } else { + s.goAwayCh = make(chan struct{}) s.nextStreamID = 2 } go s.recv() @@ -203,7 +203,7 @@ func (s *Session) AcceptStream() (*Stream, error) { return stream, nil case <-s.shutdownCh: return nil, s.shutdownErr - case <- s.goAwayCh: + case <-s.goAwayCh: return nil, ErrRemoteGoAway } } @@ -425,10 +425,6 @@ func (s *Session) recvLoop() error { handler = s.handleStreamMessage case typeGoAway: handler = s.handleGoAway - select{ - case s.goAwayCh <- struct{}{}: - default: - } case typePing: handler = s.handlePing default: @@ -529,6 +525,10 @@ func (s *Session) handleGoAway(hdr header) error { switch code { case goAwayNormal: atomic.SwapInt32(&s.remoteGoAway, 1) + select { + case s.goAwayCh <- struct{}{}: + default: + } case goAwayProtoErr: s.logger.Printf("[ERR] yamux: received protocol error go away") return fmt.Errorf("yamux protocol error") diff --git a/session_test.go b/session_test.go index d4f73d9..bafe2f4 100644 --- a/session_test.go +++ b/session_test.go @@ -352,13 +352,13 @@ func TestGoAwayClient(t *testing.T) { defer client.Close() defer server.Close() done := make(chan struct{}, 1) - go func(){ + go func() { if err := client.GoAway(); err != nil { t.Fatalf("err: %v", err) } close(done) }() - <- done + <-done _, err := server.Accept() if err != ErrRemoteGoAway { t.Errorf("err: %v", err) @@ -368,25 +368,25 @@ func TestGoAwayClient(t *testing.T) { defer client2.Close() defer server2.Close() done = make(chan struct{}, 1) - go func(){ - <- done - time.Sleep(500 *time.Millisecond) + go func() { + <-done + time.Sleep(500 * time.Millisecond) if err := client2.GoAway(); err != nil { t.Fatalf("err: %v", err) } }() errCh := make(chan error, 1) - go func(){ + go func() { close(done) _, err := server2.Accept() errCh <- err }() - select{ - case err = <- errCh: + select { + case err = <-errCh: if err != ErrRemoteGoAway { t.Errorf("err: %v", err) } - case <- time.After(2 *time.Second): + case <-time.After(2 * time.Second): t.Errorf("Timeout awaiting ErrRemoteGoAway") } }