Skip to content

Commit

Permalink
Adding aggregations to hybrid query
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Jan 6, 2024
1 parent ff38622 commit cb1f929
Show file tree
Hide file tree
Showing 6 changed files with 321 additions and 14 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ protected boolean isThresholdReached() {
return hitCount >= getTotalHitsThreshold();
}

protected ScoreMode scoreMode() {
public ScoreMode scoreMode() {
return ScoreMode.TOP_SCORES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

Expand All @@ -26,6 +27,7 @@

import lombok.Getter;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.search.HitsThresholdChecker;

/**
* Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results
Expand Down Expand Up @@ -56,11 +58,38 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept
@Override
public void setScorer(Scorable scorer) throws IOException {
super.setScorer(scorer);
compoundQueryScorer = (HybridQueryScorer) scorer;
if (scorer instanceof HybridQueryScorer) {
compoundQueryScorer = (HybridQueryScorer) scorer;
}
else {
compoundQueryScorer = getHybridQueryScorer(scorer);
}
}

@Override
private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException {
if (scorer == null) {
return null;
}
if (scorer instanceof HybridQueryScorer) {
return (HybridQueryScorer) scorer;
}
for (Scorable.ChildScorable childScorable : scorer.getChildren()) {
HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child);
if (hybridQueryScorer != null) {
return hybridQueryScorer;
}
}
return null;
}



@Override
public void collect(int doc) throws IOException {
if (compoundQueryScorer == null) {
scorer.score();
return;
}
float[] subScoresByQuery = compoundQueryScorer.hybridScores();
// iterate over results for each query
if (compoundScores == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
import java.util.List;
import java.util.Objects;

import com.google.common.base.Throwables;
import org.apache.lucene.index.IndexReader;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Collector;
import org.apache.lucene.search.FieldExistsQuery;
import org.apache.lucene.search.MultiCollector;
import org.apache.lucene.search.Query;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
Expand Down Expand Up @@ -210,11 +213,17 @@ protected boolean searchWithCollector(

final QuerySearchResult queryResult = searchContext.queryResult();

final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector(
Collector collector = new HybridTopScoreDocCollector(
numDocs,
new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo()))
);

// cannot use streams here as assigment of global variable inside the lambda will not be possible
for (int idx = 1; idx < collectors.size(); idx++) {
QueryCollectorContext collectorContext = collectors.get(idx);
collector = collectorContext.create(collector);
}

searcher.search(query, collector);

if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) {
Expand All @@ -223,20 +232,35 @@ protected boolean searchWithCollector(

setTopDocsInQueryResult(queryResult, collector, searchContext);

collectors.stream().skip(1).forEach(ctx -> {
try {
ctx.postProcess(queryResult);
} catch (IOException e) {
Throwables.throwIfUnchecked(e);
}
});

return shouldRescore;
}

private void setTopDocsInQueryResult(
final QuerySearchResult queryResult,
final HybridTopScoreDocCollector collector,
final Collector collector,
final SearchContext searchContext
) {
final List<TopDocs> topDocs = collector.topDocs();
final float maxScore = getMaxScore(topDocs);
final boolean isSingleShard = searchContext.numberOfShards() == 1;
final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs);
final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
if (collector instanceof HybridTopScoreDocCollector) {
List<TopDocs> topDocs = ((HybridTopScoreDocCollector) collector).topDocs();
float maxScore = getMaxScore(topDocs);
boolean isSingleShard = searchContext.numberOfShards() == 1;
TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs);
TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort()));
} else if (collector instanceof MultiCollector) {
MultiCollector multiCollector = (MultiCollector) collector;
for (Collector subCollector : multiCollector.getCollectors()) {
setTopDocsInQueryResult(queryResult, subCollector, searchContext);
}
}
}

private TopDocs getNewTopDocs(final TotalHits totalHits, final List<TopDocs> topDocs) {
Expand Down
119 changes: 119 additions & 0 deletions src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
import com.google.common.primitives.Floats;

import lombok.SneakyThrows;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;

public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index";
Expand All @@ -44,6 +46,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index";
private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD =
"test-neural-multi-doc-nested-type--single-shard-index";
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT =
"test-neural-multi-doc-text-and-int-index";
private static final String TEST_QUERY_TEXT = "greetings";
private static final String TEST_QUERY_TEXT2 = "salute";
private static final String TEST_QUERY_TEXT3 = "hello";
Expand All @@ -60,6 +64,9 @@ public class HybridQueryIT extends BaseNeuralSearchIT {
private static final String NESTED_FIELD_2 = "lastname";
private static final String NESTED_FIELD_1_VALUE = "john";
private static final String NESTED_FIELD_2_VALUE = "black";
private static final String INTEGER_FIELD_1 = "doc_index";
private static final int INTEGER_FIELD_1_VALUE = 1234;
private static final int INTEGER_FIELD_2_VALUE = 2345;
private final float[] testVector1 = createRandomVector(TEST_DIMENSION);
private final float[] testVector2 = createRandomVector(TEST_DIMENSION);
private final float[] testVector3 = createRandomVector(TEST_DIMENSION);
Expand Down Expand Up @@ -378,6 +385,78 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess(
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

/**
* Tests complex query with multiple nested sub-queries:
* {
* "query": {
* "hybrid": {
* "queries": [
* {
* "term": {
* "text": "word1"
* }
* },
* {
* "term": {
* "text": "word3"
* }
* }
* ]
* }
* },
* "aggs": {
* "max_index": {
* "max": {
* "field": "doc_index"
* }
* }
* }
* }
*/
@SneakyThrows
public void testAggregations_whenMetricAggregationsInQuery_thenSuccessful() {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT);

TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);

HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2);

AggregationBuilder aggsBuilder = AggregationBuilders.max("max_aggs").field(INTEGER_FIELD_1);
//AggregationBuilder aggsBuilder = null;
Map<String, Object> searchResponseAsMap1 = search(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT,
hybridQueryBuilderNeuralThenTerm,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
aggsBuilder
);

assertEquals(1, getHitCount(searchResponseAsMap1));

List<Map<String, Object>> hits1NestedList = getNestedHits(searchResponseAsMap1);
List<String> ids = new ArrayList<>();
List<Double> scores = new ArrayList<>();
for (Map<String, Object> oneHit : hits1NestedList) {
ids.add((String) oneHit.get("_id"));
scores.add((Double) oneHit.get("_score"));
}

// verify that scores are in desc order
assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1)));
// verify that all ids are unique
assertEquals(Set.copyOf(ids).size(), ids.size());

Map<String, Object> total = getTotalHits(searchResponseAsMap1);
assertNotNull(total.get("value"));
assertEquals(1, total.get("value"));
assertNotNull(total.get("relation"));
assertEquals(RELATION_EQUAL_TO, total.get("relation"));
}

@SneakyThrows
private void initializeIndexIfNotExist(String indexName) throws IOException {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) {
Expand Down Expand Up @@ -469,6 +548,46 @@ private void initializeIndexIfNotExist(String indexName) throws IOException {
List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE))
);
}

if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT.equals(indexName)
&& !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT)) {
createIndexWithConfiguration(
indexName,
buildIndexConfiguration(
List.of(),
List.of(),
List.of(INTEGER_FIELD_1),
1
),
""
);

addKnnDoc(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT,
"1",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT1),
List.of(),
List.of(),
List.of(INTEGER_FIELD_1),
List.of(INTEGER_FIELD_1_VALUE)
);

addKnnDoc(
TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT,
"2",
List.of(),
List.of(),
Collections.singletonList(TEST_TEXT_FIELD_NAME_1),
Collections.singletonList(TEST_DOC_TEXT3),
List.of(),
List.of(),
List.of(INTEGER_FIELD_1),
List.of(INTEGER_FIELD_2_VALUE)
);
}
}

private void addDocsToIndex(final String testMultiDocIndexName) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,13 @@
import org.opensearch.search.SearchShardTarget;
import org.opensearch.search.internal.ContextIndexSearcher;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;

import com.carrotsearch.randomizedtesting.RandomizedTest;

import lombok.SneakyThrows;
import org.opensearch.search.query.QueryCollectorContext;
import org.opensearch.search.query.QuerySearchResult;

public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase {
private static final String VECTOR_FIELD_NAME = "vectorField";
Expand Down Expand Up @@ -831,6 +832,82 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() {
releaseResources(directory, w, reader);
}

@SneakyThrows
public void testAggregations_whenMetricAggregation_thenSuccessful() {
HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher());
QueryShardContext mockQueryShardContext = mock(QueryShardContext.class);
KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class);
when(mockQueryShardContext.index()).thenReturn(dummyIndex);
when(mockKNNVectorField.getDimension()).thenReturn(4);
when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField);
MapperService mapperService = createMapperService();
TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME);
when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType);

Directory directory = newDirectory();
IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random())));
FieldType ft = new FieldType(TextField.TYPE_NOT_STORED);
ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS);
ft.setOmitNorms(random().nextBoolean());
ft.freeze();

w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft));
w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft));
w.commit();

IndexReader reader = DirectoryReader.open(w);
SearchContext searchContext = mock(SearchContext.class);

ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher(
reader,
IndexSearcher.getDefaultSimilarity(),
IndexSearcher.getDefaultQueryCache(),
IndexSearcher.getDefaultQueryCachingPolicy(),
true,
null,
searchContext
);

ShardId shardId = new ShardId(dummyIndex, 1);
SearchShardTarget shardTarget = new SearchShardTarget(
randomAlphaOfLength(10),
shardId,
randomAlphaOfLength(10),
OriginalIndices.NONE
);
when(searchContext.shardTarget()).thenReturn(shardTarget);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
when(searchContext.numberOfShards()).thenReturn(1);
when(searchContext.searcher()).thenReturn(contextIndexSearcher);
IndexShard indexShard = mock(IndexShard.class);
when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0));
when(searchContext.indexShard()).thenReturn(indexShard);
when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR);
when(searchContext.mapperService()).thenReturn(mapperService);

LinkedList<QueryCollectorContext> collectors = new LinkedList<>();

boolean hasFilterCollector = randomBoolean();
boolean hasTimeout = randomBoolean();

HybridQueryBuilder queryBuilder = new HybridQueryBuilder();

TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1);
queryBuilder.add(termSubQuery);

Query query = queryBuilder.toQuery(mockQueryShardContext);
when(searchContext.query()).thenReturn(query);
QuerySearchResult querySearchResult = new QuerySearchResult();
when(searchContext.queryResult()).thenReturn(querySearchResult);

hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout);

releaseResources(directory, w, reader);

verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean());
}

@SneakyThrows
private void assertQueryResults(TopDocs subQueryTopDocs, List<Integer> expectedDocIds, IndexReader reader) {
assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value);
Expand Down
Loading

0 comments on commit cb1f929

Please sign in to comment.