From 0c547a717cfb79aaf669d825125fe99db421fff3 Mon Sep 17 00:00:00 2001 From: Cory Benfield Date: Thu, 23 Jan 2025 13:22:18 +0000 Subject: [PATCH] Fix the HappyEyeballsResolver and core Bootstraps under strict concurrency (#3062) ### Motivation: The HappyEyeballsResolver, being an old part of our stack, has a lot of code in it that fails to pass strict concurrency checking. That's deeply suboptimal. ### Modifications: - Clean up the happy eyeballs resolver under strict concurrency - Further cleanups to the bootstraps ### Result: Another step taken on the road to strict concurrency. --- .../5.10/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- .../5.9/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- .../6.0/NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- .../NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- .../NIOPosixBenchmarks.TCPEcho.p90.json | 2 +- Sources/NIOPosix/Bootstrap.swift | 166 +++++++++- Sources/NIOPosix/GetaddrinfoResolver.swift | 2 +- Sources/NIOPosix/HappyEyeballs.swift | 296 ++++++++++-------- Tests/NIOPosixTests/HappyEyeballsTest.swift | 123 ++++++-- 9 files changed, 413 insertions(+), 184 deletions(-) diff --git a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json index c9e4dbd4c6..c6a93680d0 100644 --- a/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.10/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 107 + "mallocCountTotal" : 108 } diff --git a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json index 82d63a322b..248bd96061 100644 --- a/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/5.9/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 109 + "mallocCountTotal" : 110 } diff --git a/Benchmarks/Thresholds/6.0/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/6.0/NIOPosixBenchmarks.TCPEcho.p90.json index c9e4dbd4c6..c6a93680d0 100644 --- a/Benchmarks/Thresholds/6.0/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/6.0/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 107 + "mallocCountTotal" : 108 } diff --git a/Benchmarks/Thresholds/nightly-6.1/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/nightly-6.1/NIOPosixBenchmarks.TCPEcho.p90.json index c9e4dbd4c6..c6a93680d0 100644 --- a/Benchmarks/Thresholds/nightly-6.1/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/nightly-6.1/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 107 + "mallocCountTotal" : 108 } diff --git a/Benchmarks/Thresholds/nightly-main/NIOPosixBenchmarks.TCPEcho.p90.json b/Benchmarks/Thresholds/nightly-main/NIOPosixBenchmarks.TCPEcho.p90.json index c9e4dbd4c6..c6a93680d0 100644 --- a/Benchmarks/Thresholds/nightly-main/NIOPosixBenchmarks.TCPEcho.p90.json +++ b/Benchmarks/Thresholds/nightly-main/NIOPosixBenchmarks.TCPEcho.p90.json @@ -1,3 +1,3 @@ { - "mallocCountTotal" : 107 + "mallocCountTotal" : 108 } diff --git a/Sources/NIOPosix/Bootstrap.swift b/Sources/NIOPosix/Bootstrap.swift index 1ea1d311fb..b7705dc01f 100644 --- a/Sources/NIOPosix/Bootstrap.swift +++ b/Sources/NIOPosix/Bootstrap.swift @@ -823,7 +823,7 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { @usableFromInline internal var _channelOptions: ChannelOptions.Storage private var connectTimeout: TimeAmount = TimeAmount.seconds(10) - private var resolver: Optional + private var resolver: Optional private var bindTarget: Optional private var enableMPTCP: Bool @@ -924,7 +924,8 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { /// /// - Parameters: /// - resolver: The resolver that will be used during the connection attempt. - public func resolver(_ resolver: Resolver?) -> Self { + @preconcurrency + public func resolver(_ resolver: (Resolver & Sendable)?) -> Self { self.resolver = resolver return self } @@ -967,11 +968,23 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { func makeSocketChannel( eventLoop: EventLoop, protocolFamily: NIOBSDSocket.ProtocolFamily + ) throws -> SocketChannel { + try Self.makeSocketChannel( + eventLoop: eventLoop, + protocolFamily: protocolFamily, + enableMPTCP: self.enableMPTCP + ) + } + + static func makeSocketChannel( + eventLoop: EventLoop, + protocolFamily: NIOBSDSocket.ProtocolFamily, + enableMPTCP: Bool ) throws -> SocketChannel { try SocketChannel( eventLoop: eventLoop as! SelectableEventLoop, protocolFamily: protocolFamily, - enableMPTCP: self.enableMPTCP + enableMPTCP: enableMPTCP ) } @@ -992,6 +1005,11 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { aiSocktype: .stream, aiProtocol: .tcp ) + let enableMPTCP = self.enableMPTCP + let channelInitializer = self.channelInitializer + let channelOptions = self._channelOptions + let bindTarget = self.bindTarget + let connector = HappyEyeballsConnector( resolver: resolver, loop: loop, @@ -999,7 +1017,14 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { port: port, connectTimeout: self.connectTimeout ) { eventLoop, protocolFamily in - self.initializeAndRegisterNewChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) { + Self.initializeAndRegisterNewChannel( + eventLoop: eventLoop, + protocolFamily: protocolFamily, + enableMPTCP: enableMPTCP, + channelInitializer: channelInitializer, + channelOptions: channelOptions, + bindTarget: bindTarget + ) { $0.eventLoop.makeSucceededFuture(()) } } @@ -1148,24 +1173,67 @@ public final class ClientBootstrap: NIOClientTCPBootstrapProtocol { eventLoop: EventLoop, protocolFamily: NIOBSDSocket.ProtocolFamily, _ body: @escaping @Sendable (Channel) -> EventLoopFuture + ) -> EventLoopFuture { + Self.initializeAndRegisterNewChannel( + eventLoop: eventLoop, + protocolFamily: protocolFamily, + enableMPTCP: self.enableMPTCP, + channelInitializer: self.channelInitializer, + channelOptions: self._channelOptions, + bindTarget: self.bindTarget, + body + ) + } + + private static func initializeAndRegisterNewChannel( + eventLoop: EventLoop, + protocolFamily: NIOBSDSocket.ProtocolFamily, + enableMPTCP: Bool, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + channelOptions: ChannelOptions.Storage, + bindTarget: SocketAddress?, + _ body: @escaping @Sendable (Channel) -> EventLoopFuture ) -> EventLoopFuture { let channel: SocketChannel do { - channel = try self.makeSocketChannel(eventLoop: eventLoop, protocolFamily: protocolFamily) + channel = try Self.makeSocketChannel( + eventLoop: eventLoop, + protocolFamily: protocolFamily, + enableMPTCP: enableMPTCP + ) } catch { return eventLoop.makeFailedFuture(error) } - return self.initializeAndRegisterChannel(channel, body) + return Self.initializeAndRegisterChannel( + channel, + channelInitializer: channelInitializer, + channelOptions: channelOptions, + bindTarget: bindTarget, + body + ) } private func initializeAndRegisterChannel( _ channel: SocketChannel, _ body: @escaping @Sendable (Channel) -> EventLoopFuture ) -> EventLoopFuture { - let channelInitializer = self.channelInitializer - let channelOptions = self._channelOptions + Self.initializeAndRegisterChannel( + channel, + channelInitializer: self.channelInitializer, + channelOptions: self._channelOptions, + bindTarget: self.bindTarget, + body + ) + } + + private static func initializeAndRegisterChannel( + _ channel: SocketChannel, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + channelOptions: ChannelOptions.Storage, + bindTarget: SocketAddress?, + _ body: @escaping @Sendable (Channel) -> EventLoopFuture + ) -> EventLoopFuture { let eventLoop = channel.eventLoop - let bindTarget = self.bindTarget @inline(__always) @Sendable @@ -1352,6 +1420,11 @@ extension ClientBootstrap { aiProtocol: .tcp ) + let enableMPTCP = self.enableMPTCP + let bootstrapChannelInitializer = self.channelInitializer + let channelOptions = self._channelOptions + let bindTarget = self.bindTarget + let connector = HappyEyeballsConnector( resolver: resolver, loop: eventLoop, @@ -1359,9 +1432,13 @@ extension ClientBootstrap { port: port, connectTimeout: self.connectTimeout ) { eventLoop, protocolFamily in - self.initializeAndRegisterNewChannel( + Self.initializeAndRegisterNewChannel( eventLoop: eventLoop, protocolFamily: protocolFamily, + enableMPTPCP: enableMPTCP, + bootstrapChannelInitializer: bootstrapChannelInitializer, + channelOptions: channelOptions, + bindTarget: bindTarget, channelInitializer: channelInitializer, postRegisterTransformation: postRegisterTransformation ) { @@ -1426,6 +1503,46 @@ extension ClientBootstrap { ).map { (channel, $0) } } + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private static func initializeAndRegisterNewChannel< + ChannelInitializerResult: Sendable, + PostRegistrationTransformationResult: Sendable + >( + eventLoop: EventLoop, + protocolFamily: NIOBSDSocket.ProtocolFamily, + enableMPTPCP: Bool, + bootstrapChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + channelOptions: ChannelOptions.Storage, + bindTarget: SocketAddress?, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + >, + _ body: @escaping @Sendable (Channel) -> EventLoopFuture + ) -> EventLoopFuture<(Channel, PostRegistrationTransformationResult)> { + let channel: SocketChannel + do { + channel = try Self.makeSocketChannel( + eventLoop: eventLoop, + protocolFamily: protocolFamily, + enableMPTCP: enableMPTPCP + ) + } catch { + return eventLoop.makeFailedFuture(error) + } + return Self.initializeAndRegisterChannel( + channel: channel, + bootstrapChannelInitializer: bootstrapChannelInitializer, + channelOptions: channelOptions, + bindTarget: bindTarget, + channelInitializer: channelInitializer, + registration: { channel in + channel.registerAndDoSynchronously(body) + }, + postRegisterTransformation: postRegisterTransformation + ).map { (channel, $0) } + } + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) private func initializeAndRegisterChannel< ChannelInitializerResult: Sendable, @@ -1438,16 +1555,39 @@ extension ClientBootstrap { PostRegistrationTransformationResult > ) -> EventLoopFuture { - let bootstrapChannelInitializer = self.channelInitializer + Self.initializeAndRegisterChannel( + channel: channel, + bootstrapChannelInitializer: self.channelInitializer, + channelOptions: self._channelOptions, + bindTarget: self.bindTarget, + channelInitializer: channelInitializer, + registration: registration, + postRegisterTransformation: postRegisterTransformation + ) + } + + @available(macOS 10.15, iOS 13, tvOS 13, watchOS 6, *) + private static func initializeAndRegisterChannel< + ChannelInitializerResult: Sendable, + PostRegistrationTransformationResult: Sendable + >( + channel: SocketChannel, + bootstrapChannelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + channelOptions: ChannelOptions.Storage, + bindTarget: SocketAddress?, + channelInitializer: @escaping @Sendable (Channel) -> EventLoopFuture, + registration: @escaping @Sendable (SocketChannel) -> EventLoopFuture, + postRegisterTransformation: @escaping @Sendable (ChannelInitializerResult, EventLoop) -> EventLoopFuture< + PostRegistrationTransformationResult + > + ) -> EventLoopFuture { let channelInitializer = { @Sendable channel in bootstrapChannelInitializer(channel).hop(to: channel.eventLoop) .assumeIsolated() .flatMap { channelInitializer(channel) } .nonisolated() } - let channelOptions = self._channelOptions let eventLoop = channel.eventLoop - let bindTarget = self.bindTarget @inline(__always) @Sendable diff --git a/Sources/NIOPosix/GetaddrinfoResolver.swift b/Sources/NIOPosix/GetaddrinfoResolver.swift index b24b0a9a2f..3ec44ec6ff 100644 --- a/Sources/NIOPosix/GetaddrinfoResolver.swift +++ b/Sources/NIOPosix/GetaddrinfoResolver.swift @@ -50,7 +50,7 @@ import struct WinSDK.SOCKADDR_IN6 // A thread-specific variable where we store the offload queue if we're on an `SelectableEventLoop`. let offloadQueueTSV = ThreadSpecificVariable() -internal class GetaddrinfoResolver: Resolver { +internal final class GetaddrinfoResolver: Resolver, Sendable { private let loop: EventLoop private let v4Future: EventLoopPromise<[SocketAddress]> private let v6Future: EventLoopPromise<[SocketAddress]> diff --git a/Sources/NIOPosix/HappyEyeballs.swift b/Sources/NIOPosix/HappyEyeballs.swift index 25071ec32a..f1d2e9e3e3 100644 --- a/Sources/NIOPosix/HappyEyeballs.swift +++ b/Sources/NIOPosix/HappyEyeballs.swift @@ -126,25 +126,142 @@ private struct TargetIterator: IteratorProtocol { /// Given a DNS resolver and an event loop, attempts to establish a connection to /// the target host over both IPv4 and IPv6. /// -/// This class provides the code that implements RFC 8305: Happy Eyeballs 2. This +/// This type provides the code that implements RFC 8305: Happy Eyeballs 2. This /// is a connection establishment strategy that attempts to efficiently and quickly /// establish connections to a host that has multiple IP addresses available to it, /// potentially over two different IP protocol versions (4 and 6). /// -/// This class should be created when a connection attempt is made and will remain +/// This type should be created when a connection attempt is made and will remain /// active until a connection is established. It is responsible for actually creating /// connections and managing timeouts. /// -/// This class's public API is thread-safe: the constructor and `resolveAndConnect` can +/// This type's public API is thread-safe: the constructor and `resolveAndConnect` can /// be called from any thread. `resolveAndConnect` will dispatch itself to the event /// loop to force serialization. /// -/// This class's private API is *not* thread-safe, and expects to be called from the -/// event loop thread of the `loop` it is passed. -/// /// The `ChannelBuilderResult` generic type can used to tunnel an arbitrary type /// from the `channelBuilderCallback` to the `resolve` methods return value. -internal final class HappyEyeballsConnector { +internal struct HappyEyeballsConnector: Sendable { + /// The DNS resolver provided by the user. + fileprivate let resolver: Resolver & Sendable + + /// The event loop this connector will run on. + fileprivate let loop: EventLoop + + /// The host name we're connecting to. + fileprivate let host: String + + /// The port we're connecting to. + fileprivate let port: Int + + /// A callback, provided by the user, that is used to build a channel. + /// + /// This callback is expected to build *and register* a channel with the event loop that + /// was used with this resolver. It is free to set up the channel asynchronously, but note + /// that the time taken to set the channel up will be counted against the connection delay, + /// meaning that long channel setup times may cause more connections to be outstanding + /// than intended. + /// + /// The channel builder callback takes an event loop and a protocol family as arguments. + fileprivate let channelBuilderCallback: + @Sendable (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)> + + /// The amount of time to wait for an AAAA response to come in after a A response is + /// received. By default this is 50ms. + fileprivate let resolutionDelay: TimeAmount + + /// The amount of time to wait for a connection to succeed before beginning a new connection + /// attempt. By default this is 250ms. + fileprivate let connectionDelay: TimeAmount + + /// The amount of time to allow for the overall connection process before timing it out. + fileprivate let connectTimeout: TimeAmount + + /// The promise that will hold the final connected channel. + fileprivate let resolutionPromise: EventLoopPromise<(Channel, ChannelBuilderResult)> + + @inlinable + init( + resolver: Resolver & Sendable, + loop: EventLoop, + host: String, + port: Int, + connectTimeout: TimeAmount, + resolutionDelay: TimeAmount = .milliseconds(50), + connectionDelay: TimeAmount = .milliseconds(250), + channelBuilderCallback: @escaping @Sendable (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture< + (Channel, ChannelBuilderResult) + > + ) { + self.resolver = resolver + self.loop = loop + self.host = host + self.port = port + self.connectTimeout = connectTimeout + self.channelBuilderCallback = channelBuilderCallback + + self.resolutionPromise = self.loop.makePromise() + + precondition( + resolutionDelay.nanoseconds > 0, + "Resolution delay must be greater than zero, got \(resolutionDelay)." + ) + self.resolutionDelay = resolutionDelay + + precondition( + connectionDelay >= .milliseconds(100) && connectionDelay <= .milliseconds(2000), + "Connection delay must be between 100 and 2000 ms, got \(connectionDelay)" + ) + self.connectionDelay = connectionDelay + } + + @inlinable + init( + resolver: Resolver & Sendable, + loop: EventLoop, + host: String, + port: Int, + connectTimeout: TimeAmount, + resolutionDelay: TimeAmount = .milliseconds(50), + connectionDelay: TimeAmount = .milliseconds(250), + channelBuilderCallback: @escaping @Sendable (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture + ) where ChannelBuilderResult == Void { + self.init( + resolver: resolver, + loop: loop, + host: host, + port: port, + connectTimeout: connectTimeout, + resolutionDelay: resolutionDelay, + connectionDelay: connectionDelay + ) { loop, protocolFamily in + channelBuilderCallback(loop, protocolFamily).map { ($0, ()) } + } + } + + /// Initiate a DNS resolution attempt using Happy Eyeballs 2. + /// + /// returns: An `EventLoopFuture` that fires with a connected `Channel`. + @inlinable + func resolveAndConnect() -> EventLoopFuture<(Channel, ChannelBuilderResult)> { + // We dispatch ourselves onto the event loop, rather than do all the rest of our processing from outside it. + self.loop.execute { + let runner = HappyEyeballsConnectorRunner(connector: self) + runner.resolveAndConnect() + } + return self.resolutionPromise.futureResult + } + + /// Initiate a DNS resolution attempt using Happy Eyeballs 2. + /// + /// returns: An `EventLoopFuture` that fires with a connected `Channel`. + @inlinable + func resolveAndConnect() -> EventLoopFuture where ChannelBuilderResult == Void { + self.resolveAndConnect().map { $0.0 } + } +} + +private final class HappyEyeballsConnectorRunner { /// An enum for keeping track of connection state. private enum ConnectionState { /// Initial state. No work outstanding. @@ -205,57 +322,22 @@ internal final class HappyEyeballsConnector { case noTargetsRemaining } - /// The DNS resolver provided by the user. - private let resolver: Resolver - - /// The event loop this connector will run on. - private let loop: EventLoop - - /// The host name we're connecting to. - private let host: String - - /// The port we're connecting to. - private let port: Int - - /// A callback, provided by the user, that is used to build a channel. - /// - /// This callback is expected to build *and register* a channel with the event loop that - /// was used with this resolver. It is free to set up the channel asynchronously, but note - /// that the time taken to set the channel up will be counted against the connection delay, - /// meaning that long channel setup times may cause more connections to be outstanding - /// than intended. - /// - /// The channel builder callback takes an event loop and a protocol family as arguments. - private let channelBuilderCallback: - (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture<(Channel, ChannelBuilderResult)> - - /// The amount of time to wait for an AAAA response to come in after a A response is - /// received. By default this is 50ms. - private let resolutionDelay: TimeAmount + /// The thread-safe state across resolutions. + private let connector: HappyEyeballsConnector /// A reference to the task that will execute after the resolution delay expires, if /// one is scheduled. This is held to ensure that we can cancel this task if the AAAA /// response comes in before the resolution delay expires. private var resolutionTask: Optional> - /// The amount of time to wait for a connection to succeed before beginning a new connection - /// attempt. By default this is 250ms. - private let connectionDelay: TimeAmount - /// A reference to the task that will execute after the connection delay expires, if one /// is scheduled. This is held to ensure that we can cancel this task if a connection /// succeeds before the connection delay expires. private var connectionTask: Optional> - /// The amount of time to allow for the overall connection process before timing it out. - private let connectTimeout: TimeAmount - /// A reference to the task that will time us out. private var timeoutTask: Optional> - /// The promise that will hold the final connected channel. - private let resolutionPromise: EventLoopPromise<(Channel, ChannelBuilderResult)> - /// Our state machine state. private var state: ConnectionState @@ -280,89 +362,24 @@ internal final class HappyEyeballsConnector { @inlinable init( - resolver: Resolver, - loop: EventLoop, - host: String, - port: Int, - connectTimeout: TimeAmount, - resolutionDelay: TimeAmount = .milliseconds(50), - connectionDelay: TimeAmount = .milliseconds(250), - channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture< - (Channel, ChannelBuilderResult) - > + connector: HappyEyeballsConnector ) { - self.resolver = resolver - self.loop = loop - self.host = host - self.port = port - self.connectTimeout = connectTimeout - self.channelBuilderCallback = channelBuilderCallback + self.connector = connector self.resolutionTask = nil self.connectionTask = nil self.timeoutTask = nil self.state = .idle - self.resolutionPromise = self.loop.makePromise() - self.error = NIOConnectionError(host: host, port: port) - - precondition( - resolutionDelay.nanoseconds > 0, - "Resolution delay must be greater than zero, got \(resolutionDelay)." - ) - self.resolutionDelay = resolutionDelay - - precondition( - connectionDelay >= .milliseconds(100) && connectionDelay <= .milliseconds(2000), - "Connection delay must be between 100 and 2000 ms, got \(connectionDelay)" - ) - self.connectionDelay = connectionDelay - } - - @inlinable - convenience init( - resolver: Resolver, - loop: EventLoop, - host: String, - port: Int, - connectTimeout: TimeAmount, - resolutionDelay: TimeAmount = .milliseconds(50), - connectionDelay: TimeAmount = .milliseconds(250), - channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture - ) where ChannelBuilderResult == Void { - self.init( - resolver: resolver, - loop: loop, - host: host, - port: port, - connectTimeout: connectTimeout, - resolutionDelay: resolutionDelay, - connectionDelay: connectionDelay - ) { loop, protocolFamily in - channelBuilderCallback(loop, protocolFamily).map { ($0, ()) } - } + self.error = NIOConnectionError(host: self.connector.host, port: self.connector.port) } /// Initiate a DNS resolution attempt using Happy Eyeballs 2. - /// - /// returns: An `EventLoopFuture` that fires with a connected `Channel`. @inlinable - func resolveAndConnect() -> EventLoopFuture<(Channel, ChannelBuilderResult)> { - // We dispatch ourselves onto the event loop, rather than do all the rest of our processing from outside it. - self.loop.execute { - self.timeoutTask = self.loop.scheduleTask(in: self.connectTimeout) { - self.processInput(.connectTimeoutElapsed) - } - self.processInput(.resolve) + func resolveAndConnect() { + self.timeoutTask = self.connector.loop.assumeIsolated().scheduleTask(in: self.connector.connectTimeout) { + self.processInput(.connectTimeoutElapsed) } - return resolutionPromise.futureResult - } - - /// Initiate a DNS resolution attempt using Happy Eyeballs 2. - /// - /// returns: An `EventLoopFuture` that fires with a connected `Channel`. - @inlinable - func resolveAndConnect() -> EventLoopFuture where ChannelBuilderResult == Void { - self.resolveAndConnect().map { $0.0 } + self.processInput(.resolve) } /// Spin the state machine. @@ -490,11 +507,18 @@ internal final class HappyEyeballsConnector { // query. // // We hop back to `self.loop` because there's no guarantee the resolver runs - // on our event loop. - let aaaaLookup = self.resolver.initiateAAAAQuery(host: self.host, port: self.port).hop(to: self.loop) + // on our event loop. That hop then makes it safe for us to assumeIsolatedUnsafeUnchecked. + self.connector.loop.assertInEventLoop() + let aaaaLookup = self.connector.resolver.initiateAAAAQuery( + host: self.connector.host, + port: self.connector.port + ).hop(to: self.connector.loop).assumeIsolatedUnsafeUnchecked() self.whenAAAALookupComplete(future: aaaaLookup) - let aLookup = self.resolver.initiateAQuery(host: self.host, port: self.port).hop(to: self.loop) + let aLookup = self.connector.resolver.initiateAQuery( + host: self.connector.host, + port: self.connector.port + ).hop(to: self.connector.loop).assumeIsolatedUnsafeUnchecked() self.whenALookupComplete(future: aLookup) } @@ -506,7 +530,10 @@ internal final class HappyEyeballsConnector { /// /// This method sets off a scheduled task for the resolution delay. private func beginResolutionDelay() { - resolutionTask = loop.scheduleTask(in: resolutionDelay, resolutionDelayComplete) + resolutionTask = self.connector.loop.assumeIsolated().scheduleTask( + in: self.connector.resolutionDelay, + resolutionDelayComplete + ) } /// Called when we're ready to start connecting to targets. @@ -521,7 +548,11 @@ internal final class HappyEyeballsConnector { return } - connectionTask = loop.scheduleTask(in: connectionDelay) { self.processInput(.connectDelayElapsed) } + connectionTask = self.connector.loop.assumeIsolated().scheduleTask( + in: self.connector.connectionDelay + ) { + self.processInput(.connectDelayElapsed) + } connectToTarget(target) } @@ -573,7 +604,9 @@ internal final class HappyEyeballsConnector { /// Cleans up internal state and fails the connection promise. private func timedOut() { cleanUp() - self.resolutionPromise.fail(ChannelError.connectTimeout(self.connectTimeout)) + self.connector.resolutionPromise.fail( + ChannelError.connectTimeout(self.connector.connectTimeout) + ) } /// Called when we've attempted to connect to all our resolved targets, @@ -584,7 +617,7 @@ internal final class HappyEyeballsConnector { private func failed() { precondition(pendingConnections.isEmpty, "failed with pending connections") cleanUp() - self.resolutionPromise.fail(self.error) + self.connector.resolutionPromise.fail(self.error) } /// Called to connect to a given target. @@ -592,17 +625,18 @@ internal final class HappyEyeballsConnector { /// - Parameters: /// - target: The address to connect to. private func connectToTarget(_ target: SocketAddress) { - let channelFuture = channelBuilderCallback(self.loop, target.protocol) + let channelFuture = self.connector.channelBuilderCallback(self.connector.loop, target.protocol) + let isolatedChannelFuture = channelFuture.hop(to: self.connector.loop).assumeIsolated() pendingConnections.append(channelFuture) - channelFuture.whenSuccess { (channel, result) in + isolatedChannelFuture.whenSuccess { (channel, result) in // If we are in the complete state then we want to abandon this channel. Otherwise, begin // connecting. if case .complete = self.state { self.pendingConnections.removeAll { $0 === channelFuture } channel.close(promise: nil) } else { - channel.connect(to: target).map { + channel.connect(to: target).assumeIsolated().map { // The channel has connected. If we are in the complete state we want to abandon this channel. // Otherwise, fire the channel connected event. Either way we don't want the channel future to // be in our list of pending connections, so we don't either double close or close the connection @@ -613,7 +647,7 @@ internal final class HappyEyeballsConnector { channel.close(promise: nil) } else { self.processInput(.connectSuccess) - self.resolutionPromise.succeed((channel, result)) + self.connector.resolutionPromise.succeed((channel, result)) } }.whenFailure { err in // The connection attempt failed. If we're in the complete state then there's nothing @@ -631,7 +665,7 @@ internal final class HappyEyeballsConnector { } } } - channelFuture.whenFailure { error in + isolatedChannelFuture.whenFailure { error in self.error.connectionErrors.append(SingleConnectionFailure(target: target, error: error)) self.pendingConnections.removeAll { $0 === channelFuture } self.processInput(.connectFailed) @@ -644,7 +678,7 @@ internal final class HappyEyeballsConnector { assert(self.state == .complete, "Clean up in invalid state \(self.state)") if dnsResolutions < 2 { - resolver.cancelQueries() + self.connector.resolver.cancelQueries() } if let resolutionTask = self.resolutionTask { @@ -670,7 +704,7 @@ internal final class HappyEyeballsConnector { } /// A future callback that fires when a DNS A lookup completes. - private func whenALookupComplete(future: EventLoopFuture<[SocketAddress]>) { + private func whenALookupComplete(future: EventLoopFuture<[SocketAddress]>.Isolated) { future.map { results in self.targets.aResultsAvailable(results) }.recover { err in @@ -682,7 +716,7 @@ internal final class HappyEyeballsConnector { } /// A future callback that fires when a DNS AAAA lookup completes. - private func whenAAAALookupComplete(future: EventLoopFuture<[SocketAddress]>) { + private func whenAAAALookupComplete(future: EventLoopFuture<[SocketAddress]>.Isolated) { future.map { results in self.targets.aaaaResultsAvailable(results) }.recover { err in diff --git a/Tests/NIOPosixTests/HappyEyeballsTest.swift b/Tests/NIOPosixTests/HappyEyeballsTest.swift index 7e91b60a6b..bebf3d48f0 100644 --- a/Tests/NIOPosixTests/HappyEyeballsTest.swift +++ b/Tests/NIOPosixTests/HappyEyeballsTest.swift @@ -13,6 +13,7 @@ //===----------------------------------------------------------------------===// import CNIOLinux +import NIOConcurrencyHelpers import NIOEmbedded import XCTest @@ -240,9 +241,12 @@ private class DummyResolver: Resolver { extension DummyResolver.Event: Equatable { } +@Sendable private func defaultChannelBuilder(loop: EventLoop, family: NIOBSDSocket.ProtocolFamily) -> EventLoopFuture { let channel = EmbeddedChannel(loop: loop as! EmbeddedEventLoop) - XCTAssertNoThrow(try channel.pipeline.addHandler(ConnectRecorder(), name: CONNECT_RECORDER).wait()) + XCTAssertNoThrow( + try channel.pipeline.syncOperations.addHandler(ConnectRecorder(), name: CONNECT_RECORDER) + ) return loop.makeSucceededFuture(channel) } @@ -250,7 +254,7 @@ private func buildEyeballer( host: String, port: Int, connectTimeout: TimeAmount = .seconds(10), - channelBuilderCallback: @escaping (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture = + channelBuilderCallback: @escaping @Sendable (EventLoop, NIOBSDSocket.ProtocolFamily) -> EventLoopFuture = defaultChannelBuilder ) -> (eyeballer: HappyEyeballsConnector, resolver: DummyResolver, loop: EmbeddedEventLoop) { let loop = EmbeddedEventLoop() @@ -586,7 +590,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testMaximalConnectionDelay() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -655,7 +659,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testAllConnectionsFail() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -725,7 +729,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testDelayedAAAAResult() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -812,7 +816,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testTimeoutAfterAQuery() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -861,7 +865,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testAConnectFailsWaitingForAAAA() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -920,7 +924,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testDelayedAResult() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -966,7 +970,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testTimeoutBeforeAResponse() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -1015,7 +1019,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testAllConnectionsFailImmediately() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -1065,7 +1069,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testLaterConnections() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -1110,10 +1114,15 @@ public final class HappyEyeballsTest: XCTestCase { } func testDelayedChannelCreation() throws { - var ourChannelFutures: [EventLoopPromise] = [] + // This lock isn't really needed as the test is single-threaded, but because + // the buildEyeballer function constructs the event loop we can't use a LoopBoundBox. + // This is fine anyway. + let ourChannelFutures: NIOLockedValueBox<[EventLoopPromise]> = NIOLockedValueBox([]) let (eyeballer, resolver, loop) = buildEyeballer(host: "example.com", port: 80) { loop, _ in - ourChannelFutures.append(loop.makePromise()) - return ourChannelFutures.last!.futureResult + ourChannelFutures.withLockedValue { + $0.append(loop.makePromise()) + return $0.last!.futureResult + } } let channelFuture = eyeballer.resolveAndConnect() let expectedQueries: [DummyResolver.Event] = [ @@ -1127,37 +1136,51 @@ public final class HappyEyeballsTest: XCTestCase { // Return the IPv6 results and observe the channel creation attempts. resolver.v6Promise.succeed(MANY_IPv6_RESULTS) for channelCount in 1...10 { - XCTAssertEqual(ourChannelFutures.count, channelCount) + XCTAssertEqual( + ourChannelFutures.withLockedValue { $0.count }, + channelCount + ) loop.advanceTime(by: .milliseconds(250)) } XCTAssertFalse(channelFuture.isFulfilled) // Succeed the first channel future, which will connect because the default // channel builder always does. - defaultChannelBuilder(loop: loop, family: .inet6).whenSuccess { - ourChannelFutures.first!.succeed($0) - XCTAssertEqual($0.state(), .connected) + defaultChannelBuilder(loop: loop, family: .inet6).whenSuccess { result in + ourChannelFutures.withLockedValue { + $0.first!.succeed(result) + } + XCTAssertEqual(result.state(), .connected) } XCTAssertTrue(channelFuture.isFulfilled) // Ok, now succeed the second channel future. This should cause the channel to immediately be closed. - defaultChannelBuilder(loop: loop, family: .inet6).whenSuccess { - ourChannelFutures[1].succeed($0) - XCTAssertEqual($0.state(), .closed) + defaultChannelBuilder(loop: loop, family: .inet6).whenSuccess { result in + ourChannelFutures.withLockedValue { + $0[1].succeed(result) + } + XCTAssertEqual(result.state(), .closed) } - // Ok, now fail the third channel future. Nothing bad should happen here. - ourChannelFutures[2].fail(DummyError()) + try ourChannelFutures.withLockedValue { + // Ok, now fail the third channel future. Nothing bad should happen here. + $0[2].fail(DummyError()) - // Verify that the first channel is the one listed as connected. - XCTAssertTrue((try ourChannelFutures.first!.futureResult.wait()) === (try channelFuture.wait())) + // Verify that the first channel is the one listed as connected. + XCTAssertTrue((try $0.first!.futureResult.wait()) === (try channelFuture.wait())) + } } func testChannelCreationFails() throws { - var errors: [DummyError] = [] + // This lock isn't really needed as the test is single-threaded, but because + // the buildEyeballer function constructs the event loop we can't use a LoopBoundBox. + // This is fine anyway. + let errors: NIOLockedValueBox<[DummyError]> = NIOLockedValueBox([]) let (eyeballer, resolver, loop) = buildEyeballer(host: "example.com", port: 80) { loop, _ in - errors.append(DummyError()) - return loop.makeFailedFuture(errors.last!) + errors.withLockedValue { + $0.append(DummyError()) + return loop.makeFailedFuture($0.last!) + } } let channelFuture = eyeballer.resolveAndConnect() let expectedQueries: [DummyResolver.Event] = [ @@ -1171,21 +1194,21 @@ public final class HappyEyeballsTest: XCTestCase { // Here the AAAA and A results return. We are going to fail the channel creation // instantly, which should cause all 20 to appear. resolver.v6Promise.succeed(MANY_IPv6_RESULTS) - XCTAssertEqual(errors.count, 10) + XCTAssertEqual(errors.withLockedValue { $0.count }, 10) XCTAssertFalse(channelFuture.isFulfilled) resolver.v4Promise.succeed(MANY_IPv4_RESULTS) - XCTAssertEqual(errors.count, 20) + XCTAssertEqual(errors.withLockedValue { $0.count }, 20) XCTAssertTrue(channelFuture.isFulfilled) if let error = channelFuture.getError() as? NIOConnectionError { - XCTAssertEqual(error.connectionErrors.map { $0.error as! DummyError }, errors) + XCTAssertEqual(error.connectionErrors.map { $0.error as! DummyError }, errors.withLockedValue { $0 }) } else { XCTFail("Got unexpected error: \(String(describing: channelFuture.getError()))") } } func testCancellationSyncWithConnectDelay() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -1233,7 +1256,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testCancellationSyncWithResolutionDelay() throws { - var channels: [Channel] = [] + let channels = ChannelSet() defer { channels.finishAll() } @@ -1311,7 +1334,7 @@ public final class HappyEyeballsTest: XCTestCase { } func testResolutionTimeoutAndResolutionInSameTick() throws { - var channels: [Channel] = [] + let channels = ChannelSet() let (eyeballer, resolver, loop) = buildEyeballer(host: "example.com", port: 80) { let channelFuture = defaultChannelBuilder(loop: $0, family: $1) channelFuture.whenSuccess { channel in @@ -1357,3 +1380,35 @@ public final class HappyEyeballsTest: XCTestCase { XCTAssertEqual(resolver.events, expectedQueries) } } + +struct ChannelSet: Sendable, Sequence { + private let channels: NIOLockedValueBox<[Channel]> = .init([]) + + func append(_ channel: Channel) { + self.channels.withLockedValue { $0.append(channel) } + } + + var first: Channel? { + self.channels.withLockedValue { $0.first } + } + + var last: Channel? { + self.channels.withLockedValue { $0.last } + } + + var count: Int { + self.channels.withLockedValue { $0.count } + } + + subscript(index: Int) -> Channel { + self.channels.withLockedValue { $0[index] } + } + + func makeIterator() -> some IteratorProtocol { + self.channels.withLockedValue { $0.makeIterator() } + } + + func finishAll() { + self.channels.withLockedValue { $0 }.finishAll() + } +}