From d52c3e49962b30e3744cdd01b00a3f8df9acf8bc Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Mon, 16 Dec 2024 14:13:56 -0800 Subject: [PATCH] Explainability in hybrid query (#970) (#1014) (#1023) * Added Explainability support for hybrid query (cherry picked from commit 393d49ab8d2f2a045f24700ee3de6c1acc98aca8) Signed-off-by: Martin Gaievski Co-authored-by: Martin Gaievski --- CHANGELOG.md | 1 + .../neuralsearch/plugin/NeuralSearch.java | 7 +- .../processor/CompoundTopDocs.java | 25 +- .../ExplanationResponseProcessor.java | 132 ++++ .../processor/NormalizationProcessor.java | 40 +- .../NormalizationProcessorWorkflow.java | 80 +- ...zationProcessorWorkflowExecuteRequest.java | 32 + .../neuralsearch/processor/SearchShard.java | 29 + ...ithmeticMeanScoreCombinationTechnique.java | 12 +- ...eometricMeanScoreCombinationTechnique.java | 12 +- ...HarmonicMeanScoreCombinationTechnique.java | 12 +- .../combination/ScoreCombinationUtil.java | 4 +- .../processor/combination/ScoreCombiner.java | 85 ++- .../explain/CombinedExplanationDetails.java | 20 + .../processor/explain/DocIdAtSearchShard.java | 18 + .../explain/ExplainableTechnique.java | 34 + .../processor/explain/ExplanationDetails.java | 28 + .../processor/explain/ExplanationPayload.java | 25 + .../processor/explain/ExplanationUtils.java | 60 ++ .../ExplanationResponseProcessorFactory.java | 29 + .../L2ScoreNormalizationTechnique.java | 38 +- .../MinMaxScoreNormalizationTechnique.java | 86 ++- .../normalization/ScoreNormalizer.java | 21 + .../neuralsearch/query/HybridQueryWeight.java | 29 +- .../processor/CompoundTopDocsTests.java | 18 +- .../ExplanationPayloadProcessorTests.java | 446 +++++++++++ .../NormalizationProcessorTests.java | 6 +- .../ScoreCombinationTechniqueTests.java | 11 +- .../ScoreNormalizationTechniqueTests.java | 20 +- ...ticMeanScoreCombinationTechniqueTests.java | 4 +- ...ricMeanScoreCombinationTechniqueTests.java | 4 +- ...nicMeanScoreCombinationTechniqueTests.java | 4 +- .../explain/ExplanationUtilsTests.java | 115 +++ ...lanationResponseProcessorFactoryTests.java | 112 +++ .../L2ScoreNormalizationTechniqueTests.java | 26 +- ...inMaxScoreNormalizationTechniqueTests.java | 26 +- .../query/HybridQueryExplainIT.java | 722 ++++++++++++++++++ .../neuralsearch/query/HybridQuerySortIT.java | 137 ++++ .../query/HybridQueryWeightTests.java | 6 +- .../neuralsearch/BaseNeuralSearchIT.java | 31 +- .../neuralsearch/util/TestUtils.java | 15 +- 41 files changed, 2467 insertions(+), 95 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 63aca0c29..3be97248d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x) ### Features ### Enhancements +- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970)) ### Bug Fixes - Address inconsistent scoring in hybrid query results ([#998](https://github.com/opensearch-project/neural-search/pull/998)) ### Infrastructure diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 8b173ba81..1350a7963 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -32,12 +32,14 @@ import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.TextChunkingProcessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; @@ -80,6 +82,7 @@ public class NeuralSearch extends Plugin implements ActionPlugin, SearchPlugin, private NormalizationProcessorWorkflow normalizationProcessorWorkflow; private final ScoreNormalizationFactory scoreNormalizationFactory = new ScoreNormalizationFactory(); private final ScoreCombinationFactory scoreCombinationFactory = new ScoreCombinationFactory(); + public static final String EXPLANATION_RESPONSE_KEY = "explanation_response"; @Override public Collection createComponents( @@ -181,7 +184,9 @@ public Map topDocs; @Setter private List scoreDocs; + @Getter + private SearchShard searchShard; - public CompoundTopDocs(final TotalHits totalHits, final List topDocs, final boolean isSortEnabled) { - initialize(totalHits, topDocs, isSortEnabled); + public CompoundTopDocs( + final TotalHits totalHits, + final List topDocs, + final boolean isSortEnabled, + final SearchShard searchShard + ) { + initialize(totalHits, topDocs, isSortEnabled, searchShard); } - private void initialize(TotalHits totalHits, List topDocs, boolean isSortEnabled) { + private void initialize(TotalHits totalHits, List topDocs, boolean isSortEnabled, SearchShard searchShard) { this.totalHits = totalHits; this.topDocs = topDocs; scoreDocs = cloneLargestScoreDocs(topDocs, isSortEnabled); + this.searchShard = searchShard; } /** @@ -72,14 +82,17 @@ private void initialize(TotalHits totalHits, List topDocs, boolean isSo * 6, 0.15 * 0, 9549511920.4881596047 */ - public CompoundTopDocs(final TopDocs topDocs) { + public CompoundTopDocs(final QuerySearchResult querySearchResult) { + final TopDocs topDocs = querySearchResult.topDocs().topDocs; + final SearchShardTarget searchShardTarget = querySearchResult.getSearchShardTarget(); + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); boolean isSortEnabled = false; if (topDocs instanceof TopFieldDocs) { isSortEnabled = true; } ScoreDoc[] scoreDocs = topDocs.scoreDocs; if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) { - initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled); + initialize(topDocs.totalHits, new ArrayList<>(), isSortEnabled, searchShard); return; } // skipping first two elements, it's a start-stop element and delimiter for first series @@ -103,7 +116,7 @@ public CompoundTopDocs(final TopDocs topDocs) { scoreDocList.add(scoreDoc); } } - initialize(topDocs.totalHits, topDocsList, isSortEnabled); + initialize(topDocs.totalHits, topDocsList, isSortEnabled, searchShard); } private List cloneLargestScoreDocs(final List docs, boolean isSortEnabled) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java new file mode 100644 index 000000000..01cdfcb0d --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ExplanationResponseProcessor.java @@ -0,0 +1,132 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import org.apache.lucene.search.Explanation; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY; +import static org.opensearch.neuralsearch.processor.explain.ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR; + +/** + * Processor to add explanation details to search response + */ +@Getter +@AllArgsConstructor +public class ExplanationResponseProcessor implements SearchResponseProcessor { + + public static final String TYPE = "explanation_response_processor"; + + private final String description; + private final String tag; + private final boolean ignoreFailure; + + /** + * Add explanation details to search response if it is present in request context + */ + @Override + public SearchResponse processResponse(SearchRequest request, SearchResponse response) { + return processResponse(request, response, null); + } + + /** + * Combines explanation from processor with search hits level explanations and adds it to search response + */ + @Override + public SearchResponse processResponse( + final SearchRequest request, + final SearchResponse response, + final PipelineProcessingContext requestContext + ) { + if (Objects.isNull(requestContext) + || (Objects.isNull(requestContext.getAttribute(EXPLANATION_RESPONSE_KEY))) + || requestContext.getAttribute(EXPLANATION_RESPONSE_KEY) instanceof ExplanationPayload == false) { + return response; + } + // Extract explanation payload from context + ExplanationPayload explanationPayload = (ExplanationPayload) requestContext.getAttribute(EXPLANATION_RESPONSE_KEY); + Map explainPayload = explanationPayload.getExplainPayload(); + if (explainPayload.containsKey(NORMALIZATION_PROCESSOR)) { + // for score normalization, processor level explanations will be sorted in scope of each shard, + // and we are merging both into a single sorted list + SearchHits searchHits = response.getHits(); + SearchHit[] searchHitsArray = searchHits.getHits(); + // create a map of searchShard and list of indexes of search hit objects in search hits array + // the list will keep original order of sorting as per final search results + Map> searchHitsByShard = new HashMap<>(); + // we keep index for each shard, where index is a position in searchHitsByShard list + Map explainsByShardCount = new HashMap<>(); + // Build initial shard mappings + for (int i = 0; i < searchHitsArray.length; i++) { + SearchHit searchHit = searchHitsArray[i]; + SearchShardTarget searchShardTarget = searchHit.getShard(); + SearchShard searchShard = SearchShard.createSearchShard(searchShardTarget); + searchHitsByShard.computeIfAbsent(searchShard, k -> new ArrayList<>()).add(i); + explainsByShardCount.putIfAbsent(searchShard, -1); + } + // Process normalization details if available in correct format + if (explainPayload.get(NORMALIZATION_PROCESSOR) instanceof Map) { + @SuppressWarnings("unchecked") + Map> combinedExplainDetails = (Map< + SearchShard, + List>) explainPayload.get(NORMALIZATION_PROCESSOR); + // Process each search hit to add processor level explanations + for (SearchHit searchHit : searchHitsArray) { + SearchShard searchShard = SearchShard.createSearchShard(searchHit.getShard()); + int explanationIndexByShard = explainsByShardCount.get(searchShard) + 1; + CombinedExplanationDetails combinedExplainDetail = combinedExplainDetails.get(searchShard).get(explanationIndexByShard); + // Extract various explanation components + Explanation queryLevelExplanation = searchHit.getExplanation(); + ExplanationDetails normalizationExplanation = combinedExplainDetail.getNormalizationExplanations(); + ExplanationDetails combinationExplanation = combinedExplainDetail.getCombinationExplanations(); + // Create normalized explanations for each detail + Explanation[] normalizedExplanation = new Explanation[queryLevelExplanation.getDetails().length]; + for (int i = 0; i < queryLevelExplanation.getDetails().length; i++) { + normalizedExplanation[i] = Explanation.match( + // normalized score + normalizationExplanation.getScoreDetails().get(i).getKey(), + // description of normalized score + normalizationExplanation.getScoreDetails().get(i).getValue(), + // shard level details + queryLevelExplanation.getDetails()[i] + ); + } + // Create and set final explanation combining all components + Explanation finalExplanation = Explanation.match( + searchHit.getScore(), + // combination level explanation is always a single detail + combinationExplanation.getScoreDetails().get(0).getValue(), + normalizedExplanation + ); + searchHit.explanation(finalExplanation); + explainsByShardCount.put(searchShard, explanationIndexByShard); + } + } + } + return response; + } + + @Override + public String getType() { + return TYPE; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 0563c92a0..d2008ae97 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -20,6 +20,7 @@ import org.opensearch.search.SearchPhaseResult; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; @@ -43,7 +44,7 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { /** * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage - * are set as part of class constructor + * are set as part of class constructor. This method is called when there is no pipeline context * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution * @param searchPhaseContext {@link SearchContext} */ @@ -51,6 +52,31 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor { public void process( final SearchPhaseResults searchPhaseResult, final SearchPhaseContext searchPhaseContext + ) { + prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty()); + } + + /** + * Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage + * are set as part of class constructor + * @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution + * @param searchPhaseContext {@link SearchContext} + * @param requestContext {@link PipelineProcessingContext} processing context of search pipeline + * @param + */ + @Override + public void process( + final SearchPhaseResults searchPhaseResult, + final SearchPhaseContext searchPhaseContext, + final PipelineProcessingContext requestContext + ) { + prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext)); + } + + private void prepareAndExecuteNormalizationWorkflow( + SearchPhaseResults searchPhaseResult, + SearchPhaseContext searchPhaseContext, + Optional requestContextOptional ) { if (shouldSkipProcessor(searchPhaseResult)) { log.debug("Query results are not compatible with normalization processor"); @@ -58,7 +84,17 @@ public void process( } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); + boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain()) + && searchPhaseContext.getRequest().source().explain(); + NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResult) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .explain(explain) + .pipelineProcessingContext(requestContextOptional.orElse(null)) + .build(); + normalizationWorkflow.execute(request); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index c64f1c1f4..f2699d967 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -4,6 +4,7 @@ */ package org.opensearch.neuralsearch.processor; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.List; @@ -22,15 +23,23 @@ import org.opensearch.neuralsearch.processor.combination.CombineScoresDto; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.search.SearchHit; import org.opensearch.search.SearchHits; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; import org.opensearch.search.query.QuerySearchResult; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; + +import static org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY; import static org.opensearch.neuralsearch.processor.combination.ScoreCombiner.MAX_SCORE_WHEN_NO_HITS_FOUND; import static org.opensearch.neuralsearch.search.util.HybridSearchSortUtil.evaluateSortCriteria; @@ -57,22 +66,35 @@ public void execute( final ScoreNormalizationTechnique normalizationTechnique, final ScoreCombinationTechnique combinationTechnique ) { + NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder() + .querySearchResults(querySearchResults) + .fetchSearchResultOptional(fetchSearchResultOptional) + .normalizationTechnique(normalizationTechnique) + .combinationTechnique(combinationTechnique) + .explain(false) + .build(); + execute(request); + } + + public void execute(final NormalizationProcessorWorkflowExecuteRequest request) { // save original state - List unprocessedDocIds = unprocessedDocIds(querySearchResults); + List unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults()); // pre-process data log.debug("Pre-process query results"); - List queryTopDocs = getQueryTopDocs(querySearchResults); + List queryTopDocs = getQueryTopDocs(request.getQuerySearchResults()); + + explain(request, queryTopDocs); // normalize log.debug("Do score normalization"); - scoreNormalizer.normalizeScores(queryTopDocs, normalizationTechnique); + scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique()); CombineScoresDto combineScoresDTO = CombineScoresDto.builder() .queryTopDocs(queryTopDocs) - .scoreCombinationTechnique(combinationTechnique) - .querySearchResults(querySearchResults) - .sort(evaluateSortCriteria(querySearchResults, queryTopDocs)) + .scoreCombinationTechnique(request.getCombinationTechnique()) + .querySearchResults(request.getQuerySearchResults()) + .sort(evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs)) .build(); // combine @@ -82,7 +104,50 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(combineScoresDTO); - updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional, unprocessedDocIds); + updateOriginalFetchResults(request.getQuerySearchResults(), request.getFetchSearchResultOptional(), unprocessedDocIds); + } + + /** + * Collects explanations from normalization and combination techniques and save thme into pipeline context. Later that + * information will be read by the response processor to add it to search response + */ + private void explain(NormalizationProcessorWorkflowExecuteRequest request, List queryTopDocs) { + if (!request.isExplain()) { + return; + } + // build final result object with all explain related information + if (Objects.nonNull(request.getPipelineProcessingContext())) { + Sort sortForQuery = evaluateSortCriteria(request.getQuerySearchResults(), queryTopDocs); + Map normalizationExplain = scoreNormalizer.explain( + queryTopDocs, + (ExplainableTechnique) request.getNormalizationTechnique() + ); + Map> combinationExplain = scoreCombiner.explain( + queryTopDocs, + request.getCombinationTechnique(), + sortForQuery + ); + Map> combinedExplanations = new HashMap<>(); + for (Map.Entry> entry : combinationExplain.entrySet()) { + List combinedDetailsList = new ArrayList<>(); + for (ExplanationDetails explainDetail : entry.getValue()) { + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(explainDetail.getDocId(), entry.getKey()); + CombinedExplanationDetails combinedDetail = CombinedExplanationDetails.builder() + .normalizationExplanations(normalizationExplain.get(docIdAtSearchShard)) + .combinationExplanations(explainDetail) + .build(); + combinedDetailsList.add(combinedDetail); + } + combinedExplanations.put(entry.getKey(), combinedDetailsList); + } + + ExplanationPayload explanationPayload = ExplanationPayload.builder() + .explainPayload(Map.of(ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, combinedExplanations)) + .build(); + // store explain object to pipeline context + PipelineProcessingContext pipelineProcessingContext = request.getPipelineProcessingContext(); + pipelineProcessingContext.setAttribute(EXPLANATION_RESPONSE_KEY, explanationPayload); + } } /** @@ -93,7 +158,6 @@ public void execute( private List getQueryTopDocs(final List querySearchResults) { List queryTopDocs = querySearchResults.stream() .filter(searchResult -> Objects.nonNull(searchResult.topDocs())) - .map(querySearchResult -> querySearchResult.topDocs().topDocs) .map(CompoundTopDocs::new) .collect(Collectors.toList()); if (queryTopDocs.size() != querySearchResults.size()) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java new file mode 100644 index 000000000..ea0b54b9c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowExecuteRequest.java @@ -0,0 +1,32 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.search.query.QuerySearchResult; + +import java.util.List; +import java.util.Optional; + +@Builder +@AllArgsConstructor +@Getter +/** + * DTO class to hold request parameters for normalization and combination + */ +public class NormalizationProcessorWorkflowExecuteRequest { + final List querySearchResults; + final Optional fetchSearchResultOptional; + final ScoreNormalizationTechnique normalizationTechnique; + final ScoreCombinationTechnique combinationTechnique; + boolean explain; + final PipelineProcessingContext pipelineProcessingContext; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java new file mode 100644 index 000000000..c875eab55 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/SearchShard.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.AllArgsConstructor; +import lombok.Value; +import org.opensearch.search.SearchShardTarget; + +/** + * DTO class to store index, shardId and nodeId for a search shard. + */ +@Value +@AllArgsConstructor +public class SearchShard { + String index; + int shardId; + String nodeId; + + /** + * Create SearchShard from SearchShardTarget + * @param searchShardTarget + * @return SearchShard + */ + public static SearchShard createSearchShard(final SearchShardTarget searchShardTarget) { + return new SearchShard(searchShardTarget.getIndex(), searchShardTarget.getShardId().id(), searchShardTarget.getNodeId()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index 001f1670d..5ad79e75a 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -9,15 +9,18 @@ import java.util.Set; import lombok.ToString; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on arithmetic mean method */ @ToString(onlyExplicitlyIncluded = true) -public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class ArithmeticMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "arithmetic_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; @@ -54,4 +57,9 @@ public float combine(final float[] scores) { } return combinedScore / sumOfWeights; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, weights); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index c4b6dfb3f..b5bdabb43 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -9,15 +9,18 @@ import java.util.Set; import lombok.ToString; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on geometrical mean method */ @ToString(onlyExplicitlyIncluded = true) -public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class GeometricMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "geometric_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; @@ -54,4 +57,9 @@ public float combine(final float[] scores) { } return sumOfWeights == 0 ? ZERO_SCORE : (float) Math.exp(weightedLnSum / sumOfWeights); } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, weights); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index f5195f79f..eeb5950f1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -9,15 +9,18 @@ import java.util.Set; import lombok.ToString; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique; /** * Abstracts combination of scores based on harmonic mean method */ @ToString(onlyExplicitlyIncluded = true) -public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique { +public class HarmonicMeanScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "harmonic_mean"; - public static final String PARAM_NAME_WEIGHTS = "weights"; private static final Set SUPPORTED_PARAMS = Set.of(PARAM_NAME_WEIGHTS); private static final Float ZERO_SCORE = 0.0f; private final List weights; @@ -51,4 +54,9 @@ public float combine(final float[] scores) { } return sumOfHarmonics > 0 ? sumOfWeights / sumOfHarmonics : ZERO_SCORE; } + + @Override + public String describe() { + return describeCombinationTechnique(TECHNIQUE_NAME, weights); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index a915057df..5f18baf09 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -23,8 +23,8 @@ * Collection of utility methods for score combination technique classes */ @Log4j2 -class ScoreCombinationUtil { - private static final String PARAM_NAME_WEIGHTS = "weights"; +public class ScoreCombinationUtil { + public static final String PARAM_NAME_WEIGHTS = "weights"; private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index a4e39f448..1779f20f7 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -8,6 +8,7 @@ import java.util.Collection; import java.util.HashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Set; import java.util.Objects; @@ -16,6 +17,7 @@ import java.util.stream.Collectors; +import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TopFieldDocs; import org.apache.lucene.search.ScoreDoc; @@ -26,6 +28,9 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.extern.log4j.Log4j2; +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; /** * Abstracts combination of scores in query search results. @@ -96,14 +101,9 @@ private void combineShardScores( // - sort documents by scores and take first "max number" of docs // create a collection of doc ids that are sorted by their combined scores - Collection sortedDocsIds; - if (sort != null) { - sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort); - } else { - sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); - } + Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); - // - update query search results with normalized scores + // - update query search results with combined scores updateQueryTopDocsWithCombinedScores( compoundQueryTopDocs, topDocsPerSubQuery, @@ -318,4 +318,75 @@ private TotalHits getTotalHits(final List topDocsPerSubQuery, final lon } return new TotalHits(maxHits, totalHits); } + + /** + * Explain the score combination technique for each document in the given queryTopDocs. + * @param queryTopDocs + * @param combinationTechnique + * @param sort + * @return a map of SearchShard and List of ExplainationDetails for each document + */ + public Map> explain( + final List queryTopDocs, + final ScoreCombinationTechnique combinationTechnique, + final Sort sort + ) { + // In case of duplicate keys, keep the first value + Map> explanations = new HashMap<>(); + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + explanations.putIfAbsent( + compoundQueryTopDocs.getSearchShard(), + explainByShard(combinationTechnique, compoundQueryTopDocs, sort) + ); + } + return explanations; + } + + private List explainByShard( + final ScoreCombinationTechnique scoreCombinationTechnique, + final CompoundTopDocs compoundQueryTopDocs, + final Sort sort + ) { + if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { + return List.of(); + } + // create map of normalized scores results returned from the single shard + Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(compoundQueryTopDocs.getTopDocs()); + // combine scores + Map combinedNormalizedScoresByDocId = normalizedScoresPerDoc.entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, entry -> scoreCombinationTechnique.combine(entry.getValue()))); + // sort combined scores as per sorting criteria - either score desc or field sorting + Collection sortedDocsIds = getSortedDocsIds(compoundQueryTopDocs, sort, combinedNormalizedScoresByDocId); + + List listOfExplanations = new ArrayList<>(); + String combinationDescription = String.format( + Locale.ROOT, + "%s combination of:", + ((ExplainableTechnique) scoreCombinationTechnique).describe() + ); + for (int docId : sortedDocsIds) { + ExplanationDetails explanation = new ExplanationDetails( + docId, + List.of(Pair.of(combinedNormalizedScoresByDocId.get(docId), combinationDescription)) + ); + listOfExplanations.add(explanation); + } + return listOfExplanations; + } + + private Collection getSortedDocsIds( + final CompoundTopDocs compoundQueryTopDocs, + final Sort sort, + final Map combinedNormalizedScoresByDocId + ) { + Collection sortedDocsIds; + if (sort != null) { + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + sortedDocsIds = getSortedDocIdsBySortCriteria(getTopFieldDocs(sort, topDocsPerSubQuery), sort); + } else { + sortedDocsIds = getSortedDocIds(combinedNormalizedScoresByDocId); + } + return sortedDocsIds; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java new file mode 100644 index 000000000..c2e1b61e5 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/CombinedExplanationDetails.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +/** + * DTO class to hold explain details for normalization and combination + */ +@AllArgsConstructor +@Builder +@Getter +public class CombinedExplanationDetails { + private ExplanationDetails normalizationExplanations; + private ExplanationDetails combinationExplanations; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java new file mode 100644 index 000000000..51550e523 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/DocIdAtSearchShard.java @@ -0,0 +1,18 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import lombok.Value; +import org.opensearch.neuralsearch.processor.SearchShard; + +/** + * DTO class to store docId and search shard for a query. + * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. + */ +@Value +public class DocIdAtSearchShard { + int docId; + SearchShard searchShard; +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java new file mode 100644 index 000000000..cc2fab6c6 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplainableTechnique.java @@ -0,0 +1,34 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import java.util.List; +import java.util.Map; + +/** + * Abstracts explanation of score combination or normalization technique. + */ +public interface ExplainableTechnique { + + String GENERIC_DESCRIPTION_OF_TECHNIQUE = "generic score processing technique"; + + /** + * Returns a string with general description of the technique + */ + default String describe() { + return GENERIC_DESCRIPTION_OF_TECHNIQUE; + } + + /** + * Returns a map with explanation for each document id + * @param queryTopDocs collection of CompoundTopDocs for each shard result + * @return map of document per shard and corresponding explanation object + */ + default Map explain(final List queryTopDocs) { + return Map.of(); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java new file mode 100644 index 000000000..c55db4426 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationDetails.java @@ -0,0 +1,28 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import lombok.AllArgsConstructor; +import lombok.Value; +import org.apache.commons.lang3.tuple.Pair; + +import java.util.List; + +/** + * DTO class to store value and description for explain details. + * Used in {@link org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow} to normalize scores across shards. + */ +@Value +@AllArgsConstructor +public class ExplanationDetails { + int docId; + List> scoreDetails; + + public ExplanationDetails(List> scoreDetails) { + // pass docId as -1 to match docId in SearchHit + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/search/SearchHit.java#L170 + this(-1, scoreDetails); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java new file mode 100644 index 000000000..708f655c0 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationPayload.java @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Getter; + +import java.util.Map; + +/** + * DTO class to hold explain details for normalization and combination + */ +@AllArgsConstructor +@Builder +@Getter +public class ExplanationPayload { + private final Map explainPayload; + + public enum PayloadType { + NORMALIZATION_PROCESSOR + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java new file mode 100644 index 000000000..c6ac0500b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtils.java @@ -0,0 +1,60 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.apache.commons.lang3.tuple.Pair; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * Utility class for explain functionality + */ +public class ExplanationUtils { + + /** + * Creates map of DocIdAtQueryPhase to String containing source and normalized scores + * @param normalizedScores map of DocIdAtQueryPhase to normalized scores + * @return map of DocIdAtQueryPhase to String containing source and normalized scores + */ + public static Map getDocIdAtQueryForNormalization( + final Map> normalizedScores, + final ExplainableTechnique technique + ) { + Map explain = new HashMap<>(); + for (Map.Entry> entry : normalizedScores.entrySet()) { + List normScores = normalizedScores.get(entry.getKey()); + List> explanations = new ArrayList<>(); + for (float score : normScores) { + String description = String.format(Locale.ROOT, "%s normalization of:", technique.describe()); + explanations.add(Pair.of(score, description)); + } + explain.put(entry.getKey(), new ExplanationDetails(explanations)); + } + + return explain; + } + + /** + * Creates a string describing the combination technique and its parameters + * @param techniqueName the name of the combination technique + * @param weights the weights used in the combination technique + * @return a string describing the combination technique and its parameters + */ + public static String describeCombinationTechnique(final String techniqueName, final List weights) { + if (Objects.isNull(techniqueName)) { + throw new IllegalArgumentException("combination technique name cannot be null"); + } + return Optional.ofNullable(weights) + .filter(w -> !w.isEmpty()) + .map(w -> String.format(Locale.ROOT, "%s, weights %s", techniqueName, weights)) + .orElse(String.format(Locale.ROOT, "%s", techniqueName)); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java new file mode 100644 index 000000000..ffe1da12f --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactory.java @@ -0,0 +1,29 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; + +import java.util.Map; + +/** + * Factory class for creating ExplanationResponseProcessor + */ +public class ExplanationResponseProcessorFactory implements Processor.Factory { + + @Override + public SearchResponseProcessor create( + Map> processorFactories, + String tag, + String description, + boolean ignoreFailure, + Map config, + Processor.PipelineContext pipelineContext + ) throws Exception { + return new ExplanationResponseProcessor(description, tag, ignoreFailure); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 2bb6bbed7..e7fbf658c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -5,7 +5,10 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Objects; import org.apache.lucene.search.ScoreDoc; @@ -13,12 +16,17 @@ import org.opensearch.neuralsearch.processor.CompoundTopDocs; import lombok.ToString; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on L2 method */ @ToString(onlyExplicitlyIncluded = true) -public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique { +public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "l2"; private static final float MIN_SCORE = 0.0f; @@ -50,6 +58,34 @@ public void normalize(final List queryTopDocs) { } } + @Override + public String describe() { + return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); + } + + @Override + public Map explain(List queryTopDocs) { + Map> normalizedScores = new HashMap<>(); + List normsPerSubquery = getL2Norm(queryTopDocs); + + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); + float normalizedScore = normalizeSingleScore(scoreDoc.score, normsPerSubquery.get(j)); + normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore); + scoreDoc.score = normalizedScore; + } + } + } + return getDocIdAtQueryForNormalization(normalizedScores, this); + } + private List getL2Norm(final List queryTopDocs) { // find any non-empty compound top docs, it's either empty if shard does not have any results for all of sub-queries, // or it has results for all the sub-queries. In edge case of shard having results only for one sub-query, there will be TopDocs for diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index 4fdf3c0a6..0e54919d2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -4,10 +4,16 @@ */ package org.opensearch.neuralsearch.processor.normalization; +import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Objects; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; @@ -15,12 +21,17 @@ import com.google.common.primitives.Floats; import lombok.ToString; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; + +import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization; /** * Abstracts normalization of scores based on min-max method */ @ToString(onlyExplicitlyIncluded = true) -public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique { +public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "min_max"; private static final float MIN_SCORE = 0.001f; @@ -35,20 +46,45 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech */ @Override public void normalize(final List queryTopDocs) { - int numOfSubqueries = queryTopDocs.stream() - .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getTopDocs().size() > 0) - .findAny() - .get() - .getTopDocs() - .size(); + MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); + // do normalization using actual score and min and max scores for corresponding sub query + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore( + scoreDoc.score, + minMaxScores.getMinScoresPerSubquery()[j], + minMaxScores.getMaxScoresPerSubquery()[j] + ); + } + } + } + } + + private MinMaxScores getMinMaxScoresResult(final List queryTopDocs) { + int numOfSubqueries = getNumOfSubqueries(queryTopDocs); // get min scores for each sub query float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); - // get max scores for each sub query float[] maxScoresPerSubquery = getMaxScores(queryTopDocs, numOfSubqueries); + return new MinMaxScores(minScoresPerSubquery, maxScoresPerSubquery); + } - // do normalization using actual score and min and max scores for corresponding sub query + @Override + public String describe() { + return String.format(Locale.ROOT, "%s", TECHNIQUE_NAME); + } + + @Override + public Map explain(final List queryTopDocs) { + MinMaxScores minMaxScores = getMinMaxScoresResult(queryTopDocs); + + Map> normalizedScores = new HashMap<>(); for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; @@ -57,10 +93,28 @@ public void normalize(final List queryTopDocs) { for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { - scoreDoc.score = normalizeSingleScore(scoreDoc.score, minScoresPerSubquery[j], maxScoresPerSubquery[j]); + DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard()); + float normalizedScore = normalizeSingleScore( + scoreDoc.score, + minMaxScores.getMinScoresPerSubquery()[j], + minMaxScores.getMaxScoresPerSubquery()[j] + ); + normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore); + scoreDoc.score = normalizedScore; } } } + return getDocIdAtQueryForNormalization(normalizedScores, this); + } + + private int getNumOfSubqueries(final List queryTopDocs) { + return queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> !topDocs.getTopDocs().isEmpty()) + .findAny() + .get() + .getTopDocs() + .size(); } private float[] getMaxScores(final List queryTopDocs, final int numOfSubqueries) { @@ -113,4 +167,14 @@ private float normalizeSingleScore(final float score, final float minScore, fina float normalizedScore = (score - minScore) / (maxScore - minScore); return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; } + + /** + * Result class to hold min and max scores for each sub query + */ + @AllArgsConstructor + @Getter + private class MinMaxScores { + float[] minScoresPerSubquery; + float[] maxScoresPerSubquery; + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 263115f8f..67a17fda2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -5,9 +5,13 @@ package org.opensearch.neuralsearch.processor.normalization; import java.util.List; +import java.util.Map; import java.util.Objects; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique; public class ScoreNormalizer { @@ -25,4 +29,21 @@ public void normalizeScores(final List queryTopDocs, final Scor private boolean canQueryResultsBeNormalized(final List queryTopDocs) { return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0); } + + /** + * Explain normalized scores based on input normalization technique. Does not mutate input object. + * @param queryTopDocs original query results from multiple shards and multiple sub-queries + * @param queryTopDocs + * @param scoreNormalizationTechnique + * @return map of doc id to explanation details + */ + public Map explain( + final List queryTopDocs, + final ExplainableTechnique scoreNormalizationTechnique + ) { + if (canQueryResultsBeNormalized(queryTopDocs)) { + return scoreNormalizationTechnique.explain(queryTopDocs); + } + return Map.of(); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java index dc1f5e112..bad8fda74 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java +++ b/src/main/java/org/opensearch/neuralsearch/query/HybridQueryWeight.java @@ -140,16 +140,35 @@ public boolean isCacheable(LeafReaderContext ctx) { } /** - * Explain is not supported for hybrid query - * + * Returns a shard level {@link Explanation} that describes how the weight and scoring are calculated. * @param context the readers context to create the {@link Explanation} for. - * @param doc the document's id relative to the given context's reader - * @return + * @param doc the document's id relative to the given context's reader + * @return shard level {@link Explanation}, each sub-query explanation is a single nested element * @throws IOException */ @Override public Explanation explain(LeafReaderContext context, int doc) throws IOException { - throw new UnsupportedOperationException("Explain is not supported"); + boolean match = false; + double max = 0; + List subsOnNoMatch = new ArrayList<>(); + List subsOnMatch = new ArrayList<>(); + for (Weight wt : weights) { + Explanation e = wt.explain(context, doc); + if (e.isMatch()) { + match = true; + double score = e.getValue().doubleValue(); + subsOnMatch.add(e); + max = Math.max(max, score); + } else if (!match) { + subsOnNoMatch.add(e); + } + } + if (match) { + final String desc = "combined score of:"; + return Explanation.match(max, desc, subsOnMatch); + } else { + return Explanation.noMatch("no matching clause", subsOnNoMatch); + } } @RequiredArgsConstructor diff --git a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java index 3b2f64063..eabc69894 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java @@ -14,6 +14,7 @@ import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; public class CompoundTopDocsTests extends OpenSearchQueryTestCase { + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { TopDocs topDocs1 = new TopDocs( @@ -28,7 +29,7 @@ public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { new ScoreDoc(5, RandomUtils.nextFloat()) } ); List topDocs = List.of(topDocs1, topDocs2); - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs, false); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs, false, SEARCH_SHARD); assertNotNull(compoundTopDocs); assertEquals(topDocs, compoundTopDocs.getTopDocs()); } @@ -45,7 +46,8 @@ public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() { new ScoreDoc(5, RandomUtils.nextFloat()) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(hybridQueryScoreTopDocs); assertNotNull(hybridQueryScoreTopDocs.getScoreDocs()); @@ -59,21 +61,27 @@ public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWit new ScoreDoc[] { new ScoreDoc(2, RandomUtils.nextFloat()), new ScoreDoc(4, RandomUtils.nextFloat()) } ); List topDocs = List.of(topDocs1, topDocs2); - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs, false); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs, false, SEARCH_SHARD); assertNotNull(compoundTopDocs); assertNotNull(compoundTopDocs.getScoreDocs()); assertEquals(2, compoundTopDocs.getScoreDocs().size()); } public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { - CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null, false); + CompoundTopDocs compoundTopDocs = new CompoundTopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + (List) null, + false, + SEARCH_SHARD + ); assertNotNull(compoundTopDocs); assertNull(compoundTopDocs.getScoreDocs()); CompoundTopDocs compoundTopDocsWithNullArray = new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), Arrays.asList(null, null), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocsWithNullArray); assertNotNull(compoundTopDocsWithNullArray.getScoreDocs()); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java new file mode 100644 index 000000000..e47ea43d2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/ExplanationPayloadProcessorTests.java @@ -0,0 +1,446 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.commons.lang3.Range; +import org.apache.commons.lang3.tuple.Pair; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.SortField; +import org.apache.lucene.search.TotalHits; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.search.ShardSearchFailure; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationDetails; +import org.opensearch.neuralsearch.processor.explain.ExplanationPayload; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.aggregations.InternalAggregations; +import org.opensearch.search.internal.InternalSearchResponse; +import org.opensearch.search.pipeline.PipelineProcessingContext; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.transport.RemoteClusterAware; + +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.PriorityQueue; +import java.util.TreeMap; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; + +public class ExplanationPayloadProcessorTests extends OpenSearchTestCase { + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + public void testClassFields_whenCreateNewObject_thenAllFieldsPresent() { + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + + assertEquals(DESCRIPTION, explanationResponseProcessor.getDescription()); + assertEquals(PROCESSOR_TAG, explanationResponseProcessor.getTag()); + assertFalse(explanationResponseProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testPipelineContext_whenPipelineContextHasNoExplanationInfo_thenProcessorIsNoOp() { + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + SearchResponse searchResponse = new SearchResponse( + null, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + + SearchResponse processedResponse = explanationResponseProcessor.processResponse(searchRequest, searchResponse); + assertEquals(searchResponse, processedResponse); + + SearchResponse processedResponse2 = explanationResponseProcessor.processResponse(searchRequest, searchResponse, null); + assertEquals(searchResponse, processedResponse2); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + SearchResponse processedResponse3 = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + assertEquals(searchResponse, processedResponse3); + } + + @SneakyThrows + public void testParsingOfExplanations_whenResponseHasExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + float maxScore = 1.0f; + SearchHits searchHits = getSearchHits(maxScore); + SearchResponse searchResponse = getSearchResponse(searchHits); + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + pipelineProcessingContext.setAttribute( + org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, + explanationPayload + ); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + @SneakyThrows + public void testParsingOfExplanations_whenFieldSortingAndExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + float maxScore = 1.0f; + SearchHits searchHitsWithoutSorting = getSearchHits(maxScore); + for (SearchHit searchHit : searchHitsWithoutSorting.getHits()) { + Explanation explanation = Explanation.match(1.0f, "combined score of:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + final SortField[] sortFields = new SortField[] { + new SortField("random-text-field-1", SortField.Type.INT, randomBoolean()), + new SortField("random-text-field-2", SortField.Type.STRING, randomBoolean()) }; + SearchHits searchHits = new SearchHits(searchHitsWithoutSorting.getHits(), totalHits, maxScore, sortFields, null, null); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + pipelineProcessingContext.setAttribute( + org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, + explanationPayload + ); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + @SneakyThrows + public void testParsingOfExplanations_whenScoreSortingAndExplanations_thenSuccessful() { + // Setup + ExplanationResponseProcessor explanationResponseProcessor = new ExplanationResponseProcessor(DESCRIPTION, PROCESSOR_TAG, false); + SearchRequest searchRequest = mock(SearchRequest.class); + + float maxScore = 1.0f; + + SearchHits searchHits = getSearchHits(maxScore); + + SearchResponse searchResponse = getSearchResponse(searchHits); + + PipelineProcessingContext pipelineProcessingContext = new PipelineProcessingContext(); + + Map> combinedExplainDetails = getCombinedExplainDetails(searchHits); + Map explainPayload = Map.of( + ExplanationPayload.PayloadType.NORMALIZATION_PROCESSOR, + combinedExplainDetails + ); + ExplanationPayload explanationPayload = ExplanationPayload.builder().explainPayload(explainPayload).build(); + pipelineProcessingContext.setAttribute( + org.opensearch.neuralsearch.plugin.NeuralSearch.EXPLANATION_RESPONSE_KEY, + explanationPayload + ); + + // Act + SearchResponse processedResponse = explanationResponseProcessor.processResponse( + searchRequest, + searchResponse, + pipelineProcessingContext + ); + + // Assert + assertOnExplanationResults(processedResponse, maxScore); + } + + private static SearchHits getSearchHits(float maxScore) { + int numResponses = 1; + int numIndices = 2; + Iterator> indicesIterator = randomRealisticIndices(numIndices, numResponses).entrySet().iterator(); + Map.Entry entry = indicesIterator.next(); + String clusterAlias = entry.getKey(); + Index[] indices = entry.getValue(); + + int requestedSize = 2; + PriorityQueue priorityQueue = new PriorityQueue<>(new SearchHitComparator(null)); + TotalHits.Relation totalHitsRelation = randomFrom(TotalHits.Relation.values()); + TotalHits totalHits = new TotalHits(randomLongBetween(0, 1000), totalHitsRelation); + + final int numDocs = totalHits.value >= requestedSize ? requestedSize : (int) totalHits.value; + int scoreFactor = randomIntBetween(1, numResponses); + + SearchHit[] searchHitArray = randomSearchHitArray( + numDocs, + numResponses, + clusterAlias, + indices, + maxScore, + scoreFactor, + null, + priorityQueue + ); + for (SearchHit searchHit : searchHitArray) { + Explanation explanation = Explanation.match(1.0f, "combined score of:", Explanation.match(1.0f, "field1:[0 TO 100]")); + searchHit.explanation(explanation); + } + + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(numResponses, TotalHits.Relation.EQUAL_TO), maxScore); + return searchHits; + } + + private static SearchResponse getSearchResponse(SearchHits searchHits) { + InternalSearchResponse internalSearchResponse = new InternalSearchResponse( + searchHits, + InternalAggregations.EMPTY, + null, + null, + false, + null, + 1 + ); + SearchResponse searchResponse = new SearchResponse( + internalSearchResponse, + null, + 1, + 1, + 0, + 1000, + new ShardSearchFailure[0], + SearchResponse.Clusters.EMPTY + ); + return searchResponse; + } + + private static Map> getCombinedExplainDetails(SearchHits searchHits) { + Map> combinedExplainDetails = Map.of( + SearchShard.createSearchShard(searchHits.getHits()[0].getShard()), + List.of( + CombinedExplanationDetails.builder() + .normalizationExplanations(new ExplanationDetails(List.of(Pair.of(1.0f, "min_max normalization of:")))) + .combinationExplanations(new ExplanationDetails(List.of(Pair.of(0.5f, "arithmetic_mean combination of:")))) + .build() + ), + SearchShard.createSearchShard(searchHits.getHits()[1].getShard()), + List.of( + CombinedExplanationDetails.builder() + .normalizationExplanations(new ExplanationDetails(List.of(Pair.of(0.5f, "min_max normalization of:")))) + .combinationExplanations(new ExplanationDetails(List.of(Pair.of(0.25f, "arithmetic_mean combination of:")))) + .build() + ) + ); + return combinedExplainDetails; + } + + private static void assertOnExplanationResults(SearchResponse processedResponse, float maxScore) { + assertNotNull(processedResponse); + Explanation hit1TopLevelExplanation = processedResponse.getHits().getHits()[0].getExplanation(); + assertNotNull(hit1TopLevelExplanation); + assertEquals("arithmetic_mean combination of:", hit1TopLevelExplanation.getDescription()); + assertEquals(maxScore, (float) hit1TopLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation[] hit1SecondLevelDetails = hit1TopLevelExplanation.getDetails(); + assertEquals(1, hit1SecondLevelDetails.length); + assertEquals("min_max normalization of:", hit1SecondLevelDetails[0].getDescription()); + assertEquals(1.0f, (float) hit1SecondLevelDetails[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertNotNull(hit1SecondLevelDetails[0].getDetails()); + assertEquals(1, hit1SecondLevelDetails[0].getDetails().length); + Explanation hit1ShardLevelExplanation = hit1SecondLevelDetails[0].getDetails()[0]; + + assertEquals(1.0f, (float) hit1ShardLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals("field1:[0 TO 100]", hit1ShardLevelExplanation.getDescription()); + + Explanation hit2TopLevelExplanation = processedResponse.getHits().getHits()[1].getExplanation(); + assertNotNull(hit2TopLevelExplanation); + assertEquals("arithmetic_mean combination of:", hit2TopLevelExplanation.getDescription()); + assertEquals(0.0f, (float) hit2TopLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + + Explanation[] hit2SecondLevelDetails = hit2TopLevelExplanation.getDetails(); + assertEquals(1, hit2SecondLevelDetails.length); + assertEquals("min_max normalization of:", hit2SecondLevelDetails[0].getDescription()); + assertEquals(.5f, (float) hit2SecondLevelDetails[0].getValue(), DELTA_FOR_SCORE_ASSERTION); + + assertNotNull(hit2SecondLevelDetails[0].getDetails()); + assertEquals(1, hit2SecondLevelDetails[0].getDetails().length); + Explanation hit2ShardLevelExplanation = hit2SecondLevelDetails[0].getDetails()[0]; + + assertEquals(1.0f, (float) hit2ShardLevelExplanation.getValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals("field1:[0 TO 100]", hit2ShardLevelExplanation.getDescription()); + + Explanation explanationHit2 = processedResponse.getHits().getHits()[1].getExplanation(); + assertNotNull(explanationHit2); + assertEquals("arithmetic_mean combination of:", explanationHit2.getDescription()); + assertTrue(Range.of(0.0f, maxScore).contains((float) explanationHit2.getValue())); + + } + + private static Map randomRealisticIndices(int numIndices, int numClusters) { + String[] indicesNames = new String[numIndices]; + for (int i = 0; i < numIndices; i++) { + indicesNames[i] = randomAlphaOfLengthBetween(5, 10); + } + Map indicesPerCluster = new TreeMap<>(); + for (int i = 0; i < numClusters; i++) { + Index[] indices = new Index[indicesNames.length]; + for (int j = 0; j < indices.length; j++) { + String indexName = indicesNames[j]; + String indexUuid = frequently() ? randomAlphaOfLength(10) : indexName; + indices[j] = new Index(indexName, indexUuid); + } + String clusterAlias; + if (frequently() || indicesPerCluster.containsKey(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY)) { + clusterAlias = randomAlphaOfLengthBetween(5, 10); + } else { + clusterAlias = RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY; + } + indicesPerCluster.put(clusterAlias, indices); + } + return indicesPerCluster; + } + + private static SearchHit[] randomSearchHitArray( + int numDocs, + int numResponses, + String clusterAlias, + Index[] indices, + float maxScore, + int scoreFactor, + SortField[] sortFields, + PriorityQueue priorityQueue + ) { + SearchHit[] hits = new SearchHit[numDocs]; + + int[] sortFieldFactors = new int[sortFields == null ? 0 : sortFields.length]; + for (int j = 0; j < sortFieldFactors.length; j++) { + sortFieldFactors[j] = randomIntBetween(1, numResponses); + } + + for (int j = 0; j < numDocs; j++) { + ShardId shardId = new ShardId(randomFrom(indices), randomIntBetween(0, 10)); + SearchShardTarget shardTarget = new SearchShardTarget( + randomAlphaOfLengthBetween(3, 8), + shardId, + clusterAlias, + OriginalIndices.NONE + ); + SearchHit hit = new SearchHit(randomIntBetween(0, Integer.MAX_VALUE)); + + float score = Float.NaN; + if (!Float.isNaN(maxScore)) { + score = (maxScore - j) * scoreFactor; + hit.score(score); + } + + hit.shard(shardTarget); + if (sortFields != null) { + Object[] rawSortValues = new Object[sortFields.length]; + DocValueFormat[] docValueFormats = new DocValueFormat[sortFields.length]; + for (int k = 0; k < sortFields.length; k++) { + SortField sortField = sortFields[k]; + if (sortField == SortField.FIELD_SCORE) { + hit.score(score); + rawSortValues[k] = score; + } else { + rawSortValues[k] = sortField.getReverse() ? numDocs * sortFieldFactors[k] - j : j; + } + docValueFormats[k] = DocValueFormat.RAW; + } + hit.sortValues(rawSortValues, docValueFormats); + } + hits[j] = hit; + priorityQueue.add(hit); + } + return hits; + } + + private static final class SearchHitComparator implements Comparator { + + private final SortField[] sortFields; + + SearchHitComparator(SortField[] sortFields) { + this.sortFields = sortFields; + } + + @Override + public int compare(SearchHit a, SearchHit b) { + if (sortFields == null) { + int scoreCompare = Float.compare(b.getScore(), a.getScore()); + if (scoreCompare != 0) { + return scoreCompare; + } + } else { + for (int i = 0; i < sortFields.length; i++) { + SortField sortField = sortFields[i]; + if (sortField == SortField.FIELD_SCORE) { + int scoreCompare = Float.compare(b.getScore(), a.getScore()); + if (scoreCompare != 0) { + return scoreCompare; + } + } else { + Integer aSortValue = (Integer) a.getRawSortValues()[i]; + Integer bSortValue = (Integer) b.getRawSortValues()[i]; + final int compare; + if (sortField.getReverse()) { + compare = Integer.compare(bSortValue, aSortValue); + } else { + compare = Integer.compare(aSortValue, bSortValue); + } + if (compare != 0) { + return compare; + } + } + } + } + SearchShardTarget aShard = a.getShard(); + SearchShardTarget bShard = b.getShard(); + int shardIdCompareTo = aShard.getShardId().compareTo(bShard.getShardId()); + if (shardIdCompareTo != 0) { + return shardIdCompareTo; + } + int clusterAliasCompareTo = aShard.getClusterAlias().compareTo(bShard.getClusterAlias()); + if (clusterAliasCompareTo != 0) { + return clusterAliasCompareTo; + } + return Integer.compare(a.docId(), b.docId()); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index e93c9b9ec..5f45b14fe 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -179,6 +179,7 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio } SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -247,6 +248,7 @@ public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); when(searchPhaseContext.getNumShards()).thenReturn(1); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -408,6 +410,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); List querySearchResults = queryPhaseResultConsumer.getAtomicArray() @@ -417,7 +420,7 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz .collect(Collectors.toList()); TestUtils.assertQueryResultScores(querySearchResults); - verify(normalizationProcessorWorkflow).execute(any(), any(), any(), any()); + verify(normalizationProcessorWorkflow).execute(any()); } public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() { @@ -495,6 +498,7 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getRequest()).thenReturn(searchRequest); IllegalStateException exception = expectThrows( IllegalStateException.class, () -> normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext) diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 918f3f45b..6ff6d9174 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -19,6 +19,8 @@ public class ScoreCombinationTechniqueTests extends OpenSearchTestCase { + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreCombiner scoreCombiner = new ScoreCombiner(); scoreCombiner.combineScores( @@ -46,7 +48,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(5, 0.001f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -57,7 +60,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new ScoreDoc[] { new ScoreDoc(2, 0.9f), new ScoreDoc(4, 0.6f), new ScoreDoc(7, 0.5f), new ScoreDoc(9, 0.01f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -65,7 +69,8 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) ), - false + false, + SEARCH_SHARD ) ); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index 67abd552f..b2b0007f6 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -18,6 +18,8 @@ public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); + public void testEmptyResults_whenEmptyResultsAndDefaultMethod_thenNoProcessing() { ScoreNormalizer scoreNormalizationMethod = new ScoreNormalizer(); scoreNormalizationMethod.normalizeScores(List.of(), ScoreNormalizationFactory.DEFAULT_METHOD); @@ -30,7 +32,8 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco new CompoundTopDocs( new TotalHits(1, TotalHits.Relation.EQUAL_TO), List.of(new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(1, 2.0f) })), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -61,7 +64,8 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe new ScoreDoc[] { new ScoreDoc(1, 10.0f), new ScoreDoc(2, 2.5f), new ScoreDoc(4, 0.1f) } ) ), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -98,7 +102,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) ), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); @@ -147,7 +152,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new ScoreDoc[] { new ScoreDoc(3, 0.8f), new ScoreDoc(5, 0.5f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -158,7 +164,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new ScoreDoc[] { new ScoreDoc(2, 2.2f), new ScoreDoc(4, 1.8f), new ScoreDoc(7, 0.9f), new ScoreDoc(9, 0.01f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), @@ -166,7 +173,8 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]) ), - false + false, + SEARCH_SHARD ) ); scoreNormalizationMethod.normalizeScores(queryTopDocs, ScoreNormalizationFactory.DEFAULT_METHOD); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 7d3b3fb61..deac02933 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index 495e2f4cd..d46705902 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 0c6e1f81d..0cfdeb4c4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -4,13 +4,13 @@ */ package org.opensearch.neuralsearch.processor.combination; -import static org.opensearch.neuralsearch.processor.combination.ArithmeticMeanScoreCombinationTechnique.PARAM_NAME_WEIGHTS; - import java.util.List; import java.util.Map; import java.util.stream.Collectors; import java.util.stream.IntStream; +import static org.opensearch.neuralsearch.processor.combination.ScoreCombinationUtil.PARAM_NAME_WEIGHTS; + public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java new file mode 100644 index 000000000..becab3860 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/explain/ExplanationUtilsTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.explain; + +import org.apache.commons.lang3.tuple.Pair; +import org.junit.Before; + +import org.opensearch.neuralsearch.processor.SearchShard; +import org.opensearch.neuralsearch.processor.normalization.MinMaxScoreNormalizationTechnique; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +public class ExplanationUtilsTests extends OpenSearchQueryTestCase { + + private DocIdAtSearchShard docId1; + private DocIdAtSearchShard docId2; + private Map> normalizedScores; + private final MinMaxScoreNormalizationTechnique MIN_MAX_TECHNIQUE = new MinMaxScoreNormalizationTechnique(); + + @Before + public void setUp() throws Exception { + super.setUp(); + SearchShard searchShard = new SearchShard("test_index", 0, "abcdefg"); + docId1 = new DocIdAtSearchShard(1, searchShard); + docId2 = new DocIdAtSearchShard(2, searchShard); + normalizedScores = new HashMap<>(); + } + + public void testGetDocIdAtQueryForNormalization() { + // Setup + normalizedScores.put(docId1, Arrays.asList(1.0f, 0.5f)); + normalizedScores.put(docId2, Arrays.asList(0.8f)); + // Act + Map result = ExplanationUtils.getDocIdAtQueryForNormalization( + normalizedScores, + MIN_MAX_TECHNIQUE + ); + // Assert + assertNotNull(result); + assertEquals(2, result.size()); + + // Assert first document + ExplanationDetails details1 = result.get(docId1); + assertNotNull(details1); + List> explanations1 = details1.getScoreDetails(); + assertEquals(2, explanations1.size()); + assertEquals(1.0f, explanations1.get(0).getLeft(), 0.001); + assertEquals(0.5f, explanations1.get(1).getLeft(), 0.001); + assertEquals("min_max normalization of:", explanations1.get(0).getRight()); + assertEquals("min_max normalization of:", explanations1.get(1).getRight()); + + // Assert second document + ExplanationDetails details2 = result.get(docId2); + assertNotNull(details2); + List> explanations2 = details2.getScoreDetails(); + assertEquals(1, explanations2.size()); + assertEquals(0.8f, explanations2.get(0).getLeft(), 0.001); + assertEquals("min_max normalization of:", explanations2.get(0).getRight()); + } + + public void testGetDocIdAtQueryForNormalizationWithEmptyScores() { + // Setup + // Using empty normalizedScores from setUp + // Act + Map result = ExplanationUtils.getDocIdAtQueryForNormalization( + normalizedScores, + MIN_MAX_TECHNIQUE + ); + // Assert + assertNotNull(result); + assertTrue(result.isEmpty()); + } + + public void testDescribeCombinationTechniqueWithWeights() { + // Setup + String techniqueName = "test_technique"; + List weights = Arrays.asList(0.3f, 0.7f); + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights); + // Assert + assertEquals("test_technique, weights [0.3, 0.7]", result); + } + + public void testDescribeCombinationTechniqueWithoutWeights() { + // Setup + String techniqueName = "test_technique"; + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, null); + // Assert + assertEquals("test_technique", result); + } + + public void testDescribeCombinationTechniqueWithEmptyWeights() { + // Setup + String techniqueName = "test_technique"; + List weights = Arrays.asList(); + // Act + String result = ExplanationUtils.describeCombinationTechnique(techniqueName, weights); + // Assert + assertEquals("test_technique", result); + } + + public void testDescribeCombinationTechniqueWithNullTechnique() { + // Setup + List weights = Arrays.asList(1.0f); + // Act & Assert + expectThrows(IllegalArgumentException.class, () -> ExplanationUtils.describeCombinationTechnique(null, weights)); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java new file mode 100644 index 000000000..453cc471c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/ExplanationResponseProcessorFactoryTests.java @@ -0,0 +1,112 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor.factory; + +import lombok.SneakyThrows; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; +import org.opensearch.search.pipeline.Processor; +import org.opensearch.search.pipeline.SearchResponseProcessor; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.HashMap; +import java.util.Map; + +import static org.mockito.Mockito.mock; + +public class ExplanationResponseProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testDefaults_whenNoParams_thenSuccessful() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + // Assert + assertProcessor(responseProcessor, tag, description, ignoreFailure); + } + + @SneakyThrows + public void testInvalidInput_whenParamsPassedToFactory_thenSuccessful() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + // create map of random parameters + Map config = new HashMap<>(); + for (int i = 0; i < randomInt(1_000); i++) { + config.put(randomAlphaOfLength(10) + i, randomAlphaOfLength(100)); + } + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessor = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + // Assert + assertProcessor(responseProcessor, tag, description, ignoreFailure); + } + + @SneakyThrows + public void testNewInstanceCreation_whenCreateMultipleTimes_thenNewInstanceReturned() { + // Setup + ExplanationResponseProcessorFactory explanationResponseProcessorFactory = new ExplanationResponseProcessorFactory(); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + // Act + SearchResponseProcessor responseProcessorOne = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + + SearchResponseProcessor responseProcessorTwo = explanationResponseProcessorFactory.create( + processorFactories, + tag, + description, + ignoreFailure, + config, + pipelineContext + ); + + // Assert + assertNotEquals(responseProcessorOne, responseProcessorTwo); + } + + private static void assertProcessor(SearchResponseProcessor responseProcessor, String tag, String description, boolean ignoreFailure) { + assertNotNull(responseProcessor); + assertTrue(responseProcessor instanceof ExplanationResponseProcessor); + ExplanationResponseProcessor explanationResponseProcessor = (ExplanationResponseProcessor) responseProcessor; + assertEquals("explanation_response_processor", explanationResponseProcessor.getType()); + assertEquals(tag, explanationResponseProcessor.getTag()); + assertEquals(description, explanationResponseProcessor.getDescription()); + assertEquals(ignoreFailure, explanationResponseProcessor.isIgnoreFailure()); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index ba4bfee0d..734f9bb57 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -11,6 +11,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -18,6 +19,7 @@ */ public class L2ScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { L2ScoreNormalizationTechnique normalizationTechnique = new L2ScoreNormalizationTechnique(); @@ -31,7 +33,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc[] { new ScoreDoc(2, scores[0]), new ScoreDoc(4, scores[1]) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -46,7 +49,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc(4, l2Norm(scores[1], Arrays.asList(scores))) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -78,7 +82,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc(2, scoresQuery2[2]) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -101,7 +106,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc(2, l2Norm(scoresQuery2[2], Arrays.asList(scoresQuery2))) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -133,7 +139,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(2, scoresShard1and2Query3[2]) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), @@ -152,7 +159,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(15, scoresShard1and2Query3[6]) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -175,7 +183,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(2, l2Norm(scoresShard1and2Query3[2], Arrays.asList(scoresShard1and2Query3))) } ) ), - false + false, + SEARCH_SHARD ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( @@ -197,7 +206,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc(15, l2Norm(scoresShard1and2Query3[6], Arrays.asList(scoresShard1and2Query3))) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index d0445f0ca..c7692b407 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.SearchShard; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; /** @@ -17,6 +18,7 @@ */ public class MinMaxScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { private static final float DELTA_FOR_ASSERTION = 0.0001f; + private static final SearchShard SEARCH_SHARD = new SearchShard("my_index", 0, "12345678"); public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { MinMaxScoreNormalizationTechnique normalizationTechnique = new MinMaxScoreNormalizationTechnique(); @@ -29,7 +31,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -42,7 +45,8 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -69,7 +73,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -87,7 +92,8 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); @@ -113,7 +119,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } ) ), - false + false, + SEARCH_SHARD ), new CompoundTopDocs( new TotalHits(2, TotalHits.Relation.EQUAL_TO), @@ -124,7 +131,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } ) ), - false + false, + SEARCH_SHARD ) ); normalizationTechnique.normalize(compoundTopDocs); @@ -142,7 +150,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(3, 1.0f), new ScoreDoc(4, 0.75f), new ScoreDoc(2, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( @@ -154,7 +163,8 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, 0.001f) } ) ), - false + false, + SEARCH_SHARD ); assertNotNull(compoundTopDocs); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java new file mode 100644 index 000000000..b7e4f753a --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryExplainIT.java @@ -0,0 +1,722 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.query; + +import com.google.common.primitives.Floats; +import lombok.SneakyThrows; +import org.junit.Before; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Set; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.PARAM_NAME_WEIGHTS; +import static org.opensearch.neuralsearch.util.TestUtils.RELATION_EQUAL_TO; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_DIMENSION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; +import static org.opensearch.neuralsearch.util.TestUtils.createRandomVector; +import static org.opensearch.neuralsearch.util.TestUtils.getMaxScore; +import static org.opensearch.neuralsearch.util.TestUtils.getNestedHits; +import static org.opensearch.neuralsearch.util.TestUtils.getTotalHits; +import static org.opensearch.neuralsearch.util.TestUtils.getValueByKey; + +public class HybridQueryExplainIT extends BaseNeuralSearchIT { + private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-hybrid-vector-doc-field-index"; + private static final String TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME = "test-hybrid-multi-doc-nested-fields-index"; + private static final String TEST_MULTI_DOC_INDEX_NAME = "test-hybrid-multi-doc-index"; + private static final String TEST_LARGE_DOCS_INDEX_NAME = "test-hybrid-large-docs-index"; + + private static final String TEST_QUERY_TEXT3 = "hello"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_TEXT_FIELD_NAME_2 = "test-text-field-2"; + private static final String TEST_NESTED_TYPE_FIELD_NAME_1 = "user"; + private static final String NORMALIZATION_TECHNIQUE_L2 = "l2"; + private static final int MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX = 2_000; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final float[] testVector3 = createRandomVector(TEST_DIMENSION); + private static final String SEARCH_PIPELINE = "phase-results-hybrid-pipeline"; + + static final Supplier TEST_VECTOR_SUPPLIER = () -> new float[768]; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + @SneakyThrows + public void testExplain_whenMultipleSubqueriesAndOneShard_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(topLevelExplanationsHit1); + assertEquals((double) searchHit1.get("_score"), (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "arithmetic_mean combination of:"; + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); + List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); + assertEquals(1, normalizationExplanationHit1.size()); + Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("sum of:", explanationsHit1.get("description")); + assertEquals(0.754f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, ((List) explanationsHit1.get("details")).size()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(topLevelExplanationsHit2); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); + List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); + assertEquals(1, normalizationExplanationHit2.size()); + + Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); + assertEquals(1.0, hit1DetailsForHit2.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals(0.287f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit2.get("description")); + assertEquals(1, getListOfValues(explanationsHit2, "details").size()); + + Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); + assertEquals(0.287f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2Details.get("description")); + assertEquals(3, getListOfValues(explanationsHit2Details, "details").size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(1); + Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(topLevelExplanationsHit3); + assertEquals((double) searchHit2.get("_score"), (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); + List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); + assertEquals(1, normalizationExplanationHit3.size()); + + Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); + assertEquals(1.0, hit1DetailsForHit3.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit3.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); + + Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); + assertEquals(0.287f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", explanationsHit3.get("description")); + assertEquals(1, getListOfValues(explanationsHit3, "details").size()); + + Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); + assertEquals(0.287f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3Details.get("description")); + assertEquals(3, getListOfValues(explanationsHit3Details, "details").size()); + } finally { + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenMultipleSubqueriesAndMultipleShards_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); + createSearchPipeline( + SEARCH_PIPELINE, + NORMALIZATION_TECHNIQUE_L2, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.7f })), + true + ); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + KNNQueryBuilder knnQueryBuilder = KNNQueryBuilder.builder() + .fieldName(TEST_KNN_VECTOR_FIELD_NAME_1) + .vector(createRandomVector(TEST_DIMENSION)) + .k(10) + .build(); + hybridQueryBuilder.add(QueryBuilders.existsQuery(TEST_TEXT_FIELD_NAME_1)); + hybridQueryBuilder.add(knnQueryBuilder); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // basic sanity check for search hits + assertEquals(4, getHitCount(searchResponseAsMap)); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + float actualMaxScore = getMaxScore(searchResponseAsMap).get(); + assertTrue(actualMaxScore > 0); + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(4, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain, hit 1 + List> hitsNestedList = getNestedHits(searchResponseAsMap); + Map searchHit1 = hitsNestedList.get(0); + Map explanationForHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(explanationForHit1); + assertEquals((double) searchHit1.get("_score"), (double) explanationForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "arithmetic_mean, weights [0.3, 0.7] combination of:"; + assertEquals(expectedTopLevelDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + // two sub-queries meaning we do have two detail objects with separate query level details + Map hit1DetailsForHit1 = hit1Details.get(0); + assertTrue((double) hit1DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit1.get("description")); + assertTrue((double) explanationsHit1.get("value") > 0.5f); + assertEquals(0, ((List) explanationsHit1.get("details")).size()); + + Map hit1DetailsForHit2 = hit1Details.get(1); + assertTrue((double) hit1DetailsForHit2.get("value") > 0.0f); + assertEquals("l2 normalization of:", hit1DetailsForHit2.get("description")); + assertEquals(1, ((List) hit1DetailsForHit2.get("details")).size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals("within top 10", explanationsHit2.get("description")); + assertTrue((double) explanationsHit2.get("value") > 0.0f); + assertEquals(0, ((List) explanationsHit2.get("details")).size()); + + // hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map explanationForHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(explanationForHit2); + assertEquals((double) searchHit2.get("_score"), (double) explanationForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit2.get("description")); + List> hit2Details = getListOfValues(explanationForHit2, "details"); + assertEquals(2, hit2Details.size()); + + Map hit2DetailsForHit1 = hit2Details.get(0); + assertTrue((double) hit2DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit2DetailsForHit1.get("description")); + assertEquals(1, ((List) hit2DetailsForHit1.get("details")).size()); + + Map hit2DetailsForHit2 = hit2Details.get(1); + assertTrue((double) hit2DetailsForHit2.get("value") > 0.0f); + assertEquals("l2 normalization of:", hit2DetailsForHit2.get("description")); + assertEquals(1, ((List) hit2DetailsForHit2.get("details")).size()); + + // hit 3 + Map searchHit3 = hitsNestedList.get(2); + Map explanationForHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(explanationForHit3); + assertEquals((double) searchHit3.get("_score"), (double) explanationForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit3.get("description")); + List> hit3Details = getListOfValues(explanationForHit3, "details"); + assertEquals(1, hit3Details.size()); + + Map hit3DetailsForHit1 = hit3Details.get(0); + assertTrue((double) hit3DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit3DetailsForHit1.get("description")); + assertEquals(1, ((List) hit3DetailsForHit1.get("details")).size()); + + Map explanationsHit3 = getListOfValues(hit3DetailsForHit1, "details").get(0); + assertEquals("within top 10", explanationsHit3.get("description")); + assertEquals(0, getListOfValues(explanationsHit3, "details").size()); + assertTrue((double) explanationsHit3.get("value") > 0.0f); + + // hit 4 + Map searchHit4 = hitsNestedList.get(3); + Map explanationForHit4 = getValueByKey(searchHit4, "_explanation"); + assertNotNull(explanationForHit4); + assertEquals((double) searchHit4.get("_score"), (double) explanationForHit4.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(1, hit4Details.size()); + + Map hit4DetailsForHit1 = hit4Details.get(0); + assertTrue((double) hit4DetailsForHit1.get("value") > 0.5f); + assertEquals("l2 normalization of:", hit4DetailsForHit1.get("description")); + assertEquals(1, ((List) hit4DetailsForHit1.get("details")).size()); + + Map explanationsHit4 = getListOfValues(hit4DetailsForHit1, "details").get(0); + assertEquals("ConstantScore(FieldExistsQuery [field=test-text-field-1])", explanationsHit4.get("description")); + assertEquals(0, getListOfValues(explanationsHit4, "details").size()); + assertTrue((double) explanationsHit4.get("value") > 0.0f); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplanationResponseProcessor_whenProcessorIsNotConfigured_thenResponseHasQueryExplanations() { + try { + initializeIndexIfNotExist(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME); + // create search pipeline with normalization processor, no explanation response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), false); + + TermQueryBuilder termQueryBuilder1 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + // Assert + // search hits + assertEquals(3, getHitCount(searchResponseAsMap1)); + + List> hitsNestedList = getNestedHits(searchResponseAsMap1); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap1); + assertNotNull(total.get("value")); + assertEquals(3, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // explain + Map searchHit1 = hitsNestedList.get(0); + Map topLevelExplanationsHit1 = getValueByKey(searchHit1, "_explanation"); + assertNotNull(topLevelExplanationsHit1); + assertEquals(0.754f, (double) topLevelExplanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + String expectedTopLevelDescription = "combined score of:"; + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit1.get("description")); + List> normalizationExplanationHit1 = getListOfValues(topLevelExplanationsHit1, "details"); + assertEquals(1, normalizationExplanationHit1.size()); + Map hit1DetailsForHit1 = normalizationExplanationHit1.get(0); + assertEquals(0.754f, (double) hit1DetailsForHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("sum of:", hit1DetailsForHit1.get("description")); + assertEquals(1, ((List) hit1DetailsForHit1.get("details")).size()); + + Map explanationsHit1 = getListOfValues(hit1DetailsForHit1, "details").get(0); + assertEquals("weight(test-text-field-1:place in 0) [PerFieldSimilarity], result of:", explanationsHit1.get("description")); + assertEquals(0.754f, (double) explanationsHit1.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, ((List) explanationsHit1.get("details")).size()); + + Map explanationsHit1Details = getListOfValues(explanationsHit1, "details").get(0); + assertEquals(0.754f, (double) explanationsHit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit1Details.get("description")); + assertEquals(3, getListOfValues(explanationsHit1Details, "details").size()); + + Map explanationsDetails1Hit1Details = getListOfValues(explanationsHit1Details, "details").get(0); + assertEquals(2.2f, (double) explanationsDetails1Hit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("boost", explanationsDetails1Hit1Details.get("description")); + assertEquals(0, getListOfValues(explanationsDetails1Hit1Details, "details").size()); + + Map explanationsDetails2Hit1Details = getListOfValues(explanationsHit1Details, "details").get(1); + assertEquals(0.693f, (double) explanationsDetails2Hit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:", explanationsDetails2Hit1Details.get("description")); + assertFalse(getListOfValues(explanationsDetails2Hit1Details, "details").isEmpty()); + + Map explanationsDetails3Hit1Details = getListOfValues(explanationsHit1Details, "details").get(2); + assertEquals(0.495f, (double) explanationsDetails3Hit1Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals( + "tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", + explanationsDetails3Hit1Details.get("description") + ); + assertFalse(getListOfValues(explanationsDetails3Hit1Details, "details").isEmpty()); + + // search hit 2 + Map searchHit2 = hitsNestedList.get(1); + Map topLevelExplanationsHit2 = getValueByKey(searchHit2, "_explanation"); + assertNotNull(topLevelExplanationsHit2); + assertEquals(0.287f, (double) topLevelExplanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit2.get("description")); + List> normalizationExplanationHit2 = getListOfValues(topLevelExplanationsHit2, "details"); + assertEquals(1, normalizationExplanationHit2.size()); + + Map hit1DetailsForHit2 = normalizationExplanationHit2.get(0); + assertEquals(0.287f, (double) hit1DetailsForHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit2.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit2, "details").size()); + + Map explanationsHit2 = getListOfValues(hit1DetailsForHit2, "details").get(0); + assertEquals(0.287f, (double) explanationsHit2.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit2.get("description")); + assertEquals(3, getListOfValues(explanationsHit2, "details").size()); + + Map explanationsHit2Details = getListOfValues(explanationsHit2, "details").get(0); + assertEquals(2.2f, (double) explanationsHit2Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("boost", explanationsHit2Details.get("description")); + assertEquals(0, getListOfValues(explanationsHit2Details, "details").size()); + + // search hit 3 + Map searchHit3 = hitsNestedList.get(1); + Map topLevelExplanationsHit3 = getValueByKey(searchHit3, "_explanation"); + assertNotNull(topLevelExplanationsHit3); + assertEquals(0.287f, (double) topLevelExplanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + + assertEquals(expectedTopLevelDescription, topLevelExplanationsHit3.get("description")); + List> normalizationExplanationHit3 = getListOfValues(topLevelExplanationsHit3, "details"); + assertEquals(1, normalizationExplanationHit3.size()); + + Map hit1DetailsForHit3 = normalizationExplanationHit3.get(0); + assertEquals(0.287, (double) hit1DetailsForHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("weight(test-text-field-1:hello in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit3.get("description")); + assertEquals(1, getListOfValues(hit1DetailsForHit3, "details").size()); + + Map explanationsHit3 = getListOfValues(hit1DetailsForHit3, "details").get(0); + assertEquals(0.287f, (double) explanationsHit3.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("score(freq=1.0), computed as boost * idf * tf from:", explanationsHit3.get("description")); + assertEquals(3, getListOfValues(explanationsHit3, "details").size()); + + Map explanationsHit3Details = getListOfValues(explanationsHit3, "details").get(0); + assertEquals(2.2f, (double) explanationsHit3Details.get("value"), DELTA_FOR_SCORE_ASSERTION); + assertEquals("boost", explanationsHit3Details.get("description")); + assertEquals(0, getListOfValues(explanationsHit3Details, "details").size()); + } finally { + wipeOfTestResources(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testExplain_whenLargeNumberOfDocuments_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(termQueryBuilder); + + Map searchResponseAsMap = search( + TEST_LARGE_DOCS_INDEX_NAME, + hybridQueryBuilder, + null, + MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + assertNotNull(hitsNestedList); + assertFalse(hitsNestedList.isEmpty()); + + // Verify total hits + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // Sanity checks for each hit's explanation + for (Map hit : hitsNestedList) { + // Verify score is positive + double score = (double) hit.get("_score"); + assertTrue("Score should be positive", score > 0.0); + + // Basic explanation structure checks + Map explanation = getValueByKey(hit, "_explanation"); + assertNotNull(explanation); + assertEquals("arithmetic_mean combination of:", explanation.get("description")); + Map hitDetailsForHit = getListOfValues(explanation, "details").get(0); + assertTrue((double) hitDetailsForHit.get("value") > 0.0f); + assertEquals("min_max normalization of:", hitDetailsForHit.get("description")); + Map subQueryDetailsForHit = getListOfValues(hitDetailsForHit, "details").get(0); + assertTrue((double) subQueryDetailsForHit.get("value") > 0.0f); + assertFalse(subQueryDetailsForHit.get("description").toString().isEmpty()); + assertEquals(1, getListOfValues(subQueryDetailsForHit, "details").size()); + } + // Verify scores are properly ordered + List scores = new ArrayList<>(); + for (Map hit : hitsNestedList) { + scores.add((Double) hit.get("_score")); + } + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); + } finally { + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testSpecificQueryTypes_whenMultiMatchAndKnn_thenSuccessful() { + try { + initializeIndexIfNotExist(TEST_LARGE_DOCS_INDEX_NAME); + // create search pipeline with both normalization processor and explain response processor + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(QueryBuilders.multiMatchQuery(TEST_QUERY_TEXT3, TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2)); + hybridQueryBuilder.add( + KNNQueryBuilder.builder().k(10).fieldName(TEST_KNN_VECTOR_FIELD_NAME_1).vector(TEST_VECTOR_SUPPLIER.get()).build() + ); + + Map searchResponseAsMap = search( + TEST_LARGE_DOCS_INDEX_NAME, + hybridQueryBuilder, + null, + MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()) + ); + + List> hitsNestedList = getNestedHits(searchResponseAsMap); + assertNotNull(hitsNestedList); + assertFalse(hitsNestedList.isEmpty()); + + // Verify total hits + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertTrue((int) total.get("value") > 0); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // Sanity checks for each hit's explanation + for (Map hit : hitsNestedList) { + // Verify score is positive + double score = (double) hit.get("_score"); + assertTrue("Score should be positive", score > 0.0); + + // Basic explanation structure checks + Map explanation = getValueByKey(hit, "_explanation"); + assertNotNull(explanation); + assertEquals("arithmetic_mean combination of:", explanation.get("description")); + Map hitDetailsForHit = getListOfValues(explanation, "details").get(0); + assertTrue((double) hitDetailsForHit.get("value") > 0.0f); + assertEquals("min_max normalization of:", hitDetailsForHit.get("description")); + Map subQueryDetailsForHit = getListOfValues(hitDetailsForHit, "details").get(0); + assertTrue((double) subQueryDetailsForHit.get("value") > 0.0f); + assertFalse(subQueryDetailsForHit.get("description").toString().isEmpty()); + assertNotNull(getListOfValues(subQueryDetailsForHit, "details")); + } + // Verify scores are properly ordered + List scores = new ArrayList<>(); + for (Map hit : hitsNestedList) { + scores.add((Double) hit.get("_score")); + } + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(i -> scores.get(i) < scores.get(i + 1))); + } finally { + wipeOfTestResources(TEST_LARGE_DOCS_INDEX_NAME, null, null, SEARCH_PIPELINE); + } + } + + @SneakyThrows + private void initializeIndexIfNotExist(String indexName) { + if (TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { + prepareKnnIndex( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "1", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector1).toArray(), Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "2", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector2).toArray(), Floats.asList(testVector2).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "3", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector3).toArray(), Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + assertEquals(3, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); + } + + if (TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(TEST_NESTED_TYPE_FIELD_NAME_1), + 1 + ), + "" + ); + addDocsToIndex(TEST_MULTI_DOC_WITH_NESTED_FIELDS_INDEX_NAME); + } + + if (TEST_MULTI_DOC_INDEX_NAME.equals(indexName) && !indexExists(TEST_MULTI_DOC_INDEX_NAME)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration( + Collections.singletonList(new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE)), + List.of(), + 1 + ), + "" + ); + addDocsToIndex(TEST_MULTI_DOC_INDEX_NAME); + } + + if (TEST_LARGE_DOCS_INDEX_NAME.equals(indexName) && !indexExists(TEST_LARGE_DOCS_INDEX_NAME)) { + prepareKnnIndex( + TEST_LARGE_DOCS_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ) + ); + + // Index large number of documents + for (int i = 0; i < MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX; i++) { + String docText; + if (i % 5 == 0) { + docText = TEST_DOC_TEXT1; // "Hello world" + } else if (i % 7 == 0) { + docText = TEST_DOC_TEXT2; // "Hi to this place" + } else if (i % 11 == 0) { + docText = TEST_DOC_TEXT3; // "We would like to welcome everyone" + } else { + docText = String.format(Locale.ROOT, "Document %d with random content", i); + } + + addKnnDoc( + TEST_LARGE_DOCS_INDEX_NAME, + String.valueOf(i), + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of( + Floats.asList(createRandomVector(TEST_DIMENSION)).toArray(), + Floats.asList(createRandomVector(TEST_DIMENSION)).toArray() + ), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(docText) + ); + } + assertEquals(MAX_NUMBER_OF_DOCS_IN_LARGE_INDEX, getDocCount(TEST_LARGE_DOCS_INDEX_NAME)); + } + } + + private void addDocsToIndex(final String testMultiDocIndexName) { + addKnnDoc( + testMultiDocIndexName, + "1", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + testMultiDocIndexName, + "2", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector2).toArray()) + ); + addKnnDoc( + testMultiDocIndexName, + "3", + Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), + Collections.singletonList(Floats.asList(testVector3).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + addKnnDoc( + testMultiDocIndexName, + "4", + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT3) + ); + assertEquals(4, getDocCount(testMultiDocIndexName)); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java index b5e812780..875c66310 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQuerySortIT.java @@ -21,6 +21,8 @@ import org.opensearch.neuralsearch.BaseNeuralSearchIT; import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; import static org.opensearch.neuralsearch.util.TestUtils.assertHitResultsFromQueryWhenSortIsEnabled; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_NORMALIZATION_METHOD; +import static org.opensearch.neuralsearch.util.TestUtils.DEFAULT_COMBINATION_METHOD; import org.opensearch.search.sort.SortOrder; import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; @@ -531,6 +533,141 @@ public void testSortingWithRescoreWhenConcurrentSegmentSearchEnabledAndDisabled_ } } + @SneakyThrows + public void testExplainAndSort_whenIndexWithMultipleShards_thenSuccessful() { + try { + // Setup + updateClusterSettings(CONCURRENT_SEGMENT_SEARCH_ENABLED, false); + + initializeIndexIfNotExists(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, SHARDS_COUNT_IN_MULTI_NODE_CLUSTER); + createSearchPipeline(SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, Map.of(), true); + // Assert + // scores for search hits + HybridQueryBuilder hybridQueryBuilder = createHybridQueryBuilderWithMatchTermAndRangeQuery( + "mission", + "part", + LTE_OF_RANGE_IN_HYBRID_QUERY, + GTE_OF_RANGE_IN_HYBRID_QUERY + ); + + Map fieldSortOrderMap = new HashMap<>(); + fieldSortOrderMap.put("stock", SortOrder.DESC); + + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "explain", Boolean.TRUE.toString()), + null, + null, + createSortBuilders(fieldSortOrderMap, false), + false, + null, + 0 + ); + List> nestedHits = validateHitsCountAndFetchNestedHits(searchResponseAsMap, 6, 6); + assertStockValueWithSortOrderInHybridQueryResults(nestedHits, SortOrder.DESC, LARGEST_STOCK_VALUE_IN_QUERY_RESULT, true, true); + + // explain + Map searchHit1 = nestedHits.get(0); + Map explanationForHit1 = (Map) searchHit1.get("_explanation"); + assertNotNull(explanationForHit1); + assertNull(searchHit1.get("_score")); + String expectedGeneralCombineScoreDescription = "arithmetic_mean combination of:"; + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit1.get("description")); + List> hit1Details = getListOfValues(explanationForHit1, "details"); + assertEquals(2, hit1Details.size()); + Map hit1DetailsForHit1 = hit1Details.get(0); + assertEquals(1.0, hit1DetailsForHit1.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit1.get("description")); + List> hit1DetailsForHit1Details = getListOfValues(hit1DetailsForHit1, "details"); + assertEquals(1, hit1DetailsForHit1Details.size()); + + Map hit1DetailsForHit1DetailsForHit1 = hit1DetailsForHit1Details.get(0); + assertEquals("weight(name:mission in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit1DetailsForHit1.get("description")); + assertTrue((double) hit1DetailsForHit1DetailsForHit1.get("value") > 0.0f); + assertEquals(1, getListOfValues(hit1DetailsForHit1DetailsForHit1, "details").size()); + + Map hit1DetailsForHit1DetailsForHit1DetailsForHit1 = getListOfValues( + hit1DetailsForHit1DetailsForHit1, + "details" + ).get(0); + assertEquals( + "score(freq=1.0), computed as boost * idf * tf from:", + hit1DetailsForHit1DetailsForHit1DetailsForHit1.get("description") + ); + assertTrue((double) hit1DetailsForHit1DetailsForHit1DetailsForHit1.get("value") > 0.0f); + assertEquals(3, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").size()); + + assertEquals("boost", getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("description")); + assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(0).get("value") > 0.0f); + assertEquals( + "idf, computed as log(1 + (N - n + 0.5) / (n + 0.5)) from:", + getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("description") + ); + assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(1).get("value") > 0.0f); + assertEquals( + "tf, computed as freq / (freq + k1 * (1 - b + b * dl / avgdl)) from:", + getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(2).get("description") + ); + assertTrue((double) getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit1, "details").get(2).get("value") > 0.0f); + + // hit 4 + Map searchHit4 = nestedHits.get(3); + Map explanationForHit4 = (Map) searchHit4.get("_explanation"); + assertNotNull(explanationForHit4); + assertNull(searchHit4.get("_score")); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit4.get("description")); + List> hit4Details = getListOfValues(explanationForHit4, "details"); + assertEquals(2, hit4Details.size()); + Map hit1DetailsForHit4 = hit4Details.get(0); + assertEquals(1.0, hit1DetailsForHit4.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit4.get("description")); + assertEquals(1, ((List) hit1DetailsForHit4.get("details")).size()); + List> hit1DetailsForHit4Details = getListOfValues(hit1DetailsForHit4, "details"); + assertEquals(1, hit1DetailsForHit4Details.size()); + + Map hit1DetailsForHit1DetailsForHit4 = hit1DetailsForHit4Details.get(0); + assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit1DetailsForHit4.get("description")); + assertTrue((double) hit1DetailsForHit1DetailsForHit4.get("value") > 0.0f); + assertEquals(1, getListOfValues(hit1DetailsForHit1DetailsForHit4, "details").size()); + + Map hit1DetailsForHit1DetailsForHit1DetailsForHit4 = getListOfValues( + hit1DetailsForHit1DetailsForHit4, + "details" + ).get(0); + assertEquals( + "score(freq=1.0), computed as boost * idf * tf from:", + hit1DetailsForHit1DetailsForHit1DetailsForHit4.get("description") + ); + assertTrue((double) hit1DetailsForHit1DetailsForHit1DetailsForHit4.get("value") > 0.0f); + assertEquals(3, getListOfValues(hit1DetailsForHit1DetailsForHit1DetailsForHit4, "details").size()); + + // hit 6 + Map searchHit6 = nestedHits.get(5); + Map explanationForHit6 = (Map) searchHit6.get("_explanation"); + assertNotNull(explanationForHit6); + assertNull(searchHit6.get("_score")); + assertEquals(expectedGeneralCombineScoreDescription, explanationForHit6.get("description")); + List> hit6Details = getListOfValues(explanationForHit6, "details"); + assertEquals(1, hit6Details.size()); + Map hit1DetailsForHit6 = hit6Details.get(0); + assertEquals(1.0, hit1DetailsForHit6.get("value")); + assertEquals("min_max normalization of:", hit1DetailsForHit6.get("description")); + assertEquals(1, ((List) hit1DetailsForHit6.get("details")).size()); + List> hit1DetailsForHit6Details = getListOfValues(hit1DetailsForHit6, "details"); + assertEquals(1, hit1DetailsForHit6Details.size()); + + Map hit1DetailsForHit1DetailsForHit6 = hit1DetailsForHit6Details.get(0); + assertEquals("weight(name:part in 0) [PerFieldSimilarity], result of:", hit1DetailsForHit1DetailsForHit4.get("description")); + assertTrue((double) hit1DetailsForHit1DetailsForHit6.get("value") > 0.0f); + assertEquals(0, getListOfValues(hit1DetailsForHit1DetailsForHit6, "details").size()); + } finally { + wipeOfTestResources(TEST_MULTI_DOC_INDEX_WITH_TEXT_AND_INT_MULTIPLE_SHARDS, null, null, SEARCH_PIPELINE); + } + } + private HybridQueryBuilder createHybridQueryBuilderWithMatchTermAndRangeQuery(String text, String value, int lte, int gte) { MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(TEXT_FIELD_1_NAME, text); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEXT_FIELD_1_NAME, value); diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java index 10d480475..0e32b5e78 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryWeightTests.java @@ -19,6 +19,7 @@ import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; import org.apache.lucene.search.Matches; import org.apache.lucene.search.MatchesIterator; @@ -146,7 +147,7 @@ public void testSubQueries_whenMultipleEqualSubQueries_thenSuccessful() { } @SneakyThrows - public void testExplain_whenCallExplain_thenFail() { + public void testExplain_whenCallExplain_thenSuccessful() { QueryShardContext mockQueryShardContext = mock(QueryShardContext.class); TextFieldMapper.TextFieldType fieldType = (TextFieldMapper.TextFieldType) createMapperService().fieldType(TEXT_FIELD_NAME); when(mockQueryShardContext.fieldMapper(eq(TEXT_FIELD_NAME))).thenReturn(fieldType); @@ -171,7 +172,8 @@ public void testExplain_whenCallExplain_thenFail() { assertNotNull(weight); LeafReaderContext leafReaderContext = reader.getContext().leaves().get(0); - expectThrows(UnsupportedOperationException.class, () -> weight.explain(leafReaderContext, docId)); + Explanation explanation = weight.explain(leafReaderContext, docId); + assertNotNull(explanation); w.close(); reader.close(); diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 5d8e79e72..ef6d519e5 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -49,6 +49,8 @@ import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.processor.NormalizationProcessor; +import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.neuralsearch.util.TokenWeightUtil; import org.opensearch.search.sort.SortBuilder; @@ -858,6 +860,10 @@ protected List getNormalizationScoreList(final Map searc return scores; } + protected List> getListOfValues(Map searchResponseAsMap, String key) { + return (List>) searchResponseAsMap.get(key); + } + /** * Create a k-NN index from a list of KNNFieldConfigs * @@ -1161,11 +1167,24 @@ protected void createSearchPipeline( final String normalizationMethod, String combinationMethod, final Map combinationParams + ) { + createSearchPipeline(pipelineId, normalizationMethod, combinationMethod, combinationParams, false); + } + + @SneakyThrows + protected void createSearchPipeline( + final String pipelineId, + final String normalizationMethod, + final String combinationMethod, + final Map combinationParams, + boolean addExplainResponseProcessor ) { StringBuilder stringBuilderForContentBody = new StringBuilder(); stringBuilderForContentBody.append("{\"description\": \"Post processor pipeline\",") .append("\"phase_results_processors\": [{ ") - .append("\"normalization-processor\": {") + .append("\"") + .append(NormalizationProcessor.TYPE) + .append("\": {") .append("\"normalization\": {") .append("\"technique\": \"%s\"") .append("},") @@ -1178,7 +1197,15 @@ protected void createSearchPipeline( } stringBuilderForContentBody.append(" }"); } - stringBuilderForContentBody.append("}").append("}}]}"); + stringBuilderForContentBody.append("}").append("}}]"); + if (addExplainResponseProcessor) { + stringBuilderForContentBody.append(", \"response_processors\": [ ") + .append("{\"") + .append(ExplanationResponseProcessor.TYPE) + .append("\": {}}") + .append("]"); + } + stringBuilderForContentBody.append("}"); makeRequest( client(), "PUT", diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java index c10380e87..bb072ab55 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/util/TestUtils.java @@ -9,8 +9,6 @@ import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; -import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getNestedHits; -import static org.opensearch.neuralsearch.util.AggregationsTestUtils.getTotalHits; import static org.opensearch.test.OpenSearchTestCase.randomFloat; import java.util.ArrayList; @@ -352,17 +350,17 @@ public static void assertHitResultsFromQueryWhenSortIsEnabled( assertEquals(RELATION_EQUAL_TO, total.get("relation")); } - private static List> getNestedHits(Map searchResponseAsMap) { + public static List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); } - private static Map getTotalHits(Map searchResponseAsMap) { + public static Map getTotalHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (Map) hitsMap.get("total"); } - private static Optional getMaxScore(Map searchResponseAsMap) { + public static Optional getMaxScore(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } @@ -383,6 +381,13 @@ public static String getModelId(Map pipeline, String processor) return modelId; } + @SuppressWarnings("unchecked") + public static T getValueByKey(Map map, String key) { + assertNotNull(map); + Object value = map.get(key); + return (T) value; + } + public static String generateModelId() { return "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8); }