From 0b0e376a4b2af5ff3238eb4148fccbeb3fa83c4f Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Mon, 13 Jan 2025 16:59:10 -0800 Subject: [PATCH] Adding Reciprocal Rank Fusion (RRF) in hybrid query (#1086) * Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (#874) --------- Signed-off-by: Isaac Johnson Signed-off-by: Ryan Bogan Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/bwc/KnnRadialSearchIT.java | 1 - .../neuralsearch/plugin/NeuralSearch.java | 10 +- .../AbstractScoreHybridizationProcessor.java | 65 ++++ .../ExplanationResponseProcessor.java | 3 +- .../processor/InferenceProcessor.java | 7 +- .../processor/NormalizationExecuteDTO.java | 35 +++ .../processor/NormalizationProcessor.java | 36 +-- .../NormalizationProcessorWorkflow.java | 38 +-- .../processor/NormalizeScoresDTO.java | 26 ++ .../neuralsearch/processor/RRFProcessor.java | 146 +++++++++ .../RRFScoreCombinationTechnique.java | 44 +++ .../combination/ScoreCombinationFactory.java | 4 +- .../combination/ScoreCombinationUtil.java | 5 +- .../factory/RRFProcessorFactory.java | 79 +++++ .../L2ScoreNormalizationTechnique.java | 4 +- .../MinMaxScoreNormalizationTechnique.java | 4 +- .../RRFNormalizationTechnique.java | 141 +++++++++ .../ScoreNormalizationFactory.java | 18 +- .../ScoreNormalizationTechnique.java | 14 +- .../normalization/ScoreNormalizationUtil.java | 57 ++++ .../normalization/ScoreNormalizer.java | 15 +- .../plugin/NeuralSearchTests.java | 11 +- ...tractScoreHybridizationProcessorTests.java | 152 +++++++++ ...=> ExplanationResponseProcessorTests.java} | 116 ++++++- .../NormalizationProcessorTests.java | 6 +- .../NormalizationProcessorWorkflowTests.java | 100 +++--- .../processor/RRFProcessorIT.java | 93 ++++++ .../processor/RRFProcessorTests.java | 259 +++++++++++++++ .../ScoreNormalizationTechniqueTests.java | 31 +- .../TextEmbeddingProcessorTests.java | 113 +++---- .../RRFScoreCombinationTechniqueTests.java | 71 +++++ .../ScoreCombinationFactoryTests.java | 8 + ....java => ScoreNormalizationUtilTests.java} | 2 +- .../NormalizationProcessorFactoryTests.java | 21 +- .../factory/RRFProcessorFactoryTests.java | 214 +++++++++++++ .../L2ScoreNormalizationTechniqueTests.java | 21 +- ...inMaxScoreNormalizationTechniqueTests.java | 19 +- .../RRFNormalizationTechniqueTests.java | 296 ++++++++++++++++++ .../ScoreNormalizationFactoryTests.java | 8 + .../query/HybridQueryExplainIT.java | 162 +++++++++- .../query/OpenSearchQueryTestCase.java | 2 + .../query/HybridCollectorManagerTests.java | 1 - .../HybridQueryScoreDocsMergerTests.java | 2 - .../search/query/TopDocsMergerTests.java | 2 - .../neuralsearch/BaseNeuralSearchIT.java | 42 +++ 46 files changed, 2245 insertions(+), 260 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java rename src/test/java/org/opensearch/neuralsearch/processor/{ExplanationPayloadProcessorTests.java => ExplanationResponseProcessorTests.java} (76%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java rename src/test/java/org/opensearch/neuralsearch/processor/combination/{ScoreCombinationUtilTests.java => ScoreNormalizationUtilTests.java} (97%) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 12248ccad..3b5c2ca0d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features - Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048)) +- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874)) ### Enhancements - Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) - Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988)) diff --git a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java index c3f461871..838d7ae8a 100644 --- a/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java +++ b/qa/restart-upgrade/src/test/java/org/opensearch/neuralsearch/bwc/KnnRadialSearchIT.java @@ -69,7 +69,6 @@ private void validateIndexQuery(final String modelId) { .modelId(modelId) .maxDistance(100000f) .build(); - Map responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1); assertNotNull(responseWithMaxDistanceQuery); } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 1350a7963..f7ac5d19f 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -30,22 +30,24 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; -import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; -import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; @@ -157,7 +159,9 @@ public Map void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty()); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor. This method is called when there is pipeline context + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext + ) { + hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult + * @param searchPhaseContext + * @param requestContextOptional + * @param + */ + abstract void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ); +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java index 1cdd69b15..3423a2e29 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -111,8 +111,9 @@ public SearchResponse processResponse( ); } // Create and set final explanation combining all components + Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore(); Explanation finalExplanation = Explanation.match( - searchHit.getScore(), + finalScore, // combination level explanation is always a single detail combinationExplanation.getScoreDetails().get(0).getValue(), normalizedExplanation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index 3fb45ceeb..6ee54afe7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -319,7 +319,7 @@ Map buildMapWithTargetKeys(IngestDocument ingestDocument) { buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes); mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); } else { - mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey))); + mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); } } return mapWithProcessorKeys; @@ -357,9 +357,8 @@ void buildNestedMap(String parentKey, Object processorKey, Map s } treeRes.merge(parentKey, next, REMAPPING_FUNCTION); } else { - Object parentValue = sourceAndMetadataMap.get(parentKey); String key = String.valueOf(processorKey); - treeRes.put(key, normalizeSourceValue(parentValue)); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); } } @@ -404,7 +403,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { indexName, clusterService, environment, - true + false ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java new file mode 100644 index 000000000..260c06d67 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationExecuteDTO.java @@ -0,0 +1,35 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.query.QuerySearchResult; + +import java.util.List; +import java.util.Optional; + +/** + * DTO object to hold data in NormalizationProcessorWorkflow class + * in NormalizationProcessorWorkflow. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizationExecuteDTO { + @NonNull + private List querySearchResults; + @NonNull + private Optional fetchSearchResultOptional; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; + @NonNull + private ScoreCombinationTechnique combinationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index d2fa03fde..80499543e 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -19,9 +19,7 @@ import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; -import org.opensearch.search.internal.SearchContext; import org.opensearch.search.pipeline.PipelineProcessingContext; -import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; import lombok.AllArgsConstructor; @@ -33,7 +31,7 @@ */ @Log4j2 @AllArgsConstructor -public class NormalizationProcessor implements SearchPhaseResultsProcessor { +public class NormalizationProcessor extends AbstractScoreHybridizationProcessor { public static final String TYPE = "normalization-processor"; private final String tag; @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { private final ScoreCombinationTechnique combinationTechnique; private final NormalizationProcessorWorkflow normalizationWorkflow; - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor. This method is called when there is no pipeline context - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - */ @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); - } - - /** - * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor - * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution - * @param searchPhaseContext {@link SearchContext} - * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline - * @param - */ - @Override - public void process( - final SearchPhaseResults searchPhaseResult, - final SearchPhaseContext searchPhaseContext, - final PipelineProcessingContext requestContext - ) { - prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); - } - - private void prepareAndExecuteNormalizationWorkflow( + void hybridizeScores( SearchPhaseResults searchPhaseResult, SearchPhaseContext searchPhaseContext, Optional requestContextOptional diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index db3747a13..51f30f842 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -22,14 +22,12 @@ import org.opensearch.action.search.SearchPhaseContext; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; -import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; -import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; @@ -57,34 +55,14 @@ public class NormalizationProcessorWorkflow { /** * Start execution of this workflow - * @param querySearchResults input data with QuerySearchResult from multiple shards - * @param normalizationTechnique technique for score normalization - * @param combinationTechnique technique for score combination + * @param request contains querySearchResults input data with QuerySearchResult + * from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization + * combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique */ - public void execute( - final List querySearchResults, - final Optional fetchSearchResultOptional, - final ScoreNormalizationTechnique normalizationTechnique, - final ScoreCombinationTechnique combinationTechnique, - final SearchPhaseContext searchPhaseContext - ) { - NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() - .querySearchResults(querySearchResults) - .fetchSearchResultOptional(fetchSearchResultOptional) - .normalizationTechnique(normalizationTechnique) - .combinationTechnique(combinationTechnique) - .explain(false) - .searchPhaseContext(searchPhaseContext) - .build(); - execute(request); - } - public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { List querySearchResults = request.getQuerySearchResults(); Optional fetchSearchResultOptional = request.getFetchSearchResultOptional(); - - // save original state - List unprocessedDocIds = unprocessedDocIds(querySearchResults); + List unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); // pre-process data log.debug("Pre-process query results"); @@ -92,9 +70,15 @@ public void execute(final NormalizationProcessorWorkflowExecuteRequest request) explain(request, queryTopDocs); + // Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(request.getNormalizationTechnique()) + .build(); + // normalize log.debug("Do score normalization"); - scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique()); + scoreNormalizer.normalizeScores(normalizeScoresDTO); CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java new file mode 100644 index 000000000..c932a157d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizeScoresDTO.java @@ -0,0 +1,26 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import lombok.NonNull; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; + +import java.util.List; + +/** + * DTO object to hold data required for score normalization. + */ +@AllArgsConstructor +@Builder +@Getter +public class NormalizeScoresDTO { + @NonNull + private List queryTopDocs; + @NonNull + private ScoreNormalizationTechnique normalizationTechnique; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java new file mode 100644 index 000000000..cf9e3b820 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -0,0 +1,146 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + +import java.util.stream.Collectors; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import com.google.common.annotations.VisibleForTesting; +import lombok.Getter; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.query.QuerySearchResult; + +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.internal.SearchContext; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +/** + * Processor for implementing reciprocal rank fusion technique on post + * query search results. Updates query results with + * normalized and combined scores for next phase (typically it's FETCH) + * by using ranks from individual subqueries to calculate 'normalized' + * scores before combining results from subqueries into final results + */ +@Log4j2 +@AllArgsConstructor +public class RRFProcessor extends AbstractScoreHybridizationProcessor { + public static final String TYPE = "score-ranker-processor"; + + @Getter + private final String tag; + @Getter + private final String description; + private final ScoreNormalizationTechnique normalizationTechnique; + private final ScoreCombinationTechnique combinationTechnique; + private final NormalizationProcessorWorkflow normalizationWorkflow; + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + */ + @Override + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ) { + if (shouldSkipProcessor(searchPhaseResult)) { + log.debug("Query results are not compatible with RRF processor"); + return; + } + List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); + Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) + && searchPhaseContext.getRequest().source().explain(); + // make data transfer object to pass in, execute will get object with 4 or 5 fields, depending + // on coming from NormalizationProcessor or RRFProcessor + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .explain(explain) + .pipelineProcessingContext(requestContextOptional.orElse(null)) + .searchPhaseContext(searchPhaseContext) + .build(); + normalizationWorkflow.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.QUERY; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.FETCH; + } + + @Override + public String getType() { + return TYPE; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + + @VisibleForTesting + boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { + return true; + } + + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); + } + + /** + * Return true if results are from hybrid query. + * @param searchPhaseResult + * @return true if results are from hybrid query + */ + @VisibleForTesting + boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + // check for delimiter at the end of the score docs. + return Objects.nonNull(searchPhaseResult.queryResult()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) + && searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 + && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); + } + + List getQueryPhaseSearchResults(final SearchPhaseResults results) { + return results.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + } + + @VisibleForTesting + Optional getFetchSearchResults( + final SearchPhaseResults searchPhaseResults + ) { + Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); + return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java new file mode 100644 index 000000000..0f43688a6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechnique.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import lombok.ToString; +import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; + +import java.util.List; +import java.util.Objects; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; + +/** + * Abstracts combination of scores based on reciprocal rank fusion algorithm + */ +@Log4j2 +@ToString(onlyExplicitlyIncluded = true) +public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + + // Not currently using weights for RRF, no need to modify or verify these params + public RRFScoreCombinationTechnique() {} + + @Override + public float combine(final float[] scores) { + if (Objects.isNull(scores)) { + throw new IllegalArgumentException("scores array cannot be null"); + } + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, List.of()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java index 23d8e01be..3f1996424 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactory.java @@ -25,7 +25,9 @@ public class ScoreCombinationFactory { HarmonicMeanScoreCombinationTechnique.TECHNIQUE_NAME, params -> new HarmonicMeanScoreCombinationTechnique(params, scoreCombinationUtil), GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME, - params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil) + params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil), + RRFScoreCombinationTechnique.TECHNIQUE_NAME, + params -> new RRFScoreCombinationTechnique() ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index 5f18baf09..99d0401d2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -26,6 +26,7 @@ public class ScoreCombinationUtil { public static final String PARAM_NAME_WEIGHTS = "weights"; private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + private static final float DELTA_FOR_WEIGHTS_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -117,7 +118,7 @@ protected void validateIfWeightsMatchScores(final float[] scores, final List weightsList) { - boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.of(0.0f, 1.0f).contains(weight)); if (isOutOfRange) { throw new IllegalArgumentException( String.format( @@ -128,7 +129,7 @@ private void validateWeights(final List weightsList) { ); } float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); - if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_WEIGHTS_ASSERTION)) { throw new IllegalArgumentException( String.format( Locale.ROOT, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java new file mode 100644 index 000000000..fa4f39942 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactory.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import java.util.Map; +import java.util.Objects; + +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.combination.RRFScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.RRFNormalizationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; + +import lombok.AllArgsConstructor; +import lombok.extern.log4j.Log4j2; + +import static org.opensearch.ingest.ConfigurationUtils.readOptionalMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; + +/** + * Factory class to instantiate RRF processor based on user provided input. + */ +@AllArgsConstructor +@Log4j2 +public class RRFProcessorFactory implements Processor.Factory { + public static final String COMBINATION_CLAUSE = "combination"; + public static final String TECHNIQUE = "technique"; + public static final String PARAMETERS = "parameters"; + + private final NormalizationProcessorWorkflow normalizationProcessorWorkflow; + private ScoreNormalizationFactory scoreNormalizationFactory; + private ScoreCombinationFactory scoreCombinationFactory; + + @Override + public SearchPhaseResultsProcessor create( + final Map> processorFactories, + final String tag, + final String description, + final boolean ignoreFailure, + final Map config, + final Processor.PipelineContext pipelineContext + ) throws Exception { + // assign defaults + ScoreNormalizationTechnique normalizationTechnique = scoreNormalizationFactory.createNormalization( + RRFNormalizationTechnique.TECHNIQUE_NAME + ); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination( + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + Map combinationClause = readOptionalMap(RRFProcessor.TYPE, tag, config, COMBINATION_CLAUSE); + if (Objects.nonNull(combinationClause)) { + String combinationTechnique = readStringProperty( + RRFProcessor.TYPE, + tag, + combinationClause, + TECHNIQUE, + RRFScoreCombinationTechnique.TECHNIQUE_NAME + ); + // check for optional combination params + Map params = readOptionalMap(RRFProcessor.TYPE, tag, combinationClause, PARAMETERS); + normalizationTechnique = scoreNormalizationFactory.createNormalization(RRFNormalizationTechnique.TECHNIQUE_NAME, params); + scoreCombinationTechnique = scoreCombinationFactory.createCombination(combinationTechnique); + } + log.info( + "Creating search phase results processor of type [{}] with normalization [{}] and combination [{}]", + RRFProcessor.TYPE, + normalizationTechnique, + scoreCombinationTechnique + ); + return new RRFProcessor(tag, description, normalizationTechnique, scoreCombinationTechnique, normalizationProcessorWorkflow); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index e7fbf658c..c9472938d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -14,6 +14,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import lombok.ToString; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; @@ -39,7 +40,8 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechniqu * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); // get l2 norms for each sub-query List normsPerSubquery = getL2Norm(queryTopDocs); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 0e54919d2..7da4c4330 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -21,6 +21,7 @@ import com.google.common.primitives.Floats; import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; @@ -45,7 +46,8 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech * - iterate over each result and update score as per formula above where "score" is raw score returned by Hybrid query */ @Override - public void normalize(final List queryTopDocs) { + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); // do normalization using actual score and min and max scores for corresponding sub query for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java new file mode 100644 index 000000000..80fc65eb3 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechnique.java @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Locale; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.stream.IntStream; + +import org.apache.commons.lang3.Range; +import org.apache.commons.lang3.math.NumberUtils; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import lombok.ToString; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; + +/** + * Abstracts calculation of rank scores for each document returned as part of + * reciprocal rank fusion. Rank scores are summed across subqueries in combination classes. + */ +@ToString(onlyExplicitlyIncluded = true) +public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "rrf"; + public static final int DEFAULT_RANK_CONSTANT = 60; + public static final String PARAM_NAME_RANK_CONSTANT = "rank_constant"; + private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_RANK_CONSTANT); + private static final int MIN_RANK_CONSTANT = 1; + private static final int MAX_RANK_CONSTANT = 10_000; + private static final Range RANK_CONSTANT_RANGE = Range.of(MIN_RANK_CONSTANT, MAX_RANK_CONSTANT); + @ToString.Include + private final int rankConstant; + + public RRFNormalizationTechnique(final Map params, final ScoreNormalizationUtil scoreNormalizationUtil) { + scoreNormalizationUtil.validateParams(params, SUPPORTED_PARAMS); + rankConstant = getRankConstant(params); + } + + /** + * Reciprocal Rank Fusion normalization technique + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant, which has a default value of 60 if not specified by user + * algorithm as follows, where document_n_score is the new score for each document in queryTopDocs + * and subquery_result_rank is the position in the array of documents returned for each subquery + * (j + 1 is used to adjust for 0 indexing) + * document_n_score = 1 / (rankConstant + subquery_result_rank) + * document scores are summed in combination step + */ + @Override + public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + processTopDocs(compoundQueryTopDocs, (docId, score) -> {}); + } + } + + @Override + public String describe() { + return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + processTopDocs( + compoundQueryTopDocs, + (docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score) + ); + } + + return getDocIdAtQueryForNormalization(normalizedScores, this); + } + + private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer scoreProcessor) { + if (Objects.isNull(compoundQueryTopDocs)) { + return; + } + + compoundQueryTopDocs.getTopDocs().forEach(topDocs -> { + IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> { + float normalizedScore = calculateNormalizedScore(position); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard( + topDocs.scoreDocs[position].doc, + compoundQueryTopDocs.getSearchShard() + ); + scoreProcessor.accept(docIdAtSearchShard, normalizedScore); + topDocs.scoreDocs[position].score = normalizedScore; + }); + }); + } + + private float calculateNormalizedScore(int position) { + return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue(); + } + + private int getRankConstant(final Map params) { + if (Objects.isNull(params) || !params.containsKey(PARAM_NAME_RANK_CONSTANT)) { + return DEFAULT_RANK_CONSTANT; + } + int rankConstant = getParamAsInteger(params, PARAM_NAME_RANK_CONSTANT); + validateRankConstant(rankConstant); + return rankConstant; + } + + private void validateRankConstant(final int rankConstant) { + if (!RANK_CONSTANT_RANGE.contains(rankConstant)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "rank constant must be in the interval between 1 and 10000, submitted rank constant: %d", + rankConstant + ) + ); + } + } + + private static int getParamAsInteger(final Map parameters, final String fieldName) { + try { + return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName))); + } catch (NumberFormatException e) { + throw new IllegalArgumentException(String.format(Locale.ROOT, "parameter [%s] must be an integer", fieldName)); + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index ca6ad20d6..7c62893a5 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -6,19 +6,24 @@ import java.util.Map; import java.util.Optional; +import java.util.function.Function; /** * Abstracts creation of exact score normalization method based on technique name */ public class ScoreNormalizationFactory { + private static final ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + public static final ScoreNormalizationTechnique DEFAULT_METHOD = new MinMaxScoreNormalizationTechnique(); - private final Map scoreNormalizationMethodsMap = Map.of( + private final Map, ScoreNormalizationTechnique>> scoreNormalizationMethodsMap = Map.of( MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, - new MinMaxScoreNormalizationTechnique(), + params -> new MinMaxScoreNormalizationTechnique(), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - new L2ScoreNormalizationTechnique() + params -> new L2ScoreNormalizationTechnique(), + RRFNormalizationTechnique.TECHNIQUE_NAME, + params -> new RRFNormalizationTechnique(params, scoreNormalizationUtil) ); /** @@ -27,7 +32,12 @@ public class ScoreNormalizationFactory { * @return instance of ScoreNormalizationMethod for technique name */ public ScoreNormalizationTechnique createNormalization(final String technique) { + return createNormalization(technique, Map.of()); + } + + public ScoreNormalizationTechnique createNormalization(final String technique, final Map params) { return Optional.ofNullable(scoreNormalizationMethodsMap.get(technique)) - .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")); + .orElseThrow(() -> new IllegalArgumentException("provided normalization technique is not supported")) + .apply(params); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index 0b784c678..f8190a728 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -4,9 +4,7 @@ */ package org.opensearch.neuralsearch.processor.normalization; -import java.util.List; - -import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** * Abstracts normalization of scores in query search results. @@ -14,8 +12,12 @@ public interface ScoreNormalizationTechnique { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO is a data transfer object that contains queryTopDocs + * original query results from multiple shards and multiple sub-queries, ScoreNormalizationTechnique, + * and nullable rankConstant that is only used in RRF technique */ - void normalize(final List queryTopDocs); + void normalize(final NormalizeScoresDTO normalizeScoresDTO); + } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java new file mode 100644 index 000000000..ad24b0aaa --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationUtil.java @@ -0,0 +1,57 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import lombok.extern.log4j.Log4j2; + +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * Collection of utility methods for score combination technique classes + */ +@Log4j2 +class ScoreNormalizationUtil { + private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; + + /** + * Validate config parameters for this technique + * @param actualParams map of parameters in form of name-value + * @param supportedParams collection of parameters that we should validate against, typically that's what is supported by exact technique + */ + public void validateParams(final Map actualParams, final Set supportedParams) { + if (Objects.isNull(actualParams) || actualParams.isEmpty()) { + return; + } + // check if only supported params are passed + Optional optionalNotSupportedParam = actualParams.keySet() + .stream() + .filter(paramName -> !supportedParams.contains(paramName)) + .findFirst(); + if (optionalNotSupportedParam.isPresent()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "provided parameter for combination technique is not supported. supported parameters are [%s]", + String.join(",", supportedParams) + ) + ); + } + + // check param types + if (actualParams.keySet().stream().anyMatch(PARAM_NAME_WEIGHTS::equalsIgnoreCase)) { + if (!(actualParams.get(PARAM_NAME_WEIGHTS) instanceof List)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "parameter [%s] must be a collection of numbers", PARAM_NAME_WEIGHTS) + ); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 67a17fda2..381ec9b9a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -12,17 +12,22 @@ import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; public class ScoreNormalizer { /** - * Performs score normalization based on input normalization technique. Mutates input object by updating normalized scores. - * @param queryTopDocs original query results from multiple shards and multiple sub-queries - * @param scoreNormalizationTechnique exact normalization technique that should be applied + * Performs score normalization based on input normalization technique. + * Mutates input object by updating normalized scores. + * @param normalizeScoresDTO used as data transfer object to pass in queryTopDocs, original query results + * from multiple shards and multiple sub-queries, scoreNormalizationTechnique exact normalization technique + * that should be applied, and nullable rankConstant that is only used in RRF technique */ - public void normalizeScores(final List queryTopDocs, final ScoreNormalizationTechnique scoreNormalizationTechnique) { + public void normalizeScores(final NormalizeScoresDTO normalizeScoresDTO) { + final List queryTopDocs = normalizeScoresDTO.getQueryTopDocs(); + final ScoreNormalizationTechnique scoreNormalizationTechnique = normalizeScoresDTO.getNormalizationTechnique(); if (canQueryResultsBeNormalized(queryTopDocs)) { - scoreNormalizationTechnique.normalize(queryTopDocs); + scoreNormalizationTechnique.normalize(normalizeScoresDTO); } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 9a969e71b..a4ad9f2d4 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -27,8 +27,10 @@ import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.RRFProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory; import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; @@ -143,12 +145,19 @@ public void testSearchPhaseResultsProcessors() { Map> searchPhaseResultsProcessors = plugin .getSearchPhaseResultsProcessors(searchParameters); assertNotNull(searchPhaseResultsProcessors); - assertEquals(1, searchPhaseResultsProcessors.size()); + assertEquals(2, searchPhaseResultsProcessors.size()); + // assert normalization processor conditions assertTrue(searchPhaseResultsProcessors.containsKey("normalization-processor")); org.opensearch.search.pipeline.Processor.Factory scoringProcessor = searchPhaseResultsProcessors.get( NormalizationProcessor.TYPE ); assertTrue(scoringProcessor instanceof NormalizationProcessorFactory); + // assert rrf processor conditions + assertTrue(searchPhaseResultsProcessors.containsKey("score-ranker-processor")); + org.opensearch.search.pipeline.Processor.Factory rankingProcessor = searchPhaseResultsProcessors.get( + RRFProcessor.TYPE + ); + assertTrue(rankingProcessor instanceof RRFProcessorFactory); } public void testGetSettings() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java new file mode 100644 index 000000000..4e9ab59e5 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/AbstractScoreHybridizationProcessorTests.java @@ -0,0 +1,152 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class AbstractScoreHybridizationProcessorTests extends OpenSearchTestCase { + private static final String TEST_TAG = "test_processor"; + private static final String TEST_DESCRIPTION = "Test Processor"; + + private TestScoreHybridizationProcessor processor; + private NormalizationProcessorWorkflow normalizationWorkflow; + + private static class TestScoreHybridizationProcessor extends AbstractScoreHybridizationProcessor { + private final String tag; + private final String description; + private final NormalizationProcessorWorkflow normalizationWorkflow1; + + TestScoreHybridizationProcessor(String tag, String description, NormalizationProcessorWorkflow normalizationWorkflow) { + this.tag = tag; + this.description = description; + normalizationWorkflow1 = normalizationWorkflow; + } + + @Override + void hybridizeScores( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional + ) { + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .pipelineProcessingContext(requestContextOptional.orElse(null)) + .build(); + normalizationWorkflow1.execute(normalizationExecuteDTO); + } + + @Override + public SearchPhaseName getBeforePhase() { + return SearchPhaseName.FETCH; + } + + @Override + public SearchPhaseName getAfterPhase() { + return SearchPhaseName.QUERY; + } + + @Override + public String getType() { + return "my_processor"; + } + + @Override + public String getTag() { + return tag; + } + + @Override + public String getDescription() { + return description; + } + + @Override + public boolean isIgnoreFailure() { + return false; + } + } + + @Before + public void setup() { + normalizationWorkflow = mock(NormalizationProcessorWorkflow.class); + + processor = new TestScoreHybridizationProcessor(TEST_TAG, TEST_DESCRIPTION, normalizationWorkflow); + } + + public void testProcessorMetadata() { + assertEquals(TEST_TAG, processor.getTag()); + assertEquals(TEST_DESCRIPTION, processor.getDescription()); + } + + public void testProcessWithExplanations() { + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + SearchPhaseContext context = mock(SearchPhaseContext.class); + SearchPhaseResults results = mock(SearchPhaseResults.class); + + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(context.getRequest()).thenReturn(searchRequest); + + AtomicArray resultsArray = new AtomicArray<>(1); + QuerySearchResult queryResult = createQuerySearchResult(); + resultsArray.set(0, queryResult); + when(results.getAtomicArray()).thenReturn(resultsArray); + + TestScoreHybridizationProcessor spyProcessor = spy(processor); + spyProcessor.process(results, context); + + verify(spyProcessor).hybridizeScores(any(SearchPhaseResults.class), any(SearchPhaseContext.class), any(Optional.class)); + verify(normalizationWorkflow).execute(any()); + } + + public void testProcess() { + SearchPhaseResults searchPhaseResult = mock(SearchPhaseResults.class); + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + PipelineProcessingContext requestContext = mock(PipelineProcessingContext.class); + + TestScoreHybridizationProcessor spyProcessor = spy(processor); + spyProcessor.process(searchPhaseResult, searchPhaseContext, requestContext); + + verify(spyProcessor).hybridizeScores(any(SearchPhaseResults.class), any(SearchPhaseContext.class), any(Optional.class)); + } + + private QuerySearchResult createQuerySearchResult() { + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE), + null + ); + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(0, 1.0f) }); + result.topDocs(new TopDocsAndMaxScore(topDocs, 1.0f), new DocValueFormat[0]); + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java similarity index 76% rename from src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java index 2e603d078..bfcd14251 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessorTests.java @@ -37,9 +37,10 @@ import java.util.TreeMap; import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_FLOATS_ASSERTION; import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; -public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { +public class ExplanationResponseProcessorTests extends OpenSearchTestCase { private static final String PROCESSOR_TAG = "mockTag"; private static final String DESCRIPTION = "mockDescription"; @@ -192,6 +193,119 @@ public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSucces assertOnExplanationResults(processedResponse, maxScore); } + @SneakyThrows + public void testProcessResponse_whenNullSearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = getSearchResponse(null); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenEmptySearchHits_thenNoOp() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits emptyHits = new SearchHits(new SearchHit[0], new TotalHits(0, TotalHits.Relation.EQUAL_TO), 0.0f); + SearchResponse searchResponse = getSearchResponse(emptyHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenNullExplanation_thenSkipProcessing() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + for (SearchHit hit : searchHits.getHits()) { + hit.explanation(null); + } + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertEquals(searchResponse, processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenInvalidExplanationPayload_thenHandleGracefully() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(1.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Set invalid payload + Map invalidPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + "invalid payload" + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(invalidPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + } + + @SneakyThrows + public void testProcessResponse_whenZeroScore_thenProcessCorrectly() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchHits searchHits = getSearchHits(0.0f); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + assertNotNull(processedResponse); + assertEquals(0.0f, processedResponse.getHits().getMaxScore(), DELTA_FOR_SCORE_ASSERTION); + } + + @SneakyThrows + public void testProcessResponse_whenScoreIsNaN_thenExplanationUsesZero() { + ExplanationResponseProcessor processor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + // Create SearchHits with NaN score + SearchHits searchHits = getSearchHits(Float.NaN); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext context = new PipelineProcessingContext(); + + // Setup explanation payload + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + context.setAttribute(org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, explanationPayload); + + // Process response + SearchResponse processedResponse = processor.processResponse(searchRequest, searchResponse, context); + + // Verify results + assertNotNull(processedResponse); + SearchHit[] hits = processedResponse.getHits().getHits(); + assertNotNull(hits); + assertTrue(hits.length > 0); + + // Verify that the explanation uses 0.0f when input score was NaN + Explanation explanation = hits[0].getExplanation(); + assertNotNull(explanation); + assertEquals(0.0f, (float) explanation.getValue(), DELTA_FOR_FLOATS_ASSERTION); + } + private static SearchHits getSearchHits(float maxScore) { int numResponses = 1; int numIndices = 2; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 87dac8674..9f67327f1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -273,8 +273,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl ); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -329,8 +328,7 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormalization() { diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 9969081a6..61828d822 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -81,13 +81,15 @@ public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationC searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); } @@ -123,19 +125,22 @@ public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); } + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); SearchRequest searchRequest = mock(SearchRequest.class); SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder(); searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.empty(), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.empty()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); } @@ -194,13 +199,15 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -260,13 +267,15 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); @@ -318,16 +327,15 @@ public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchRe searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - expectThrows( - IllegalStateException.class, - () -> normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ) - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + expectThrows(IllegalStateException.class, () -> normalizationProcessorWorkflow.execute(normalizationExecuteDTO)); } public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { @@ -376,13 +384,15 @@ public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResu searchSourceBuilder.from(0); when(searchPhaseContext.getRequest()).thenReturn(searchRequest); when(searchRequest.source()).thenReturn(searchSourceBuilder); - normalizationProcessorWorkflow.execute( - querySearchResults, - Optional.of(fetchSearchResult), - ScoreNormalizationFactory.DEFAULT_METHOD, - ScoreCombinationFactory.DEFAULT_METHOD, - searchPhaseContext - ); + NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(Optional.of(fetchSearchResult)) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .combinationTechnique(ScoreCombinationFactory.DEFAULT_METHOD) + .searchPhaseContext(searchPhaseContext) + .build(); + + normalizationProcessorWorkflow.execute(normalizationExecuteDTO); TestUtils.assertQueryResultScores(querySearchResults); TestUtils.assertFetchResultScores(fetchSearchResult, 4); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java new file mode 100644 index 000000000..fccabab5c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; + +public class RRFProcessorIT extends BaseNeuralSearchIT { + + private int currentDoc = 1; + private static final String RRF_INDEX_NAME = "rrf-index"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; + private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline"; + + private static final int RRF_DIMENSION = 5; + + @SneakyThrows + public void testRRF_whenValidInput_thenSucceed() { + try { + createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING); + prepareKnnIndex( + RRF_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE)) + ); + addDocuments(); + createDefaultRRFSearchPipeline(); + + HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(); + + Map results = search( + RRF_INDEX_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE) + ); + Map hits = (Map) results.get("hits"); + ArrayList> hitsList = (ArrayList>) hits.get("hits"); + assertEquals(3, hitsList.size()); + assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE); + } + } + + private HybridQueryBuilder getHybridQueryBuilder() { + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco"); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding") + .k(5) + .vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(knnQueryBuilder); + return hybridQueryBuilder; + } + + @SneakyThrows + private void addDocuments() { + addDocument( + "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .", + "4319130149.jpg" + ); + addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); + addDocument( + "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .", + "2664027527.jpg" + ); + addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); + addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); + } + + @SneakyThrows + private void addDocument(String description, String imageText) { + addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java new file mode 100644 index 000000000..753c0b8fe --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -0,0 +1,259 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RRFProcessorTests extends OpenSearchTestCase { + + @Mock + private ScoreNormalizationTechnique mockNormalizationTechnique; + @Mock + private ScoreCombinationTechnique mockCombinationTechnique; + @Mock + private NormalizationProcessorWorkflow mockNormalizationWorkflow; + @Mock + private SearchPhaseResults mockSearchPhaseResults; + @Mock + private SearchPhaseContext mockSearchPhaseContext; + @Mock + private QueryPhaseResultConsumer mockQueryPhaseResultConsumer; + + private RRFProcessor rrfProcessor; + private static final String TAG = "tag"; + private static final String DESCRIPTION = "description"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + MockitoAnnotations.openMocks(this); + rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow); + } + + @SneakyThrows + public void testGetType() { + assertEquals(RRFProcessor.TYPE, rrfProcessor.getType()); + } + + @SneakyThrows + public void testGetBeforePhase() { + assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase()); + } + + @SneakyThrows + public void testGetAfterPhase() { + assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase()); + } + + @SneakyThrows + public void testIsIgnoreFailure() { + assertFalse(rrfProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() { + rrfProcessor.process(null, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() { + rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenValidHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + SearchRequest searchRequest = new SearchRequest(); + searchRequest.source(new SearchSourceBuilder()); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); + } + + @SneakyThrows + public void testProcess_whenValidNonHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(false); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow, never()).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); + } + + @SneakyThrows + public void testGetTag() { + assertEquals(TAG, rrfProcessor.getTag()); + } + + @SneakyThrows + public void testGetDescription() { + assertEquals(DESCRIPTION, rrfProcessor.getDescription()); + } + + @SneakyThrows + public void testShouldSkipProcessor() { + assertTrue(rrfProcessor.shouldSkipProcessor(null)); + assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults)); + + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + + atomicArray.set(0, createQuerySearchResult(true)); + assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + } + + @SneakyThrows + public void testGetQueryPhaseSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(2); + atomicArray.set(0, createQuerySearchResult(true)); + atomicArray.set(1, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + List results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer); + assertEquals(2, results.size()); + assertNotNull(results.get(0)); + assertNotNull(results.get(1)); + } + + @SneakyThrows + public void testGetFetchSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(true)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + Optional result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer); + assertFalse(result.isPresent()); + } + + @SneakyThrows + public void testProcess_whenExplainIsTrue_thenExplanationIsAdded() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + SearchRequest searchRequest = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + sourceBuilder.explain(true); + searchRequest.source(sourceBuilder); + when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + // Capture the actual request + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass( + NormalizationProcessorWorkflowExecuteRequest.class + ); + verify(mockNormalizationWorkflow).execute(requestCaptor.capture()); + + // Verify the captured request + NormalizationProcessorWorkflowExecuteRequest capturedRequest = requestCaptor.getValue(); + assertTrue(capturedRequest.isExplain()); + assertNull(capturedRequest.getPipelineProcessingContext()); + } + + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { + ShardId shardId = new ShardId("index", "uuid", 0); + OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); + SearchRequest searchRequest = new SearchRequest("index"); + searchRequest.source(new SearchSourceBuilder()); + searchRequest.allowPartialSearchResults(true); + + int numberOfShards = 1; + AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY); + float indexBoost = 1.0f; + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = null; + String[] indexRoutings = Strings.EMPTY_ARRAY; + + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + numberOfShards, + aliasFilter, + indexBoost, + nowInMillis, + clusterAlias, + indexRoutings + ); + + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", shardId, clusterAlias, originalIndices), + shardSearchRequest + ); + result.from(0).size(10); + + ScoreDoc[] scoreDocs; + if (isHybrid) { + scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) }; + } else { + scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) }; + } + + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f); + result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]); + + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index b2b0007f6..fe7192ecd 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -12,6 +12,7 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; + import org.opensearch.test.OpenSearchTestCase; import lombok.SneakyThrows; @@ -22,7 +23,11 @@ public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); - scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(List.of()) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); } @SneakyThrows @@ -36,7 +41,11 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -68,7 +77,11 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); @@ -106,7 +119,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); @@ -177,7 +194,11 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn SEARCH_SHARD ) ); - scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(queryTopDocs) + .normalizationTechnique(ScoreNormalizationFactory.DEFAULT_METHOD) + .build(); + scoreNormalizationMethod.normalizeScores(normalizeScoresDTO); assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); // shard one diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 8fedd1fca..e42c9023b 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -29,7 +29,6 @@ import java.util.function.Supplier; import java.util.stream.IntStream; -import org.apache.commons.lang3.StringUtils; import org.apache.commons.lang3.tuple.Pair; import org.junit.Before; import org.mockito.ArgumentCaptor; @@ -246,6 +245,31 @@ public void testExecute_withListTypeInput_successful() { verify(handler).accept(any(IngestDocument.class), isNull()); } + public void testExecute_SimpleTypeWithEmptyStringValue_throwIllegalArgumentException() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", " "); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_listHasEmptyStringValue_throwIllegalArgumentException() { + List list1 = ImmutableList.of("", "test2", "test3"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", list1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel1MapConfig(); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + public void testExecute_listHasNonStringValue_throwIllegalArgumentException() { List list2 = ImmutableList.of(1, 2, 3); Map sourceAndMetadata = new HashMap<>(); @@ -614,6 +638,20 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } + public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test3", " "); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("key2", map2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextEmbeddingProcessor processor = createInstanceWithLevel2MapConfig(); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map ret = createMaxDepthLimitExceedMap(() -> 1); Map sourceAndMetadata = new HashMap<>(); @@ -995,79 +1033,6 @@ public void testBuildVectorOutput_withNestedListHasNotForEmbeddingField_Level2_s assertNotNull(nestedObj.get(1).get("vectorField")); } - @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withPlainString_EmptyString_skipped() { - Map config = createPlainStringConfiguration(); - IngestDocument ingestDocument = createPlainIngestDocument(); - Map sourceAndMetadata = ingestDocument.getSourceAndMetadata(); - sourceAndMetadata.put("oriKey1", StringUtils.EMPTY); - - TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createRandomOneDimensionalMockVector(6, 100, 0.0f, 1.0f); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); - - /** IngestDocument - * "oriKey1": "", - * "oriKey2": "oriValue2", - * "oriKey3": "oriValue3", - * "oriKey4": "oriValue4", - * "oriKey5": "oriValue5", - * "oriKey6": [ - * "oriValue6", - * "oriValue7" - * ] - * - */ - assertEquals(11, sourceAndMetadata.size()); - assertFalse(sourceAndMetadata.containsKey("oriKey1_knn")); - } - - @SuppressWarnings("unchecked") - public void testBuildVectorOutput_withNestedField_EmptyString_skipped() { - Map config = createNestedMapConfiguration(); - IngestDocument ingestDocument = createNestedMapIngestDocument(); - Map favorites = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); - Map favorite = (Map) favorites.get("favorite"); - favorite.put("movie", StringUtils.EMPTY); - - TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); - Map knnMap = processor.buildMapWithTargetKeys(ingestDocument); - List> modelTensorList = createRandomOneDimensionalMockVector(1, 100, 0.0f, 1.0f); - processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); - - /** - * "favorites": { - * "favorite": { - * "movie": "", - * "actor": "Charlie Chaplin", - * "games" : { - * "adventure": { - * "action": "overwatch", - * "rpg": "elden ring" - * } - * } - * } - * } - */ - Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); - assertNotNull(favoritesMap); - Map favoriteMap = (Map) favoritesMap.get("favorite"); - assertNotNull(favoriteMap); - - Map favoriteGames = (Map) favoriteMap.get("games"); - assertNotNull(favoriteGames); - Map adventure = (Map) favoriteGames.get("adventure"); - List adventureKnnVector = (List) adventure.get("with_action_knn"); - assertNotNull(adventureKnnVector); - assertEquals(100, adventureKnnVector.size()); - for (float vector : adventureKnnVector) { - assertTrue(vector >= 0.0f && vector <= 1.0f); - } - - assertFalse(favoriteMap.containsKey("favorite_movie_knn")); - } - public void test_updateDocument_appendVectorFieldsToDocument_successful() { Map config = createPlainStringConfiguration(); IngestDocument ingestDocument = createPlainIngestDocument(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java new file mode 100644 index 000000000..39b4dd4e3 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/RRFScoreCombinationTechniqueTests.java @@ -0,0 +1,71 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.combination; + +import java.util.List; + +public class RRFScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { + + private static final int RANK_CONSTANT = 60; + private RRFScoreCombinationTechnique combinationTechnique; + + public RRFScoreCombinationTechniqueTests() { + this.expectedScoreFunction = (scores, weights) -> RRF(scores, weights); + combinationTechnique = new RRFScoreCombinationTechnique(); + } + + public void testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(); + testLogic_whenAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { + ScoreCombinationTechnique technique = new RRFScoreCombinationTechnique(); + testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores(technique); + } + + public void testDescribe() { + String description = combinationTechnique.describe(); + assertEquals("rrf", description); + } + + public void testCombineWithEmptyInput() { + float[] scores = new float[0]; + float result = combinationTechnique.combine(scores); + assertEquals(0.0f, result, 0.001f); + } + + public void testCombineWithSingleScore() { + float[] scores = new float[] { 0.5f }; + float result = combinationTechnique.combine(scores); + assertEquals(0.5f, result, 0.001f); + } + + public void testCombineWithMultipleScores() { + float[] scores = new float[] { 0.8f, 0.6f, 0.4f }; + float result = combinationTechnique.combine(scores); + float expected = 0.8f + 0.6f + 0.4f; + assertEquals(expected, result, 0.001f); + } + + public void testCombineWithZeroScores() { + float[] scores = new float[] { 0.0f, 0.0f }; + float result = combinationTechnique.combine(scores); + assertEquals(0.0f, result, 0.001f); + } + + public void testCombineWithNullInput() { + IllegalArgumentException exception = expectThrows(IllegalArgumentException.class, () -> combinationTechnique.combine(null)); + assertEquals("scores array cannot be null", exception.getMessage()); + } + + private float RRF(List scores, List weights) { + float sumScores = 0.0f; + for (float score : scores) { + sumScores += score; + } + return sumScores; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java index b36a6b492..5ca534dac 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationFactoryTests.java @@ -34,6 +34,14 @@ public void testGeometricWeightedMean_whenCreatingByName_thenReturnCorrectInstan assertTrue(scoreCombinationTechnique instanceof GeometricMeanScoreCombinationTechnique); } + public void testRRF_whenCreatingByName_thenReturnCorrectInstance() { + ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + ScoreCombinationTechnique scoreCombinationTechnique = scoreCombinationFactory.createCombination("rrf"); + + assertNotNull(scoreCombinationTechnique); + assertTrue(scoreCombinationTechnique instanceof RRFScoreCombinationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java similarity index 97% rename from src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java index 9e00e3833..009681116 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreNormalizationUtilTests.java @@ -12,7 +12,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { +public class ScoreNormalizationUtilTests extends OpenSearchQueryTestCase { public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index 9895d5b97..4cf1457ac 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -167,22 +167,21 @@ public void testWeightsParams_whenInvalidValues_thenFail() { String tag = "tag"; String description = "description"; boolean ignoreFailure = false; + + // First value is always 0.5 + double first = 0.5; + // Second value is random between 0.3 and 1.0 + double second = 0.3 + (RandomizedTest.randomDouble() * 0.7); + // Third value is random between 0.3 and 1.0 + double third = 0.3 + (RandomizedTest.randomDouble() * 0.7); + // This ensures minimum sum of 1.1 (0.5 + 0.3 + 0.3) + Map config = new HashMap<>(); config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); config.put( COMBINATION_CLAUSE, new HashMap<>( - Map.of( - TECHNIQUE, - "arithmetic_mean", - PARAMETERS, - new HashMap<>( - Map.of( - "weights", - Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble(), RandomizedTest.randomDouble()) - ) - ) - ) + Map.of(TECHNIQUE, "arithmetic_mean", PARAMETERS, new HashMap<>(Map.of("weights", Arrays.asList(first, second, third)))) ) ); Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java new file mode 100644 index 000000000..3097402a0 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/RRFProcessorFactoryTests.java @@ -0,0 +1,214 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.RRFProcessor; +import org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; +import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +import static org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory.NORMALIZATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.PARAMETERS; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.COMBINATION_CLAUSE; +import static org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory.TECHNIQUE; + +public class RRFProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoValuesPassed_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testCombinationParams_whenValidValues_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNegative_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", -1))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: -1") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsTooLarge_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 50_000))))); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue( + exception.getMessage().contains("rank constant must be in the interval between 1 and 10000, submitted rank constant: 50000") + ); + } + + @SneakyThrows + public void testInvalidCombinationParams_whenRankIsNotNumeric_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", "string")))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("parameter [rank_constant] must be an integer")); + } + + @SneakyThrows + public void testInvalidCombinationName_whenUnsupportedFunction_thenFail() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put( + COMBINATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, "my_function", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100)))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> rrfProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("provided combination technique is not supported")); + } + + @SneakyThrows + public void testInvalidTechniqueType_whenPassingNormalization_thenSuccessful() { + RRFProcessorFactory rrfProcessorFactory = new RRFProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + + Map config = new HashMap<>(); + config.put(COMBINATION_CLAUSE, new HashMap<>(Map.of(TECHNIQUE, "rrf", PARAMETERS, new HashMap<>(Map.of("rank_constant", 100))))); + config.put( + NORMALIZATION_CLAUSE, + new HashMap<>(Map.of(TECHNIQUE, ArithmeticMeanScoreCombinationTechnique.TECHNIQUE_NAME, PARAMETERS, new HashMap<>(Map.of()))) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + SearchPhaseResultsProcessor searchPhaseResultsProcessor = rrfProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + assertRRFProcessor(searchPhaseResultsProcessor); + } + + private static void assertRRFProcessor(SearchPhaseResultsProcessor searchPhaseResultsProcessor) { + assertNotNull(searchPhaseResultsProcessor); + assertTrue(searchPhaseResultsProcessor instanceof RRFProcessor); + RRFProcessor rrfProcessor = (RRFProcessor) searchPhaseResultsProcessor; + assertEquals("score-ranker-processor", rrfProcessor.getType()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index 734f9bb57..fc1663d75 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -13,9 +13,10 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; /** - * Abstracts normalization of scores based on min-max method + * Abstracts normalization of scores based on L2 method */ public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; @@ -37,7 +38,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -86,7 +91,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -163,7 +172,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index c7692b407..85c54ea3a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -35,7 +36,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -77,7 +82,11 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), @@ -135,7 +144,11 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the SEARCH_SHARD ) ); - normalizationTechnique.normalize(compoundTopDocs); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java new file mode 100644 index 000000000..da6d37bd7 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/RRFNormalizationTechniqueTests.java @@ -0,0 +1,296 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.normalization; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.NormalizeScoresDTO; +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.math.BigDecimal; +import java.math.RoundingMode; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; + +/** + * Abstracts testing of normalization of scores based on RRF method + */ +public class RRFNormalizationTechniqueTests extends OpenSearchQueryTestCase { + static final int RANK_CONSTANT = 60; + private ScoreNormalizationUtil scoreNormalizationUtil = new ScoreNormalizationUtil(); + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + + public void testDescribe() { + // verify with default values for parameters + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [60]", normalizationTechnique.describe()); + + // verify when parameter values are set + normalizationTechnique = new RRFNormalizationTechnique(Map.of("rank_constant", 25), scoreNormalizationUtil); + assertEquals("rrf, rank_constant [25]", normalizationTechnique.describe()); + } + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scores = { 0.5f, 0.2f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } + ) + ), + false, + SEARCH_SHARD + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ) + ), + false, + SEARCH_SHARD + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresQuery1 = { 0.5f, 0.2f }; + float[] scoresQuery2 = { 0.9f, 0.7f, 0.1f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresQuery1[0]), new ScoreDoc(4, scoresQuery1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresQuery2[0]), + new ScoreDoc(4, scoresQuery2[1]), + new ScoreDoc(2, scoresQuery2[2]) } + ) + ), + false, + SEARCH_SHARD + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false, + SEARCH_SHARD + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + float[] scoresShard1Query1 = { 0.5f, 0.2f }; + float[] scoresShard1and2Query3 = { 0.9f, 0.7f, 0.1f, 0.8f, 0.7f, 0.6f, 0.5f }; + float[] scoresShard2Query2 = { 2.9f, 0.7f }; + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, scoresShard1Query1[0]), new ScoreDoc(4, scoresShard1Query1[1]) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[0]), + new ScoreDoc(4, scoresShard1and2Query3[1]), + new ScoreDoc(2, scoresShard1and2Query3[2]) } + ) + ), + false, + SEARCH_SHARD + ), + new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, scoresShard2Query2[0]), new ScoreDoc(9, scoresShard2Query2[1]) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, scoresShard1and2Query3[3]), + new ScoreDoc(9, scoresShard1and2Query3[4]), + new ScoreDoc(10, scoresShard1and2Query3[5]), + new ScoreDoc(15, scoresShard1and2Query3[6]) } + ) + ), + false, + SEARCH_SHARD + ) + ); + NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder() + .queryTopDocs(compoundTopDocs) + .normalizationTechnique(normalizationTechnique) + .build(); + normalizationTechnique.normalize(normalizeScoresDTO); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, rrfNorm(0)), new ScoreDoc(4, rrfNorm(1)), new ScoreDoc(2, rrfNorm(2)) } + ) + ), + false, + SEARCH_SHARD + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, rrfNorm(0)), new ScoreDoc(9, rrfNorm(1)) } + ), + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(3, rrfNorm(3)), + new ScoreDoc(9, rrfNorm(4)), + new ScoreDoc(10, rrfNorm(5)), + new ScoreDoc(15, rrfNorm(6)) } + ) + ), + false, + SEARCH_SHARD + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); + } + } + + public void testExplainWithEmptyAndNullList() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + normalizationTechnique.explain(List.of()); + + List compoundTopDocs = new ArrayList<>(); + compoundTopDocs.add(null); + normalizationTechnique.explain(compoundTopDocs); + } + + public void testExplainWithSingleTopDocs() { + RRFNormalizationTechnique normalizationTechnique = new RRFNormalizationTechnique(Map.of(), scoreNormalizationUtil); + CompoundTopDocs topDocs = createCompoundTopDocs(new float[] { 0.8f }, 1); + List queryTopDocs = Collections.singletonList(topDocs); + + Map explanation = normalizationTechnique.explain(queryTopDocs); + + assertNotNull(explanation); + assertEquals(1, explanation.size()); + assertTrue(explanation.containsKey(new DocIdAtSearchShard(0, new SearchShard("test_index", 0, "uuid")))); + } + + private float rrfNorm(int rank) { + // 1.0f / (float) (rank + RANK_CONSTANT + 1); + return BigDecimal.ONE.divide(BigDecimal.valueOf(rank + RANK_CONSTANT + 1), 10, RoundingMode.HALF_UP).floatValue(); + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } + + private CompoundTopDocs createCompoundTopDocs(float[] scores, int size) { + ScoreDoc[] scoreDocs = new ScoreDoc[size]; + for (int i = 0; i < size; i++) { + scoreDocs[i] = new ScoreDoc(i, scores[i]); + } + TopDocs singleTopDocs = new TopDocs(new TotalHits(size, TotalHits.Relation.EQUAL_TO), scoreDocs); + + List topDocsList = Collections.singletonList(singleTopDocs); + TopDocs topDocs = new TopDocs(new TotalHits(size, TotalHits.Relation.EQUAL_TO), scoreDocs); + SearchShard searchShard = new SearchShard("test_index", 0, "uuid"); + + return new CompoundTopDocs( + new TotalHits(size, TotalHits.Relation.EQUAL_TO), + topDocsList, + false, // isSortEnabled + searchShard + ); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java index d9dcd5540..cecdf8779 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactoryTests.java @@ -26,6 +26,14 @@ public void testL2Norm_whenCreatingByName_thenReturnCorrectInstance() { assertTrue(scoreNormalizationTechnique instanceof L2ScoreNormalizationTechnique); } + public void testRRFNorm_whenCreatingByName_thenReturnCorrectInstance() { + ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); + ScoreNormalizationTechnique scoreNormalizationTechnique = scoreNormalizationFactory.createNormalization("rrf"); + + assertNotNull(scoreNormalizationTechnique); + assertTrue(scoreNormalizationTechnique instanceof RRFNormalizationTechnique); + } + public void testUnsupportedTechnique_whenPassingInvalidName_thenFail() { ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); IllegalArgumentException illegalArgumentException = expectThrows( diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java index b7e4f753a..c6eaa21ff 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -58,7 +58,8 @@ public class HybridQueryExplainIT extends BaseNeuralSearchIT { private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); - private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + private static final String NORMALIZATION_SEARCH_PIPELINE = "normalization-search-pipeline"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[768]; @@ -78,7 +79,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -95,7 +96,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -187,7 +188,7 @@ public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); assertEquals(3, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -196,7 +197,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() try { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); createSearchPipeline( - SEARCH_PIPELINE, + NORMALIZATION_SEARCH_PIPELINE, NORMALIZATION_TECHNIQUE_L2, DEFAULT_COMBINATION_METHOD, Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), @@ -217,7 +218,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() hybridQueryBuilder, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // basic sanity check for search hits @@ -322,7 +323,7 @@ public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() assertEquals(0, getListOfValues(explanationsHit4, "details").size()); assertTrue((double) explanationsHit4.get("value") > 0.0f); } finally { - wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -331,7 +332,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe try { initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); // create search pipeline with normalization processor, no explanation response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); @@ -348,7 +349,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe hybridQueryBuilderNeuralThenTerm, null, 10, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); // Assert // search hits @@ -463,7 +464,7 @@ public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenRe assertEquals("boost", explanationsHit3Details.get("description")); assertEquals(0, getListOfValues(explanationsHit3Details, "details").size()); } finally { - wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -472,7 +473,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -483,7 +484,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -521,7 +522,7 @@ public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); } } @@ -530,7 +531,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { try { initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); // create search pipeline with both normalization processor and explain response processor - createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + createSearchPipeline(NORMALIZATION_SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); @@ -543,7 +544,7 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { hybridQueryBuilder, null, MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, - Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + Map.of("search_pipeline", NORMALIZATION_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) ); List> hitsNestedList = getNestedHits(searchResponseAsMap); @@ -581,7 +582,136 @@ public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { } assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); } finally { - wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, NORMALIZATION_SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenRRFProcessor_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, true); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain, hit 1 + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "rrf combination of:"; + assertEquals(expectedTopLevelDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + // two sub-queries meaning we do have two detail objects with separate query level details + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit1.get("description")); + assertTrue((double) explanationsHit1.get("value") > 0.5f); + assertEquals(0, ((List) explanationsHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); + assertEquals("rrf, rank_constant [60] normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals("within top 10", explanationsHit2.get("description")); + assertTrue((double) explanationsHit2.get("value") > 0.0f); + assertEquals(0, ((List) explanationsHit2.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit2.get("description")); + List> hit2Details = getListOfValues(explanationForHit2, "details"); + assertEquals(2, hit2Details.size()); + + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit1.get("description")); + assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue((double) hit2DetailsForHit2.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit2DetailsForHit2.get("description")); + assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit3.get("description")); + List> hit3Details = getListOfValues(explanationForHit3, "details"); + assertEquals(1, hit3Details.size()); + + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > .0f); + assertEquals("rrf, rank_constant [60] normalization of:", hit3DetailsForHit1.get("description")); + assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); + + Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); + assertEquals("within top 10", explanationsHit3.get("description")); + assertEquals(0, getListOfValues(explanationsHit3, "details").size()); + assertTrue((double) explanationsHit3.get("value") > 0.0f); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = getValueByKey(searchHit4, "_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(1, hit4Details.size()); + + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > DELTA_FOR_SCORE_ASSERTION); + assertEquals("rrf, rank_constant [60] normalization of:", hit4DetailsForHit1.get("description")); + assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); + + Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit4.get("description")); + assertEquals(0, getListOfValues(explanationsHit4, "details").size()); + assertTrue((double) explanationsHit4.get("value") > 0.0f); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, RRF_SEARCH_PIPELINE); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index a1e8210e6..9c162ce11 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -53,6 +53,8 @@ public abstract class OpenSearchQueryTestCase extends OpenSearchTestCase { + protected static final float DELTA_FOR_ASSERTION = 0.001f; + protected final MapperService createMapperService(Version version, XContentBuilder mapping) throws IOException { IndexMetadata meta = IndexMetadata.builder("index") .settings(Settings.builder().put("index.version.created", version)) diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java index e0d95f24e..f6948e3e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridCollectorManagerTests.java @@ -85,7 +85,6 @@ public class HybridCollectorManagerTests extends OpenSearchQueryTestCase { private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String QUERY1 = "hello"; private static final String QUERY2 = "hi"; - private static final float DELTA_FOR_ASSERTION = 0.001f; protected static final String QUERY3 = "everyone"; @SneakyThrows diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java index 196014220..f91dae327 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryScoreDocsMergerTests.java @@ -21,8 +21,6 @@ public class HybridQueryScoreDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - public void testIncorrectInput_whenScoreDocsAreNullOrNotEnoughElements_thenFail() { HybridQueryScoreDocsMerger scoreDocsMerger = new HybridQueryScoreDocsMerger<>(); TopDocsMerger topDocsMerger = new TopDocsMerger(null); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java index 9c2718687..2e064913f 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/TopDocsMergerTests.java @@ -27,8 +27,6 @@ public class TopDocsMergerTests extends OpenSearchQueryTestCase { - private static final float DELTA_FOR_ASSERTION = 0.001f; - @SneakyThrows public void testMergeScoreDocs_whenBothTopDocsHasHits_thenSuccessful() { TopDocsMerger topDocsMerger = new TopDocsMerger(null); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 509527aeb..3d5929767 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -93,6 +93,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; + protected static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -1555,4 +1556,45 @@ protected enum ProcessorType { SPARSE_ENCODING, SPARSE_ENCODING_PRUNE } + + @SneakyThrows + protected void createDefaultRRFSearchPipeline() { + createRRFSearchPipeline(RRF_SEARCH_PIPELINE, false); + } + + @SneakyThrows + protected void createRRFSearchPipeline(final String pipelineName, boolean addExplainResponseProcessor) { + XContentBuilder builder = XContentFactory.jsonBuilder() + .startObject() + .field("description", "Post processor for hybrid search") + .startArray("phase_results_processors") + .startObject() + .startObject("score-ranker-processor") + .startObject("combination") + .field("technique", "rrf") + .endObject() + .endObject() + .endObject() + .endArray(); + + if (addExplainResponseProcessor) { + builder.startArray("response_processors") + .startObject() + .startObject("hybrid_score_explanation") + .endObject() + .endObject() + .endArray(); + } + + String requestBody = builder.endObject().toString(); + + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", pipelineName), + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } }