diff --git a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/security/WebsocketBasicAuthTestCase.java b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/security/WebsocketBasicAuthTestCase.java index a155c164bf..0497723a63 100644 --- a/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/security/WebsocketBasicAuthTestCase.java +++ b/websockets-jsr/src/test/java/io/undertow/websockets/jsr/test/security/WebsocketBasicAuthTestCase.java @@ -43,7 +43,6 @@ import jakarta.websocket.ContainerProvider; import jakarta.websocket.Endpoint; import jakarta.websocket.EndpointConfig; -import jakarta.websocket.MessageHandler; import jakarta.websocket.OnOpen; import jakarta.websocket.Session; import jakarta.websocket.server.ServerEndpoint; @@ -147,12 +146,7 @@ public static void cleanup() throws ServletException { @Test public void testAuthenticatedWebsocket() throws Exception { ProgramaticClientEndpoint endpoint = new ProgramaticClientEndpoint(); - ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().configurator(new ClientConfigurator(){ - @Override - public void beforeRequest(Map> headers) { - headers.put(AUTHORIZATION.toString(), Collections.singletonList(BASIC + " " + FlexBase64.encodeString("user1:password1".getBytes(), false))); - } - }).build(); + ClientEndpointConfig clientEndpointConfig = ClientEndpointConfig.Builder.create().configurator(new CustomClientConfigurator()).build(); ContainerProvider.getWebSocketContainer().connectToServer(endpoint, clientEndpointConfig, new URI("ws://" + DefaultServer.getHostAddress("default") + ":" + DefaultServer.getHostPort("default") + "/servletContext/secured")); assertEquals("user1", endpoint.getResponses().poll(15, TimeUnit.SECONDS)); endpoint.session.close(); @@ -179,13 +173,7 @@ public static class ProgramaticClientEndpoint extends Endpoint { @Override public void onOpen(Session session, EndpointConfig config) { this.session = session; - session.addMessageHandler(new MessageHandler.Whole() { - - @Override - public void onMessage(String message) { - responses.add(message); - } - }); + session.addMessageHandler(String.class, (message) -> responses.add(message)); } @Override @@ -217,12 +205,7 @@ public void init(FilterConfig filterConfig) { @Override public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException { - filterChain.doFilter(new HttpServletRequestWrapper((HttpServletRequest) servletRequest) { - @Override - public Principal getUserPrincipal() { - return () -> "wrapped"; - } - }, servletResponse); + filterChain.doFilter(new ServletRequestWrapper((HttpServletRequest) servletRequest), servletResponse); } @Override @@ -231,4 +214,23 @@ public void destroy() { } } + private static class ServletRequestWrapper extends HttpServletRequestWrapper { + + ServletRequestWrapper(HttpServletRequest request) { + super(request); + } + + @Override + public Principal getUserPrincipal() { + return () -> "wrapped"; + } + } + + private static class CustomClientConfigurator extends ClientConfigurator { + + @Override + public void beforeRequest(Map> headers) { + headers.put(AUTHORIZATION.toString(), Collections.singletonList(BASIC + " " + FlexBase64.encodeString("user1:password1".getBytes(), false))); + } + } }