From e19845b0fb3a3d415e6de2c4b59bec4788e33886 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Tue, 9 Jul 2024 15:03:27 +0100 Subject: [PATCH] Move Retriever Handling to Rewrite Phase This change moves the handling of the retriever to the rewrite phase. It also adds validation of the search source builder after extracting the retriever into the source builder. Relates #110482 --- .../action/search/SearchRequest.java | 131 +-------- .../search/builder/SearchSourceBuilder.java | 248 +++++++++++++++--- .../search/rescore/RescorerBuilder.java | 4 +- .../search/retriever/RetrieverBuilder.java | 25 +- .../search/vectors/KnnSearchBuilder.java | 4 + .../KnnRetrieverBuilderParsingTests.java | 34 ++- .../retriever/RetrieverBuilderErrorTests.java | 52 ++-- .../StandardRetrieverBuilderParsingTests.java | 62 ++++- .../search/sort/SortBuilderTests.java | 6 +- .../ltr/LearningToRankRescorerBuilder.java | 10 +- 10 files changed, 379 insertions(+), 197 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index e1fe6eac7e9c1..514e8d10eeca1 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -17,20 +17,16 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.Scroll; -import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.tasks.TaskId; @@ -324,124 +320,15 @@ public void writeTo(StreamOutput out) throws IOException { public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; boolean scroll = scroll() != null; + + if (source != null) { + validationException = source.validate(validationException, scroll); + } if (scroll) { - if (source != null) { - if (source.trackTotalHitsUpTo() != null && source.trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { - validationException = addValidationError( - "disabling [track_total_hits] is not allowed in a scroll context", - validationException - ); - } - if (source.from() > 0) { - validationException = addValidationError("using [from] is not allowed in a scroll context", validationException); - } - if (source.size() == 0) { - validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException); - } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException); - } - if (CollectionUtils.isEmpty(source.searchAfter()) == false) { - validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException); - } - if (source.collapse() != null) { - validationException = addValidationError("cannot use `collapse` in a scroll context", validationException); - } - } if (requestCache != null && requestCache) { validationException = addValidationError("[request_cache] cannot be used in a scroll context", validationException); } } - if (source != null) { - if (source.slice() != null) { - if (source.pointInTimeBuilder() == null && (scroll == false)) { - validationException = addValidationError( - "[slice] can only be used with [scroll] or [point-in-time] requests", - validationException - ); - } - } - if (source.from() > 0 && CollectionUtils.isEmpty(source.searchAfter()) == false) { - validationException = addValidationError( - "[from] parameter must be set to 0 when [search_after] is used", - validationException - ); - } - if (source.storedFields() != null) { - if (source.storedFields().fetchFields() == false) { - if (source.fetchSource() != null && source.fetchSource().fetchSource()) { - validationException = addValidationError( - "[stored_fields] cannot be disabled if [_source] is requested", - validationException - ); - } - if (source.fetchFields() != null) { - validationException = addValidationError( - "[stored_fields] cannot be disabled when using the [fields] option", - validationException - ); - } - - } - } - if (source.subSearches().size() >= 2 && source.rankBuilder() == null) { - validationException = addValidationError("[sub_searches] requires [rank]", validationException); - } - if (source.aggregations() != null) { - validationException = source.aggregations().validate(validationException); - } - if (source.rankBuilder() != null) { - int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size(); - if (size == 0) { - validationException = addValidationError("[rank] requires [size] greater than [0]", validationException); - } - if (size > source.rankBuilder().rankWindowSize()) { - validationException = addValidationError( - "[rank] requires [rank_window_size: " - + source.rankBuilder().rankWindowSize() - + "]" - + " be greater than or equal to [size: " - + size - + "]", - validationException - ); - } - int queryCount = source.subSearches().size() + source.knnSearch().size(); - if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) { - validationException = addValidationError( - "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", - validationException - ); - } - if (scroll) { - validationException = addValidationError("[rank] cannot be used in a scroll context", validationException); - } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); - } - if (source.sorts() != null && source.sorts().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [sort]", validationException); - } - if (source.collapse() != null) { - validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); - } - if (source.suggest() != null && source.suggest().getSuggestions().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); - } - if (source.highlighter() != null) { - validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); - } - if (source.pointInTimeBuilder() != null) { - validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); - } - } - if (source.rescores() != null) { - for (@SuppressWarnings("rawtypes") - RescorerBuilder rescoreBuilder : source.rescores()) { - validationException = rescoreBuilder.validate(this, validationException); - } - } - } if (pointInTimeBuilder() != null) { if (scroll) { validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException); @@ -461,16 +348,6 @@ public ActionRequestValidationException validate() { if (preference() != null) { validationException = addValidationError("[preference] cannot be used with point in time", validationException); } - } else if (source != null && source.sorts() != null) { - for (SortBuilder sortBuilder : source.sorts()) { - if (sortBuilder instanceof FieldSortBuilder - && ShardDocSortField.NAME.equals(((FieldSortBuilder) sortBuilder).getFieldName())) { - validationException = addValidationError( - "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", - validationException - ); - } - } } if (minCompatibleShardNode() != null) { if (isCcsMinimizeRoundtrips()) { diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index cb2c53a97fbc3..a80378db99348 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -10,12 +10,16 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.logging.DeprecationLogger; +import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; @@ -28,6 +32,7 @@ import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.script.Script; import org.elasticsearch.search.SearchExtBuilder; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; @@ -43,7 +48,9 @@ import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.slice.SliceBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -71,6 +78,7 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.index.query.AbstractQueryBuilder.parseTopLevelQuery; import static org.elasticsearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_ACCURATE; @@ -78,10 +86,9 @@ /** * A search source builder allowing to easily build search source. Simple - * construction using - * {@link org.elasticsearch.search.builder.SearchSourceBuilder#searchSource()}. + * construction using {@link SearchSourceBuilder#searchSource()}. * - * @see org.elasticsearch.action.search.SearchRequest#source(SearchSourceBuilder) + * @see SearchRequest#source(SearchSourceBuilder) */ public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class); @@ -141,6 +148,8 @@ public static HighlightBuilder highlight() { return new HighlightBuilder(); } + private transient RetrieverBuilder retrieverBuilder; + private List subSearchSourceBuilders = new ArrayList<>(); private QueryBuilder postQueryBuilder; @@ -283,6 +292,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { + if (retrieverBuilder != null) { + throw new IllegalStateException("SearchSourceBuilder should be rewritten first"); + } out.writeOptionalWriteable(aggregations); out.writeOptionalBoolean(explain); out.writeOptionalWriteable(fetchSourceContext); @@ -367,6 +379,18 @@ public void writeTo(StreamOutput out) throws IOException { } } + /** + * Sets the retriever for this request. + */ + public SearchSourceBuilder retriever(RetrieverBuilder retrieverBuilder) { + this.retrieverBuilder = retrieverBuilder; + return this; + } + + public RetrieverBuilder retriever() { + return retrieverBuilder; + } + /** * Sets the query for this request. */ @@ -1134,6 +1158,21 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti highlightBuilder ) )); + if (retrieverBuilder != null) { + var newRetriever = retrieverBuilder.rewrite(context); + if (newRetriever != retrieverBuilder) { + var rewritten = shallowCopy(); + rewritten.retrieverBuilder = newRetriever; + return rewritten; + } else { + // retriever is transient, the rewritten version is extracted in this source. + var retriever = retrieverBuilder; + retrieverBuilder = null; + retriever.extractToSearchSourceBuilder(this, false); + validate(); + } + } + List subSearchSourceBuilders = Rewriteable.rewrite(this.subSearchSourceBuilders, context); QueryBuilder postQueryBuilder = null; if (this.postQueryBuilder != null) { @@ -1293,7 +1332,6 @@ private SearchSourceBuilder parseXContent( } List knnBuilders = new ArrayList<>(); - RetrieverBuilder retrieverBuilder = null; SearchUsage searchUsage = new SearchUsage(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -1627,39 +1665,6 @@ private SearchSourceBuilder parseXContent( } knnSearch = knnBuilders.stream().map(knnBuilder -> knnBuilder.build(size())).collect(Collectors.toList()); - - if (retrieverBuilder != null) { - List specified = new ArrayList<>(); - if (subSearchSourceBuilders.isEmpty() == false) { - specified.add(QUERY_FIELD.getPreferredName()); - } - if (knnSearch.isEmpty() == false) { - specified.add(KNN_FIELD.getPreferredName()); - } - if (searchAfterBuilder != null) { - specified.add(SEARCH_AFTER.getPreferredName()); - } - if (terminateAfter != DEFAULT_TERMINATE_AFTER) { - specified.add(TERMINATE_AFTER_FIELD.getPreferredName()); - } - if (sorts != null) { - specified.add(SORT_FIELD.getPreferredName()); - } - if (rescoreBuilders != null) { - specified.add(RESCORE_FIELD.getPreferredName()); - } - if (minScore != null) { - specified.add(MIN_SCORE_FIELD.getPreferredName()); - } - if (rankBuilder != null) { - specified.add(RANK_FIELD.getPreferredName()); - } - if (specified.isEmpty() == false) { - throw new IllegalArgumentException("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified); - } - retrieverBuilder.extractToSearchSourceBuilder(this, false); - } - searchUsageConsumer.accept(searchUsage); return this; } @@ -1689,6 +1694,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t builder.field(TERMINATE_AFTER_FIELD.getPreferredName(), terminateAfter); } + if (retrieverBuilder != null) { + builder.field(RETRIEVER.getPreferredName(), retrieverBuilder); + } + if (subSearchSourceBuilders.isEmpty() == false) { if (subSearchSourceBuilders.size() == 1) { builder.field(QUERY_FIELD.getPreferredName(), subSearchSourceBuilders.get(0).getQueryBuilder()); @@ -2183,4 +2192,169 @@ public boolean supportsParallelCollection(ToLongFunction fieldCardinalit return collapse == null && (aggregations == null || aggregations.supportsParallelCollection(fieldCardinality)); } + + private void validate() throws ValidationException { + var exceptions = validate(null, false); + if (exceptions != null) { + throw exceptions; + } + } + + public ActionRequestValidationException validate(ActionRequestValidationException validationException, boolean isScroll) { + if (retriever() != null) { + List specified = new ArrayList<>(); + if (subSearches().isEmpty() == false) { + specified.add(QUERY_FIELD.getPreferredName()); + } + if (knnSearch().isEmpty() == false) { + specified.add(KNN_FIELD.getPreferredName()); + } + if (searchAfter() != null) { + specified.add(SEARCH_AFTER.getPreferredName()); + } + if (terminateAfter() != DEFAULT_TERMINATE_AFTER) { + specified.add(TERMINATE_AFTER_FIELD.getPreferredName()); + } + if (sorts() != null) { + specified.add(SORT_FIELD.getPreferredName()); + } + if (rescores() != null) { + specified.add(RESCORE_FIELD.getPreferredName()); + } + if (minScore() != null) { + specified.add(MIN_SCORE_FIELD.getPreferredName()); + } + if (rankBuilder() != null) { + specified.add(RANK_FIELD.getPreferredName()); + } + if (specified.isEmpty() == false) { + validationException = addValidationError( + "cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified, + validationException + ); + } + } + if (isScroll) { + if (trackTotalHitsUpTo() != null && trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { + validationException = addValidationError( + "disabling [track_total_hits] is not allowed in a scroll context", + validationException + ); + } + if (from() > 0) { + validationException = addValidationError("using [from] is not allowed in a scroll context", validationException); + } + if (size() == 0) { + validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException); + } + if (rescores() != null && rescores().isEmpty() == false) { + validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException); + } + if (CollectionUtils.isEmpty(searchAfter()) == false) { + validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException); + } + if (collapse() != null) { + validationException = addValidationError("cannot use `collapse` in a scroll context", validationException); + } + } + if (slice() != null) { + if (pointInTimeBuilder() == null && (isScroll == false)) { + validationException = addValidationError( + "[slice] can only be used with [scroll] or [point-in-time] requests", + validationException + ); + } + } + if (from() > 0 && CollectionUtils.isEmpty(searchAfter()) == false) { + validationException = addValidationError("[from] parameter must be set to 0 when [search_after] is used", validationException); + } + if (storedFields() != null) { + if (storedFields().fetchFields() == false) { + if (fetchSource() != null && fetchSource().fetchSource()) { + validationException = addValidationError( + "[stored_fields] cannot be disabled if [_source] is requested", + validationException + ); + } + if (fetchFields() != null) { + validationException = addValidationError( + "[stored_fields] cannot be disabled when using the [fields] option", + validationException + ); + } + } + } + if (subSearches().size() >= 2 && rankBuilder() == null) { + validationException = addValidationError("[sub_searches] requires [rank]", validationException); + } + if (aggregations() != null) { + validationException = aggregations().validate(validationException); + } + + if (rankBuilder() != null) { + int s = size() == -1 ? SearchService.DEFAULT_SIZE : size(); + if (s == 0) { + validationException = addValidationError("[rank] requires [size] greater than [0]", validationException); + } + if (s > rankBuilder().rankWindowSize()) { + validationException = addValidationError( + "[rank] requires [rank_window_size: " + + rankBuilder().rankWindowSize() + + "]" + + " be greater than or equal to [size: " + + s + + "]", + validationException + ); + } + int queryCount = subSearches().size() + knnSearch().size(); + if (rankBuilder().isCompoundBuilder() && queryCount < 2) { + validationException = addValidationError( + "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", + validationException + ); + } + if (isScroll) { + validationException = addValidationError("[rank] cannot be used in a scroll context", validationException); + } + if (rescores() != null && rescores().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); + } + if (sorts() != null && sorts().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [sort]", validationException); + } + if (collapse() != null) { + validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); + } + if (suggest() != null && suggest().getSuggestions().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); + } + if (highlighter() != null) { + validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); + } + if (pointInTimeBuilder() != null) { + validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); + } + } + + if (rescores() != null) { + for (@SuppressWarnings("rawtypes") + var rescorer : rescores()) { + validationException = rescorer.validate(this, validationException); + } + } + + if (pointInTimeBuilder() == null && sorts() != null) { + for (var sortBuilder : sorts()) { + if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder + && ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) { + validationException = addValidationError( + "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", + validationException + ); + } + } + } + return validationException; + } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java index 5cd5a888581c8..946fd46fe6aec 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.rescore; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -17,6 +16,7 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -120,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) { + public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) { return validationException; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 6e3d2a58dbd5d..3a9979030683a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -14,6 +14,8 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.AbstractObjectParser; import org.elasticsearch.xcontent.FilterXContentParserWrapper; @@ -33,16 +35,17 @@ /** * A retriever represents an API element that returns an ordered list of top * documents. These can be obtained from a query, from another retriever, etc. - * Internally, a {@link RetrieverBuilder} is just a wrapper for other search - * elements that are extracted into a {@link SearchSourceBuilder}. The advantage - * retrievers have is in the API they appear as a tree-like structure enabling + * Internally, a {@link RetrieverBuilder} is first rewritten into its simplest + * form and then its elements are extracted into a {@link SearchSourceBuilder}. + * + * The advantage retrievers have is in the API they appear as a tree-like structure enabling * easier reasoning about what a search does. * * This is the base class for all other retrievers. This class does not support * serialization and is expected to be fully extracted to a {@link SearchSourceBuilder} * prior to any transport calls. */ -public abstract class RetrieverBuilder implements ToXContent { +public abstract class RetrieverBuilder implements Rewriteable, ToXContent { public static final NodeFeature RETRIEVERS_SUPPORTED = new NodeFeature("retrievers_supported"); @@ -181,6 +184,13 @@ protected static RetrieverBuilder parseInnerRetrieverBuilder(XContentParser pars protected String retrieverName; + /** + * Determines if this retriever contains sub-retrievers that need to be executed prior to search. + */ + public boolean isCompound() { + return false; + } + /** * Gets the filters for this retriever. */ @@ -188,8 +198,13 @@ public List getPreFilterQueryBuilders() { return preFilterQueryBuilders; } + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + return this; + } + /** - * This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}. + * This method is called at the end of rewriting on behalf of a {@link SearchSourceBuilder}. * Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}. */ public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 601c55293418d..348a65d0c4960 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -298,6 +298,10 @@ public String getField() { return field; } + public List getFilterQueries() { + return filterQueries; + } + public KnnSearchBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index cbbbe7d86f4e2..de35d765a1551 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -10,9 +10,14 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RandomQueryBuilder; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -23,6 +28,10 @@ import java.util.List; import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase { @@ -34,7 +43,7 @@ public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [query]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query]")); } try ( @@ -44,26 +49,35 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [knn]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [knn]")); } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"search_after\": [1], \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [search_after]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [search_after]")); + } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"terminate_after\": 1, \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [terminate_after]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [terminate_after]")); } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"sort\": [\"field\"], \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [sort]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [sort]")); } try ( @@ -73,14 +87,18 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [rescore]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [rescore]")); } try (XContentParser parser = createParser(JsonXContent.jsonXContent, "{\"min_score\": 2, \"retriever\":{\"standard\":{}}}")) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [min_score]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [min_score]")); } try ( @@ -90,8 +108,10 @@ public void testRetrieverExtractionErrors() throws IOException { ) ) { SearchSourceBuilder ssb = new SearchSourceBuilder(); - IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> ssb.parseXContent(parser, true, nf -> true)); - assertEquals("cannot specify [retriever] and [query, terminate_after, min_score]", iae.getMessage()); + ssb.parseXContent(parser, true, nf -> true); + ActionRequestValidationException iae = ssb.validate(null, false); + assertNotNull(iae); + assertThat(iae.getMessage(), containsString("cannot specify [retriever] and [query, terminate_after, min_score]")); } } diff --git a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java index bec534d89cc03..cd0d8f8d50c1e 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/StandardRetrieverBuilderParsingTests.java @@ -11,8 +11,15 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.BoolQueryBuilder; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RandomQueryBuilder; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.collapse.CollapseBuilderTests; import org.elasticsearch.search.searchafter.SearchAfterBuilderTests; import org.elasticsearch.search.sort.SortBuilderTests; @@ -27,6 +34,11 @@ import java.util.List; import java.util.function.BiFunction; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; + public class StandardRetrieverBuilderParsingTests extends AbstractXContentTestCase { /** @@ -59,7 +71,7 @@ public static StandardRetrieverBuilder createRandomStandardRetrieverBuilder( } if (randomBoolean()) { - standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(); + standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(false); } if (randomBoolean()) { @@ -109,4 +121,52 @@ protected String[] getShuffleFieldsExceptions() { protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()); } + + public void testRewrite() throws IOException { + for (int i = 0; i < 10; i++) { + StandardRetrieverBuilder standardRetriever = createTestInstance(); + SearchSourceBuilder source = new SearchSourceBuilder().retriever(standardRetriever); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + source = Rewriteable.rewrite(source, queryRewriteContext); + assertNull(source.retriever()); + assertTrue(source.knnSearch().isEmpty()); + if (standardRetriever.queryBuilder != null) { + assertNotNull(source.query()); + if (standardRetriever.preFilterQueryBuilders.size() > 0) { + if (source.query() instanceof MatchAllQueryBuilder == false + && source.query() instanceof MatchNoneQueryBuilder == false) { + assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); + assertFalse(bq.must().isEmpty()); + assertThat(bq.must().size(), equalTo(1)); + assertThat(bq.must().get(0), equalTo(standardRetriever.queryBuilder)); + for (int j = 0; j < bq.filter().size(); j++) { + assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j)); + } + } + } else { + assertEqualQueryOrMatchAllNone(source.query(), standardRetriever.queryBuilder); + } + } else if (standardRetriever.preFilterQueryBuilders.size() > 0) { + if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) { + assertNotNull(source.query()); + assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); + assertTrue(bq.must().isEmpty()); + for (int j = 0; j < bq.filter().size(); j++) { + assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j)); + } + } + } else { + assertNull(source.query()); + } + if (standardRetriever.sortBuilders != null) { + assertThat(source.sorts().size(), equalTo(standardRetriever.sortBuilders.size())); + } + } + } + + private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) { + assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected))); + } } diff --git a/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java b/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java index eee98297c7a13..84f87b3f01881 100644 --- a/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java @@ -119,7 +119,7 @@ public void testSingleFieldSort() throws IOException { public void testRandomSortBuilders() throws IOException { for (int runs = 0; runs < NUMBER_OF_RUNS; runs++) { Set expectedWarningHeaders = new HashSet<>(); - List> testBuilders = randomSortBuilderList(); + List> testBuilders = randomSortBuilderList(randomBoolean()); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); xContentBuilder.startObject(); if (testBuilders.size() > 1) { @@ -171,7 +171,7 @@ public void testRandomSortBuilders() throws IOException { } } - public static List> randomSortBuilderList() { + public static List> randomSortBuilderList(boolean hasPIT) { int size = randomIntBetween(1, 5); List> list = new ArrayList<>(size); for (int i = 0; i < size; i++) { @@ -181,7 +181,7 @@ public static List> randomSortBuilderList() { case 2 -> SortBuilders.fieldSort(FieldSortBuilder.DOC_FIELD_NAME); case 3 -> GeoDistanceSortBuilderTests.randomGeoDistanceSortBuilder(); case 4 -> ScriptSortBuilderTests.randomScriptSortBuilder(); - case 5 -> SortBuilders.pitTiebreaker(); + case 5 -> hasPIT ? SortBuilders.pitTiebreaker() : ScriptSortBuilderTests.randomScriptSortBuilder(); default -> throw new IllegalStateException("unexpected randomization in tests"); }); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java index f8f9caf365918..d2179a69ebc24 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java @@ -11,12 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -134,10 +134,10 @@ public RescorerBuilder rewrite(QueryRewriteContex } @Override - public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) { - validationException = super.validate(searchRequest, validationException); + public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) { + validationException = super.validate(source, validationException); - int searchRequestPaginationSize = searchRequest.source().from() + searchRequest.source().size(); + int searchRequestPaginationSize = source.from() + source.size(); if (windowSize() < searchRequestPaginationSize) { return addValidationError( @@ -151,7 +151,7 @@ public ActionRequestValidationException validate(SearchRequest searchRequest, Ac } @SuppressWarnings("rawtypes") - List rescorers = searchRequest.source().rescores(); + List rescorers = source.rescores(); assert rescorers != null && rescorers.contains(this); for (int i = rescorers.indexOf(this) + 1; i < rescorers.size(); i++) {