From db3cfbcfcde88e2c425faf84eb321f1b486b5095 Mon Sep 17 00:00:00 2001 From: Luiz Aoqui Date: Fri, 20 Sep 2024 23:28:45 -0400 Subject: [PATCH] async: ensure connection close error is set The previous implementation had a couple of issues that could cause closeWithError() to not propagate the passed error properly. 1. c.error.Store() is only called after c.close(), which could cause the channel returned by CloseChannel() to be closed before the error is set. Callers that that block on the channel may attempt to read the error before it is set. 2. If c.close() fails, closeError is returned, but not set to c.error.Store(), causing callers to miss it. Additionally, the original error is ignored and not made available to callers at all. To fix these issues, closeWithError() is updated to set c.error before calling c.close() and all errors are returned and set to c.error by using errors.Join(). In order to support different kinds of concrete error types, the type of the error field had to be changed from atomic.Value to error. Signed-off-by: Luiz Aoqui --- async.go | 19 +++++++++++-------- 1 file changed, 11 insertions(+), 8 deletions(-) diff --git a/async.go b/async.go index faf1b8e..2063b7a 100644 --- a/async.go +++ b/async.go @@ -37,7 +37,8 @@ type Async struct { stale []*packet.Packet logger types.Logger wg sync.WaitGroup - error atomic.Value + errorMu sync.RWMutex + error error streamsMu sync.Mutex streams map[uint16]*Stream newStreamHandlerMu sync.Mutex @@ -241,11 +242,9 @@ func (c *Async) Logger() types.Logger { // Error returns the error that caused the frisbee.Async connection to close func (c *Async) Error() error { - err := c.error.Load() - if err == nil { - return nil - } - return err.(error) + c.errorMu.RLock() + defer c.errorMu.RUnlock() + return c.error } // Closed returns whether the frisbee.Async connection is closed @@ -424,12 +423,16 @@ func (c *Async) close() error { } func (c *Async) closeWithError(err error) error { + c.errorMu.Lock() + defer c.errorMu.Unlock() + + c.error = err closeError := c.close() if closeError != nil { c.Logger().Debug().Err(closeError).Msgf("attempted to close connection with error `%s`, but got error while closing", err) - return closeError + c.error = errors.Join(closeError, err) + return c.error } - c.error.Store(err) _ = c.conn.Close() return err }