From c4a08767151a40d9c2c7c9533b7ee89eda6e78ab Mon Sep 17 00:00:00 2001 From: Elle Mouton Date: Fri, 1 Dec 2023 16:52:22 +0200 Subject: [PATCH] gbn: modularise --- gbn/config.go | 10 - gbn/gbn_client.go | 7 +- gbn/gbn_conn.go | 664 ++++++++++------------------------------- gbn/gbn_conn_test.go | 7 +- gbn/gbn_server.go | 9 +- gbn/queue.go | 2 +- gbn/receiver.go | 216 ++++++++++++++ gbn/sender.go | 390 ++++++++++++++++++++++++ gbn/timeouts.go | 33 ++ mailbox/server_conn.go | 2 +- 10 files changed, 818 insertions(+), 522 deletions(-) create mode 100644 gbn/receiver.go create mode 100644 gbn/sender.go create mode 100644 gbn/timeouts.go diff --git a/gbn/config.go b/gbn/config.go index 4d14f7f2..3f21ef14 100644 --- a/gbn/config.go +++ b/gbn/config.go @@ -10,15 +10,6 @@ type config struct { // GoBN handshake. n uint8 - // s is the maximum sequence number used to label packets. Packets - // are labelled with incrementing sequence numbers modulo s. - // s must be strictly larger than the window size, n. This - // is so that the receiver can tell if the sender is resending the - // previous window (maybe the sender did not receive the acks) or if - // they are sending the next window. If s <= n then there would be - // no way to tell. - s uint8 - // maxChunkSize is the maximum payload size in bytes allowed per // message. If the payload to be sent is larger than maxChunkSize then // the payload will be split between multiple packets. @@ -53,7 +44,6 @@ func newConfig(sendFunc sendBytesFunc, recvFunc recvBytesFunc, return &config{ n: n, - s: n + 1, recvFromStream: recvFunc, sendToStream: sendFunc, resendTimeout: defaultResendTimeout, diff --git a/gbn/gbn_client.go b/gbn/gbn_client.go index c1032e26..5359a4bc 100644 --- a/gbn/gbn_client.go +++ b/gbn/gbn_client.go @@ -14,7 +14,7 @@ import ( // The resendTimeout parameter defines the duration to wait before resending data // if the corresponding ACK for the data is not received. func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, - receiveFunc recvBytesFunc, opts ...Option) (*GoBackNConn, error) { + receiveFunc recvBytesFunc, opts ...Option) (GBN, error) { if n == math.MaxUint8 { return nil, fmt.Errorf("n must be smaller than %d", @@ -28,7 +28,7 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, o(cfg) } - conn := newGoBackNConn(ctx, cfg, "client") + conn := newGBN(ctx, cfg, "client") if err := conn.clientHandshake(); err != nil { if err := conn.Close(); err != nil { @@ -36,6 +36,7 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, } return nil, err } + conn.start() return conn, nil @@ -50,7 +51,7 @@ func NewClientConn(ctx context.Context, n uint8, sendFunc sendBytesFunc, // SYNACK. // 3b. If the client does not receive SYN from the server within a given // timeout, then the client restarts the handshake from step 1. -func (g *GoBackNConn) clientHandshake() error { +func (g *gbn) clientHandshake() error { // Spin off the recv function in a goroutine so that we can use // a select to choose to timeout waiting for data from the receive // stream. This is needed instead of a context timeout because the diff --git a/gbn/gbn_conn.go b/gbn/gbn_conn.go index 7fc4962c..4d592882 100644 --- a/gbn/gbn_conn.go +++ b/gbn/gbn_conn.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - "io" - "math" "sync" "time" @@ -25,297 +23,238 @@ const ( defaultHandshakeTimeout = 100 * time.Millisecond defaultResendTimeout = 100 * time.Millisecond finSendTimeout = 1000 * time.Millisecond - DefaultSendTimeout = math.MaxInt64 - DefaultRecvTimeout = math.MaxInt64 ) type sendBytesFunc func(ctx context.Context, b []byte) error type recvBytesFunc func(ctx context.Context) ([]byte, error) -type GoBackNConn struct { - cfg *config - - sendQueue *queue - - resendTicker *time.Ticker - - recvDataChan chan *PacketData - sendDataChan chan *PacketData - - sendTimeout time.Duration - recvTimeout time.Duration - timeoutsMu sync.RWMutex - - // receivedACKSignal channel is used to signal that the queue size has - // been decreased. - receivedACKSignal chan struct{} +type GBN interface { + Send([]byte) error + Recv() ([]byte, error) + SetRecvTimeout(timeout time.Duration) + SetSendTimeout(timeout time.Duration) + Close() error +} - // resendSignal is used to signal that normal operation sending should - // stop and the current queue contents should first be resent. Note - // that this channel should only be listened on in one place. - resendSignal chan struct{} +type gbn struct { + cfg *config - pingTicker *IntervalAwareForceTicker - pongTicker *IntervalAwareForceTicker + // s is the maximum sequence number used to label packets. Packets + // are labelled with incrementing sequence numbers modulo s. + // s must be strictly larger than the window size, n. This + // is so that the receiver can tell if the sender is resending the + // previous window (maybe the sender did not receive the acks) or if + // they are sending the next window. If s <= n then there would be + // no way to tell. + s uint8 ctx context.Context //nolint:containedctx cancel func() + sender *sender + senderErr chan error + receiver *receiver + receiverErr chan error + // remoteClosed is closed if the remote party initiated the FIN sequence. remoteClosed chan struct{} log btclog.Logger - // quit is used to stop the normal operations of the connection. - // Once closed, the send and receive streams will still be available - // for the FIN sequence. quit chan struct{} closeOnce sync.Once wg sync.WaitGroup + errChan chan error } -// newGoBackNConn creates a GoBackNConn instance with all the members which -// are common between client and server initialised. -func newGoBackNConn(ctx context.Context, cfg *config, - loggerPrefix string) *GoBackNConn { - +func newGBN(ctx context.Context, cfg *config, loggerPrefix string) *gbn { ctxc, cancel := context.WithCancel(ctx) // Construct a new prefixed logger. prefix := fmt.Sprintf("(%s)", loggerPrefix) plog := build.NewPrefixLog(prefix, log) - g := &GoBackNConn{ - cfg: cfg, - recvDataChan: make(chan *PacketData, cfg.n), - sendDataChan: make(chan *PacketData), - recvTimeout: DefaultRecvTimeout, - sendTimeout: DefaultSendTimeout, - receivedACKSignal: make(chan struct{}), - resendSignal: make(chan struct{}, 1), - remoteClosed: make(chan struct{}), - ctx: ctxc, - cancel: cancel, - log: plog, - quit: make(chan struct{}), + senderErr := make(chan error, 1) + receiverErr := make(chan error, 1) + + g := &gbn{ + cfg: cfg, + ctx: ctxc, + cancel: cancel, + log: plog, + senderErr: senderErr, + receiverErr: receiverErr, + remoteClosed: make(chan struct{}), + errChan: make(chan error, 1), + quit: make(chan struct{}), } - g.sendQueue = newQueue(&queueCfg{ - s: cfg.n + 1, - timeout: cfg.resendTimeout, - log: plog, - sendPkt: func(packet *PacketData) error { - return g.sendPacket(g.ctx, packet) - }, - }) + g.sender = newSender( + cfg.n, g.sendPacket, senderErr, plog, cfg.resendTimeout, + ) + + g.receiver = newReceiver( + cfg.n+1, g.sendPacket, receiverErr, cfg.resendTimeout, plog, + ) return g } -// setN sets the current N to use. This _must_ be set before the handshake is -// completed. -func (g *GoBackNConn) setN(n uint8) { +func (g *gbn) setN(n uint8) { g.cfg.n = n - g.cfg.s = n + 1 - g.recvDataChan = make(chan *PacketData, n) - g.sendQueue = newQueue(&queueCfg{ - s: n + 1, - timeout: g.cfg.resendTimeout, - log: g.log, - sendPkt: func(packet *PacketData) error { - return g.sendPacket(g.ctx, packet) - }, - }) -} - -// SetSendTimeout sets the timeout used in the Send function. -func (g *GoBackNConn) SetSendTimeout(timeout time.Duration) { - g.timeoutsMu.Lock() - defer g.timeoutsMu.Unlock() - - g.sendTimeout = timeout + g.s = n + 1 + g.sender = newSender( + g.cfg.n, g.sendPacket, g.senderErr, g.log, g.cfg.resendTimeout, + ) + g.receiver = newReceiver( + g.cfg.n+1, g.sendPacket, g.receiverErr, g.cfg.resendTimeout, + g.log, + ) } -func (g *GoBackNConn) getSendTimeout() time.Duration { - g.timeoutsMu.RLock() - defer g.timeoutsMu.RUnlock() - - return g.sendTimeout +func (g *gbn) Send(data []byte) error { + return g.sender.Send(data) } -// SetRecvTimeout sets the timeout used in the Recv function. -func (g *GoBackNConn) SetRecvTimeout(timeout time.Duration) { - g.timeoutsMu.Lock() - defer g.timeoutsMu.Unlock() - - g.recvTimeout = timeout +func (g *gbn) Recv() ([]byte, error) { + return g.receiver.Receive() } -func (g *GoBackNConn) getRecvTimeout() time.Duration { - g.timeoutsMu.RLock() - defer g.timeoutsMu.RUnlock() +// start kicks off the various goroutines needed by GoBackNConn. +// start should only be called once the handshake has been completed. +func (g *gbn) start() { + g.log.Debugf("Starting") - return g.recvTimeout -} + g.wg.Add(1) + go g.packetDistributor() -// Send blocks until an ack is received for the packet sent N packets before. -func (g *GoBackNConn) Send(data []byte) error { - select { - case <-g.quit: - return io.EOF - default: - } + g.sender.start() + g.receiver.start() - ticker := time.NewTimer(g.getSendTimeout()) - defer ticker.Stop() + go func() { + var ( + err error + errProducer string + ) - // sendPacket sends the given packet onto the sendDataChan. - sendPacket := func(packet *PacketData) error { select { - case g.sendDataChan <- packet: - return nil - case <-ticker.C: - return errSendTimeout + case <-g.senderErr: + errProducer = "sender" + case <-g.receiverErr: + errProducer = "receiver" + case <-g.errChan: + errProducer = "gbn" case <-g.quit: - return fmt.Errorf("cannot send, gbn exited") + return } - } - - // If splitting is disabled, then we set the packet's FinalChunk to - // true. - if g.cfg.maxChunkSize == 0 { - return sendPacket(&PacketData{ - Payload: data, - FinalChunk: true, - }) - } - // Splitting is enabled. Split into packets no larger than maxChunkSize. - var ( - sentBytes = 0 - maxChunk = g.cfg.maxChunkSize - ) - for sentBytes < len(data) { - var packet PacketData - - remainingBytes := len(data) - sentBytes - if remainingBytes <= maxChunk { - packet.Payload = data[sentBytes:] - sentBytes += remainingBytes - packet.FinalChunk = true - } else { - packet.Payload = data[sentBytes : sentBytes+maxChunk] - sentBytes += maxChunk - } + g.log.Errorf("Error from %s: %v", errProducer, err) - if err := sendPacket(&packet); err != nil { - return err + if err := g.Close(); err != nil { + g.log.Errorf("Error closing gbn: %v", err) } - } - - return nil + }() } -// Recv blocks until it gets a recv with the correct sequence it was expecting. -func (g *GoBackNConn) Recv() ([]byte, error) { - // Wait for handshake to complete - select { - case <-g.quit: - return nil, io.EOF - default: +func (g *gbn) receivePacket() (Message, error) { + b, err := g.cfg.recvFromStream(g.ctx) + if err != nil { + return nil, fmt.Errorf("error receiving from stream: %w", err) } - var ( - b []byte - msg *PacketData - ) + m, err := Deserialize(b) + if err != nil { + return nil, err + } - ticker := time.NewTimer(g.getRecvTimeout()) - defer ticker.Stop() + g.sender.AnyReceive() - for { - select { - case <-g.quit: - return nil, fmt.Errorf("cannot receive, gbn exited") - case <-ticker.C: - return nil, errRecvTimeout - case msg = <-g.recvDataChan: - } + return m, nil +} - b = append(b, msg.Payload...) +func (g *gbn) sendPacket(msg Message) error { + b, err := msg.Serialize() + if err != nil { + return fmt.Errorf("serialize error: %s", err) + } - if msg.FinalChunk { - break - } + err = g.cfg.sendToStream(g.ctx, b) + if err != nil { + return fmt.Errorf("error calling sendToStream: %s", err) } - return b, nil + return nil } -// start kicks off the various goroutines needed by GoBackNConn. -// start should only be called once the handshake has been completed. -func (g *GoBackNConn) start() { - g.log.Debugf("Starting") - - pingTime := time.Duration(math.MaxInt64) - if g.cfg.pingTime != 0 { - pingTime = g.cfg.pingTime +func (g *gbn) sendPacketWithCtx(ctx context.Context, msg Message) error { + b, err := msg.Serialize() + if err != nil { + return fmt.Errorf("serialize error: %s", err) } - g.pingTicker = NewIntervalAwareForceTicker(pingTime) - g.pingTicker.Resume() - - pongTime := time.Duration(math.MaxInt64) - if g.cfg.pongTime != 0 { - pongTime = g.cfg.pongTime + err = g.cfg.sendToStream(ctx, b) + if err != nil { + return fmt.Errorf("error calling sendToStream: %s", err) } - g.pongTicker = NewIntervalAwareForceTicker(pongTime) + return nil +} - g.resendTicker = time.NewTicker(g.cfg.resendTimeout) +func (g *gbn) packetDistributor() { + defer g.wg.Done() - g.wg.Add(1) - go func() { - defer func() { - g.wg.Done() - if err := g.Close(); err != nil { - g.log.Errorf("Error closing GoBackNConn: %v", - err) - } - }() + for { + select { + case <-g.quit: + return + default: + } - err := g.receivePacketsForever() + msg, err := g.receivePacket() if err != nil { - g.log.Debugf("Error in receivePacketsForever: %v", err) + g.errChan <- fmt.Errorf("deserialize error: %s", err) + + return } - g.log.Debugf("receivePacketsForever stopped") - }() + switch m := msg.(type) { + case *PacketData: + // Send DATA packets to the receiver. + g.receiver.GotData(m) - g.wg.Add(1) - go func() { - defer func() { - g.wg.Done() - if err := g.Close(); err != nil { - g.log.Errorf("Error closing GoBackNConn: %v", - err) - } - }() + case *PacketACK: + // ACKs go to the sender. + g.sender.ACK(m.Seq) - err := g.sendPacketsForever() - if err != nil { - g.log.Debugf("Error in sendPacketsForever: %v", err) + case *PacketNACK: + // NACKs go to the sender. + g.sender.NACK(m.Seq) - } + case *PacketFIN: + // A FIN packet indicates that the peer would like to + // close the connection. + g.log.Tracef("Received a FIN packet") - g.log.Debugf("sendPacketsForever stopped") - }() + close(g.remoteClosed) + g.errChan <- errTransportClosing + + default: + g.errChan <- fmt.Errorf("received unhandled message: "+ + "%T", msg) + + return + } + } } -// Close attempts to cleanly close the connection by sending a FIN message. -func (g *GoBackNConn) Close() error { +func (g *gbn) Close() error { g.closeOnce.Do(func() { - g.log.Debugf("Closing GoBackNConn") + g.log.Debugf("Closing GBN") + + // Canceling the context will ensure that we are not hanging on + // the receive or send functions passed to the server on + // initialisation. + g.cancel() // We close the quit channel to stop the usual operations of the // server. @@ -332,307 +271,30 @@ func (g *GoBackNConn) Close() error { g.ctx, finSendTimeout, ) defer cancel() - if err := g.sendPacket(ctxc, &PacketFIN{}); err != nil { + + err := g.sendPacketWithCtx(ctxc, &PacketFIN{}) + if err != nil { g.log.Errorf("Error sending FIN: %v", err) } } - // Canceling the context will ensure that we are not hanging on - // the receive or send functions passed to the server on - // initialisation. - g.cancel() + g.receiver.stop() + g.sender.stop() g.wg.Wait() - if g.pingTicker != nil { - g.pingTicker.Stop() - } - if g.resendTicker != nil { - g.resendTicker.Stop() - } - g.log.Debugf("GBN is closed") }) return nil } -// sendPacket serializes a message and writes it to the underlying send stream. -func (g *GoBackNConn) sendPacket(ctx context.Context, msg Message) error { - b, err := msg.Serialize() - if err != nil { - return fmt.Errorf("serialize error: %s", err) - } - - err = g.cfg.sendToStream(ctx, b) - if err != nil { - return fmt.Errorf("error calling sendToStream: %s", err) - } - - return nil -} - -func (g *GoBackNConn) receivePacket(ctx context.Context) (Message, error) { - b, err := g.cfg.recvFromStream(ctx) - if err != nil { - return nil, fmt.Errorf("error receiving from stream: %w", err) - } - - return Deserialize(b) +func (g *gbn) SetRecvTimeout(t time.Duration) { + g.receiver.SetTimeout(t) } -// sendPacketsForever manages the resending logic. It keeps a cache of up to -// N packets and manages the resending of packets if acks are not received for -// them or if NACKs are received. It reads new data from sendDataChan only -// when there is space in the queue. -// -// This function must be called in a go routine. -func (g *GoBackNConn) sendPacketsForever() error { - for { - // The queue is not empty. If we receive a resend signal - // or if the resend timeout passes then we resend the - // current contents of the queue. Otherwise, wait for - // more data to arrive on sendDataChan. - var packet *PacketData - select { - case <-g.quit: - return nil - - case <-g.resendSignal: - if err := g.sendQueue.resend(); err != nil { - return err - } - continue - - case <-g.resendTicker.C: - if err := g.sendQueue.resend(); err != nil { - return err - } - continue - - case <-g.pingTicker.Ticks(): - - // Start the pong timer. - g.pongTicker.Reset() - g.pongTicker.Resume() - - g.log.Tracef("Sending a PING packet") - - packet = &PacketData{ - IsPing: true, - } - - case <-g.pongTicker.Ticks(): - return errKeepaliveTimeout - - case packet = <-g.sendDataChan: - } - - // New data has arrived that we need to add to the queue and - // send. - g.sendQueue.addPacket(packet) - - g.log.Tracef("Sending data %d", packet.Seq) - if err := g.sendPacket(g.ctx, packet); err != nil { - return err - } - - for { - // If the queue size is still less than N, we can - // continue to add more packets to the queue. - if g.sendQueue.size() < g.cfg.n { - break - } - - g.log.Tracef("The queue is full.") - - // The queue is full. We wait for a ACKs to arrive or - // resend the queue after a timeout. - select { - case <-g.quit: - return nil - case <-g.receivedACKSignal: - break - case <-g.resendSignal: - if err := g.sendQueue.resend(); err != nil { - return err - } - case <-g.resendTicker.C: - if err := g.sendQueue.resend(); err != nil { - return err - } - } - } - } +func (g *gbn) SetSendTimeout(t time.Duration) { + g.sender.SetTimeout(t) } -// receivePacketsForever uses the provided recvFromStream to get new data -// from the underlying transport. It then checks to see if what was received is -// data, an ACK, NACK or FIN signal and then processes the packet accordingly. -// -// This function must be called in a go routine. -func (g *GoBackNConn) receivePacketsForever() error { // nolint:gocyclo - var ( - lastNackSeq uint8 - lastNackTime time.Time - - // recvSeq keeps track of the latest, correctly sequenced packet - // sequence that we have received. - recvSeq uint8 - ) - - for { - select { - case <-g.quit: - return nil - default: - } - - msg, err := g.receivePacket(g.ctx) - if err != nil { - return fmt.Errorf("deserialize error: %s", err) - } - - // Reset the ping & pong timer if any packet is received. - // If ping/pong is disabled, this is a no-op. - g.pingTicker.Reset() - if g.pongTicker.IsActive() { - g.pongTicker.Pause() - } - - switch m := msg.(type) { - case *PacketData: - switch m.Seq == recvSeq { - case true: - // We received a data packet with the sequence - // number we were expecting. So we respond with - // an ACK message with that sequence number - // and we bump the sequence number that we - // expect of the next data packet. - g.log.Tracef("Got expected data %d", m.Seq) - - ack := &PacketACK{ - Seq: m.Seq, - } - - if err = g.sendPacket(g.ctx, ack); err != nil { - return err - } - - recvSeq = (recvSeq + 1) % g.cfg.s - - // If the packet was a ping, then there is no - // data to return to the above layer. - if m.IsPing { - continue - } - - // Pass the returned packet to the layer above - // GBN. - select { - case g.recvDataChan <- m: - case <-g.quit: - return nil - } - - case false: - // We received a data packet with a sequence - // number that we were not expecting. This - // could be a packet that we have already - // received and that is being resent because - // the ACK for it was not received in time or - // it could be that we missed a previous packet. - // In either case, we send a NACK with the - // sequence number that we were expecting. - g.log.Tracef("Got unexpected data %d", m.Seq) - - // If we recently sent a NACK for the same - // sequence number then back off. - if lastNackSeq == recvSeq && - time.Since(lastNackTime) < - g.cfg.resendTimeout { - - continue - } - - g.log.Tracef("Sending NACK %d", recvSeq) - - // Send a NACK with the expected sequence - // number. - nack := &PacketNACK{ - Seq: recvSeq, - } - - if err = g.sendPacket(g.ctx, nack); err != nil { - return err - } - - lastNackTime = time.Now() - lastNackSeq = nack.Seq - } - - case *PacketACK: - gotValidACK := g.sendQueue.processACK(m.Seq) - if gotValidACK { - g.resendTicker.Reset(g.cfg.resendTimeout) - - // Send a signal to indicate that new - // ACKs have been received. - select { - case g.receivedACKSignal <- struct{}{}: - default: - } - } - - case *PacketNACK: - // We received a NACK packet. This means that the - // receiver got a data packet that they were not - // expecting. This likely means that a packet that we - // sent was dropped, or maybe we sent a duplicate - // message. The NACK message contains the sequence - // number that the receiver was expecting. - inQueue, bumped := g.sendQueue.processNACK(m.Seq) - - // If the NACK sequence number is not in our queue - // then we ignore it. We must have received the ACK - // for the sequence number in the meantime. - if !inQueue { - g.log.Tracef("NACK seq %d is not in the "+ - "queue. Ignoring", m.Seq) - - continue - } - - // If the base was bumped, then the queue is now smaller - // and so we can send a signal to indicate this. - if bumped { - select { - case g.receivedACKSignal <- struct{}{}: - default: - } - } - - g.log.Tracef("Sending a resend signal") - - // Send a signal to indicate that new sends should pause - // and the current queue should be resent instead. - select { - case g.resendSignal <- struct{}{}: - default: - } - - case *PacketFIN: - // A FIN packet indicates that the peer would like to - // close the connection. - g.log.Tracef("Received a FIN packet") - - close(g.remoteClosed) - - return errTransportClosing - - default: - return fmt.Errorf("received unexpected message: %T", - msg) - } - } -} +var _ GBN = (*gbn)(nil) diff --git a/gbn/gbn_conn_test.go b/gbn/gbn_conn_test.go index 94cd8429..1092a33f 100644 --- a/gbn/gbn_conn_test.go +++ b/gbn/gbn_conn_test.go @@ -152,7 +152,7 @@ func TestServerHandshakeTimeout(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) var ( - server *GoBackNConn + server GBN wg sync.WaitGroup ) defer func() { @@ -189,6 +189,7 @@ func TestServerHandshakeTimeout(t *testing.T) { cancel() wg.Wait() + } func TestDroppedMessage(t *testing.T) { @@ -1211,12 +1212,12 @@ func TestPayloadSplitting(t *testing.T) { func setUpClientServerConns(t *testing.T, n uint8, cRead, sRead func(ctx context.Context) ([]byte, error), cWrite, sWrite func(ctx context.Context, b []byte) error, - opts ...Option) (*GoBackNConn, *GoBackNConn, func()) { + opts ...Option) (GBN, GBN, func()) { t.Helper() var ( - server *GoBackNConn + server GBN err error wg sync.WaitGroup ) diff --git a/gbn/gbn_server.go b/gbn/gbn_server.go index 68e3b5b7..dbbef5c7 100644 --- a/gbn/gbn_server.go +++ b/gbn/gbn_server.go @@ -2,6 +2,7 @@ package gbn import ( "context" + "fmt" "io" "time" ) @@ -12,7 +13,7 @@ import ( // The resendTimeout parameter defines the duration to wait before resending data // if the corresponding ACK for the data is not received. func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, - recvFunc recvBytesFunc, opts ...Option) (*GoBackNConn, error) { + recvFunc recvBytesFunc, opts ...Option) (GBN, error) { cfg := newConfig(sendFunc, recvFunc, DefaultN) @@ -21,7 +22,7 @@ func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, o(cfg) } - conn := newGoBackNConn(ctx, cfg, "server") + conn := newGBN(ctx, cfg, "server") if err := conn.serverHandshake(); err != nil { if err := conn.Close(); err != nil { @@ -30,6 +31,7 @@ func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, return nil, err } + conn.start() return conn, nil @@ -44,7 +46,7 @@ func NewServerConn(ctx context.Context, sendFunc sendBytesFunc, // handshake is considered complete. // 4b. If SYNACK is not received before a certain resendTimeout, then the // handshake is aborted and the process is started from step 1 again. -func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo +func (g *gbn) serverHandshake() error { // nolint:gocyclo recvChan := make(chan []byte) recvNext := make(chan int, 1) errChan := make(chan error, 1) @@ -111,6 +113,7 @@ func (g *GoBackNConn) serverHandshake() error { // nolint:gocyclo switch msg.(type) { case *PacketSYN: default: + fmt.Printf("got something else %T\n", msg) g.log.Tracef("Expected SYN, got %T", msg) continue } diff --git a/gbn/queue.go b/gbn/queue.go index e7a3077e..6da7af5c 100644 --- a/gbn/queue.go +++ b/gbn/queue.go @@ -21,7 +21,7 @@ type queueCfg struct { log btclog.Logger - sendPkt func(packet *PacketData) error + sendPkt func(Message) error } // queue is a fixed size queue with a sliding window that has a base and a top diff --git a/gbn/receiver.go b/gbn/receiver.go new file mode 100644 index 00000000..1885c1d8 --- /dev/null +++ b/gbn/receiver.go @@ -0,0 +1,216 @@ +package gbn + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/btcsuite/btclog" +) + +type Receiver interface { + Receive() ([]byte, error) + GotData(data *PacketData) + SetTimeout(duration time.Duration) +} + +// receiver is only interested in receiving DATA packets, and sending ACKs and +// NACKs for them. +type receiver struct { + // s is the maximum sequence number used to label packets. Packets + // are labelled with incrementing sequence numbers modulo s. + s uint8 + + sendPkt func(Message) error + + timeout *safeTimeout + + expectedSeq uint8 + lastNackSeq uint8 + lastNackTime time.Time + + resendTimeout time.Duration + + newDataChan chan *PacketData + processedDataChan chan *PacketData + + log btclog.Logger + + errChan chan error + quit chan struct{} + wg sync.WaitGroup +} + +var _ Receiver = (*receiver)(nil) + +func newReceiver(s uint8, sendFn func(Message) error, + errChan chan error, resendTimeout time.Duration, + logger btclog.Logger) *receiver { + + return &receiver{ + s: s, + sendPkt: sendFn, + timeout: newSafeTimeout(), + resendTimeout: resendTimeout, + newDataChan: make(chan *PacketData, s), + processedDataChan: make(chan *PacketData, s), + log: logger, + errChan: errChan, + quit: make(chan struct{}), + } +} + +func (r *receiver) start() { + r.wg.Add(1) + go r.receiveForever() +} + +func (r *receiver) receiveForever() { + defer r.wg.Done() + + var err error + for { + select { + case pkt := <-r.newDataChan: + if pkt.Seq != r.expectedSeq { + err = r.handleUnexpectedPkt(pkt) + if err != nil { + r.errChan <- err + + return + } + + continue + } + + err = r.handleExpectedPkt(pkt) + if err != nil { + r.errChan <- err + + return + } + + case <-r.quit: + return + } + } +} + +func (r *receiver) handleExpectedPkt(pkt *PacketData) error { + // We received a data packet with the sequence number we were expecting. + // So we respond with an ACK message with that sequence number and we + // bump the sequence number that we expect of the next data packet. + r.log.Tracef("Got expected data %d", pkt.Seq) + + ack := &PacketACK{ + Seq: pkt.Seq, + } + + if err := r.sendPkt(ack); err != nil { + return err + } + + r.expectedSeq = (r.expectedSeq + 1) % r.s + + // If the packet was a ping, then there is no data to return to the + // above layer. + if pkt.IsPing { + return nil + } + + // Pass the returned packet to the layer above + // GBN. + select { + case r.processedDataChan <- pkt: + case <-r.quit: + return nil + } + + return nil +} + +func (r *receiver) handleUnexpectedPkt(pkt *PacketData) error { + // We received a data packet with a sequence number that we were not + // expecting. This could be a packet that we have already received and + // that is being resent because the ACK for it was not received in time + // or it could be that we missed a previous packet. In either case, we + // send a NACK with the sequence number that we were expecting. + r.log.Tracef("Got unexpected data %d", pkt.Seq) + + // If we recently sent a NACK for the same sequence number then back + // off. + if r.lastNackSeq == r.expectedSeq && time.Since(r.lastNackTime) < + r.resendTimeout { + + return nil + } + + r.log.Tracef("Sending NACK %d", r.expectedSeq) + + // Send a NACK with the expected sequence + // number. + nack := &PacketNACK{ + Seq: r.expectedSeq, + } + + if err := r.sendPkt(nack); err != nil { + return err + } + + r.lastNackTime = time.Now() + r.lastNackSeq = nack.Seq + + return nil +} + +func (r *receiver) Receive() ([]byte, error) { + select { + case <-r.quit: + return nil, io.EOF + default: + } + + var ( + b []byte + msg *PacketData + ) + + ticker := time.NewTimer(r.timeout.get()) + defer ticker.Stop() + + for { + select { + case <-r.quit: + return nil, fmt.Errorf("cannot receive, receiver " + + "exited") + case <-ticker.C: + return nil, errRecvTimeout + case msg = <-r.processedDataChan: + } + + b = append(b, msg.Payload...) + + if msg.FinalChunk { + break + } + } + + return b, nil +} + +func (r *receiver) GotData(data *PacketData) { + select { + case <-r.quit: + case r.newDataChan <- data: + } +} + +func (r *receiver) SetTimeout(duration time.Duration) { + r.timeout.set(duration) +} + +func (r *receiver) stop() { + close(r.quit) + r.wg.Wait() +} diff --git a/gbn/sender.go b/gbn/sender.go new file mode 100644 index 00000000..c5587c7f --- /dev/null +++ b/gbn/sender.go @@ -0,0 +1,390 @@ +package gbn + +import ( + "fmt" + "io" + "math" + "sync" + "time" + + "github.com/btcsuite/btclog" +) + +type Sender interface { + Send([]byte) error + ACK(uint8) + NACK(uint8) + SetTimeout(time.Duration) + AnyReceive() +} + +// sender is only concerned with sending DATA packets, and handling received +// ACK and NACK packets. +type sender struct { + // n is the window size. The sender can send a maximum of n packets + // before requiring an ack from the receiver for the first packet in + // the window. + n uint8 + + timeout *safeTimeout + + sendQueue *queue + + sendFn func(Message) error + + log btclog.Logger + + resendTicker *time.Ticker + resendTimeout time.Duration + + pingTicker *IntervalAwareForceTicker + pingTickerMu sync.Mutex + pongTicker *IntervalAwareForceTicker + + pingTime time.Duration + pongTime time.Duration + + // maxChunkSize is the maximum payload size in bytes allowed per + // message. If the payload to be sent is larger than maxChunkSize then + // the payload will be split between multiple packets. + // If maxChunkSize is zero then it is disabled and data won't be split + // between packets. + maxChunkSize int + + newData chan *PacketData + ackChan chan uint8 + nackChan chan uint8 + + // receivedACKSignal channel is used to signal that the queue size has + // potentially been decreased. + receivedACKSignal chan struct{} + + // resendSignal is used to signal that normal operation sending should + // stop and the current queue contents should first be resent. Note + // that this channel should only be listened on in one place. + resendSignal chan struct{} + + quit chan struct{} + wg sync.WaitGroup + errChan chan error +} + +func (s *sender) AnyReceive() { + select { + case <-s.quit: + return + default: + } + + // Reset the ping & pong timer if any packet is received. + // If ping/pong is disabled, this is a no-op. + s.pingTickerMu.Lock() + s.pingTicker.Reset() + if s.pongTicker.IsActive() { + s.pongTicker.Pause() + } + s.pingTickerMu.Unlock() +} + +func newSender(n uint8, sendFn func(Message) error, errChan chan error, + logger btclog.Logger, resendTimeout time.Duration) *sender { + + return &sender{ + n: n, + timeout: newSafeTimeout(), + newData: make(chan *PacketData), + ackChan: make(chan uint8, n), + nackChan: make(chan uint8, n), + resendTimeout: resendTimeout, + log: logger, + sendFn: sendFn, + sendQueue: newQueue(&queueCfg{ + s: n + 1, + timeout: resendTimeout, + log: logger, + sendPkt: sendFn, + }), + errChan: errChan, + receivedACKSignal: make(chan struct{}), + resendSignal: make(chan struct{}, 1), + quit: make(chan struct{}), + } +} + +func (s *sender) stop() { + close(s.quit) + s.wg.Wait() + + if s.resendTicker != nil { + s.resendTicker.Stop() + } + s.pingTickerMu.Lock() + if s.pingTicker != nil { + s.pingTicker.Stop() + } + s.pingTickerMu.Unlock() +} + +func (s *sender) start() { + pingTime := time.Duration(math.MaxInt64) + if s.pingTime != 0 { + pingTime = s.pingTime + } + + s.pingTicker = NewIntervalAwareForceTicker(pingTime) + s.pingTicker.Resume() + + pongTime := time.Duration(math.MaxInt64) + if s.pongTime != 0 { + pongTime = s.pongTime + } + + s.pongTicker = NewIntervalAwareForceTicker(pongTime) + s.resendTicker = time.NewTicker(s.resendTimeout) + + s.wg.Add(1) + go s.handleAcksAndNacks() + + s.wg.Add(1) + go func() { + defer s.wg.Done() + + err := s.handleQueueContents() + if err != nil { + s.errChan <- err + } + }() +} + +func (s *sender) handleQueueContents() error { + for { + // The queue is not full. If we receive a resend signal or if + // the resend timeout passes then we resend the current contents + // of the queue. Otherwise, wait for more data to arrive on + // sendDataChan. + var packet *PacketData + select { + case <-s.quit: + return nil + + case <-s.resendSignal: + if err := s.sendQueue.resend(); err != nil { + return err + } + + continue + + case <-s.resendTicker.C: + if err := s.sendQueue.resend(); err != nil { + return err + } + continue + + case <-s.pingTicker.Ticks(): + + // Start the pong timer. + s.pongTicker.Reset() + s.pongTicker.Resume() + + s.log.Tracef("Sending a PING packet") + + packet = &PacketData{ + IsPing: true, + } + + case <-s.pongTicker.Ticks(): + return errKeepaliveTimeout + + case packet = <-s.newData: + } + + // New data has arrived that we need to add to the queue and + // send. + s.sendQueue.addPacket(packet) + + s.log.Tracef("Sending data %d", packet.Seq) + + if err := s.sendFn(packet); err != nil { + return err + } + + for { + // If the queue size is still less than N, we can + // continue to add more packets to the queue. + if s.sendQueue.size() < s.n { + break + } + + s.log.Tracef("The queue is full.") + + // The queue is full. We wait for a ACKs to arrive or + // resend the queue after a timeout. + select { + case <-s.quit: + return nil + + case <-s.receivedACKSignal: + break + + case <-s.resendSignal: + if err := s.sendQueue.resend(); err != nil { + return err + } + + case <-s.resendTicker.C: + if err := s.sendQueue.resend(); err != nil { + return err + } + } + } + } +} + +func (s *sender) handleAcksAndNacks() { + defer s.wg.Done() + + var ackSeq, nackSeq uint8 + for { + select { + case ackSeq = <-s.ackChan: + s.handleACK(ackSeq) + + case nackSeq = <-s.nackChan: + s.handleNACK(nackSeq) + + case <-s.quit: + return + } + } +} + +func (s *sender) handleACK(seq uint8) { + gotValidACK := s.sendQueue.processACK(seq) + if gotValidACK { + s.resendTicker.Reset(s.resendTimeout) + + // Send a signal to indicate that new + // ACKs have been received. + select { + case s.receivedACKSignal <- struct{}{}: + default: + } + } +} + +func (s *sender) handleNACK(seq uint8) { + // We received a NACK packet. This means that the receiver got a data + // packet that they were not expecting. This likely means that a packet + // that we sent was dropped, or maybe we sent a duplicate message. The + // NACK message contains the sequence number that the receiver was + // expecting. + inQueue, bumped := s.sendQueue.processNACK(seq) + + // If the NACK sequence number is not in our queue + // then we ignore it. We must have received the ACK + // for the sequence number in the meantime. + if !inQueue { + s.log.Tracef("NACK seq %d is not in the queue. Ignoring", seq) + + return + } + + // If the base was bumped, then the queue is now smaller + // and so we can send a signal to indicate this. + if bumped { + select { + case s.receivedACKSignal <- struct{}{}: + default: + } + } + + s.log.Tracef("Sending a resend signal") + + // Send a signal to indicate that new sends should pause + // and the current queue should be resent instead. + select { + case s.resendSignal <- struct{}{}: + default: + } +} + +func (s *sender) Send(data []byte) error { + select { + case <-s.quit: + return io.EOF + default: + } + + ticker := time.NewTimer(s.timeout.get()) + defer ticker.Stop() + + // sendData sends the given data message onto the dataToSend channel. + sendData := func(packet *PacketData) error { + select { + case s.newData <- packet: + return nil + case <-ticker.C: + return errSendTimeout + case <-s.quit: + return fmt.Errorf("cannot send, sender has exited") + } + } + + // If splitting is disabled, then we set the packet's FinalChunk to + // true. + if s.maxChunkSize == 0 { + return sendData(&PacketData{ + Payload: data, + FinalChunk: true, + }) + } + + // Splitting is enabled. Split into packets no larger than maxChunkSize. + var ( + sentBytes = 0 + maxChunk = s.maxChunkSize + ) + for sentBytes < len(data) { + var msg PacketData + + remainingBytes := len(data) - sentBytes + if remainingBytes <= maxChunk { + msg.Payload = data[sentBytes:] + msg.FinalChunk = true + + sentBytes += remainingBytes + } else { + msg.Payload = data[sentBytes : sentBytes+maxChunk] + + sentBytes += maxChunk + } + + if err := sendData(&msg); err != nil { + return err + } + } + + return nil +} + +func (s *sender) ACK(u uint8) { + select { + case <-s.quit: + return + case s.ackChan <- u: + } +} + +func (s *sender) NACK(u uint8) { + select { + case <-s.quit: + return + case s.nackChan <- u: + } +} + +func (s *sender) SetTimeout(duration time.Duration) { + s.timeout.set(duration) +} + +var _ Sender = (*sender)(nil) diff --git a/gbn/timeouts.go b/gbn/timeouts.go new file mode 100644 index 00000000..287053b2 --- /dev/null +++ b/gbn/timeouts.go @@ -0,0 +1,33 @@ +package gbn + +import ( + "math" + "sync" + "time" +) + +const DefaultTimout = math.MaxInt64 + +type safeTimeout struct { + t time.Duration + mu sync.RWMutex +} + +func newSafeTimeout() *safeTimeout { + return &safeTimeout{ + t: DefaultTimout, + } +} + +func (t *safeTimeout) set(timeout time.Duration) { + t.mu.Lock() + t.t = timeout + t.mu.Unlock() +} + +func (t *safeTimeout) get() time.Duration { + t.mu.RLock() + defer t.mu.RUnlock() + + return t.t +} diff --git a/mailbox/server_conn.go b/mailbox/server_conn.go index ac75055c..4f564025 100644 --- a/mailbox/server_conn.go +++ b/mailbox/server_conn.go @@ -41,7 +41,7 @@ type ServerConn struct { client hashmailrpc.HashMailClient - gbnConn *gbn.GoBackNConn + gbnConn gbn.GBN gbnOptions []gbn.Option receiveBoxCreated bool