Skip to content

Commit

Permalink
gbn: misc cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
ellemouton committed Dec 1, 2023
1 parent d84c090 commit c2216e4
Showing 1 changed file with 46 additions and 37 deletions.
83 changes: 46 additions & 37 deletions gbn/gbn_conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ type GoBackNConn struct {

sendQueue *queue

// recvSeq keeps track of the latest, correctly sequenced packet
// sequence that we have received.
recvSeq uint8

resendTicker *time.Ticker

recvDataChan chan *PacketData
Expand All @@ -50,8 +46,6 @@ type GoBackNConn struct {
recvTimeout time.Duration
timeoutsMu sync.RWMutex

log btclog.Logger

// receivedACKSignal channel is used to signal that the queue size has
// been decreased.
receivedACKSignal chan struct{}
Expand All @@ -70,6 +64,8 @@ type GoBackNConn struct {
// remoteClosed is closed if the remote party initiated the FIN sequence.
remoteClosed chan struct{}

log btclog.Logger

// quit is used to stop the normal operations of the connection.
// Once closed, the send and receive streams will still be available
// for the FIN sequence.
Expand Down Expand Up @@ -140,6 +136,13 @@ func (g *GoBackNConn) SetSendTimeout(timeout time.Duration) {
g.sendTimeout = timeout
}

func (g *GoBackNConn) getSendTimeout() time.Duration {
g.timeoutsMu.RLock()
defer g.timeoutsMu.RUnlock()

return g.sendTimeout
}

// SetRecvTimeout sets the timeout used in the Recv function.
func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) {
g.timeoutsMu.Lock()
Expand All @@ -148,20 +151,25 @@ func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) {
g.recvTimeout = timeout
}

func (g *GoBackNConn) getRecvTimeout() time.Duration {
g.timeoutsMu.RLock()
defer g.timeoutsMu.RUnlock()

return g.recvTimeout
}

// Send blocks until an ack is received for the packet sent N packets before.
func (g *GoBackNConn) Send(data []byte) error {
// Wait for handshake to complete before we can send data.
select {
case <-g.quit:
return io.EOF
default:
}

g.timeoutsMu.RLock()
ticker := time.NewTimer(g.sendTimeout)
g.timeoutsMu.RUnlock()
ticker := time.NewTimer(g.getSendTimeout())
defer ticker.Stop()

// sendPacket sends the given packet onto the sendDataChan.
sendPacket := func(packet *PacketData) error {
select {
case g.sendDataChan <- packet:
Expand All @@ -173,8 +181,9 @@ func (g *GoBackNConn) Send(data []byte) error {
}
}

// If splitting is disabled, then we set the packet's FinalChunk to
// true.
if g.cfg.maxChunkSize == 0 {
// Splitting is disabled.
return sendPacket(&PacketData{
Payload: data,
FinalChunk: true,
Expand All @@ -187,7 +196,7 @@ func (g *GoBackNConn) Send(data []byte) error {
maxChunk = g.cfg.maxChunkSize
)
for sentBytes < len(data) {
packet := &PacketData{}
var packet PacketData

remainingBytes := len(data) - sentBytes
if remainingBytes <= maxChunk {
Expand All @@ -199,7 +208,7 @@ func (g *GoBackNConn) Send(data []byte) error {
sentBytes += maxChunk
}

if err := sendPacket(packet); err != nil {
if err := sendPacket(&packet); err != nil {
return err
}
}
Expand All @@ -221,9 +230,7 @@ func (g *GoBackNConn) Recv() ([]byte, error) {
msg *PacketData
)

g.timeoutsMu.RLock()
ticker := time.NewTimer(g.recvTimeout)
g.timeoutsMu.RUnlock()
ticker := time.NewTimer(g.getRecvTimeout())
defer ticker.Stop()

for {
Expand Down Expand Up @@ -365,18 +372,22 @@ func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error {
return nil
}

func (g *GoBackNConn) receivePacket(ctx context.Context) (Message, error) {
b, err := g.cfg.recvFromStream(ctx)
if err != nil {
return nil, fmt.Errorf("error receiving from stream: %w", err)
}

return Deserialize(b)
}

// sendPacketsForever manages the resending logic. It keeps a cache of up to
// N packets and manages the resending of packets if acks are not received for
// them or if NACKs are received. It reads new data from sendDataChan only
// when there is space in the queue.
//
// This function must be called in a go routine.
func (g *GoBackNConn) sendPacketsForever() error {
// resendQueue re-sends the current contents of the queue.
resendQueue := func() error {
return g.sendQueue.resend()
}

for {
// The queue is not empty. If we receive a resend signal
// or if the resend timeout passes then we resend the
Expand All @@ -388,13 +399,13 @@ func (g *GoBackNConn) sendPacketsForever() error {
return nil

case <-g.resendSignal:
if err := resendQueue(); err != nil {
if err := g.sendQueue.resend(); err != nil {
return err
}
continue

case <-g.resendTicker.C:
if err := resendQueue(); err != nil {
if err := g.sendQueue.resend(); err != nil {
return err
}
continue
Expand Down Expand Up @@ -443,11 +454,11 @@ func (g *GoBackNConn) sendPacketsForever() error {
case <-g.receivedACKSignal:
break
case <-g.resendSignal:
if err := resendQueue(); err != nil {
if err := g.sendQueue.resend(); err != nil {
return err
}
case <-g.resendTicker.C:
if err := resendQueue(); err != nil {
if err := g.sendQueue.resend(); err != nil {
return err
}
}
Expand All @@ -464,6 +475,10 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
var (
lastNackSeq uint8
lastNackTime time.Time

// recvSeq keeps track of the latest, correctly sequenced packet
// sequence that we have received.
recvSeq uint8
)

for {
Expand All @@ -473,13 +488,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
default:
}

b, err := g.cfg.recvFromStream(g.ctx)
if err != nil {
return fmt.Errorf("error receiving "+
"from recvFromStream: %s", err)
}

msg, err := Deserialize(b)
msg, err := g.receivePacket(g.ctx)
if err != nil {
return fmt.Errorf("deserialize error: %s", err)
}
Expand All @@ -493,7 +502,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo

switch m := msg.(type) {
case *PacketData:
switch m.Seq == g.recvSeq {
switch m.Seq == recvSeq {
case true:
// We received a data packet with the sequence
// number we were expecting. So we respond with
Expand All @@ -510,7 +519,7 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo
return err
}

g.recvSeq = (g.recvSeq + 1) % g.cfg.s
recvSeq = (recvSeq + 1) % g.cfg.s

// If the packet was a ping, then there is no
// data to return to the above layer.
Expand Down Expand Up @@ -539,19 +548,19 @@ func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo

// If we recently sent a NACK for the same
// sequence number then back off.
if lastNackSeq == g.recvSeq &&
if lastNackSeq == recvSeq &&
time.Since(lastNackTime) <
g.cfg.resendTimeout {

continue
}

g.log.Tracef("Sending NACK %d", g.recvSeq)
g.log.Tracef("Sending NACK %d", recvSeq)

// Send a NACK with the expected sequence
// number.
nack := &PacketNACK{
Seq: g.recvSeq,
Seq: recvSeq,
}

if err = g.sendPacket(g.ctx, nack); err != nil {
Expand Down

0 comments on commit c2216e4

Please sign in to comment.