diff --git a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponse.java b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponse.java index 169777e8dc..a8d06fa626 100644 --- a/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponse.java +++ b/plugin/src/main/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponse.java @@ -5,10 +5,15 @@ package org.opensearch.sql.plugin.transport; +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; import java.io.IOException; +import java.io.UncheckedIOException; import lombok.Getter; import lombok.RequiredArgsConstructor; import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.InputStreamStreamInput; +import org.opensearch.core.common.io.stream.OutputStreamStreamOutput; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; @@ -25,4 +30,22 @@ public TransportPPLQueryResponse(StreamInput in) throws IOException { public void writeTo(StreamOutput out) throws IOException { out.writeString(result); } + + public static TransportPPLQueryResponse fromActionResponse(ActionResponse actionResponse) { + if (actionResponse instanceof TransportPPLQueryResponse) { + return (TransportPPLQueryResponse) actionResponse; + } + + try (ByteArrayOutputStream baos = new ByteArrayOutputStream(); + OutputStreamStreamOutput osso = new OutputStreamStreamOutput(baos)) { + actionResponse.writeTo(osso); + try (StreamInput input = + new InputStreamStreamInput(new ByteArrayInputStream(baos.toByteArray()))) { + return new TransportPPLQueryResponse(input); + } + } catch (IOException e) { + throw new UncheckedIOException( + "failed to parse ActionResponse into TransportPPLQueryResponse", e); + } + } } diff --git a/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponseTest.java b/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponseTest.java new file mode 100644 index 0000000000..ba065c1ebe --- /dev/null +++ b/plugin/src/test/java/org/opensearch/sql/plugin/transport/TransportPPLQueryResponseTest.java @@ -0,0 +1,82 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.plugin.transport; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.net.URI; +import java.net.URISyntaxException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.Paths; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; +import org.opensearch.core.action.ActionResponse; + +public class TransportPPLQueryResponseTest { + + @Rule public ExpectedException exceptionRule = ExpectedException.none(); + + @Test + public void testFromActionResponseSameClassloader() { + TransportPPLQueryResponse response1 = new TransportPPLQueryResponse("mock result"); + TransportPPLQueryResponse response2 = TransportPPLQueryResponse.fromActionResponse(response1); + assertEquals(response1.getResult(), response2.getResult()); + } + + @Test + public void testFromActionResponseDifferentClassLoader() + throws ClassNotFoundException, + InstantiationException, + IllegalAccessException, + NoSuchMethodException, + InvocationTargetException, + URISyntaxException { + ClassLoader loader = TransportPPLQueryResponseTest.class.getClassLoader(); + URI resourceURI = + loader + .getResource("org/opensearch/sql/plugin/transport/TransportPPLQueryResponse.class") + .toURI(); + Path classFilePath = Paths.get(resourceURI); + CustomClassLoader classLoader1 = new CustomClassLoader(classFilePath); + CustomClassLoader classLoader2 = new CustomClassLoader(classFilePath); + + Class class1 = + classLoader1.findClass("org.opensearch.sql.plugin.transport.TransportPPLQueryResponse"); + Class class2 = + classLoader2.findClass("org.opensearch.sql.plugin.transport.TransportPPLQueryResponse"); + + assertFalse(class1.isAssignableFrom(class2)); + String result = "mock result"; + TransportPPLQueryResponse response2 = + TransportPPLQueryResponse.fromActionResponse( + (ActionResponse) class1.getDeclaredConstructor(String.class).newInstance(result)); + assertEquals(result, response2.getResult()); + } +} + +class CustomClassLoader extends ClassLoader { + + private final Path classFilePath; + + public CustomClassLoader(Path classFilePath) { + this.classFilePath = classFilePath; + } + + @Override + protected Class findClass(String name) throws ClassNotFoundException { + try { + byte[] classBytes = Files.readAllBytes(classFilePath); + return defineClass(name, classBytes, 0, classBytes.length); + } catch (IOException e) { + throw new ClassNotFoundException("Failed to load class: " + name, e); + } + } +}