diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index e1fe6eac7e9c1..514e8d10eeca1 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -17,20 +17,16 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.Scroll; -import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.tasks.TaskId; @@ -324,124 +320,15 @@ public void writeTo(StreamOutput out) throws IOException { public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; boolean scroll = scroll() != null; + + if (source != null) { + validationException = source.validate(validationException, scroll); + } if (scroll) { - if (source != null) { - if (source.trackTotalHitsUpTo() != null && source.trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { - validationException = addValidationError( - "disabling [track_total_hits] is not allowed in a scroll context", - validationException - ); - } - if (source.from() > 0) { - validationException = addValidationError("using [from] is not allowed in a scroll context", validationException); - } - if (source.size() == 0) { - validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException); - } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException); - } - if (CollectionUtils.isEmpty(source.searchAfter()) == false) { - validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException); - } - if (source.collapse() != null) { - validationException = addValidationError("cannot use `collapse` in a scroll context", validationException); - } - } if (requestCache != null && requestCache) { validationException = addValidationError("[request_cache] cannot be used in a scroll context", validationException); } } - if (source != null) { - if (source.slice() != null) { - if (source.pointInTimeBuilder() == null && (scroll == false)) { - validationException = addValidationError( - "[slice] can only be used with [scroll] or [point-in-time] requests", - validationException - ); - } - } - if (source.from() > 0 && CollectionUtils.isEmpty(source.searchAfter()) == false) { - validationException = addValidationError( - "[from] parameter must be set to 0 when [search_after] is used", - validationException - ); - } - if (source.storedFields() != null) { - if (source.storedFields().fetchFields() == false) { - if (source.fetchSource() != null && source.fetchSource().fetchSource()) { - validationException = addValidationError( - "[stored_fields] cannot be disabled if [_source] is requested", - validationException - ); - } - if (source.fetchFields() != null) { - validationException = addValidationError( - "[stored_fields] cannot be disabled when using the [fields] option", - validationException - ); - } - - } - } - if (source.subSearches().size() >= 2 && source.rankBuilder() == null) { - validationException = addValidationError("[sub_searches] requires [rank]", validationException); - } - if (source.aggregations() != null) { - validationException = source.aggregations().validate(validationException); - } - if (source.rankBuilder() != null) { - int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size(); - if (size == 0) { - validationException = addValidationError("[rank] requires [size] greater than [0]", validationException); - } - if (size > source.rankBuilder().rankWindowSize()) { - validationException = addValidationError( - "[rank] requires [rank_window_size: " - + source.rankBuilder().rankWindowSize() - + "]" - + " be greater than or equal to [size: " - + size - + "]", - validationException - ); - } - int queryCount = source.subSearches().size() + source.knnSearch().size(); - if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) { - validationException = addValidationError( - "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", - validationException - ); - } - if (scroll) { - validationException = addValidationError("[rank] cannot be used in a scroll context", validationException); - } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); - } - if (source.sorts() != null && source.sorts().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [sort]", validationException); - } - if (source.collapse() != null) { - validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); - } - if (source.suggest() != null && source.suggest().getSuggestions().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); - } - if (source.highlighter() != null) { - validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); - } - if (source.pointInTimeBuilder() != null) { - validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); - } - } - if (source.rescores() != null) { - for (@SuppressWarnings("rawtypes") - RescorerBuilder rescoreBuilder : source.rescores()) { - validationException = rescoreBuilder.validate(this, validationException); - } - } - } if (pointInTimeBuilder() != null) { if (scroll) { validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException); @@ -461,16 +348,6 @@ public ActionRequestValidationException validate() { if (preference() != null) { validationException = addValidationError("[preference] cannot be used with point in time", validationException); } - } else if (source != null && source.sorts() != null) { - for (SortBuilder sortBuilder : source.sorts()) { - if (sortBuilder instanceof FieldSortBuilder - && ShardDocSortField.NAME.equals(((FieldSortBuilder) sortBuilder).getFieldName())) { - validationException = addValidationError( - "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", - validationException - ); - } - } } if (minCompatibleShardNode() != null) { if (isCcsMinimizeRoundtrips()) { diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index cb2c53a97fbc3..a80378db99348 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -10,12 +10,16 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.logging.DeprecationLogger; +import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; @@ -28,6 +32,7 @@ import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.script.Script; import org.elasticsearch.search.SearchExtBuilder; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; @@ -43,7 +48,9 @@ import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.slice.SliceBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -71,6 +78,7 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.index.query.AbstractQueryBuilder.parseTopLevelQuery; import static org.elasticsearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_ACCURATE; @@ -78,10 +86,9 @@ /** * A search source builder allowing to easily build search source. Simple - * construction using - * {@link org.elasticsearch.search.builder.SearchSourceBuilder#searchSource()}. + * construction using {@link SearchSourceBuilder#searchSource()}. * - * @see org.elasticsearch.action.search.SearchRequest#source(SearchSourceBuilder) + * @see SearchRequest#source(SearchSourceBuilder) */ public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class); @@ -141,6 +148,8 @@ public static HighlightBuilder highlight() { return new HighlightBuilder(); } + private transient RetrieverBuilder retrieverBuilder; + private List subSearchSourceBuilders = new ArrayList<>(); private QueryBuilder postQueryBuilder; @@ -283,6 +292,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { + if (retrieverBuilder != null) { + throw new IllegalStateException("SearchSourceBuilder should be rewritten first"); + } out.writeOptionalWriteable(aggregations); out.writeOptionalBoolean(explain); out.writeOptionalWriteable(fetchSourceContext); @@ -367,6 +379,18 @@ public void writeTo(StreamOutput out) throws IOException { } } + /** + * Sets the retriever for this request. + */ + public SearchSourceBuilder retriever(RetrieverBuilder retrieverBuilder) { + this.retrieverBuilder = retrieverBuilder; + return this; + } + + public RetrieverBuilder retriever() { + return retrieverBuilder; + } + /** * Sets the query for this request. */ @@ -1134,6 +1158,21 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti highlightBuilder ) )); + if (retrieverBuilder != null) { + var newRetriever = retrieverBuilder.rewrite(context); + if (newRetriever != retrieverBuilder) { + var rewritten = shallowCopy(); + rewritten.retrieverBuilder = newRetriever; + return rewritten; + } else { + // retriever is transient, the rewritten version is extracted in this source. + var retriever = retrieverBuilder; + retrieverBuilder = null; + retriever.extractToSearchSourceBuilder(this, false); + validate(); + } + } + List subSearchSourceBuilders = Rewriteable.rewrite(this.subSearchSourceBuilders, context); QueryBuilder postQueryBuilder = null; if (this.postQueryBuilder != null) { @@ -1293,7 +1332,6 @@ private SearchSourceBuilder parseXContent( } List knnBuilders = new ArrayList<>(); - RetrieverBuilder retrieverBuilder = null; SearchUsage searchUsage = new SearchUsage(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -1627,39 +1665,6 @@ private SearchSourceBuilder parseXContent( } knnSearch = knnBuilders.stream().map(knnBuilder -> knnBuilder.build(size())).collect(Collectors.toList()); - - if (retrieverBuilder != null) { - List specified = new ArrayList<>(); - if (subSearchSourceBuilders.isEmpty() == false) { - specified.add(QUERY_FIELD.getPreferredName()); - } - if (knnSearch.isEmpty() == false) { - specified.add(KNN_FIELD.getPreferredName()); - } - if (searchAfterBuilder != null) { - specified.add(SEARCH_AFTER.getPreferredName()); - } - if (terminateAfter != DEFAULT_TERMINATE_AFTER) { - specified.add(TERMINATE_AFTER_FIELD.getPreferredName()); - } - if (sorts != null) { - specified.add(SORT_FIELD.getPreferredName()); - } - if (rescoreBuilders != null) { - specified.add(RESCORE_FIELD.getPreferredName()); - } - if (minScore != null) { - specified.add(MIN_SCORE_FIELD.getPreferredName()); - } - if (rankBuilder != null) { - specified.add(RANK_FIELD.getPreferredName()); - } - if (specified.isEmpty() == false) { - throw new IllegalArgumentException("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified); - } - retrieverBuilder.extractToSearchSourceBuilder(this, false); - } - searchUsageConsumer.accept(searchUsage); return this; } @@ -1689,6 +1694,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t builder.field(TERMINATE_AFTER_FIELD.getPreferredName(), terminateAfter); } + if (retrieverBuilder != null) { + builder.field(RETRIEVER.getPreferredName(), retrieverBuilder); + } + if (subSearchSourceBuilders.isEmpty() == false) { if (subSearchSourceBuilders.size() == 1) { builder.field(QUERY_FIELD.getPreferredName(), subSearchSourceBuilders.get(0).getQueryBuilder()); @@ -2183,4 +2192,169 @@ public boolean supportsParallelCollection(ToLongFunction fieldCardinalit return collapse == null && (aggregations == null || aggregations.supportsParallelCollection(fieldCardinality)); } + + private void validate() throws ValidationException { + var exceptions = validate(null, false); + if (exceptions != null) { + throw exceptions; + } + } + + public ActionRequestValidationException validate(ActionRequestValidationException validationException, boolean isScroll) { + if (retriever() != null) { + List specified = new ArrayList<>(); + if (subSearches().isEmpty() == false) { + specified.add(QUERY_FIELD.getPreferredName()); + } + if (knnSearch().isEmpty() == false) { + specified.add(KNN_FIELD.getPreferredName()); + } + if (searchAfter() != null) { + specified.add(SEARCH_AFTER.getPreferredName()); + } + if (terminateAfter() != DEFAULT_TERMINATE_AFTER) { + specified.add(TERMINATE_AFTER_FIELD.getPreferredName()); + } + if (sorts() != null) { + specified.add(SORT_FIELD.getPreferredName()); + } + if (rescores() != null) { + specified.add(RESCORE_FIELD.getPreferredName()); + } + if (minScore() != null) { + specified.add(MIN_SCORE_FIELD.getPreferredName()); + } + if (rankBuilder() != null) { + specified.add(RANK_FIELD.getPreferredName()); + } + if (specified.isEmpty() == false) { + validationException = addValidationError( + "cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified, + validationException + ); + } + } + if (isScroll) { + if (trackTotalHitsUpTo() != null && trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { + validationException = addValidationError( + "disabling [track_total_hits] is not allowed in a scroll context", + validationException + ); + } + if (from() > 0) { + validationException = addValidationError("using [from] is not allowed in a scroll context", validationException); + } + if (size() == 0) { + validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException); + } + if (rescores() != null && rescores().isEmpty() == false) { + validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException); + } + if (CollectionUtils.isEmpty(searchAfter()) == false) { + validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException); + } + if (collapse() != null) { + validationException = addValidationError("cannot use `collapse` in a scroll context", validationException); + } + } + if (slice() != null) { + if (pointInTimeBuilder() == null && (isScroll == false)) { + validationException = addValidationError( + "[slice] can only be used with [scroll] or [point-in-time] requests", + validationException + ); + } + } + if (from() > 0 && CollectionUtils.isEmpty(searchAfter()) == false) { + validationException = addValidationError("[from] parameter must be set to 0 when [search_after] is used", validationException); + } + if (storedFields() != null) { + if (storedFields().fetchFields() == false) { + if (fetchSource() != null && fetchSource().fetchSource()) { + validationException = addValidationError( + "[stored_fields] cannot be disabled if [_source] is requested", + validationException + ); + } + if (fetchFields() != null) { + validationException = addValidationError( + "[stored_fields] cannot be disabled when using the [fields] option", + validationException + ); + } + } + } + if (subSearches().size() >= 2 && rankBuilder() == null) { + validationException = addValidationError("[sub_searches] requires [rank]", validationException); + } + if (aggregations() != null) { + validationException = aggregations().validate(validationException); + } + + if (rankBuilder() != null) { + int s = size() == -1 ? SearchService.DEFAULT_SIZE : size(); + if (s == 0) { + validationException = addValidationError("[rank] requires [size] greater than [0]", validationException); + } + if (s > rankBuilder().rankWindowSize()) { + validationException = addValidationError( + "[rank] requires [rank_window_size: " + + rankBuilder().rankWindowSize() + + "]" + + " be greater than or equal to [size: " + + s + + "]", + validationException + ); + } + int queryCount = subSearches().size() + knnSearch().size(); + if (rankBuilder().isCompoundBuilder() && queryCount < 2) { + validationException = addValidationError( + "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", + validationException + ); + } + if (isScroll) { + validationException = addValidationError("[rank] cannot be used in a scroll context", validationException); + } + if (rescores() != null && rescores().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); + } + if (sorts() != null && sorts().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [sort]", validationException); + } + if (collapse() != null) { + validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); + } + if (suggest() != null && suggest().getSuggestions().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); + } + if (highlighter() != null) { + validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); + } + if (pointInTimeBuilder() != null) { + validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); + } + } + + if (rescores() != null) { + for (@SuppressWarnings("rawtypes") + var rescorer : rescores()) { + validationException = rescorer.validate(this, validationException); + } + } + + if (pointInTimeBuilder() == null && sorts() != null) { + for (var sortBuilder : sorts()) { + if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder + && ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) { + validationException = addValidationError( + "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", + validationException + ); + } + } + } + return validationException; + } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java index 5cd5a888581c8..946fd46fe6aec 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.rescore; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -17,6 +16,7 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -120,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) { + public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) { return validationException; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 6e3d2a58dbd5d..3a9979030683a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -14,6 +14,8 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.AbstractObjectParser; import org.elasticsearch.xcontent.FilterXContentParserWrapper; @@ -33,16 +35,17 @@ /** * A retriever represents an API element that returns an ordered list of top * documents. These can be obtained from a query, from another retriever, etc. - * Internally, a {@link RetrieverBuilder} is just a wrapper for other search - * elements that are extracted into a {@link SearchSourceBuilder}. The advantage - * retrievers have is in the API they appear as a tree-like structure enabling + * Internally, a {@link RetrieverBuilder} is first rewritten into its simplest + * form and then its elements are extracted into a {@link SearchSourceBuilder}. + * + * The advantage retrievers have is in the API they appear as a tree-like structure enabling * easier reasoning about what a search does. * * This is the base class for all other retrievers. This class does not support * serialization and is expected to be fully extracted to a {@link SearchSourceBuilder} * prior to any transport calls. */ -public abstract class RetrieverBuilder implements ToXContent { +public abstract class RetrieverBuilder implements Rewriteable, ToXContent { public static final NodeFeature RETRIEVERS_SUPPORTED = new NodeFeature("retrievers_supported"); @@ -181,6 +184,13 @@ protected static RetrieverBuilder parseInnerRetrieverBuilder(XContentParser pars protected String retrieverName; + /** + * Determines if this retriever contains sub-retrievers that need to be executed prior to search. + */ + public boolean isCompound() { + return false; + } + /** * Gets the filters for this retriever. */ @@ -188,8 +198,13 @@ public List getPreFilterQueryBuilders() { return preFilterQueryBuilders; } + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + return this; + } + /** - * This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}. + * This method is called at the end of rewriting on behalf of a {@link SearchSourceBuilder}. * Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}. */ public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 601c55293418d..348a65d0c4960 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -298,6 +298,10 @@ public String getField() { return field; } + public List getFilterQueries() { + return filterQueries; + } + public KnnSearchBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index cbbbe7d86f4e2..de35d765a1551 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -10,9 +10,14 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RandomQueryBuilder; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -23,6 +28,10 @@ import java.util.List; import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase { @@ -34,7 +43,7 @@ public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [query]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query]")); } try ( @@ -44,26 +49,35 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [knn]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [knn]")); } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"search_after\": [1], \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [search_after]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [search_after]")); + } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"terminate_after\": 1, \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [terminate_after]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [terminate_after]")); } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"sort\": [\"field\"], \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [sort]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [sort]")); } try ( @@ -73,14 +87,18 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [rescore]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [rescore]")); } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"min_score\": 2, \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [min_score]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [min_score]")); } try ( @@ -90,8 +108,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [query, terminate_after, min_score]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query, terminate_after, min_score]")); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java index bec534d89cc03..cd0d8f8d50c1e 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java @@ -11,8 +11,15 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RandomQueryBuilder; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilderTests; import org.elasticsearch.search.searchafter.SearchAfterBuilderTests; import org.elasticsearch.search.sort.SortBuilderTests; @@ -27,6 +34,11 @@ import java.util.List; import java.util.function.BiFunction; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; + public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCase { /** @@ -59,7 +71,7 @@ public static StandardRetrieverBuilder createRandomStandardRetrieverBuilder( } if (randomBoolean()) { - standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(); + standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(false); } if (randomBoolean()) { @@ -109,4 +121,52 @@ protected String[] getShuffleFieldsExceptions() { protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()); } + + public void testRewrite() throws IOException { + for (int i = 0; i < 10; i++) { + StandardRetrieverBuilder standardRetriever = createTestInstance(); + SearchSourceBuilder source = new SearchSourceBuilder().retriever(standardRetriever); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + source = Rewriteable.rewrite(source, queryRewriteContext); + assertNull(source.retriever()); + assertTrue(source.knnSearch().isEmpty()); + if (standardRetriever.queryBuilder != null) { + assertNotNull(source.query()); + if (standardRetriever.preFilterQueryBuilders.size() > 0) { + if (source.query() instanceof MatchAllQueryBuilder == false + && source.query() instanceof MatchNoneQueryBuilder == false) { + assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); + assertFalse(bq.must().isEmpty()); + assertThat(bq.must().size(), equalTo(1)); + assertThat(bq.must().get(0), equalTo(standardRetriever.queryBuilder)); + for (int j = 0; j < bq.filter().size(); j++) { + assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j)); + } + } + } else { + assertEqualQueryOrMatchAllNone(source.query(), standardRetriever.queryBuilder); + } + } else if (standardRetriever.preFilterQueryBuilders.size() > 0) { + if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) { + assertNotNull(source.query()); + assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); + assertTrue(bq.must().isEmpty()); + for (int j = 0; j < bq.filter().size(); j++) { + assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j)); + } + } + } else { + assertNull(source.query()); + } + if (standardRetriever.sortBuilders != null) { + assertThat(source.sorts().size(), equalTo(standardRetriever.sortBuilders.size())); + } + } + } + + private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) { + assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected))); + } } diff --git a/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java b/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java index eee98297c7a13..84f87b3f01881 100644 --- a/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java @@ -119,7 +119,7 @@ public void testSingleFieldSort() throws IOException { public void testRandomSortBuilders() throws IOException { for (int runs = 0; runs < NUMBER_OF_RUNS; runs++) { Set expectedWarningHeaders = new HashSet<>(); - List> testBuilders = randomSortBuilderList(); + List> testBuilders = randomSortBuilderList(randomBoolean()); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); xContentBuilder.startObject(); if (testBuilders.size() > 1) { @@ -171,7 +171,7 @@ public void testRandomSortBuilders() throws IOException { } } - public static List> randomSortBuilderList() { + public static List> randomSortBuilderList(boolean hasPIT) { int size = randomIntBetween(1, 5); List> list = new ArrayList<>(size); for (int i = 0; i < size; i++) { @@ -181,7 +181,7 @@ public static List> randomSortBuilderList() { case 2 -> SortBuilders.fieldSort(FieldSortBuilder.DOC_FIELD_NAME); case 3 -> GeoDistanceSortBuilderTests.randomGeoDistanceSortBuilder(); case 4 -> ScriptSortBuilderTests.randomScriptSortBuilder(); - case 5 -> SortBuilders.pitTiebreaker(); + case 5 -> hasPIT ? SortBuilders.pitTiebreaker() : ScriptSortBuilderTests.randomScriptSortBuilder(); default -> throw new IllegalStateException("unexpected randomization in tests"); }); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java index f8f9caf365918..d2179a69ebc24 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java @@ -11,12 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -134,10 +134,10 @@ public RescorerBuilder rewrite(QueryRewriteContex } @Override - public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) { - validationException = super.validate(searchRequest, validationException); + public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) { + validationException = super.validate(source, validationException); - int searchRequestPaginationSize = searchRequest.source().from() + searchRequest.source().size(); + int searchRequestPaginationSize = source.from() + source.size(); if (windowSize() < searchRequestPaginationSize) { return addValidationError( @@ -151,7 +151,7 @@ public ActionRequestValidationException validate(SearchRequest searchRequest, Ac } @SuppressWarnings("rawtypes") - List rescorers = searchRequest.source().rescores(); + List rescorers = source.rescores(); assert rescorers != null && rescorers.contains(this); for (int i = rescorers.indexOf(this) + 1; i < rescorers.size(); i++) { diff --git a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java index 229f900ef3d15..f5a9f4e9b0c3e 100644 --- a/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java +++ b/x-pack/plugin/rank-rrf/src/test/java/org/elasticsearch/xpack/rank/rrf/RRFRetrieverBuilderTests.java @@ -48,7 +48,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[search_after] cannot be used in children of compound retrievers", iae.getMessage()); } @@ -60,7 +63,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[terminate_after] cannot be used in children of compound retrievers", iae.getMessage()); } @@ -71,7 +77,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[sort] cannot be used in children of compound retrievers", iae.getMessage()); } @@ -82,7 +91,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[min_score] cannot be used in children of compound retrievers", iae.getMessage()); } @@ -94,7 +106,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[collapse] cannot be used in children of compound retrievers", iae.getMessage()); } @@ -105,7 +120,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[rank] cannot be used in children of compound retrievers", iae.getMessage()); } } @@ -119,7 +137,10 @@ public void testRetrieverBuilderParsingMaxDepth() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); + IllegalArgumentException iae = expectThrows( + IllegalArgumentException.class, + () -> ssb.parseXContent(parser, true, nf -> true).rewrite(null) + ); assertEquals("[1:65] [rrf] failed to parse field [retrievers]", iae.getMessage()); assertEquals( "the nested depth of the [standard] retriever exceeds the maximum nested depth [2] for retrievers",