diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java index 7ac002b163962..1ad0a221d6edb 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/OversampleRequestProcessor.java @@ -14,9 +14,12 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; import org.opensearch.search.pipeline.StatefulSearchRequestProcessor; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; import java.util.Map; +import static org.opensearch.search.pipeline.common.helpers.ContextUtils.applyContextPrefix; + /** * Multiplies the "size" parameter on the {@link SearchRequest} by the given scaling factor, storing the original value * in the request context as "original_size". @@ -27,20 +30,22 @@ public class OversampleRequestProcessor extends AbstractProcessor implements Sta * Key to reference this processor type from a search pipeline. */ public static final String TYPE = "oversample"; - private static final String SAMPLE_FACTOR = "sample_factor"; + static final String SAMPLE_FACTOR = "sample_factor"; static final String ORIGINAL_SIZE = "original_size"; private final double sampleFactor; + private final String contextPrefix; - private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor) { + private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor, String contextPrefix) { super(tag, description, ignoreFailure); this.sampleFactor = sampleFactor; + this.contextPrefix = contextPrefix; } @Override public SearchRequest processRequest(SearchRequest request, Map requestContext) { if (request.source() != null) { int originalSize = request.source().size(); - requestContext.put(ORIGINAL_SIZE, originalSize); + requestContext.put(applyContextPrefix(contextPrefix, ORIGINAL_SIZE), originalSize); int newSize = (int) Math.ceil(originalSize * sampleFactor); request.source().size(newSize); } @@ -53,7 +58,6 @@ public String getType() { } static class Factory implements Processor.Factory { - @Override public OversampleRequestProcessor create( Map> processorFactories, @@ -67,7 +71,8 @@ public OversampleRequestProcessor create( if (sampleFactor < 1.0) { throw ConfigurationUtils.newConfigurationException(TYPE, tag, SAMPLE_FACTOR, "Value must be >= 1.0"); } - return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor); + String contextPrefix = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ContextUtils.CONTEXT_PREFIX_PARAMETER); + return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor, contextPrefix); } } } diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java index 90f71fd1754e4..8551f8e9f180c 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/ScriptRequestProcessor.java @@ -25,6 +25,7 @@ import org.opensearch.search.pipeline.AbstractProcessor; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.StatefulSearchRequestProcessor; import org.opensearch.search.pipeline.common.helpers.SearchRequestMap; import java.io.InputStream; @@ -38,7 +39,7 @@ * Processor that evaluates a script with a search request in its context * and then returns the modified search request. */ -public final class ScriptRequestProcessor extends AbstractProcessor implements SearchRequestProcessor { +public final class ScriptRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor { /** * Key to reference this processor type from a search pipeline. */ @@ -72,15 +73,8 @@ public final class ScriptRequestProcessor extends AbstractProcessor implements S this.scriptService = scriptService; } - /** - * Executes the script with the search request in context. - * - * @param request The search request passed into the script context. - * @return The modified search request. - * @throws Exception if an error occurs while processing the request. - */ @Override - public SearchRequest processRequest(SearchRequest request) throws Exception { + public SearchRequest processRequest(SearchRequest request, Map requestContext) throws Exception { // assert request is not null and source is not null if (request == null || request.source() == null) { throw new IllegalArgumentException("search request must not be null"); @@ -93,7 +87,7 @@ public SearchRequest processRequest(SearchRequest request) throws Exception { searchScript = precompiledSearchScript; } // execute the script with the search request in context - searchScript.execute(Map.of("_source", new SearchRequestMap(request))); + searchScript.execute(Map.of("_source", new SearchRequestMap(request), "request_context", requestContext)); return request; } diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java index 9282f0235bc50..8fed8f67bd572 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessor.java @@ -17,10 +17,13 @@ import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.search.pipeline.StatefulSearchResponseProcessor; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil; import java.util.Map; +import static org.opensearch.search.pipeline.common.helpers.ContextUtils.applyContextPrefix; + /** * Truncates the returned search hits from the {@link SearchResponse}. If no target size is specified in the pipeline, then * we try using the "original_size" value from the request context, which may have been set by {@link OversampleRequestProcessor}. @@ -30,17 +33,19 @@ public class TruncateHitsResponseProcessor extends AbstractProcessor implements * Key to reference this processor type from a search pipeline. */ public static final String TYPE = "truncate_hits"; - private static final String TARGET_SIZE = "target_size"; + static final String TARGET_SIZE = "target_size"; private final int targetSize; + private final String contextPrefix; @Override public String getType() { return TYPE; } - private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize) { + private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize, String contextPrefix) { super(tag, description, ignoreFailure); this.targetSize = targetSize; + this.contextPrefix = contextPrefix; } @Override @@ -48,7 +53,12 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp int size; if (targetSize < 0) { - size = (int) requestContext.get(OversampleRequestProcessor.ORIGINAL_SIZE); + String key = applyContextPrefix(contextPrefix, OversampleRequestProcessor.ORIGINAL_SIZE); + Object o = requestContext.get(key); + if (o == null) { + throw new IllegalStateException("Must specify target_size unless an earlier processor set " + key); + } + size = (int) o; } else { size = targetSize; } @@ -71,16 +81,17 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp static class Factory implements Processor.Factory { @Override - public SearchResponseProcessor create( + public TruncateHitsResponseProcessor create( Map> processorFactories, String tag, String description, boolean ignoreFailure, Map config, PipelineContext pipelineContext - ) throws Exception { + ) { int targetSize = ConfigurationUtils.readIntProperty(TYPE, tag, config, TARGET_SIZE, -1); - return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize); + String contextPrefix = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ContextUtils.CONTEXT_PREFIX_PARAMETER); + return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize, contextPrefix); } } diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java new file mode 100644 index 0000000000000..9697da85dbecf --- /dev/null +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/ContextUtils.java @@ -0,0 +1,38 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common.helpers; + +/** + * Helpers for working with request-scoped context. + */ +public final class ContextUtils { + private ContextUtils() {} + + /** + * Parameter that can be passed to a stateful processor to avoid collisions between contextual variables by + * prefixing them with distinct qualifiers. + */ + public static final String CONTEXT_PREFIX_PARAMETER = "context_prefix"; + + /** + * Replaces a "global" variable name with one scoped to a given context prefix (unless prefix is null or empty). + * @param contextPrefix the prefix qualifier for the variable + * @param variableName the generic "global" form of the context variable + * @return the variableName prefixed with contextPrefix followed by ".", or just variableName if contextPrefix is null or empty + */ + public static String applyContextPrefix(String contextPrefix, String variableName) { + String contextVariable; + if (contextPrefix != null && contextPrefix.isEmpty() == false) { + contextVariable = contextPrefix + "." + variableName; + } else { + contextVariable = variableName; + } + return contextVariable; + } +} diff --git a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java index e2679e9c5e3f4..f3ff458caa264 100644 --- a/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java +++ b/modules/search-pipeline-common/src/main/java/org/opensearch/search/pipeline/common/helpers/SearchResponseUtil.java @@ -9,8 +9,9 @@ package org.opensearch.search.pipeline.common.helpers; import org.opensearch.action.search.SearchResponse; -import org.opensearch.action.search.SearchResponseSections; import org.opensearch.search.SearchHits; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; import org.opensearch.search.profile.SearchProfileShardResults; /** @@ -29,13 +30,13 @@ private SearchResponseUtil() { */ public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) { return new SearchResponse( - new SearchResponseSections( + new InternalSearchResponse( newHits, - response.getAggregations(), + (InternalAggregations) response.getAggregations(), response.getSuggest(), + new SearchProfileShardResults(response.getProfileResults()), response.isTimedOut(), response.isTerminatedEarly(), - new SearchProfileShardResults(response.getProfileResults()), response.getNumReducePhases() ), response.getScrollId(), diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java new file mode 100644 index 0000000000000..56165035ee778 --- /dev/null +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/OversampleRequestProcessorTests.java @@ -0,0 +1,63 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.opensearch.action.search.SearchRequest; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class OversampleRequestProcessorTests extends OpenSearchTestCase { + + public void testEmptySource() { + OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory(); + Map config = new HashMap<>(Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0)); + OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + SearchRequest request = new SearchRequest(); + Map context = new HashMap<>(); + SearchRequest transformedRequest = processor.processRequest(request, context); + assertEquals(request, transformedRequest); + assertTrue(context.isEmpty()); + } + + public void testBasicBehavior() { + OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory(); + Map config = new HashMap<>(Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0)); + OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10); + SearchRequest request = new SearchRequest().source(sourceBuilder); + Map context = new HashMap<>(); + SearchRequest transformedRequest = processor.processRequest(request, context); + assertEquals(30, transformedRequest.source().size()); + assertEquals(1, context.size()); + assertEquals(10, context.get("original_size")); + } + + public void testContextPrefix() { + OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory(); + Map config = new HashMap<>( + Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0, ContextUtils.CONTEXT_PREFIX_PARAMETER, "foo") + ); + OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10); + SearchRequest request = new SearchRequest().source(sourceBuilder); + Map context = new HashMap<>(); + SearchRequest transformedRequest = processor.processRequest(request, context); + assertEquals(30, transformedRequest.source().size()); + assertEquals(1, context.size()); + assertEquals(10, context.get("foo.original_size")); + } +} diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java index fde9757312e30..6f7dbc6390db1 100644 --- a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/ScriptRequestProcessorTests.java @@ -27,8 +27,6 @@ import java.util.Map; import java.util.concurrent.TimeUnit; -import static org.hamcrest.core.Is.is; - public class ScriptRequestProcessorTests extends OpenSearchTestCase { private ScriptService scriptService; @@ -87,7 +85,7 @@ public void testScriptingWithoutPrecompiledScriptFactory() throws Exception { searchRequest.source(createSearchSourceBuilder()); assertNotNull(searchRequest); - processor.processRequest(searchRequest); + processor.processRequest(searchRequest, new HashMap<>()); assertSearchRequest(searchRequest); } @@ -104,7 +102,7 @@ public void testScriptingWithPrecompiledIngestScript() throws Exception { searchRequest.source(createSearchSourceBuilder()); assertNotNull(searchRequest); - processor.processRequest(searchRequest); + processor.processRequest(searchRequest, new HashMap<>()); assertSearchRequest(searchRequest); } @@ -124,15 +122,15 @@ private SearchSourceBuilder createSearchSourceBuilder() { } private void assertSearchRequest(SearchRequest searchRequest) { - assertThat(searchRequest.source().from(), is(20)); - assertThat(searchRequest.source().size(), is(30)); - assertThat(searchRequest.source().explain(), is(false)); - assertThat(searchRequest.source().version(), is(false)); - assertThat(searchRequest.source().seqNoAndPrimaryTerm(), is(false)); - assertThat(searchRequest.source().trackScores(), is(false)); - assertThat(searchRequest.source().trackTotalHitsUpTo(), is(4)); - assertThat(searchRequest.source().minScore(), is(2.0f)); - assertThat(searchRequest.source().timeout(), is(new TimeValue(60, TimeUnit.SECONDS))); - assertThat(searchRequest.source().terminateAfter(), is(6)); + assertEquals(20, searchRequest.source().from()); + assertEquals(30, searchRequest.source().size()); + assertFalse(searchRequest.source().explain()); + assertFalse(searchRequest.source().version()); + assertFalse(searchRequest.source().seqNoAndPrimaryTerm()); + assertFalse(searchRequest.source().trackScores()); + assertEquals(4, searchRequest.source().trackTotalHitsUpTo().intValue()); + assertEquals(2.0f, searchRequest.source().minScore(), 0.0001); + assertEquals(new TimeValue(60, TimeUnit.SECONDS), searchRequest.source().timeout()); + assertEquals(6, searchRequest.source().terminateAfter()); } } diff --git a/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java new file mode 100644 index 0000000000000..d82c302b98b70 --- /dev/null +++ b/modules/search-pipeline-common/src/test/java/org/opensearch/search/pipeline/common/TruncateHitsResponseProcessorTests.java @@ -0,0 +1,87 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.search.pipeline.common; + +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.pipeline.common.helpers.ContextUtils; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +public class TruncateHitsResponseProcessorTests extends OpenSearchTestCase { + + public void testBasicBehavior() { + int targetSize = randomInt(50); + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + Map config = new HashMap<>(Map.of(TruncateHitsResponseProcessor.TARGET_SIZE, targetSize)); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + SearchResponse transformedResponse = processor.processResponse(new SearchRequest(), response, Collections.emptyMap()); + assertEquals(Math.min(targetSize, numHits), transformedResponse.getHits().getHits().length); + } + + public void testTargetSizePassedViaContext() { + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null); + + int targetSize = randomInt(50); + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + SearchResponse transformedResponse = processor.processResponse(new SearchRequest(), response, Map.of("original_size", targetSize)); + assertEquals(Math.min(targetSize, numHits), transformedResponse.getHits().getHits().length); + } + + public void testTargetSizePassedViaContextWithPrefix() { + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + Map config = new HashMap<>(Map.of(ContextUtils.CONTEXT_PREFIX_PARAMETER, "foo")); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null); + + int targetSize = randomInt(50); + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + SearchResponse transformedResponse = processor.processResponse( + new SearchRequest(), + response, + Map.of("foo.original_size", targetSize) + ); + assertEquals(Math.min(targetSize, numHits), transformedResponse.getHits().getHits().length); + } + + public void testTargetSizeMissing() { + TruncateHitsResponseProcessor.Factory factory = new TruncateHitsResponseProcessor.Factory(); + TruncateHitsResponseProcessor processor = factory.create(Collections.emptyMap(), null, null, false, Collections.emptyMap(), null); + + int numHits = randomInt(100); + SearchResponse response = constructResponse(numHits); + assertThrows(IllegalStateException.class, () -> processor.processResponse(new SearchRequest(), response, Collections.emptyMap())); + } + + private static SearchResponse constructResponse(int numHits) { + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + hitsArray[i] = new SearchHit(i, Integer.toString(i), Collections.emptyMap(), Collections.emptyMap()); + } + SearchHits searchHits = new SearchHits( + hitsArray, + new TotalHits(Math.max(numHits, 1000), TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO), + 1.0f + ); + InternalSearchResponse internalSearchResponse = new InternalSearchResponse(searchHits, null, null, null, false, false, 0); + return new SearchResponse(internalSearchResponse, null, 1, 1, 0, 10, null, null); + } +} diff --git a/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml new file mode 100644 index 0000000000000..9c9f6747e9bdc --- /dev/null +++ b/modules/search-pipeline-common/src/yamlRestTest/resources/rest-api-spec/test/search_pipeline/70_script_truncate.yml @@ -0,0 +1,70 @@ +--- +teardown: + - do: + search_pipeline.delete: + id: "my_pipeline" + ignore: 404 + +--- +"Test state propagating from script request to truncate_hits processor": + - do: + search_pipeline.put: + id: "my_pipeline" + body: > + { + "description": "_description", + "request_processors": [ + { + "script" : { + "source" : "ctx.request_context['foo.original_size'] = 2" + } + } + ], + "response_processors": [ + { + "truncate_hits" : { + "context_prefix" : "foo" + } + } + ] + } + - match: { acknowledged: true } + + - do: + index: + index: test + id: 1 + body: {} + - do: + index: + index: test + id: 2 + body: {} + - do: + index: + index: test + id: 3 + body: {} + - do: + index: + index: test + id: 4 + body: {} + - do: + indices.refresh: + index: test + + - do: + search: + body: { + } + - match: { hits.total.value: 4 } + - length: { hits.hits: 4 } + + - do: + search: + search_pipeline: my_pipeline + body: { + } + - match: { hits.total.value: 4 } + - length: { hits.hits: 2 } diff --git a/server/src/main/java/org/opensearch/search/pipeline/Processor.java b/server/src/main/java/org/opensearch/search/pipeline/Processor.java index fb33f46acada4..f9b62416d25ab 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/Processor.java +++ b/server/src/main/java/org/opensearch/search/pipeline/Processor.java @@ -21,13 +21,6 @@ * @opensearch.internal */ public interface Processor { - /** - * Processor configuration key to let the factory know the context for pipeline creation. - *

- * See {@link PipelineSource}. - */ - String PIPELINE_SOURCE = "pipeline_source"; - /** * Gets the type of processor */ diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 660d0f4e7c686..a467a6dcd4334 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -41,6 +41,7 @@ import org.opensearch.common.settings.Settings; import org.opensearch.common.util.concurrent.AtomicArray; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.xcontent.XContentType; import org.opensearch.core.common.breaker.CircuitBreaker; import org.opensearch.core.common.breaker.NoopCircuitBreaker; import org.opensearch.core.common.bytes.BytesArray; @@ -67,6 +68,7 @@ import java.util.List; import java.util.Map; import java.util.concurrent.ExecutorService; +import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; import static org.mockito.ArgumentMatchers.anyString; @@ -1336,4 +1338,89 @@ public void testExtraParameterInProcessorConfig() { fail("Wrong exception type: " + e.getClass()); } } + + private static class FakeStatefulRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor { + private final String type; + private final Consumer> stateConsumer; + + public FakeStatefulRequestProcessor(String type, Consumer> stateConsumer) { + super(null, null, false); + this.type = type; + this.stateConsumer = stateConsumer; + } + + @Override + public String getType() { + return type; + } + + @Override + public SearchRequest processRequest(SearchRequest request, Map requestContext) throws Exception { + stateConsumer.accept(requestContext); + return request; + } + } + + private static class FakeStatefulResponseProcessor extends AbstractProcessor implements StatefulSearchResponseProcessor { + private final String type; + private final Consumer> stateConsumer; + + public FakeStatefulResponseProcessor(String type, Consumer> stateConsumer) { + super(null, null, false); + this.type = type; + this.stateConsumer = stateConsumer; + } + + @Override + public String getType() { + return type; + } + + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response, Map requestContext) + throws Exception { + stateConsumer.accept(requestContext); + return response; + } + } + + public void testStatefulProcessors() { + AtomicReference contextHolder = new AtomicReference<>(); + SearchPipelineService searchPipelineService = createWithProcessors( + Map.of("write_context", (pf, t, d, igf, cfg, ctx) -> new FakeStatefulRequestProcessor("write_context", (c) -> c.put("a", "b"))), + Map.of( + "read_context", + (pf, t, d, igf, cfg, ctx) -> new FakeStatefulResponseProcessor( + "read_context", + (c) -> contextHolder.set((String) c.get("a")) + ) + ), + Collections.emptyMap() + ); + + SearchPipelineMetadata metadata = new SearchPipelineMetadata( + Map.of( + "p1", + new PipelineConfiguration( + "p1", + new BytesArray( + "{\"request_processors\" : [ { \"write_context\": {} } ], \"response_processors\": [ { \"read_context\": {} }] }" + ), + XContentType.JSON + ) + ) + ); + ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build(); + ClusterState previousState = clusterState; + clusterState = ClusterState.builder(clusterState) + .metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata)) + .build(); + searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); + + PipelinedRequest request = searchPipelineService.resolvePipeline(new SearchRequest().pipeline("p1")); + assertNull(contextHolder.get()); + request.transformResponse(new SearchResponse(null, null, 0, 0, 0, 0, null, null)); + assertNotNull(contextHolder.get()); + assertEquals("b", contextHolder.get()); + } }