Skip to content

Commit

Permalink
Add case for null/NaN scores and minor refactoring
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 19, 2024
1 parent 2c5e6d0 commit 86c3263
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ public SearchResponse processResponse(
);
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
finalScore,
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {

@Override
public String describe() {
return String.format(Locale.ROOT, "%s, rank_constant %s", TECHNIQUE_NAME, rankConstant);
return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,10 @@
import java.util.TreeMap;

import static org.mockito.Mockito.mock;
import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION;
import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;

public class ExplanationPayloadProcessorTests extends OpenSearchTestCase {
public class ExplanationResponseProcessorTests extends OpenSearchTestCase {
private static final String PROCESSOR_TAG = "mockTag";
private static final String DESCRIPTION = "mockDescription";

Expand Down Expand Up @@ -192,6 +193,119 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces
assertOnExplanationResults(processedResponse, maxScore);
}

@SneakyThrows
public void testProcessResponse_whenNullSearchHits_thenNoOp() {
ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);
SearchResponse searchResponse = getSearchResponse(null);
PipelineProcessingContext context = new PipelineProcessingContext();

SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context);
assertEquals(searchResponse, processedResponse);
}

@SneakyThrows
public void testProcessResponse_whenEmptySearchHits_thenNoOp() {
ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);
SearchHits emptyHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f);
SearchResponse searchResponse = getSearchResponse(emptyHits);
PipelineProcessingContext context = new PipelineProcessingContext();

SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context);
assertEquals(searchResponse, processedResponse);
}

@SneakyThrows
public void testProcessResponse_whenNullExplanation_thenSkipProcessing() {
ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);
SearchHits searchHits = getSearchHits(1.0f);
for (SearchHit hit : searchHits.getHits()) {
hit.explanation(null);
}
SearchResponse searchResponse = getSearchResponse(searchHits);
PipelineProcessingContext context = new PipelineProcessingContext();

SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context);
assertEquals(searchResponse, processedResponse);
}

@SneakyThrows
public void testProcessResponse_whenInvalidExplanationPayload_thenHandleGracefully() {
ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);
SearchHits searchHits = getSearchHits(1.0f);
SearchResponse searchResponse = getSearchResponse(searchHits);
PipelineProcessingContext context = new PipelineProcessingContext();

// Set invalid payload
Map<ExplanationPayload.PayloadType, Object> invalidPayload = Map.of(
ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR,
"invalid payload"
);
ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(invalidPayload).build();
context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload);

SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context);
assertNotNull(processedResponse);
}

@SneakyThrows
public void testProcessResponse_whenZeroScore_thenProcessCorrectly() {
ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);
SearchHits searchHits = getSearchHits(0.0f);
SearchResponse searchResponse = getSearchResponse(searchHits);
PipelineProcessingContext context = new PipelineProcessingContext();

Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = getCombinedExplainDetails(searchHits);
Map<ExplanationPayload.PayloadType, Object> explainPayload = Map.of(
ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR,
combinedExplainDetails
);
ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build();
context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload);

SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context);
assertNotNull(processedResponse);
assertEquals(0.0f, processedResponse.getHits().getMaxScore(), DELTA_FOR_SCORE_ASSERTION);
}

@SneakyThrows
public void testProcessResponse_whenScoreIsNaN_thenExplanationUsesZero() {
ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false);
SearchRequest searchRequest = mock(SearchRequest.class);

// Create SearchHits with NaN score
SearchHits searchHits = getSearchHits(Float.NaN);
SearchResponse searchResponse = getSearchResponse(searchHits);
PipelineProcessingContext context = new PipelineProcessingContext();

// Setup explanation payload
Map<SearchShard, List<CombinedExplanationDetails>> combinedExplainDetails = getCombinedExplainDetails(searchHits);
Map<ExplanationPayload.PayloadType, Object> explainPayload = Map.of(
ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR,
combinedExplainDetails
);
ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build();
context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload);

// Process response
SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context);

// Verify results
assertNotNull(processedResponse);
SearchHit[] hits = processedResponse.getHits().getHits();
assertNotNull(hits);
assertTrue(hits.length > 0);

// Verify that the explanation uses 0.0f when input score was NaN
Explanation explanation = hits[0].getExplanation();
assertNotNull(explanation);
assertEquals(0.0f, (float) explanation.getValue(), DELTA_FOR_FLOATS_ASSERTION);
}

private static SearchHits getSearchHits(float maxScore) {
int numResponses = 1;
int numIndices = 2;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,13 @@ public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase {
private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678");

public void testDescribe() {
// verify with default values for parameters
RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil);
assertEquals("rrf, rank_constant 60", normalizationTechnique.describe());
assertEquals("rrf, rank_constant [60]", normalizationTechnique.describe());

// verify when parameter values are set
normalizationTechnique = new RRFNormalizationTechnique(Map.of("rank_constant", 25), scoreNormalizationUtil);
assertEquals("rrf, rank_constant [25]", normalizationTechnique.describe());
}

public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -633,7 +633,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() {
// two sub-queries meaning we do have two detail objects with separate query level details
Map<String, Object> hit1DetailsForHit1 = hit1Details.get(0);
assertTrue((double) hit1DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION);
assertEquals("rrf, rank_constant 60 normalization of:", hit1DetailsForHit1.get("description"));
assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit1.get("description"));
assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size());

Map<String, Object> explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0);
Expand All @@ -643,7 +643,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() {

Map<String, Object> hit1DetailsForHit2 = hit1Details.get(1);
assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f);
assertEquals("rrf, rank_constant 60 normalization of:", hit1DetailsForHit2.get("description"));
assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit2.get("description"));
assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size());

Map<String, Object> explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0);
Expand All @@ -663,12 +663,12 @@ public void testExplain_whenRRFProcessor_thenSuccessful() {

Map<String, Object> hit2DetailsForHit1 = hit2Details.get(0);
assertTrue((double) hit2DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION);
assertEquals("rrf, rank_constant 60 normalization of:", hit2DetailsForHit1.get("description"));
assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit1.get("description"));
assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size());

Map<String, Object> hit2DetailsForHit2 = hit2Details.get(1);
assertTrue((double) hit2DetailsForHit2.get("value") > DELTA_FOR_SCORE_ASSERTION);
assertEquals("rrf, rank_constant 60 normalization of:", hit2DetailsForHit2.get("description"));
assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit2.get("description"));
assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size());

// hit 3
Expand All @@ -683,7 +683,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() {

Map<String, Object> hit3DetailsForHit1 = hit3Details.get(0);
assertTrue((double) hit3DetailsForHit1.get("value") > .0f);
assertEquals("rrf, rank_constant 60 normalization of:", hit3DetailsForHit1.get("description"));
assertEquals("rrf, rank_constant [60] normalization of:", hit3DetailsForHit1.get("description"));
assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size());

Map<String, Object> explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0);
Expand All @@ -703,7 +703,7 @@ public void testExplain_whenRRFProcessor_thenSuccessful() {

Map<String, Object> hit4DetailsForHit1 = hit4Details.get(0);
assertTrue((double) hit4DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION);
assertEquals("rrf, rank_constant 60 normalization of:", hit4DetailsForHit1.get("description"));
assertEquals("rrf, rank_constant [60] normalization of:", hit4DetailsForHit1.get("description"));
assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size());

Map<String, Object> explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0);
Expand Down

0 comments on commit 86c3263

Please sign in to comment.