From 3fc4056cc22e469dc43fd25146925a07c3d9aa74 Mon Sep 17 00:00:00 2001 From: Stuart Carnie Date: Tue, 12 Sep 2017 14:13:24 -0700 Subject: [PATCH] improve memory utilization in receive buffer, fix flow control * flow control was assuming a `Read` consumed the entire buffer * flow control fix reduces memory utilization when receiving large streams of data * use timer pool to reduce allocations * use static handler function pointer to avoid closure allocations for every frame --- session.go | 35 +++++++++++++++++++---------------- stream.go | 23 ++++++++++++++++------- util.go | 15 +++++++++++++++ 3 files changed, 50 insertions(+), 23 deletions(-) diff --git a/session.go b/session.go index e179818..21025fc 100644 --- a/session.go +++ b/session.go @@ -323,8 +323,13 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error { // potential shutdown. Since there's the expectation that sends can happen // in a timely manner, we enforce the connection write timeout here. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { - timer := time.NewTimer(s.config.ConnectionWriteTimeout) - defer timer.Stop() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + timer.Stop() + timerPool.Put(t) + }() ready := sendReady{Hdr: hdr, Body: body, Err: errCh} select { @@ -408,11 +413,19 @@ func (s *Session) recv() { } } +var ( + handlers = []func(*Session, header) error{ + typeData: (*Session).handleStreamMessage, + typeWindowUpdate: (*Session).handleStreamMessage, + typePing: (*Session).handlePing, + typeGoAway: (*Session).handleGoAway, + } +) + // recvLoop continues to receive data until a fatal error is encountered func (s *Session) recvLoop() error { defer close(s.recvDoneCh) hdr := header(make([]byte, headerSize)) - var handler func(header) error for { // Read the header if _, err := io.ReadFull(s.bufRead, hdr); err != nil { @@ -428,22 +441,12 @@ func (s *Session) recvLoop() error { return ErrInvalidVersion } - // Switch on the type - switch hdr.MsgType() { - case typeData: - handler = s.handleStreamMessage - case typeWindowUpdate: - handler = s.handleStreamMessage - case typeGoAway: - handler = s.handleGoAway - case typePing: - handler = s.handlePing - default: + mt := hdr.MsgType() + if mt < typeData || mt > typeGoAway { return ErrInvalidMsgType } - // Invoke the handler - if err := handler(hdr); err != nil { + if err := handlers[mt](s, hdr); err != nil { return err } } diff --git a/stream.go b/stream.go index d216e28..c3255a6 100644 --- a/stream.go +++ b/stream.go @@ -238,18 +238,25 @@ func (s *Stream) sendWindowUpdate() error { // Determine the delta update max := s.session.config.MaxStreamWindowSize - delta := max - atomic.LoadUint32(&s.recvWindow) + var bufLen uint32 + s.recvLock.Lock() + if s.recvBuf != nil { + bufLen = uint32(s.recvBuf.Len()) + } + delta := (max - bufLen) - s.recvWindow // Determine the flags if any flags := s.sendFlags() // Check if we can omit the update if delta < (max/2) && flags == 0 { + s.recvLock.Unlock() return nil } // Update our window - atomic.AddUint32(&s.recvWindow, delta) + s.recvWindow += delta + s.recvLock.Unlock() // Send the header s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) @@ -392,16 +399,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { if length == 0 { return nil } - if remain := atomic.LoadUint32(&s.recvWindow); length > remain { - s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length) - return ErrRecvWindowExceeded - } // Wrap in a limited reader conn = &io.LimitedReader{R: conn, N: int64(length)} // Copy into buffer s.recvLock.Lock() + + if length > s.recvWindow { + s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + return ErrRecvWindowExceeded + } + if s.recvBuf == nil { // Allocate the receive buffer just-in-time to fit the full data frame. // This way we can read in the whole packet without further allocations. @@ -414,7 +423,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { } // Decrement the receive window - atomic.AddUint32(&s.recvWindow, ^uint32(length-1)) + s.recvWindow += ^uint32(length - 1) s.recvLock.Unlock() // Unblock any readers diff --git a/util.go b/util.go index 5fe45af..8a73e92 100644 --- a/util.go +++ b/util.go @@ -1,5 +1,20 @@ package yamux +import ( + "sync" + "time" +) + +var ( + timerPool = &sync.Pool{ + New: func() interface{} { + timer := time.NewTimer(time.Hour * 1e6) + timer.Stop() + return timer + }, + } +) + // asyncSendErr is used to try an async send of an error func asyncSendErr(ch chan error, err error) { if ch == nil {