Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Streaming perf improv #53

Merged
merged 5 commits into from
Mar 14, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions bench_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package yamux

import (
"fmt"
"testing"
)

Expand Down Expand Up @@ -85,7 +84,6 @@ func BenchmarkSendRecvLarge(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()

const sendSize = 512 * 1024 * 1024
const recvSize = 4 * 1024

Expand All @@ -107,9 +105,6 @@ func BenchmarkSendRecvLarge(b *testing.B) {
b.Fatalf("err: %v", err)
}
}

fmt.Printf("Capacity of rcv buffer = %v, length of rcv window = %v\n", stream.recvBuf.Cap(), stream.recvWindow)

}
close(recvDone)
}()
Expand Down
53 changes: 35 additions & 18 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,8 +323,17 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
// potential shutdown. Since there's the expectation that sends can happen
// in a timely manner, we enforce the connection write timeout here.
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()

ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
Expand All @@ -349,8 +358,17 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e
// the send happens right here, we enforce the connection write timeout if we
// can't queue the header to be sent.
func (s *Session) sendNoWait(hdr header) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()

select {
case s.sendCh <- sendReady{Hdr: hdr}:
Expand Down Expand Up @@ -408,11 +426,20 @@ func (s *Session) recv() {
}
}

// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
var (
handlers = []func(*Session, header) error{
typeData: (*Session).handleStreamMessage,
typeWindowUpdate: (*Session).handleStreamMessage,
typePing: (*Session).handlePing,
typeGoAway: (*Session).handleGoAway,
}
)

// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
Expand All @@ -428,22 +455,12 @@ func (s *Session) recvLoop() error {
return ErrInvalidVersion
}

// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
mt := hdr.MsgType()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}

// Invoke the handler
if err := handler(hdr); err != nil {
if err := handlers[mt](s, hdr); err != nil {
return err
}
}
Expand Down
74 changes: 68 additions & 6 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,12 @@ func TestSendData_Large(t *testing.T) {
defer client.Close()
defer server.Close()

data := make([]byte, 512*1024)
const (
sendSize = 250 * 1024 * 1024
recvSize = 4 * 1024
)

data := make([]byte, sendSize)
for idx := range data {
data[idx] = byte(idx % 256)
}
Expand All @@ -390,16 +395,17 @@ func TestSendData_Large(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}

buf := make([]byte, 4*1024)
for i := 0; i < 128; i++ {
var sz int
buf := make([]byte, recvSize)
for i := 0; i < sendSize/recvSize; i++ {
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4*1024 {
if n != recvSize {
t.Fatalf("short read: %d", n)
}
sz += n
for idx := range buf {
if buf[idx] != byte(idx%256) {
t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
Expand All @@ -410,6 +416,8 @@ func TestSendData_Large(t *testing.T) {
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}

t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()

go func() {
Expand Down Expand Up @@ -439,7 +447,7 @@ func TestSendData_Large(t *testing.T) {
}()
select {
case <-doneCh:
case <-time.After(time.Second):
case <-time.After(5 * time.Second):
panic("timeout")
}
}
Expand Down Expand Up @@ -1026,6 +1034,60 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
wg.Wait()
}

func TestSession_PartialReadWindowUpdate(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
defer server.Close()

var wg sync.WaitGroup
wg.Add(1)

// Choose a huge flood size that we know will result in a window update.
flood := int64(client.config.MaxStreamWindowSize)
var wr *Stream

// The server will accept a new stream and then flood data to it.
go func() {
defer wg.Done()

var err error
wr, err = server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer wr.Close()

if wr.sendWindow != client.config.MaxStreamWindowSize {
t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
}

n, err := wr.Write(make([]byte, flood))
if err != nil {
t.Fatalf("err: %v", err)
}
if int64(n) != flood {
t.Fatalf("short write: %d", n)
}
if wr.sendWindow != 0 {
t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
}
}()

stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()

wg.Wait()

_, err = stream.Read(make([]byte, flood/2+1))

if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
}
}

func TestSession_sendNoWait_Timeout(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
Expand Down
23 changes: 16 additions & 7 deletions stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -238,18 +238,25 @@ func (s *Stream) sendWindowUpdate() error {

// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)
var bufLen uint32
s.recvLock.Lock()
if s.recvBuf != nil {
bufLen = uint32(s.recvBuf.Len())
}
delta := (max - bufLen) - s.recvWindow

// Determine the flags if any
flags := s.sendFlags()

// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}

// Update our window
atomic.AddUint32(&s.recvWindow, delta)
s.recvWindow += delta
s.recvLock.Unlock()

// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
Expand Down Expand Up @@ -392,16 +399,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}

// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}

// Copy into buffer
s.recvLock.Lock()

if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}

if s.recvBuf == nil {
// Allocate the receive buffer just-in-time to fit the full data frame.
// This way we can read in the whole packet without further allocations.
Expand All @@ -414,7 +423,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
}

// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvWindow -= length
s.recvLock.Unlock()

// Unblock any readers
Expand Down
15 changes: 15 additions & 0 deletions util.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
package yamux

import (
"sync"
"time"
)

var (
timerPool = &sync.Pool{
New: func() interface{} {
timer := time.NewTimer(time.Hour * 1e6)
timer.Stop()
return timer
},
}
)

// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
Expand Down