Skip to content

Commit

Permalink
Fixed logic for getting scorers for sub queries in HQ
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Oct 22, 2024
1 parent 23a2967 commit c5c74a7
Show file tree
Hide file tree
Showing 6 changed files with 451 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

Compatible with OpenSearch 2.18.0

### Bug Fixes
- Fixed incorrect document order for nested aggregations in hybrid query ([#956](https://github.com/opensearch-project/neural-search/pull/956))
### Enhancements
- Implement `ignore_missing` field in text chunking processors ([#907](https://github.com/opensearch-project/neural-search/pull/907))
- Added rescorer in hybrid query ([#917](https://github.com/opensearch-project/neural-search/pull/917))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,13 @@ public int advanceShallow(int target) throws IOException {
*/
@Override
public float score() throws IOException {
return score(getSubMatches());
}

private float score(DisiWrapper topList) throws IOException {
float totalScore = 0.0f;
for (DisiWrapper disiWrapper : subScorersPQ) {
for (DisiWrapper disiWrapper = topList; disiWrapper != null; disiWrapper = disiWrapper.next) {
// check if this doc has match in the subQuery. If not, add score as 0.0 and continue
if (disiWrapper.scorer.docID() == DocIdSetIterator.NO_MORE_DOCS) {
continue;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,36 +12,44 @@
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.index.query.TermQueryBuilder;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.script.Script;
import org.opensearch.search.aggregations.AggregationBuilder;
import org.opensearch.search.aggregations.AggregationBuilders;
import org.opensearch.search.aggregations.BucketOrder;
import org.opensearch.search.aggregations.PipelineAggregatorBuilders;
import org.opensearch.search.aggregations.bucket.histogram.DateHistogramInterval;
import org.opensearch.search.aggregations.pipeline.AvgBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.BucketMetricsPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MaxBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.MinBucketPipelineAggregationBuilder;
import org.opensearch.search.aggregations.pipeline.SumBucketPipelineAggregationBuilder;
import org.opensearch.search.sort.SortBuilders;
import org.opensearch.search.sort.SortOrder;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;

import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationBuckets;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValue;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregationValues;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getAggregations;
import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits;
import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQuery;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE;
import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;

/**
* Integration tests for base scenarios when aggregations are combined with hybrid query
*/
public class HybridQueryAggregationsIT extends BaseNeuralSearchIT {
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS = "test-hybrid-aggs-multi-doc-index-multiple-shards";
private static final String TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD = "test-hybrid-aggs-multi-doc-index-single-shard";
private static final String TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS = "test-hybrid-nested-aggs-multi-doc-index";
private static final String TEST_QUERY_TEXT3 = "hello";
private static final String TEST_QUERY_TEXT4 = "everyone";
private static final String TEST_QUERY_TEXT5 = "welcome";
Expand Down Expand Up @@ -86,6 +94,12 @@ public class HybridQueryAggregationsIT extends BaseNeuralSearchIT {
private static final String BUCKETS_AGGREGATION_NAME_2 = "date_buckets_2";
private static final String BUCKETS_AGGREGATION_NAME_3 = "date_buckets_3";
private static final String BUCKETS_AGGREGATION_NAME_4 = "date_buckets_4";
protected static final String FLOAT_FIELD_NAME_IMDB = "imdb";
protected static final String KEYWORD_FIELD_NAME_ACTOR = "actor";
protected static final String CARDINALITY_OF_UNIQUE_NAMES = "cardinality_of_unique_names";
protected static final String UNIQUE_NAMES = "unique_names";
protected static final String AGGREGATION_NAME_MAX_SCORE = "max_score";
protected static final String AGGREGATION_NAME_TOP_DOC = "top_doc";

@Before
public void setUp() throws Exception {
Expand Down Expand Up @@ -464,6 +478,186 @@ public void testPostFilterOnIndexWithSingleShards_WhenConcurrentSearchEnabled_th
testPostFilterWithComplexHybridQuery(true, true);
}

@SneakyThrows
public void testNestedAggs_whenMultipleShardsAndConcurrentSearchDisabled_thenSuccessful() {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false);
try {
prepareResourcesForNestegAggregationsScenario(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS);
assertNestedAggregations(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS);
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}

@SneakyThrows
public void testNestedAggs_whenMultipleShardsAndConcurrentSearchEnabled_thenSuccessful() {
updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, true);
try {
prepareResourcesForNestegAggregationsScenario(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS);
assertNestedAggregations(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS);
} finally {
wipeOfTestResources(TEST_MULTI_DOC_INDEX_FOR_NESTED_AGGS_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE);
}
}

private void prepareResourcesForNestegAggregationsScenario(String index) throws Exception {
if (!indexExists(index)) {
createIndexWithConfiguration(
index,
buildIndexConfiguration(
List.of(new KNNFieldConfig("location", 2, TEST_SPACE_TYPE)),
List.of(),
List.of(),
List.of(FLOAT_FIELD_NAME_IMDB),
List.of(KEYWORD_FIELD_NAME_ACTOR),
List.of(),
3
),
""
);

String ingestBulkPayload = Files.readString(Path.of(classLoader.getResource("processor/ingest_bulk.json").toURI()))
.replace("\"{indexname}\"", "\"" + index + "\"");

bulkIngest(ingestBulkPayload, null);
}
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
}

private void assertNestedAggregations(String index) {
/* constructing following search query
{
"from": 0,
"aggs": {
"cardinality_of_unique_names": {
"cardinality": {
"field": "actor"
}
},
"unique_names": {
"terms": {
"field": "actor",
"size": 10,
"order": {
"max_score": "desc"
}
},
"aggs": {
"top_doc": {
"top_hits": {
"size": 1,
"sort": [
{
"_score": {
"order": "desc"
}
}
]
}
},
"max_score": {
"max": {
"script": {
"source": "_score"
}
}
}
}
}
},
"query": {
"hybrid": {
"queries": [
{
"match": {
"actor": "anil"
}
},
{
"range": {
"imdb": {
"gte": 1.0,
"lte": 10.0
}
}
}
]}}}
*/

QueryBuilder rangeFilterQuery = QueryBuilders.rangeQuery(FLOAT_FIELD_NAME_IMDB).gte(1.0).lte(10.0);
QueryBuilder matchQuery = QueryBuilders.matchQuery(KEYWORD_FIELD_NAME_ACTOR, "anil");
HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(matchQuery).add(rangeFilterQuery);

AggregationBuilder aggsBuilderCardinality = AggregationBuilders.cardinality(CARDINALITY_OF_UNIQUE_NAMES)
.field(KEYWORD_FIELD_NAME_ACTOR);
AggregationBuilder aggsBuilderUniqueNames = AggregationBuilders.terms(UNIQUE_NAMES)
.field(KEYWORD_FIELD_NAME_ACTOR)
.size(10)
.order(BucketOrder.aggregation(AGGREGATION_NAME_MAX_SCORE, false))
.subAggregation(
AggregationBuilders.topHits(AGGREGATION_NAME_TOP_DOC).size(1).sort(SortBuilders.scoreSort().order(SortOrder.DESC))
)
.subAggregation(AggregationBuilders.max(AGGREGATION_NAME_MAX_SCORE).script(new Script("_score")));

Map<String, Object> searchResponseAsMap = search(
index,
hybridQueryBuilder,
null,
10,
Map.of("search_pipeline", SEARCH_PIPELINE),
List.of(aggsBuilderCardinality, aggsBuilderUniqueNames),
rangeFilterQuery,
null,
false,
null,
0
);
assertNotNull(searchResponseAsMap);

// assert actual results
// aggregations
Map<String, Object> aggregations = getAggregations(searchResponseAsMap);
assertNotNull(aggregations);

int cardinalityValue = getAggregationValue(aggregations, CARDINALITY_OF_UNIQUE_NAMES);
assertEquals(7, cardinalityValue);

Map<String, Object> uniqueAggValue = getAggregationValues(aggregations, UNIQUE_NAMES);
assertEquals(3, uniqueAggValue.size());
assertEquals(0, uniqueAggValue.get("doc_count_error_upper_bound"));
assertEquals(0, uniqueAggValue.get("sum_other_doc_count"));

List<Map<String, Object>> buckets = getAggregationBuckets(aggregations, UNIQUE_NAMES);
assertNotNull(buckets);
assertEquals(7, buckets.size());

// check content of few buckets
Map<String, Object> firstBucket = buckets.get(0);
assertEquals(4, firstBucket.size());
assertEquals("anil", firstBucket.get(KEY));
assertEquals(42, firstBucket.get(BUCKET_AGG_DOC_COUNT_FIELD));
assertNotNull(getAggregationValue(firstBucket, AGGREGATION_NAME_MAX_SCORE));
assertTrue((double) getAggregationValue(firstBucket, AGGREGATION_NAME_MAX_SCORE) > 1.0f);

Map<String, Object> secondBucket = buckets.get(1);
assertEquals(4, secondBucket.size());
assertEquals("abhishek", secondBucket.get(KEY));
assertEquals(8, secondBucket.get(BUCKET_AGG_DOC_COUNT_FIELD));
assertNotNull(getAggregationValue(secondBucket, AGGREGATION_NAME_MAX_SCORE));
assertEquals(1.0, getAggregationValue(secondBucket, AGGREGATION_NAME_MAX_SCORE), DELTA_FOR_SCORE_ASSERTION);

Map<String, Object> lastBucket = buckets.get(buckets.size() - 1);
assertEquals(4, lastBucket.size());
assertEquals("sanjay", lastBucket.get(KEY));
assertEquals(7, lastBucket.get(BUCKET_AGG_DOC_COUNT_FIELD));
assertNotNull(getAggregationValue(lastBucket, AGGREGATION_NAME_MAX_SCORE));
assertEquals(1.0, getAggregationValue(lastBucket, AGGREGATION_NAME_MAX_SCORE), DELTA_FOR_SCORE_ASSERTION);

// assert the hybrid query scores
assertHitResultsFromQuery(10, 92, searchResponseAsMap);
}

private void testMaxAggsOnSingleShardCluster() throws Exception {
try {
prepareResourcesForSingleShardIndex(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_SINGLE_SHARD, SEARCH_PIPELINE);
Expand Down Expand Up @@ -501,8 +695,6 @@ private void testDateRange() throws IOException {
try {
initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS);
createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE);
// try {
// prepareResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SEARCH_PIPELINE);

AggregationBuilder aggsBuilder = AggregationBuilders.dateRange(DATE_AGGREGATION_NAME)
.field(DATE_FIELD_1)
Expand Down
Loading

0 comments on commit c5c74a7

Please sign in to comment.