Skip to content

Commit

Permalink
Fix the rewrite method for MatchOnlyText field query (#14248)
Browse files Browse the repository at this point in the history
Signed-off-by: Rishabh Maurya <[email protected]>
  • Loading branch information
rishabhmaurya authored Jun 13, 2024
1 parent d25b64d commit 679ccac
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ public void visit(QueryVisitor visitor) {

@Override
public Query rewrite(IndexSearcher indexSearcher) throws IOException {
Query rewritten = indexSearcher.rewrite(delegateQuery);
Query rewritten = delegateQuery.rewrite(indexSearcher);
if (rewritten == delegateQuery) {
return this;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.MatchAllDocsQuery;
import org.apache.lucene.search.MatchNoDocsQuery;
import org.apache.lucene.search.PhraseQuery;
import org.apache.lucene.search.Pruning;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.Sort;
Expand All @@ -48,12 +49,18 @@
import org.opensearch.action.search.SearchShardTask;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.xcontent.json.JsonXContent;
import org.opensearch.core.index.Index;
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.index.mapper.DocumentMapper;
import org.opensearch.index.mapper.MatchOnlyTextFieldMapper;
import org.opensearch.index.mapper.NumberFieldMapper.NumberFieldType;
import org.opensearch.index.mapper.NumberFieldMapper.NumberType;
import org.opensearch.index.mapper.SourceFieldMapper;
import org.opensearch.index.mapper.TextSearchInfo;
import org.opensearch.index.query.ParsedQuery;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.index.query.SourceFieldMatchQuery;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.IndexShardTestCase;
import org.opensearch.lucene.queries.MinDocQuery;
Expand All @@ -62,6 +69,9 @@
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.ScrollContext;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.lookup.LeafSearchLookup;
import org.opensearch.search.lookup.SearchLookup;
import org.opensearch.search.lookup.SourceLookup;
import org.opensearch.search.profile.ProfileResult;
import org.opensearch.search.profile.ProfileShardResult;
import org.opensearch.search.profile.SearchProfileShardResults;
Expand All @@ -80,6 +90,7 @@
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
Expand All @@ -94,6 +105,7 @@
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

Expand Down Expand Up @@ -1526,6 +1538,90 @@ public void testCollapseQuerySearchResults() throws Exception {
dir.close();
}

public void testSourceFieldMatchQueryWithProfile() throws Exception {
Directory dir = newDirectory();
IndexWriterConfig iwc = newIndexWriterConfig();
RandomIndexWriter w = new RandomIndexWriter(random(), dir, iwc);
w.close();
IndexReader reader = DirectoryReader.open(dir);
QueryShardContext queryShardContext = mock(QueryShardContext.class);
DocumentMapper mockDocumentMapper = mock(DocumentMapper.class);
SourceFieldMapper mockSourceMapper = mock(SourceFieldMapper.class);
SearchLookup searchLookup = mock(SearchLookup.class);
LeafSearchLookup leafSearchLookup = mock(LeafSearchLookup.class);

when(queryShardContext.sourcePath("foo")).thenReturn(Set.of("bar"));
when(queryShardContext.index()).thenReturn(new Index("test_index", "uuid"));
when(searchLookup.getLeafSearchLookup(any())).thenReturn(leafSearchLookup);
when(leafSearchLookup.source()).thenReturn(new SourceLookup());
when(mockSourceMapper.enabled()).thenReturn(true);
when(mockDocumentMapper.sourceMapper()).thenReturn(mockSourceMapper);
when(queryShardContext.documentMapper(any())).thenReturn(mockDocumentMapper);
when(queryShardContext.lookup()).thenReturn(searchLookup);

TestSearchContext context = new TestSearchContext(queryShardContext, indexShard, newContextSearcher(reader, executor));
context.parsedQuery(
new ParsedQuery(
new SourceFieldMatchQuery(
new TermQuery(new Term("foo", "bar")),
new PhraseQuery("foo", "bar", "baz"),
new MatchOnlyTextFieldMapper.MatchOnlyTextFieldType(
"user",
true,
true,
TextSearchInfo.WHITESPACE_MATCH_ONLY,
Collections.emptyMap()
),
queryShardContext
)
)
);

context.setTask(new SearchShardTask(123L, "", "", "", null, Collections.emptyMap()));
context.setSize(1);
context.trackTotalHitsUpTo(5);
QueryPhase.executeInternal(context.withCleanQueryResult().withProfilers(), queryPhaseSearcher);
assertProfileData(context, "SourceFieldMatchQuery", query -> {
assertThat(query.getTimeBreakdown().keySet(), not(empty()));
assertThat(query.getTimeBreakdown().get("score"), equalTo(0L));
assertThat(query.getTimeBreakdown().get("score_count"), equalTo(0L));
if (executor != null) {
long maxScore = query.getTimeBreakdown().get("max_score");
long minScore = query.getTimeBreakdown().get("min_score");
long avgScore = query.getTimeBreakdown().get("avg_score");
long maxScoreCount = query.getTimeBreakdown().get("max_score_count");
long minScoreCount = query.getTimeBreakdown().get("min_score_count");
long avgScoreCount = query.getTimeBreakdown().get("avg_score_count");
assertThat(maxScore, equalTo(0L));
assertThat(minScore, equalTo(0L));
assertThat(avgScore, equalTo(0L));
assertThat(maxScore, equalTo(avgScore));
assertThat(avgScore, equalTo(minScore));
assertThat(maxScoreCount, equalTo(0L));
assertThat(minScoreCount, equalTo(0L));
assertThat(avgScoreCount, equalTo(0L));
assertThat(maxScoreCount, equalTo(avgScoreCount));
assertThat(avgScoreCount, equalTo(minScoreCount));
}
assertThat(query.getTimeBreakdown().get("create_weight"), greaterThan(0L));
assertThat(query.getTimeBreakdown().get("create_weight_count"), equalTo(1L));
assertThat(query.getProfiledChildren(), empty());
}, collector -> {
assertThat(collector.getReason(), equalTo("search_top_hits"));
assertThat(collector.getTime(), greaterThan(0L));
if (collector.getName().contains("CollectorManager")) {
assertThat(collector.getReduceTime(), greaterThan(0L));
}
assertThat(collector.getMaxSliceTime(), greaterThan(0L));
assertThat(collector.getMinSliceTime(), greaterThan(0L));
assertThat(collector.getAvgSliceTime(), greaterThan(0L));
assertThat(collector.getSliceCount(), greaterThanOrEqualTo(1));
assertThat(collector.getProfiledChildren(), empty());
});
reader.close();
dir.close();
}

private void assertProfileData(SearchContext context, String type, Consumer<ProfileResult> query, Consumer<CollectorResult> collector)
throws IOException {
assertProfileData(context, collector, (profileResult) -> {
Expand Down

0 comments on commit 679ccac

Please sign in to comment.