diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 74ae0621a..01c1516d2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -89,9 +89,9 @@ public SearchResponse processResponse( for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { normalizedExplanation[i] = Explanation.match( // normalized score - normalizationExplanation.scoreDetails().get(i).getKey(), + normalizationExplanation.getScoreDetails().get(i).getKey(), // description of normalized score - normalizationExplanation.scoreDetails().get(i).getValue(), + normalizationExplanation.getScoreDetails().get(i).getValue(), // shard level details queryLevelExplanation.getDetails()[i] ); @@ -99,7 +99,7 @@ public SearchResponse processResponse( Explanation finalExplanation = Explanation.match( searchHit.getScore(), // combination level explanation is always a single detail - combinationExplanation.scoreDetails().get(0).getValue(), + combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation ); searchHit.explanation(finalExplanation); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 1a958676a..078c68aff 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -125,7 +125,7 @@ private void explain(NormalizationProcessorWorkflowExecuteRequest request, List< Map> combinedExplanations = combinationExplain.entrySet() .stream() .collect(Collectors.toMap(Map.Entry::getKey, entry -> entry.getValue().stream().map(explainDetail -> { - DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.docId(), entry.getKey()); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey()); return CombinedExplanationDetails.builder() .normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)) .combinationExplanations(explainDetail) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java index 505b19ae0..c875eab55 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -4,12 +4,19 @@ */ package org.opensearch.neuralsearch.processor; +import lombok.AllArgsConstructor; +import lombok.Value; import org.opensearch.search.SearchShardTarget; /** * DTO class to store index, shardId and nodeId for a search shard. */ -public record SearchShard(String index, int shardId, String nodeId) { +@Value +@AllArgsConstructor +public class SearchShard { + String index; + int shardId; + String nodeId; /** * Create SearchShard from SearchShardTarget diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java index 9ce4ebf97..51550e523 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java @@ -4,13 +4,15 @@ */ package org.opensearch.neuralsearch.processor.explain; +import lombok.Value; import org.opensearch.neuralsearch.processor.SearchShard; /** * DTO class to store docId and search shard for a query. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. - * @param docId - * @param searchShard */ -public record DocIdAtSearchShard(int docId, SearchShard searchShard) { +@Value +public class DocIdAtSearchShard { + int docId; + SearchShard searchShard; } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java index fe009f383..e577e6f43 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -4,6 +4,8 @@ */ package org.opensearch.neuralsearch.processor.explain; +import lombok.AllArgsConstructor; +import lombok.Value; import org.apache.commons.lang3.tuple.Pair; import java.util.List; @@ -11,10 +13,12 @@ /** * DTO class to store value and description for explain details. * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. - * @param docId iterator based id of the document - * @param scoreDetails list of score details for the document, each Pair object contains score and description of the score */ -public record ExplanationDetails(int docId, List> scoreDetails) { +@Value +@AllArgsConstructor +public class ExplanationDetails { + int docId; + List> scoreDetails; public ExplanationDetails(List> scoreDetails) { this(-1, scoreDetails); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java index 499ce77cf..b4c5cd557 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java @@ -9,6 +9,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.stream.Collectors; @@ -45,6 +46,9 @@ public static Map getDocIdAtQueryForNorm * @return a string describing the combination technique and its parameters */ public static String describeCombinationTechnique(final String techniqueName, final List weights) { + if (Objects.isNull(techniqueName)) { + throw new IllegalArgumentException("combination technique name cannot be null"); + } return Optional.ofNullable(weights) .filter(w -> !w.isEmpty()) .map(w -> String.format(Locale.ROOT, "%s, weights %s", techniqueName, weights)) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java new file mode 100644 index 000000000..becab3860 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Before; + +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ExplanationUtilsTests extends OpenSearchQueryTestCase { + + private DocIdAtSearchShard docId1; + private DocIdAtSearchShard docId2; + private Map> normalizedScores; + private final MinMaxScoreNormalizationTechnique MIN_MAX_TECHNIQUE = new MinMaxScoreNormalizationTechnique(); + + @Before + public void setUp() throws Exception { + super.setUp(); + SearchShard searchShard = new SearchShard("test_index", 0, "abcdefg"); + docId1 = new DocIdAtSearchShard(1, searchShard); + docId2 = new DocIdAtSearchShard(2, searchShard); + normalizedScores = new HashMap<>(); + } + + public void testGetDocIdAtQueryForNormalization() { + // Setup + normalizedScores.put(docId1, Arrays.asList(1.0f, 0.5f)); + normalizedScores.put(docId2, Arrays.asList(0.8f)); + // Act + Map result = ExplanationUtils.getDocIdAtQueryForNormalization( + normalizedScores, + MIN_MAX_TECHNIQUE + ); + // Assert + assertNotNull(result); + assertEquals(2, result.size()); + + // Assert first document + ExplanationDetails details1 = result.get(docId1); + assertNotNull(details1); + List> explanations1 = details1.getScoreDetails(); + assertEquals(2, explanations1.size()); + assertEquals(1.0f, explanations1.get(0).getLeft(), 0.001); + assertEquals(0.5f, explanations1.get(1).getLeft(), 0.001); + assertEquals("min_max normalization of:", explanations1.get(0).getRight()); + assertEquals("min_max normalization of:", explanations1.get(1).getRight()); + + // Assert second document + ExplanationDetails details2 = result.get(docId2); + assertNotNull(details2); + List> explanations2 = details2.getScoreDetails(); + assertEquals(1, explanations2.size()); + assertEquals(0.8f, explanations2.get(0).getLeft(), 0.001); + assertEquals("min_max normalization of:", explanations2.get(0).getRight()); + } + + public void testGetDocIdAtQueryForNormalizationWithEmptyScores() { + // Setup + // Using empty normalizedScores from setUp + // Act + Map result = ExplanationUtils.getDocIdAtQueryForNormalization( + normalizedScores, + MIN_MAX_TECHNIQUE + ); + // Assert + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + public void testDescribeCombinationTechniqueWithWeights() { + // Setup + String techniqueName = "test_technique"; + List weights = Arrays.asList(0.3f, 0.7f); + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights); + // Assert + assertEquals("test_technique, weights [0.3, 0.7]", result); + } + + public void testDescribeCombinationTechniqueWithoutWeights() { + // Setup + String techniqueName = "test_technique"; + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, null); + // Assert + assertEquals("test_technique", result); + } + + public void testDescribeCombinationTechniqueWithEmptyWeights() { + // Setup + String techniqueName = "test_technique"; + List weights = Arrays.asList(); + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights); + // Assert + assertEquals("test_technique", result); + } + + public void testDescribeCombinationTechniqueWithNullTechnique() { + // Setup + List weights = Arrays.asList(1.0f); + // Act & Assert + expectThrows(IllegalArgumentException.class, () -> ExplanationUtils.describeCombinationTechnique(null, weights)); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java new file mode 100644 index 000000000..453cc471c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class ExplanationResponseProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoParams_thenSuccessful() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + // Assert + assertProcessor(responseProcessor, tag, description, ignoreFailure); + } + + @SneakyThrows + public void testInvalidInput_whenParamsPassedToFactory_thenSuccessful() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + // create map of random parameters + Map config = new HashMap<>(); + for (int i = 0; i < randomInt(1_000); i++) { + config.put(randomAlphaOfLength(10) + i, randomAlphaOfLength(100)); + } + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + // Assert + assertProcessor(responseProcessor, tag, description, ignoreFailure); + } + + @SneakyThrows + public void testNewInstanceCreation_whenCreateMultipleTimes_thenNewInstanceReturned() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessorOne = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + + SearchResponseProcessor responseProcessorTwo = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + + // Assert + assertNotEquals(responseProcessorOne, responseProcessorTwo); + } + + private static void assertProcessor(SearchResponseProcessor responseProcessor, String tag, String description, boolean ignoreFailure) { + assertNotNull(responseProcessor); + assertTrue(responseProcessor instanceof ExplanationResponseProcessor); + ExplanationResponseProcessor explanationResponseProcessor = (ExplanationResponseProcessor) responseProcessor; + assertEquals("explanation_response_processor", explanationResponseProcessor.getType()); + assertEquals(tag, explanationResponseProcessor.getTag()); + assertEquals(description, explanationResponseProcessor.getDescription()); + assertEquals(ignoreFailure, explanationResponseProcessor.isIgnoreFailure()); + } +}