From 1d93192ced70c7012a9ef2f13acabf6c19ee7616 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 23 Dec 2024 08:53:43 -0800 Subject: [PATCH] Integrate explainability for hybrid query into RRF processor (#1037) * Integrate explainability for hybrid query into RRF processor Signed-off-by: Martin Gaievski --- .../AbstractScoreHybridizationProcessor.java | 65 +++++++ .../ExplanationResponseProcessor.java | 3 +- .../processor/NormalizationProcessor.java | 36 +--- .../neuralsearch/processor/RRFProcessor.java | 19 +- .../RRFScoreCombinationTechnique.java | 18 +- .../combination/ScoreCombinationFactory.java | 2 +- .../RRFNormalizationTechnique.java | 71 ++++++-- ...tractScoreHybridizationProcessorTests.java | 152 ++++++++++++++++ ...=> ExplanationResponseProcessorTests.java} | 116 ++++++++++++- .../processor/RRFProcessorTests.java | 33 ++++ .../RRFScoreCombinationTechniqueTests.java | 44 ++++- .../RRFNormalizationTechniqueTests.java | 54 ++++++ .../query/HybridQueryExplainIT.java | 162 ++++++++++++++++-- .../neuralsearch/BaseNeuralSearchIT.java | 24 ++- 14 files changed, 708 insertions(+), 91 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java rename src/test/java/org/opensearch/neuralsearch/processor/{ExplanationPayloadProcessorTests.java => ExplanationResponseProcessorTests.java} (76%) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java new file mode 100644 index 000000000..456e8415a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import java.util.Optional; + +/** + * Base class for all score hybridization processors. This class is responsible for executing the score hybridization process. + * It is a pipeline processor that is executed after the query phase and before the fetch phase. + */ +public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor { + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is no pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty()); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult + * @param searchPhaseContext + * @param requestContextOptional + * @param + */ + abstract void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 01cdfcb0d..7a61519f8 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -111,8 +111,9 @@ public SearchResponse processResponse( ); } // Create and set final explanation combining all components + Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore(); Explanation finalExplanation = Explanation.match( - searchHit.getScore(), + finalScore, // combination level explanation is always a single detail combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2008ae97..2b1a28d01 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -19,9 +19,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; -import org.opensearch.search.internal.SearchContext; import org.opensearch.search.pipeline.PipelineProcessingContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; import lombok.AllArgsConstructor; @@ -33,7 +31,7 @@ */ @Log4j2 @AllArgsConstructor -public class NormalizationProcessor implements SearchPhaseResultsProcessor { +public class NormalizationProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "normalization-processor"; private final String tag; @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { private final ScoreCombinationTechnique combinationTechnique; private final NormalizationProcessorWorkflow normalizationWorkflow; - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor. This method is called when there is no pipeline context - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); - } - - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline - * @param - */ - @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext, - final PipelineProcessingContext requestContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); - } - - private void prepareAndExecuteNormalizationWorkflow( + void hybridizeScores( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional requestContextOptional diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index ca67f2d1c..100cf9fc6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -12,11 +12,12 @@ import java.util.Objects; import java.util.Optional; +import com.google.common.annotations.VisibleForTesting; import lombok.Getter; -import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.action.search.QueryPhaseResultConsumer; @@ -25,7 +26,6 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.internal.SearchContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; @@ -39,7 +39,7 @@ */ @Log4j2 @AllArgsConstructor -public class RRFProcessor implements SearchPhaseResultsProcessor { +public class RRFProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "score-ranker-processor"; @Getter @@ -57,9 +57,10 @@ public class RRFProcessor implements SearchPhaseResultsProcessor { * @param searchPhaseContext {@link SearchContext} */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional ) { if (shouldSkipProcessor(searchPhaseResult)) { log.debug("Query results are not compatible with RRF processor"); @@ -67,7 +68,8 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - + boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) + && searchPhaseContext.getRequest().source().explain(); // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending // on coming from NormalizationProcessor or RRFProcessor NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() @@ -75,7 +77,8 @@ public void process( .fetchSearchResultOptional(fetchSearchResult) .normalizationTechnique(normalizationTechnique) .combinationTechnique(combinationTechnique) - .explain(false) + .explain(explain) + .pipelineProcessingContext(requestContextOptional.orElse(null)) .build(); normalizationWorkflow.execute(normalizationExecuteDTO); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java index befe14dda..6d6c94b94 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -6,27 +6,39 @@ import lombok.ToString; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; -import java.util.Map; +import java.util.List; +import java.util.Objects; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; @Log4j2 /** * Abstracts combination of scores based on reciprocal rank fusion algorithm */ @ToString(onlyExplicitlyIncluded = true) -public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique { +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; // Not currently using weights for RRF, no need to modify or verify these params - public RRFScoreCombinationTechnique(final Map params, final ScoreCombinationUtil combinationUtil) {} + public RRFScoreCombinationTechnique() {} @Override public float combine(final float[] scores) { + if (Objects.isNull(scores)) { + throw new IllegalArgumentException("scores array cannot be null"); + } float sumScores = 0.0f; for (float score : scores) { sumScores += score; } return sumScores; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, List.of()); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 1e560342a..3f1996424 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -27,7 +27,7 @@ public class ScoreCombinationFactory { GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil), RRFScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new RRFScoreCombinationTechnique() ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java index 16ef83d05..80fc65eb3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -6,27 +6,34 @@ import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Locale; import java.util.Set; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; import org.apache.commons.lang3.Range; import org.apache.commons.lang3.math.NumberUtils; -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.ToString; import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts calculation of rank scores for each document returned as part of * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. */ @ToString(onlyExplicitlyIncluded = true) -public class RRFNormalizationTechnique implements ScoreNormalizationTechnique { +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "rrf"; public static final int DEFAULT_RANK_CONSTANT = 60; @@ -58,21 +65,49 @@ public RRFNormalizationTechnique(final Map params, final ScoreNo public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { - if (Objects.isNull(compoundQueryTopDocs)) { - continue; - } - List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); - for (TopDocs topDocs : topDocsPerSubQuery) { - int docsCountPerSubQuery = topDocs.scoreDocs.length; - ScoreDoc[] scoreDocs = topDocs.scoreDocs; - for (int j = 0; j < docsCountPerSubQuery; j++) { - // using big decimal approach to minimize error caused by floating point ops - // score = 1.f / (float) (rankConstant + j + 1)) - scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP) - .floatValue(); - } - } + processTopDocs(compoundQueryTopDocs, (docId, score) -> {}); + } + } + + @Override + public String describe() { + return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + processTopDocs( + compoundQueryTopDocs, + (docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score) + ); } + + return getDocIdAtQueryForNormalization(normalizedScores, this); + } + + private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer scoreProcessor) { + if (Objects.isNull(compoundQueryTopDocs)) { + return; + } + + compoundQueryTopDocs.getTopDocs().forEach(topDocs -> { + IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> { + float normalizedScore = calculateNormalizedScore(position); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard( + topDocs.scoreDocs[position].doc, + compoundQueryTopDocs.getSearchShard() + ); + scoreProcessor.accept(docIdAtSearchShard, normalizedScore); + topDocs.scoreDocs[position].score = normalizedScore; + }); + }); + } + + private float calculateNormalizedScore(int position) { + return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue(); } private int getRankConstant(final Map params) { @@ -96,7 +131,7 @@ private void validateRankConstant(final int rankConstant) { } } - public static int getParamAsInteger(final Map parameters, final String fieldName) { + private static int getParamAsInteger(final Map parameters, final String fieldName) { try { return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); } catch (NumberFormatException e) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java new file mode 100644 index 000000000..4e9ab59e5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AbstractScoreHybridizationProcessorTests extends OpenSearchTestCase { + private static final String TEST_TAG = "test_processor"; + private static final String TEST_DESCRIPTION = "Test Processor"; + + private TestScoreHybridizationProcessor processor; + private NormalizationProcessorWorkflow normalizationWorkflow; + + private static class TestScoreHybridizationProcessor extends AbstractScoreHybridizationProcessor { + private final String tag; + private final String description; + private final NormalizationProcessorWorkflow normalizationWorkflow1; + + TestScoreHybridizationProcessor(String tag, String description, NormalizationProcessorWorkflow normalizationWorkflow) { + this.tag = tag; + this.description = description; + normalizationWorkflow1 = normalizationWorkflow; + } + + @Override + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ) { + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .pipelineProcessingContext(requestContextOptional.orElse(null)) + .build(); + normalizationWorkflow1.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.FETCH; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.QUERY; + } + + @Override + public String getType() { + return "my_processor"; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + } + + @Before + public void setup() { + normalizationWorkflow = mock(NormalizationProcessorWorkflow.class); + + processor = new TestScoreHybridizationProcessor(TEST_TAG, TEST_DESCRIPTION, normalizationWorkflow); + } + + public void testProcessorMetadata() { + assertEquals(TEST_TAG, processor.getTag()); + assertEquals(TEST_DESCRIPTION, processor.getDescription()); + } + + public void testProcessWithExplanations() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchPhaseContext context = mock(SearchPhaseContext.class); + SearchPhaseResults results = mock(SearchPhaseResults.class); + + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(context.getRequest()).thenReturn(searchRequest); + + AtomicArray resultsArray = new AtomicArray<>(1); + QuerySearchResult queryResult = createQuerySearchResult(); + resultsArray.set(0, queryResult); + when(results.getAtomicArray()).thenReturn(resultsArray); + + TestScoreHybridizationProcessor spyProcessor = spy(processor); + spyProcessor.process(results, context); + + verify(spyProcessor).hybridizeScores(any(SearchPhaseResults.class), any(SearchPhaseContext.class), any(Optional.class)); + verify(normalizationWorkflow).execute(any()); + } + + public void testProcess() { + SearchPhaseResults searchPhaseResult = mock(SearchPhaseResults.class); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + PipelineProcessingContext requestContext = mock(PipelineProcessingContext.class); + + TestScoreHybridizationProcessor spyProcessor = spy(processor); + spyProcessor.process(searchPhaseResult, searchPhaseContext, requestContext); + + verify(spyProcessor).hybridizeScores(any(SearchPhaseResults.class), any(SearchPhaseContext.class), any(Optional.class)); + } + + private QuerySearchResult createQuerySearchResult() { + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE), + null + ); + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, 1.0f) }); + result.topDocs(new TopDocsAndMaxScore(topDocs, 1.0f), new DocValueFormat[0]); + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java similarity index 76% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java index e47ea43d2..530753a96 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -37,9 +37,10 @@ import java.util.TreeMap; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -192,6 +193,119 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces assertOnExplanationResults(processedResponse, maxScore); } + @SneakyThrows + public void testProcessResponse_whenNullSearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = getSearchResponse(null); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenEmptySearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits emptyHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponse searchResponse = getSearchResponse(emptyHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenNullExplanation_thenSkipProcessing() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + for (SearchHit hit : searchHits.getHits()) { + hit.explanation(null); + } + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenInvalidExplanationPayload_thenHandleGracefully() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Set invalid payload + Map invalidPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + "invalid payload" + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(invalidPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenZeroScore_thenProcessCorrectly() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(0.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + assertEquals(0.0f, processedResponse.getHits().getMaxScore(), DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testProcessResponse_whenScoreIsNaN_thenExplanationUsesZero() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + // Create SearchHits with NaN score + SearchHits searchHits = getSearchHits(Float.NaN); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Setup explanation payload + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + // Process response + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + + // Verify results + assertNotNull(processedResponse); + SearchHit[] hits = processedResponse.getHits().getHits(); + assertNotNull(hits); + assertTrue(hits.length > 0); + + // Verify that the explanation uses 0.0f when input score was NaN + Explanation explanation = hits[0].getExplanation(); + assertNotNull(explanation); + assertEquals(0.0f, (float) explanation.getValue(), DELTA_FOR_FLOATS_ASSERTION); + } + private static SearchHits getSearchHits(float maxScore) { int numResponses = 1; int numIndices = 2; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java index b7764128f..753c0b8fe 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -9,6 +9,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.junit.Before; +import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.action.OriginalIndices; @@ -111,6 +112,10 @@ public void testProcess_whenValidHybridInput_thenSucceed() { when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder()); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); @@ -177,6 +182,34 @@ public void testGetFetchSearchResults() { assertFalse(result.isPresent()); } + @SneakyThrows + public void testProcess_whenExplainIsTrue_thenExplanationIsAdded() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + // Capture the actual request + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass( + NormalizationProcessorWorkflowExecuteRequest.class + ); + verify(mockNormalizationWorkflow).execute(requestCaptor.capture()); + + // Verify the captured request + NormalizationProcessorWorkflowExecuteRequest capturedRequest = requestCaptor.getValue(); + assertTrue(capturedRequest.isExplain()); + assertNull(capturedRequest.getPipelineProcessingContext()); + } + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { ShardId shardId = new ShardId("index", "uuid", 0); OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java index daed466d3..39b4dd4e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java @@ -5,26 +5,62 @@ package org.opensearch.neuralsearch.processor.combination; import java.util.List; -import java.util.Map; public class RRFScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { - private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + private static final int RANK_CONSTANT = 60; + private RRFScoreCombinationTechnique combinationTechnique; public RRFScoreCombinationTechniqueTests() { this.expectedScoreFunction = (scores, weights) -> RRF(scores, weights); + combinationTechnique = new RRFScoreCombinationTechnique(); } public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(); testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); } public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { - ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(Map.of(), scoreCombinationUtil); + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(); testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); } + public void testDescribe() { + String description = combinationTechnique.describe(); + assertEquals("rrf", description); + } + + public void testCombineWithEmptyInput() { + float[] scores = new float[0]; + float result = combinationTechnique.combine(scores); + assertEquals(0.0f, result, 0.001f); + } + + public void testCombineWithSingleScore() { + float[] scores = new float[] { 0.5f }; + float result = combinationTechnique.combine(scores); + assertEquals(0.5f, result, 0.001f); + } + + public void testCombineWithMultipleScores() { + float[] scores = new float[] { 0.8f, 0.6f, 0.4f }; + float result = combinationTechnique.combine(scores); + float expected = 0.8f + 0.6f + 0.4f; + assertEquals(expected, result, 0.001f); + } + + public void testCombineWithZeroScores() { + float[] scores = new float[] { 0.0f, 0.0f }; + float result = combinationTechnique.combine(scores); + assertEquals(0.0f, result, 0.001f); + } + + public void testCombineWithNullInput() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> combinationTechnique.combine(null)); + assertEquals("scores array cannot be null", exception.getMessage()); + } + private float RRF(List scores, List weights) { float sumScores = 0.0f; for (float score : scores) { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java index 1e1089846..da6d37bd7 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -10,10 +10,14 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; import java.math.BigDecimal; import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; import java.util.List; import java.util.Map; @@ -25,6 +29,16 @@ public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + public void testDescribe() { + // verify with default values for parameters + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [60]", normalizationTechnique.describe()); + + // verify when parameter values are set + normalizationTechnique = new RRFNormalizationTechnique(Map.of("rank_constant", 25), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [25]", normalizationTechnique.describe()); + } + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); float[] scores = { 0.5f, 0.2f }; @@ -224,6 +238,27 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the } } + public void testExplainWithEmptyAndNullList() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + normalizationTechnique.explain(List.of()); + + List compoundTopDocs = new ArrayList<>(); + compoundTopDocs.add(null); + normalizationTechnique.explain(compoundTopDocs); + } + + public void testExplainWithSingleTopDocs() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + CompoundTopDocs topDocs = createCompoundTopDocs(new float[] { 0.8f }, 1); + List queryTopDocs = Collections.singletonList(topDocs); + + Map explanation = normalizationTechnique.explain(queryTopDocs); + + assertNotNull(explanation); + assertEquals(1, explanation.size()); + assertTrue(explanation.containsKey(new DocIdAtSearchShard(0, new SearchShard("test_index", 0, "uuid")))); + } + private float rrfNorm(int rank) { // 1.0f / (float) (rank + RANK_CONSTANT + 1); return BigDecimal.ONE.divide(BigDecimal.valueOf(rank + RANK_CONSTANT + 1), 10, RoundingMode.HALF_UP).floatValue(); @@ -239,4 +274,23 @@ private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); } } + + private CompoundTopDocs createCompoundTopDocs(float[] scores, int size) { + ScoreDoc[] scoreDocs = new ScoreDoc[size]; + for (int i = 0; i < size; i++) { + scoreDocs[i] = new ScoreDoc(i, scores[i]); + } + TopDocs singleTopDocs = new TopDocs(new TotalHits(size, TotalHits.Relation.EQUAL_TO), scoreDocs); + + List topDocsList = Collections.singletonList(singleTopDocs); + TopDocs topDocs = new TopDocs(new TotalHits(size, TotalHits.Relation.EQUAL_TO), scoreDocs); + SearchShard searchShard = new SearchShard("test_index", 0, "uuid"); + + return new CompoundTopDocs( + new TotalHits(size, TotalHits.Relation.EQUAL_TO), + topDocsList, + false, // isSortEnabled + searchShard + ); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index b7e4f753a..c6eaa21ff 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -58,7 +58,8 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); - private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String NORMALIZATION_SEARCH_PIPELINE = "normalization-search-pipeline"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[768]; @@ -78,7 +79,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -95,7 +96,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -187,7 +188,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); assertEquals(3, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -196,7 +197,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipeline( - SEARCH_PIPELINE, + NORMALIZATION_SEARCH_PIPELINE, NORMALIZATION_TECHNIQUE_L2, DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), @@ -217,7 +218,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() hybridQueryBuilder, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // basic sanity check for search hits @@ -322,7 +323,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(0, getListOfValues(explanationsHit4, "details").size()); assertTrue((double) explanationsHit4.get("value") > 0.0f); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -331,7 +332,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with normalization processor, no explanation response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -348,7 +349,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -463,7 +464,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe assertEquals("boost", explanationsHit3Details.get("description")); assertEquals(0, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -472,7 +473,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -483,7 +484,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -521,7 +522,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -530,7 +531,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); @@ -543,7 +544,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -581,7 +582,136 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenRRFProcessor_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, true); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain, hit 1 + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "rrf combination of:"; + assertEquals(expectedTopLevelDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + // two sub-queries meaning we do have two detail objects with separate query level details + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit1.get("description")); + assertTrue((double) explanationsHit1.get("value") > 0.5f); + assertEquals(0, ((List) explanationsHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals("within top 10", explanationsHit2.get("description")); + assertTrue((double) explanationsHit2.get("value") > 0.0f); + assertEquals(0, ((List) explanationsHit2.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit2.get("description")); + List> hit2Details = getListOfValues(explanationForHit2, "details"); + assertEquals(2, hit2Details.size()); + + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit1.get("description")); + assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue((double) hit2DetailsForHit2.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit2.get("description")); + assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit3.get("description")); + List> hit3Details = getListOfValues(explanationForHit3, "details"); + assertEquals(1, hit3Details.size()); + + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > .0f); + assertEquals("rrf, rank_constant [60] normalization of:", hit3DetailsForHit1.get("description")); + assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); + + Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); + assertEquals("within top 10", explanationsHit3.get("description")); + assertEquals(0, getListOfValues(explanationsHit3, "details").size()); + assertTrue((double) explanationsHit3.get("value") > 0.0f); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = getValueByKey(searchHit4, "_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(1, hit4Details.size()); + + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit4DetailsForHit1.get("description")); + assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); + + Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit4.get("description")); + assertEquals(0, getListOfValues(explanationsHit4, "details").size()); + assertTrue((double) explanationsHit4.get("value") > 0.0f); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, RRF_SEARCH_PIPELINE); } } diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 5107296c3..46e835cd0 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -1472,7 +1472,12 @@ protected enum ProcessorType { @SneakyThrows protected void createDefaultRRFSearchPipeline() { - String requestBody = XContentFactory.jsonBuilder() + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, false); + } + + @SneakyThrows + protected void createRRFSearchPipeline(final String pipelineName, boolean addExplainResponseProcessor) { + XContentBuilder builder = XContentFactory.jsonBuilder() .startObject() .field("description", "Post processor for hybrid search") .startArray("phase_results_processors") @@ -1483,14 +1488,23 @@ protected void createDefaultRRFSearchPipeline() { .endObject() .endObject() .endObject() - .endArray() - .endObject() - .toString(); + .endArray(); + + if (addExplainResponseProcessor) { + builder.startArray("response_processors") + .startObject() + .startObject("explanation_response_processor") + .endObject() + .endObject() + .endArray(); + } + + String requestBody = builder.endObject().toString(); makeRequest( client(), "PUT", - String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + String.format(LOCALE, "/_search/pipeline/%s", pipelineName), null, toHttpEntity(String.format(LOCALE, requestBody)), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT))