Skip to content

Commit

Permalink
Fixed logic for getting scorers for sub queries in HQ
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 22, 2024
1 parent 23a2967 commit e232396
Show file tree
Hide file tree
Showing 6 changed files with 425 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Compatible with OpenSearch 2.18.0

### Bug Fixes
- Fixed incorrect document order for nested aggregations in hybrid query ([#956](https://github.com/opensearch-project/neural-search/pull/956))
### Enhancements
- Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907))
- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ public int advanceShallow(int target) throws IOException {
*/
@Override
public float score() throws IOException {
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper : subScorersPQ) {
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,44 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.script.Script;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.PipelineAggregatorBuilders;
import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval;
import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MinBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits;
import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQuery;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE;
import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;

/**
* Integration tests for base scenarios when aggregations are combined with hybrid query
*/
public class HybridQueryAggregationsIT extends BaseNeuralSearchIT {
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-hybrid-aggs-multi-doc-index-multiple-shards";
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-hybrid-aggs-multi-doc-index-single-shard";
private static final String TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS = "test-hybrid-nested-aggs-multi-doc-index";
private static final String TEST_QUERY_TEXT3 = "hello";
private static final String TEST_QUERY_TEXT4 = "everyone";
private static final String TEST_QUERY_TEXT5 = "welcome";
Expand Down Expand Up @@ -464,6 +472,166 @@ public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_th
testPostFilterWithComplexHybridQuery(true, true);
}

@SneakyThrows
public void testNestedAggs_whenMultipleShardsAndConcurrentSearchDisabled_thenSuccessful() {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
try {
if (!indexExists(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS)) {
createIndexWithConfiguration(
TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS,
buildIndexConfiguration(
List.of(new KNNFieldConfig("location", 2, TEST_SPACE_TYPE)),
List.of(),
List.of(),
List.of("imdb"),
List.of("actor"),
List.of(),
3
),
""
);

String ingestBulkPayload = Files.readString(Path.of(classLoader.getResource("processor/ingest_bulk.json").toURI()))
.replace("\"{indexname}\"", "\"" + TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS + "\"");

bulkIngest(ingestBulkPayload, null);
}
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);

/* constructing following search query
{
"from": 0,
"aggs": {
"cardinality_of_unique_names": {
"cardinality": {
"field": "actor"
}
},
"unique_names": {
"terms": {
"field": "actor",
"size": 10,
"order": {
"max_score": "desc"
}
},
"aggs": {
"top_doc": {
"top_hits": {
"size": 1,
"sort": [
{
"_score": {
"order": "desc"
}
}
]
}
},
"max_score": {
"max": {
"script": {
"source": "_score"
}
}
}
}
}
},
"query": {
"hybrid": {
"queries": [
{
"match": {
"actor": "anil"
}
},
{
"range": {
"imdb": {
"gte": 1.0,
"lte": 10.0
}
}
}
]}}}
*/

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery("imdb").gte(1.0).lte(10.0);
QueryBuilder matchQuery = QueryBuilders.matchQuery("actor", "anil");
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(matchQuery).add(rangeFilterQuery);

AggregationBuilder aggsBuilderCardinality = AggregationBuilders.cardinality("cardinality_of_unique_names").field("actor");
AggregationBuilder aggsBuilderUniqueNames = AggregationBuilders.terms("unique_names")
.field("actor")
.size(10)
.order(BucketOrder.aggregation("max_score", false))
.subAggregation(AggregationBuilders.topHits("top_doc").size(1).sort(SortBuilders.scoreSort().order(SortOrder.DESC)))
.subAggregation(AggregationBuilders.max("max_score").script(new Script("_score")));

Map<String, Object> searchResponseAsMap = search(
TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS,
hybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
List.of(aggsBuilderCardinality, aggsBuilderUniqueNames),
rangeFilterQuery,
null,
false,
null,
0
);
assertNotNull(searchResponseAsMap);

// assert actual results
// aggregations
Map<String, Object> aggregations = getAggregations(searchResponseAsMap);
assertNotNull(aggregations);

int cardinalityValue = getAggregationValue(aggregations, "cardinality_of_unique_names");
assertEquals(7, cardinalityValue);

Map<String, Object> uniqueAggValue = getAggregationValues(aggregations, "unique_names");
assertEquals(3, uniqueAggValue.size());
assertEquals(0, uniqueAggValue.get("doc_count_error_upper_bound"));
assertEquals(0, uniqueAggValue.get("sum_other_doc_count"));

List<Map<String, Object>> buckets = getAggregationBuckets(aggregations, "unique_names");
assertNotNull(buckets);
assertEquals(7, buckets.size());

// check content of few buckets
Map<String, Object> firstBucket = buckets.get(0);
assertEquals(4, firstBucket.size());
assertEquals("anil", firstBucket.get(KEY));
assertEquals(42, firstBucket.get("doc_count"));
assertNotNull(getAggregationValue(firstBucket, "max_score"));
assertTrue((double) getAggregationValue(firstBucket, "max_score") > 1.0f);

Map<String, Object> secondBucket = buckets.get(1);
assertEquals(4, secondBucket.size());
assertEquals("abhishek", secondBucket.get(KEY));
assertEquals(8, secondBucket.get("doc_count"));
assertNotNull(getAggregationValue(secondBucket, "max_score"));
assertEquals(1.0, getAggregationValue(secondBucket, "max_score"), DELTA_FOR_SCORE_ASSERTION);

Map<String, Object> lastBucket = buckets.get(buckets.size() - 1);
assertEquals(4, lastBucket.size());
assertEquals("sanjay", lastBucket.get(KEY));
assertEquals(7, lastBucket.get("doc_count"));
assertNotNull(getAggregationValue(lastBucket, "max_score"));
assertEquals(1.0, getAggregationValue(lastBucket, "max_score"), DELTA_FOR_SCORE_ASSERTION);

// assert the hybrid query scores
assertHitResultsFromQuery(10, 92, searchResponseAsMap);

} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}

private void testMaxAggsOnSingleShardCluster() throws Exception {
try {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
Expand Down Expand Up @@ -501,8 +669,6 @@ private void testDateRange() throws IOException {
try {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
// try {
// prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);

AggregationBuilder aggsBuilder = AggregationBuilders.dateRange(DATE_AGGREGATION_NAME)
.field(DATE_FIELD_1)
Expand Down
Loading

0 comments on commit e232396

Please sign in to comment.