From 89ceb5994b8c4941b9b9a1c43623a408799453d8 Mon Sep 17 00:00:00 2001 From: Georgios Andrianakis Date: Mon, 10 Jun 2024 17:22:10 +0300 Subject: [PATCH] Add TLS Registry configuration to WebSockets Next Client Resolves: #41004 --- extensions/websockets-next/deployment/pom.xml | 9 ++ .../client/MtlsWithP12ClientEndpointTest.java | 130 ++++++++++++++++++ .../test/client/TlsClientEndpointTest.java | 121 ++++++++++++++++ extensions/websockets-next/runtime/pom.xml | 4 + .../next/WebSocketsClientRuntimeConfig.java | 10 ++ .../runtime/BasicWebSocketConnectorImpl.java | 19 +-- .../next/runtime/WebSocketConnectorBase.java | 64 ++++++++- .../next/runtime/WebSocketConnectorImpl.java | 20 +-- 8 files changed, 346 insertions(+), 31 deletions(-) create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/MtlsWithP12ClientEndpointTest.java create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java diff --git a/extensions/websockets-next/deployment/pom.xml b/extensions/websockets-next/deployment/pom.xml index 78e90a6a61959..7a72a09f793d2 100644 --- a/extensions/websockets-next/deployment/pom.xml +++ b/extensions/websockets-next/deployment/pom.xml @@ -25,6 +25,10 @@ io.quarkus quarkus-jackson-deployment + + io.quarkus + quarkus-tls-registry-deployment + io.quarkus quarkus-websockets-next @@ -56,6 +60,11 @@ assertj-core test + + me.escoffier.certs + certificate-generator-junit5 + test + diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/MtlsWithP12ClientEndpointTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/MtlsWithP12ClientEndpointTest.java new file mode 100644 index 0000000000000..974b29e25d63b --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/MtlsWithP12ClientEndpointTest.java @@ -0,0 +1,130 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.File; +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketConnector; +import me.escoffier.certs.Format; +import me.escoffier.certs.junit5.Certificate; +import me.escoffier.certs.junit5.Certificates; + +@Certificates(baseDir = "target/certs", certificates = @Certificate(name = "mtls-test", password = "secret", formats = { + Format.JKS, Format.PKCS12, Format.PEM }, client = true)) +public class MtlsWithP12ClientEndpointTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(ServerEndpoint.class, ClientEndpoint.class) + .addAsResource(new File("target/certs/mtls-test-keystore.p12"), "server-keystore.p12") + .addAsResource(new File("target/certs/mtls-test-server-truststore.p12"), "server-truststore.p12") + .addAsResource(new File("target/certs/mtls-test-client-keystore.p12"), "client-keystore.p12") + .addAsResource(new File("target/certs/mtls-test-client-truststore.p12"), "client-truststore.p12")) + + .overrideConfigKey("quarkus.tls.ws-server.key-store.p12.path", "server-keystore.p12") + .overrideConfigKey("quarkus.tls.ws-server.key-store.p12.password", "secret") + .overrideConfigKey("quarkus.tls.ws-server.trust-store.p12.path", "server-truststore.p12") + .overrideConfigKey("quarkus.tls.ws-server.trust-store.p12.password", "secret") + .overrideConfigKey("quarkus.http.tls-configuration-name", "ws-server") + + .overrideConfigKey("quarkus.tls.ws-client.key-store.p12.path", "client-keystore.p12") + .overrideConfigKey("quarkus.tls.ws-client.key-store.p12.password", "secret") + .overrideConfigKey("quarkus.tls.ws-client.trust-store.p12.path", "client-truststore.p12") + .overrideConfigKey("quarkus.tls.ws-client.trust-store.p12.password", "secret") + .overrideConfigKey("quarkus.websockets-next.client.tls-configuration-name", "ws-client"); + + @Inject + WebSocketConnector connector; + + @TestHTTPResource(value = "/", tls = true) + URI uri; + + @Test + void testClient() throws InterruptedException { + WebSocketClientConnection connection = connector + .baseUri(uri) + // The value will be encoded automatically + .pathParam("name", "Lu=") + .connectAndAwait(); + assertTrue(connection.isSecure()); + + assertEquals("Lu=", connection.pathParam("name")); + connection.sendTextAndAwait("Hi!"); + + assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS)); + assertEquals("Lu=:Hello Lu=!", ClientEndpoint.MESSAGES.get(0)); + assertEquals("Lu=:Hi!", ClientEndpoint.MESSAGES.get(1)); + + connection.closeAndAwait(); + assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/endpoint/{name}") + public static class ServerEndpoint { + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + String open(@PathParam String name) { + return "Hello " + name + "!"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + + @WebSocketClient(path = "/endpoint/{name}") + public static class ClientEndpoint { + + static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + + static final List MESSAGES = new CopyOnWriteArrayList<>(); + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnTextMessage + void onMessage(@PathParam String name, String message, WebSocketClientConnection connection) { + if (!name.equals(connection.pathParam("name"))) { + throw new IllegalArgumentException(); + } + MESSAGES.add(name + ":" + message); + MESSAGE_LATCH.countDown(); + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } +} diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java new file mode 100644 index 0000000000000..07ec8054d0d2d --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/client/TlsClientEndpointTest.java @@ -0,0 +1,121 @@ +package io.quarkus.websockets.next.test.client; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.File; +import java.net.URI; +import java.util.List; +import java.util.concurrent.CopyOnWriteArrayList; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.OnClose; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.PathParam; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketClient; +import io.quarkus.websockets.next.WebSocketClientConnection; +import io.quarkus.websockets.next.WebSocketConnector; +import me.escoffier.certs.Format; +import me.escoffier.certs.junit5.Certificate; +import me.escoffier.certs.junit5.Certificates; + +@Certificates(baseDir = "target/certs", certificates = @Certificate(name = "ssl-test", password = "secret", formats = { + Format.JKS, Format.PKCS12, Format.PEM })) +public class TlsClientEndpointTest { + + @RegisterExtension + static final QuarkusUnitTest config = new QuarkusUnitTest() + .withApplicationRoot((jar) -> jar + .addClasses(ServerEndpoint.class, ClientEndpoint.class) + .addAsResource(new File("target/certs/ssl-test-keystore.jks"), "keystore.jks") + .addAsResource(new File("target/certs/ssl-test-truststore.jks"), "truststore.jks")) + .overrideConfigKey("quarkus.tls.key-store.jks.path", "keystore.jks") + .overrideConfigKey("quarkus.tls.key-store.jks.password", "secret") + .overrideConfigKey("quarkus.tls.ws-client.trust-store.jks.path", "truststore.jks") + .overrideConfigKey("quarkus.tls.ws-client.trust-store.jks.password", "secret") + .overrideConfigKey("quarkus.websockets-next.client.tls-configuration-name", "ws-client"); + + @Inject + WebSocketConnector connector; + + @TestHTTPResource(value = "/", tls = true) + URI uri; + + @Test + void testClient() throws InterruptedException { + WebSocketClientConnection connection = connector + .baseUri(uri) + // The value will be encoded automatically + .pathParam("name", "Lu=") + .connectAndAwait(); + assertTrue(connection.isSecure()); + + assertEquals("Lu=", connection.pathParam("name")); + connection.sendTextAndAwait("Hi!"); + + assertTrue(ClientEndpoint.MESSAGE_LATCH.await(5, TimeUnit.SECONDS)); + assertEquals("Lu=:Hello Lu=!", ClientEndpoint.MESSAGES.get(0)); + assertEquals("Lu=:Hi!", ClientEndpoint.MESSAGES.get(1)); + + connection.closeAndAwait(); + assertTrue(ClientEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + assertTrue(ServerEndpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS)); + } + + @WebSocket(path = "/endpoint/{name}") + public static class ServerEndpoint { + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnOpen + String open(@PathParam String name) { + return "Hello " + name + "!"; + } + + @OnTextMessage + String echo(String message) { + return message; + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } + + @WebSocketClient(path = "/endpoint/{name}") + public static class ClientEndpoint { + + static final CountDownLatch MESSAGE_LATCH = new CountDownLatch(2); + + static final List MESSAGES = new CopyOnWriteArrayList<>(); + + static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1); + + @OnTextMessage + void onMessage(@PathParam String name, String message, WebSocketClientConnection connection) { + if (!name.equals(connection.pathParam("name"))) { + throw new IllegalArgumentException(); + } + MESSAGES.add(name + ":" + message); + MESSAGE_LATCH.countDown(); + } + + @OnClose + void close() { + CLOSED_LATCH.countDown(); + } + + } +} diff --git a/extensions/websockets-next/runtime/pom.xml b/extensions/websockets-next/runtime/pom.xml index d913689652388..0d2a68bd5bbbc 100644 --- a/extensions/websockets-next/runtime/pom.xml +++ b/extensions/websockets-next/runtime/pom.xml @@ -26,6 +26,10 @@ io.quarkus quarkus-jackson + + io.quarkus + quarkus-tls-registry + io.quarkus.security diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java index ecaf0bb169d0d..b79a9de857853 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketsClientRuntimeConfig.java @@ -48,4 +48,14 @@ public interface WebSocketsClientRuntimeConfig { @WithDefault("close") UnhandledFailureStrategy unhandledFailureStrategy(); + /** + * The name of the TLS configuration to use. + *

+ * If a name is configured, it uses the configuration from {@code quarkus.tls..*} + * If a name is configured, but no TLS configuration is found with that name then an error will be thrown. + *

+ * The default TLS configuration is not used by default. + */ + Optional tlsConfigurationName(); + } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java index 46eca5bd0b36e..d47df577837d3 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/BasicWebSocketConnectorImpl.java @@ -14,6 +14,7 @@ import org.jboss.logging.Logger; +import io.quarkus.tls.TlsConfigurationRegistry; import io.quarkus.virtual.threads.VirtualThreadsRecorder; import io.quarkus.websockets.next.BasicWebSocketConnector; import io.quarkus.websockets.next.CloseReason; @@ -27,7 +28,6 @@ import io.vertx.core.Vertx; import io.vertx.core.buffer.Buffer; import io.vertx.core.http.WebSocketClient; -import io.vertx.core.http.WebSocketClientOptions; import io.vertx.core.http.WebSocketConnectOptions; @Typed(BasicWebSocketConnector.class) @@ -54,8 +54,8 @@ public class BasicWebSocketConnectorImpl extends WebSocketConnectorBase errorHandler; BasicWebSocketConnectorImpl(Vertx vertx, Codecs codecs, ClientConnectionManager connectionManager, - WebSocketsClientRuntimeConfig config) { - super(vertx, codecs, connectionManager, config); + WebSocketsClientRuntimeConfig config, TlsConfigurationRegistry tlsConfigurationRegistry) { + super(vertx, codecs, connectionManager, config, tlsConfigurationRegistry); } @Override @@ -115,18 +115,7 @@ public Uni connect() { // Currently we create a new client for each connection // The client is closed when the connection is closed // TODO would it make sense to share clients? - WebSocketClientOptions clientOptions = new WebSocketClientOptions(); - if (config.offerPerMessageCompression()) { - clientOptions.setTryUsePerMessageCompression(true); - if (config.compressionLevel().isPresent()) { - clientOptions.setCompressionLevel(config.compressionLevel().getAsInt()); - } - } - if (config.maxMessageSize().isPresent()) { - clientOptions.setMaxMessageSize(config.maxMessageSize().getAsInt()); - } - - WebSocketClient client = vertx.createWebSocketClient(); + WebSocketClient client = vertx.createWebSocketClient(populateClientOptions()); WebSocketConnectOptions connectOptions = new WebSocketConnectOptions() .setSsl(baseUri.getScheme().equals("https")) diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java index 728850f3083fd..ee098d6a43d16 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorBase.java @@ -9,13 +9,19 @@ import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; +import io.quarkus.tls.TlsConfiguration; +import io.quarkus.tls.TlsConfigurationRegistry; import io.quarkus.websockets.next.WebSocketClientException; import io.quarkus.websockets.next.WebSocketsClientRuntimeConfig; import io.vertx.core.Vertx; +import io.vertx.core.buffer.Buffer; +import io.vertx.core.http.WebSocketClientOptions; +import io.vertx.core.net.SSLOptions; abstract class WebSocketConnectorBase> { @@ -45,8 +51,11 @@ abstract class WebSocketConnectorBase> protected final WebSocketsClientRuntimeConfig config; + protected final TlsConfigurationRegistry tlsConfigurationRegistry; + WebSocketConnectorBase(Vertx vertx, Codecs codecs, - ClientConnectionManager connectionManager, WebSocketsClientRuntimeConfig config) { + ClientConnectionManager connectionManager, WebSocketsClientRuntimeConfig config, + TlsConfigurationRegistry tlsConfigurationRegistry) { this.headers = new HashMap<>(); this.subprotocols = new HashSet<>(); this.pathParams = new HashMap<>(); @@ -54,6 +63,7 @@ abstract class WebSocketConnectorBase> this.codecs = codecs; this.connectionManager = connectionManager; this.config = config; + this.tlsConfigurationRegistry = tlsConfigurationRegistry; this.path = ""; this.pathParamNames = Set.of(); } @@ -129,4 +139,56 @@ String replacePathParameters(String path) { return path.startsWith("/") ? sb.toString() : "/" + sb.toString(); } + protected WebSocketClientOptions populateClientOptions() { + WebSocketClientOptions clientOptions = new WebSocketClientOptions(); + if (config.offerPerMessageCompression()) { + clientOptions.setTryUsePerMessageCompression(true); + if (config.compressionLevel().isPresent()) { + clientOptions.setCompressionLevel(config.compressionLevel().getAsInt()); + } + } + if (config.maxMessageSize().isPresent()) { + clientOptions.setMaxMessageSize(config.maxMessageSize().getAsInt()); + } + + Optional maybeTlsConfiguration = TlsConfiguration.from(tlsConfigurationRegistry, + config.tlsConfigurationName()); + if (maybeTlsConfiguration.isPresent()) { + clientOptions.setSsl(true); + + TlsConfiguration tlsConfiguration = maybeTlsConfiguration.get(); + if (tlsConfiguration.getTrustStoreOptions() != null) { + clientOptions.setTrustOptions(tlsConfiguration.getTrustStoreOptions()); + } + + // For mTLS: + if (tlsConfiguration.getKeyStoreOptions() != null) { + clientOptions.setKeyCertOptions(tlsConfiguration.getKeyStoreOptions()); + } + + if (tlsConfiguration.isTrustAll()) { + clientOptions.setTrustAll(true); + } + if (tlsConfiguration.getHostnameVerificationAlgorithm().isPresent() + && tlsConfiguration.getHostnameVerificationAlgorithm().get().equals("NONE")) { + // Only disable hostname verification if the algorithm is explicitly set to NONE + clientOptions.setVerifyHost(false); + } + + SSLOptions sslOptions = tlsConfiguration.getSSLOptions(); + if (sslOptions != null) { + clientOptions.setSslHandshakeTimeout(sslOptions.getSslHandshakeTimeout()); + clientOptions.setSslHandshakeTimeoutUnit(sslOptions.getSslHandshakeTimeoutUnit()); + for (String suite : sslOptions.getEnabledCipherSuites()) { + clientOptions.addEnabledCipherSuite(suite); + } + for (Buffer buffer : sslOptions.getCrlValues()) { + clientOptions.addCrlValue(buffer); + } + clientOptions.setEnabledSecureTransportProtocols(sslOptions.getEnabledSecureTransportProtocols()); + clientOptions.setUseAlpn(sslOptions.isUseAlpn()); + } + } + return clientOptions; + } } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java index ceaeab285dd80..be39e41799564 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectorImpl.java @@ -16,6 +16,7 @@ import org.jboss.logging.Logger; import io.quarkus.arc.Arc; +import io.quarkus.tls.TlsConfigurationRegistry; import io.quarkus.websockets.next.WebSocketClientConnection; import io.quarkus.websockets.next.WebSocketClientException; import io.quarkus.websockets.next.WebSocketConnector; @@ -26,7 +27,6 @@ import io.smallrye.mutiny.vertx.UniHelper; import io.vertx.core.Vertx; import io.vertx.core.http.WebSocketClient; -import io.vertx.core.http.WebSocketClientOptions; import io.vertx.core.http.WebSocketConnectOptions; @Typed(WebSocketConnector.class) @@ -41,8 +41,9 @@ public class WebSocketConnectorImpl extends WebSocketConnectorBase connect() { // Currently we create a new client for each connection // The client is closed when the connection is closed // TODO would it make sense to share clients? - WebSocketClientOptions clientOptions = new WebSocketClientOptions(); - if (config.offerPerMessageCompression()) { - clientOptions.setTryUsePerMessageCompression(true); - if (config.compressionLevel().isPresent()) { - clientOptions.setCompressionLevel(config.compressionLevel().getAsInt()); - } - } - if (config.maxMessageSize().isPresent()) { - clientOptions.setMaxMessageSize(config.maxMessageSize().getAsInt()); - } - - WebSocketClient client = vertx.createWebSocketClient(); + WebSocketClient client = vertx.createWebSocketClient(populateClientOptions()); StringBuilder serverEndpoint = new StringBuilder(); if (baseUri != null) {