Skip to content

Commit

Permalink
More work on stateful pipeline processors
Browse files Browse the repository at this point in the history
Added "context_prefix" convention to scope variables to avoid
collisions.

Let script processor have access to the request context.

Added more unit tests.

Signed-off-by: Michael Froh <[email protected]>
  • Loading branch information
msfroh committed Aug 23, 2023
1 parent 4d0d93e commit 923b12d
Show file tree
Hide file tree
Showing 11 changed files with 393 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.StatefulSearchRequestProcessor;
import org.opensearch.search.pipeline.common.helpers.ContextUtils;

import java.util.Map;

import static org.opensearch.search.pipeline.common.helpers.ContextUtils.applyContextPrefix;

/**
* Multiplies the "size" parameter on the {@link SearchRequest} by the given scaling factor, storing the original value
* in the request context as "original_size".
Expand All @@ -27,20 +30,22 @@ public class OversampleRequestProcessor extends AbstractProcessor implements Sta
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "oversample";
private static final String SAMPLE_FACTOR = "sample_factor";
static final String SAMPLE_FACTOR = "sample_factor";
static final String ORIGINAL_SIZE = "original_size";
private final double sampleFactor;
private final String contextPrefix;

private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor) {
private OversampleRequestProcessor(String tag, String description, boolean ignoreFailure, double sampleFactor, String contextPrefix) {
super(tag, description, ignoreFailure);
this.sampleFactor = sampleFactor;
this.contextPrefix = contextPrefix;
}

@Override
public SearchRequest processRequest(SearchRequest request, Map<String, Object> requestContext) {
if (request.source() != null) {
int originalSize = request.source().size();
requestContext.put(ORIGINAL_SIZE, originalSize);
requestContext.put(applyContextPrefix(contextPrefix, ORIGINAL_SIZE), originalSize);
int newSize = (int) Math.ceil(originalSize * sampleFactor);
request.source().size(newSize);
}
Expand All @@ -53,7 +58,6 @@ public String getType() {
}

static class Factory implements Processor.Factory<SearchRequestProcessor> {

@Override
public OversampleRequestProcessor create(
Map<String, Processor.Factory<SearchRequestProcessor>> processorFactories,
Expand All @@ -67,7 +71,8 @@ public OversampleRequestProcessor create(
if (sampleFactor < 1.0) {
throw ConfigurationUtils.newConfigurationException(TYPE, tag, SAMPLE_FACTOR, "Value must be >= 1.0");
}
return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor);
String contextPrefix = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ContextUtils.CONTEXT_PREFIX_PARAMETER);
return new OversampleRequestProcessor(tag, description, ignoreFailure, sampleFactor, contextPrefix);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.opensearch.search.pipeline.AbstractProcessor;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.StatefulSearchRequestProcessor;
import org.opensearch.search.pipeline.common.helpers.SearchRequestMap;

import java.io.InputStream;
Expand All @@ -38,7 +39,7 @@
* Processor that evaluates a script with a search request in its context
* and then returns the modified search request.
*/
public final class ScriptRequestProcessor extends AbstractProcessor implements SearchRequestProcessor {
public final class ScriptRequestProcessor extends AbstractProcessor implements StatefulSearchRequestProcessor {
/**
* Key to reference this processor type from a search pipeline.
*/
Expand Down Expand Up @@ -72,15 +73,8 @@ public final class ScriptRequestProcessor extends AbstractProcessor implements S
this.scriptService = scriptService;
}

/**
* Executes the script with the search request in context.
*
* @param request The search request passed into the script context.
* @return The modified search request.
* @throws Exception if an error occurs while processing the request.
*/
@Override
public SearchRequest processRequest(SearchRequest request) throws Exception {
public SearchRequest processRequest(SearchRequest request, Map<String, Object> requestContext) throws Exception {
// assert request is not null and source is not null
if (request == null || request.source() == null) {
throw new IllegalArgumentException("search request must not be null");
Expand All @@ -93,7 +87,7 @@ public SearchRequest processRequest(SearchRequest request) throws Exception {
searchScript = precompiledSearchScript;
}
// execute the script with the search request in context
searchScript.execute(Map.of("_source", new SearchRequestMap(request)));
searchScript.execute(Map.of("_source", new SearchRequestMap(request), "request_context", requestContext));
return request;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,13 @@
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.pipeline.StatefulSearchResponseProcessor;
import org.opensearch.search.pipeline.common.helpers.ContextUtils;
import org.opensearch.search.pipeline.common.helpers.SearchResponseUtil;

import java.util.Map;

import static org.opensearch.search.pipeline.common.helpers.ContextUtils.applyContextPrefix;

/**
* Truncates the returned search hits from the {@link SearchResponse}. If no target size is specified in the pipeline, then
* we try using the "original_size" value from the request context, which may have been set by {@link OversampleRequestProcessor}.
Expand All @@ -30,25 +33,32 @@ public class TruncateHitsResponseProcessor extends AbstractProcessor implements
* Key to reference this processor type from a search pipeline.
*/
public static final String TYPE = "truncate_hits";
private static final String TARGET_SIZE = "target_size";
static final String TARGET_SIZE = "target_size";
private final int targetSize;
private final String contextPrefix;

@Override
public String getType() {
return TYPE;
}

private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize) {
private TruncateHitsResponseProcessor(String tag, String description, boolean ignoreFailure, int targetSize, String contextPrefix) {
super(tag, description, ignoreFailure);
this.targetSize = targetSize;
this.contextPrefix = contextPrefix;
}

@Override
public SearchResponse processResponse(SearchRequest request, SearchResponse response, Map<String, Object> requestContext) {

int size;
if (targetSize < 0) {
size = (int) requestContext.get(OversampleRequestProcessor.ORIGINAL_SIZE);
String key = applyContextPrefix(contextPrefix, OversampleRequestProcessor.ORIGINAL_SIZE);
Object o = requestContext.get(key);
if (o == null) {
throw new IllegalStateException("Must specify target_size unless an earlier processor set " + key);
}
size = (int) o;
} else {
size = targetSize;
}
Expand All @@ -71,16 +81,17 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp
static class Factory implements Processor.Factory<SearchResponseProcessor> {

@Override
public SearchResponseProcessor create(
public TruncateHitsResponseProcessor create(
Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
String tag,
String description,
boolean ignoreFailure,
Map<String, Object> config,
PipelineContext pipelineContext
) throws Exception {
) {
int targetSize = ConfigurationUtils.readIntProperty(TYPE, tag, config, TARGET_SIZE, -1);
return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize);
String contextPrefix = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, ContextUtils.CONTEXT_PREFIX_PARAMETER);
return new TruncateHitsResponseProcessor(tag, description, ignoreFailure, targetSize, contextPrefix);
}
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common.helpers;

/**
* Helpers for working with request-scoped context.
*/
public final class ContextUtils {
private ContextUtils() {}

/**
* Parameter that can be passed to a stateful processor to avoid collisions between contextual variables by
* prefixing them with distinct qualifiers.
*/
public static final String CONTEXT_PREFIX_PARAMETER = "context_prefix";

/**
* Replaces a "global" variable name with one scoped to a given context prefix (unless prefix is null or empty).
* @param contextPrefix the prefix qualifier for the variable
* @param variableName the generic "global" form of the context variable
* @return the variableName prefixed with contextPrefix followed by ".", or just variableName if contextPrefix is null or empty
*/
public static String applyContextPrefix(String contextPrefix, String variableName) {
String contextVariable;
if (contextPrefix != null && contextPrefix.isEmpty() == false) {
contextVariable = contextPrefix + "." + variableName;
} else {
contextVariable = variableName;
}
return contextVariable;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,9 @@
package org.opensearch.search.pipeline.common.helpers;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.action.search.SearchResponseSections;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
import org.opensearch.search.internal.InternalSearchResponse;
import org.opensearch.search.profile.SearchProfileShardResults;

/**
Expand All @@ -29,13 +30,13 @@ private SearchResponseUtil() {
*/
public static SearchResponse replaceHits(SearchHits newHits, SearchResponse response) {
return new SearchResponse(
new SearchResponseSections(
new InternalSearchResponse(
newHits,
response.getAggregations(),
(InternalAggregations) response.getAggregations(),
response.getSuggest(),
new SearchProfileShardResults(response.getProfileResults()),
response.isTimedOut(),
response.isTerminatedEarly(),
new SearchProfileShardResults(response.getProfileResults()),
response.getNumReducePhases()
),
response.getScrollId(),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.search.pipeline.common;

import org.opensearch.action.search.SearchRequest;
import org.opensearch.search.builder.SearchSourceBuilder;
import org.opensearch.search.pipeline.common.helpers.ContextUtils;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

public class OversampleRequestProcessorTests extends OpenSearchTestCase {

public void testEmptySource() {
OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory();
Map<String, Object> config = new HashMap<>(Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0));
OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null);

SearchRequest request = new SearchRequest();
Map<String, Object> context = new HashMap<>();
SearchRequest transformedRequest = processor.processRequest(request, context);
assertEquals(request, transformedRequest);
assertTrue(context.isEmpty());
}

public void testBasicBehavior() {
OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory();
Map<String, Object> config = new HashMap<>(Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0));
OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null);

SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10);
SearchRequest request = new SearchRequest().source(sourceBuilder);
Map<String, Object> context = new HashMap<>();
SearchRequest transformedRequest = processor.processRequest(request, context);
assertEquals(30, transformedRequest.source().size());
assertEquals(1, context.size());
assertEquals(10, context.get("original_size"));
}

public void testContextPrefix() {
OversampleRequestProcessor.Factory factory = new OversampleRequestProcessor.Factory();
Map<String, Object> config = new HashMap<>(
Map.of(OversampleRequestProcessor.SAMPLE_FACTOR, 3.0, ContextUtils.CONTEXT_PREFIX_PARAMETER, "foo")
);
OversampleRequestProcessor processor = factory.create(Collections.emptyMap(), null, null, false, config, null);

SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().size(10);
SearchRequest request = new SearchRequest().source(sourceBuilder);
Map<String, Object> context = new HashMap<>();
SearchRequest transformedRequest = processor.processRequest(request, context);
assertEquals(30, transformedRequest.source().size());
assertEquals(1, context.size());
assertEquals(10, context.get("foo.original_size"));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,6 @@
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.hamcrest.core.Is.is;

public class ScriptRequestProcessorTests extends OpenSearchTestCase {

private ScriptService scriptService;
Expand Down Expand Up @@ -87,7 +85,7 @@ public void testScriptingWithoutPrecompiledScriptFactory() throws Exception {
searchRequest.source(createSearchSourceBuilder());

assertNotNull(searchRequest);
processor.processRequest(searchRequest);
processor.processRequest(searchRequest, new HashMap<>());
assertSearchRequest(searchRequest);
}

Expand All @@ -104,7 +102,7 @@ public void testScriptingWithPrecompiledIngestScript() throws Exception {
searchRequest.source(createSearchSourceBuilder());

assertNotNull(searchRequest);
processor.processRequest(searchRequest);
processor.processRequest(searchRequest, new HashMap<>());
assertSearchRequest(searchRequest);
}

Expand All @@ -124,15 +122,15 @@ private SearchSourceBuilder createSearchSourceBuilder() {
}

private void assertSearchRequest(SearchRequest searchRequest) {
assertThat(searchRequest.source().from(), is(20));
assertThat(searchRequest.source().size(), is(30));
assertThat(searchRequest.source().explain(), is(false));
assertThat(searchRequest.source().version(), is(false));
assertThat(searchRequest.source().seqNoAndPrimaryTerm(), is(false));
assertThat(searchRequest.source().trackScores(), is(false));
assertThat(searchRequest.source().trackTotalHitsUpTo(), is(4));
assertThat(searchRequest.source().minScore(), is(2.0f));
assertThat(searchRequest.source().timeout(), is(new TimeValue(60, TimeUnit.SECONDS)));
assertThat(searchRequest.source().terminateAfter(), is(6));
assertEquals(20, searchRequest.source().from());
assertEquals(30, searchRequest.source().size());
assertFalse(searchRequest.source().explain());
assertFalse(searchRequest.source().version());
assertFalse(searchRequest.source().seqNoAndPrimaryTerm());
assertFalse(searchRequest.source().trackScores());
assertEquals(4, searchRequest.source().trackTotalHitsUpTo().intValue());
assertEquals(2.0f, searchRequest.source().minScore(), 0.0001);
assertEquals(new TimeValue(60, TimeUnit.SECONDS), searchRequest.source().timeout());
assertEquals(6, searchRequest.source().terminateAfter());
}
}
Loading

0 comments on commit 923b12d

Please sign in to comment.