Skip to content

Commit

Permalink
aligning err handling more closely with io.Pipe, and breaking blockin…
Browse files Browse the repository at this point in the history
…g calls on Close(), fixes #4
djherbis committed Jan 24, 2017
1 parent fe07db0 commit bfaed0c
Showing 2 changed files with 97 additions and 24 deletions.
53 changes: 53 additions & 0 deletions nio_test.go
Original file line number Diff line number Diff line change
@@ -3,11 +3,14 @@ package nio
import (
"bufio"
"bytes"
"errors"
"io"
"testing"

"io/ioutil"

"time"

"github.com/djherbis/buffer"
)

@@ -60,10 +63,16 @@ func TestPipeCloseEarly(t *testing.T) {
buf := buffer.New(1024)
r, w := Pipe(buf)
r.Close()

_, err := w.Write([]byte("hello world"))
if err != io.ErrClosedPipe {
t.Errorf("expected closed pipe")
}

_, err = io.Copy(ioutil.Discard, r)
if err != io.ErrClosedPipe {
t.Errorf("expected closed pipe")
}
}

func TestPipe(t *testing.T) {
@@ -89,7 +98,51 @@ func TestPipe(t *testing.T) {
if !bytes.Equal(data, result) {
t.Errorf("exp [%s]\ngot[%s]", string(data), string(result))
}
}

func TestEarlyCloseWrite(t *testing.T) {
buf := buffer.New(1)
r, w := Pipe(buf)

testerr := errors.New("test err")

w.CloseWithError(testerr)

_, err := w.Write([]byte("ab")) // too big for buffer

if err != io.ErrClosedPipe {
t.Errorf("expected %s but got %s.", testerr, err)
}

_, err = io.Copy(ioutil.Discard, r)
if err != testerr {
t.Errorf("expected %s but got %s.", testerr, err)
}
}

func TestUnblockWrite(t *testing.T) {
buf := buffer.New(1)
r, w := Pipe(buf)

testerr := errors.New("test err")

go func() {
<-time.After(100 * time.Millisecond)
if er := w.CloseWithError(testerr); er != nil {
t.Error(er)
}
}()

_, err := w.Write([]byte("ab")) // too big for buffer

if err != io.ErrClosedPipe {
t.Errorf("expected %s but got %s.", testerr, err)
}

_, err = io.Copy(ioutil.Discard, r)
if err != testerr {
t.Errorf("expected %s but got %s.", testerr, err)
}
}

type badBuffer struct{}
68 changes: 44 additions & 24 deletions sync.go
Original file line number Diff line number Diff line change
@@ -17,9 +17,10 @@ func (r *PipeReader) CloseWithError(err error) error {
}
r.bufpipe.l.Lock()
defer r.bufpipe.l.Unlock()
if r.bufpipe.err == nil {
r.bufpipe.err = err
r.bufpipe.c.Signal()
if r.bufpipe.rerr == nil {
r.bufpipe.rerr = err
r.bufpipe.rwait.Signal()
r.bufpipe.wwait.Signal()
}
return nil
}
@@ -42,9 +43,10 @@ func (w *PipeWriter) CloseWithError(err error) error {
}
w.bufpipe.l.Lock()
defer w.bufpipe.l.Unlock()
if w.bufpipe.err == nil {
w.bufpipe.err = err
w.bufpipe.c.Signal()
if w.bufpipe.werr == nil {
w.bufpipe.werr = err
w.bufpipe.rwait.Signal()
w.bufpipe.wwait.Signal()
}
return nil
}
@@ -56,19 +58,22 @@ func (w *PipeWriter) Close() error {
}

type bufpipe struct {
rl sync.Mutex
wl sync.Mutex
l sync.Mutex
c *sync.Cond
b Buffer
err error
rl sync.Mutex
wl sync.Mutex
l sync.Mutex
rwait sync.Cond
wwait sync.Cond
b Buffer
rerr error // if reader closed, error to give writes
werr error // if writer closed, error to give reads
}

func newBufferedPipe(buf Buffer) *bufpipe {
s := &bufpipe{
b: buf,
}
s.c = sync.NewCond(&s.l)
s.rwait.L = &s.l
s.wwait.L = &s.l
return s
}

@@ -85,16 +90,20 @@ func (r *PipeReader) Read(p []byte) (n int, err error) {
defer r.rl.Unlock()

r.l.Lock()
defer r.c.Signal()
defer r.wwait.Signal()
defer r.l.Unlock()

for empty(r.b) {
if r.err != nil {
return 0, r.err
if r.rerr != nil {
return 0, io.ErrClosedPipe
}

r.c.Signal()
r.c.Wait()
if r.werr != nil {
return 0, r.werr
}

r.wwait.Signal()
r.rwait.Wait()
}

n, err = r.b.Read(p)
@@ -115,21 +124,32 @@ func (w *PipeWriter) Write(p []byte) (int, error) {
defer w.wl.Unlock()

w.l.Lock()
defer w.c.Signal()
defer w.rwait.Signal()
defer w.l.Unlock()

if w.err != nil {
if w.werr != nil {
return 0, io.ErrClosedPipe
}

// while there is data to write
for writeLen := sliceLen; writeLen > 0 && err == nil; writeLen = sliceLen - n {

// wait for some buffer space to become available
for space = gap(w.b); space == 0; space = gap(w.b) {
w.c.Signal()
w.c.Wait()
// wait for some buffer space to become available (while no errs)
for space = gap(w.b); space == 0 && w.rerr == nil && w.werr == nil; space = gap(w.b) {
w.rwait.Signal()
w.wwait.Wait()
}

if w.rerr != nil {
err = w.rerr
break
}

if w.werr != nil {
err = io.ErrClosedPipe
break
}

// space > 0, and locked

var nn int64

0 comments on commit bfaed0c

Please sign in to comment.