Skip to content

Commit

Permalink
WebSockets Next: introduce OpenConnections
Browse files Browse the repository at this point in the history
- also add WebSocket#endpointId() and WebSocketConnection#endpointId()
  • Loading branch information
mkouba committed Apr 4, 2024
1 parent cc07077 commit 637a473
Show file tree
Hide file tree
Showing 13 changed files with 308 additions and 26 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
*/
public final class GeneratedEndpointBuildItem extends MultiBuildItem {

public final String endpointId;
public final String endpointClassName;
public final String generatedClassName;
public final String path;

GeneratedEndpointBuildItem(String endpointClassName, String generatedClassName, String path) {
GeneratedEndpointBuildItem(String endpointId, String endpointClassName, String generatedClassName, String path) {
this.endpointId = endpointId;
this.endpointClassName = endpointClassName;
this.generatedClassName = generatedClassName;
this.path = path;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem {

public final BeanInfo bean;
public final String path;
public final String endpointId;
public final WebSocket.ExecutionMode executionMode;
public final Callback onOpen;
public final Callback onTextMessage;
Expand All @@ -45,11 +46,13 @@ public final class WebSocketEndpointBuildItem extends MultiBuildItem {
public final Callback onClose;
public final List<Callback> onErrors;

WebSocketEndpointBuildItem(BeanInfo bean, String path, WebSocket.ExecutionMode executionMode, Callback onOpen,
WebSocketEndpointBuildItem(BeanInfo bean, String path, String endpointId, WebSocket.ExecutionMode executionMode,
Callback onOpen,
Callback onTextMessage, Callback onBinaryMessage, Callback onPongMessage, Callback onClose,
List<Callback> onErrors) {
this.bean = bean;
this.path = path;
this.endpointId = endpointId;
this.executionMode = executionMode;
this.onOpen = onOpen;
this.onTextMessage = onTextMessage;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,6 +149,7 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
globalErrorHandlers.produce(new GlobalErrorHandlersBuildItem(List.copyOf(globalErrors.values())));

// Collect WebSocket endpoints
Map<String, DotName> idToEndpoint = new HashMap<>();
Map<String, DotName> pathToEndpoint = new HashMap<>();
for (BeanInfo bean : beanDiscoveryFinished.beanStream().classBeans()) {
ClassInfo beanClass = bean.getTarget().get().asClass();
Expand All @@ -159,10 +160,23 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
// Sub-websocket - merge the path from the enclosing classes
path = mergePath(getPathPrefix(index, beanClass.enclosingClass()), path);
}
DotName previous = pathToEndpoint.put(path, beanClass.name());
if (previous != null) {
DotName prevPath = pathToEndpoint.put(path, beanClass.name());
if (prevPath != null) {
throw new WebSocketServerException(
String.format("Multiple endpoints [%s, %s] define the same path: %s", previous, beanClass, path));
String.format("Multiple endpoints [%s, %s] define the same path: %s", prevPath, beanClass, path));
}
String endpointId;
AnnotationValue endpointIdValue = webSocketAnnotation.value("endpointId");
if (endpointIdValue == null) {
endpointId = beanClass.name().toString();
} else {
endpointId = endpointIdValue.asString();
}
DotName prevId = idToEndpoint.put(endpointId, beanClass.name());
if (prevId != null) {
throw new WebSocketServerException(
String.format("Multiple endpoints [%s, %s] define the same endpoint id: %s", prevId, beanClass,
endpointId));
}
Callback onOpen = findCallback(beanArchiveIndex.getIndex(), beanClass, WebSocketDotNames.ON_OPEN,
callbackArguments, transformedAnnotations, path);
Expand All @@ -182,7 +196,7 @@ public void collectEndpoints(BeanArchiveIndexBuildItem beanArchiveIndex,
+ beanClass);
}
AnnotationValue executionMode = webSocketAnnotation.value("executionMode");
endpoints.produce(new WebSocketEndpointBuildItem(bean, path,
endpoints.produce(new WebSocketEndpointBuildItem(bean, path, endpointId,
executionMode != null ? WebSocket.ExecutionMode.valueOf(executionMode.asEnum())
: WebSocket.ExecutionMode.SERIAL,
onOpen,
Expand Down Expand Up @@ -234,8 +248,9 @@ public String apply(String name) {
String generatedName = generateEndpoint(endpoint, argumentProviders, transformedAnnotations,
index.getIndex(), classOutput, globalErrorHandlers);
reflectiveClasses.produce(ReflectiveClassBuildItem.builder(generatedName).constructors().build());
generatedEndpoints.produce(new GeneratedEndpointBuildItem(endpoint.bean.getImplClazz().name().toString(),
generatedName, endpoint.path));
generatedEndpoints
.produce(new GeneratedEndpointBuildItem(endpoint.endpointId, endpoint.bean.getImplClazz().name().toString(),
generatedName, endpoint.path));
}
}

Expand All @@ -250,7 +265,7 @@ public void registerRoutes(WebSocketServerRecorder recorder, HttpRootPathBuildIt
.route(httpRootPath.relativePath(endpoint.path))
.displayOnNotFoundPage("WebSocket Endpoint")
.handlerType(HandlerType.NORMAL)
.handler(recorder.createEndpointHandler(endpoint.generatedClassName));
.handler(recorder.createEndpointHandler(endpoint.generatedClassName, endpoint.endpointId));
routes.produce(builder.build());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ public class ConnectionArgumentTest {
void testArgument() {
String message = "ok";
String header = "fool";
WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header),
testUri);
JsonObject reply = client.sendAndAwaitReply(message).toJsonObject();
assertEquals(header, reply.getString("header"), reply.toString());
assertEquals(message, reply.getString("message"), reply.toString());
try (WSClient client = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader("X-Test", header),
testUri)) {
JsonObject reply = client.sendAndAwaitReply(message).toJsonObject();
assertEquals(header, reply.getString("header"), reply.toString());
assertEquals(message, reply.getString("message"), reply.toString());
}
}

@WebSocket(path = "/echo")
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package io.quarkus.websockets.next.test.endpoints;

import static org.junit.jupiter.api.Assertions.fail;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketServerException;

public class AmbiguousEndpointIdTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint1.class, Endpoint2.class);
})
.setExpectedException(WebSocketServerException.class);

@Test
public void testEndpointIds() {
fail();
}

@WebSocket(path = "/ws1", endpointId = "foo")
public static class Endpoint1 {

@OnOpen
void open() {
}

}

@WebSocket(path = "/ws2", endpointId = "foo")
public static class Endpoint2 {

@OnOpen
void open() {
}

}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
package io.quarkus.websockets.next.test.openconnections;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertNotNull;
import static org.junit.jupiter.api.Assertions.assertNull;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;

import java.net.URI;
import java.util.Collection;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import jakarta.inject.Inject;

import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.RegisterExtension;

import io.quarkus.test.QuarkusUnitTest;
import io.quarkus.test.common.http.TestHTTPResource;
import io.quarkus.websockets.next.OnClose;
import io.quarkus.websockets.next.OnOpen;
import io.quarkus.websockets.next.OpenConnections;
import io.quarkus.websockets.next.WebSocket;
import io.quarkus.websockets.next.WebSocketConnection;
import io.quarkus.websockets.next.test.utils.WSClient;
import io.vertx.core.Vertx;
import io.vertx.core.http.WebSocketConnectOptions;

public class OpenConnectionsTest {

@RegisterExtension
public static final QuarkusUnitTest test = new QuarkusUnitTest()
.withApplicationRoot(root -> {
root.addClasses(Endpoint.class, WSClient.class);
});

@Inject
Vertx vertx;

@TestHTTPResource("endpoint")
URI endUri;

@Inject
OpenConnections connections;

@Test
void testOpenConnections() throws Exception {
String headerName = "X-Test";
String header2 = "foo";
String header3 = "bar";

for (WebSocketConnection c : connections) {
fail("No connection should be found: " + c);
}

try (WSClient client1 = WSClient.create(vertx).connect(endUri);
WSClient client2 = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader(headerName, header2),
endUri);
WSClient client3 = WSClient.create(vertx).connect(new WebSocketConnectOptions().addHeader(headerName, header3),
endUri)) {

client1.waitForMessages(1);
String client1Id = client1.getMessages().get(0).toString();

client2.waitForMessages(1);
String client2Id = client2.getMessages().get(0).toString();

client3.waitForMessages(1);
String client3Id = client3.getMessages().get(0).toString();

assertNotNull(connections.findByConnectionId(client1Id).orElse(null));
Collection<WebSocketConnection> found = connections.stream()
.filter(c -> header3.equals(c.handshakeRequest().header(headerName)))
.toList();
assertEquals(1, found.size());
assertEquals(client3Id, found.iterator().next().id());

found = connections.listAll();
assertEquals(3, found.size());
for (WebSocketConnection c : found) {
assertTrue(c.id().equals(client1Id) || c.id().equals(client2Id) || c.id().equals(client3Id));
}

client2.disconnect();
assertTrue(Endpoint.CLOSED_LATCH.await(5, TimeUnit.SECONDS));

assertEquals(2, connections.listAll().size());
assertNull(connections.stream().filter(c -> c.id().equals(client2Id)).findFirst().orElse(null));

found = connections.findByEndpointId("end");
assertEquals(2, found.size());
}
}

@WebSocket(endpointId = "end", path = "/endpoint")
public static class Endpoint {

static final CountDownLatch CLOSED_LATCH = new CountDownLatch(1);

@OnOpen
String open(WebSocketConnection connection) {
return connection.id();
}

@OnClose
void close() {
CLOSED_LATCH.countDown();
}

}

}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import io.vertx.core.http.WebSocketClient;
import io.vertx.core.http.WebSocketConnectOptions;

public class WSClient {
public class WSClient implements AutoCloseable {

private final WebSocketClient client;
private AtomicReference<WebSocket> socket = new AtomicReference<>();
Expand Down Expand Up @@ -124,4 +124,10 @@ public Buffer sendAndAwaitReply(String message) {
public boolean isClosed() {
return socket.get().isClosed();
}

@Override
public void close() {
disconnect();
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
package io.quarkus.websockets.next;

import java.util.Collection;
import java.util.Optional;
import java.util.stream.Stream;

import io.smallrye.common.annotation.Experimental;

/**
* Provides convenient access to all open connections from clients to {@link WebSocket} endpoints on the server.
* <p>
* Quarkus provides a built-in CDI bean with the {@link jakarta.inject.Singleton} scope that implements this interface.
*/
@Experimental("This API is experimental and may change in the future")
public interface OpenConnections extends Iterable<WebSocketConnection> {

/**
* Returns an immutable snapshot of all open connections at the given time.
*
* @return an immutable collection of all open connections
*/
default Collection<WebSocketConnection> listAll() {
return stream().toList();
}

/**
* Returns an immutable snapshot of all open connections for the given endpoint id.
*
* @param endpointId
* @return an immutable collection of all open connections for the given endpoint id
* @see WebSocket#endpointId()
*/
default Collection<WebSocketConnection> findByEndpointId(String endpointId) {
return stream().filter(c -> c.endpointId().equals(endpointId)).toList();
}

/**
* Returns the open connection with the given id.
*
* @param connectionId
* @return the open connection or empty {@link Optional} if no open connection with the given id exists
* @see WebSocketConnection#id()
*/
default Optional<WebSocketConnection> findByConnectionId(String connectionId) {
return stream().filter(c -> c.id().equals(connectionId)).findFirst();
}

/**
* Returns the stream of all open connections at the given time.
*
* @return the stream of open connections
*/
Stream<WebSocketConnection> stream();

}
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,24 @@
*/
public String path();

/**
* By default, the fully qualified name of the annotated class is used.
*
* @return the endpoint id
* @see WebSocketConnection#endpointId()
*/
public String endpointId() default FCQN_NAME;

/**
* The execution mode used to process incoming messages for a specific connection.
*/
public ExecutionMode executionMode() default ExecutionMode.SERIAL;

/**
* Constant value for {@link #endpointId()} indicating that the fully qualified name of the annotated class should be used.
*/
String FCQN_NAME = "<<fcqn name>>";

/**
* Defines the execution mode used to process incoming messages for a specific connection.
*
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,13 @@ public interface WebSocketConnection extends Sender, BlockingSender {
*/
String id();

/**
*
* @return the endpoint id
* @see WebSocket#endpointId()
*/
String endpointId();

/**
*
* @param name
Expand Down
Loading

0 comments on commit 637a473

Please sign in to comment.