diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverRewriteIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverRewriteIT.java new file mode 100644 index 0000000000000..ef6fef4b478f4 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/retriever/RetrieverRewriteIT.java @@ -0,0 +1,257 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.search.retriever; + +import org.apache.lucene.search.TotalHits; +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.action.search.SearchResponse; +import org.elasticsearch.index.query.QueryBuilders; +import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.plugins.SearchPlugin; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.test.ESIntegTestCase; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.junit.Before; + +import java.io.IOException; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.ExecutionException; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.hamcrest.Matchers.equalTo; + +public class RetrieverRewriteIT extends ESIntegTestCase { + @Override + protected Collection> nodePlugins() { + return List.of(AssertingPlugin.class); + } + + private static String INDEX_DOCS = "docs"; + private static String INDEX_QUERIES = "queries"; + private static final String ID_FIELD = "_id"; + private static final String QUERY_FIELD = "query"; + + @Before + public void setup() throws Exception { + createIndex(INDEX_DOCS); + index(INDEX_DOCS, "doc_0", "{}"); + index(INDEX_DOCS, "doc_1", "{}"); + index(INDEX_DOCS, "doc_2", "{}"); + refresh(INDEX_DOCS); + + createIndex(INDEX_QUERIES); + index(INDEX_QUERIES, "query_0", "{ \"" + QUERY_FIELD + "\": \"doc_2\"}"); + index(INDEX_QUERIES, "query_1", "{ \"" + QUERY_FIELD + "\": \"doc_1\"}"); + index(INDEX_QUERIES, "query_2", "{ \"" + QUERY_FIELD + "\": \"doc_0\"}"); + refresh(INDEX_QUERIES); + } + + public void testRewrite() throws ExecutionException, InterruptedException { + SearchSourceBuilder source = new SearchSourceBuilder(); + StandardRetrieverBuilder standard = new StandardRetrieverBuilder(); + standard.queryBuilder = QueryBuilders.termQuery(ID_FIELD, "doc_0"); + source.retriever(new AssertingRetrieverBuilder(standard)); + SearchRequest req = new SearchRequest(INDEX_DOCS, INDEX_QUERIES).source(source); + SearchResponse resp = client().search(req).get(); + assertNull(resp.pointInTimeId()); + assertThat(resp.getHits().getTotalHits().value, equalTo(1L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_0")); + } + + public void testRewriteCompound() throws ExecutionException, InterruptedException { + SearchSourceBuilder source = new SearchSourceBuilder(); + source.retriever(new AssertingCompoundRetrieverBuilder("query_0")); + SearchRequest req = new SearchRequest(INDEX_DOCS, INDEX_QUERIES).source(source); + SearchResponse resp = client().search(req).get(); + assertNull(resp.pointInTimeId()); + assertThat(resp.getHits().getTotalHits().value, equalTo(1L)); + assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO)); + assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2")); + } + + public static class AssertingPlugin extends Plugin implements SearchPlugin { + public AssertingPlugin() {} + + @Override + public List> getRetrievers() { + return List.of( + new RetrieverSpec(AssertingRetrieverBuilder.NAME, AssertingRetrieverBuilder::fromXContent), + new RetrieverSpec(AssertingCompoundRetrieverBuilder.NAME, AssertingCompoundRetrieverBuilder::fromXContent) + ); + } + } + + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + AssertingCompoundRetrieverBuilder.NAME, + args -> new AssertingRetrieverBuilder((RetrieverBuilder) args[0]) + ); + + public static final ConstructingObjectParser PARSER_COMPOUND = + new ConstructingObjectParser<>( + AssertingCompoundRetrieverBuilder.NAME, + args -> new AssertingCompoundRetrieverBuilder((String) args[0]) + ); + + static { + RetrieverBuilder.declareBaseParserFields(AssertingRetrieverBuilder.NAME, PARSER); + PARSER.declareObject(constructorArg(), RetrieverBuilder::parseInnerRetrieverBuilder, new ParseField("retriever")); + + RetrieverBuilder.declareBaseParserFields(AssertingCompoundRetrieverBuilder.NAME, PARSER_COMPOUND); + PARSER_COMPOUND.declareString(constructorArg(), new ParseField("id")); + } + + private static class AssertingRetrieverBuilder extends RetrieverBuilder { + static final String NAME = "asserting"; + + private final RetrieverBuilder innerRetriever; + + public static AssertingRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException { + return PARSER.apply(parser, context); + } + + private AssertingRetrieverBuilder(RetrieverBuilder innerRetriever) { + this.innerRetriever = innerRetriever; + } + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + assertNull(ctx.getPointInTimeBuilder()); + assertNull(ctx.convertToInnerHitsRewriteContext()); + assertNull(ctx.convertToCoordinatorRewriteContext()); + assertNull(ctx.convertToIndexMetadataContext()); + assertNull(ctx.convertToSearchExecutionContext()); + assertNull(ctx.convertToDataRewriteContext()); + var newRetriever = innerRetriever.rewrite(ctx); + if (newRetriever != innerRetriever) { + return new AssertingRetrieverBuilder(newRetriever); + } + return this; + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder sourceBuilder, boolean compoundUsed) { + assertNull(sourceBuilder.retriever()); + innerRetriever.extractToSearchSourceBuilder(sourceBuilder, compoundUsed); + } + + @Override + public String getName() { + return "asserting"; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException {} + + @Override + protected boolean doEquals(Object o) { + return false; + } + + @Override + protected int doHashCode() { + return innerRetriever.doHashCode(); + } + } + + private static class AssertingCompoundRetrieverBuilder extends RetrieverBuilder { + static final String NAME = "asserting_compound"; + + private final String id; + private final SetOnce innerRetriever; + + public static AssertingCompoundRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) + throws IOException { + return PARSER_COMPOUND.apply(parser, context); + } + + private AssertingCompoundRetrieverBuilder(String id) { + this.id = id; + this.innerRetriever = new SetOnce<>(null); + } + + private AssertingCompoundRetrieverBuilder(String id, SetOnce innerRetriever) { + this.id = id; + this.innerRetriever = innerRetriever; + } + + @Override + public boolean isCompound() { + return true; + } + + @Override + public RetrieverBuilder rewrite(QueryRewriteContext ctx) throws IOException { + assertNotNull(ctx.getPointInTimeBuilder()); + assertNull(ctx.convertToInnerHitsRewriteContext()); + assertNull(ctx.convertToCoordinatorRewriteContext()); + assertNull(ctx.convertToIndexMetadataContext()); + assertNull(ctx.convertToSearchExecutionContext()); + assertNull(ctx.convertToDataRewriteContext()); + if (innerRetriever.get() != null) { + return this; + } + SetOnce innerRetriever = new SetOnce<>(); + ctx.registerAsyncAction((client, actionListener) -> { + SearchSourceBuilder source = new SearchSourceBuilder().pointInTimeBuilder(ctx.getPointInTimeBuilder()) + .query(QueryBuilders.termQuery(ID_FIELD, id)) + .fetchField(QUERY_FIELD); + client.search(new SearchRequest().source(source), new ActionListener<>() { + @Override + public void onResponse(SearchResponse response) { + String query = response.getHits().getAt(0).field(QUERY_FIELD).getValue(); + StandardRetrieverBuilder standard = new StandardRetrieverBuilder(); + standard.queryBuilder = QueryBuilders.termQuery(ID_FIELD, query); + innerRetriever.set(standard); + actionListener.onResponse(null); + } + + @Override + public void onFailure(Exception e) { + actionListener.onFailure(e); + } + }); + }); + return new AssertingCompoundRetrieverBuilder(id, innerRetriever); + } + + @Override + public void extractToSearchSourceBuilder(SearchSourceBuilder sourceBuilder, boolean compoundUsed) { + assertNull(sourceBuilder.retriever()); + innerRetriever.get().extractToSearchSourceBuilder(sourceBuilder, compoundUsed); + } + + @Override + public String getName() { + return "asserting"; + } + + @Override + protected void doToXContent(XContentBuilder builder, Params params) throws IOException { + throw new AssertionError("not implemented"); + } + + @Override + protected boolean doEquals(Object o) { + return false; + } + + @Override + protected int doHashCode() { + return id.hashCode(); + } + } +} diff --git a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java index 0db9f3d20d117..ae0a0c40f1267 100644 --- a/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/AbstractSearchAsyncAction.java @@ -709,7 +709,9 @@ public void sendSearchResponse(SearchResponseSections internalSearchResponse, At if (buildPointInTimeFromSearchResults()) { searchContextId = SearchContextId.encode(queryResults.asList(), aliasFilter, minTransportVersion); } else { - if (request.source() != null && request.source().pointInTimeBuilder() != null) { + if (request.source() != null + && request.source().pointInTimeBuilder() != null + && request.source().pointInTimeBuilder().singleSession() == false) { searchContextId = request.source().pointInTimeBuilder().getEncodedId(); } else { searchContextId = null; diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index c2d1cdae85cd9..3fb63591bf3a4 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -66,6 +66,7 @@ import org.elasticsearch.search.SearchPhaseResult; import org.elasticsearch.search.SearchService; import org.elasticsearch.search.aggregations.AggregationReduceContext; +import org.elasticsearch.search.builder.PointInTimeBuilder; import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.search.internal.AliasFilter; import org.elasticsearch.search.internal.SearchContext; @@ -484,11 +485,66 @@ void executeRequest( } } }); + final SearchSourceBuilder source = original.source(); + if (shouldOpenPIT(source)) { + openPIT(original, searchService.getDefaultKeepAliveInMillis(), listener.delegateFailureAndWrap((delegate, resp) -> { + source.pointInTimeBuilder(new PointInTimeBuilder(resp.getPointInTimeId()).setKeepAlive(TimeValue.MINUS_ONE)); + executeRequest(task, original, new ActionListener<>() { + @Override + public void onResponse(SearchResponse response) { + // we need to close the PIT first so we delay the release of the response to after the closing + response.incRef(); + closePIT(original.source().pointInTimeBuilder(), () -> ActionListener.respondAndRelease(listener, response)); + } + + @Override + public void onFailure(Exception e) { + closePIT(original.source().pointInTimeBuilder(), () -> listener.onFailure(e)); + } + }, searchPhaseProvider); + })); + } else { + Rewriteable.rewriteAndFetch( + original, + searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices, original.pointInTimeBuilder()), + rewriteListener + ); + } + } + + /** + * Returns true if the provided source needs to open a shared point in time prior to executing the request. + */ + private boolean shouldOpenPIT(SearchSourceBuilder source) { + if (source == null) { + return false; + } + if (source.pointInTimeBuilder() != null) { + return false; + } + var retriever = source.retriever(); + return retriever != null && retriever.isCompound(); + } + + private void openPIT(SearchRequest request, long keepAliveMillis, ActionListener listener) { + OpenPointInTimeRequest pitReq = new OpenPointInTimeRequest(request.indices()).indicesOptions(request.indicesOptions()) + .preference(request.preference()) + .routing(request.routing()) + .keepAlive(TimeValue.timeValueMillis(keepAliveMillis)); + client.execute(TransportOpenPointInTimeAction.TYPE, pitReq, listener); + } - Rewriteable.rewriteAndFetch( - original, - searchService.getRewriteContext(timeProvider::absoluteStartMillis, resolvedIndices), - rewriteListener + private void closePIT(PointInTimeBuilder pit, Runnable next) { + client.execute( + TransportClosePointInTimeAction.TYPE, + new ClosePointInTimeRequest(pit.getEncodedId()), + ActionListener.runAfter(new ActionListener<>() { + @Override + public void onResponse(ClosePointInTimeResponse closePointInTimeResponse) {} + + @Override + public void onFailure(Exception e) {} + }, next) ); } diff --git a/server/src/main/java/org/elasticsearch/search/builder/PointInTimeBuilder.java b/server/src/main/java/org/elasticsearch/search/builder/PointInTimeBuilder.java index 1966f7eaa1e69..db79feb986b84 100644 --- a/server/src/main/java/org/elasticsearch/search/builder/PointInTimeBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/builder/PointInTimeBuilder.java @@ -136,6 +136,13 @@ public TimeValue getKeepAlive() { return keepAlive; } + /** + * Returns {@code true} if the point in time is explicitly released when returning the response. + */ + public boolean singleSession() { + return keepAlive != null && TimeValue.MINUS_ONE.equals(keepAlive); + } + @Override public boolean equals(Object o) { if (this == o) return true;