diff --git a/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java b/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java index 984f8270d..2b915cf31 100644 --- a/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java +++ b/src/main/java/org/opensearch/sdk/ExtensionRestResponse.java @@ -24,6 +24,10 @@ public class ExtensionRestResponse extends BytesRestResponse { * Key passed in {@link BytesRestResponse} headers to identify parameters consumed by the handler. For internal use. */ static final String CONSUMED_PARAMS_KEY = "extension.consumed.parameters"; + /** + * Key passed in {@link BytesRestResponse} headers to identify content consumed by the handler. For internal use. + */ + static final String CONSUMED_CONTENT_KEY = "extension.consumed.content"; /** * Creates a new response based on {@link XContentBuilder}. @@ -34,7 +38,7 @@ public class ExtensionRestResponse extends BytesRestResponse { */ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, XContentBuilder builder) { super(status, builder); - addConsumedParamHeader(request.consumedParams()); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** @@ -46,7 +50,7 @@ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, XC */ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String content) { super(status, content); - addConsumedParamHeader(request.consumedParams()); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** @@ -59,7 +63,7 @@ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, St */ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, String content) { super(status, contentType, content); - addConsumedParamHeader(request.consumedParams()); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** @@ -72,7 +76,7 @@ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, St */ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, byte[] content) { super(status, contentType, content); - addConsumedParamHeader(request.consumedParams()); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } /** @@ -85,10 +89,11 @@ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, St */ public ExtensionRestResponse(ExtensionRestRequest request, RestStatus status, String contentType, BytesReference content) { super(status, contentType, content); - addConsumedParamHeader(request.consumedParams()); + addConsumedHeaders(request.consumedParams(), request.isContentConsumed()); } - private void addConsumedParamHeader(List consumedParams) { + private void addConsumedHeaders(List consumedParams, boolean contentConusmed) { consumedParams.stream().forEach(p -> addHeader(CONSUMED_PARAMS_KEY, p)); + addHeader(CONSUMED_CONTENT_KEY, Boolean.toString(contentConusmed)); } } diff --git a/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java b/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java index 63ba66936..4bd982da1 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionRestResponse.java @@ -7,6 +7,7 @@ import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.xcontent.XContentBuilder; import org.opensearch.common.xcontent.XContentType; @@ -34,10 +35,12 @@ public void setUp() throws Exception { super.setUp(); testText = "plain text"; testBytes = new byte[] { 1, 2 }; - request = new ExtensionRestRequest(Method.GET, "/foo", Collections.emptyMap(), null); + request = new ExtensionRestRequest(Method.GET, "/foo", Collections.emptyMap(), null, new BytesArray("Text Content"), null); // consume params "foo" and "bar" request.param("foo"); request.param("bar"); + // consume content + request.content(); } @Test @@ -55,6 +58,7 @@ public void testConstructorWithBuilder() throws IOException { for (String param : consumedParams) { assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test @@ -68,6 +72,7 @@ public void testConstructorWithPlainText() { for (String param : consumedParams) { assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test @@ -82,6 +87,7 @@ public void testConstructorWithText() { for (String param : consumedParams) { assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test @@ -95,6 +101,7 @@ public void testConstructorWithByteArray() { for (String param : consumedParams) { assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } @Test @@ -113,5 +120,6 @@ public void testConstructorWithBytesReference() { for (String param : consumedParams) { assertTrue(request.consumedParams().contains(param)); } + assertTrue(request.isContentConsumed()); } } diff --git a/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java b/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java index 7810e3484..a8b1d93de 100644 --- a/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java +++ b/src/test/java/org/opensearch/sdk/TestExtensionsRunner.java @@ -35,6 +35,7 @@ import org.junit.jupiter.api.Test; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.io.stream.NamedWriteableRegistryResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.settings.Setting.Property; @@ -167,6 +168,8 @@ public void testHandleExtensionRestRequest() throws Exception { Method.GET, "/foo", Collections.emptyMap(), + null, + new BytesArray("bar"), ext.generateToken(userPrincipal) ); RestExecuteOnExtensionResponse response = extensionsRunner.handleRestExecuteOnExtensionRequest(request); diff --git a/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java b/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java index 000544aef..cfe04a18c 100644 --- a/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java +++ b/src/test/java/org/opensearch/sdk/sample/helloworld/rest/TestRestHelloAction.java @@ -19,6 +19,7 @@ import org.opensearch.identity.PrincipalIdentifierToken; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest.Method; +import org.opensearch.common.bytes.BytesArray; import org.opensearch.common.bytes.BytesReference; import org.opensearch.extensions.rest.ExtensionRestRequest; import org.opensearch.rest.BytesRestResponse; @@ -56,15 +57,31 @@ public void testHandleRequest() { PrincipalIdentifierToken token = extensionTokenProcessor.generateToken(userPrincipal); Map params = Collections.emptyMap(); - ExtensionRestRequest getRequest = new ExtensionRestRequest(Method.GET, "/hello", params, token); + ExtensionRestRequest getRequest = new ExtensionRestRequest(Method.GET, "/hello", params, null, new BytesArray(""), token); ExtensionRestRequest putRequest = new ExtensionRestRequest( Method.PUT, "/hello/Passing+Test", Map.of("name", "Passing+Test"), + null, + new BytesArray(""), + token + ); + ExtensionRestRequest badRequest = new ExtensionRestRequest( + Method.PUT, + "/hello/Bad%Request", + Map.of("name", "Bad%Request"), + null, + new BytesArray(""), + token + ); + ExtensionRestRequest unsuccessfulRequest = new ExtensionRestRequest( + Method.POST, + "/goodbye", + params, + null, + new BytesArray(""), token ); - ExtensionRestRequest badRequest = new ExtensionRestRequest(Method.PUT, "/hello/Bad%Request", Map.of("name", "Bad%Request"), token); - ExtensionRestRequest unsuccessfulRequest = new ExtensionRestRequest(Method.POST, "/goodbye", params, token); RestResponse response = restHelloAction.handleRequest(getRequest); assertEquals(RestStatus.OK, response.status());