Skip to content

Commit

Permalink
improve memory utilization in receive buffer, fix flow control
Browse files Browse the repository at this point in the history
* flow control was assuming a `Read` consumed the entire buffer
* introduced fixed receive buffer on streams to reduce memory
  utilization when receiving large streams of data
* use timer pool to reduce allocations
* use static handler function pointer to avoid closure allocations
  for every frame
  • Loading branch information
stuartcarnie committed Sep 12, 2017
1 parent b82d425 commit d483bcb
Show file tree
Hide file tree
Showing 4 changed files with 140 additions and 25 deletions.
88 changes: 88 additions & 0 deletions buffer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
package yamux

import (
"errors"
"io"
)

// buffer is an io.ReadWriteCloser backed by a fixed size buffer.
// It never allocates, but moves old data as new data is written.
type buffer struct {
buf []byte
r, w int
}

var (
errWriteFull = errors.New("write: full buffer")
)

// Read copies bytes from the buffer into p.
// It is an error to read when no data is available.
func (b *buffer) Read(p []byte) (n int, err error) {
n = copy(p, b.buf[b.r:b.w])
b.r += n
if b.r == b.w && n == 0 {
err = io.EOF
}
return n, nil
}

// Len returns the number of bytes of the unread portion of the buffer.
func (b *buffer) Len() int {
return b.w - b.r
}

// Cap returns the capacity of the buffer's underlying byte slice, that is, the
// total space allocated for the buffer's data.
func (b *buffer) Cap() int {
return cap(b.buf)
}

// Write copies bytes from p into the buffer.
// It is an error to write more data than the buffer can hold.
func (b *buffer) Write(p []byte) (n int, err error) {
// Slide existing data to beginning.
if b.r > 0 && len(p) > len(b.buf)-b.w {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}

// Write new data.
n = copy(b.buf[b.w:], p)
b.w += n
if n < len(p) {
err = errWriteFull
}
return n, err
}

// ReadFrom reads data from r until EOF and appends it to the buffer, growing
// the buffer as needed. The return value n is the number of bytes read. Any
// error except io.EOF encountered during the read is also returned. If the
// buffer becomes too large, ReadFrom will panic with ErrTooLarge.
func (b *buffer) ReadFrom(r io.Reader) (n int64, err error) {
// Slide existing data to beginning.
if b.r > 0 {
copy(b.buf, b.buf[b.r:b.w])
b.w -= b.r
b.r = 0
}

for {
m, e := r.Read(b.buf[b.w:])
if b.w == len(b.buf) && e != io.EOF {
return n, errWriteFull
}

n += int64(m)
b.w += m
if e == io.EOF {
break
}
if e != nil {
return n, e
}
}
return n, nil
}
35 changes: 19 additions & 16 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,13 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
// potential shutdown. Since there's the expectation that sends can happen
// in a timely manner, we enforce the connection write timeout here.
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
timerPool.Put(t)
}()

ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
Expand Down Expand Up @@ -408,11 +413,19 @@ func (s *Session) recv() {
}
}

var (
handlers = []func(*Session, header) error{
typeData: (*Session).handleStreamMessage,
typeWindowUpdate: (*Session).handleStreamMessage,
typePing: (*Session).handlePing,
typeGoAway: (*Session).handleGoAway,
}
)

// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
Expand All @@ -428,22 +441,12 @@ func (s *Session) recvLoop() error {
return ErrInvalidVersion
}

// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
mt := hdr.MsgType()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}

// Invoke the handler
if err := handler(hdr); err != nil {
if err := handlers[mt](s, hdr); err != nil {
return err
}
}
Expand Down
27 changes: 18 additions & 9 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Stream struct {
state streamState
stateLock sync.Mutex

recvBuf *bytes.Buffer
recvBuf *buffer
recvLock sync.Mutex

controlHdr header
Expand Down Expand Up @@ -238,18 +238,25 @@ func (s *Stream) sendWindowUpdate() error {

// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)
var bufLen uint32
s.recvLock.Lock()
if s.recvBuf != nil {
bufLen = uint32(s.recvBuf.Len())
}
delta := (max - bufLen) - s.recvWindow

// Determine the flags if any
flags := s.sendFlags()

// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}

// Update our window
atomic.AddUint32(&s.recvWindow, delta)
s.recvWindow += delta
s.recvLock.Unlock()

// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
Expand Down Expand Up @@ -392,20 +399,22 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}

// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}

// Copy into buffer
s.recvLock.Lock()

if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}

if s.recvBuf == nil {
// Allocate the receive buffer just-in-time to fit the full data frame.
// This way we can read in the whole packet without further allocations.
s.recvBuf = bytes.NewBuffer(make([]byte, 0, length))
s.recvBuf = &buffer{buf: make([]byte, s.session.config.MaxStreamWindowSize)}
}
if _, err := io.Copy(s.recvBuf, conn); err != nil {
s.session.logger.Printf("[ERR] yamux: Failed to read stream data: %v", err)
Expand All @@ -414,7 +423,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
}

// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvWindow += ^uint32(length - 1)
s.recvLock.Unlock()

// Unblock any readers
Expand Down
15 changes: 15 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
package yamux

import (
"sync"
"time"
)

var (
timerPool = &sync.Pool{
New: func() interface{} {
timer := time.NewTimer(time.Hour * 1e6)
timer.Stop()
return timer
},
}
)

// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
Expand Down

0 comments on commit d483bcb

Please sign in to comment.