From 64c29cf746378d95132d4008dda50ae91022607c Mon Sep 17 00:00:00 2001
From: "opensearch-trigger-bot[bot]"
 <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com>
Date: Tue, 28 May 2024 21:00:05 -0700
Subject: [PATCH] Optimize the max score tracking in the Query Phase of Hybrid
 Search (#765) (#767)

(cherry picked from commit 806042cc59a32138bb34990b8d2c8e328822bbaa)

Co-authored-by: Varun Jain <varunudr@amazon.com>
---
 CHANGELOG.md                                  |  1 +
 .../search/HybridTopScoreDocCollector.java    |  3 ++
 .../search/query/HybridCollectorManager.java  | 15 +---------
 .../neuralsearch/query/HybridQueryIT.java     | 30 +++++++++++++++++++
 4 files changed, 35 insertions(+), 14 deletions(-)

diff --git a/CHANGELOG.md b/CHANGELOG.md
index ec655eae4..452f3664c 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
 - Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
 - Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
 - Use lazy initialization for priority queue of hits and scores to improve latencies by 20% ([#746](https://github.com/opensearch-project/neural-search/pull/746))
+- Optimize max score calculation in the Query Phase of the Hybrid Search ([765](https://github.com/opensearch-project/neural-search/pull/765))
 ### Bug Fixes
 - Total hit count fix in Hybrid Query ([756](https://github.com/opensearch-project/neural-search/pull/756))
 ### Infrastructure
diff --git a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java
index 308756909..85fc15bf4 100644
--- a/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java
+++ b/src/main/java/org/opensearch/neuralsearch/search/HybridTopScoreDocCollector.java
@@ -39,6 +39,8 @@ public class HybridTopScoreDocCollector implements Collector {
     private int[] collectedHitsPerSubQuery;
     private final int numOfHits;
     private PriorityQueue<ScoreDoc>[] compoundScores;
+    @Getter
+    private float maxScore = 0.0f;
 
     public HybridTopScoreDocCollector(int numHits, HitsThresholdChecker hitsThresholdChecker) {
         numOfHits = numHits;
@@ -115,6 +117,7 @@ public void collect(int doc) throws IOException {
                     collectedHitsPerSubQuery[i]++;
                     PriorityQueue<ScoreDoc> pq = compoundScores[i];
                     ScoreDoc currentDoc = new ScoreDoc(doc + docBase, score);
+                    maxScore = Math.max(currentDoc.score, maxScore);
                     // this way we're inserting into heap and do nothing else unless we reach the capacity
                     // after that we pull out the lowest score element on each insert
                     pq.insertWithOverflow(currentDoc);
diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java
index 120cd1428..456aa2def 100644
--- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java
+++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridCollectorManager.java
@@ -146,8 +146,7 @@ public ReduceableSearchResult reduce(Collection<Collector> collectors) {
                 getTotalHits(this.trackTotalHitsUpTo, topDocs, isSingleShard, hybridTopScoreDocCollector.getTotalHits()),
                 topDocs
             );
-            float maxScore = getMaxScore(topDocs);
-            TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore);
+            TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, hybridTopScoreDocCollector.getMaxScore());
             return (QuerySearchResult result) -> { result.topDocs(topDocsAndMaxScore, getSortValueFormats(sortAndFormats)); };
         }
         throw new IllegalStateException("cannot collect results of hybrid search query, there are no proper score collectors");
@@ -212,18 +211,6 @@ private TotalHits getTotalHits(
         return new TotalHits(maxTotalHits, relation);
     }
 
-    private float getMaxScore(final List<TopDocs> topDocs) {
-        if (topDocs.isEmpty()) {
-            return 0.0f;
-        } else {
-            return topDocs.stream()
-                .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0])
-                .map(scoreDoc -> scoreDoc.score)
-                .max(Float::compare)
-                .get();
-        }
-    }
-
     private DocValueFormat[] getSortValueFormats(final SortAndFormats sortAndFormats) {
         return sortAndFormats == null ? null : sortAndFormats.formats;
     }
diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
index 15e941ff2..f99710c51 100644
--- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
+++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java
@@ -206,6 +206,36 @@ public void testTotalHits_whenResultSizeIsLessThenDefaultSize_thenSuccessful() {
         assertEquals(RELATION_EQUAL_TO, total.get("relation"));
     }
 
+    @SneakyThrows
+    public void testMaxScoreCalculation_whenMaxScoreIsTrackedAtCollectorLevel_thenSuccessful() {
+        initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME);
+        TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3);
+        TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4);
+        TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5);
+        BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
+        boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3);
+
+        HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder();
+        hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1);
+        hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder);
+        Map<String, Object> searchResponseAsMap = search(
+            TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME,
+            hybridQueryBuilderNeuralThenTerm,
+            null,
+            10,
+            null
+        );
+
+        double maxScore = getMaxScore(searchResponseAsMap).get();
+        List<Map<String, Object>> hits = getNestedHits(searchResponseAsMap);
+        double maxScoreExpected = 0.0;
+        for (Map<String, Object> hit : hits) {
+            double score = (double) hit.get("_score");
+            maxScoreExpected = Math.max(score, maxScoreExpected);
+        }
+        assertEquals(maxScoreExpected, maxScore, 0.0000001);
+    }
+
     /**
      * Tests complex query with multiple nested sub-queries, where some sub-queries are same
      * {