From 63f7d6d5c5ad0a5bf8050755948ada096c095f64 Mon Sep 17 00:00:00 2001 From: Martin Kouba Date: Wed, 9 Oct 2024 17:46:37 +0200 Subject: [PATCH] WebSockets Next: make it possible to store user data in a connection - resolves #43772 --- .../test/connection/ConnectionDataTest.java | 81 ++++++++++++++ .../quarkus/websockets/next/Connection.java | 102 ++++++++++++++++++ .../io/quarkus/websockets/next/UserData.java | 55 ++++++++++ .../next/WebSocketClientConnection.java | 80 +------------- .../websockets/next/WebSocketConnection.java | 89 +-------------- .../websockets/next/runtime/UserDataImpl.java | 44 ++++++++ .../next/runtime/WebSocketConnectionBase.java | 27 ++++- 7 files changed, 311 insertions(+), 167 deletions(-) create mode 100644 extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionDataTest.java create mode 100644 extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/Connection.java create mode 100644 extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UserData.java create mode 100644 extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/UserDataImpl.java diff --git a/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionDataTest.java b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionDataTest.java new file mode 100644 index 00000000000000..7073f4bf25f177 --- /dev/null +++ b/extensions/websockets-next/deployment/src/test/java/io/quarkus/websockets/next/test/connection/ConnectionDataTest.java @@ -0,0 +1,81 @@ +package io.quarkus.websockets.next.test.connection; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.net.URI; +import java.util.List; + +import jakarta.inject.Inject; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +import io.quarkus.test.QuarkusUnitTest; +import io.quarkus.test.common.http.TestHTTPResource; +import io.quarkus.websockets.next.UserData.TypedKey; +import io.quarkus.websockets.next.OnOpen; +import io.quarkus.websockets.next.OnTextMessage; +import io.quarkus.websockets.next.WebSocket; +import io.quarkus.websockets.next.WebSocketConnection; +import io.quarkus.websockets.next.test.utils.WSClient; +import io.vertx.core.Vertx; + +public class ConnectionDataTest { + + @RegisterExtension + public static final QuarkusUnitTest test = new QuarkusUnitTest() + .withApplicationRoot(root -> { + root.addClasses(MyEndpoint.class, WSClient.class); + }); + + @Inject + Vertx vertx; + + @TestHTTPResource("/end") + URI baseUri; + + @Test + void testConnectionData() { + try (WSClient client = WSClient.create(vertx).connect(baseUri)) { + assertEquals("4", client.sendAndAwaitReply("bar").toString()); + assertEquals("FOOMartin", client.sendAndAwaitReply("foo").toString()); + assertEquals("0", client.sendAndAwaitReply("bar").toString()); + } + } + + @WebSocket(path = "/end") + public static class MyEndpoint { + + @OnOpen + void onOpen(WebSocketConnection connection) { + connection.userData().put(TypedKey.forLong("foo"), 42l); + connection.userData().put(TypedKey.forString("username"), "Martin"); + connection.userData().put(TypedKey.forBoolean("isActive"), true); + connection.userData().put(new TypedKey>("list"), List.of()); + } + + @OnTextMessage + public String onMessage(String message, WebSocketConnection connection) { + if ("bar".equals(message)) { + return connection.userData().size() + ""; + } + try { + connection.userData().get(TypedKey.forString("foo")).toString(); + throw new IllegalStateException(); + } catch (ClassCastException expected) { + } + if (!connection.userData().get(TypedKey.forBoolean("isActive")) + || !connection.userData().get(new TypedKey>("list")).isEmpty()) { + return "NOK"; + } + if (connection.userData().remove(TypedKey.forLong("foo")) != 42l) { + throw new IllegalStateException(); + } + String ret = message.toUpperCase() + connection.userData().get(TypedKey.forString("username")); + connection.userData().clear(); + return ret; + } + + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/Connection.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/Connection.java new file mode 100644 index 00000000000000..dd987a2a7661ec --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/Connection.java @@ -0,0 +1,102 @@ +package io.quarkus.websockets.next; + +import java.time.Instant; + +import io.smallrye.common.annotation.CheckReturnValue; +import io.smallrye.mutiny.Uni; + +/** + * + * @see WebSocketConnection + * @see WebSocketClientConnection + */ +public interface Connection extends BlockingSender { + + /** + * + * @return the unique identifier assigned to this connection + */ + String id(); + + /** + * + * @param name + * @return the value of the path parameter or {@code null} + * @see WebSocketClient#path() + */ + String pathParam(String name); + + /** + * @return {@code true} if the HTTP connection is encrypted via SSL/TLS + */ + boolean isSecure(); + + /** + * @return {@code true} if the WebSocket is closed + */ + boolean isClosed(); + + /** + * + * @return the close reason or {@code null} if the connection is not closed + */ + CloseReason closeReason(); + + /** + * + * @return {@code true} if the WebSocket is open + */ + default boolean isOpen() { + return !isClosed(); + } + + /** + * Close the connection. + * + * @return a new {@link Uni} with a {@code null} item + */ + @CheckReturnValue + default Uni close() { + return close(CloseReason.NORMAL); + } + + /** + * Close the connection with a specific reason. + * + * @param reason + * @return a new {@link Uni} with a {@code null} item + */ + Uni close(CloseReason reason); + + /** + * Close the connection and wait for the completion. + */ + default void closeAndAwait() { + close().await().indefinitely(); + } + + /** + * Close the connection with a specific reason and wait for the completion. + */ + default void closeAndAwait(CloseReason reason) { + close(reason).await().indefinitely(); + } + + /** + * + * @return the handshake request + */ + HandshakeRequest handshakeRequest(); + + /** + * + * @return the time when this connection was created + */ + Instant creationTime(); + + /** + * + * @return the user data associated with this connection + */ + UserData userData(); +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UserData.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UserData.java new file mode 100644 index 00000000000000..12c6ecfaeb3394 --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/UserData.java @@ -0,0 +1,55 @@ +package io.quarkus.websockets.next; + +/** + * Mutable user data associated with a connection. Implementations must be thread-safe. + */ +public interface UserData { + + /** + * + * @param + * @param key + * @return the value or {@code null} if no mapping is found + */ + VALUE get(TypedKey key); + + /** + * Associates the specified value with the specified key. An old value is replaced by the specified value. + * + * @param + * @param key + * @param value + * @return the previous value associated with {@code key}, or {@code null} if no mapping exists + */ + VALUE put(TypedKey key, VALUE value); + + /** + * + * @param + * @param key + */ + VALUE remove(TypedKey key); + + int size(); + + void clear(); + + /** + * @param The type this key is used for. + */ + record TypedKey(String value) { + + public static TypedKey forLong(String val) { + return new TypedKey<>(val); + } + + public static TypedKey forString(String val) { + return new TypedKey<>(val); + } + + public static TypedKey forBoolean(String val) { + return new TypedKey<>(val); + } + } + +} \ No newline at end of file diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java index 393ba422b73518..e33f95bea1e541 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketClientConnection.java @@ -1,8 +1,6 @@ package io.quarkus.websockets.next; -import io.smallrye.common.annotation.CheckReturnValue; import io.smallrye.common.annotation.Experimental; -import io.smallrye.mutiny.Uni; /** * This interface represents a client connection to a WebSocket endpoint. @@ -11,87 +9,11 @@ * endpoint and used to interact with the connected server. */ @Experimental("This API is experimental and may change in the future") -public interface WebSocketClientConnection extends Sender, BlockingSender { - - /** - * - * @return the unique identifier assigned to this connection - */ - String id(); +public interface WebSocketClientConnection extends Connection { /* * @return the client id */ String clientId(); - /** - * - * @param name - * @return the value of the path parameter or {@code null} - * @see WebSocketClient#path() - */ - String pathParam(String name); - - /** - * @return {@code true} if the HTTP connection is encrypted via SSL/TLS - */ - boolean isSecure(); - - /** - * @return {@code true} if the WebSocket is closed - */ - boolean isClosed(); - - /** - * - * @return the close reason or {@code null} if the connection is not closed - */ - CloseReason closeReason(); - - /** - * - * @return {@code true} if the WebSocket is open - */ - default boolean isOpen() { - return !isClosed(); - } - - /** - * Close the connection. - * - * @return a new {@link Uni} with a {@code null} item - */ - @CheckReturnValue - default Uni close() { - return close(CloseReason.NORMAL); - } - - /** - * Close the connection with a specific reason. - * - * @param reason - * @return a new {@link Uni} with a {@code null} item - */ - Uni close(CloseReason reason); - - /** - * Close the connection and wait for the completion. - */ - default void closeAndAwait() { - close().await().indefinitely(); - } - - /** - * Close the connection with a specific reason and wait for the completion. - */ - default void closeAndAwait(CloseReason reason) { - close(reason).await().indefinitely(); - } - - /** - * - * @return the handshake request - */ - HandshakeRequest handshakeRequest(); - } diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java index a63a3e2e5772eb..c5deaa339b216e 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/WebSocketConnection.java @@ -1,12 +1,9 @@ package io.quarkus.websockets.next; -import java.time.Instant; import java.util.Set; import java.util.function.Predicate; -import io.smallrye.common.annotation.CheckReturnValue; import io.smallrye.common.annotation.Experimental; -import io.smallrye.mutiny.Uni; /** * This interface represents a connection from a client to a specific {@link WebSocket} endpoint on the server. @@ -19,13 +16,7 @@ * {@link BlockingSender} and {@link Sender} respectively. */ @Experimental("This API is experimental and may change in the future") -public interface WebSocketConnection extends Sender, BlockingSender { - - /** - * - * @return the unique identifier assigned to this connection - */ - String id(); +public interface WebSocketConnection extends Connection { /** * @@ -34,14 +25,6 @@ public interface WebSocketConnection extends Sender, BlockingSender { */ String endpointId(); - /** - * - * @param name - * @return the decoded value of the path parameter or {@code null} - * @see WebSocket#path() - */ - String pathParam(String name); - /** * Sends messages to all open clients connected to the same WebSocket endpoint. * @@ -57,86 +40,18 @@ public interface WebSocketConnection extends Sender, BlockingSender { */ Set getOpenConnections(); - /** - * @return {@code true} if the HTTP connection is encrypted via SSL/TLS - */ - boolean isSecure(); - - /** - * @return {@code true} if the WebSocket is closed - */ - boolean isClosed(); - - /** - * - * @return the close reason or {@code null} if the connection is not closed - */ - CloseReason closeReason(); - - /** - * - * @return {@code true} if the WebSocket is open - */ - default boolean isOpen() { - return !isClosed(); - } - - /** - * Close the connection. - * - * @return a new {@link Uni} with a {@code null} item - */ - @CheckReturnValue - default Uni close() { - return close(CloseReason.NORMAL); - } - - /** - * Close the connection with a specific reason. - * - * @param reason - * @return a new {@link Uni} with a {@code null} item - */ - Uni close(CloseReason reason); - - /** - * Close the connection and wait for the completion. - */ - default void closeAndAwait() { - close().await().indefinitely(); - } - - /** - * Close the connection and wait for the completion. - */ - default void closeAndAwait(CloseReason reason) { - close(reason).await().indefinitely(); - } - - /** - * - * @return the handshake request - */ - HandshakeRequest handshakeRequest(); - /** * * @return the subprotocol selected by the handshake */ String subprotocol(); - /** - * - * @return the time when this connection was created - */ - Instant creationTime(); - /** * Makes it possible to send messages to all clients connected to the same WebSocket endpoint. * * @see WebSocketConnection#getOpenConnections() */ - interface BroadcastSender extends Sender, BlockingSender { + interface BroadcastSender extends BlockingSender { /** * diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/UserDataImpl.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/UserDataImpl.java new file mode 100644 index 00000000000000..a92f5192b2c3c5 --- /dev/null +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/UserDataImpl.java @@ -0,0 +1,44 @@ +package io.quarkus.websockets.next.runtime; + +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import io.quarkus.websockets.next.UserData; + +final class UserDataImpl implements UserData { + + private final ConcurrentMap data; + + UserDataImpl() { + this.data = new ConcurrentHashMap<>(); + } + + @SuppressWarnings("unchecked") + @Override + public VALUE get(TypedKey key) { + return (VALUE) data.get(key.value()); + } + + @SuppressWarnings("unchecked") + @Override + public VALUE put(TypedKey key, VALUE value) { + return (VALUE) data.put(key.value(), value); + } + + @SuppressWarnings("unchecked") + @Override + public VALUE remove(TypedKey key) { + return (VALUE) data.remove(key.value()); + } + + @Override + public void clear() { + data.clear(); + } + + @Override + public int size() { + return data.size(); + } + +} diff --git a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java index 4febc7792d8133..5f22f1e9f29fd4 100644 --- a/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java +++ b/extensions/websockets-next/runtime/src/main/java/io/quarkus/websockets/next/runtime/WebSocketConnectionBase.java @@ -8,6 +8,8 @@ import io.quarkus.vertx.utils.NoBoundChecksBuffer; import io.quarkus.websockets.next.CloseReason; +import io.quarkus.websockets.next.Connection; +import io.quarkus.websockets.next.UserData; import io.quarkus.websockets.next.HandshakeRequest; import io.quarkus.websockets.next.WebSocketConnection.BroadcastSender; import io.smallrye.mutiny.Uni; @@ -17,7 +19,7 @@ import io.vertx.core.json.JsonArray; import io.vertx.core.json.JsonObject; -public abstract class WebSocketConnectionBase { +public abstract class WebSocketConnectionBase implements Connection { private static final Logger LOG = Logger.getLogger(WebSocketConnectionBase.class); @@ -33,6 +35,8 @@ public abstract class WebSocketConnectionBase { protected final TrafficLogger trafficLogger; + private final UserData data; + WebSocketConnectionBase(Map pathParams, Codecs codecs, HandshakeRequest handshakeRequest, TrafficLogger trafficLogger) { this.identifier = UUID.randomUUID().toString(); @@ -41,18 +45,22 @@ public abstract class WebSocketConnectionBase { this.handshakeRequest = handshakeRequest; this.creationTime = Instant.now(); this.trafficLogger = trafficLogger; + this.data = new UserDataImpl(); } abstract WebSocketBase webSocket(); + @Override public String id() { return identifier; } + @Override public String pathParam(String name) { return pathParams.get(name); } + @Override public Uni sendText(String message) { Uni uni = Uni.createFrom().completionStage(() -> webSocket().writeTextMessage(message).toCompletionStage()); return trafficLogger == null ? uni : uni.invoke(() -> { @@ -60,11 +68,13 @@ public Uni sendText(String message) { }); } + @Override public Uni sendBinary(Buffer message) { Uni uni = Uni.createFrom().completionStage(() -> webSocket().writeBinaryMessage(message).toCompletionStage()); return trafficLogger == null ? uni : uni.invoke(() -> trafficLogger.binaryMessageSent(this, message)); } + @Override public Uni sendText(M message) { String text; // Use the same conversion rules as defined for the OnTextMessage @@ -79,6 +89,7 @@ public Uni sendText(M message) { return sendText(text); } + @Override public Uni sendPing(Buffer data) { return Uni.createFrom().completionStage(() -> webSocket().writePing(data).toCompletionStage()); } @@ -91,14 +102,17 @@ void sendAutoPing() { }); } + @Override public Uni sendPong(Buffer data) { return Uni.createFrom().completionStage(() -> webSocket().writePong(data).toCompletionStage()); } + @Override public Uni close() { return close(CloseReason.NORMAL); } + @Override public Uni close(CloseReason reason) { if (isClosed()) { LOG.warnf("Connection already closed: %s", this); @@ -108,18 +122,22 @@ public Uni close(CloseReason reason) { .completionStage(() -> webSocket().close((short) reason.getCode(), reason.getMessage()).toCompletionStage()); } + @Override public boolean isSecure() { return webSocket().isSsl(); } + @Override public boolean isClosed() { return webSocket().isClosed(); } + @Override public HandshakeRequest handshakeRequest() { return handshakeRequest; } + @Override public Instant creationTime() { return creationTime; } @@ -128,6 +146,7 @@ public BroadcastSender broadcast() { throw new UnsupportedOperationException(); } + @Override public CloseReason closeReason() { WebSocketBase ws = webSocket(); if (ws.isClosed()) { @@ -140,4 +159,10 @@ public CloseReason closeReason() { } return null; } + + @Override + public UserData userData() { + return data; + } + }