From cb1f929a0b4758dae489c4f0e88c064417e6e945 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 26 Dec 2023 15:05:57 -0800 Subject: [PATCH] Adding aggregations to hybrid query Signed-off-by: Martin Gaievski --- .../search/HitsThresholdChecker.java | 2 +- .../search/HybridTopScoreDocCollector.java | 33 ++++- .../query/HybridQueryPhaseSearcher.java | 40 ++++-- .../neuralsearch/query/HybridQueryIT.java | 119 ++++++++++++++++++ .../query/HybridQueryPhaseSearcherTests.java | 81 +++++++++++- .../neuralsearch/BaseNeuralSearchIT.java | 60 ++++++++- 6 files changed, 321 insertions(+), 14 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java index 1299537bb..2e8b365e2 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HitsThresholdChecker.java @@ -34,7 +34,7 @@ protected boolean isThresholdReached() { return hitCount >= getTotalHitsThreshold(); } - protected ScoreMode scoreMode() { + public ScoreMode scoreMode() { return ScoreMode.TOP_SCORES; } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java index 8b7a12d29..6bc3c7be0 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java +++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Locale; +import java.util.Optional; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -26,6 +27,7 @@ import lombok.Getter; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.search.HitsThresholdChecker; /** * Collects the TopDocs after executing hybrid query. Uses HybridQueryTopDocs as DTO to handle each sub query results @@ -56,11 +58,38 @@ public LeafCollector getLeafCollector(LeafReaderContext context) throws IOExcept @Override public void setScorer(Scorable scorer) throws IOException { super.setScorer(scorer); - compoundQueryScorer = (HybridQueryScorer) scorer; + if (scorer instanceof HybridQueryScorer) { + compoundQueryScorer = (HybridQueryScorer) scorer; + } + else { + compoundQueryScorer = getHybridQueryScorer(scorer); + } } - @Override + private HybridQueryScorer getHybridQueryScorer(final Scorable scorer) throws IOException { + if (scorer == null) { + return null; + } + if (scorer instanceof HybridQueryScorer) { + return (HybridQueryScorer) scorer; + } + for (Scorable.ChildScorable childScorable : scorer.getChildren()) { + HybridQueryScorer hybridQueryScorer = getHybridQueryScorer(childScorable.child); + if (hybridQueryScorer != null) { + return hybridQueryScorer; + } + } + return null; + } + + + + @Override public void collect(int doc) throws IOException { + if (compoundQueryScorer == null) { + scorer.score(); + return; + } float[] subScoresByQuery = compoundQueryScorer.hybridScores(); // iterate over results for each query if (compoundScores == null) { diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index bf05fdc9d..616b1a652 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -15,10 +15,13 @@ import java.util.List; import java.util.Objects; +import com.google.common.base.Throwables; import org.apache.lucene.index.IndexReader; import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.Collector; import org.apache.lucene.search.FieldExistsQuery; +import org.apache.lucene.search.MultiCollector; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -210,11 +213,17 @@ protected boolean searchWithCollector( final QuerySearchResult queryResult = searchContext.queryResult(); - final HybridTopScoreDocCollector collector = new HybridTopScoreDocCollector( + Collector collector = new HybridTopScoreDocCollector( numDocs, new HitsThresholdChecker(Math.max(numDocs, searchContext.trackTotalHitsUpTo())) ); + // cannot use streams here as assigment of global variable inside the lambda will not be possible + for (int idx = 1; idx < collectors.size(); idx++) { + QueryCollectorContext collectorContext = collectors.get(idx); + collector = collectorContext.create(collector); + } + searcher.search(query, collector); if (searchContext.terminateAfter() != SearchContext.DEFAULT_TERMINATE_AFTER && queryResult.terminatedEarly() == null) { @@ -223,20 +232,35 @@ protected boolean searchWithCollector( setTopDocsInQueryResult(queryResult, collector, searchContext); + collectors.stream().skip(1).forEach(ctx -> { + try { + ctx.postProcess(queryResult); + } catch (IOException e) { + Throwables.throwIfUnchecked(e); + } + }); + return shouldRescore; } private void setTopDocsInQueryResult( final QuerySearchResult queryResult, - final HybridTopScoreDocCollector collector, + final Collector collector, final SearchContext searchContext ) { - final List topDocs = collector.topDocs(); - final float maxScore = getMaxScore(topDocs); - final boolean isSingleShard = searchContext.numberOfShards() == 1; - final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); - final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); - queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); + if (collector instanceof HybridTopScoreDocCollector) { + List topDocs = ((HybridTopScoreDocCollector) collector).topDocs(); + float maxScore = getMaxScore(topDocs); + boolean isSingleShard = searchContext.numberOfShards() == 1; + TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); + queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); + } else if (collector instanceof MultiCollector) { + MultiCollector multiCollector = (MultiCollector) collector; + for (Collector subCollector : multiCollector.getCollectors()) { + setTopDocsInQueryResult(queryResult, subCollector, searchContext); + } + } } private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 7e80d7fda..2f279a4cc 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -36,6 +36,8 @@ import com.google.common.primitives.Floats; import lombok.SneakyThrows; +import org.opensearch.search.aggregations.AggregationBuilder; +import org.opensearch.search.aggregations.AggregationBuilders; public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; @@ -44,6 +46,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-neural-multi-doc-single-shard-index"; private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-neural-multi-doc-nested-type--single-shard-index"; + private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT = + "test-neural-multi-doc-text-and-int-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -60,6 +64,9 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String NESTED_FIELD_2 = "lastname"; private static final String NESTED_FIELD_1_VALUE = "john"; private static final String NESTED_FIELD_2_VALUE = "black"; + private static final String INTEGER_FIELD_1 = "doc_index"; + private static final int INTEGER_FIELD_1_VALUE = 1234; + private static final int INTEGER_FIELD_2_VALUE = 2345; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -378,6 +385,78 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess( assertEquals(RELATION_EQUAL_TO, total.get("relation")); } + /** + * Tests complex query with multiple nested sub-queries: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * }, + * "aggs": { + * "max_index": { + * "max": { + * "field": "doc_index" + * } + * } + * } + * } + */ + @SneakyThrows + public void testAggregations_whenMetricAggregationsInQuery_thenSuccessful() { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder2); + + AggregationBuilder aggsBuilder = AggregationBuilders.max("max_aggs").field(INTEGER_FIELD_1); + //AggregationBuilder aggsBuilder = null; + Map searchResponseAsMap1 = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE), + aggsBuilder + ); + + assertEquals(1, getHitCount(searchResponseAsMap1)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(1, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -469,6 +548,46 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) ); } + + if (TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT.equals(indexName) + && !indexExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + 1 + ), + "" + ); + + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT, + "1", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_1_VALUE) + ); + + addKnnDoc( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT, + "2", + List.of(), + List.of(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3), + List.of(), + List.of(), + List.of(INTEGER_FIELD_1), + List.of(INTEGER_FIELD_2_VALUE) + ); + } } private void addDocsToIndex(final String testMultiDocIndexName) { diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index e609eec05..56bce8962 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -63,12 +63,13 @@ import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.query.QueryCollectorContext; -import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import com.carrotsearch.randomizedtesting.RandomizedTest; import lombok.SneakyThrows; +import org.opensearch.search.query.QueryCollectorContext; +import org.opensearch.search.query.QuerySearchResult; public class HybridQueryPhaseSearcherTests extends OpenSearchQueryTestCase { private static final String VECTOR_FIELD_NAME = "vectorField"; @@ -831,6 +832,82 @@ public void testBoolQuery_whenTooManyNestedLevels_thenSuccess() { releaseResources(directory, w, reader); } + @SneakyThrows + public void testAggregations_whenMetricAggregation_thenSuccessful() { + HybridQueryPhaseSearcher hybridQueryPhaseSearcher = spy(new HybridQueryPhaseSearcher()); + QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); + KNNVectorFieldMapper.KNNVectorFieldType mockKNNVectorField = mock(KNNVectorFieldMapper.KNNVectorFieldType.class); + when(mockQueryShardContext.index()).thenReturn(dummyIndex); + when(mockKNNVectorField.getDimension()).thenReturn(4); + when(mockQueryShardContext.fieldMapper(eq(VECTOR_FIELD_NAME))).thenReturn(mockKNNVectorField); + MapperService mapperService = createMapperService(); + TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) mapperService.fieldType(TEXT_FIELD_NAME); + when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); + + Directory directory = newDirectory(); + IndexWriter w = new IndexWriter(directory, newIndexWriterConfig(new MockAnalyzer(random()))); + FieldType ft = new FieldType(TextField.TYPE_NOT_STORED); + ft.setIndexOptions(random().nextBoolean() ? IndexOptions.DOCS : IndexOptions.DOCS_AND_FREQS); + ft.setOmitNorms(random().nextBoolean()); + ft.freeze(); + + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT1, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT2, ft)); + w.addDocument(getDocument(TEXT_FIELD_NAME, RandomizedTest.randomInt(), TEST_DOC_TEXT3, ft)); + w.commit(); + + IndexReader reader = DirectoryReader.open(w); + SearchContext searchContext = mock(SearchContext.class); + + ContextIndexSearcher contextIndexSearcher = new ContextIndexSearcher( + reader, + IndexSearcher.getDefaultSimilarity(), + IndexSearcher.getDefaultQueryCache(), + IndexSearcher.getDefaultQueryCachingPolicy(), + true, + null, + searchContext + ); + + ShardId shardId = new ShardId(dummyIndex, 1); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLength(10), + shardId, + randomAlphaOfLength(10), + OriginalIndices.NONE + ); + when(searchContext.shardTarget()).thenReturn(shardTarget); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + when(searchContext.numberOfShards()).thenReturn(1); + when(searchContext.searcher()).thenReturn(contextIndexSearcher); + IndexShard indexShard = mock(IndexShard.class); + when(indexShard.shardId()).thenReturn(new ShardId("test", "test", 0)); + when(searchContext.indexShard()).thenReturn(indexShard); + when(searchContext.bucketCollectorProcessor()).thenReturn(SearchContext.NO_OP_BUCKET_COLLECTOR_PROCESSOR); + when(searchContext.mapperService()).thenReturn(mapperService); + + LinkedList collectors = new LinkedList<>(); + + boolean hasFilterCollector = randomBoolean(); + boolean hasTimeout = randomBoolean(); + + HybridQueryBuilder queryBuilder = new HybridQueryBuilder(); + + TermQueryBuilder termSubQuery = QueryBuilders.termQuery(TEXT_FIELD_NAME, QUERY_TEXT1); + queryBuilder.add(termSubQuery); + + Query query = queryBuilder.toQuery(mockQueryShardContext); + when(searchContext.query()).thenReturn(query); + QuerySearchResult querySearchResult = new QuerySearchResult(); + when(searchContext.queryResult()).thenReturn(querySearchResult); + + hybridQueryPhaseSearcher.searchWith(searchContext, contextIndexSearcher, query, collectors, hasFilterCollector, hasTimeout); + + releaseResources(directory, w, reader); + + verify(hybridQueryPhaseSearcher, atLeastOnce()).searchWithCollector(any(), any(), any(), any(), anyBoolean(), anyBoolean()); + } + @SneakyThrows private void assertQueryResults(TopDocs subQueryTopDocs, List expectedDocIds, IndexReader reader) { assertEquals(expectedDocIds.size(), subQueryTopDocs.totalHits.value); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 680d90b65..f8d274a15 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; +import static org.opensearch.search.aggregations.Aggregations.AGGREGATIONS_FIELD; import java.io.IOException; import java.nio.file.Files; @@ -46,6 +47,7 @@ import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; +import org.opensearch.search.aggregations.AggregationBuilder; import org.opensearch.test.ClusterServiceUtils; import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; @@ -374,6 +376,18 @@ protected Map search( QueryBuilder rescorer, int resultSize, Map requestParams + ) { + return search(index, queryBuilder, rescorer, resultSize, requestParams, null); + } + + @SneakyThrows + protected Map search( + String index, + QueryBuilder queryBuilder, + QueryBuilder rescorer, + int resultSize, + Map requestParams, + AggregationBuilder aggsBuilder ) { XContentBuilder builder = XContentFactory.jsonBuilder().startObject().field("query"); queryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS); @@ -383,6 +397,16 @@ protected Map search( rescorer.toXContent(builder, ToXContent.EMPTY_PARAMS); builder.endObject().endObject(); } + if (aggsBuilder != null) { + builder.startObject("aggs").value(aggsBuilder).endObject(); + /*builder.startObject("aggs") + .startObject("max_index") + .startObject("max") + .field("field", "test-text-field-1") + .endObject() + .endObject() + .endObject();*/ + } builder.endObject(); @@ -425,6 +449,20 @@ protected void addKnnDoc( addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, Collections.emptyList(), Collections.emptyList()); } + @SneakyThrows + protected void addKnnDoc( + String index, + String docId, + List vectorFieldNames, + List vectors, + List textFieldNames, + List texts, + List nestedFieldNames, + List> nestedFields + ) { + addKnnDoc(index, docId, vectorFieldNames, vectors, textFieldNames, texts, nestedFieldNames, nestedFields, Collections.emptyList(), Collections.emptyList()); + } + /** * Add a set of knn vectors and text to an index * @@ -446,7 +484,9 @@ protected void addKnnDoc( List textFieldNames, List texts, List nestedFieldNames, - List> nestedFields + List> nestedFields, + List integerFieldNames, + List integerFieldValues ) { Request request = new Request("POST", "/" + index + "/_doc/" + docId + "?refresh=true"); XContentBuilder builder = XContentFactory.jsonBuilder().startObject(); @@ -467,6 +507,10 @@ protected void addKnnDoc( } builder.endObject(); } + + for (int i = 0; i < integerFieldNames.size(); i++) { + builder.field(integerFieldNames.get(i), integerFieldValues.get(i)); + } builder.endObject(); request.setJsonEntity(builder.toString()); @@ -553,10 +597,20 @@ protected String buildIndexConfiguration(final List knnFieldConf return buildIndexConfiguration(knnFieldConfigs, Collections.emptyList(), numberOfShards); } + @SneakyThrows + protected String buildIndexConfiguration( + final List knnFieldConfigs, + final List nestedFields, + final int numberOfShards + ) { + return buildIndexConfiguration(knnFieldConfigs, nestedFields, Collections.emptyList(), numberOfShards); + } + @SneakyThrows protected String buildIndexConfiguration( final List knnFieldConfigs, final List nestedFields, + final List intFields, final int numberOfShards ) { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() @@ -584,6 +638,10 @@ protected String buildIndexConfiguration( xContentBuilder.startObject(nestedField).field("type", "nested").endObject(); } + for (String intField : intFields) { + xContentBuilder.startObject(intField).field("type", "integer").endObject(); + } + xContentBuilder.endObject().endObject().endObject(); return xContentBuilder.toString(); }