Skip to content

Commit

Permalink
address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Jul 10, 2024
1 parent e8f00b7 commit a370cac
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,32 +12,28 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.search.SearchRequest;
import org.elasticsearch.action.search.SearchRequestBuilder;
import org.elasticsearch.action.search.SearchResponse;
import org.elasticsearch.index.query.QueryBuilders;
import org.elasticsearch.index.query.QueryRewriteContext;
import org.elasticsearch.plugins.Plugin;
import org.elasticsearch.plugins.SearchPlugin;
import org.elasticsearch.search.MockSearchService;
import org.elasticsearch.search.builder.SearchSourceBuilder;
import org.elasticsearch.test.ESIntegTestCase;
import org.elasticsearch.xcontent.ConstructingObjectParser;
import org.elasticsearch.xcontent.ParseField;
import org.elasticsearch.test.hamcrest.ElasticsearchAssertions;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xcontent.XContentParser;
import org.junit.Before;

import java.io.IOException;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.ExecutionException;

import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg;
import static org.hamcrest.Matchers.equalTo;

public class RetrieverRewriteIT extends ESIntegTestCase {
@Override
protected Collection<Class<? extends Plugin>> nodePlugins() {
return List.of(AssertingPlugin.class, MockSearchService.TestPlugin.class);
return List.of(MockSearchService.TestPlugin.class);
}

private static String INDEX_DOCS = "docs";
Expand All @@ -60,78 +56,35 @@ public void setup() throws Exception {
refresh(INDEX_QUERIES);
}

public void testRewrite() throws ExecutionException, InterruptedException {
public void testRewrite() {
SearchSourceBuilder source = new SearchSourceBuilder();
StandardRetrieverBuilder standard = new StandardRetrieverBuilder();
standard.queryBuilder = QueryBuilders.termQuery(ID_FIELD, "doc_0");
source.retriever(new AssertingRetrieverBuilder(standard));
SearchRequest req = new SearchRequest(INDEX_DOCS, INDEX_QUERIES).source(source);
SearchResponse resp = client().search(req).get();
try {
SearchRequestBuilder req = client().prepareSearch(INDEX_DOCS, INDEX_QUERIES).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_0"));
} finally {
resp.decRef();
}
});
}

public void testRewriteCompound() throws ExecutionException, InterruptedException {
public void testRewriteCompound() {
SearchSourceBuilder source = new SearchSourceBuilder();
source.retriever(new AssertingCompoundRetrieverBuilder("query_0"));
SearchRequest req = new SearchRequest(INDEX_DOCS, INDEX_QUERIES).source(source);
SearchResponse resp = client().search(req).get();
try {
SearchRequestBuilder req = client().prepareSearch(INDEX_DOCS, INDEX_QUERIES).setSource(source);
ElasticsearchAssertions.assertResponse(req, resp -> {
assertNull(resp.pointInTimeId());
assertThat(resp.getHits().getTotalHits().value, equalTo(1L));
assertThat(resp.getHits().getTotalHits().relation, equalTo(TotalHits.Relation.EQUAL_TO));
assertThat(resp.getHits().getAt(0).getId(), equalTo("doc_2"));
} finally {
resp.decRef();
}
}

public static class AssertingPlugin extends Plugin implements SearchPlugin {
public AssertingPlugin() {}

@Override
public List<RetrieverSpec<?>> getRetrievers() {
return List.of(
new RetrieverSpec<RetrieverBuilder>(AssertingRetrieverBuilder.NAME, AssertingRetrieverBuilder::fromXContent),
new RetrieverSpec<RetrieverBuilder>(AssertingCompoundRetrieverBuilder.NAME, AssertingCompoundRetrieverBuilder::fromXContent)
);
}
}

public static final ConstructingObjectParser<AssertingRetrieverBuilder, RetrieverParserContext> PARSER = new ConstructingObjectParser<>(
AssertingCompoundRetrieverBuilder.NAME,
args -> new AssertingRetrieverBuilder((RetrieverBuilder) args[0])
);

public static final ConstructingObjectParser<AssertingCompoundRetrieverBuilder, RetrieverParserContext> PARSER_COMPOUND =
new ConstructingObjectParser<>(
AssertingCompoundRetrieverBuilder.NAME,
args -> new AssertingCompoundRetrieverBuilder((String) args[0])
);

static {
RetrieverBuilder.declareBaseParserFields(AssertingRetrieverBuilder.NAME, PARSER);
PARSER.declareObject(constructorArg(), RetrieverBuilder::parseInnerRetrieverBuilder, new ParseField("retriever"));

RetrieverBuilder.declareBaseParserFields(AssertingCompoundRetrieverBuilder.NAME, PARSER_COMPOUND);
PARSER_COMPOUND.declareString(constructorArg(), new ParseField("id"));
});
}

private static class AssertingRetrieverBuilder extends RetrieverBuilder {
static final String NAME = "asserting";

private final RetrieverBuilder innerRetriever;

public static AssertingRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context) throws IOException {
return PARSER.apply(parser, context);
}

private AssertingRetrieverBuilder(RetrieverBuilder innerRetriever) {
this.innerRetriever = innerRetriever;
}
Expand Down Expand Up @@ -177,16 +130,9 @@ protected int doHashCode() {
}

private static class AssertingCompoundRetrieverBuilder extends RetrieverBuilder {
static final String NAME = "asserting_compound";

private final String id;
private final SetOnce<RetrieverBuilder> innerRetriever;

public static AssertingCompoundRetrieverBuilder fromXContent(XContentParser parser, RetrieverParserContext context)
throws IOException {
return PARSER_COMPOUND.apply(parser, context);
}

private AssertingCompoundRetrieverBuilder(String id) {
this.id = id;
this.innerRetriever = new SetOnce<>(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,19 +490,25 @@ void executeRequest(
});
final SearchSourceBuilder source = original.source();
if (shouldOpenPIT(source)) {
openPIT(original, searchService.getDefaultKeepAliveInMillis(), listener.delegateFailureAndWrap((delegate, resp) -> {
openPIT(client, original, searchService.getDefaultKeepAliveInMillis(), listener.delegateFailureAndWrap((delegate, resp) -> {
// We set the keep alive to -1 to indicate that we don't need the pit id in the response.
// This is needed since we delete the pit prior to sending the response so the id doesn't exist anymore.
source.pointInTimeBuilder(new PointInTimeBuilder(resp.getPointInTimeId()).setKeepAlive(TimeValue.MINUS_ONE));
executeRequest(task, original, new ActionListener<>() {
@Override
public void onResponse(SearchResponse response) {
// we need to close the PIT first so we delay the release of the response to after the closing
response.incRef();
closePIT(original.source().pointInTimeBuilder(), () -> ActionListener.respondAndRelease(listener, response));
closePIT(
client,
original.source().pointInTimeBuilder(),
() -> ActionListener.respondAndRelease(listener, response)
);
}

@Override
public void onFailure(Exception e) {
closePIT(original.source().pointInTimeBuilder(), () -> listener.onFailure(e));
closePIT(client, original.source().pointInTimeBuilder(), () -> listener.onFailure(e));
}
}, searchPhaseProvider);
}));
Expand All @@ -529,15 +535,15 @@ private boolean shouldOpenPIT(SearchSourceBuilder source) {
return retriever != null && retriever.isCompound();
}

private void openPIT(SearchRequest request, long keepAliveMillis, ActionListener<OpenPointInTimeResponse> listener) {
static void openPIT(Client client, SearchRequest request, long keepAliveMillis, ActionListener<OpenPointInTimeResponse> listener) {
OpenPointInTimeRequest pitReq = new OpenPointInTimeRequest(request.indices()).indicesOptions(request.indicesOptions())
.preference(request.preference())
.routing(request.routing())
.keepAlive(TimeValue.timeValueMillis(keepAliveMillis));
client.execute(TransportOpenPointInTimeAction.TYPE, pitReq, listener);
}

private void closePIT(PointInTimeBuilder pit, Runnable next) {
static void closePIT(Client client, PointInTimeBuilder pit, Runnable next) {
client.execute(
TransportClosePointInTimeAction.TYPE,
new ClosePointInTimeRequest(pit.getEncodedId()),
Expand Down

0 comments on commit a370cac

Please sign in to comment.