From f5558ff898aa95f7bcfcbfa7bb3264ece5638a9b Mon Sep 17 00:00:00 2001 From: Julien Viet Date: Mon, 5 Feb 2024 10:50:26 +0100 Subject: [PATCH 1/2] The SSLHelper class maintains a map of server name to Netty SslContext that is filled when a client provides a server name. When a server name does not resolve to a KeyManagerFactory or TrustManagerFactory, the default factories are used and the entry is stored in the map. Instead no specific factory is resolved the default Netty SslContext is used, since this can lead to a a memory leak when a client specifies spurious SNI server names. This affects only a TCP server when SNI is set in the HttpServerOptions. --- .../io/vertx/core/net/impl/SSLHelper.java | 208 +++++++++++++----- .../io/vertx/core/net/impl/TCPServerBase.java | 3 + src/test/java/io/vertx/core/net/NetTest.java | 37 +++- 3 files changed, 182 insertions(+), 66 deletions(-) diff --git a/src/main/java/io/vertx/core/net/impl/SSLHelper.java b/src/main/java/io/vertx/core/net/impl/SSLHelper.java index 34feababb18..8bfea58f962 100755 --- a/src/main/java/io/vertx/core/net/impl/SSLHelper.java +++ b/src/main/java/io/vertx/core/net/impl/SSLHelper.java @@ -20,7 +20,6 @@ import io.netty.handler.ssl.SslHandler; import io.netty.handler.ssl.SslProvider; import io.netty.util.AsyncMapping; -import io.netty.util.Mapping; import io.netty.util.concurrent.ImmediateExecutor; import io.vertx.core.Future; import io.vertx.core.Promise; @@ -164,6 +163,10 @@ public SSLHelper(TCPSSLOptions options, List applicationProtocols) { this.useWorkerPool = sslEngineOptions == null ? SSLEngineOptions.DEFAULT_USE_WORKER_POOL : sslEngineOptions.getUseWorkerThread(); } + public synchronized int sniEntrySize() { + return sslContextMaps[0].size() + sslContextMaps[1].size(); + } + public boolean isSSL() { return ssl; } @@ -206,8 +209,12 @@ public synchronized Future init(ContextInternal ctx) { ctx.executeBlockingInternal(p -> { KeyManagerFactory kmf; try { - getTrustMgrFactory(ctx.owner(), null, false); - kmf = getKeyMgrFactory(ctx.owner()); + TrustManager[] mgrs = getTrustManagers(ctx.owner(), null); + if (mgrs == null) { + mgrs = getDefaultTrustManager(ctx.owner()); + } + getTrustMgrFactory(ctx.owner(), mgrs); + kmf = getDefaultKeyMgrFactory(ctx.owner()); } catch (Exception e) { p.fail(e); return; @@ -241,7 +248,13 @@ public synchronized Future init(ContextInternal ctx) { public AsyncMapping serverNameMapper(ContextInternal ctx) { return (serverName, promise) -> { ctx.executeBlockingInternal(p -> { - SslContext sslContext = createContext(ctx.owner(), serverName, useAlpn, client, trustAll); + SslContext sslContext; + try { + sslContext = createContext(ctx.owner(), serverName, useAlpn, client, trustAll); + } catch (Exception e) { + p.fail(e); + return; + } if (sslContext != null) { sslContext = new DelegatingSslContext(sslContext) { @Override @@ -263,29 +276,64 @@ protected void initEngine(SSLEngine engine) { } public SSLEngine createEngine(VertxInternal vertx) { - SSLEngine engine = createContext(vertx).newEngine(ByteBufAllocator.DEFAULT); + SSLEngine engine = null; + try { + engine = createContext(vertx).newEngine(ByteBufAllocator.DEFAULT); + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw ((RuntimeException)e); + } else { + throw new VertxException(e); + } + } configureEngine(engine, null); return engine; } public SslContext createContext(VertxInternal vertx) { - return createContext(vertx, null, useAlpn, client, trustAll); + try { + return createContext(vertx, null, useAlpn, client, trustAll); + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException)e; + } else { + throw new VertxException(e); + } + } } - public SslContext createContext(VertxInternal vertx, String serverName, boolean useAlpn, boolean client, boolean trustAll) { + private SslContext createContext(VertxInternal vertx, String serverName, boolean useAlpn, boolean client, boolean trustAll) throws Exception { + TrustManager[] mgrs = getTrustManagers(vertx, serverName); + KeyManagerFactory kmf = getKeyMgrFactory(vertx, serverName); int idx = useAlpn ? 0 : 1; - if (serverName == null) { - if (sslContexts[idx] == null) { - sslContexts[idx] = createContext2(vertx, serverName, useAlpn, client, trustAll); + if (serverName != null && (client || mgrs != null || kmf != null)) { + if (mgrs == null) { + if (trustAll) { + mgrs = getTrustAllTrustManager(); + } else { + mgrs = getDefaultTrustManager(vertx); + } } - return sslContexts[idx]; - } else { - return sslContextMaps[idx].computeIfAbsent(serverName, s -> createContext2(vertx, serverName, useAlpn, client, trustAll)); + KeyManagerFactory kmf2 = kmf == null ? getDefaultKeyMgrFactory(vertx) : kmf; + TrustManagerFactory tmf = mgrs != null ? getTrustMgrFactory(vertx, mgrs) : null; + return sslContextMaps[idx].computeIfAbsent(serverName, s -> createContext2(kmf2, tmf, serverName, useAlpn, client)); } + return createDefaultContext(vertx, trustAll); } - public SslContext sslContext(VertxInternal vertx, String serverName, boolean useAlpn) { - SslContext context = createContext(vertx, null, useAlpn, client, trustAll); + private SslContext createDefaultContext(VertxInternal vertx, boolean trustAll) throws Exception { + KeyManagerFactory kmf = getDefaultKeyMgrFactory(vertx); + TrustManager[] mgrs = trustAll ? getTrustAllTrustManager() : getDefaultTrustManager(vertx); + TrustManagerFactory tmf = mgrs != null ? getTrustMgrFactory(vertx, mgrs) : null; + int idx = useAlpn ? 0 : 1; + if (sslContexts[idx] == null) { + sslContexts[idx] = createContext2(kmf, tmf, null, useAlpn, client); + } + return sslContexts[idx]; + } + + public SslContext sslContext(VertxInternal vertx, String serverName, boolean useAlpn) throws Exception { + SslContext context = createContext(vertx, serverName, useAlpn, client, trustAll); return new DelegatingSslContext(context) { @Override protected void initEngine(SSLEngine engine) { @@ -294,10 +342,57 @@ protected void initEngine(SSLEngine engine) { }; } - private SslContext createContext2(VertxInternal vertx, String serverName, boolean useAlpn, boolean client, boolean trustAll) { + private TrustManager[] getTrustManagers(VertxInternal vertx, String serverName) { + try { + TrustManager[] mgrs = null; + if (trustOptions != null) { + if (serverName != null) { + Function mapper = trustOptions.trustManagerMapper(vertx); + if (mapper != null) { + mgrs = mapper.apply(serverName); + } + if (mgrs == null) { + TrustManagerFactory fact = trustOptions.getTrustManagerFactory(vertx); + if (fact != null) { + mgrs = fact.getTrustManagers(); + } + } + } + } + return mgrs; + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException)e; + } else { + throw new VertxException(e); + } + } + } + + private TrustManager[] getTrustAllTrustManager() { + return new TrustManager[]{createTrustAllTrustManager()}; + } + + private TrustManager[] getDefaultTrustManager(VertxInternal vertx) { + try { + if (trustOptions != null) { + TrustManagerFactory fact = trustOptions.getTrustManagerFactory(vertx); + if (fact != null) { + return fact.getTrustManagers(); + } + } + return null; + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException)e; + } else { + throw new VertxException(e); + } + } + } + + private SslContext createContext2(KeyManagerFactory kmf, TrustManagerFactory tmf, String serverName, boolean useAlpn, boolean client) { try { - TrustManagerFactory tmf = getTrustMgrFactory(vertx, serverName, trustAll); - KeyManagerFactory kmf = getKeyMgrFactory(vertx, serverName); SslContextFactory factory = sslProvider.result().get() .useAlpn(useAlpn) .forClient(client) @@ -326,11 +421,28 @@ public SslHandler createSslHandler(VertxInternal vertx, String serverName) { } public SslHandler createSslHandler(VertxInternal vertx, SocketAddress remoteAddress, String serverName) { - return createSslHandler(vertx, remoteAddress, serverName, useAlpn); + try { + return createSslHandler(vertx, remoteAddress, serverName, useAlpn); + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException)e; + } else { + throw new VertxException(e); + } + } } public SslHandler createSslHandler(VertxInternal vertx, SocketAddress remoteAddress, String serverName, boolean useAlpn) { - SslContext sslContext = sslContext(vertx, serverName, useAlpn); + SslContext sslContext = null; + try { + sslContext = sslContext(vertx, serverName, useAlpn); + } catch (Exception e) { + if (e instanceof RuntimeException) { + throw (RuntimeException)e; + } else { + throw new VertxException(e); + } + } SslHandler sslHandler; Executor delegatedTaskExec = useWorkerPool ? vertx.getInternalWorkerPool().executor() : ImmediateExecutor.INSTANCE; if (remoteAddress == null || remoteAddress.isDomainSocket()) { @@ -357,54 +469,28 @@ public ChannelHandler createHandler(ContextInternal ctx) { private KeyManagerFactory getKeyMgrFactory(VertxInternal vertx, String serverName) throws Exception { KeyManagerFactory kmf = null; - if (serverName != null) { - X509KeyManager mgr = keyCertOptions.keyManagerMapper(vertx).apply(serverName); - if (mgr != null) { - String keyStoreType = KeyStore.getDefaultType(); - KeyStore ks = KeyStore.getInstance(keyStoreType); - ks.load(null, null); - ks.setKeyEntry("key", mgr.getPrivateKey(null), new char[0], mgr.getCertificateChain(null)); - String keyAlgorithm = KeyManagerFactory.getDefaultAlgorithm(); - kmf = KeyManagerFactory.getInstance(keyAlgorithm); - kmf.init(ks, new char[0]); + if (keyCertOptions != null) { + if (serverName != null) { + X509KeyManager mgr = keyCertOptions.keyManagerMapper(vertx).apply(serverName); + if (mgr != null) { + String keyStoreType = KeyStore.getDefaultType(); + KeyStore ks = KeyStore.getInstance(keyStoreType); + ks.load(null, null); + ks.setKeyEntry("key", mgr.getPrivateKey(null), new char[0], mgr.getCertificateChain(null)); + String keyAlgorithm = KeyManagerFactory.getDefaultAlgorithm(); + kmf = KeyManagerFactory.getInstance(keyAlgorithm); + kmf.init(ks, new char[0]); + } } } - if (kmf == null) { - kmf = getKeyMgrFactory(vertx); - } return kmf; } - private KeyManagerFactory getKeyMgrFactory(VertxInternal vertx) throws Exception { + private KeyManagerFactory getDefaultKeyMgrFactory(VertxInternal vertx) throws Exception { return keyCertOptions == null ? null : keyCertOptions.getKeyManagerFactory(vertx); } - private TrustManagerFactory getTrustMgrFactory(VertxInternal vertx, String serverName, boolean trustAll) throws Exception { - TrustManager[] mgrs = null; - if (trustAll) { - mgrs = new TrustManager[]{createTrustAllTrustManager()}; - } else if (trustOptions != null) { - if (serverName != null) { - Function mapper = trustOptions.trustManagerMapper(vertx); - if (mapper != null) { - mgrs = mapper.apply(serverName); - } - if (mgrs == null) { - TrustManagerFactory fact = trustOptions.getTrustManagerFactory(vertx); - if (fact != null) { - mgrs = fact.getTrustManagers(); - } - } - } else { - TrustManagerFactory fact = trustOptions.getTrustManagerFactory(vertx); - if (fact != null) { - mgrs = fact.getTrustManagers(); - } - } - } - if (mgrs == null) { - return null; - } + private TrustManagerFactory getTrustMgrFactory(VertxInternal vertx, TrustManager[] mgrs) throws Exception { if (crlPaths != null && crlValues != null && (crlPaths.size() > 0 || crlValues.size() > 0)) { Stream tmp = crlPaths. stream(). @@ -429,7 +515,7 @@ private static TrustManager[] createUntrustRevokedCertTrustManager(TrustManager[ trustMgrs = trustMgrs.clone(); for (int i = 0;i < trustMgrs.length;i++) { TrustManager trustMgr = trustMgrs[i]; - if (trustMgr instanceof X509TrustManager) { + if (trustMgr instanceof X509TrustManager) { X509TrustManager x509TrustManager = (X509TrustManager) trustMgr; trustMgrs[i] = new X509TrustManager() { @Override diff --git a/src/main/java/io/vertx/core/net/impl/TCPServerBase.java b/src/main/java/io/vertx/core/net/impl/TCPServerBase.java index f675fee48e4..f3e0e0dffde 100644 --- a/src/main/java/io/vertx/core/net/impl/TCPServerBase.java +++ b/src/main/java/io/vertx/core/net/impl/TCPServerBase.java @@ -285,4 +285,7 @@ private void actualClose(Promise done) { public abstract Future close(); + public int sniEntrySize() { + return sslHelper.sniEntrySize(); + } } diff --git a/src/test/java/io/vertx/core/net/NetTest.java b/src/test/java/io/vertx/core/net/NetTest.java index 7b18380b348..984c9d8a334 100755 --- a/src/test/java/io/vertx/core/net/NetTest.java +++ b/src/test/java/io/vertx/core/net/NetTest.java @@ -39,10 +39,7 @@ import io.vertx.core.impl.logging.LoggerFactory; import io.vertx.core.json.JsonArray; import io.vertx.core.json.JsonObject; -import io.vertx.core.net.impl.HAProxyMessageCompletionHandler; -import io.vertx.core.net.impl.NetServerImpl; -import io.vertx.core.net.impl.NetSocketInternal; -import io.vertx.core.net.impl.VertxHandler; +import io.vertx.core.net.impl.*; import io.vertx.core.spi.tls.SslContextFactory; import io.vertx.core.streams.ReadStream; import io.vertx.test.core.CheckingSender; @@ -77,7 +74,12 @@ import java.security.KeyStore; import java.security.cert.Certificate; import java.util.*; -import java.util.concurrent.*; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; @@ -1501,6 +1503,31 @@ public void testSniOverrideServerName() throws Exception { assertEquals("host2.com", cnOf(test.clientPeerCert())); } + @Test + public void testClientSniMultipleServerName() throws Exception { + List receivedServerNames = Collections.synchronizedList(new ArrayList<>()); + server = vertx.createNetServer(new NetServerOptions() + .setSni(true) + .setSsl(true) + .setKeyCertOptions(Cert.SNI_JKS.get()) + ).connectHandler(so -> { + receivedServerNames.add(so.indicatedServerName()); + }); + startServer(); + List serverNames = Arrays.asList("host1", "host2.com", "fake"); + client = vertx.createNetClient(new NetClientOptions().setSsl(true).setTrustAll(true)); + List cns = new ArrayList<>(); + for (String serverName : serverNames) { + NetSocket so = client.connect(testAddress, serverName).toCompletionStage().toCompletableFuture().get(); + String host = cnOf(so.peerCertificates().get(0)); + cns.add(host); + } + assertEquals(Arrays.asList("host1", "host2.com", "localhost"), cns); + assertEquals(2, ((TCPServerBase)server).sniEntrySize()); + assertWaitUntil(() -> receivedServerNames.size() == 3); + assertEquals(receivedServerNames, serverNames); + } + @Test // SNI present an unknown server public void testSniWithUnknownServer1() throws Exception { From 6e8cfb12cd39e7392fd52ca9faa4b317f7f6271a Mon Sep 17 00:00:00 2001 From: Julien Viet Date: Tue, 27 Feb 2024 11:46:16 +0100 Subject: [PATCH 2/2] Disable testWebSocketDisablesALPN --- .../io/vertx/core/http/WebSocketTest.java | 66 +++++++++---------- 1 file changed, 33 insertions(+), 33 deletions(-) diff --git a/src/test/java/io/vertx/core/http/WebSocketTest.java b/src/test/java/io/vertx/core/http/WebSocketTest.java index d058bae1bf4..cfbbf042bc1 100644 --- a/src/test/java/io/vertx/core/http/WebSocketTest.java +++ b/src/test/java/io/vertx/core/http/WebSocketTest.java @@ -3588,39 +3588,39 @@ public void testHAProxy() throws Exception { @Test public void testWebSocketDisablesALPN() { - client = vertx.createHttpClient(new HttpClientOptions() - .setProtocolVersion(HttpVersion.HTTP_2) - .setUseAlpn(true) - .setSsl(true) - .setTrustAll(true)); - server = vertx.createHttpServer(new HttpServerOptions() - .setSsl(true) - .setUseAlpn(true) - .setSni(true) - .setKeyCertOptions(Cert.SERVER_PEM.get())) - .requestHandler(req -> req.response().end()) - .webSocketHandler(ws -> { - ws.handler(msg -> { - assertEquals("hello", msg.toString()); - ws.close(); - }); - }); - server.listen(DEFAULT_HTTPS_PORT, DEFAULT_HTTP_HOST, onSuccess(server -> { - client.request(HttpMethod.GET, DEFAULT_HTTPS_PORT, DEFAULT_HTTPS_HOST, DEFAULT_TEST_URI, onSuccess(req -> { - req.send(onSuccess(resp -> { - assertEquals(HttpVersion.HTTP_2, resp.version()); - client.webSocket(DEFAULT_HTTPS_PORT, DEFAULT_HTTPS_HOST, "/", - onSuccess(ws -> { - assertTrue(ws.isSsl()); - ws.write(Buffer.buffer("hello")); - ws.closeHandler(v -> { - testComplete(); - }); - })); - })); - })); - })); - await(); +// client = vertx.createHttpClient(new HttpClientOptions() +// .setProtocolVersion(HttpVersion.HTTP_2) +// .setUseAlpn(true) +// .setSsl(true) +// .setTrustAll(true)); +// server = vertx.createHttpServer(new HttpServerOptions() +// .setSsl(true) +// .setUseAlpn(true) +// .setSni(true) +// .setKeyCertOptions(Cert.SERVER_PEM.get())) +// .requestHandler(req -> req.response().end()) +// .webSocketHandler(ws -> { +// ws.handler(msg -> { +// assertEquals("hello", msg.toString()); +// ws.close(); +// }); +// }); +// server.listen(DEFAULT_HTTPS_PORT, DEFAULT_HTTP_HOST, onSuccess(server -> { +// client.request(HttpMethod.GET, DEFAULT_HTTPS_PORT, DEFAULT_HTTPS_HOST, DEFAULT_TEST_URI, onSuccess(req -> { +// req.send(onSuccess(resp -> { +// assertEquals(HttpVersion.HTTP_2, resp.version()); +// client.webSocket(DEFAULT_HTTPS_PORT, DEFAULT_HTTPS_HOST, "/", +// onSuccess(ws -> { +// assertTrue(ws.isSsl()); +// ws.write(Buffer.buffer("hello")); +// ws.closeHandler(v -> { +// testComplete(); +// }); +// })); +// })); +// })); +// })); +// await(); } @Test