Skip to content

Commit

Permalink
Stop instantiating RankFeaturePhase unnecessarily (elastic#115724)
Browse files Browse the repository at this point in the history
We should not create the phase instance when we know we won't be doing any rank feature
execution up-front. An instance of these isn't free and entails creating an array of searched_shard_count size
which along is non-trivial. Also, this needlessly obscured the threading logic for fetch which has already led to a bug
before.
  • Loading branch information
original-brownbear authored Oct 29, 2024
1 parent 2522c98 commit d0f71fc
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 134 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,17 @@ public class RankFeaturePhase extends SearchPhase {
final SearchPhaseResults<SearchPhaseResult> rankPhaseResults;
private final AggregatedDfs aggregatedDfs;
private final SearchProgressListener progressListener;
private final Client client;
private final RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext;

RankFeaturePhase(
SearchPhaseResults<SearchPhaseResult> queryPhaseResults,
AggregatedDfs aggregatedDfs,
SearchPhaseContext context,
Client client
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext
) {
super("rank-feature");
assert rankFeaturePhaseRankCoordinatorContext != null;
this.rankFeaturePhaseRankCoordinatorContext = rankFeaturePhaseRankCoordinatorContext;
if (context.getNumShards() != queryPhaseResults.getNumShards()) {
throw new IllegalStateException(
"number of shards must match the length of the query results but doesn't:"
Expand All @@ -65,17 +67,10 @@ public class RankFeaturePhase extends SearchPhase {
this.rankPhaseResults = new ArraySearchPhaseResults<>(context.getNumShards());
context.addReleasable(rankPhaseResults);
this.progressListener = context.getTask().getProgressListener();
this.client = client;
}

@Override
public void run() {
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext = coordinatorContext(context.getRequest().source());
if (rankFeaturePhaseRankCoordinatorContext == null) {
moveToNextPhase(queryPhaseResults, null);
return;
}

context.execute(new AbstractRunnable() {
@Override
protected void doRun() throws Exception {
Expand Down Expand Up @@ -122,7 +117,7 @@ void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordin
}
}

private RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source) {
static RankFeaturePhaseRankCoordinatorContext coordinatorContext(SearchSourceBuilder source, Client client) {
return source == null || source.rankBuilder() == null
? null
: source.rankBuilder().buildRankFeaturePhaseCoordinatorContext(source.size(), source.from(), client);
Expand Down Expand Up @@ -175,7 +170,6 @@ private void onPhaseDone(
RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext,
SearchPhaseController.ReducedQueryPhase reducedQueryPhase
) {
assert rankFeaturePhaseRankCoordinatorContext != null;
ThreadedActionListener<RankFeatureDoc[]> rankResultListener = new ThreadedActionListener<>(context, new ActionListener<>() {
@Override
public void onResponse(RankFeatureDoc[] docsWithUpdatedScores) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ protected SearchPhase getNextPhase(final SearchPhaseResults<DfsSearchResult> res
aggregatedDfs,
mergedKnnResults,
queryPhaseResultConsumer,
(queryResults) -> new RankFeaturePhase(queryResults, aggregatedDfs, context, client),
(queryResults) -> SearchQueryThenFetchAsyncAction.nextPhase(client, context, queryResults, aggregatedDfs),
context
);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.search.SearchPhaseResult;
import org.elasticsearch.search.SearchShardTarget;
import org.elasticsearch.search.dfs.AggregatedDfs;
import org.elasticsearch.search.internal.AliasFilter;
import org.elasticsearch.search.internal.SearchContext;
import org.elasticsearch.search.internal.ShardSearchRequest;
Expand Down Expand Up @@ -125,9 +126,22 @@ && getRequest().scroll() == null
super.onShardResult(result, shardIt);
}

static SearchPhase nextPhase(
Client client,
SearchPhaseContext context,
SearchPhaseResults<SearchPhaseResult> queryResults,
AggregatedDfs aggregatedDfs
) {
var rankFeaturePhaseCoordCtx = RankFeaturePhase.coordinatorContext(context.getRequest().source(), client);
if (rankFeaturePhaseCoordCtx == null) {
return new FetchSearchPhase(queryResults, aggregatedDfs, context, null);
}
return new RankFeaturePhase(queryResults, aggregatedDfs, context, rankFeaturePhaseCoordCtx);
}

@Override
protected SearchPhase getNextPhase(final SearchPhaseResults<SearchPhaseResult> results, SearchPhaseContext context) {
return new RankFeaturePhase(results, null, this, client);
return nextPhase(client, this, results, null);
}

private ShardSearchRequest rewriteShardSearchRequest(ShardSearchRequest request) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -287,95 +287,6 @@ public void sendExecuteRankFeature(
}
}

public void testRankFeaturePhaseNoNeedForFetchingFieldData() {
AtomicBoolean phaseDone = new AtomicBoolean(false);
final ScoreDoc[][] finalResults = new ScoreDoc[1][1];

// build the appropriate RankBuilder; using a null rankFeaturePhaseRankShardContext
// and non-field based rankFeaturePhaseRankCoordinatorContext
RankBuilder rankBuilder = rankBuilder(
DEFAULT_RANK_WINDOW_SIZE,
defaultQueryPhaseRankShardContext(Collections.emptyList(), DEFAULT_RANK_WINDOW_SIZE),
negatingScoresQueryFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE),
null,
null
);
// create a SearchSource to attach to the request
SearchSourceBuilder searchSourceBuilder = searchSourceWithRankBuilder(rankBuilder);

SearchPhaseController controller = searchPhaseController();
SearchShardTarget shard1Target = new SearchShardTarget("node0", new ShardId("test", "na", 0), null);

MockSearchPhaseContext mockSearchPhaseContext = new MockSearchPhaseContext(1);
mockSearchPhaseContext.getRequest().source(searchSourceBuilder);
try (SearchPhaseResults<SearchPhaseResult> results = searchPhaseResults(controller, mockSearchPhaseContext)) {
// generate the QuerySearchResults that the RankFeaturePhase would have received from QueryPhase
// here we have 2 results, with doc ids 1 and 2
final ShardSearchContextId ctx = new ShardSearchContextId(UUIDs.base64UUID(), 123);
QuerySearchResult queryResult = new QuerySearchResult(ctx, shard1Target, null);

try {
queryResult.setShardIndex(shard1Target.getShardId().getId());
int totalHits = randomIntBetween(2, 100);
final ScoreDoc[] shard1Docs = new ScoreDoc[] { new ScoreDoc(1, 10.0F), new ScoreDoc(2, 9.0F) };
populateQuerySearchResult(queryResult, totalHits, shard1Docs);
results.consumeResult(queryResult, () -> {});
// do not make an actual http request, but rather generate the response
// as if we would have read it from the RankFeatureShardPhase
mockSearchPhaseContext.searchTransport = new SearchTransportService(null, null, null) {
@Override
public void sendExecuteRankFeature(
Transport.Connection connection,
final RankFeatureShardRequest request,
SearchTask task,
final ActionListener<RankFeatureResult> listener
) {
// make sure to match the context id generated above, otherwise we throw
if (request.contextId().getId() == 123 && Arrays.equals(request.getDocIds(), new int[] { 1, 2 })) {
listener.onFailure(new UnsupportedOperationException("should not have reached here"));
} else {
listener.onFailure(new MockDirectoryWrapper.FakeIOException());
}
}
};
} finally {
queryResult.decRef();
}
// override the RankFeaturePhase to skip moving to next phase
RankFeaturePhase rankFeaturePhase = rankFeaturePhase(results, mockSearchPhaseContext, finalResults, phaseDone);
try {
rankFeaturePhase.run();
mockSearchPhaseContext.assertNoFailure();
assertTrue(mockSearchPhaseContext.failures.isEmpty());
assertTrue(phaseDone.get());

// in this case there was no additional "RankFeature" results on shards, so we shortcut directly to queryPhaseResults
SearchPhaseResults<SearchPhaseResult> rankPhaseResults = rankFeaturePhase.queryPhaseResults;
assertNotNull(rankPhaseResults.getAtomicArray());
assertEquals(1, rankPhaseResults.getAtomicArray().length());
assertEquals(1, rankPhaseResults.getSuccessfulResults().count());

SearchPhaseResult shardResult = rankPhaseResults.getAtomicArray().get(0);
assertTrue(shardResult instanceof QuerySearchResult);
QuerySearchResult rankResult = (QuerySearchResult) shardResult;
assertNull(rankResult.rankFeatureResult());
assertNotNull(rankResult.queryResult());

List<ExpectedRankFeatureDoc> expectedFinalResults = List.of(
new ExpectedRankFeatureDoc(2, 1, -9.0F, null),
new ExpectedRankFeatureDoc(1, 2, -10.0F, null)
);
assertFinalResults(finalResults[0], expectedFinalResults);
} finally {
rankFeaturePhase.rankPhaseResults.close();
}
} finally {
if (mockSearchPhaseContext.searchResponse.get() != null) {
mockSearchPhaseContext.searchResponse.get().decRef();
}
}
}

public void testRankFeaturePhaseOneShardFails() {
AtomicBoolean phaseDone = new AtomicBoolean(false);
final ScoreDoc[][] finalResults = new ScoreDoc[1][1];
Expand Down Expand Up @@ -534,7 +445,12 @@ public void sendExecuteRankFeature(
queryResult.decRef();
}
// override the RankFeaturePhase to raise an exception
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
RankFeaturePhase rankFeaturePhase = new RankFeaturePhase(
results,
null,
mockSearchPhaseContext,
defaultRankFeaturePhaseRankCoordinatorContext(DEFAULT_SIZE, DEFAULT_FROM, DEFAULT_RANK_WINDOW_SIZE)
) {
@Override
void innerRun(RankFeaturePhaseRankCoordinatorContext rankFeaturePhaseRankCoordinatorContext) {
throw new IllegalArgumentException("simulated failure");
Expand Down Expand Up @@ -890,36 +806,6 @@ public RankFeatureDoc[] rankAndPaginate(RankFeatureDoc[] rankFeatureDocs) {
};
}

private QueryPhaseRankCoordinatorContext negatingScoresQueryFeaturePhaseRankCoordinatorContext(int size, int from, int rankWindowSize) {
return new QueryPhaseRankCoordinatorContext(rankWindowSize) {
@Override
public ScoreDoc[] rankQueryPhaseResults(
List<QuerySearchResult> rankSearchResults,
SearchPhaseController.TopDocsStats topDocsStats
) {
List<ScoreDoc> docScores = new ArrayList<>();
for (QuerySearchResult phaseResults : rankSearchResults) {
docScores.addAll(Arrays.asList(phaseResults.topDocs().topDocs.scoreDocs));
}
ScoreDoc[] sortedDocs = docScores.toArray(new ScoreDoc[0]);
// negating scores
Arrays.stream(sortedDocs).forEach(doc -> doc.score *= -1);

Arrays.sort(sortedDocs, Comparator.comparing((ScoreDoc doc) -> doc.score).reversed());
sortedDocs = Arrays.stream(sortedDocs).limit(rankWindowSize).toArray(ScoreDoc[]::new);
RankFeatureDoc[] topResults = new RankFeatureDoc[Math.max(0, Math.min(size, sortedDocs.length - from))];
// perform pagination
for (int rank = 0; rank < topResults.length; ++rank) {
ScoreDoc base = sortedDocs[from + rank];
topResults[rank] = new RankFeatureDoc(base.doc, base.score, base.shardIndex);
topResults[rank].rank = from + rank + 1;
}
topDocsStats.fetchHits = topResults.length;
return topResults;
}
};
}

private RankFeaturePhaseRankShardContext defaultRankFeaturePhaseRankShardContext(String field) {
return new RankFeaturePhaseRankShardContext(field) {
@Override
Expand Down Expand Up @@ -1134,7 +1020,12 @@ private RankFeaturePhase rankFeaturePhase(
AtomicBoolean phaseDone
) {
// override the RankFeaturePhase to skip moving to next phase
return new RankFeaturePhase(results, null, mockSearchPhaseContext, null) {
return new RankFeaturePhase(
results,
null,
mockSearchPhaseContext,
RankFeaturePhase.coordinatorContext(mockSearchPhaseContext.getRequest().source(), null)
) {
@Override
public void moveToNextPhase(
SearchPhaseResults<SearchPhaseResult> phaseResults,
Expand Down

0 comments on commit d0f71fc

Please sign in to comment.