From 72841eeffedaaba241e45060c94024e63b22c467 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Fri, 24 Mar 2023 19:38:51 +0000 Subject: [PATCH] Execute remote actions on another extension (#588) * Add ProxyAction with TransportAction and handlers Signed-off-by: Daniel Widdis * Give SDKActionModule a copy of ExtensionsRunner to use with transport Signed-off-by: Daniel Widdis * Add new ProxyActionRequest Signed-off-by: Daniel Widdis * Add SDKTransportService wrapper accessible to actions Signed-off-by: Daniel Widdis * Implement ProxyTransportAction Signed-off-by: Daniel Widdis * Add test case to HelloWorldExtension Signed-off-by: Daniel Widdis * Better naming of ExtensionActionResponse and correct action name Signed-off-by: Daniel Widdis * Refactoring with TransportService and latest OpenSearch PR updates Signed-off-by: Daniel Widdis * Add ExtensionsActionRequestHandler Signed-off-by: Daniel Widdis * Instantiate Proxy Action Request Signed-off-by: Daniel Widdis * Working test case! Signed-off-by: Daniel Widdis * Properly parse returned byte array into a response Signed-off-by: Daniel Widdis * Add sequence diagram to DESIGN.md Signed-off-by: Daniel Widdis * Typoo fix Signed-off-by: Daniel Widdis * Update with latest changes on companion PR Signed-off-by: Daniel Widdis * Rename ProxyFoo to RemoteExtensionFoo Signed-off-by: Daniel Widdis * Better handling of response bytes Signed-off-by: Daniel Widdis * Handle plugin remote action requests Signed-off-by: Daniel Widdis * Address code review comments Signed-off-by: Daniel Widdis * Update sequence diagram Signed-off-by: Daniel Widdis --------- Signed-off-by: Daniel Widdis (cherry picked from commit ebc684ae223e87f1e05712e086fc1dab911a3b5f) Signed-off-by: github-actions[bot] --- DESIGN.md | 6 + Docs/RemoteActionExecution.svg | 1 + .../org/opensearch/sdk/BaseExtension.java | 8 +- .../java/org/opensearch/sdk/Extension.java | 7 + .../org/opensearch/sdk/ExtensionsRunner.java | 45 +++++- .../java/org/opensearch/sdk/SDKClient.java | 16 ++ .../opensearch/sdk/SDKTransportService.java | 132 ++++++++++++++++ .../sdk/action/RemoteExtensionAction.java | 32 ++++ .../action/RemoteExtensionActionRequest.java | 137 +++++++++++++++++ .../RemoteExtensionTransportAction.java | 57 +++++++ .../sdk/action/SDKActionModule.java | 72 +++------ .../ExtensionActionRequestHandler.java | 144 ++++++++++++++++++ .../ExtensionActionResponseHandler.java | 84 ++++++++++ .../ExtensionsInitRequestHandler.java | 13 +- .../helloworld/HelloWorldExtension.java | 4 +- .../rest/RestRemoteHelloAction.java | 94 ++++++++++++ .../sdk/TestExtensionInterfaces.java | 4 +- .../opensearch/sdk/TestExtensionsRunner.java | 2 +- .../sdk/TestSDKTransportService.java | 104 +++++++++++++ .../sdk/action/TestProxyActionRequest.java | 98 ++++++++++++ .../sdk/action/TestSDKActionModule.java | 73 ++------- .../helloworld/TestHelloWorldExtension.java | 16 +- 22 files changed, 1017 insertions(+), 132 deletions(-) create mode 100644 Docs/RemoteActionExecution.svg create mode 100644 src/main/java/org/opensearch/sdk/SDKTransportService.java create mode 100644 src/main/java/org/opensearch/sdk/action/RemoteExtensionAction.java create mode 100644 src/main/java/org/opensearch/sdk/action/RemoteExtensionActionRequest.java create mode 100644 src/main/java/org/opensearch/sdk/action/RemoteExtensionTransportAction.java create mode 100644 src/main/java/org/opensearch/sdk/handlers/ExtensionActionRequestHandler.java create mode 100644 src/main/java/org/opensearch/sdk/handlers/ExtensionActionResponseHandler.java create mode 100644 src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestRemoteHelloAction.java create mode 100644 src/test/java/org/opensearch/sdk/TestSDKTransportService.java create mode 100644 src/test/java/org/opensearch/sdk/action/TestProxyActionRequest.java diff --git a/DESIGN.md b/DESIGN.md index 130be287..8f9db10b 100644 --- a/DESIGN.md +++ b/DESIGN.md @@ -119,6 +119,12 @@ The `ExtensionsManager` reads a list of extensions present in `extensions.yml`. (27) The User receives the response. +#### Remote Action Execution on another Extension + +Extensions may invoke actions on other extensions using the `ProxyAction` and `ProxyActionRequest`. The code sequence is shown below. + +![](Docs/RemoteActionExecution.svg) + #### Extension Point Implementation Walk Through An example of a more complex extension point, `getNamedXContent()` is shown below. A similar pattern can be followed for most extension points. diff --git a/Docs/RemoteActionExecution.svg b/Docs/RemoteActionExecution.svg new file mode 100644 index 00000000..e7c9c3eb --- /dev/null +++ b/Docs/RemoteActionExecution.svg @@ -0,0 +1 @@ +title%20Remote%20Action%20Execution%0A%0Aparticipant%20%22Local%20Extension%22%20as%20a%0Aparticipant%20%22OpenSearch%22%20as%20os%0Aparticipant%20%22Remote%20Extension%22%20as%20b%0A%0Aos%3C-b%3AIterate%20getActions()%20and%20register%5Cnaction%20handlers%20in%20OpenSearch%20using%5CnExtensionTransportActionsHandler%5CnregisterAction()%20method%0Aos-%3Eos%3Aregister%20actions%20in%20ActionModule%5CnDynamicActionRegistry%20linking%5CnAction%20class%20name%20to%20extension%0A%0Aa-%3Ea%3ACreate%20RemoteExtensionActionRequest%20with%5CnRemote%20Extension%20ActionType%20class%2C%5CnActionRequest%20class%2C%20and%20bytes%20from%5CnActionRequest%20param%20serialization%0A%0Aa-%3Ea%3Aclient.execute()%20invokes%20doExecute()%5Cnon%20RemoteExtensionTransportAction%0A%0Aa-%3Eos%3AsendRemoteExtensionActionRequest()%20method%20in%5CnSDKTransportService%20generates%5CnTransportActionRequestFromExtension%5Cnand%20sends%20to%20OpenSearch%5CnExtensionTransportActionsHandler%0A%0Aos-%3Eos%3AhandleTransportActionRequestFromExtension()%5Cnidentifies%20remote%20Extension%20node%20and%20calls%5Cnclient.execute()%20on%20Dynamic%20action%20from%5CnDynamicActionRegistry%0A%0Aos-%3Eb%3AsendTransportRequestToExtension()%5Cngenerates%20ExtensionActionRequest%5Cnand%20sends%20to%20the%20remote%20extension%0A%0Ab%3C-b%3AExtensionActionRequestHandler%5Cnreconstructs%20ActionType%20instance%5Cnand%20ActionRequest%20class%20and%5Cninvokes%20client.execute()%20on%20the%5Cndesired%20remote%20action%0A%0Aos%3C-b%3ARemoteExtensionActionResponse%20handled%5Cnby%20sendTransportRequestToExtension()%0A%0Aa%3C-os%3ARemoteExtensionActionResponse%20handled%5Cnby%20sendRemoteExtensionActionRequest()%0A%0Aa-%3Ea%3AResponse%20bytes%20deserialized%20based%20on%20remote%5Cnextension's%20ActionResponse%20serializationRemote Action ExecutionLocal ExtensionOpenSearchRemote ExtensionIterate getActions() and registeraction handlers in OpenSearch usingExtensionTransportActionsHandlerregisterAction() methodregister actions in ActionModuleDynamicActionRegistry linkingAction class name to extensionCreate RemoteExtensionActionRequest withRemote Extension ActionType class,ActionRequest class, and bytes fromActionRequest param serializationclient.execute() invokes doExecute()on RemoteExtensionTransportActionsendRemoteExtensionActionRequest() method inSDKTransportService generatesTransportActionRequestFromExtensionand sends to OpenSearchExtensionTransportActionsHandlerhandleTransportActionRequestFromExtension()identifies remote Extension node and callsclient.execute() on Dynamic action fromDynamicActionRegistrysendTransportRequestToExtension()generates ExtensionActionRequestand sends to the remote extensionExtensionActionRequestHandlerreconstructs ActionType instanceand ActionRequest class andinvokes client.execute() on thedesired remote actionRemoteExtensionActionResponse handledby sendTransportRequestToExtension()RemoteExtensionActionResponse handledby sendRemoteExtensionActionRequest()Response bytes deserialized based on remoteextension's ActionResponse serialization diff --git a/src/main/java/org/opensearch/sdk/BaseExtension.java b/src/main/java/org/opensearch/sdk/BaseExtension.java index 0a109c8e..fde742d5 100644 --- a/src/main/java/org/opensearch/sdk/BaseExtension.java +++ b/src/main/java/org/opensearch/sdk/BaseExtension.java @@ -11,8 +11,6 @@ import java.io.IOException; -import com.google.inject.Inject; - /** * An abstract class that simplifies extension initialization and provides an instance of the runner. */ @@ -20,7 +18,6 @@ public abstract class BaseExtension implements Extension { /** * The {@link ExtensionsRunner} instance running this extension */ - @Inject private ExtensionsRunner extensionsRunner; /** @@ -56,6 +53,11 @@ public ExtensionSettings getExtensionSettings() { return this.settings; } + @Override + public void setExtensionsRunner(ExtensionsRunner runner) { + this.extensionsRunner = runner; + } + /** * Gets the {@link ExtensionsRunner} of this extension. * diff --git a/src/main/java/org/opensearch/sdk/Extension.java b/src/main/java/org/opensearch/sdk/Extension.java index cd7bdc8a..980f91cc 100644 --- a/src/main/java/org/opensearch/sdk/Extension.java +++ b/src/main/java/org/opensearch/sdk/Extension.java @@ -25,6 +25,13 @@ */ public interface Extension { + /** + * Set the instance of {@link ExtensionsRunner} for this extension. + * + * @param runner The ExtensionsRunner instance. + */ + public void setExtensionsRunner(ExtensionsRunner runner); + /** * Gets the {@link ExtensionSettings} of this extension. * diff --git a/src/main/java/org/opensearch/sdk/ExtensionsRunner.java b/src/main/java/org/opensearch/sdk/ExtensionsRunner.java index 4e1b034a..5c1119b2 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionsRunner.java +++ b/src/main/java/org/opensearch/sdk/ExtensionsRunner.java @@ -25,6 +25,7 @@ import org.opensearch.extensions.DiscoveryExtensionNode; import org.opensearch.extensions.AddSettingsUpdateConsumerRequest; import org.opensearch.extensions.UpdateSettingsRequest; +import org.opensearch.extensions.action.ExtensionActionRequest; import org.opensearch.extensions.ExtensionsManager.RequestType; import org.opensearch.extensions.ExtensionRequest; import org.opensearch.extensions.ExtensionsManager; @@ -33,6 +34,7 @@ import org.opensearch.sdk.handlers.ClusterSettingsResponseHandler; import org.opensearch.sdk.handlers.ClusterStateResponseHandler; import org.opensearch.sdk.handlers.EnvironmentSettingsResponseHandler; +import org.opensearch.sdk.handlers.ExtensionActionRequestHandler; import org.opensearch.sdk.action.SDKActionModule; import org.opensearch.sdk.handlers.AcknowledgedResponseHandler; import org.opensearch.sdk.handlers.ExtensionDependencyResponseHandler; @@ -132,13 +134,15 @@ public class ExtensionsRunner { private final SDKNamedXContentRegistry sdkNamedXContentRegistry; private final SDKClient sdkClient; private final SDKClusterService sdkClusterService; + private final SDKTransportService sdkTransportService; private final SDKActionModule sdkActionModule; - private ExtensionsInitRequestHandler extensionsInitRequestHandler = new ExtensionsInitRequestHandler(this); - private ExtensionsIndicesModuleRequestHandler extensionsIndicesModuleRequestHandler = new ExtensionsIndicesModuleRequestHandler(); - private ExtensionsIndicesModuleNameRequestHandler extensionsIndicesModuleNameRequestHandler = + private final ExtensionsInitRequestHandler extensionsInitRequestHandler = new ExtensionsInitRequestHandler(this); + private final ExtensionsIndicesModuleRequestHandler extensionsIndicesModuleRequestHandler = new ExtensionsIndicesModuleRequestHandler(); + private final ExtensionsIndicesModuleNameRequestHandler extensionsIndicesModuleNameRequestHandler = new ExtensionsIndicesModuleNameRequestHandler(); - private ExtensionsRestRequestHandler extensionsRestRequestHandler = new ExtensionsRestRequestHandler(extensionRestPathRegistry); + private final ExtensionsRestRequestHandler extensionsRestRequestHandler = new ExtensionsRestRequestHandler(extensionRestPathRegistry); + private final ExtensionActionRequestHandler extensionsActionRequestHandler; /** * Instantiates a new update settings request handler @@ -152,7 +156,10 @@ public class ExtensionsRunner { * @throws IOException if the runner failed to read settings or API. */ protected ExtensionsRunner(Extension extension) throws IOException { + // Link these classes together this.extension = extension; + extension.setExtensionsRunner(this); + // Initialize concrete classes needed by extensions // These must have getters from this class to be accessible via createComponents // If they require later initialization, create a concrete wrapper class and update the internals @@ -175,6 +182,8 @@ protected ExtensionsRunner(Extension extension) throws IOException { this.sdkClient = new SDKClient(extensionSettings); // initialize SDKClusterService. Must happen after extension field assigned this.sdkClusterService = new SDKClusterService(this); + // initialize SDKTransportService. Must happen after extension field assigned + this.sdkTransportService = new SDKTransportService(); // Create Guice modules for injection List modules = new ArrayList<>(); @@ -189,6 +198,7 @@ protected ExtensionsRunner(Extension extension) throws IOException { b.bind(SDKClient.class).toInstance(getSdkClient()); b.bind(SDKClusterService.class).toInstance(getSdkClusterService()); + b.bind(SDKTransportService.class).toInstance(getSdkTransportService()); }); // Bind the return values from create components modules.add(this::injectComponents); @@ -202,6 +212,8 @@ protected ExtensionsRunner(Extension extension) throws IOException { // initialize SDKClient action map initializeSdkClient(); + extensionsActionRequestHandler = new ExtensionActionRequestHandler(getSdkClient()); + if (extension instanceof ActionExtension) { // store REST handlers in the registry for (ExtensionRestHandler extensionRestHandler : ((ActionExtension) extension).getExtensionRestHandlers()) { @@ -391,6 +403,25 @@ public void startTransportService(TransportService transportService) { ((request, channel, task) -> channel.sendResponse(updateSettingsRequestHandler.handleUpdateSettingsRequest(request))) ); + // This handles a remote extension request from OpenSearch or a plugin, sending an ExtensionActionResponse + transportService.registerRequestHandler( + ExtensionsManager.REQUEST_EXTENSION_HANDLE_TRANSPORT_ACTION, + ThreadPool.Names.GENERIC, + false, + false, + ExtensionActionRequest::new, + ((request, channel, task) -> channel.sendResponse(extensionsActionRequestHandler.handleExtensionActionRequest(request))) + ); + + // This handles a remote extension request from another extension, sending a RemoteExtensionActionResponse + transportService.registerRequestHandler( + ExtensionsManager.REQUEST_EXTENSION_HANDLE_REMOTE_TRANSPORT_ACTION, + ThreadPool.Names.GENERIC, + false, + false, + ExtensionActionRequest::new, + ((request, channel, task) -> channel.sendResponse(extensionsActionRequestHandler.handleRemoteExtensionActionRequest(request))) + ); } /** @@ -638,6 +669,10 @@ public TransportService getExtensionTransportService() { return extensionTransportService; } + public SDKTransportService getSdkTransportService() { + return sdkTransportService; + } + /** * Starts an ActionListener. * @@ -660,6 +695,8 @@ public static void run(Extension extension) throws IOException { // initialize the transport service NettyTransport nettyTransport = new NettyTransport(runner); runner.extensionTransportService = nettyTransport.initializeExtensionTransportService(runner.getSettings(), runner.getThreadPool()); + // TODO: merge above line with below line when refactoring out extensionTransportService + runner.getSdkTransportService().setTransportService(runner.extensionTransportService); runner.startActionListener(0); } diff --git a/src/main/java/org/opensearch/sdk/SDKClient.java b/src/main/java/org/opensearch/sdk/SDKClient.java index 5647351a..58a2aaca 100644 --- a/src/main/java/org/opensearch/sdk/SDKClient.java +++ b/src/main/java/org/opensearch/sdk/SDKClient.java @@ -13,6 +13,7 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; +import java.util.stream.Collectors; import com.fasterxml.jackson.annotation.JsonTypeInfo; import com.fasterxml.jackson.databind.DeserializationFeature; @@ -91,6 +92,9 @@ public SDKClient(ExtensionSettings extensionSettings) { // Used by client.execute, populated by initialize method @SuppressWarnings("rawtypes") private Map actions = Collections.emptyMap(); + // Used by remote client execution where we get a string for the class name + @SuppressWarnings("rawtypes") + private Map actionClassToInstanceMap = Collections.emptyMap(); /** * Initialize this client. @@ -100,6 +104,7 @@ public SDKClient(ExtensionSettings extensionSettings) { @SuppressWarnings("rawtypes") public void initialize(Map actions) { this.actions = actions; + this.actionClassToInstanceMap = actions.keySet().stream().collect(Collectors.toMap(a -> a.getClass().getName(), a -> a)); } /** @@ -259,6 +264,17 @@ public void close() throws IOException { doCloseHighLevelClient(); } + /** + * Gets an instance of {@link ActionType} from its corresponding class name, suitable for using as the first parameter in {@link #execute(ActionType, ActionRequest, ActionListener)}. + * + * @param className The class name of the action type + * @return The instance corresponding to the class name + */ + @SuppressWarnings("unchecked") + public ActionType getActionFromClassName(String className) { + return actionClassToInstanceMap.get(className); + } + /** * Executes a generic action, denoted by an {@link ActionType}. * diff --git a/src/main/java/org/opensearch/sdk/SDKTransportService.java b/src/main/java/org/opensearch/sdk/SDKTransportService.java new file mode 100644 index 00000000..20d8377a --- /dev/null +++ b/src/main/java/org/opensearch/sdk/SDKTransportService.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk; + +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeoutException; +import java.util.stream.Collectors; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.extensions.action.RegisterTransportActionsRequest; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; +import org.opensearch.extensions.action.TransportActionRequestFromExtension; +import org.opensearch.sdk.ActionExtension.ActionHandler; +import org.opensearch.sdk.action.RemoteExtensionActionRequest; +import org.opensearch.sdk.action.SDKActionModule; +import org.opensearch.sdk.handlers.AcknowledgedResponseHandler; +import org.opensearch.sdk.handlers.ExtensionActionResponseHandler; +import org.opensearch.transport.TransportService; + +/** + * Wrapper class for {@link TransportService} and associated methods. + * + * TODO: Move all the sendFooRequest() methods here + * TODO: Replace usages of getExtensionTransportService with this class + * https://github.com/opensearch-project/opensearch-sdk-java/issues/585 + */ +public class SDKTransportService { + private final Logger logger = LogManager.getLogger(SDKTransportService.class); + + private TransportService transportService; + private DiscoveryNode opensearchNode; + private String uniqueId; + + /** + * Requests that OpenSearch register the Transport Actions for this extension. + * + * @param actions The map of registered actions from {@link SDKActionModule#getActions()} + */ + public void sendRegisterTransportActionsRequest(Map> actions) { + logger.info("Sending Register Transport Actions request to OpenSearch"); + Set actionNameSet = actions.values() + .stream() + .filter(h -> !h.getAction().name().startsWith("internal")) + .map(h -> h.getAction().getClass().getName()) + .collect(Collectors.toSet()); + AcknowledgedResponseHandler registerTransportActionsResponseHandler = new AcknowledgedResponseHandler(); + try { + transportService.sendRequest( + opensearchNode, + ExtensionsManager.REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS, + new RegisterTransportActionsRequest(uniqueId, actionNameSet), + registerTransportActionsResponseHandler + ); + } catch (Exception e) { + logger.error("Failed to send Register Transport Actions request to OpenSearch", e); + } + } + + /** + * Requests that OpenSearch execute a Transport Actions on another extension. + * + * @param request The request to send + * @return A buffer serializing the response from the remote action if successful, otherwise null + */ + public RemoteExtensionActionResponse sendRemoteExtensionActionRequest(RemoteExtensionActionRequest request) { + logger.info("Sending Remote Extension Action request to OpenSearch for [" + request.getAction() + "]"); + // Combine class name string and request bytes + byte[] requestClassBytes = request.getRequestClass().getBytes(StandardCharsets.UTF_8); + byte[] proxyRequestBytes = ByteBuffer.allocate(requestClassBytes.length + 1 + request.getRequestBytes().length) + .put(requestClassBytes) + .put(RemoteExtensionActionRequest.UNIT_SEPARATOR) + .put(request.getRequestBytes()) + .array(); + ExtensionActionResponseHandler extensionActionResponseHandler = new ExtensionActionResponseHandler(); + try { + transportService.sendRequest( + opensearchNode, + ExtensionsManager.TRANSPORT_ACTION_REQUEST_FROM_EXTENSION, + new TransportActionRequestFromExtension(request.getAction(), proxyRequestBytes, uniqueId), + extensionActionResponseHandler + ); + // Wait on response + extensionActionResponseHandler.awaitResponse(); + } catch (TimeoutException e) { + logger.error("Failed to receive Remote Extension Action response from OpenSearch", e); + } catch (Exception e) { + logger.error("Failed to send Remote Extension Action request to OpenSearch", e); + } + // At this point, response handler has read in the response bytes + return new RemoteExtensionActionResponse( + extensionActionResponseHandler.isSuccess(), + extensionActionResponseHandler.getResponseBytes() + ); + } + + public TransportService getTransportService() { + return transportService; + } + + public DiscoveryNode getOpensearchNode() { + return opensearchNode; + } + + public String getUniqueId() { + return uniqueId; + } + + public void setTransportService(TransportService transportService) { + this.transportService = transportService; + } + + public void setOpensearchNode(DiscoveryNode opensearchNode) { + this.opensearchNode = opensearchNode; + } + + public void setUniqueId(String uniqueId) { + this.uniqueId = uniqueId; + } +} diff --git a/src/main/java/org/opensearch/sdk/action/RemoteExtensionAction.java b/src/main/java/org/opensearch/sdk/action/RemoteExtensionAction.java new file mode 100644 index 00000000..4cb8d00d --- /dev/null +++ b/src/main/java/org/opensearch/sdk/action/RemoteExtensionAction.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.action; + +import org.opensearch.action.ActionType; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; + +/** + * The {@link ActionType} used as they key for the {@link RemoteExtensionTransportAction}. + */ +public class RemoteExtensionAction extends ActionType { + + /** + * The name to look up this action with + */ + public static final String NAME = "internal:remote-extension-action"; + /** + * The singleton instance of this class + */ + public static final RemoteExtensionAction INSTANCE = new RemoteExtensionAction(); + + private RemoteExtensionAction() { + super(NAME, RemoteExtensionActionResponse::new); + } +} diff --git a/src/main/java/org/opensearch/sdk/action/RemoteExtensionActionRequest.java b/src/main/java/org/opensearch/sdk/action/RemoteExtensionActionRequest.java new file mode 100644 index 00000000..b3242b4e --- /dev/null +++ b/src/main/java/org/opensearch/sdk/action/RemoteExtensionActionRequest.java @@ -0,0 +1,137 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.action; + +import java.io.IOException; +import java.util.Objects; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.extensions.action.ExtensionTransportActionsHandler; + +/** + * A request class to request an action be executed on another extension + */ +public class RemoteExtensionActionRequest extends ActionRequest { + /** + * The Unicode UNIT SEPARATOR used to separate the Request class name and parameter bytes + */ + public static final byte UNIT_SEPARATOR = (byte) '\u001F'; + /** + * action is the TransportAction intended to be invoked which is registered by an extension via {@link ExtensionTransportActionsHandler}. + */ + private final String action; + /** + * requestClass is the ActionRequest class associated with the TransportAction + */ + private final String requestClass; + /** + * requestBytes is the raw bytes being transported between extensions. + *

+ * This array is the serialized bytes used to instantiate the {@link #requestClass} instance using its StreamInput constructor. + */ + private final byte[] requestBytes; + + /** + * RemoteExtensionActionRequest constructor with an ActionType and Request class. Requires a dependency on the remote extension code. + * + * @param instance An instance of {@link ActionType} registered with the remote extension's getActions registry + * @param request A class extending {@link ActionRequest} associated with an action to be executed on another extension. + */ + public RemoteExtensionActionRequest(ActionType instance, ActionRequest request) { + this.action = instance.getClass().getName(); + this.requestClass = request.getClass().getName(); + byte[] bytes = new byte[0]; + try (BytesStreamOutput out = new BytesStreamOutput()) { + request.writeTo(out); + bytes = BytesReference.toBytes(out.bytes()); + } catch (IOException e) { + throw new IllegalStateException("Writing an OutputStream to memory should never result in an IOException."); + } + this.requestBytes = bytes; + } + + /** + * RemoteExtensionActionRequest constructor with class names and request bytes. Does not require a dependency on the remote extension code. + * + * @param action A string representing the fully qualified class name of the remote ActionType instance + * @param requestClass A string representing the fully qualified class name of the remote ActionRequest class + * @param requestBytes Bytes representing the serialized parameters to be used in the ActionRequest class StreamInput constructor + */ + public RemoteExtensionActionRequest(String action, String requestClass, byte[] requestBytes) { + this.action = action; + this.requestClass = requestClass; + this.requestBytes = requestBytes; + } + + /** + * RemoteExtensionActionRequest constructor from {@link StreamInput}. + * + * @param in bytes stream input used to de-serialize the message. + * @throws IOException when message de-serialization fails. + */ + public RemoteExtensionActionRequest(StreamInput in) throws IOException { + super(in); + this.action = in.readString(); + this.requestClass = in.readString(); + this.requestBytes = in.readByteArray(); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(action); + out.writeString(requestClass); + out.writeByteArray(requestBytes); + } + + public String getAction() { + return this.action; + } + + public String getRequestClass() { + return this.requestClass; + } + + public byte[] getRequestBytes() { + return this.requestBytes; + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + + @Override + public String toString() { + return "RemoteExtensionActionRequest{action=" + action + ", requestClass=" + requestClass + ", requestBytes=" + requestBytes + "}"; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) return true; + if (obj == null || getClass() != obj.getClass()) return false; + RemoteExtensionActionRequest that = (RemoteExtensionActionRequest) obj; + return Objects.equals(action, that.action) + && Objects.equals(requestClass, that.requestClass) + && Objects.equals(requestBytes, that.requestBytes); + } + + @Override + public int hashCode() { + return Objects.hash(action, requestClass, requestBytes); + } +} diff --git a/src/main/java/org/opensearch/sdk/action/RemoteExtensionTransportAction.java b/src/main/java/org/opensearch/sdk/action/RemoteExtensionTransportAction.java new file mode 100644 index 00000000..fa9ce8ed --- /dev/null +++ b/src/main/java/org/opensearch/sdk/action/RemoteExtensionTransportAction.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.action; + +import org.opensearch.action.ActionListener; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.TransportAction; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; +import org.opensearch.sdk.SDKTransportService; +import org.opensearch.tasks.Task; +import org.opensearch.tasks.TaskManager; + +import com.google.inject.Inject; + +/** + * Sends a request to OpenSearch for a remote extension to execute an action. + */ +public class RemoteExtensionTransportAction extends TransportAction { + + private SDKTransportService sdkTransportService; + + /** + * Instantiate this action + * + * @param actionName The action name + * @param actionFilters Action filters + * @param taskManager The task manager + * @param sdkTransportService The SDK transport service + */ + @Inject + protected RemoteExtensionTransportAction( + String actionName, + ActionFilters actionFilters, + TaskManager taskManager, + SDKTransportService sdkTransportService + ) { + super(actionName, actionFilters, taskManager); + this.sdkTransportService = sdkTransportService; + } + + @Override + protected void doExecute(Task task, RemoteExtensionActionRequest request, ActionListener listener) { + RemoteExtensionActionResponse response = sdkTransportService.sendRemoteExtensionActionRequest(request); + if (response.getResponseBytes().length > 0) { + listener.onResponse(response); + } else { + listener.onFailure(new RuntimeException("No response received from remote extension.")); + } + } +} diff --git a/src/main/java/org/opensearch/sdk/action/SDKActionModule.java b/src/main/java/org/opensearch/sdk/action/SDKActionModule.java index ef8f5868..4b6cf228 100644 --- a/src/main/java/org/opensearch/sdk/action/SDKActionModule.java +++ b/src/main/java/org/opensearch/sdk/action/SDKActionModule.java @@ -13,19 +13,12 @@ import java.util.Map; import java.util.stream.Collectors; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionType; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.TransportAction; -import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.NamedRegistry; -import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.extensions.RegisterTransportActionsRequest; import org.opensearch.sdk.ActionExtension.ActionHandler; import org.opensearch.sdk.Extension; -import org.opensearch.sdk.handlers.AcknowledgedResponseHandler; -import org.opensearch.transport.TransportService; import com.google.inject.AbstractModule; import com.google.inject.multibindings.MapBinder; @@ -38,15 +31,13 @@ * A module for injecting getActions classes into Guice. */ public class SDKActionModule extends AbstractModule { - private final Logger logger = LogManager.getLogger(SDKActionModule.class); - private final Map> actions; private final ActionFilters actionFilters; /** * Instantiate this module * - * @param extension An instance of {@link ActionExtension}. + * @param extension The extension */ public SDKActionModule(Extension extension) { this.actions = setupActions(extension); @@ -62,24 +53,33 @@ public ActionFilters getActionFilters() { } private static Map> setupActions(Extension extension) { - if (extension instanceof ActionExtension) { - // Subclass NamedRegistry for easy registration - class ActionRegistry extends NamedRegistry> { - ActionRegistry() { - super("action"); - } - - public void register(ActionHandler handler) { - register(handler.getAction().name(), handler); - } + /** + * Subclass of NamedRegistry permitting easier action registration + */ + class ActionRegistry extends NamedRegistry> { + ActionRegistry() { + super("action"); } - ActionRegistry actions = new ActionRegistry(); - // Register getActions in it - ((ActionExtension) extension).getActions().stream().forEach(actions::register); - return unmodifiableMap(actions.getRegistry()); + /** + * Register an action handler pairing an ActionType and TransportAction + * + * @param handler The ActionHandler to register + */ + public void register(ActionHandler handler) { + register(handler.getAction().name(), handler); + } } - return Collections.emptyMap(); + ActionRegistry actions = new ActionRegistry(); + + // Register SDK actions + actions.register(new ActionHandler<>(RemoteExtensionAction.INSTANCE, RemoteExtensionTransportAction.class)); + + // Register actions from getActions extension point + if (extension instanceof ActionExtension) { + ((ActionExtension) extension).getActions().stream().forEach(actions::register); + } + return unmodifiableMap(actions.getRegistry()); } private static ActionFilters setupActionFilters(Extension extension) { @@ -108,26 +108,4 @@ protected void configure() { transportActionsBinder.addBinding(action.getAction()).to(action.getTransportAction()).asEagerSingleton(); } } - - /** - * Requests that OpenSearch register the Transport Actions for this extension. - * - * @param transportService The TransportService defining the connection to OpenSearch. - * @param opensearchNode The OpenSearch node where transport actions being registered. - * @param uniqueId The identity used to - */ - public void sendRegisterTransportActionsRequest(TransportService transportService, DiscoveryNode opensearchNode, String uniqueId) { - logger.info("Sending Register Transport Actions request to OpenSearch"); - AcknowledgedResponseHandler registerTransportActionsResponseHandler = new AcknowledgedResponseHandler(); - try { - transportService.sendRequest( - opensearchNode, - ExtensionsManager.REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS, - new RegisterTransportActionsRequest(uniqueId, getActions().keySet()), - registerTransportActionsResponseHandler - ); - } catch (Exception e) { - logger.info("Failed to send Register Transport Actions request to OpenSearch", e); - } - } } diff --git a/src/main/java/org/opensearch/sdk/handlers/ExtensionActionRequestHandler.java b/src/main/java/org/opensearch/sdk/handlers/ExtensionActionRequestHandler.java new file mode 100644 index 00000000..7511cf9e --- /dev/null +++ b/src/main/java/org/opensearch/sdk/handlers/ExtensionActionRequestHandler.java @@ -0,0 +1,144 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.handlers; + +import java.io.IOException; +import java.lang.reflect.Constructor; +import java.nio.charset.StandardCharsets; +import java.util.Arrays; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.ActionListener; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.extensions.action.ExtensionActionRequest; +import org.opensearch.extensions.action.ExtensionActionResponse; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; +import org.opensearch.sdk.SDKClient; +import org.opensearch.sdk.SDKTransportService; +import org.opensearch.sdk.action.RemoteExtensionActionRequest; + +/** + * This class handles a request from OpenSearch from another extension's {@link SDKTransportService#sendProxyActionRequest()} call. + */ +public class ExtensionActionRequestHandler { + private static final Logger logger = LogManager.getLogger(ExtensionActionRequestHandler.class); + + private final SDKClient sdkClient; + + /** + * Instantiate this handler + * + * @param sdkClient An initialized SDKClient with the registered actions + */ + public ExtensionActionRequestHandler(SDKClient sdkClient) { + this.sdkClient = sdkClient; + } + + /** + * Handles a request from OpenSearch to execute a TransportAction on the extension. These requests originated from OpenSearch or a plugin. + * + * @param request The request to execute + * @return The response from the TransportAction + */ + public ExtensionActionResponse handleExtensionActionRequest(ExtensionActionRequest request) { + // For now we just delegate to the remote actions. + // There is potential in the future for handling these requests differently + RemoteExtensionActionResponse response = handleRemoteExtensionActionRequest(request); + // Discard the success bit and just return the bytes + return new ExtensionActionResponse(response.getResponseBytes()); + } + + /** + * Handles a request from OpenSearch to execute a TransportAction on the extension. These requests originated from another extension. + * + * @param request The request to execute + * @return The response from the TransportAction + */ + public RemoteExtensionActionResponse handleRemoteExtensionActionRequest(ExtensionActionRequest request) { + logger.debug("Received request to execute action [" + request.getAction() + "]"); + final RemoteExtensionActionResponse response = new RemoteExtensionActionResponse(false, new byte[0]); + + // Find matching ActionType instance + ActionType action = sdkClient.getActionFromClassName(request.getAction()); + if (action == null) { + response.setResponseBytesAsString("No action [" + request.getAction() + "] is registered."); + return response; + } + logger.debug("Found matching action [" + action.name() + "], an instance of [" + action.getClass().getName() + "]"); + + // Extract request class name from bytes and instantiate request + int nullPos = indexOf(request.getRequestBytes(), RemoteExtensionActionRequest.UNIT_SEPARATOR); + String requestClassName = new String(Arrays.copyOfRange(request.getRequestBytes(), 0, nullPos + 1), StandardCharsets.UTF_8) + .stripTrailing(); + ActionRequest actionRequest = null; + try { + Class clazz = Class.forName(requestClassName); + Constructor constructor = clazz.getConstructor(StreamInput.class); + StreamInput requestByteStream = StreamInput.wrap( + Arrays.copyOfRange(request.getRequestBytes(), nullPos + 1, request.getRequestBytes().length) + ); + actionRequest = (ActionRequest) constructor.newInstance(requestByteStream); + } catch (Exception e) { + response.setResponseBytesAsString("No request class [" + requestClassName + "] is available: " + e.getMessage()); + return response; + } + + // Execute the action + // TODO: We need async client.execute to hide these action listener details and return the future directly + // https://github.com/opensearch-project/opensearch-sdk-java/issues/584 + CompletableFuture futureResponse = new CompletableFuture<>(); + sdkClient.execute(action, actionRequest, ActionListener.wrap(r -> { + byte[] bytes = new byte[0]; + try (BytesStreamOutput out = new BytesStreamOutput()) { + ((ActionResponse) r).writeTo(out); + bytes = BytesReference.toBytes(out.bytes()); + } catch (IOException e) { + throw new IllegalStateException("Writing an OutputStream to memory should never result in an IOException."); + } + response.setSuccess(true); + response.setResponseBytes(bytes); + futureResponse.complete(response); + }, e -> futureResponse.completeExceptionally(e))); + + logger.debug("Waiting for response to action [" + request.getAction() + "]"); + try { + RemoteExtensionActionResponse actionResponse = futureResponse.orTimeout( + ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, + TimeUnit.SECONDS + ).get(); + response.setSuccess(true); + response.setResponseBytes(actionResponse.getResponseBytes()); + logger.debug("Response successful to [" + request.getAction() + "]"); + } catch (Exception e) { + response.setResponseBytesAsString("Action failed: " + e.getMessage()); + logger.debug("Response failed to [" + request.getAction() + "]"); + } + logger.debug("Sending action response to OpenSearch: " + response.getResponseBytes().length + " bytes"); + return response; + } + + private static int indexOf(byte[] bytes, byte value) { + for (int offset = 0; offset < bytes.length; ++offset) { + if (bytes[offset] == value) { + return offset; + } + } + return -1; + } +} diff --git a/src/main/java/org/opensearch/sdk/handlers/ExtensionActionResponseHandler.java b/src/main/java/org/opensearch/sdk/handlers/ExtensionActionResponseHandler.java new file mode 100644 index 00000000..caa82a50 --- /dev/null +++ b/src/main/java/org/opensearch/sdk/handlers/ExtensionActionResponseHandler.java @@ -0,0 +1,84 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.handlers; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; +import org.opensearch.sdk.SDKTransportService; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportException; +import org.opensearch.transport.TransportResponseHandler; + +import java.io.IOException; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; + +/** + * This class handles the response from OpenSearch to a {@link SDKTransportService#sendProxyActionRequest()} call. + */ +public class ExtensionActionResponseHandler implements TransportResponseHandler { + + private static final Logger logger = LogManager.getLogger(ExtensionActionResponseHandler.class); + private final CompletableFuture inProgressFuture; + private boolean success = false; + private byte[] responseBytes = new byte[0]; + + /** + * Instantiates a new ExtensionActionResponseHandler + */ + public ExtensionActionResponseHandler() { + this.inProgressFuture = new CompletableFuture<>(); + } + + @Override + public void handleResponse(RemoteExtensionActionResponse response) { + logger.info("Received response bytes: " + response.getResponseBytes().length + " bytes"); + // Set ExtensionActionResponse from response + this.success = response.isSuccess(); + this.responseBytes = response.getResponseBytes(); + inProgressFuture.complete(response); + } + + @Override + public void handleException(TransportException exp) { + logger.error("ExtensionActionResponseRequest failed", exp); + inProgressFuture.completeExceptionally(exp); + } + + @Override + public String executor() { + return ThreadPool.Names.GENERIC; + } + + @Override + public RemoteExtensionActionResponse read(StreamInput in) throws IOException { + return new RemoteExtensionActionResponse(in); + } + + /** + * Waits for the ExtensionActionResponseHandler future to complete + * @throws Exception + * if the response times out + */ + public void awaitResponse() throws Exception { + inProgressFuture.orTimeout(ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, TimeUnit.SECONDS).get(); + } + + public boolean isSuccess() { + return success; + } + + public byte[] getResponseBytes() { + return this.responseBytes; + } +} diff --git a/src/main/java/org/opensearch/sdk/handlers/ExtensionsInitRequestHandler.java b/src/main/java/org/opensearch/sdk/handlers/ExtensionsInitRequestHandler.java index e995e63e..90444878 100644 --- a/src/main/java/org/opensearch/sdk/handlers/ExtensionsInitRequestHandler.java +++ b/src/main/java/org/opensearch/sdk/handlers/ExtensionsInitRequestHandler.java @@ -16,6 +16,7 @@ import org.opensearch.discovery.InitializeExtensionRequest; import org.opensearch.discovery.InitializeExtensionResponse; import org.opensearch.sdk.ExtensionsRunner; +import org.opensearch.sdk.SDKTransportService; import org.opensearch.transport.TransportService; import static org.opensearch.sdk.ExtensionsRunner.NODE_NAME_SETTING; @@ -48,6 +49,10 @@ public InitializeExtensionResponse handleExtensionInitRequest(InitializeExtensio logger.info("Registering Extension Request received from OpenSearch"); extensionsRunner.opensearchNode = extensionInitRequest.getSourceNode(); extensionsRunner.setUniqueId(extensionInitRequest.getExtension().getId()); + // TODO: Remove above two lines in favor of the below when refactoring + SDKTransportService sdkTransportService = extensionsRunner.getSdkTransportService(); + sdkTransportService.setOpensearchNode(extensionInitRequest.getSourceNode()); + sdkTransportService.setUniqueId(extensionInitRequest.getExtension().getId()); // Successfully initialized. Send the response. try { return new InitializeExtensionResponse( @@ -58,16 +63,12 @@ public InitializeExtensionResponse handleExtensionInitRequest(InitializeExtensio // After sending successful response to initialization, send the REST API and Settings extensionsRunner.setOpensearchNode(extensionsRunner.opensearchNode); extensionsRunner.setExtensionNode(extensionInitRequest.getExtension()); + // TODO: replace with sdkTransportService.getTransportService() TransportService extensionTransportService = extensionsRunner.getExtensionTransportService(); extensionTransportService.connectToNode(extensionsRunner.opensearchNode); extensionsRunner.sendRegisterRestActionsRequest(extensionTransportService); extensionsRunner.sendRegisterCustomSettingsRequest(extensionTransportService); - extensionsRunner.getSdkActionModule() - .sendRegisterTransportActionsRequest( - extensionTransportService, - extensionsRunner.opensearchNode, - extensionsRunner.getUniqueId() - ); + sdkTransportService.sendRegisterTransportActionsRequest(extensionsRunner.getSdkActionModule().getActions()); // Get OpenSearch Settings and set values on ExtensionsRunner Settings settings = extensionsRunner.sendEnvironmentSettingsRequest(extensionTransportService); extensionsRunner.setEnvironmentSettings(settings); diff --git a/src/main/java/org/opensearch/sdk/sample/helloworld/HelloWorldExtension.java b/src/main/java/org/opensearch/sdk/sample/helloworld/HelloWorldExtension.java index faf48e03..edbcc7e2 100644 --- a/src/main/java/org/opensearch/sdk/sample/helloworld/HelloWorldExtension.java +++ b/src/main/java/org/opensearch/sdk/sample/helloworld/HelloWorldExtension.java @@ -22,8 +22,8 @@ import org.opensearch.sdk.ExtensionSettings; import org.opensearch.sdk.ExtensionsRunner; import org.opensearch.sdk.ActionExtension; -import org.opensearch.sdk.ActionExtension.ActionHandler; import org.opensearch.sdk.sample.helloworld.rest.RestHelloAction; +import org.opensearch.sdk.sample.helloworld.rest.RestRemoteHelloAction; import org.opensearch.sdk.sample.helloworld.transport.SampleAction; import org.opensearch.sdk.sample.helloworld.transport.SampleTransportAction; @@ -56,7 +56,7 @@ public HelloWorldExtension() { @Override public List getExtensionRestHandlers() { - return List.of(new RestHelloAction()); + return List.of(new RestHelloAction(), new RestRemoteHelloAction(extensionsRunner())); } @Override diff --git a/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestRemoteHelloAction.java b/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestRemoteHelloAction.java new file mode 100644 index 00000000..9ce93adf --- /dev/null +++ b/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestRemoteHelloAction.java @@ -0,0 +1,94 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.sample.helloworld.rest; + +import org.opensearch.action.ActionListener; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.extensions.action.RemoteExtensionActionResponse; +import org.opensearch.extensions.rest.ExtensionRestRequest; +import org.opensearch.extensions.rest.ExtensionRestResponse; +import org.opensearch.sdk.BaseExtensionRestHandler; +import org.opensearch.sdk.ExtensionsRunner; +import org.opensearch.sdk.RouteHandler; +import org.opensearch.sdk.SDKClient; +import org.opensearch.sdk.action.RemoteExtensionAction; +import org.opensearch.sdk.action.RemoteExtensionActionRequest; +import org.opensearch.sdk.sample.helloworld.transport.SampleAction; +import org.opensearch.sdk.sample.helloworld.transport.SampleRequest; +import org.opensearch.sdk.sample.helloworld.transport.SampleResponse; + +import java.util.List; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; +import java.util.function.Function; + +import static org.opensearch.rest.RestRequest.Method.GET; +import static org.opensearch.rest.RestStatus.OK; + +/** + * Sample REST Handler demostrating proxy actions to another extension + */ +public class RestRemoteHelloAction extends BaseExtensionRestHandler { + + private ExtensionsRunner extensionsRunner; + + /** + * Instantiate this action + * + * @param runner The ExtensionsRunner instance + */ + public RestRemoteHelloAction(ExtensionsRunner runner) { + this.extensionsRunner = runner; + } + + @Override + public List routeHandlers() { + return List.of(new RouteHandler(GET, "/hello/{name}", handleRemoteGetRequest)); + } + + private Function handleRemoteGetRequest = (request) -> { + SDKClient client = extensionsRunner.getSdkClient(); + + String name = request.param("name"); + // Create a request using class on remote + // This class happens to be local for simplicity but is a class on the remote extension + SampleRequest sampleRequest = new SampleRequest(name); + + // Serialize this request in a proxy action request + // This requires that the remote extension has a corresponding transport action registered + // This Action class happens to be local for simplicity but is a class on the remote extension + RemoteExtensionActionRequest proxyActionRequest = new RemoteExtensionActionRequest(SampleAction.INSTANCE, sampleRequest); + + // TODO: We need async client.execute to hide these action listener details and return the future directly + // https://github.com/opensearch-project/opensearch-sdk-java/issues/584 + CompletableFuture futureResponse = new CompletableFuture<>(); + client.execute( + RemoteExtensionAction.INSTANCE, + proxyActionRequest, + ActionListener.wrap(r -> futureResponse.complete(r), e -> futureResponse.completeExceptionally(e)) + ); + try { + RemoteExtensionActionResponse response = futureResponse.orTimeout( + ExtensionsManager.EXTENSION_REQUEST_WAIT_TIMEOUT, + TimeUnit.SECONDS + ).get(); + if (!response.isSuccess()) { + return new ExtensionRestResponse(request, OK, "Remote extension reponse failed: " + response.getResponseBytesAsString()); + } + // Parse out the expected response class from the bytes + SampleResponse sampleResponse = new SampleResponse(StreamInput.wrap(response.getResponseBytes())); + return new ExtensionRestResponse(request, OK, "Received greeting from remote extension: " + sampleResponse.getGreeting()); + } catch (Exception e) { + return exceptionalRequest(request, e); + } + }; + +} diff --git a/src/test/java/org/opensearch/sdk/TestExtensionInterfaces.java b/src/test/java/org/opensearch/sdk/TestExtensionInterfaces.java index 6e2a281f..debef1b4 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionInterfaces.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionInterfaces.java @@ -24,7 +24,6 @@ import org.opensearch.ingest.Processor; import org.opensearch.test.OpenSearchTestCase; -import java.util.Map; import java.util.function.Predicate; public class TestExtensionInterfaces extends OpenSearchTestCase { @@ -36,6 +35,9 @@ void testExtension() { public ExtensionSettings getExtensionSettings() { return null; } + + @Override + public void setExtensionsRunner(ExtensionsRunner runner) {} }; assertTrue(extension.getSettings().isEmpty()); diff --git a/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java b/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java index 40447330..93bf1b61 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java @@ -100,7 +100,7 @@ public void testStartTransportService() { verify(transportService, times(1)).start(); // cannot verify acceptIncomingRequests as it is a final method // test registerRequestHandlers - verify(transportService, times(5)).registerRequestHandler(anyString(), anyString(), anyBoolean(), anyBoolean(), any(), any()); + verify(transportService, times(7)).registerRequestHandler(anyString(), anyString(), anyBoolean(), anyBoolean(), any(), any()); } @Test diff --git a/src/test/java/org/opensearch/sdk/TestSDKTransportService.java b/src/test/java/org/opensearch/sdk/TestSDKTransportService.java new file mode 100644 index 00000000..60a1bc2a --- /dev/null +++ b/src/test/java/org/opensearch/sdk/TestSDKTransportService.java @@ -0,0 +1,104 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.settings.Settings; +import org.opensearch.common.transport.TransportAddress; +import org.opensearch.extensions.ExtensionsManager; +import org.opensearch.extensions.action.RegisterTransportActionsRequest; +import org.opensearch.sdk.action.RemoteExtensionAction; +import org.opensearch.sdk.action.SDKActionModule; +import org.opensearch.sdk.action.TestSDKActionModule; +import org.opensearch.sdk.handlers.AcknowledgedResponseHandler; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.Transport; +import org.opensearch.transport.TransportService; + +import java.net.InetAddress; +import java.util.Collections; + +import static java.util.Collections.emptyMap; +import static java.util.Collections.emptySet; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class TestSDKTransportService extends OpenSearchTestCase { + + private static final String TEST_UNIQUE_ID = "test-extension"; + + private TransportService transportService; + private DiscoveryNode opensearchNode; + private SDKActionModule sdkActionModule; + private SDKTransportService sdkTransportService; + + @Override + @BeforeEach + public void setUp() throws Exception { + super.setUp(); + this.transportService = spy( + new TransportService( + Settings.EMPTY, + mock(Transport.class), + null, + TransportService.NOOP_TRANSPORT_INTERCEPTOR, + x -> null, + null, + Collections.emptySet() + ) + ); + this.opensearchNode = new DiscoveryNode( + "test_node", + new TransportAddress(InetAddress.getByName("localhost"), 9876), + emptyMap(), + emptySet(), + Version.CURRENT + ); + sdkActionModule = new SDKActionModule(new TestSDKActionModule.TestActionExtension()); + + sdkTransportService = new SDKTransportService(); + sdkTransportService.setTransportService(transportService); + sdkTransportService.setOpensearchNode(opensearchNode); + sdkTransportService.setUniqueId(TEST_UNIQUE_ID); + } + + @Test + public void testRegisterTransportAction() { + ArgumentCaptor registerTransportActionsRequestCaptor = ArgumentCaptor.forClass( + RegisterTransportActionsRequest.class + ); + + sdkTransportService.sendRegisterTransportActionsRequest(sdkActionModule.getActions()); + verify(transportService, times(1)).sendRequest( + any(), + eq(ExtensionsManager.REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS), + registerTransportActionsRequestCaptor.capture(), + any(AcknowledgedResponseHandler.class) + ); + assertEquals(TEST_UNIQUE_ID, registerTransportActionsRequestCaptor.getValue().getUniqueId()); + // Should contain the TestAction, but since it's mocked the name may change + assertTrue( + registerTransportActionsRequestCaptor.getValue() + .getTransportActions() + .stream() + .anyMatch(s -> s.startsWith("org.opensearch.action.ActionType$MockitoMock$")) + ); + // Internal action should be filtered out + assertFalse(registerTransportActionsRequestCaptor.getValue().getTransportActions().contains(RemoteExtensionAction.class.getName())); + } +} diff --git a/src/test/java/org/opensearch/sdk/action/TestProxyActionRequest.java b/src/test/java/org/opensearch/sdk/action/TestProxyActionRequest.java new file mode 100644 index 00000000..417fbdac --- /dev/null +++ b/src/test/java/org/opensearch/sdk/action/TestProxyActionRequest.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.sdk.action; + +import java.io.IOException; + +import org.junit.jupiter.api.Test; +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ActionResponse; +import org.opensearch.action.ActionType; +import org.opensearch.common.bytes.BytesReference; +import org.opensearch.common.io.stream.BytesStreamInput; +import org.opensearch.common.io.stream.BytesStreamOutput; +import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.common.io.stream.StreamOutput; +import org.opensearch.test.OpenSearchTestCase; + +public class TestProxyActionRequest extends OpenSearchTestCase { + + @Test + public void testProxyActionRequest() throws Exception { + TestRequest testRequest = new TestRequest("test-action"); + + String expectedAction = TestAction.class.getName(); + String expectedRequestClass = testRequest.getClass().getName(); + byte[] expectedRequestBytes; + try (BytesStreamOutput out = new BytesStreamOutput()) { + testRequest.writeTo(out); + expectedRequestBytes = BytesReference.toBytes(out.bytes()); + } + + RemoteExtensionActionRequest request = new RemoteExtensionActionRequest(TestAction.INSTANCE, testRequest); + assertEquals(expectedAction, request.getAction()); + assertEquals(expectedRequestClass, request.getRequestClass()); + assertArrayEquals(expectedRequestBytes, request.getRequestBytes()); + + request = new RemoteExtensionActionRequest(expectedAction, expectedRequestClass, expectedRequestBytes); + + try (BytesStreamOutput out = new BytesStreamOutput()) { + request.writeTo(out); + out.flush(); + try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { + request = new RemoteExtensionActionRequest(in); + + assertEquals(expectedAction, request.getAction()); + assertEquals(expectedRequestClass, request.getRequestClass()); + assertArrayEquals(expectedRequestBytes, request.getRequestBytes()); + } + } + } + + static class TestRequest extends ActionRequest { + + private String data; + + public TestRequest(String data) { + this.data = data; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(data); + } + + @Override + public ActionRequestValidationException validate() { + return null; + } + } + + static class TestResponse extends ActionResponse { + public TestResponse(StreamInput in) {} + + @Override + public void writeTo(StreamOutput out) throws IOException {} + } + + static class TestAction extends ActionType { + + public static final String NAME = "test"; + public static final TestAction INSTANCE = new TestAction(); + + private TestAction() { + super(NAME, TestResponse::new); + } + + } + +} diff --git a/src/test/java/org/opensearch/sdk/action/TestSDKActionModule.java b/src/test/java/org/opensearch/sdk/action/TestSDKActionModule.java index 1c37989a..83e34fdf 100644 --- a/src/test/java/org/opensearch/sdk/action/TestSDKActionModule.java +++ b/src/test/java/org/opensearch/sdk/action/TestSDKActionModule.java @@ -11,52 +11,29 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; -import org.mockito.ArgumentCaptor; -import org.opensearch.Version; import org.opensearch.action.ActionRequest; import org.opensearch.action.ActionResponse; import org.opensearch.action.ActionType; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.settings.Settings; -import org.opensearch.common.transport.TransportAddress; -import org.opensearch.extensions.ExtensionsManager; -import org.opensearch.extensions.RegisterTransportActionsRequest; import org.opensearch.sdk.ActionExtension; -import org.opensearch.sdk.Extension; +import org.opensearch.sdk.BaseExtension; import org.opensearch.sdk.ExtensionSettings; -import org.opensearch.sdk.handlers.AcknowledgedResponseHandler; import org.opensearch.test.OpenSearchTestCase; -import org.opensearch.transport.Transport; -import org.opensearch.transport.TransportService; -import java.net.InetAddress; import java.util.Arrays; -import java.util.Collections; import java.util.List; -import java.util.Set; -import static java.util.Collections.emptyMap; -import static java.util.Collections.emptySet; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; public class TestSDKActionModule extends OpenSearchTestCase { - private static final String TEST_UNIQUE_ID = "test-extension"; - private static final String TEST_ACTION_NAME = "testAction"; + public static final String TEST_ACTION_NAME = "testAction"; - private TransportService transportService; - private DiscoveryNode opensearchNode; + private SDKActionModule sdkActionModule; - private static class TestActionExtension implements Extension, ActionExtension { - @Override - public ExtensionSettings getExtensionSettings() { - return null; + public static class TestActionExtension extends BaseExtension implements ActionExtension { + public TestActionExtension() { + super(mock(ExtensionSettings.class)); } @Override @@ -69,45 +46,17 @@ public ExtensionSettings getExtensionSettings() { } } - private SDKActionModule sdkActionModule = new SDKActionModule(new TestActionExtension()); - @Override @BeforeEach public void setUp() throws Exception { super.setUp(); - this.transportService = spy( - new TransportService( - Settings.EMPTY, - mock(Transport.class), - null, - TransportService.NOOP_TRANSPORT_INTERCEPTOR, - x -> null, - null, - Collections.emptySet() - ) - ); - this.opensearchNode = new DiscoveryNode( - "test_node", - new TransportAddress(InetAddress.getByName("localhost"), 9876), - emptyMap(), - emptySet(), - Version.CURRENT - ); + sdkActionModule = new SDKActionModule(new TestActionExtension()); } @Test - public void testRegisterTransportAction() { - ArgumentCaptor registerTransportActionsRequestCaptor = ArgumentCaptor.forClass( - RegisterTransportActionsRequest.class - ); - sdkActionModule.sendRegisterTransportActionsRequest(transportService, opensearchNode, TEST_UNIQUE_ID); - verify(transportService, times(1)).sendRequest( - any(), - eq(ExtensionsManager.REQUEST_EXTENSION_REGISTER_TRANSPORT_ACTIONS), - registerTransportActionsRequestCaptor.capture(), - any(AcknowledgedResponseHandler.class) - ); - assertEquals(TEST_UNIQUE_ID, registerTransportActionsRequestCaptor.getValue().getUniqueId()); - assertEquals(Set.of(TEST_ACTION_NAME), registerTransportActionsRequestCaptor.getValue().getTransportActions()); + public void testGetActions() { + assertEquals(2, sdkActionModule.getActions().size()); + assertTrue(sdkActionModule.getActions().containsKey(RemoteExtensionAction.NAME)); + assertTrue(sdkActionModule.getActions().containsKey(TEST_ACTION_NAME)); } } diff --git a/src/test/java/org/opensearch/sdk/sample/helloworld/TestHelloWorldExtension.java b/src/test/java/org/opensearch/sdk/sample/helloworld/TestHelloWorldExtension.java index 27d8ce70..564ac50e 100644 --- a/src/test/java/org/opensearch/sdk/sample/helloworld/TestHelloWorldExtension.java +++ b/src/test/java/org/opensearch/sdk/sample/helloworld/TestHelloWorldExtension.java @@ -25,7 +25,6 @@ import org.opensearch.action.ActionType; import org.opensearch.action.support.TransportAction; import org.opensearch.common.settings.Settings; -import org.opensearch.rest.RestHandler.Route; import org.opensearch.sdk.ActionExtension.ActionHandler; import org.opensearch.sdk.sample.helloworld.transport.SampleAction; import org.opensearch.sdk.sample.helloworld.transport.SampleRequest; @@ -46,6 +45,7 @@ import static org.opensearch.sdk.sample.helloworld.ExampleCustomSettingConfig.VALIDATED_SETTING; +@SuppressWarnings("deprecation") public class TestHelloWorldExtension extends OpenSearchTestCase { private HelloWorldExtension extension; @@ -110,9 +110,9 @@ public void testExtensionSettings() { @Test public void testExtensionRestHandlers() { List extensionRestHandlers = extension.getExtensionRestHandlers(); - assertEquals(1, extensionRestHandlers.size()); - List routes = extensionRestHandlers.get(0).routes(); - assertEquals(4, routes.size()); + assertEquals(2, extensionRestHandlers.size()); + assertEquals(4, extensionRestHandlers.get(0).routes().size()); + assertEquals(1, extensionRestHandlers.get(1).routes().size()); } @Test @@ -121,6 +121,12 @@ public void testGetActions() { assertEquals(1, actions.size()); } + @Test + public void testClientGetActionFromClassName() { + ActionType action = SampleAction.INSTANCE; + assertEquals(action, sdkClient.getActionFromClassName(action.getClass().getName())); + } + @Test public void testClientExecuteSampleActions() throws Exception { String expectedName = "world"; @@ -144,7 +150,6 @@ public void onFailure(Exception e) { assertEquals(expectedGreeting, response.getGreeting()); } - @SuppressWarnings("deprecation") @Test public void testRestClientExecuteSampleActions() throws Exception { String expectedName = "world"; @@ -192,7 +197,6 @@ public void onFailure(Exception e) { assertEquals("The request name is blank.", cause.getMessage()); } - @SuppressWarnings("deprecation") @Test public void testExceptionalRestClientExecuteSampleActions() throws Exception { String expectedName = "";