diff --git a/src/main/java/org/opensearch/security/filter/SecurityFilter.java b/src/main/java/org/opensearch/security/filter/SecurityFilter.java index b9d4a73967..12f7e3c5e0 100644 --- a/src/main/java/org/opensearch/security/filter/SecurityFilter.java +++ b/src/main/java/org/opensearch/security/filter/SecurityFilter.java @@ -102,7 +102,7 @@ public class SecurityFilter implements ActionFilter { protected final Logger log = LogManager.getLogger(this.getClass()); private final PrivilegesEvaluator evalp; private final AdminDNs adminDns; - private DlsFlsRequestValve dlsFlsValve; + private final DlsFlsRequestValve dlsFlsValve; private final AuditLog auditLog; private final ThreadContext threadContext; private final ClusterService cs; @@ -184,7 +184,7 @@ private void ap } if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) { - threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false); + threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true); } final ComplianceConfig complianceConfig = auditLog.getComplianceConfig(); @@ -255,7 +255,7 @@ private void ap ); threadContext.putHeader( - "_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(), + "_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID(), Thread.currentThread().getName() + " FILTER -> " + "Node " @@ -481,11 +481,7 @@ public void onFailure(Exception e) { } private static boolean isUserAdmin(User user, final AdminDNs adminDns) { - if (user != null && adminDns.isAdmin(user)) { - return true; - } - - return false; + return user != null && adminDns.isAdmin(user); } private void attachSourceFieldContext(ActionRequest request) { diff --git a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java index 39312e29ad..8b0b35a8a0 100644 --- a/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java +++ b/src/main/java/org/opensearch/security/ssl/transport/SecuritySSLRequestHandler.java @@ -46,6 +46,8 @@ import io.netty.handler.ssl.SslHandler; +import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization; + public class SecuritySSLRequestHandler implements TransportRequestHandler { private final String action; @@ -94,10 +96,7 @@ public final void messageReceived(T request, TransportChannel channel, Task task channel = getInnerChannel(channel); } - threadContext.putTransient( - ConfigConstants.USE_JDK_SERIALIZATION, - channel.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION) - ); + threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, shouldUseJDKSerialization(channel.getVersion())); if (SSLRequestHelper.containsBadHeader(threadContext, "_opendistro_security_ssl_")) { final Exception exception = ExceptionUtils.createBadHeaderException(); diff --git a/src/main/java/org/opensearch/security/support/Base64Helper.java b/src/main/java/org/opensearch/security/support/Base64Helper.java index a5fbab8515..f3872338e1 100644 --- a/src/main/java/org/opensearch/security/support/Base64Helper.java +++ b/src/main/java/org/opensearch/security/support/Base64Helper.java @@ -28,6 +28,8 @@ import java.io.Serializable; +import org.opensearch.Version; + public class Base64Helper { public static String serializeObject(final Serializable object, final boolean useJDKSerialization) { @@ -35,11 +37,11 @@ public static String serializeObject(final Serializable object, final boolean us } public static String serializeObject(final Serializable object) { - return serializeObject(object, false); + return serializeObject(object, true); } public static Serializable deserializeObject(final String string) { - return deserializeObject(string, false); + return deserializeObject(string, true); } public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) { @@ -69,4 +71,32 @@ public static String ensureJDKSerialized(final String string) { // If we see an exception now, we want the caller to see it - return Base64Helper.serializeObject(serializable, true); } + + /** + * Ensures that the returned string is custom serialized. + * + * If the supplied string is a JDK serialized representation, will deserialize it and further serialize using + * custom, otherwise returns the string as is. + * + * @param string original string, can be JDK or custom serialized + * @return custom serialized string + */ + public static String ensureCustomSerialized(final String string) { + Serializable serializable; + try { + serializable = Base64Helper.deserializeObject(string, true); + } catch (Exception e) { + // We received an exception when de-serializing the given string. It is probably custom serialized. + // Try to deserialize using custom + Base64Helper.deserializeObject(string, false); + // Since we could deserialize the object using custom, the string is already custom serialized, return as is + return string; + } + // If we see an exception now, we want the caller to see it - + return Base64Helper.serializeObject(serializable, false); + } + + public static boolean shouldUseJDKSerialization(Version remoteVersion) { + return !remoteVersion.equals(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); + } } diff --git a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java index fe1094c411..ec26a70192 100644 --- a/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java +++ b/src/main/java/org/opensearch/security/transport/SecurityInterceptor.java @@ -72,12 +72,13 @@ import org.opensearch.transport.TransportResponseHandler; import static org.opensearch.security.OpenSearchSecurityPlugin.isActionTraceEnabled; +import static org.opensearch.security.support.Base64Helper.shouldUseJDKSerialization; public class SecurityInterceptor { protected final Logger log = LogManager.getLogger(getClass()); - private BackendRegistry backendRegistry; - private AuditLog auditLog; + private final BackendRegistry backendRegistry; + private final AuditLog auditLog; private final ThreadPool threadPool; private final PrincipalExtractor principalExtractor; private final InterClusterRequestEvaluator requestEvalProvider; @@ -148,7 +149,7 @@ public void sendRequestDecorate( final String origCCSTransientMf = getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_MASKED_FIELD_CCS); final boolean isDebugEnabled = log.isDebugEnabled(); - final boolean useJDKSerialization = connection.getVersion().before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); + final boolean useJDKSerialization = shouldUseJDKSerialization(connection.getVersion()); final boolean isSameNodeRequest = localNode != null && localNode.equals(connection.getNode()); try (ThreadContext.StoredContext stashedContext = getThreadContext().stashContext()) { @@ -226,13 +227,13 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL ); } - if (useJDKSerialization) { - Map jdkSerializedHeaders = new HashMap<>(); + if (!useJDKSerialization) { + Map customSerializedHeaders = new HashMap<>(); HeaderHelper.getAllSerializedHeaderNames() .stream() .filter(k -> headerMap.get(k) != null) - .forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k)))); - headerMap.putAll(jdkSerializedHeaders); + .forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k)))); + headerMap.putAll(customSerializedHeaders); } getThreadContext().putHeader(headerMap); @@ -249,7 +250,7 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL if (isActionTraceEnabled()) { getThreadContext().putHeader( - "_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID().toString(), + "_opendistro_security_trace" + System.currentTimeMillis() + "#" + UUID.randomUUID(), Thread.currentThread().getName() + " IC -> " + action diff --git a/src/test/java/org/opensearch/security/support/Base64HelperTest.java b/src/test/java/org/opensearch/security/support/Base64HelperTest.java index 3bc81aaebc..70f7a1e4d1 100644 --- a/src/test/java/org/opensearch/security/support/Base64HelperTest.java +++ b/src/test/java/org/opensearch/security/support/Base64HelperTest.java @@ -38,6 +38,11 @@ public void testSerde() { String test = "string"; Assert.assertEquals(test, ds(test)); Assert.assertEquals(test, dsJDK(test)); + + // verify that default methods use JDK serialization + Assert.assertEquals(serializeObject(test), serializeObject(test, true)); + String serialized = serializeObject(test); + Assert.assertEquals(deserializeObject(serialized), deserializeObject(serialized, true)); } @Test @@ -48,4 +53,13 @@ public void testEnsureJDKSerialized() { Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(jdkSerialized)); Assert.assertEquals(jdkSerialized, Base64Helper.ensureJDKSerialized(customSerialized)); } + + @Test + public void testEnsureCustomSerialized() { + String test = "string"; + String jdkSerialized = Base64Helper.serializeObject(test, true); + String customSerialized = Base64Helper.serializeObject(test, false); + Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(jdkSerialized)); + Assert.assertEquals(customSerialized, Base64Helper.ensureCustomSerialized(customSerialized)); + } } diff --git a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java index 903ad89eac..ae4edccc92 100644 --- a/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java +++ b/src/test/java/org/opensearch/security/transport/SecurityInterceptorTests.java @@ -21,6 +21,7 @@ import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; +import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.common.transport.TransportAddress; import org.opensearch.core.transport.TransportResponse; import org.opensearch.extensions.ExtensionsManager; @@ -51,6 +52,7 @@ import static java.util.Collections.emptySet; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -108,8 +110,7 @@ public void setup() { ); } - private void testSendRequestDecorate(Version remoteNodeVersion) { - boolean useJDKSerialization = remoteNodeVersion.before(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); + private void testSendRequestDecorate(DiscoveryNode localNode, DiscoveryNode otherNode, boolean shouldUseJDKSerialization) { ClusterName clusterName = ClusterName.DEFAULT; when(clusterService.getClusterName()).thenReturn(clusterName); @@ -143,17 +144,7 @@ private void testSendRequestDecorate(Version remoteNodeVersion) { @SuppressWarnings("unchecked") TransportResponseHandler handler = mock(TransportResponseHandler.class); - InetAddress localAddress = null; - try { - localAddress = InetAddress.getByName("0.0.0.0"); - } catch (final UnknownHostException uhe) { - throw new RuntimeException(uhe); - } - - DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(localAddress, 1234), Version.CURRENT); Connection connection1 = transportService.getConnection(localNode); - - DiscoveryNode otherNode = new DiscoveryNode("remote-node", new TransportAddress(localAddress, 4321), remoteNodeVersion); Connection connection2 = transportService.getConnection(otherNode); // from thread context inside sendRequestDecorate @@ -176,7 +167,7 @@ public void sendRequest( // from original context User transientUser = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser, user); - assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + assertNull(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER)); // checking thread context inside sendRequestDecorate sender = new AsyncSender() { @@ -189,7 +180,7 @@ public void sendRequest( TransportResponseHandler handler ) { String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER); - assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, useJDKSerialization)); + assertEquals(serializedUserHeader, Base64Helper.serializeObject(user, shouldUseJDKSerialization)); } }; // isSameNodeRequest = false @@ -198,20 +189,52 @@ public void sendRequest( // from original context User transientUser2 = threadPool.getThreadContext().getTransient(ConfigConstants.OPENDISTRO_SECURITY_USER); assertEquals(transientUser2, user); - assertEquals(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER), null); + assertNull(threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER)); } + /** + * Tests the scenario when remote node is on same OS version + */ @Test public void testSendRequestDecorate() { - testSendRequestDecorate(Version.CURRENT); + DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT); + DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.CURRENT); + testSendRequestDecorate(localNode, otherNode, true); } /** - * Tests the scenario when remote node does not implement custom serialization protocol and uses JDK serialization + * Tests the scenarios for mixed node versions */ @Test - public void testSendRequestDecorateWhenRemoteNodeUsesJDKSerde() { - testSendRequestDecorate(Version.V_2_0_0); + public void testSendRequestDecorateWithMixedNodeVersions() { + + // local on latest version, remote on 2.11.0 - should use custom + + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT); + DiscoveryNode otherNode = new DiscoveryNode( + "other-node", + new TransportAddress(getLocalAddress(), 3456), + ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION + ); + testSendRequestDecorate(localNode, otherNode, false); + } + + // remote node is on a version > 2.11.1 while local node is on version 2.11.1 - should use JDK + try (ThreadContext.StoredContext ignore = threadPool.getThreadContext().stashContext()) { + DiscoveryNode localNode = new DiscoveryNode("local-node", new TransportAddress(getLocalAddress(), 1234), Version.CURRENT); + DiscoveryNode otherNode = new DiscoveryNode("other-node", new TransportAddress(getLocalAddress(), 3456), Version.V_2_11_1); + testSendRequestDecorate(localNode, otherNode, true); + } + + } + + private static InetAddress getLocalAddress() { + try { + return InetAddress.getByName("0.0.0.0"); + } catch (final UnknownHostException uhe) { + throw new RuntimeException(uhe); + } } } diff --git a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java index ba791c2494..b63fff45cd 100644 --- a/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java +++ b/src/test/java/org/opensearch/security/transport/SecuritySSLRequestHandlerTests.java @@ -89,9 +89,15 @@ public void testUseJDKSerializationHeaderIsSetOnMessageReceived() throws Excepti Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); threadPool.getThreadContext().stashContext(); - when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.CURRENT); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, transportChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + } @Test @@ -108,9 +114,14 @@ public void testUseJDKSerializationHeaderIsSetWithWrapperChannel() throws Except Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); threadPool.getThreadContext().stashContext(); - when(transportChannel.getVersion()).thenReturn(Version.V_2_11_0); + when(transportChannel.getVersion()).thenReturn(ConfigConstants.FIRST_CUSTOM_SERIALIZATION_SUPPORTED_OS_VERSION); Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); Assert.assertFalse(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); + + threadPool.getThreadContext().stashContext(); + when(transportChannel.getVersion()).thenReturn(Version.CURRENT); + Assert.assertThrows(Exception.class, () -> securitySSLRequestHandler.messageReceived(transportRequest, wrappedChannel, task)); + Assert.assertTrue(threadPool.getThreadContext().getTransient(ConfigConstants.USE_JDK_SERIALIZATION)); } @Test @@ -135,7 +146,7 @@ public void testUseJDKSerializationHeaderIsSetAfterGetInnerChannel() throws Exce public class WrappedTransportChannel implements TransportChannel { - private TransportChannel inner; + private final TransportChannel inner; public WrappedTransportChannel(TransportChannel inner) { this.inner = inner;