Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix window update, reduce memory utilization and improve performance #50

Closed
wants to merge 7 commits into from
43 changes: 43 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,3 +79,46 @@ func BenchmarkSendRecv(b *testing.B) {
}
<-doneCh
}

func BenchmarkSendRecvLarge(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

const sendSize = 100 * 1024 * 1024
const recvSize = 4 * 1024

sendBuf := make([]byte, sendSize)
recvBuf := make([]byte, recvSize)

b.ResetTimer()
recvDone := make(chan struct{})

go func() {
stream, err := server.AcceptStream()
if err != nil {
return
}
defer stream.Close()
for i := 0; i < b.N; i++ {
for j := 0; j < sendSize/recvSize; j++ {
if _, err := stream.Read(recvBuf); err != nil {
b.Fatalf("err: %v", err)
}
}
}
close(recvDone)
}()

stream, err := client.Open()
if err != nil {
b.Fatalf("err: %v", err)
}
defer stream.Close()
for i := 0; i < b.N; i++ {
if _, err := stream.Write(sendBuf); err != nil {
b.Fatalf("err: %v", err)
}
}
<-recvDone
}
35 changes: 19 additions & 16 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,13 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
// potential shutdown. Since there's the expectation that sends can happen
// in a timely manner, we enforce the connection write timeout here.
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
timerPool.Put(t)
}()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems like sendNoWait could use the same timer pool too?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could use the timer pool for sendNoWait too

ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
Expand Down Expand Up @@ -408,11 +413,19 @@ func (s *Session) recv() {
}
}

var (
handlers = []func(*Session, header) error{
typeData: (*Session).handleStreamMessage,
typeWindowUpdate: (*Session).handleStreamMessage,
typePing: (*Session).handlePing,
typeGoAway: (*Session).handleGoAway,
}
)

// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
Expand All @@ -428,22 +441,12 @@ func (s *Session) recvLoop() error {
return ErrInvalidVersion
}

// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
mt := hdr.MsgType()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}

// Invoke the handler
if err := handler(hdr); err != nil {
if err := handlers[mt](s, hdr); err != nil {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really like this change. Yes it's less lines of code but it's also much more susceptible to subtle bugs where you forget to change the if condition above. With the switch statement adding or removing a case only affects 1 place while after this change you need to update both the if statement and the lookup table.

Copy link
Contributor Author

@stuartcarnie stuartcarnie Sep 13, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I made this change to avoid allocations for every message that is handled.

I tried this approach first, however it still resulted in allocations (which was surprising):

switch mt {
case typeData:
    handler = (*Session).handleStreamMessage
case typeWindowUpdate:
    handler = (*Session).handleStreamMessage
...
}

if err := handler(s, hdr)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See commit e7f9152

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you sure it allocates? I can't seem to produce any similar code that actually allocates: https://gist.github.com/erikdubbelboer/53f4bc902563293ffa9e3a351ff4a149

If it really allocates I would have turned handlers into a map so you can easily add and remove things while keeping the if the same. But of course this would be a bit slower than what you write now. As long as there are enough tests I guess your typeMax solution is also good.

return err
}
}
Expand Down
74 changes: 68 additions & 6 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,12 @@ func TestSendData_Large(t *testing.T) {
defer client.Close()
defer server.Close()

data := make([]byte, 512*1024)
const (
sendSize = 250 * 1024 * 1024
recvSize = 4 * 1024
)

data := make([]byte, sendSize)
for idx := range data {
data[idx] = byte(idx % 256)
}
Expand All @@ -390,16 +395,17 @@ func TestSendData_Large(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

buf := make([]byte, 4*1024)
for i := 0; i < 128; i++ {
var sz int
buf := make([]byte, recvSize)
for i := 0; i < sendSize/recvSize; i++ {
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4*1024 {
if n != recvSize {
t.Fatalf("short read: %d", n)
}
sz += n
for idx := range buf {
if buf[idx] != byte(idx%256) {
t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
Expand All @@ -410,6 +416,8 @@ func TestSendData_Large(t *testing.T) {
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}

t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()

go func() {
Expand Down Expand Up @@ -439,7 +447,7 @@ func TestSendData_Large(t *testing.T) {
}()
select {
case <-doneCh:
case <-time.After(time.Second):
case <-time.After(5 * time.Second):
panic("timeout")
}
}
Expand Down Expand Up @@ -1026,6 +1034,60 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
wg.Wait()
}

func TestSession_PartialReadWindowUpdate(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
defer server.Close()

var wg sync.WaitGroup
wg.Add(1)

// Choose a huge flood size that we know will result in a window update.
flood := int64(client.config.MaxStreamWindowSize)
var wr *Stream

// The server will accept a new stream and then flood data to it.
go func() {
defer wg.Done()

var err error
wr, err = server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer wr.Close()

if wr.sendWindow != client.config.MaxStreamWindowSize {
t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
}

n, err := wr.Write(make([]byte, flood))
if err != nil {
t.Fatalf("err: %v", err)
}
if int64(n) != flood {
t.Fatalf("short write: %d", n)
}
if wr.sendWindow != 0 {
t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
}
}()

stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()

wg.Wait()

_, err = stream.Read(make([]byte, flood/2+1))

if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
}
}

func TestSession_sendNoWait_Timeout(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
Expand Down
23 changes: 16 additions & 7 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,25 @@ func (s *Stream) sendWindowUpdate() error {

// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did the atomic load go away? I don't have too much context around this code, but seems like the receive lock recvLock has a lot more contention since many other places use it, vs previously here it was using atomic load to get a potentially faster lower level primitive to load recvWindow. Can this new logic for bufLen be rewritten to use atomic.loadUint32 for both recvWindow and recvBuf.len() so we don't need the recvLock mutex?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @preetapan, thanks for the feedback.

We must synchronize access to the recvWindow and recvBuf values, because they are related. The previous (incorrect) logic only looked at the recvWindow for determining the value of the window update after a Read operation. The bug surfaced when a Read does not completely drain the recvBuf. We must then take remaining bytes into consideration when calculating the final window update value and therefore take the lock.

var bufLen uint32
s.recvLock.Lock()
if s.recvBuf != nil {
bufLen = uint32(s.recvBuf.Len())
}
delta := (max - bufLen) - s.recvWindow

// Determine the flags if any
flags := s.sendFlags()

// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}

// Update our window
atomic.AddUint32(&s.recvWindow, delta)
s.recvWindow += delta
s.recvLock.Unlock()

// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
Expand Down Expand Up @@ -392,16 +399,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}

// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}

// Copy into buffer
s.recvLock.Lock()

if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}

if s.recvBuf == nil {
// Allocate the receive buffer just-in-time to fit the full data frame.
// This way we can read in the whole packet without further allocations.
Expand All @@ -414,7 +423,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
}

// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvWindow += ^uint32(length - 1)
s.recvLock.Unlock()

// Unblock any readers
Expand Down
15 changes: 15 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
package yamux

import (
"sync"
"time"
)

var (
timerPool = &sync.Pool{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This doesn't need to be a pointer but that doesn't really matter.

New: func() interface{} {
timer := time.NewTimer(time.Hour * 1e6)
timer.Stop()
return timer
},
}
)

// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
Expand Down