diff --git a/bench_test.go b/bench_test.go index 4377b83..9d3b2a1 100644 --- a/bench_test.go +++ b/bench_test.go @@ -79,3 +79,46 @@ func BenchmarkSendRecv(b *testing.B) { } <-doneCh } + +func BenchmarkSendRecvLarge(b *testing.B) { + client, server := testClientServer() + defer client.Close() + defer server.Close() + + const sendSize = 100 * 1024 * 1024 + const recvSize = 4 * 1024 + + sendBuf := make([]byte, sendSize) + recvBuf := make([]byte, recvSize) + + b.ResetTimer() + recvDone := make(chan struct{}) + + go func() { + stream, err := server.AcceptStream() + if err != nil { + return + } + defer stream.Close() + for i := 0; i < b.N; i++ { + for j := 0; j < sendSize/recvSize; j++ { + if _, err := stream.Read(recvBuf); err != nil { + b.Fatalf("err: %v", err) + } + } + } + close(recvDone) + }() + + stream, err := client.Open() + if err != nil { + b.Fatalf("err: %v", err) + } + defer stream.Close() + for i := 0; i < b.N; i++ { + if _, err := stream.Write(sendBuf); err != nil { + b.Fatalf("err: %v", err) + } + } + <-recvDone +} diff --git a/const.go b/const.go index 4f52938..fb5bb21 100644 --- a/const.go +++ b/const.go @@ -76,6 +76,10 @@ const ( // GoAway is sent to terminate a session. The StreamID // should be 0 and the length is an error code. typeGoAway + + // typeMax defines the upper bound of valid message types + // and should always be the last constant. + typeMax ) const ( diff --git a/session.go b/session.go index e179818..bf19715 100644 --- a/session.go +++ b/session.go @@ -175,11 +175,7 @@ GET_ID: // Send the window update to create if err := stream.sendWindowUpdate(); err != nil { - select { - case <-s.synCh: - default: - s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore") - } + s.closeStream(id) return nil, err } return stream, nil @@ -323,8 +319,13 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error { // potential shutdown. Since there's the expectation that sends can happen // in a timely manner, we enforce the connection write timeout here. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { - timer := time.NewTimer(s.config.ConnectionWriteTimeout) - defer timer.Stop() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + timer.Stop() + timerPool.Put(t) + }() ready := sendReady{Hdr: hdr, Body: body, Err: errCh} select { @@ -373,7 +374,7 @@ func (s *Session) send() { for sent < len(ready.Hdr) { n, err := s.conn.Write(ready.Hdr[sent:]) if err != nil { - s.logger.Printf("[ERR] yamux: Failed to write header: %v", err) + s.logger.Printf("[WARN] yamux: Failed to write header: %v", err) asyncSendErr(ready.Err, err) s.exitErr(err) return @@ -386,7 +387,7 @@ func (s *Session) send() { if ready.Body != nil { _, err := io.Copy(s.conn, ready.Body) if err != nil { - s.logger.Printf("[ERR] yamux: Failed to write body: %v", err) + s.logger.Printf("[WARN] yamux: Failed to write body: %v", err) asyncSendErr(ready.Err, err) s.exitErr(err) return @@ -408,11 +409,19 @@ func (s *Session) recv() { } } +var ( + handlers = []func(*Session, header) error{ + typeData: (*Session).handleStreamMessage, + typeWindowUpdate: (*Session).handleStreamMessage, + typePing: (*Session).handlePing, + typeGoAway: (*Session).handleGoAway, + } +) + // recvLoop continues to receive data until a fatal error is encountered func (s *Session) recvLoop() error { defer close(s.recvDoneCh) hdr := header(make([]byte, headerSize)) - var handler func(header) error for { // Read the header if _, err := io.ReadFull(s.bufRead, hdr); err != nil { @@ -428,22 +437,12 @@ func (s *Session) recvLoop() error { return ErrInvalidVersion } - // Switch on the type - switch hdr.MsgType() { - case typeData: - handler = s.handleStreamMessage - case typeWindowUpdate: - handler = s.handleStreamMessage - case typeGoAway: - handler = s.handleGoAway - case typePing: - handler = s.handlePing - default: + mt := hdr.MsgType() + if mt < typeData || mt >= typeMax { return ErrInvalidMsgType } - // Invoke the handler - if err := handler(hdr); err != nil { + if err := handlers[mt](s, hdr); err != nil { return err } } @@ -465,7 +464,7 @@ func (s *Session) handleStreamMessage(hdr header) error { stream := s.streams[id] s.streamLock.Unlock() - // If we do not have a stream, likely we sent a RST + // If we do not have a stream, likely we sent a RST or an error occurred sending a SYN if stream == nil { // Drain any data on the wire if hdr.MsgType() == typeData && hdr.Length() > 0 { @@ -595,6 +594,7 @@ func (s *Session) incomingStream(id uint32) error { func (s *Session) closeStream(id uint32) { s.streamLock.Lock() if _, ok := s.inflight[id]; ok { + delete(s.inflight, id) select { case <-s.synCh: default: diff --git a/session_test.go b/session_test.go index 0b4200e..17304a3 100644 --- a/session_test.go +++ b/session_test.go @@ -376,7 +376,12 @@ func TestSendData_Large(t *testing.T) { defer client.Close() defer server.Close() - data := make([]byte, 512*1024) + const ( + sendSize = 250 * 1024 * 1024 + recvSize = 4 * 1024 + ) + + data := make([]byte, sendSize) for idx := range data { data[idx] = byte(idx % 256) } @@ -390,16 +395,17 @@ func TestSendData_Large(t *testing.T) { if err != nil { t.Fatalf("err: %v", err) } - - buf := make([]byte, 4*1024) - for i := 0; i < 128; i++ { + var sz int + buf := make([]byte, recvSize) + for i := 0; i < sendSize/recvSize; i++ { n, err := stream.Read(buf) if err != nil { t.Fatalf("err: %v", err) } - if n != 4*1024 { + if n != recvSize { t.Fatalf("short read: %d", n) } + sz += n for idx := range buf { if buf[idx] != byte(idx%256) { t.Fatalf("bad: %v %v %v", i, idx, buf[idx]) @@ -410,6 +416,8 @@ func TestSendData_Large(t *testing.T) { if err := stream.Close(); err != nil { t.Fatalf("err: %v", err) } + + t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz) }() go func() { @@ -439,7 +447,7 @@ func TestSendData_Large(t *testing.T) { }() select { case <-doneCh: - case <-time.After(time.Second): + case <-time.After(5 * time.Second): panic("timeout") } } @@ -972,6 +980,48 @@ func TestBacklogExceeded_Accept(t *testing.T) { } } +func TestSessionOpenStream_WindowUpdateSYNTimeout(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + // Prevent the client from initially writing SYN + clientConn := client.conn.(*pipeConn) + clientConn.writeBlocker.Lock() + + var wg sync.WaitGroup + wg.Add(1) + + // server + go func() { + defer wg.Done() + + stream, err := server.Accept() + if err != nil { + t.Fatalf("err: %v", err) + } + stream.Close() + }() + + stream, err := client.OpenStream() + if err == nil { + t.Fatal("expected err: connection write timeout") + } + + // release lock + clientConn.writeBlocker.Unlock() + + if stream != nil { + t.Fatal("expected stream to be nil") + } + + wg.Wait() + + if exp, got := 0, len(client.streams); got != exp { + t.Errorf("invalid streams length; exp=%d, got=%d", exp, got) + } +} + func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { client, server := testClientServerConfig(testConfNoKeepAlive()) defer client.Close() @@ -1026,6 +1076,60 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) { wg.Wait() } +func TestSession_PartialReadWindowUpdate(t *testing.T) { + client, server := testClientServerConfig(testConfNoKeepAlive()) + defer client.Close() + defer server.Close() + + var wg sync.WaitGroup + wg.Add(1) + + // Choose a huge flood size that we know will result in a window update. + flood := int64(client.config.MaxStreamWindowSize) + var wr *Stream + + // The server will accept a new stream and then flood data to it. + go func() { + defer wg.Done() + + var err error + wr, err = server.AcceptStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer wr.Close() + + if wr.sendWindow != client.config.MaxStreamWindowSize { + t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow) + } + + n, err := wr.Write(make([]byte, flood)) + if err != nil { + t.Fatalf("err: %v", err) + } + if int64(n) != flood { + t.Fatalf("short write: %d", n) + } + if wr.sendWindow != 0 { + t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow) + } + }() + + stream, err := client.OpenStream() + if err != nil { + t.Fatalf("err: %v", err) + } + defer stream.Close() + + wg.Wait() + + _, err = stream.Read(make([]byte, flood/2+1)) + + if exp := uint32(flood/2 + 1); wr.sendWindow != exp { + t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow) + } +} + func TestSession_sendNoWait_Timeout(t *testing.T) { client, server := testClientServerConfig(testConfNoKeepAlive()) defer client.Close() diff --git a/stream.go b/stream.go index d216e28..c3255a6 100644 --- a/stream.go +++ b/stream.go @@ -238,18 +238,25 @@ func (s *Stream) sendWindowUpdate() error { // Determine the delta update max := s.session.config.MaxStreamWindowSize - delta := max - atomic.LoadUint32(&s.recvWindow) + var bufLen uint32 + s.recvLock.Lock() + if s.recvBuf != nil { + bufLen = uint32(s.recvBuf.Len()) + } + delta := (max - bufLen) - s.recvWindow // Determine the flags if any flags := s.sendFlags() // Check if we can omit the update if delta < (max/2) && flags == 0 { + s.recvLock.Unlock() return nil } // Update our window - atomic.AddUint32(&s.recvWindow, delta) + s.recvWindow += delta + s.recvLock.Unlock() // Send the header s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) @@ -392,16 +399,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { if length == 0 { return nil } - if remain := atomic.LoadUint32(&s.recvWindow); length > remain { - s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length) - return ErrRecvWindowExceeded - } // Wrap in a limited reader conn = &io.LimitedReader{R: conn, N: int64(length)} // Copy into buffer s.recvLock.Lock() + + if length > s.recvWindow { + s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + return ErrRecvWindowExceeded + } + if s.recvBuf == nil { // Allocate the receive buffer just-in-time to fit the full data frame. // This way we can read in the whole packet without further allocations. @@ -414,7 +423,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { } // Decrement the receive window - atomic.AddUint32(&s.recvWindow, ^uint32(length-1)) + s.recvWindow += ^uint32(length - 1) s.recvLock.Unlock() // Unblock any readers diff --git a/util.go b/util.go index 5fe45af..8a73e92 100644 --- a/util.go +++ b/util.go @@ -1,5 +1,20 @@ package yamux +import ( + "sync" + "time" +) + +var ( + timerPool = &sync.Pool{ + New: func() interface{} { + timer := time.NewTimer(time.Hour * 1e6) + timer.Stop() + return timer + }, + } +) + // asyncSendErr is used to try an async send of an error func asyncSendErr(ch chan error, err error) { if ch == nil {