Skip to content

Commit

Permalink
Add ensureCustomSerialization to ensure that headers are serialized c…
Browse files Browse the repository at this point in the history
…orrectly with multiple transport hops (#4741)

Signed-off-by: Craig Perkins <[email protected]>
  • Loading branch information
cwperks authored Sep 19, 2024
1 parent 7ddbf6a commit 8ae88a7
Show file tree
Hide file tree
Showing 6 changed files with 150 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.Version;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterStateListener;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -67,6 +68,17 @@ public boolean isInitialized() {
return initialized;
}

public Version getMinNodeVersion() {
if (nodes == null) {
if (log.isDebugEnabled()) {
log.debug("Cluster Info Holder not initialized yet for 'nodes'");
}
return null;
}

return nodes.getMinNodeVersion();
}

public Boolean hasNode(DiscoveryNode node) {
if (nodes == null) {
if (log.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private <Request extends ActionRequest, Response extends ActionResponse> void ap
}

if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) {
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false);
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true);
}

final ComplianceConfig complianceConfig = auditLog.getComplianceConfig();
Expand Down
28 changes: 26 additions & 2 deletions src/main/java/org/opensearch/security/support/Base64Helper.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ public static String serializeObject(final Serializable object, final boolean us
}

public static String serializeObject(final Serializable object) {
return serializeObject(object, false);
return serializeObject(object, true);
}

public static Serializable deserializeObject(final String string) {
return deserializeObject(string, false);
return deserializeObject(string, true);
}

public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
Expand Down Expand Up @@ -69,4 +69,28 @@ public static String ensureJDKSerialized(final String string) {
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, true);
}

/**
* Ensures that the returned string is custom serialized.
*
* If the supplied string is a JDK serialized representation, will deserialize it and further serialize using
* custom, otherwise returns the string as is.
*
* @param string original string, can be JDK or custom serialized
* @return custom serialized string
*/
public static String ensureCustomSerialized(final String string) {
Serializable serializable;
try {
serializable = Base64Helper.deserializeObject(string, true);
} catch (Exception e) {
// We received an exception when de-serializing the given string. It is probably custom serialized.
// Try to deserialize using custom
Base64Helper.deserializeObject(string, false);
// Since we could deserialize the object using custom, the string is already custom serialized, return as is
return string;
}
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.Version;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -231,13 +232,22 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
}

try {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
if (clusterInfoHolder.getMinNodeVersion() == null || clusterInfoHolder.getMinNodeVersion().before(Version.V_2_14_0)) {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
} else if (serializationFormat == SerializationFormat.CustomSerializer_2_11) {
Map<String, String> customSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k))));
headerMap.putAll(customSerializedHeaders);
}
}
getThreadContext().putHeader(headerMap);
} catch (IllegalArgumentException iae) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ public void testEnsureJDKSerialized() {
assertThat(Base64Helper.ensureJDKSerialized(customSerialized), is(jdkSerialized));
}

@Test
public void testEnsureCustomSerialized() {
String test = "string";
String jdkSerialized = Base64Helper.serializeObject(test, true);
String customSerialized = Base64Helper.serializeObject(test, false);
assertThat(Base64Helper.ensureCustomSerialized(jdkSerialized), is(customSerialized));
assertThat(Base64Helper.ensureCustomSerialized(customSerialized), is(customSerialized));
}

@Test
public void testDuplicatedItemSizes() {
var largeObject = new HashMap<String, Object>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ public class SecurityInterceptorTests {
private Connection connection3;
private DiscoveryNode otherRemoteNode;
private Connection connection4;
private DiscoveryNode remoteNodeWithCustomSerialization;
private Connection connection5;

private AsyncSender sender;
private AsyncSender serializedSender;
private AsyncSender jdkSerializedSender;
private AsyncSender customSerializedSender;
private AtomicReference<CountDownLatch> senderLatch = new AtomicReference<>(new CountDownLatch(1));

@Before
Expand Down Expand Up @@ -199,7 +202,30 @@ public void setup() {
otherRemoteNode = new DiscoveryNode("remote-node2", new TransportAddress(remoteAddress, 9876), remoteNodeVersion);
connection4 = transportService.getConnection(otherRemoteNode);

serializedSender = new AsyncSender() {
remoteNodeWithCustomSerialization = new DiscoveryNode(
"remote-node-with-custom-serialization",
new TransportAddress(localAddress, 7456),
Version.V_2_12_0
);
connection5 = transportService.getConnection(remoteNodeWithCustomSerialization);

jdkSerializedSender = new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(
Connection connection,
String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler<T> handler
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
User deserializedUser = (User) Base64Helper.deserializeObject(serializedUserHeader, true);
assertThat(deserializedUser, is(user));
senderLatch.get().countDown();
}
};

customSerializedSender = new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(
Connection connection,
Expand All @@ -209,7 +235,7 @@ public <T extends TransportResponse> void sendRequest(
TransportResponseHandler<T> handler
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, true)));
assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, false)));
senderLatch.get().countDown();
}
};
Expand Down Expand Up @@ -265,6 +291,27 @@ final void completableRequestDecorate(
senderLatch.set(new CountDownLatch(1));
}

@SuppressWarnings({ "rawtypes", "unchecked" })
final void completableRequestDecorateWithPreviouslyPopulatedHeaders(
AsyncSender sender,
Connection connection,
String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler handler,
DiscoveryNode localNode
) {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
try {
senderLatch.get().await(1, TimeUnit.SECONDS);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}

// Reset the latch so another request can be processed
senderLatch.set(new CountDownLatch(1));
}

@Test
public void testSendRequestDecorateLocalConnection() {

Expand All @@ -278,16 +325,44 @@ public void testSendRequestDecorateLocalConnection() {
public void testSendRequestDecorateRemoteConnection() {

// this is a remote request
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
// this is a remote request where the transport address is different
completableRequestDecorate(serializedSender, connection4, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection4, action, request, options, handler, localNode);
}

@Test
public void testSendRequestDecorateRemoteConnectionUsesJDKSerialization() {
threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, false));
completableRequestDecorateWithPreviouslyPopulatedHeaders(
jdkSerializedSender,
connection3,
action,
request,
options,
handler,
localNode
);
}

@Test
public void testSendRequestDecorateRemoteConnectionUsesCustomSerialization() {
threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, true));
completableRequestDecorateWithPreviouslyPopulatedHeaders(
customSerializedSender,
connection5,
action,
request,
options,
handler,
localNode
);
}

@Test
public void testSendNoOriginNodeCausesSerialization() {

// this is a request where the local node is null; have to use the remote connection since the serialization will fail
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, null);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, null);
}

@Test
Expand All @@ -296,7 +371,7 @@ public void testSendNoConnectionShouldThrowNPE() {
// The completable version swallows the NPE so have to call actual method
assertThrows(
java.lang.NullPointerException.class,
() -> securityInterceptor.sendRequestDecorate(serializedSender, null, action, request, options, handler, localNode)
() -> securityInterceptor.sendRequestDecorate(jdkSerializedSender, null, action, request, options, handler, localNode)
);
}

Expand Down Expand Up @@ -328,7 +403,7 @@ public void testCustomRemoteAddressCausesSerialization() {
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS,
String.valueOf(new TransportAddress(new InetSocketAddress("8.8.8.8", 80)))
);
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
}

@Test
Expand All @@ -351,7 +426,7 @@ public void testFakeHeaderIsIgnored() {
// this is a local request
completableRequestDecorate(sender, connection1, action, request, options, handler, localNode);
// this is a remote request
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
}

@Test
Expand All @@ -363,7 +438,7 @@ public void testNullHeaderIsIgnored() {
// this is a local request
completableRequestDecorate(sender, connection1, action, request, options, handler, localNode);
// this is a remote request
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
}

@Test
Expand Down

0 comments on commit 8ae88a7

Please sign in to comment.