From 52635c803535653aef9cc2c7ba488e69b9ed98ec Mon Sep 17 00:00:00 2001 From: Bo Xu Date: Fri, 17 May 2024 05:07:06 +0000 Subject: [PATCH] Round to 5 decimal for embedding score to prevent flakiness (#4249) --- nl_server/search.py | 8 ++++++-- server/integration_tests/explore_test.py | 19 ++++++++----------- .../debug_info.json | 4 ++-- .../compareobesityvs.poverty/debug_info.json | 4 ++-- .../debug_info.json | 8 ++++---- .../debug_info.json | 8 ++++---- .../detection_api_reranking/debug_info.json | 2 +- 7 files changed, 27 insertions(+), 26 deletions(-) diff --git a/nl_server/search.py b/nl_server/search.py index db0f7e76c2..76adfb6588 100644 --- a/nl_server/search.py +++ b/nl_server/search.py @@ -88,8 +88,12 @@ def _rank_vars(candidates: EmbeddingsResult, dvars.SentenceScore(sentence=c.sentence, score=c.score)) sv2sentences[dcid].add(c.sentence) - for sv, score in sorted(sv2score.items(), - key=lambda item: (-item[1], item[0])): + # TODO: truncate the score based on model parameters from yaml + # Same model would produce different scores after certain decimals, so we want + # to round to 6 decimal places to make the score and rank stable. + sorted_score = sorted(sv2score.items(), + key=lambda item: (-round(item[1], 6), item[0])) + for sv, score in sorted_score: result.svs.append(sv) result.scores.append(score) diff --git a/server/integration_tests/explore_test.py b/server/integration_tests/explore_test.py index 8c41c5e33a..77aa585f3f 100644 --- a/server/integration_tests/explore_test.py +++ b/server/integration_tests/explore_test.py @@ -63,7 +63,6 @@ def run_detection(self, d = '' else: d = re.sub(r'[ ?"]', '', q).lower() - print(d) self.handle_response(q, resp, test_dir, d, failure, check_detection) def run_detect_and_fulfill(self, @@ -316,16 +315,14 @@ def test_detection_bio(self): check_detection=True) def test_detection_multivar(self): - self.run_detection( - 'detection_api_multivar', - [ - 'number of poor hispanic women with phd', - # 'compare obesity vs. poverty', - 'show me the impact of climate change on drought', - 'how are factors like obesity, blood pressure and asthma impacted by climate change', - 'Compare "Male population" with "Female Population"', - ], - check_detection=True) + self.run_detection('detection_api_multivar', [ + 'number of poor hispanic women with phd', + 'compare obesity vs. poverty', + 'show me the impact of climate change on drought', + 'how are factors like obesity, blood pressure and asthma impacted by climate change', + 'Compare "Male population" with "Female Population"', + ], + check_detection=True) def test_detection_context(self): self.run_detection('detection_api_context', [ diff --git a/server/integration_tests/test_data/detection_api_multivar/comparemalepopulationwithfemalepopulation/debug_info.json b/server/integration_tests/test_data/detection_api_multivar/comparemalepopulationwithfemalepopulation/debug_info.json index c44935458e..71605661ae 100644 --- a/server/integration_tests/test_data/detection_api_multivar/comparemalepopulationwithfemalepopulation/debug_info.json +++ b/server/integration_tests/test_data/detection_api_multivar/comparemalepopulationwithfemalepopulation/debug_info.json @@ -90,7 +90,7 @@ { "CosineScore": [ 0.8874201774597168, - 0.8723466992378235 + 0.8723466396331787 ], "QueryPart": "population", "SV": [ @@ -123,7 +123,7 @@ }, { "CosineScore": [ - 0.9170327186584473 + 0.917032778263092 ], "QueryPart": "population female population", "SV": [ diff --git a/server/integration_tests/test_data/detection_api_multivar/compareobesityvs.poverty/debug_info.json b/server/integration_tests/test_data/detection_api_multivar/compareobesityvs.poverty/debug_info.json index 35aa9ce57a..77d741517a 100644 --- a/server/integration_tests/test_data/detection_api_multivar/compareobesityvs.poverty/debug_info.json +++ b/server/integration_tests/test_data/detection_api_multivar/compareobesityvs.poverty/debug_info.json @@ -26,8 +26,8 @@ 0.7139475345611572, 0.7126951217651367, 0.7112098932266235, - 0.7085551023483276, 0.7085550427436829, + 0.7085551023483276, 0.7080719470977783, 0.7078822255134583, 0.7053750157356262, @@ -89,8 +89,8 @@ "dc/topic/sdg_1", "Count_Household_FamilyHousehold_BelowPovertyLevelInThePast12Months", "dc/4lvmzr1h1ylk1", - "dc/topic/sdg_2.1.2", "dc/topic/FoodInsecurity", + "dc/topic/sdg_2.1.2", "WHO/NCD_BMI_18A", "Count_Household_WithoutFoodStampsInThePast12Months_AbovePovertyLevelInThePast12Months", "WHO/NUTOVERWEIGHTPREV", diff --git a/server/integration_tests/test_data/detection_api_multivar/numberofpoorhispanicwomenwithphd/debug_info.json b/server/integration_tests/test_data/detection_api_multivar/numberofpoorhispanicwomenwithphd/debug_info.json index 93376d4267..88e50545fd 100644 --- a/server/integration_tests/test_data/detection_api_multivar/numberofpoorhispanicwomenwithphd/debug_info.json +++ b/server/integration_tests/test_data/detection_api_multivar/numberofpoorhispanicwomenwithphd/debug_info.json @@ -86,8 +86,8 @@ }, { "CosineScore": [ - 0.8318259716033936, - 0.8029923439025879 + 0.8318256735801697, + 0.8029924631118774 ], "QueryPart": "women phd", "SV": [ @@ -103,8 +103,8 @@ "Parts": [ { "CosineScore": [ - 0.8433718085289001, - 0.8265039920806885 + 0.8433718681335449, + 0.8265039324760437 ], "QueryPart": "number poor hispanic women", "SV": [ diff --git a/server/integration_tests/test_data/detection_api_multivar/showmetheimpactofclimatechangeondrought/debug_info.json b/server/integration_tests/test_data/detection_api_multivar/showmetheimpactofclimatechangeondrought/debug_info.json index d8122377c8..a92bf6e5a3 100644 --- a/server/integration_tests/test_data/detection_api_multivar/showmetheimpactofclimatechangeondrought/debug_info.json +++ b/server/integration_tests/test_data/detection_api_multivar/showmetheimpactofclimatechangeondrought/debug_info.json @@ -41,7 +41,7 @@ "Parts": [ { "CosineScore": [ - 0.8787485361099243 + 0.8787486553192139 ], "QueryPart": "show climate change", "SV": [ @@ -67,8 +67,8 @@ "CosineScore": [ 0.794236958026886, 0.7929013967514038, - 0.7586166858673096, - 0.7583958506584167 + 0.7586168050765991, + 0.758395791053772 ], "QueryPart": "show climate", "SV": [ @@ -80,7 +80,7 @@ }, { "CosineScore": [ - 0.8665127754211426 + 0.866512656211853 ], "QueryPart": "change drought", "SV": [ diff --git a/server/integration_tests/test_data/detection_api_reranking/debug_info.json b/server/integration_tests/test_data/detection_api_reranking/debug_info.json index b7461002eb..fe8b9b4c25 100644 --- a/server/integration_tests/test_data/detection_api_reranking/debug_info.json +++ b/server/integration_tests/test_data/detection_api_reranking/debug_info.json @@ -378,7 +378,7 @@ "sv_detection_query_index_type": "base_uae_mem", "sv_detection_query_input": "population that is rich in", "sv_detection_query_stop_words_removal": "population rich", - "time_var_reranking": 3.320078134536743 + "time_var_reranking": 3.4432520866394043 } } } \ No newline at end of file