From 83a0fe4baa0fd7e34df32f5be79f3fe415194c82 Mon Sep 17 00:00:00 2001 From: Jim Ferenczi Date: Thu, 4 Jul 2024 14:52:42 +0100 Subject: [PATCH] Consume retrievers during the query rewrite --- .../action/search/SearchRequest.java | 131 +--------- .../search/builder/SearchSourceBuilder.java | 246 +++++++++++++++--- .../search/rescore/RescorerBuilder.java | 4 +- .../search/retriever/RetrieverBuilder.java | 25 +- .../ltr/LearningToRankRescorerBuilder.java | 10 +- 5 files changed, 240 insertions(+), 176 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java index e1fe6eac7e9c1..514e8d10eeca1 100644 --- a/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java +++ b/server/src/main/java/org/elasticsearch/action/search/SearchRequest.java @@ -17,20 +17,16 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.mapper.SourceLoader; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.Scroll; -import org.elasticsearch.search.SearchService; import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.SearchContext; -import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.search.sort.FieldSortBuilder; -import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.tasks.TaskId; @@ -324,124 +320,15 @@ public void writeTo(StreamOutput out) throws IOException { public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; boolean scroll = scroll() != null; + + if (source != null) { + validationException = source.validate(validationException, scroll); + } if (scroll) { - if (source != null) { - if (source.trackTotalHitsUpTo() != null && source.trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { - validationException = addValidationError( - "disabling [track_total_hits] is not allowed in a scroll context", - validationException - ); - } - if (source.from() > 0) { - validationException = addValidationError("using [from] is not allowed in a scroll context", validationException); - } - if (source.size() == 0) { - validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException); - } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException); - } - if (CollectionUtils.isEmpty(source.searchAfter()) == false) { - validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException); - } - if (source.collapse() != null) { - validationException = addValidationError("cannot use `collapse` in a scroll context", validationException); - } - } if (requestCache != null && requestCache) { validationException = addValidationError("[request_cache] cannot be used in a scroll context", validationException); } } - if (source != null) { - if (source.slice() != null) { - if (source.pointInTimeBuilder() == null && (scroll == false)) { - validationException = addValidationError( - "[slice] can only be used with [scroll] or [point-in-time] requests", - validationException - ); - } - } - if (source.from() > 0 && CollectionUtils.isEmpty(source.searchAfter()) == false) { - validationException = addValidationError( - "[from] parameter must be set to 0 when [search_after] is used", - validationException - ); - } - if (source.storedFields() != null) { - if (source.storedFields().fetchFields() == false) { - if (source.fetchSource() != null && source.fetchSource().fetchSource()) { - validationException = addValidationError( - "[stored_fields] cannot be disabled if [_source] is requested", - validationException - ); - } - if (source.fetchFields() != null) { - validationException = addValidationError( - "[stored_fields] cannot be disabled when using the [fields] option", - validationException - ); - } - - } - } - if (source.subSearches().size() >= 2 && source.rankBuilder() == null) { - validationException = addValidationError("[sub_searches] requires [rank]", validationException); - } - if (source.aggregations() != null) { - validationException = source.aggregations().validate(validationException); - } - if (source.rankBuilder() != null) { - int size = source.size() == -1 ? SearchService.DEFAULT_SIZE : source.size(); - if (size == 0) { - validationException = addValidationError("[rank] requires [size] greater than [0]", validationException); - } - if (size > source.rankBuilder().rankWindowSize()) { - validationException = addValidationError( - "[rank] requires [rank_window_size: " - + source.rankBuilder().rankWindowSize() - + "]" - + " be greater than or equal to [size: " - + size - + "]", - validationException - ); - } - int queryCount = source.subSearches().size() + source.knnSearch().size(); - if (source.rankBuilder().isCompoundBuilder() && queryCount < 2) { - validationException = addValidationError( - "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", - validationException - ); - } - if (scroll) { - validationException = addValidationError("[rank] cannot be used in a scroll context", validationException); - } - if (source.rescores() != null && source.rescores().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); - } - if (source.sorts() != null && source.sorts().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [sort]", validationException); - } - if (source.collapse() != null) { - validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); - } - if (source.suggest() != null && source.suggest().getSuggestions().isEmpty() == false) { - validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); - } - if (source.highlighter() != null) { - validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); - } - if (source.pointInTimeBuilder() != null) { - validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); - } - } - if (source.rescores() != null) { - for (@SuppressWarnings("rawtypes") - RescorerBuilder rescoreBuilder : source.rescores()) { - validationException = rescoreBuilder.validate(this, validationException); - } - } - } if (pointInTimeBuilder() != null) { if (scroll) { validationException = addValidationError("using [point in time] is not allowed in a scroll context", validationException); @@ -461,16 +348,6 @@ public ActionRequestValidationException validate() { if (preference() != null) { validationException = addValidationError("[preference] cannot be used with point in time", validationException); } - } else if (source != null && source.sorts() != null) { - for (SortBuilder sortBuilder : source.sorts()) { - if (sortBuilder instanceof FieldSortBuilder - && ShardDocSortField.NAME.equals(((FieldSortBuilder) sortBuilder).getFieldName())) { - validationException = addValidationError( - "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", - validationException - ); - } - } } if (minCompatibleShardNode() != null) { if (isCcsMinimizeRoundtrips()) { diff --git a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java index cb2c53a97fbc3..27f9147de9b36 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/SearchSourceBuilder.java @@ -10,12 +10,16 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.logging.DeprecationLogger; +import org.elasticsearch.common.util.CollectionUtils; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Booleans; import org.elasticsearch.core.Nullable; @@ -28,6 +32,7 @@ import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.script.Script; import org.elasticsearch.search.SearchExtBuilder; +import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationBuilder; import org.elasticsearch.search.aggregations.AggregatorFactories; import org.elasticsearch.search.aggregations.PipelineAggregationBuilder; @@ -43,7 +48,9 @@ import org.elasticsearch.search.retriever.RetrieverParserContext; import org.elasticsearch.search.searchafter.SearchAfterBuilder; import org.elasticsearch.search.slice.SliceBuilder; +import org.elasticsearch.search.sort.FieldSortBuilder; import org.elasticsearch.search.sort.ScoreSortBuilder; +import org.elasticsearch.search.sort.ShardDocSortField; import org.elasticsearch.search.sort.SortBuilder; import org.elasticsearch.search.sort.SortBuilders; import org.elasticsearch.search.sort.SortOrder; @@ -71,6 +78,7 @@ import java.util.stream.Collectors; import static java.util.Collections.emptyMap; +import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.index.query.AbstractQueryBuilder.parseTopLevelQuery; import static org.elasticsearch.search.internal.SearchContext.DEFAULT_TERMINATE_AFTER; import static org.elasticsearch.search.internal.SearchContext.TRACK_TOTAL_HITS_ACCURATE; @@ -78,10 +86,9 @@ /** * A search source builder allowing to easily build search source. Simple - * construction using - * {@link org.elasticsearch.search.builder.SearchSourceBuilder#searchSource()}. + * construction using {@link SearchSourceBuilder#searchSource()}. * - * @see org.elasticsearch.action.search.SearchRequest#source(SearchSourceBuilder) + * @see SearchRequest#source(SearchSourceBuilder) */ public final class SearchSourceBuilder implements Writeable, ToXContentObject, Rewriteable { private static final DeprecationLogger deprecationLogger = DeprecationLogger.getLogger(SearchSourceBuilder.class); @@ -141,6 +148,8 @@ public static HighlightBuilder highlight() { return new HighlightBuilder(); } + private transient RetrieverBuilder retrieverBuilder; + private List subSearchSourceBuilders = new ArrayList<>(); private QueryBuilder postQueryBuilder; @@ -283,6 +292,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException { @Override public void writeTo(StreamOutput out) throws IOException { + if (retrieverBuilder != null) { + throw new IllegalStateException("SearchSourceBuilder should be rewritten first"); + } out.writeOptionalWriteable(aggregations); out.writeOptionalBoolean(explain); out.writeOptionalWriteable(fetchSourceContext); @@ -367,6 +379,18 @@ public void writeTo(StreamOutput out) throws IOException { } } + /** + * Sets the retriever for this request. + */ + public SearchSourceBuilder retriever(RetrieverBuilder retrieverBuilder) { + this.retrieverBuilder = retrieverBuilder; + return this; + } + + public RetrieverBuilder retriever() { + return retrieverBuilder; + } + /** * Sets the query for this request. */ @@ -1134,6 +1158,21 @@ public SearchSourceBuilder rewrite(QueryRewriteContext context) throws IOExcepti highlightBuilder ) )); + if (retrieverBuilder != null) { + var newRetriever = retrieverBuilder.rewrite(context); + if (newRetriever != retrieverBuilder) { + var rewritten = shallowCopy(); + rewritten.retrieverBuilder = newRetriever; + return rewritten; + } else { + // retriever is transient, the rewritten version is extracted in this source. + var retriever = retrieverBuilder; + retrieverBuilder = null; + retriever.extractToSearchSourceBuilder(this, false); + validate(); + } + } + List subSearchSourceBuilders = Rewriteable.rewrite(this.subSearchSourceBuilders, context); QueryBuilder postQueryBuilder = null; if (this.postQueryBuilder != null) { @@ -1293,7 +1332,6 @@ private SearchSourceBuilder parseXContent( } List knnBuilders = new ArrayList<>(); - RetrieverBuilder retrieverBuilder = null; SearchUsage searchUsage = new SearchUsage(); while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) { if (token == XContentParser.Token.FIELD_NAME) { @@ -1627,39 +1665,6 @@ private SearchSourceBuilder parseXContent( } knnSearch = knnBuilders.stream().map(knnBuilder -> knnBuilder.build(size())).collect(Collectors.toList()); - - if (retrieverBuilder != null) { - List specified = new ArrayList<>(); - if (subSearchSourceBuilders.isEmpty() == false) { - specified.add(QUERY_FIELD.getPreferredName()); - } - if (knnSearch.isEmpty() == false) { - specified.add(KNN_FIELD.getPreferredName()); - } - if (searchAfterBuilder != null) { - specified.add(SEARCH_AFTER.getPreferredName()); - } - if (terminateAfter != DEFAULT_TERMINATE_AFTER) { - specified.add(TERMINATE_AFTER_FIELD.getPreferredName()); - } - if (sorts != null) { - specified.add(SORT_FIELD.getPreferredName()); - } - if (rescoreBuilders != null) { - specified.add(RESCORE_FIELD.getPreferredName()); - } - if (minScore != null) { - specified.add(MIN_SCORE_FIELD.getPreferredName()); - } - if (rankBuilder != null) { - specified.add(RANK_FIELD.getPreferredName()); - } - if (specified.isEmpty() == false) { - throw new IllegalArgumentException("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified); - } - retrieverBuilder.extractToSearchSourceBuilder(this, false); - } - searchUsageConsumer.accept(searchUsage); return this; } @@ -1689,6 +1694,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t builder.field(TERMINATE_AFTER_FIELD.getPreferredName(), terminateAfter); } + if (retrieverBuilder != null) { + builder.field(RETRIEVER.getPreferredName(), retrieverBuilder); + } + if (subSearchSourceBuilders.isEmpty() == false) { if (subSearchSourceBuilders.size() == 1) { builder.field(QUERY_FIELD.getPreferredName(), subSearchSourceBuilders.get(0).getQueryBuilder()); @@ -2183,4 +2192,167 @@ public boolean supportsParallelCollection(ToLongFunction fieldCardinalit return collapse == null && (aggregations == null || aggregations.supportsParallelCollection(fieldCardinality)); } + + private void validate() throws ValidationException { + var exceptions = validate(null, false); + if (exceptions != null) { + throw exceptions; + } + } + + public ActionRequestValidationException validate(ActionRequestValidationException validationException, boolean isScroll) { + if (retriever() != null) { + List specified = new ArrayList<>(); + if (subSearches().isEmpty() == false) { + specified.add(QUERY_FIELD.getPreferredName()); + } + if (knnSearch().isEmpty() == false) { + specified.add(KNN_FIELD.getPreferredName()); + } + if (searchAfter() != null) { + specified.add(SEARCH_AFTER.getPreferredName()); + } + if (terminateAfter() != DEFAULT_TERMINATE_AFTER) { + specified.add(TERMINATE_AFTER_FIELD.getPreferredName()); + } + if (sorts() != null) { + specified.add(SORT_FIELD.getPreferredName()); + } + if (rescores() != null) { + specified.add(RESCORE_FIELD.getPreferredName()); + } + if (minScore() != null) { + specified.add(MIN_SCORE_FIELD.getPreferredName()); + } + if (rankBuilder() != null) { + specified.add(RANK_FIELD.getPreferredName()); + } + if (specified.isEmpty() == false) { + validationException = addValidationError("cannot specify [" + RETRIEVER.getPreferredName() + "] and " + specified, validationException); + } + } + if (isScroll) { + if (trackTotalHitsUpTo() != null && trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE) { + validationException = addValidationError( + "disabling [track_total_hits] is not allowed in a scroll context", + validationException + ); + } + if (from() > 0) { + validationException = addValidationError("using [from] is not allowed in a scroll context", validationException); + } + if (size() == 0) { + validationException = addValidationError("[size] cannot be [0] in a scroll context", validationException); + } + if (rescores() != null && rescores().isEmpty() == false) { + validationException = addValidationError("using [rescore] is not allowed in a scroll context", validationException); + } + if (CollectionUtils.isEmpty(searchAfter()) == false) { + validationException = addValidationError("[search_after] cannot be used in a scroll context", validationException); + } + if (collapse() != null) { + validationException = addValidationError("cannot use `collapse` in a scroll context", validationException); + } + } + if (slice() != null) { + if (pointInTimeBuilder() == null && (isScroll == false)) { + validationException = addValidationError( + "[slice] can only be used with [scroll] or [point-in-time] requests", + validationException + ); + } + } + if (from() > 0 && CollectionUtils.isEmpty(searchAfter()) == false) { + validationException = addValidationError("[from] parameter must be set to 0 when [search_after] is used", validationException); + } + if (storedFields() != null) { + if (storedFields().fetchFields() == false) { + if (fetchSource() != null && fetchSource().fetchSource()) { + validationException = addValidationError( + "[stored_fields] cannot be disabled if [_source] is requested", + validationException + ); + } + if (fetchSource() != null) { + validationException = addValidationError( + "[stored_fields] cannot be disabled when using the [fields] option", + validationException + ); + } + + } + } + if (subSearches().size() >= 2 && rankBuilder() == null) { + validationException = addValidationError("[sub_searches] requires [rank]", validationException); + } + if (aggregations() != null) { + validationException = aggregations().validate(validationException); + } + + if (rankBuilder() != null) { + int s = size() == -1 ? SearchService.DEFAULT_SIZE : size(); + if (s == 0) { + validationException = addValidationError("[rank] requires [size] greater than [0]", validationException); + } + if (s > rankBuilder().rankWindowSize()) { + validationException = addValidationError( + "[rank] requires [rank_window_size: " + + rankBuilder().rankWindowSize() + + "]" + + " be greater than or equal to [size: " + + size() + + "]", + validationException + ); + } + int queryCount = subSearches().size() + knnSearch().size(); + if (rankBuilder().isCompoundBuilder() && queryCount < 2) { + validationException = addValidationError( + "[rank] requires a minimum of [2] result sets using a combination of sub searches and/or knn searches", + validationException + ); + } + if (isScroll) { + validationException = addValidationError("[rank] cannot be used in a scroll context", validationException); + } + if (rescores() != null && rescores().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [rescore]", validationException); + } + if (sorts() != null && sorts().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [sort]", validationException); + } + if (collapse() != null) { + validationException = addValidationError("[rank] cannot be used with [collapse]", validationException); + } + if (suggest() != null && suggest().getSuggestions().isEmpty() == false) { + validationException = addValidationError("[rank] cannot be used with [suggest]", validationException); + } + if (highlight() != null) { + validationException = addValidationError("[rank] cannot be used with [highlighter]", validationException); + } + if (pointInTimeBuilder() != null) { + validationException = addValidationError("[rank] cannot be used with [point in time]", validationException); + } + } + + if (rescores() != null) { + for (@SuppressWarnings("rawtypes") + var rescorer : rescores()) { + validationException = rescorer.validate(this, validationException); + } + } + + if (pointInTimeBuilder() == null && sorts() != null) { + for (var sortBuilder : sorts()) { + if (sortBuilder instanceof FieldSortBuilder fieldSortBuilder + && ShardDocSortField.NAME.equals(fieldSortBuilder.getFieldName())) { + validationException = addValidationError( + "[" + FieldSortBuilder.SHARD_DOC_FIELD_NAME + "] sort field cannot be used without [point in time]", + validationException + ); + } + } + } + return validationException; + } } diff --git a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java index 5cd5a888581c8..946fd46fe6aec 100644 --- a/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/rescore/RescorerBuilder.java @@ -9,7 +9,6 @@ package org.elasticsearch.search.rescore; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.ParsingException; import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; @@ -17,6 +16,7 @@ import org.elasticsearch.common.io.stream.VersionedNamedWriteable; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; @@ -120,7 +120,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws return builder; } - public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) { + public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) { return validationException; } diff --git a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java index 6e3d2a58dbd5d..3a9979030683a 100644 --- a/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/retriever/RetrieverBuilder.java @@ -14,6 +14,8 @@ import org.elasticsearch.features.NodeFeature; import org.elasticsearch.index.query.AbstractQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.xcontent.AbstractObjectParser; import org.elasticsearch.xcontent.FilterXContentParserWrapper; @@ -33,16 +35,17 @@ /** * A retriever represents an API element that returns an ordered list of top * documents. These can be obtained from a query, from another retriever, etc. - * Internally, a {@link RetrieverBuilder} is just a wrapper for other search - * elements that are extracted into a {@link SearchSourceBuilder}. The advantage - * retrievers have is in the API they appear as a tree-like structure enabling + * Internally, a {@link RetrieverBuilder} is first rewritten into its simplest + * form and then its elements are extracted into a {@link SearchSourceBuilder}. + * + * The advantage retrievers have is in the API they appear as a tree-like structure enabling * easier reasoning about what a search does. * * This is the base class for all other retrievers. This class does not support * serialization and is expected to be fully extracted to a {@link SearchSourceBuilder} * prior to any transport calls. */ -public abstract class RetrieverBuilder implements ToXContent { +public abstract class RetrieverBuilder implements Rewriteable, ToXContent { public static final NodeFeature RETRIEVERS_SUPPORTED = new NodeFeature("retrievers_supported"); @@ -181,6 +184,13 @@ protected static RetrieverBuilder parseInnerRetrieverBuilder(XContentParser pars protected String retrieverName; + /** + * Determines if this retriever contains sub-retrievers that need to be executed prior to search. + */ + public boolean isCompound() { + return false; + } + /** * Gets the filters for this retriever. */ @@ -188,8 +198,13 @@ public List getPreFilterQueryBuilders() { return preFilterQueryBuilders; } + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + return this; + } + /** - * This method is called at the end of parsing on behalf of a {@link SearchSourceBuilder}. + * This method is called at the end of rewriting on behalf of a {@link SearchSourceBuilder}. * Elements from retrievers are expected to be "extracted" into the {@link SearchSourceBuilder}. */ public abstract void extractToSearchSourceBuilder(SearchSourceBuilder searchSourceBuilder, boolean compoundUsed); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java index f8f9caf365918..d2179a69ebc24 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ltr/LearningToRankRescorerBuilder.java @@ -11,12 +11,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; -import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.rescore.RescorerBuilder; import org.elasticsearch.xcontent.ObjectParser; import org.elasticsearch.xcontent.ParseField; @@ -134,10 +134,10 @@ public RescorerBuilder rewrite(QueryRewriteContex } @Override - public ActionRequestValidationException validate(SearchRequest searchRequest, ActionRequestValidationException validationException) { - validationException = super.validate(searchRequest, validationException); + public ActionRequestValidationException validate(SearchSourceBuilder source, ActionRequestValidationException validationException) { + validationException = super.validate(source, validationException); - int searchRequestPaginationSize = searchRequest.source().from() + searchRequest.source().size(); + int searchRequestPaginationSize = source.from() + source.size(); if (windowSize() < searchRequestPaginationSize) { return addValidationError( @@ -151,7 +151,7 @@ public ActionRequestValidationException validate(SearchRequest searchRequest, Ac } @SuppressWarnings("rawtypes") - List rescorers = searchRequest.source().rescores(); + List rescorers = source.rescores(); assert rescorers != null && rescorers.contains(this); for (int i = rescorers.indexOf(this) + 1; i < rescorers.size(); i++) {