diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java b/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java index c62095e8..4f96b980 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java +++ b/src/main/java/org/opensearch/sdk/ExtensionRestHandler.java @@ -9,6 +9,7 @@ import java.util.List; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.rest.BaseRestHandler; import org.opensearch.rest.RestHandler; import org.opensearch.rest.RestHandler.Route; @@ -29,12 +30,11 @@ public interface ExtensionRestHandler { /** * Handles REST Requests forwarded from OpenSearch for a configured route on an extension. - * Parameters are components of the {@link RestRequest} received from a user. + * Parameter contains components of the {@link RestRequest} received from a user. * This method corresponds to the {@link BaseRestHandler#prepareRequest} method. - * As in that method, consumed parameters must be tracked and returned in the response. * - * @param restRequest a REST request object for a request to be forwarded to extensions + * @param request a REST request object for a request to be forwarded to extensions * @return An {@link ExtensionRestResponse} to the request. */ - ExtensionRestResponse handleRequest(ExtensionRestRequest restRequest); + ExtensionRestResponse handleRequest(ExtensionRestRequest request); } diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestPathRegistry.java b/src/main/java/org/opensearch/sdk/ExtensionRestPathRegistry.java index 26d0c556..56bd850d 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionRestPathRegistry.java +++ b/src/main/java/org/opensearch/sdk/ExtensionRestPathRegistry.java @@ -28,11 +28,11 @@ public class ExtensionRestPathRegistry { * Register a REST handler to handle a method and route in this extension's path registry. * * @param method The method to register. - * @param uri The URI to register. May include named wildcards. + * @param path The path to register. May include named wildcards. * @param extensionRestHandler The RestHandler to handle this route */ - public void registerHandler(Method method, String uri, ExtensionRestHandler extensionRestHandler) { - String restPath = restPathToString(method, uri); + public void registerHandler(Method method, String path, ExtensionRestHandler extensionRestHandler) { + String restPath = restPathToString(method, path); pathTrie.insert(restPath, extensionRestHandler); registeredPaths.add(restPath); } @@ -41,11 +41,11 @@ public void registerHandler(Method method, String uri, ExtensionRestHandler exte * Get the registered REST handler for the specified method and URI. * * @param method the registered method. - * @param uri the registered URI. + * @param path the registered path. * @return The REST handler registered to handle this method and URI combination if found, null otherwise. */ - public ExtensionRestHandler getHandler(Method method, String uri) { - return pathTrie.retrieve(restPathToString(method, uri)); + public ExtensionRestHandler getHandler(Method method, String path) { + return pathTrie.retrieve(restPathToString(method, path)); } /** @@ -58,13 +58,13 @@ public List getRegisteredPaths() { } /** - * Converts a REST method and URI to a string. + * Converts a REST method and path to a space delimited string to be used as a map lookup key. * * @param method the method. - * @param uri the URI. - * @return A string appending the method and URI. + * @param path the path. + * @return A string appending the method and path. */ - public static String restPathToString(Method method, String uri) { - return method.name() + " " + uri; + public static String restPathToString(Method method, String path) { + return method.name() + " " + path; } } diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestRequest.java b/src/main/java/org/opensearch/sdk/ExtensionRestRequest.java deleted file mode 100644 index 5f170a59..00000000 --- a/src/main/java/org/opensearch/sdk/ExtensionRestRequest.java +++ /dev/null @@ -1,121 +0,0 @@ -/* - * SPDX-License-Identifier: Apache-2.0 - * - * The OpenSearch Contributors require contributions made to - * this file be licensed under the Apache-2.0 license or a - * compatible open source license. - */ -package org.opensearch.sdk; - -import org.opensearch.common.io.stream.StreamInput; -import org.opensearch.common.io.stream.StreamOutput; -import org.opensearch.extensions.rest.RestExecuteOnExtensionRequest; -import org.opensearch.identity.PrincipalIdentifierToken; -import org.opensearch.rest.RestRequest.Method; -import org.opensearch.transport.TransportRequest; - -import java.io.IOException; -import java.util.Objects; - -/** - * A subclass of {@link TransportRequest} which contains request relevant information - * to be utilised in ExtensionRestHandler implementation - */ -public class ExtensionRestRequest extends TransportRequest { - private Method method; - private String uri; - /** - * The owner of this request object - */ - private PrincipalIdentifierToken principalIdentifierToken; - - /** - * This object can be instantiated given method, uri and identifier - * @param method of type {@link Method} - * @param uri url string - * @param principalIdentifier the owner of this request - */ - public ExtensionRestRequest(Method method, String uri, PrincipalIdentifierToken principalIdentifier) { - this.method = method; - this.uri = uri; - this.principalIdentifierToken = principalIdentifier; - } - - /** - * The object to be created from rest request object incoming from OpenSearch - * @param request incoming object from OpenSearch - * @throws IllegalArgumentException when request is null - */ - protected ExtensionRestRequest(RestExecuteOnExtensionRequest request) throws IllegalArgumentException { - if (request == null) throw new IllegalArgumentException("Request object can't be null"); - this.method = request.getMethod(); - this.uri = request.getUri(); - this.principalIdentifierToken = request.getRequestIssuerIdentity(); - } - - /** - * Object generated from input stream - * @param in Input stream - * @throws IOException if there's an error in generating object from input stream - */ - public ExtensionRestRequest(StreamInput in) throws IOException { - super(in); - method = in.readEnum(Method.class); - uri = in.readString(); - principalIdentifierToken = in.readNamedWriteable(PrincipalIdentifierToken.class); - } - - /** - * Write this object to output stream - * @param out the writeable output stream - * @throws IOException if there's an error in generating object from output stream - */ - @Override - public void writeTo(StreamOutput out) throws IOException { - super.writeTo(out); - out.writeEnum(method); - out.writeString(uri); - out.writeNamedWriteable(principalIdentifierToken); - } - - /** - * @return This REST request {@link Method} type - */ - public Method method() { - return method; - } - - /** - * @return This REST request's uri - */ - public String uri() { - return uri; - } - - /** - * @return This REST request issuer's identity token - */ - public PrincipalIdentifierToken getRequestIssuerIdentity() { - return principalIdentifierToken; - } - - @Override - public String toString() { - return "ExtensionRestRequest{method=" + method + ", uri=" + uri + ", requester = " + principalIdentifierToken.getToken() + "}"; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; - ExtensionRestRequest that = (ExtensionRestRequest) obj; - return Objects.equals(method, that.method) - && Objects.equals(uri, that.uri) - && Objects.equals(principalIdentifierToken, that.principalIdentifierToken); - } - - @Override - public int hashCode() { - return Objects.hash(method, uri, principalIdentifierToken); - } -} diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java b/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java index c623c3ec..2b915cf3 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java +++ b/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java @@ -11,6 +11,7 @@ import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentBuilder; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestStatus; @@ -23,71 +24,76 @@ public class ExtensionRestResponse extends BytesRestResponse { * Key passed in {@link BytesRestResponse} headers to identify parameters consumed by the handler. For internal use. */ static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters"; + /** + * Key passed in {@link BytesRestResponse} headers to identify content consumed by the handler. For internal use. + */ + static final String CONSUMED_CONTENT_KEY = "extension.consumed.content"; /** * Creates a new response based on {@link XContentBuilder}. * + * @param request the REST request being responded to. * @param status The REST status. * @param builder The builder for the response. - * @param consumedParams Parameters consumed by the handler. */ - public ExtensionRestResponse(RestStatus status, XContentBuilder builder, List consumedParams) { + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, XContentBuilder builder) { super(status, builder); - addConsumedParamHeader(consumedParams); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** * Creates a new plain text response. * + * @param request the REST request being responded to. * @param status The REST status. * @param content A plain text response string. - * @param consumedParams Parameters consumed by the handler. */ - public ExtensionRestResponse(RestStatus status, String content, List consumedParams) { + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String content) { super(status, content); - addConsumedParamHeader(consumedParams); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** * Creates a new plain text response. * + * @param request the REST request being responded to. * @param status The REST status. * @param contentType The content type of the response string. * @param content A response string. - * @param consumedParams Parameters consumed by the handler. */ - public ExtensionRestResponse(RestStatus status, String contentType, String content, List consumedParams) { + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, String content) { super(status, contentType, content); - addConsumedParamHeader(consumedParams); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** * Creates a binary response. * + * @param request the REST request being responded to. * @param status The REST status. * @param contentType The content type of the response bytes. * @param content Response bytes. - * @param consumedParams Parameters consumed by the handler. */ - public ExtensionRestResponse(RestStatus status, String contentType, byte[] content, List consumedParams) { + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, byte[] content) { super(status, contentType, content); - addConsumedParamHeader(consumedParams); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** * Creates a binary response. * + * @param request the REST request being responded to. * @param status The REST status. * @param contentType The content type of the response bytes. * @param content Response bytes. - * @param consumedParams Parameters consumed by the handler. */ - public ExtensionRestResponse(RestStatus status, String contentType, BytesReference content, List consumedParams) { + public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, BytesReference content) { super(status, contentType, content); - addConsumedParamHeader(consumedParams); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } - private void addConsumedParamHeader(List consumedParams) { + private void addConsumedHeaders(List consumedParams, boolean contentConusmed) { consumedParams.stream().forEach(p -> addHeader(CONSUMED_PARAMS_KEY, p)); + addHeader(CONSUMED_CONTENT_KEY, Boolean.toString(contentConusmed)); } } diff --git a/src/main/java/org/opensearch/sdk/ExtensionsRunner.java b/src/main/java/org/opensearch/sdk/ExtensionsRunner.java index 6ad1d1e9..062deef2 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionsRunner.java +++ b/src/main/java/org/opensearch/sdk/ExtensionsRunner.java @@ -18,8 +18,8 @@ import org.opensearch.common.io.stream.NamedWriteableRegistry; import org.opensearch.common.io.stream.NamedWriteableRegistryParseRequest; import org.opensearch.extensions.OpenSearchRequest; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.extensions.rest.RegisterRestActionsRequest; -import org.opensearch.extensions.rest.RestExecuteOnExtensionRequest; import org.opensearch.extensions.settings.RegisterCustomSettingsRequest; import org.opensearch.common.network.NetworkModule; import org.opensearch.common.network.NetworkService; @@ -357,7 +357,7 @@ public void startTransportService(TransportService transportService) { ThreadPool.Names.GENERIC, false, false, - RestExecuteOnExtensionRequest::new, + ExtensionRestRequest::new, ((request, channel, task) -> channel.sendResponse(extensionsRestRequestHandler.handleRestExecuteOnExtensionRequest(request))) ); @@ -477,7 +477,7 @@ public void sendLocalNodeRequest(TransportService transportService) { * Requests the ActionListener onFailure method to be run by OpenSearch. The result will be handled by a {@link ActionListenerOnFailureResponseHandler}. * * @param transportService The TransportService defining the connection to OpenSearch. - * @param failureException The exception to be sent to OpenSearch + * @param failureException The exception to be sent to OpenSearch */ public void sendActionListenerOnFailureRequest(TransportService transportService, Exception failureException) { logger.info("Sending ActionListener onFailure request to OpenSearch"); diff --git a/src/main/java/org/opensearch/sdk/handlers/ExtensionsRestRequestHandler.java b/src/main/java/org/opensearch/sdk/handlers/ExtensionsRestRequestHandler.java index 5dbe0792..f777347c 100644 --- a/src/main/java/org/opensearch/sdk/handlers/ExtensionsRestRequestHandler.java +++ b/src/main/java/org/opensearch/sdk/handlers/ExtensionsRestRequestHandler.java @@ -10,13 +10,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.bytes.BytesReference; -import org.opensearch.extensions.rest.RestExecuteOnExtensionRequest; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.extensions.rest.RestExecuteOnExtensionResponse; import org.opensearch.rest.RestStatus; import org.opensearch.sdk.ExtensionRestHandler; import org.opensearch.sdk.ExtensionsRunner; import org.opensearch.sdk.ExtensionRestPathRegistry; -import org.opensearch.sdk.ExtensionRestRequest; import org.opensearch.sdk.ExtensionRestResponse; /** @@ -33,24 +32,18 @@ public class ExtensionsRestRequestHandler { * @param request The REST request to execute. * @return A response acknowledging the request. */ - public RestExecuteOnExtensionResponse handleRestExecuteOnExtensionRequest(RestExecuteOnExtensionRequest request) { + public RestExecuteOnExtensionResponse handleRestExecuteOnExtensionRequest(ExtensionRestRequest request) { - ExtensionRestHandler restHandler = extensionRestPathRegistry.getHandler(request.getMethod(), request.getUri()); + ExtensionRestHandler restHandler = extensionRestPathRegistry.getHandler(request.method(), request.path()); if (restHandler == null) { return new RestExecuteOnExtensionResponse( RestStatus.NOT_FOUND, - "No handler for " + ExtensionRestPathRegistry.restPathToString(request.getMethod(), request.getUri()) + "No handler for " + ExtensionRestPathRegistry.restPathToString(request.method(), request.path()) ); } - // ExtensionRestRequest restRequest = new ExtensionRestRequest(request); - ExtensionRestRequest restRequest = new ExtensionRestRequest( - request.getMethod(), - request.getUri(), - request.getRequestIssuerIdentity() - ); // Get response from extension - ExtensionRestResponse response = restHandler.handleRequest(restRequest); + ExtensionRestResponse response = restHandler.handleRequest(request); logger.info("Sending extension response to OpenSearch: " + response.status()); return new RestExecuteOnExtensionResponse( response.status(), diff --git a/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java b/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java index 185fcff0..4e5ce476 100644 --- a/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java +++ b/src/main/java/org/opensearch/sdk/sample/helloworld/rest/RestHelloAction.java @@ -7,15 +7,14 @@ */ package org.opensearch.sdk.sample.helloworld.rest; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest.Method; import org.opensearch.sdk.ExtensionRestHandler; -import org.opensearch.sdk.ExtensionRestRequest; import org.opensearch.sdk.ExtensionRestResponse; import java.net.URLDecoder; import java.nio.charset.StandardCharsets; -import java.util.ArrayList; import java.util.List; import static org.opensearch.rest.RestRequest.Method.GET; @@ -39,31 +38,32 @@ public List routes() { @Override public ExtensionRestResponse handleRequest(ExtensionRestRequest request) { - // We need to track which parameters are consumed to pass back to OpenSearch - List consumedParams = new ArrayList<>(); Method method = request.method(); - String uri = request.uri(); - if (Method.GET.equals(method) && "/hello".equals(uri)) { - return new ExtensionRestResponse(OK, String.format(GREETING, worldName), consumedParams); - } else if (Method.PUT.equals(method) && uri.startsWith("/hello/")) { - // Placeholder code here for parameters in named wildcard paths - // Full implementation based on params() will be implemented as part of - // https://github.com/opensearch-project/opensearch-sdk-java/issues/111 - String name = uri.substring("/hello/".length()); - consumedParams.add("name"); - try { - worldName = URLDecoder.decode(name, StandardCharsets.UTF_8); - } catch (IllegalArgumentException e) { - return new ExtensionRestResponse(BAD_REQUEST, e.getMessage(), consumedParams); - } - return new ExtensionRestResponse(OK, "Updated the world's name to " + worldName, consumedParams); + if (Method.GET.equals(method)) { + return handleGetRequest(request); + } else if (Method.PUT.equals(method)) { + return handlePutRequest(request); } - return new ExtensionRestResponse( - NOT_FOUND, - "Extension REST action improperly configured to handle " + method.name() + " " + uri, - consumedParams - ); + return handleBadRequest(request); + } + + private ExtensionRestResponse handleGetRequest(ExtensionRestRequest request) { + return new ExtensionRestResponse(request, OK, String.format(GREETING, worldName)); + } + + private ExtensionRestResponse handlePutRequest(ExtensionRestRequest request) { + String name = request.param("name"); + try { + worldName = URLDecoder.decode(name, StandardCharsets.UTF_8); + } catch (IllegalArgumentException e) { + return new ExtensionRestResponse(request, BAD_REQUEST, e.getMessage()); + } + return new ExtensionRestResponse(request, OK, "Updated the world's name to " + worldName); + } + + private ExtensionRestResponse handleBadRequest(ExtensionRestRequest request) { + return new ExtensionRestResponse(request, NOT_FOUND, "Extension REST action improperly configured to handle " + request.toString()); } } diff --git a/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java b/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java index 8a0a2cb7..d6f2ae38 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionRestPathRegistry.java @@ -4,6 +4,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest.Method; import org.opensearch.test.OpenSearchTestCase; diff --git a/src/test/java/org/opensearch/sdk/TestExtensionRestRequest.java b/src/test/java/org/opensearch/sdk/TestExtensionRestRequest.java deleted file mode 100644 index 0187dd62..00000000 --- a/src/test/java/org/opensearch/sdk/TestExtensionRestRequest.java +++ /dev/null @@ -1,57 +0,0 @@ -package org.opensearch.sdk; - -import org.junit.jupiter.api.Test; -import org.opensearch.common.bytes.BytesReference; -import org.opensearch.common.io.stream.BytesStreamInput; -import org.opensearch.common.io.stream.BytesStreamOutput; -import org.opensearch.common.io.stream.NamedWriteableAwareStreamInput; -import org.opensearch.common.io.stream.NamedWriteableRegistry; -import org.opensearch.identity.ExtensionTokenProcessor; -import org.opensearch.identity.PrincipalIdentifierToken; -import org.opensearch.rest.RestRequest; -import org.opensearch.test.OpenSearchTestCase; - -import java.security.Principal; - -public class TestExtensionRestRequest extends OpenSearchTestCase { - - @Test - public void testExtensionRestRequest() throws Exception { - RestRequest.Method expectedMethod = RestRequest.Method.GET; - String expectedUri = "/test/uri"; - String extensionUniqueId1 = "ext_1"; - Principal userPrincipal = () -> "user1"; - ExtensionTokenProcessor extensionTokenProcessor = new ExtensionTokenProcessor(extensionUniqueId1); - PrincipalIdentifierToken expectedRequestIssuerIdentity = extensionTokenProcessor.generateToken(userPrincipal); - NamedWriteableRegistry registry = new NamedWriteableRegistry( - org.opensearch.common.collect.List.of( - new NamedWriteableRegistry.Entry( - PrincipalIdentifierToken.class, - PrincipalIdentifierToken.NAME, - PrincipalIdentifierToken::new - ) - ) - ); - - ExtensionRestRequest request = new ExtensionRestRequest(expectedMethod, expectedUri, expectedRequestIssuerIdentity); - - assertEquals(expectedMethod, request.method()); - assertEquals(expectedUri, request.uri()); - assertEquals(expectedRequestIssuerIdentity, request.getRequestIssuerIdentity()); - - try (BytesStreamOutput out = new BytesStreamOutput()) { - request.writeTo(out); - out.flush(); - try (BytesStreamInput in = new BytesStreamInput(BytesReference.toBytes(out.bytes()))) { - try (NamedWriteableAwareStreamInput nameWritableAwareIn = new NamedWriteableAwareStreamInput(in, registry)) { - request = new ExtensionRestRequest(nameWritableAwareIn); - } - - assertEquals(expectedMethod, request.method()); - assertEquals(expectedUri, request.uri()); - assertEquals(expectedRequestIssuerIdentity, request.getRequestIssuerIdentity()); - } - } - } - -} diff --git a/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java b/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java index e11aed5e..4bd982da 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java @@ -2,13 +2,17 @@ import java.io.IOException; import java.nio.ByteBuffer; +import java.util.Collections; import java.util.List; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentType; +import org.opensearch.extensions.rest.ExtensionRestRequest; +import org.opensearch.rest.RestRequest.Method; import org.opensearch.test.OpenSearchTestCase; import static org.opensearch.rest.BytesRestResponse.TEXT_CONTENT_TYPE; @@ -23,7 +27,7 @@ public class TestExtensionRestResponse extends OpenSearchTestCase { private String testText; private byte[] testBytes; - private List testConsumedParams; + private ExtensionRestRequest request; @Override @BeforeEach @@ -31,7 +35,12 @@ public void setUp() throws Exception { super.setUp(); testText = "plain text"; testBytes = new byte[] { 1, 2 }; - testConsumedParams = List.of("foo", "bar"); + request = new ExtensionRestRequest(Method.GET, "/foo", Collections.emptyMap(), null, new BytesArray("Text Content"), null); + // consume params "foo" and "bar" + request.param("foo"); + request.param("bar"); + // consume content + request.content(); } @Test @@ -40,33 +49,35 @@ public void testConstructorWithBuilder() throws IOException { builder.startObject(); builder.field("status", ACCEPTED); builder.endObject(); - ExtensionRestResponse response = new ExtensionRestResponse(OK, builder, testConsumedParams); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, builder); assertEquals(OK, response.status()); assertEquals(JSON_CONTENT_TYPE, response.contentType()); assertEquals("{\"status\":\"ACCEPTED\"}", response.content().utf8ToString()); List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); for (String param : consumedParams) { - assertTrue(testConsumedParams.contains(param)); + assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test public void testConstructorWithPlainText() { - ExtensionRestResponse response = new ExtensionRestResponse(OK, testText, testConsumedParams); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, testText); assertEquals(OK, response.status()); assertEquals(TEXT_CONTENT_TYPE, response.contentType()); assertEquals(testText, response.content().utf8ToString()); List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); for (String param : consumedParams) { - assertTrue(testConsumedParams.contains(param)); + assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test public void testConstructorWithText() { - ExtensionRestResponse response = new ExtensionRestResponse(OK, TEXT_CONTENT_TYPE, testText, testConsumedParams); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, TEXT_CONTENT_TYPE, testText); assertEquals(OK, response.status()); assertEquals(TEXT_CONTENT_TYPE, response.contentType()); @@ -74,30 +85,32 @@ public void testConstructorWithText() { List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); for (String param : consumedParams) { - assertTrue(testConsumedParams.contains(param)); + assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test public void testConstructorWithByteArray() { - ExtensionRestResponse response = new ExtensionRestResponse(OK, OCTET_CONTENT_TYPE, testBytes, testConsumedParams); + ExtensionRestResponse response = new ExtensionRestResponse(request, OK, OCTET_CONTENT_TYPE, testBytes); assertEquals(OK, response.status()); assertEquals(OCTET_CONTENT_TYPE, response.contentType()); assertArrayEquals(testBytes, BytesReference.toBytes(response.content())); List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); for (String param : consumedParams) { - assertTrue(testConsumedParams.contains(param)); + assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test public void testConstructorWithBytesReference() { ExtensionRestResponse response = new ExtensionRestResponse( + request, OK, OCTET_CONTENT_TYPE, - BytesReference.fromByteBuffer(ByteBuffer.wrap(testBytes, 0, 2)), - testConsumedParams + BytesReference.fromByteBuffer(ByteBuffer.wrap(testBytes, 0, 2)) ); assertEquals(OK, response.status()); @@ -105,7 +118,8 @@ public void testConstructorWithBytesReference() { assertArrayEquals(testBytes, BytesReference.toBytes(response.content())); List consumedParams = response.getHeaders().get(CONSUMED_PARAMS_KEY); for (String param : consumedParams) { - assertTrue(testConsumedParams.contains(param)); + assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } } diff --git a/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java b/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java index 54688a88..30d0ab80 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.Test; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.io.stream.NamedWriteableRegistryResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.Setting.Property; @@ -45,7 +46,7 @@ import org.opensearch.extensions.ExtensionBooleanResponse; import org.opensearch.extensions.ExtensionsOrchestrator.OpenSearchRequestType; import org.opensearch.extensions.OpenSearchRequest; -import org.opensearch.extensions.rest.RestExecuteOnExtensionRequest; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.extensions.rest.RestExecuteOnExtensionResponse; import org.opensearch.identity.ExtensionTokenProcessor; import org.opensearch.rest.BytesRestResponse; @@ -73,7 +74,6 @@ public class TestExtensionsRunner extends OpenSearchTestCase { private ExtensionsInitRequestHandler extensionsInitRequestHandler = new ExtensionsInitRequestHandler(); private OpensearchRequestHandler opensearchRequestHandler = new OpensearchRequestHandler(); private ExtensionsRestRequestHandler extensionsRestRequestHandler = new ExtensionsRestRequestHandler(); - private ExtensionsRunner extensionsRunner; private TransportService transportService; @@ -168,11 +168,18 @@ public void testHandleOpenSearchRequest() throws Exception { } @Test - public void testHandleRestExecuteOnExtensionRequest() throws Exception { + public void testHandleExtensionRestRequest() throws Exception { ExtensionTokenProcessor ext = new ExtensionTokenProcessor(EXTENSION_NAME); Principal userPrincipal = () -> "user1"; - RestExecuteOnExtensionRequest request = new RestExecuteOnExtensionRequest(Method.GET, "/foo", ext.generateToken(userPrincipal)); + ExtensionRestRequest request = new ExtensionRestRequest( + Method.GET, + "/foo", + Collections.emptyMap(), + null, + new BytesArray("bar"), + ext.generateToken(userPrincipal) + ); RestExecuteOnExtensionResponse response = extensionsRestRequestHandler.handleRestExecuteOnExtensionRequest(request); // this will fail in test environment with no registered actions assertEquals(RestStatus.NOT_FOUND, response.getStatus()); diff --git a/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java b/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java index 42a366b4..cfe04a18 100644 --- a/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java +++ b/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java @@ -9,7 +9,9 @@ import java.nio.charset.StandardCharsets; import java.security.Principal; +import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -17,12 +19,13 @@ import org.opensearch.identity.PrincipalIdentifierToken; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest.Method; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; +import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestResponse; import org.opensearch.rest.RestStatus; import org.opensearch.sdk.ExtensionRestHandler; -import org.opensearch.sdk.ExtensionRestRequest; import org.opensearch.test.OpenSearchTestCase; public class TestRestHelloAction extends OpenSearchTestCase { @@ -52,12 +55,33 @@ public void testHandleRequest() { Principal userPrincipal = () -> "user1"; ExtensionTokenProcessor extensionTokenProcessor = new ExtensionTokenProcessor(EXTENSION_NAME); PrincipalIdentifierToken token = extensionTokenProcessor.generateToken(userPrincipal); + Map params = Collections.emptyMap(); - ExtensionRestRequest getRequest = new ExtensionRestRequest(Method.GET, "/hello", token); - ExtensionRestRequest putRequest = new ExtensionRestRequest(Method.PUT, "/hello", token); - ExtensionRestRequest updateRequest = new ExtensionRestRequest(Method.PUT, "/hello/Passing+Test", token); - ExtensionRestRequest badRequest = new ExtensionRestRequest(Method.PUT, "/hello/Bad%Request", token); - ExtensionRestRequest unsuccessfulRequest = new ExtensionRestRequest(Method.GET, "/goodbye", token); + ExtensionRestRequest getRequest = new ExtensionRestRequest(Method.GET, "/hello", params, null, new BytesArray(""), token); + ExtensionRestRequest putRequest = new ExtensionRestRequest( + Method.PUT, + "/hello/Passing+Test", + Map.of("name", "Passing+Test"), + null, + new BytesArray(""), + token + ); + ExtensionRestRequest badRequest = new ExtensionRestRequest( + Method.PUT, + "/hello/Bad%Request", + Map.of("name", "Bad%Request"), + null, + new BytesArray(""), + token + ); + ExtensionRestRequest unsuccessfulRequest = new ExtensionRestRequest( + Method.POST, + "/goodbye", + params, + null, + new BytesArray(""), + token + ); RestResponse response = restHelloAction.handleRequest(getRequest); assertEquals(RestStatus.OK, response.status()); @@ -66,12 +90,6 @@ public void testHandleRequest() { assertEquals("Hello, World!", responseStr); response = restHelloAction.handleRequest(putRequest); - assertEquals(RestStatus.NOT_FOUND, response.status()); - assertEquals(BytesRestResponse.TEXT_CONTENT_TYPE, response.contentType()); - responseStr = new String(BytesReference.toBytes(response.content()), StandardCharsets.UTF_8); - assertTrue(responseStr.contains("PUT")); - - response = restHelloAction.handleRequest(updateRequest); assertEquals(RestStatus.OK, response.status()); assertEquals(BytesRestResponse.TEXT_CONTENT_TYPE, response.contentType()); responseStr = new String(BytesReference.toBytes(response.content()), StandardCharsets.UTF_8);