Skip to content

Commit

Permalink
async: ensure connection close error is set
Browse files Browse the repository at this point in the history
The previous implementation had a couple of issues that could cause
closeWithError() to not propagate the passed error properly.

1. c.error.Store() is only called after c.close(), which could cause the
   channel returned by CloseChannel() to be closed before the error is
   set. Callers that that block on the channel may attempt to read the
   error before it is set.
2. If c.close() fails, closeError is returned, but not set to
   c.error.Store(), causing callers to miss it. Additionally, the
   original error is ignored and not made available to callers at all.

To fix these issues, closeWithError() is updated to set c.error before
calling c.close() and all errors are returned and set to c.error by
using errors.Join().

In order to support different kinds of concrete error types, the type of
the error field had to be changed from atomic.Value to error.

Signed-off-by: Luiz Aoqui <[email protected]>
  • Loading branch information
lgfa29 committed Sep 21, 2024
1 parent 6e1bae7 commit db3cfbc
Showing 1 changed file with 11 additions and 8 deletions.
19 changes: 11 additions & 8 deletions async.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ type Async struct {
stale []*packet.Packet
logger types.Logger
wg sync.WaitGroup
error atomic.Value
errorMu sync.RWMutex
error error
streamsMu sync.Mutex
streams map[uint16]*Stream
newStreamHandlerMu sync.Mutex
Expand Down Expand Up @@ -241,11 +242,9 @@ func (c *Async) Logger() types.Logger {

// Error returns the error that caused the frisbee.Async connection to close
func (c *Async) Error() error {
err := c.error.Load()
if err == nil {
return nil
}
return err.(error)
c.errorMu.RLock()
defer c.errorMu.RUnlock()
return c.error
}

// Closed returns whether the frisbee.Async connection is closed
Expand Down Expand Up @@ -424,12 +423,16 @@ func (c *Async) close() error {
}

func (c *Async) closeWithError(err error) error {
c.errorMu.Lock()
defer c.errorMu.Unlock()

c.error = err
closeError := c.close()
if closeError != nil {
c.Logger().Debug().Err(closeError).Msgf("attempted to close connection with error `%s`, but got error while closing", err)
return closeError
c.error = errors.Join(closeError, err)
return c.error
}
c.error.Store(err)
_ = c.conn.Close()
return err
}
Expand Down

0 comments on commit db3cfbc

Please sign in to comment.