Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add ensureCustomSerialization to ensure that headers are serialized correctly with multiple transport hops #29

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.Version;
import org.opensearch.cluster.ClusterChangedEvent;
import org.opensearch.cluster.ClusterStateListener;
import org.opensearch.cluster.node.DiscoveryNode;
Expand Down Expand Up @@ -67,6 +68,17 @@ public boolean isInitialized() {
return initialized;
}

public Version getMinNodeVersion() {
if (nodes == null) {
if (log.isDebugEnabled()) {
log.debug("Cluster Info Holder not initialized yet for 'nodes'");
}
return null;
}

return nodes.getMinNodeVersion();
}

public Boolean hasNode(DiscoveryNode node) {
if (nodes == null) {
if (log.isDebugEnabled()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ private <Request extends ActionRequest, Response extends ActionResponse> void ap
}

if (threadContext.getTransient(ConfigConstants.USE_JDK_SERIALIZATION) == null) {
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, false);
threadContext.putTransient(ConfigConstants.USE_JDK_SERIALIZATION, true);
}

final ComplianceConfig complianceConfig = auditLog.getComplianceConfig();
Expand Down
28 changes: 26 additions & 2 deletions src/main/java/org/opensearch/security/support/Base64Helper.java
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ public static String serializeObject(final Serializable object, final boolean us
}

public static String serializeObject(final Serializable object) {
return serializeObject(object, false);
return serializeObject(object, true);
}

public static Serializable deserializeObject(final String string) {
return deserializeObject(string, false);
return deserializeObject(string, true);
}

public static Serializable deserializeObject(final String string, final boolean useJDKDeserialization) {
Expand Down Expand Up @@ -69,4 +69,28 @@ public static String ensureJDKSerialized(final String string) {
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, true);
}

/**
* Ensures that the returned string is custom serialized.
*
* If the supplied string is a JDK serialized representation, will deserialize it and further serialize using
* custom, otherwise returns the string as is.
*
* @param string original string, can be JDK or custom serialized
* @return custom serialized string
*/
public static String ensureCustomSerialized(final String string) {
Serializable serializable;
try {
serializable = Base64Helper.deserializeObject(string, true);
} catch (Exception e) {
// We received an exception when de-serializing the given string. It is probably custom serialized.
// Try to deserialize using custom
Base64Helper.deserializeObject(string, false);
// Since we could deserialize the object using custom, the string is already custom serialized, return as is
return string;
}
// If we see an exception now, we want the caller to see it -
return Base64Helper.serializeObject(serializable, false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import org.opensearch.Version;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsAction;
import org.opensearch.action.admin.cluster.shards.ClusterSearchShardsResponse;
import org.opensearch.action.get.GetRequest;
Expand Down Expand Up @@ -231,13 +232,22 @@ && getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_INJECTED_ROL
}

try {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
if (clusterInfoHolder.getMinNodeVersion() == null || clusterInfoHolder.getMinNodeVersion().before(Version.V_2_14_0)) {
if (serializationFormat == SerializationFormat.JDK) {
Map<String, String> jdkSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> jdkSerializedHeaders.put(k, Base64Helper.ensureJDKSerialized(headerMap.get(k))));
headerMap.putAll(jdkSerializedHeaders);
} else if (serializationFormat == SerializationFormat.CustomSerializer_2_11) {
Map<String, String> customSerializedHeaders = new HashMap<>();
HeaderHelper.getAllSerializedHeaderNames()
.stream()
.filter(k -> headerMap.get(k) != null)
.forEach(k -> customSerializedHeaders.put(k, Base64Helper.ensureCustomSerialized(headerMap.get(k))));
headerMap.putAll(customSerializedHeaders);
}
}
getThreadContext().putHeader(headerMap);
} catch (IllegalArgumentException iae) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,15 @@ public void testEnsureJDKSerialized() {
assertThat(Base64Helper.ensureJDKSerialized(customSerialized), is(jdkSerialized));
}

@Test
public void testEnsureCustomSerialized() {
String test = "string";
String jdkSerialized = Base64Helper.serializeObject(test, true);
String customSerialized = Base64Helper.serializeObject(test, false);
assertThat(Base64Helper.ensureCustomSerialized(jdkSerialized), is(customSerialized));
assertThat(Base64Helper.ensureCustomSerialized(customSerialized), is(customSerialized));
}

@Test
public void testDuplicatedItemSizes() {
var largeObject = new HashMap<String, Object>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,12 @@ public class SecurityInterceptorTests {
private Connection connection3;
private DiscoveryNode otherRemoteNode;
private Connection connection4;
private DiscoveryNode remoteNodeWithCustomSerialization;
private Connection connection5;

private AsyncSender sender;
private AsyncSender serializedSender;
private AsyncSender jdkSerializedSender;
private AsyncSender customSerializedSender;
private AtomicReference<CountDownLatch> senderLatch = new AtomicReference<>(new CountDownLatch(1));

@Before
Expand Down Expand Up @@ -199,7 +202,30 @@ public void setup() {
otherRemoteNode = new DiscoveryNode("remote-node2", new TransportAddress(remoteAddress, 9876), remoteNodeVersion);
connection4 = transportService.getConnection(otherRemoteNode);

serializedSender = new AsyncSender() {
remoteNodeWithCustomSerialization = new DiscoveryNode(
"remote-node-with-custom-serialization",
new TransportAddress(localAddress, 7456),
Version.V_2_12_0
);
connection5 = transportService.getConnection(remoteNodeWithCustomSerialization);

jdkSerializedSender = new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(
Connection connection,
String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler<T> handler
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
User deserializedUser = (User) Base64Helper.deserializeObject(serializedUserHeader, true);
assertThat(deserializedUser, is(user));
senderLatch.get().countDown();
}
};

customSerializedSender = new AsyncSender() {
@Override
public <T extends TransportResponse> void sendRequest(
Connection connection,
Expand All @@ -209,7 +235,7 @@ public <T extends TransportResponse> void sendRequest(
TransportResponseHandler<T> handler
) {
String serializedUserHeader = threadPool.getThreadContext().getHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER);
assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, true)));
assertThat(serializedUserHeader, is(Base64Helper.serializeObject(user, false)));
senderLatch.get().countDown();
}
};
Expand Down Expand Up @@ -265,6 +291,27 @@ final void completableRequestDecorate(
senderLatch.set(new CountDownLatch(1));
}

@SuppressWarnings({ "rawtypes", "unchecked" })
final void completableRequestDecorateWithPreviouslyPopulatedHeaders(
AsyncSender sender,
Connection connection,
String action,
TransportRequest request,
TransportRequestOptions options,
TransportResponseHandler handler,
DiscoveryNode localNode
) {
securityInterceptor.sendRequestDecorate(sender, connection, action, request, options, handler, localNode);
try {
senderLatch.get().await(1, TimeUnit.SECONDS);
} catch (final InterruptedException e) {
throw new RuntimeException(e);
}

// Reset the latch so another request can be processed
senderLatch.set(new CountDownLatch(1));
}

@Test
public void testSendRequestDecorateLocalConnection() {

Expand All @@ -278,16 +325,44 @@ public void testSendRequestDecorateLocalConnection() {
public void testSendRequestDecorateRemoteConnection() {

// this is a remote request
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
// this is a remote request where the transport address is different
completableRequestDecorate(serializedSender, connection4, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection4, action, request, options, handler, localNode);
}

@Test
public void testSendRequestDecorateRemoteConnectionUsesJDKSerialization() {
threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, false));
completableRequestDecorateWithPreviouslyPopulatedHeaders(
jdkSerializedSender,
connection3,
action,
request,
options,
handler,
localNode
);
}

@Test
public void testSendRequestDecorateRemoteConnectionUsesCustomSerialization() {
threadPool.getThreadContext().putHeader(ConfigConstants.OPENDISTRO_SECURITY_USER_HEADER, Base64Helper.serializeObject(user, true));
completableRequestDecorateWithPreviouslyPopulatedHeaders(
customSerializedSender,
connection5,
action,
request,
options,
handler,
localNode
);
}

@Test
public void testSendNoOriginNodeCausesSerialization() {

// this is a request where the local node is null; have to use the remote connection since the serialization will fail
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, null);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, null);
}

@Test
Expand All @@ -296,7 +371,7 @@ public void testSendNoConnectionShouldThrowNPE() {
// The completable version swallows the NPE so have to call actual method
assertThrows(
java.lang.NullPointerException.class,
() -> securityInterceptor.sendRequestDecorate(serializedSender, null, action, request, options, handler, localNode)
() -> securityInterceptor.sendRequestDecorate(jdkSerializedSender, null, action, request, options, handler, localNode)
);
}

Expand Down Expand Up @@ -328,7 +403,7 @@ public void testCustomRemoteAddressCausesSerialization() {
ConfigConstants.OPENDISTRO_SECURITY_REMOTE_ADDRESS,
String.valueOf(new TransportAddress(new InetSocketAddress("8.8.8.8", 80)))
);
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
}

@Test
Expand All @@ -351,7 +426,7 @@ public void testFakeHeaderIsIgnored() {
// this is a local request
completableRequestDecorate(sender, connection1, action, request, options, handler, localNode);
// this is a remote request
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
}

@Test
Expand All @@ -363,7 +438,7 @@ public void testNullHeaderIsIgnored() {
// this is a local request
completableRequestDecorate(sender, connection1, action, request, options, handler, localNode);
// this is a remote request
completableRequestDecorate(serializedSender, connection3, action, request, options, handler, localNode);
completableRequestDecorate(jdkSerializedSender, connection3, action, request, options, handler, localNode);
}

@Test
Expand Down
Loading