diff --git a/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java b/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java index 5b887b48f696e..f16d7d1e7d6a3 100644 --- a/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java +++ b/server/src/main/java/org/opensearch/action/search/MultiSearchRequest.java @@ -310,6 +310,10 @@ public static void readMultiLineFormat( ) { consumer.accept(searchRequest, parser); } + + if (searchRequest.source() != null && searchRequest.source().pipeline() != null) { + searchRequest.pipeline(searchRequest.source().pipeline()); + } // move pointers from = nextMarker + 1; } diff --git a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java index e4ec85db4d331..012d6695c042b 100644 --- a/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java +++ b/server/src/main/java/org/opensearch/search/pipeline/SearchPipelineService.java @@ -395,9 +395,6 @@ public PipelinedRequest resolvePipeline(SearchRequest searchRequest, IndexNameEx if (searchRequest.pipeline() != null) { // Named pipeline specified for the request pipelineId = searchRequest.pipeline(); - } else if (searchRequest.source() != null && searchRequest.source().pipeline() != null) { - // Inline pipeline specified for the request - pipelineId = searchRequest.source().pipeline(); } else if (state != null && searchRequest.indices() != null && searchRequest.indices().length != 0) { try { // Check for index default pipeline diff --git a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java index 2e91a0ad25188..c0639607de8d1 100644 --- a/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java +++ b/server/src/test/java/org/opensearch/search/pipeline/SearchPipelineServiceTests.java @@ -18,6 +18,7 @@ import org.opensearch.Version; import org.opensearch.action.search.DeleteSearchPipelineRequest; import org.opensearch.action.search.MockSearchPhaseContext; +import org.opensearch.action.search.MultiSearchRequest; import org.opensearch.action.search.PutSearchPipelineRequest; import org.opensearch.action.search.QueryPhaseResultConsumer; import org.opensearch.action.search.SearchPhaseContext; @@ -75,6 +76,8 @@ import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import static org.opensearch.search.RandomSearchRequestGenerator.randomSearchRequest; +import static org.hamcrest.Matchers.containsString; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -972,7 +975,8 @@ public void testInlinePipeline() throws Exception { /** * Tests a pipeline name defined in the search request source. */ - public void testInlineDefinedPipeline() throws Exception { + public void testInlineDefinedPipelineForMultiSearch() throws Exception { + int numberOfSearchRequests = randomIntBetween(0, 32); SearchPipelineService searchPipelineService = createWithProcessors(); SearchPipelineMetadata metadata = new SearchPipelineMetadata( @@ -988,7 +992,6 @@ public void testInlineDefinedPipeline() throws Exception { ), MediaTypeRegistry.JSON ) - ) ); @@ -999,34 +1002,49 @@ public void testInlineDefinedPipeline() throws Exception { .build(); searchPipelineService.applyClusterState(new ClusterChangedEvent("", clusterState, previousState)); - SearchSourceBuilder sourceBuilder = SearchSourceBuilder.searchSource().size(100).pipeline("p1"); - SearchRequest searchRequest = new SearchRequest().source(sourceBuilder); - - // Verify pipeline - PipelinedRequest pipelinedRequest = syncTransformRequest( - searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) - ); - Pipeline pipeline = pipelinedRequest.getPipeline(); - assertEquals("p1", pipeline.getId()); - assertEquals(1, pipeline.getSearchRequestProcessors().size()); - assertEquals(1, pipeline.getSearchResponseProcessors().size()); - - // Verify that pipeline transforms request - assertEquals(200, pipelinedRequest.source().size()); + MultiSearchRequest multiSearchRequest = new MultiSearchRequest(); + for (int i = 0; i < numberOfSearchRequests; i++) { + SearchRequest searchRequest = randomSearchRequest(() -> { + // No need to return a very complex SearchSourceBuilder here, that is tested + // elsewhere + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); + searchSourceBuilder.from(randomInt(10)); + searchSourceBuilder.size(randomIntBetween(20, 100)); + searchSourceBuilder.pipeline("p1"); + return searchSourceBuilder; + }); + multiSearchRequest.add(searchRequest); + + // Verify pipeline + PipelinedRequest pipelinedRequest = syncTransformRequest( + searchPipelineService.resolvePipeline(searchRequest, indexNameExpressionResolver) + ); + Pipeline pipeline = pipelinedRequest.getPipeline(); + assertEquals("p1", pipeline.getId()); + assertEquals(1, pipeline.getSearchRequestProcessors().size()); + assertEquals(1, pipeline.getSearchResponseProcessors().size()); + + // Verify that pipeline transforms request + assertEquals(200, pipelinedRequest.source().size()); + + int size = 10; + SearchHit[] hits = new SearchHit[size]; + for (int j = 0; j < size; j++) { + hits[j] = new SearchHit(j, "doc" + j, Collections.emptyMap(), Collections.emptyMap()); + hits[j].score(j); + } + SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); + SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); - int size = 10; - SearchHit[] hits = new SearchHit[size]; - for (int i = 0; i < size; i++) { - hits[i] = new SearchHit(i, "doc" + i, Collections.emptyMap(), Collections.emptyMap()); - hits[i].score(i); + SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); + for (int j = 0; j < size; j++) { + assertEquals(2.0, transformedResponse.getHits().getHits()[j].getScore(), 0.0001); + } } - SearchHits searchHits = new SearchHits(hits, new TotalHits(size * 2, TotalHits.Relation.EQUAL_TO), size); - SearchResponseSections searchResponseSections = new SearchResponseSections(searchHits, null, null, false, false, null, 0); - SearchResponse searchResponse = new SearchResponse(searchResponseSections, null, 1, 1, 0, 10, null, null); - SearchResponse transformedResponse = syncTransformResponse(pipelinedRequest, searchResponse); - for (int i = 0; i < size; i++) { - assertEquals(2.0, transformedResponse.getHits().getHits()[i].getScore(), 0.0001); + for (SearchRequest subReq : multiSearchRequest.requests()) { + assertThat(multiSearchRequest.toString(), containsString(subReq.toString())); } }