diff --git a/buffer.go b/buffer.go new file mode 100644 index 0000000..d42189c --- /dev/null +++ b/buffer.go @@ -0,0 +1,88 @@ +package yamux + +import ( + "errors" + "io" +) + +// buffer is an io.ReadWriteCloser backed by a fixed size buffer. +// It never allocates, but moves old data as new data is written. +type buffer struct { + buf []byte + r, w int +} + +var ( + errWriteFull = errors.New("write: full buffer") +) + +// Read copies bytes from the buffer into p. +// It is an error to read when no data is available. +func (b *buffer) Read(p []byte) (n int, err error) { + n = copy(p, b.buf[b.r:b.w]) + b.r += n + if b.r == b.w && n == 0 { + err = io.EOF + } + return n, nil +} + +// Len returns the number of bytes of the unread portion of the buffer. +func (b *buffer) Len() int { + return b.w - b.r +} + +// Cap returns the capacity of the buffer's underlying byte slice, that is, the +// total space allocated for the buffer's data. +func (b *buffer) Cap() int { + return cap(b.buf) +} + +// Write copies bytes from p into the buffer. +// It is an error to write more data than the buffer can hold. +func (b *buffer) Write(p []byte) (n int, err error) { + // Slide existing data to beginning. + if b.r > 0 && len(p) > len(b.buf)-b.w { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + // Write new data. + n = copy(b.buf[b.w:], p) + b.w += n + if n < len(p) { + err = errWriteFull + } + return n, err +} + +// ReadFrom reads data from r until EOF and appends it to the buffer, growing +// the buffer as needed. The return value n is the number of bytes read. Any +// error except io.EOF encountered during the read is also returned. If the +// buffer becomes too large, ReadFrom will panic with ErrTooLarge. +func (b *buffer) ReadFrom(r io.Reader) (n int64, err error) { + // Slide existing data to beginning. + if b.r > 0 { + copy(b.buf, b.buf[b.r:b.w]) + b.w -= b.r + b.r = 0 + } + + for { + m, e := r.Read(b.buf[b.w:]) + if b.w == len(b.buf) && e != io.EOF { + return n, errWriteFull + } + + n += int64(m) + b.w += m + if e == io.EOF { + break + } + if e != nil { + return n, e + } + } + return n, nil +} diff --git a/session.go b/session.go index e179818..21025fc 100644 --- a/session.go +++ b/session.go @@ -323,8 +323,13 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error { // potential shutdown. Since there's the expectation that sends can happen // in a timely manner, we enforce the connection write timeout here. func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error { - timer := time.NewTimer(s.config.ConnectionWriteTimeout) - defer timer.Stop() + t := timerPool.Get() + timer := t.(*time.Timer) + timer.Reset(s.config.ConnectionWriteTimeout) + defer func() { + timer.Stop() + timerPool.Put(t) + }() ready := sendReady{Hdr: hdr, Body: body, Err: errCh} select { @@ -408,11 +413,19 @@ func (s *Session) recv() { } } +var ( + handlers = []func(*Session, header) error{ + typeData: (*Session).handleStreamMessage, + typeWindowUpdate: (*Session).handleStreamMessage, + typePing: (*Session).handlePing, + typeGoAway: (*Session).handleGoAway, + } +) + // recvLoop continues to receive data until a fatal error is encountered func (s *Session) recvLoop() error { defer close(s.recvDoneCh) hdr := header(make([]byte, headerSize)) - var handler func(header) error for { // Read the header if _, err := io.ReadFull(s.bufRead, hdr); err != nil { @@ -428,22 +441,12 @@ func (s *Session) recvLoop() error { return ErrInvalidVersion } - // Switch on the type - switch hdr.MsgType() { - case typeData: - handler = s.handleStreamMessage - case typeWindowUpdate: - handler = s.handleStreamMessage - case typeGoAway: - handler = s.handleGoAway - case typePing: - handler = s.handlePing - default: + mt := hdr.MsgType() + if mt < typeData || mt > typeGoAway { return ErrInvalidMsgType } - // Invoke the handler - if err := handler(hdr); err != nil { + if err := handlers[mt](s, hdr); err != nil { return err } } diff --git a/stream.go b/stream.go index d216e28..ad45143 100644 --- a/stream.go +++ b/stream.go @@ -33,7 +33,7 @@ type Stream struct { state streamState stateLock sync.Mutex - recvBuf *bytes.Buffer + recvBuf *buffer recvLock sync.Mutex controlHdr header @@ -238,18 +238,25 @@ func (s *Stream) sendWindowUpdate() error { // Determine the delta update max := s.session.config.MaxStreamWindowSize - delta := max - atomic.LoadUint32(&s.recvWindow) + var bufLen uint32 + s.recvLock.Lock() + if s.recvBuf != nil { + bufLen = uint32(s.recvBuf.Len()) + } + delta := (max - bufLen) - s.recvWindow // Determine the flags if any flags := s.sendFlags() // Check if we can omit the update if delta < (max/2) && flags == 0 { + s.recvLock.Unlock() return nil } // Update our window - atomic.AddUint32(&s.recvWindow, delta) + s.recvWindow += delta + s.recvLock.Unlock() // Send the header s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta) @@ -392,20 +399,22 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { if length == 0 { return nil } - if remain := atomic.LoadUint32(&s.recvWindow); length > remain { - s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length) - return ErrRecvWindowExceeded - } // Wrap in a limited reader conn = &io.LimitedReader{R: conn, N: int64(length)} // Copy into buffer s.recvLock.Lock() + + if length > s.recvWindow { + s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length) + return ErrRecvWindowExceeded + } + if s.recvBuf == nil { // Allocate the receive buffer just-in-time to fit the full data frame. // This way we can read in the whole packet without further allocations. - s.recvBuf = bytes.NewBuffer(make([]byte, 0, length)) + s.recvBuf = &buffer{buf: make([]byte, s.session.config.MaxStreamWindowSize)} } if _, err := io.Copy(s.recvBuf, conn); err != nil { s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err) @@ -414,7 +423,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error { } // Decrement the receive window - atomic.AddUint32(&s.recvWindow, ^uint32(length-1)) + s.recvWindow += ^uint32(length - 1) s.recvLock.Unlock() // Unblock any readers diff --git a/util.go b/util.go index 5fe45af..8a73e92 100644 --- a/util.go +++ b/util.go @@ -1,5 +1,20 @@ package yamux +import ( + "sync" + "time" +) + +var ( + timerPool = &sync.Pool{ + New: func() interface{} { + timer := time.NewTimer(time.Hour * 1e6) + timer.Stop() + return timer + }, + } +) + // asyncSendErr is used to try an async send of an error func asyncSendErr(ch chan error, err error) { if ch == nil {