From 0b95678ac040e718dc708e621ea677bd095ad0ed Mon Sep 17 00:00:00 2001 From: Owais Date: Thu, 12 Sep 2024 17:16:55 -0700 Subject: [PATCH 1/8] Added support for search pipeline name in multi search API Signed-off-by: Owais --- .../search/builder/SearchSourceBuilder.java | 21 +++++++ .../pipeline/SearchPipelineService.java | 3 + .../pipeline/SearchPipelineServiceTests.java | 61 +++++++++++++++++++ 3 files changed, 85 insertions(+) diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 8a9704b04566f..e949b145421c9 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -224,6 +224,8 @@ public static HighlightBuilder highlight() { private Map searchPipelineSource = null; + private String searchPipeline; + /** * Constructs a new search source builder. */ @@ -273,6 +275,7 @@ public SearchSourceBuilder(StreamInput in) throws IOException { seqNoAndPrimaryTerm = in.readOptionalBoolean(); extBuilders = in.readNamedWriteableList(SearchExtBuilder.class); profile = in.readBoolean(); + searchPipeline = in.readOptionalString(); searchAfterBuilder = in.readOptionalWriteable(SearchAfterBuilder::new); sliceBuilder = in.readOptionalWriteable(SliceBuilder::new); collapse = in.readOptionalWriteable(CollapseBuilder::new); @@ -347,6 +350,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(seqNoAndPrimaryTerm); out.writeNamedWriteableList(extBuilders); out.writeBoolean(profile); + out.writeOptionalString(searchPipeline); out.writeOptionalWriteable(searchAfterBuilder); out.writeOptionalWriteable(sliceBuilder); out.writeOptionalWriteable(collapse); @@ -1111,6 +1115,13 @@ public Map searchPipelineSource() { return searchPipelineSource; } + /** + * @return a search pipeline name defined within the search source (see {@link org.opensearch.search.pipeline.SearchPipelineService}) + */ + public String pipeline() { + return searchPipeline; + } + /** * Define a search pipeline to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}. */ @@ -1119,6 +1130,14 @@ public SearchSourceBuilder searchPipelineSource(Map searchPipeli return this; } + /** + * Define a search pipeline name to process this search request and/or its response. See {@link org.opensearch.search.pipeline.SearchPipelineService}. + */ + public SearchSourceBuilder pipeline(String searchPipeline) { + this.searchPipeline = searchPipeline; + return this; + } + /** * Rewrites this search source builder into its primitive form. e.g. by * rewriting the QueryBuilder. If the builder did not change the identity @@ -1283,6 +1302,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th sort(parser.text()); } else if (PROFILE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { profile = parser.booleanValue(); + } else if (SEARCH_PIPELINE.match(currentFieldName, parser.getDeprecationHandler())) { + searchPipeline = parser.text(); } else { throw new ParsingException( parser.getTokenLocation(), diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index 012d6695c042b..e4ec85db4d331 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -395,6 +395,9 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest, IndexNameEx if (searchRequest.pipeline() != null) { // Named pipeline specified for the request pipelineId = searchRequest.pipeline(); + } else if (searchRequest.source() != null && searchRequest.source().pipeline() != null) { + // Inline pipeline specified for the request + pipelineId = searchRequest.source().pipeline(); } else if (state != null && searchRequest.indices() != null && searchRequest.indices().length != 0) { try { // Check for index default pipeline diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index f5857922fdff2..2e91a0ad25188 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -969,6 +969,67 @@ public void testInlinePipeline() throws Exception { } } + /** + * Tests a pipeline name defined in the search request source. + */ + public void testInlineDefinedPipeline() throws Exception { + SearchPipelineService searchPipelineService = createWithProcessors(); + + SearchPipelineMetadata metadata = new SearchPipelineMetadata( + Map.of( + "p1", + new PipelineConfiguration( + "p1", + new BytesArray( + "{" + + "\"request_processors\": [{ \"scale_request_size\": { \"scale\" : 2 } }]," + + "\"response_processors\": [{ \"fixed_score\": { \"score\" : 2 } }]" + + "}" + ), + MediaTypeRegistry.JSON + ) + + ) + + ); + ClusterState clusterState = ClusterState.builder(new ClusterName("_name")).build(); + ClusterState previousState = clusterState; + clusterState = ClusterState.builder(clusterState) + .metadata(Metadata.builder().putCustom(SearchPipelineMetadata.TYPE, metadata)) + .build(); + searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); + + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1"); + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + + // Verify pipeline + PipelinedRequest pipelinedRequest = syncTransformRequest( + searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) + ); + Pipeline pipeline = pipelinedRequest.getPipeline(); + assertEquals("p1", pipeline.getId()); + assertEquals(1, pipeline.getSearchRequestProcessors().size()); + assertEquals(1, pipeline.getSearchResponseProcessors().size()); + + // Verify that pipeline transforms request + assertEquals(200, pipelinedRequest.source().size()); + + int size = 10; + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < size; i++) { + hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap()); + hits[i].score(i); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + + SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); + for (int i = 0; i < size; i++) { + assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001); + } + } + public void testInfo() { SearchPipelineService searchPipelineService = createWithProcessors(); SearchPipelineInfo info = searchPipelineService.info(); From cae8b4ce118832bb61cda798416e77b4aa5a87af Mon Sep 17 00:00:00 2001 From: Owais Date: Thu, 12 Sep 2024 17:22:23 -0700 Subject: [PATCH 2/8] Updated CHANGELOG Signed-off-by: Owais --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index eebebec672058..da0c7760ae4ae 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Implement WithFieldName interface in ValuesSourceAggregationBuilder & FieldSortBuilder ([#15916](https://github.com/opensearch-project/OpenSearch/pull/15916)) - Add successfulSearchShardIndices in searchRequestContext ([#15967](https://github.com/opensearch-project/OpenSearch/pull/15967)) - Remove identity-related feature flagged code from the RestController ([#15430](https://github.com/opensearch-project/OpenSearch/pull/15430)) +- Add support for msearch API to pass search pipeline name - ([#15923](https://github.com/opensearch-project/OpenSearch/pull/15923)) ### Dependencies - Bump `com.azure:azure-identity` from 1.13.0 to 1.13.2 ([#15578](https://github.com/opensearch-project/OpenSearch/pull/15578)) From fb7aae9c0d4d947dcba7f4cd555a480c95a228ef Mon Sep 17 00:00:00 2001 From: Owais Date: Wed, 18 Sep 2024 14:17:22 -0700 Subject: [PATCH 3/8] Pulled search pipeline in MultiSearchRequest and updated test Signed-off-by: Owais --- .../action/search/MultiSearchRequest.java | 4 ++ .../pipeline/SearchPipelineService.java | 3 - .../pipeline/SearchPipelineServiceTests.java | 72 ++++++++++++------- 3 files changed, 49 insertions(+), 30 deletions(-) diff --git a/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java b/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java index 5b887b48f696e..f16d7d1e7d6a3 100644 --- a/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java +++ b/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java @@ -310,6 +310,10 @@ public static void readMultiLineFormat( ) { consumer.accept(searchRequest, parser); } + + if (searchRequest.source() != null && searchRequest.source().pipeline() != null) { + searchRequest.pipeline(searchRequest.source().pipeline()); + } // move pointers from = nextMarker + 1; } diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index e4ec85db4d331..012d6695c042b 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -395,9 +395,6 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest, IndexNameEx if (searchRequest.pipeline() != null) { // Named pipeline specified for the request pipelineId = searchRequest.pipeline(); - } else if (searchRequest.source() != null && searchRequest.source().pipeline() != null) { - // Inline pipeline specified for the request - pipelineId = searchRequest.source().pipeline(); } else if (state != null && searchRequest.indices() != null && searchRequest.indices().length != 0) { try { // Check for index default pipeline diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 2e91a0ad25188..c0639607de8d1 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -18,6 +18,7 @@ import org.opensearch.Version; import org.opensearch.action.search.DeleteSearchPipelineRequest; import org.opensearch.action.search.MockSearchPhaseContext; +import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.PutSearchPipelineRequest; import org.opensearch.action.search.QueryPhaseResultConsumer; import org.opensearch.action.search.SearchPhaseContext; @@ -75,6 +76,8 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static org.opensearch.search.RandomSearchRequestGenerator.randomSearchRequest; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -972,7 +975,8 @@ public void testInlinePipeline() throws Exception { /** * Tests a pipeline name defined in the search request source. */ - public void testInlineDefinedPipeline() throws Exception { + public void testInlineDefinedPipelineForMultiSearch() throws Exception { + int numberOfSearchRequests = randomIntBetween(0, 32); SearchPipelineService searchPipelineService = createWithProcessors(); SearchPipelineMetadata metadata = new SearchPipelineMetadata( @@ -988,7 +992,6 @@ public void testInlineDefinedPipeline() throws Exception { ), MediaTypeRegistry.JSON ) - ) ); @@ -999,34 +1002,49 @@ public void testInlineDefinedPipeline() throws Exception { .build(); searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); - SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1"); - SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); - - // Verify pipeline - PipelinedRequest pipelinedRequest = syncTransformRequest( - searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) - ); - Pipeline pipeline = pipelinedRequest.getPipeline(); - assertEquals("p1", pipeline.getId()); - assertEquals(1, pipeline.getSearchRequestProcessors().size()); - assertEquals(1, pipeline.getSearchResponseProcessors().size()); - - // Verify that pipeline transforms request - assertEquals(200, pipelinedRequest.source().size()); + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (int i = 0; i < numberOfSearchRequests; i++) { + SearchRequest searchRequest = randomSearchRequest(() -> { + // No need to return a very complex SearchSourceBuilder here, that is tested + // elsewhere + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(randomInt(10)); + searchSourceBuilder.size(randomIntBetween(20, 100)); + searchSourceBuilder.pipeline("p1"); + return searchSourceBuilder; + }); + multiSearchRequest.add(searchRequest); + + // Verify pipeline + PipelinedRequest pipelinedRequest = syncTransformRequest( + searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) + ); + Pipeline pipeline = pipelinedRequest.getPipeline(); + assertEquals("p1", pipeline.getId()); + assertEquals(1, pipeline.getSearchRequestProcessors().size()); + assertEquals(1, pipeline.getSearchResponseProcessors().size()); + + // Verify that pipeline transforms request + assertEquals(200, pipelinedRequest.source().size()); + + int size = 10; + SearchHit[] hits = new SearchHit[size]; + for (int j = 0; j < size; j++) { + hits[j] = new SearchHit(j, "doc" + j, Collections.emptyMap(), Collections.emptyMap()); + hits[j].score(j); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); - int size = 10; - SearchHit[] hits = new SearchHit[size]; - for (int i = 0; i < size; i++) { - hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap()); - hits[i].score(i); + SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); + for (int j = 0; j < size; j++) { + assertEquals(2.0, transformedResponse.getHits().getHits()[j].getScore(), 0.0001); + } } - SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); - SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); - SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); - for (int i = 0; i < size; i++) { - assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001); + for (SearchRequest subReq : multiSearchRequest.requests()) { + assertThat(multiSearchRequest.toString(), containsString(subReq.toString())); } } From c0163941154583f54636d68685d237e4d008b39f Mon Sep 17 00:00:00 2001 From: Owais Date: Wed, 18 Sep 2024 16:21:31 -0700 Subject: [PATCH 4/8] Updated test Signed-off-by: Owais --- .../pipeline/SearchPipelineServiceTests.java | 75 +++++++------------ 1 file changed, 27 insertions(+), 48 deletions(-) diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index c0639607de8d1..b52205996f34b 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -18,7 +18,6 @@ import org.opensearch.Version; import org.opensearch.action.search.DeleteSearchPipelineRequest; import org.opensearch.action.search.MockSearchPhaseContext; -import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.PutSearchPipelineRequest; import org.opensearch.action.search.QueryPhaseResultConsumer; import org.opensearch.action.search.SearchPhaseContext; @@ -76,8 +75,6 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; -import static org.opensearch.search.RandomSearchRequestGenerator.randomSearchRequest; -import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -972,11 +969,7 @@ public void testInlinePipeline() throws Exception { } } - /** - * Tests a pipeline name defined in the search request source. - */ - public void testInlineDefinedPipelineForMultiSearch() throws Exception { - int numberOfSearchRequests = randomIntBetween(0, 32); + public void testInlineDefinedPipeline() throws Exception { SearchPipelineService searchPipelineService = createWithProcessors(); SearchPipelineMetadata metadata = new SearchPipelineMetadata( @@ -1002,49 +995,35 @@ public void testInlineDefinedPipelineForMultiSearch() throws Exception { .build(); searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); - MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); - for (int i = 0; i < numberOfSearchRequests; i++) { - SearchRequest searchRequest = randomSearchRequest(() -> { - // No need to return a very complex SearchSourceBuilder here, that is tested - // elsewhere - SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); - searchSourceBuilder.from(randomInt(10)); - searchSourceBuilder.size(randomIntBetween(20, 100)); - searchSourceBuilder.pipeline("p1"); - return searchSourceBuilder; - }); - multiSearchRequest.add(searchRequest); - - // Verify pipeline - PipelinedRequest pipelinedRequest = syncTransformRequest( - searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) - ); - Pipeline pipeline = pipelinedRequest.getPipeline(); - assertEquals("p1", pipeline.getId()); - assertEquals(1, pipeline.getSearchRequestProcessors().size()); - assertEquals(1, pipeline.getSearchResponseProcessors().size()); - - // Verify that pipeline transforms request - assertEquals(200, pipelinedRequest.source().size()); - - int size = 10; - SearchHit[] hits = new SearchHit[size]; - for (int j = 0; j < size; j++) { - hits[j] = new SearchHit(j, "doc" + j, Collections.emptyMap(), Collections.emptyMap()); - hits[j].score(j); - } - SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); - SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); + SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1"); + SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); + searchRequest.pipeline(searchRequest.source().pipeline()); - SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); - for (int j = 0; j < size; j++) { - assertEquals(2.0, transformedResponse.getHits().getHits()[j].getScore(), 0.0001); - } + // Verify pipeline + PipelinedRequest pipelinedRequest = syncTransformRequest( + searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) + ); + Pipeline pipeline = pipelinedRequest.getPipeline(); + assertEquals("p1", pipeline.getId()); + assertEquals(1, pipeline.getSearchRequestProcessors().size()); + assertEquals(1, pipeline.getSearchResponseProcessors().size()); + + // Verify that pipeline transforms request + assertEquals(200, pipelinedRequest.source().size()); + + int size = 10; + SearchHit[] hits = new SearchHit[size]; + for (int i = 0; i < size; i++) { + hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap()); + hits[i].score(i); } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); - for (SearchRequest subReq : multiSearchRequest.requests()) { - assertThat(multiSearchRequest.toString(), containsString(subReq.toString())); + SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); + for (int i = 0; i < size; i++) { + assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001); } } From fe74c218af6bb79990bb600c62f9ec500c375ac4 Mon Sep 17 00:00:00 2001 From: Owais Date: Thu, 19 Sep 2024 18:03:24 -0700 Subject: [PATCH 5/8] Updated SearchRequest with search pipeline from source Signed-off-by: Owais --- .../org/opensearch/rest/action/search/RestSearchAction.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java index 3a6b45013e892..05465e32631fd 100644 --- a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java @@ -210,7 +210,7 @@ public static void parseSearchRequest( searchRequest.routing(request.param("routing")); searchRequest.preference(request.param("preference")); searchRequest.indicesOptions(IndicesOptions.fromRequest(request, searchRequest.indicesOptions())); - searchRequest.pipeline(request.param("search_pipeline")); + searchRequest.pipeline(request.param("search_pipeline", searchRequest.source().pipeline())); checkRestTotalHits(request, searchRequest); request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false); From b2c200100bf9f6612aeb68c3c057e944af1629b9 Mon Sep 17 00:00:00 2001 From: Owais Date: Thu, 19 Sep 2024 18:45:02 -0700 Subject: [PATCH 6/8] Added tests for parseSearchRequest Signed-off-by: Owais --- .../action/search/SearchRequestTests.java | 63 +++++++++++++++++-- 1 file changed, 59 insertions(+), 4 deletions(-) diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java index 40514c526f190..a916ea522ca30 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java @@ -42,6 +42,8 @@ import org.opensearch.geometry.LinearRing; import org.opensearch.index.query.GeoShapeQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.search.RestSearchAction; import org.opensearch.search.AbstractSearchTestCase; import org.opensearch.search.Scroll; import org.opensearch.search.builder.PointInTimeBuilder; @@ -50,14 +52,19 @@ import org.opensearch.search.rescore.QueryRescorerBuilder; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.VersionUtils; +import org.opensearch.test.rest.FakeRestRequest; import java.io.IOException; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.function.IntConsumer; import static java.util.Collections.emptyMap; +import static org.opensearch.action.search.SearchType.DFS_QUERY_THEN_FETCH; import static org.opensearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; import static org.hamcrest.Matchers.equalTo; +import static org.mockito.Mockito.mock; public class SearchRequestTests extends AbstractSearchTestCase { @@ -242,6 +249,57 @@ public void testCopyConstructor() throws IOException { assertNotSame(deserializedRequest, searchRequest); } + public void testParseSearchRequest() throws IOException { + RestRequest restRequest = new FakeRestRequest(); + SearchRequest searchRequest = createSearchRequest(); + IntConsumer setSize = mock(IntConsumer.class); + + restRequest.params().put("index", "index1,index2"); + restRequest.params().put("batched_reduce_size", "512"); + restRequest.params().put("pre_filter_shard_size", "128"); + restRequest.params().put("max_concurrent_shard_requests", "10"); + restRequest.params().put("allow_partial_search_results", "true"); + restRequest.params().put("phase_took", "false"); + restRequest.params().put("search_type", "dfs_query_then_fetch"); + restRequest.params().put("request_cache", "true"); + restRequest.params().put("scroll", "1m"); + restRequest.params().put("routing", "routing_value"); + restRequest.params().put("preference", "preference_value"); + restRequest.params().put("search_pipeline", "pipeline_value"); + restRequest.params().put("ccs_minimize_roundtrips", "true"); + restRequest.params().put("cancel_after_time_interval", "5s"); + + RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize); + + assertEquals(Arrays.asList("index1", "index2"), Arrays.asList(searchRequest.indices())); + assertEquals(512, searchRequest.getBatchedReduceSize()); + assertEquals(Integer.valueOf(128), searchRequest.getPreFilterShardSize()); + assertEquals(10, searchRequest.getMaxConcurrentShardRequests()); + assertTrue(searchRequest.allowPartialSearchResults()); + assertFalse(searchRequest.isPhaseTook()); + assertEquals(DFS_QUERY_THEN_FETCH, searchRequest.searchType()); + assertEquals(true, searchRequest.requestCache()); + assertEquals(TimeValue.timeValueMinutes(1), searchRequest.scroll().keepAlive()); + assertEquals("routing_value", searchRequest.routing()); + assertEquals("preference_value", searchRequest.preference()); + assertEquals("pipeline_value", searchRequest.pipeline()); + assertTrue(searchRequest.isCcsMinimizeRoundtrips()); + assertEquals(TimeValue.timeValueSeconds(5), searchRequest.getCancelAfterTimeInterval()); + } + + public void testParseSearchRequestWithUnsupportedSearchType() throws IOException { + RestRequest restRequest = new FakeRestRequest(); + SearchRequest searchRequest = createSearchRequest(); + IntConsumer setSize = mock(IntConsumer.class); + restRequest.params().put("search_type", "query_and_fetch"); + + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize) + ); + assertEquals("Unsupported search type [query_and_fetch]", exception.getMessage()); + } + public void testEqualsAndHashcode() throws IOException { checkEqualsAndHashCode(createSearchRequest(), SearchRequest::new, this::mutate); } @@ -268,10 +326,7 @@ private SearchRequest mutate(SearchRequest searchRequest) { ); mutators.add( () -> mutation.searchType( - randomValueOtherThan( - searchRequest.searchType(), - () -> randomFrom(SearchType.DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH) - ) + randomValueOtherThan(searchRequest.searchType(), () -> randomFrom(DFS_QUERY_THEN_FETCH, SearchType.QUERY_THEN_FETCH)) ) ); mutators.add(() -> mutation.source(randomValueOtherThan(searchRequest.source(), this::createSearchSourceBuilder))); From a1bc82b231be551a4e0fc5f97f411151e4f03313 Mon Sep 17 00:00:00 2001 From: Owais Date: Thu, 19 Sep 2024 19:33:54 -0700 Subject: [PATCH 7/8] Guard serialization with version check Signed-off-by: Owais --- .../opensearch/search/builder/SearchSourceBuilder.java | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index e949b145421c9..73cfca3ec912f 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -275,7 +275,6 @@ public SearchSourceBuilder(StreamInput in) throws IOException { seqNoAndPrimaryTerm = in.readOptionalBoolean(); extBuilders = in.readNamedWriteableList(SearchExtBuilder.class); profile = in.readBoolean(); - searchPipeline = in.readOptionalString(); searchAfterBuilder = in.readOptionalWriteable(SearchAfterBuilder::new); sliceBuilder = in.readOptionalWriteable(SliceBuilder::new); collapse = in.readOptionalWriteable(CollapseBuilder::new); @@ -300,6 +299,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException { derivedFields = in.readList(DerivedField::new); } } + if (in.getVersion().onOrAfter(Version.V_2_18_0)) { + searchPipeline = in.readOptionalString(); + } } @Override @@ -350,7 +352,6 @@ public void writeTo(StreamOutput out) throws IOException { out.writeOptionalBoolean(seqNoAndPrimaryTerm); out.writeNamedWriteableList(extBuilders); out.writeBoolean(profile); - out.writeOptionalString(searchPipeline); out.writeOptionalWriteable(searchAfterBuilder); out.writeOptionalWriteable(sliceBuilder); out.writeOptionalWriteable(collapse); @@ -381,6 +382,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeList(derivedFields); } } + if (out.getVersion().onOrAfter(Version.V_2_18_0)) { + out.writeOptionalString(searchPipeline); + } } /** From a6a64eeb358615b6125109250770f7f783fc189e Mon Sep 17 00:00:00 2001 From: Owais Date: Fri, 20 Sep 2024 16:48:05 -0700 Subject: [PATCH 8/8] Updated version and added another test for serialization Signed-off-by: Owais --- .../search/builder/SearchSourceBuilder.java | 15 +++++-- .../action/search/SearchRequestTests.java | 39 ------------------- .../builder/SearchSourceBuilderTests.java | 21 ++++++++++ 3 files changed, 32 insertions(+), 43 deletions(-) diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 73cfca3ec912f..dd4e4d073cb1b 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -299,7 +299,7 @@ public SearchSourceBuilder(StreamInput in) throws IOException { derivedFields = in.readList(DerivedField::new); } } - if (in.getVersion().onOrAfter(Version.V_2_18_0)) { + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { searchPipeline = in.readOptionalString(); } } @@ -382,7 +382,7 @@ public void writeTo(StreamOutput out) throws IOException { out.writeList(derivedFields); } } - if (out.getVersion().onOrAfter(Version.V_2_18_0)) { + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { out.writeOptionalString(searchPipeline); } } @@ -1239,6 +1239,7 @@ private SearchSourceBuilder shallowCopy( rewrittenBuilder.pointInTimeBuilder = pointInTimeBuilder; rewrittenBuilder.derivedFieldsObject = derivedFieldsObject; rewrittenBuilder.derivedFields = derivedFields; + rewrittenBuilder.searchPipeline = searchPipeline; return rewrittenBuilder; } @@ -1637,6 +1638,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t } + if (searchPipeline != null) { + builder.field(SEARCH_PIPELINE.getPreferredName(), searchPipeline); + } + return builder; } @@ -1914,7 +1919,8 @@ public int hashCode() { trackTotalHitsUpTo, pointInTimeBuilder, derivedFieldsObject, - derivedFields + derivedFields, + searchPipeline ); } @@ -1959,7 +1965,8 @@ public boolean equals(Object obj) { && Objects.equals(trackTotalHitsUpTo, other.trackTotalHitsUpTo) && Objects.equals(pointInTimeBuilder, other.pointInTimeBuilder) && Objects.equals(derivedFieldsObject, other.derivedFieldsObject) - && Objects.equals(derivedFields, other.derivedFields); + && Objects.equals(derivedFields, other.derivedFields) + && Objects.equals(searchPipeline, other.searchPipeline); } @Override diff --git a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java index a916ea522ca30..acda1445bacbb 100644 --- a/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java +++ b/server/src/test/java/org/opensearch/action/search/SearchRequestTests.java @@ -56,7 +56,6 @@ import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.List; import java.util.function.IntConsumer; @@ -249,44 +248,6 @@ public void testCopyConstructor() throws IOException { assertNotSame(deserializedRequest, searchRequest); } - public void testParseSearchRequest() throws IOException { - RestRequest restRequest = new FakeRestRequest(); - SearchRequest searchRequest = createSearchRequest(); - IntConsumer setSize = mock(IntConsumer.class); - - restRequest.params().put("index", "index1,index2"); - restRequest.params().put("batched_reduce_size", "512"); - restRequest.params().put("pre_filter_shard_size", "128"); - restRequest.params().put("max_concurrent_shard_requests", "10"); - restRequest.params().put("allow_partial_search_results", "true"); - restRequest.params().put("phase_took", "false"); - restRequest.params().put("search_type", "dfs_query_then_fetch"); - restRequest.params().put("request_cache", "true"); - restRequest.params().put("scroll", "1m"); - restRequest.params().put("routing", "routing_value"); - restRequest.params().put("preference", "preference_value"); - restRequest.params().put("search_pipeline", "pipeline_value"); - restRequest.params().put("ccs_minimize_roundtrips", "true"); - restRequest.params().put("cancel_after_time_interval", "5s"); - - RestSearchAction.parseSearchRequest(searchRequest, restRequest, null, namedWriteableRegistry, setSize); - - assertEquals(Arrays.asList("index1", "index2"), Arrays.asList(searchRequest.indices())); - assertEquals(512, searchRequest.getBatchedReduceSize()); - assertEquals(Integer.valueOf(128), searchRequest.getPreFilterShardSize()); - assertEquals(10, searchRequest.getMaxConcurrentShardRequests()); - assertTrue(searchRequest.allowPartialSearchResults()); - assertFalse(searchRequest.isPhaseTook()); - assertEquals(DFS_QUERY_THEN_FETCH, searchRequest.searchType()); - assertEquals(true, searchRequest.requestCache()); - assertEquals(TimeValue.timeValueMinutes(1), searchRequest.scroll().keepAlive()); - assertEquals("routing_value", searchRequest.routing()); - assertEquals("preference_value", searchRequest.preference()); - assertEquals("pipeline_value", searchRequest.pipeline()); - assertTrue(searchRequest.isCcsMinimizeRoundtrips()); - assertEquals(TimeValue.timeValueSeconds(5), searchRequest.getCancelAfterTimeInterval()); - } - public void testParseSearchRequestWithUnsupportedSearchType() throws IOException { RestRequest restRequest = new FakeRestRequest(); SearchRequest searchRequest = createSearchRequest(); diff --git a/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java b/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java index 9697f4cee0d58..da8ccc9e121e0 100644 --- a/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java +++ b/server/src/test/java/org/opensearch/search/builder/SearchSourceBuilderTests.java @@ -421,6 +421,27 @@ public void testDerivedFieldsParsingAndSerializationObjectType() throws IOExcept } } + public void testSearchPipelineParsingAndSerialization() throws IOException { + String restContent = "{ \"query\": { \"match_all\": {} }, \"from\": 0, \"size\": 10, \"search_pipeline\": \"my_pipeline\" }"; + String expectedContent = "{\"from\":0,\"size\":10,\"query\":{\"match_all\":{\"boost\":1.0}},\"search_pipeline\":\"my_pipeline\"}"; + + try (XContentParser parser = createParser(JsonXContent.jsonXContent, restContent)) { + SearchSourceBuilder searchSourceBuilder = SearchSourceBuilder.fromXContent(parser); + searchSourceBuilder = rewrite(searchSourceBuilder); + + try (BytesStreamOutput output = new BytesStreamOutput()) { + searchSourceBuilder.writeTo(output); + try (StreamInput in = new NamedWriteableAwareStreamInput(output.bytes().streamInput(), namedWriteableRegistry)) { + SearchSourceBuilder deserializedBuilder = new SearchSourceBuilder(in); + String actualContent = deserializedBuilder.toString(); + assertEquals(expectedContent, actualContent); + assertEquals(searchSourceBuilder.hashCode(), deserializedBuilder.hashCode()); + assertNotSame(searchSourceBuilder, deserializedBuilder); + } + } + } + } + public void testAggsParsing() throws IOException { { String restContent = "{\n"