diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index acc5626..7fc4962 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -37,10 +37,6 @@ type GoBackNConn struct { sendQueue *queue - // recvSeq keeps track of the latest, correctly sequenced packet - // sequence that we have received. - recvSeq uint8 - resendTicker *time.Ticker recvDataChan chan *PacketData @@ -50,8 +46,6 @@ type GoBackNConn struct { recvTimeout time.Duration timeoutsMu sync.RWMutex - log btclog.Logger - // receivedACKSignal channel is used to signal that the queue size has // been decreased. receivedACKSignal chan struct{} @@ -70,6 +64,8 @@ type GoBackNConn struct { // remoteClosed is closed if the remote party initiated the FIN sequence. remoteClosed chan struct{} + log btclog.Logger + // quit is used to stop the normal operations of the connection. // Once closed, the send and receive streams will still be available // for the FIN sequence. @@ -140,6 +136,13 @@ func (g *GoBackNConn) SetSendTimeout(timeout time.Duration) { g.sendTimeout = timeout } +func (g *GoBackNConn) getSendTimeout() time.Duration { + g.timeoutsMu.RLock() + defer g.timeoutsMu.RUnlock() + + return g.sendTimeout +} + // SetRecvTimeout sets the timeout used in the Recv function. func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) { g.timeoutsMu.Lock() @@ -148,20 +151,25 @@ func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) { g.recvTimeout = timeout } +func (g *GoBackNConn) getRecvTimeout() time.Duration { + g.timeoutsMu.RLock() + defer g.timeoutsMu.RUnlock() + + return g.recvTimeout +} + // Send blocks until an ack is received for the packet sent N packets before. func (g *GoBackNConn) Send(data []byte) error { - // Wait for handshake to complete before we can send data. select { case <-g.quit: return io.EOF default: } - g.timeoutsMu.RLock() - ticker := time.NewTimer(g.sendTimeout) - g.timeoutsMu.RUnlock() + ticker := time.NewTimer(g.getSendTimeout()) defer ticker.Stop() + // sendPacket sends the given packet onto the sendDataChan. sendPacket := func(packet *PacketData) error { select { case g.sendDataChan <- packet: @@ -173,8 +181,9 @@ func (g *GoBackNConn) Send(data []byte) error { } } + // If splitting is disabled, then we set the packet's FinalChunk to + // true. if g.cfg.maxChunkSize == 0 { - // Splitting is disabled. return sendPacket(&PacketData{ Payload: data, FinalChunk: true, @@ -187,7 +196,7 @@ func (g *GoBackNConn) Send(data []byte) error { maxChunk = g.cfg.maxChunkSize ) for sentBytes < len(data) { - packet := &PacketData{} + var packet PacketData remainingBytes := len(data) - sentBytes if remainingBytes <= maxChunk { @@ -199,7 +208,7 @@ func (g *GoBackNConn) Send(data []byte) error { sentBytes += maxChunk } - if err := sendPacket(packet); err != nil { + if err := sendPacket(&packet); err != nil { return err } } @@ -221,9 +230,7 @@ func (g *GoBackNConn) Recv() ([]byte, error) { msg *PacketData ) - g.timeoutsMu.RLock() - ticker := time.NewTimer(g.recvTimeout) - g.timeoutsMu.RUnlock() + ticker := time.NewTimer(g.getRecvTimeout()) defer ticker.Stop() for { @@ -365,6 +372,15 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { return nil } +func (g *GoBackNConn) receivePacket(ctx context.Context) (Message, error) { + b, err := g.cfg.recvFromStream(ctx) + if err != nil { + return nil, fmt.Errorf("error receiving from stream: %w", err) + } + + return Deserialize(b) +} + // sendPacketsForever manages the resending logic. It keeps a cache of up to // N packets and manages the resending of packets if acks are not received for // them or if NACKs are received. It reads new data from sendDataChan only @@ -372,11 +388,6 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { // // This function must be called in a go routine. func (g *GoBackNConn) sendPacketsForever() error { - // resendQueue re-sends the current contents of the queue. - resendQueue := func() error { - return g.sendQueue.resend() - } - for { // The queue is not empty. If we receive a resend signal // or if the resend timeout passes then we resend the @@ -388,13 +399,13 @@ func (g *GoBackNConn) sendPacketsForever() error { return nil case <-g.resendSignal: - if err := resendQueue(); err != nil { + if err := g.sendQueue.resend(); err != nil { return err } continue case <-g.resendTicker.C: - if err := resendQueue(); err != nil { + if err := g.sendQueue.resend(); err != nil { return err } continue @@ -443,11 +454,11 @@ func (g *GoBackNConn) sendPacketsForever() error { case <-g.receivedACKSignal: break case <-g.resendSignal: - if err := resendQueue(); err != nil { + if err := g.sendQueue.resend(); err != nil { return err } case <-g.resendTicker.C: - if err := resendQueue(); err != nil { + if err := g.sendQueue.resend(); err != nil { return err } } @@ -464,6 +475,10 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo var ( lastNackSeq uint8 lastNackTime time.Time + + // recvSeq keeps track of the latest, correctly sequenced packet + // sequence that we have received. + recvSeq uint8 ) for { @@ -473,13 +488,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo default: } - b, err := g.cfg.recvFromStream(g.ctx) - if err != nil { - return fmt.Errorf("error receiving "+ - "from recvFromStream: %s", err) - } - - msg, err := Deserialize(b) + msg, err := g.receivePacket(g.ctx) if err != nil { return fmt.Errorf("deserialize error: %s", err) } @@ -493,7 +502,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo switch m := msg.(type) { case *PacketData: - switch m.Seq == g.recvSeq { + switch m.Seq == recvSeq { case true: // We received a data packet with the sequence // number we were expecting. So we respond with @@ -510,7 +519,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo return err } - g.recvSeq = (g.recvSeq + 1) % g.cfg.s + recvSeq = (recvSeq + 1) % g.cfg.s // If the packet was a ping, then there is no // data to return to the above layer. @@ -539,19 +548,19 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo // If we recently sent a NACK for the same // sequence number then back off. - if lastNackSeq == g.recvSeq && + if lastNackSeq == recvSeq && time.Since(lastNackTime) < g.cfg.resendTimeout { continue } - g.log.Tracef("Sending NACK %d", g.recvSeq) + g.log.Tracef("Sending NACK %d", recvSeq) // Send a NACK with the expected sequence // number. nack := &PacketNACK{ - Seq: g.recvSeq, + Seq: recvSeq, } if err = g.sendPacket(g.ctx, nack); err != nil {