diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 78c98dd99f..39312e29ad 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -21,6 +21,7 @@ import java.security.cert.Certificate; import java.security.cert.X509Certificate; import java.util.Arrays; +import java.util.Set; import javax.net.ssl.SSLPeerUnverifiedException; import org.apache.logging.log4j.LogManager; @@ -55,6 +56,8 @@ public class SecuritySSLRequestHandler implements Tr private final SslExceptionHandler errorHandler; private final SSLConfig SSLConfig; + private static final Set DEFAULT_CHANNEL_TYPES = Set.of("direct", "transport"); + public SecuritySSLRequestHandler( String action, TransportRequestHandler actualHandler, @@ -86,6 +89,11 @@ public final void messageReceived(T request, TransportChannel channel, Task task ThreadContext threadContext = getThreadContext(); + String channelType = channel.getChannelType(); + if (!DEFAULT_CHANNEL_TYPES.contains(channelType)) { + channel = getInnerChannel(channel); + } + threadContext.putTransient( ConfigConstants.USE_JDK_SERIALIZATION, channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) @@ -97,11 +105,6 @@ public final void messageReceived(T request, TransportChannel channel, Task task throw exception; } - String channelType = channel.getChannelType(); - if (!channelType.equals("direct") && !channelType.equals("transport")) { - channel = getInnerChannel(channel); - } - if (!"transport".equals(channel.getChannelType())) { // netty4 messageReceivedDecorate(request, actualHandler, channel, task); return; diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java index b6967b0e68..2d10b6f84f 100644 --- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -9,12 +9,15 @@ */ package org.opensearch.security.transport; +import java.io.IOException; + import org.junit.Assert; import org.junit.Before; import org.junit.Test; import org.opensearch.Version; import org.opensearch.common.settings.Settings; +import org.opensearch.core.transport.TransportResponse; import org.opensearch.security.ssl.SslExceptionHandler; import org.opensearch.security.ssl.transport.PrincipalExtractor; import org.opensearch.security.ssl.transport.SSLConfig; @@ -27,11 +30,13 @@ import org.opensearch.transport.TransportRequestHandler; import org.mockito.ArgumentMatchers; +import org.mockito.InOrder; import org.mockito.Mock; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyInt; import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -93,4 +98,81 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } + + @Test + public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + TransportChannel wrappedChannel = new WrappedTransportChannel(transportChannel); + Task task = mock(Task.class); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0); + when(transportChannel.getChannelType()).thenReturn("other"); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.V_3_0_0); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + } + + @Test + public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exception { + TransportRequest transportRequest = mock(TransportRequest.class); + TransportChannel transportChannel = mock(TransportChannel.class); + WrappedTransportChannel wrappedChannel = mock(WrappedTransportChannel.class); + Task task = mock(Task.class); + when(wrappedChannel.getInnerChannel()).thenReturn(transportChannel); + when(wrappedChannel.getChannelType()).thenReturn("other"); + doNothing().when(transportChannel).sendResponse(ArgumentMatchers.any(Exception.class)); + when(transportChannel.getVersion()).thenReturn(Version.V_2_10_0); + + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + InOrder inOrder = inOrder(wrappedChannel, transportChannel); + + inOrder.verify(wrappedChannel).getInnerChannel(); + inOrder.verify(transportChannel).getVersion(); + } + + public class WrappedTransportChannel implements TransportChannel { + + private TransportChannel inner; + + public WrappedTransportChannel(TransportChannel inner) { + this.inner = inner; + } + + @Override + public String getProfileName() { + return "WrappedTransportChannelProfileName"; + } + + public TransportChannel getInnerChannel() { + return this.inner; + } + + @Override + public void sendResponse(TransportResponse response) throws IOException { + inner.sendResponse(response); + } + + @Override + public void sendResponse(Exception e) throws IOException { + + } + + @Override + public String getChannelType() { + return "WrappedTransportChannelType"; + } + } }