From d12f480e74a851b4f88b6f57733e4c30ff5fb624 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 24 Aug 2023 20:56:54 +0200 Subject: [PATCH 1/6] Fixed fialing tests after ml-commons added model_gropu_id feature (#262) Signed-off-by: Martin Gaievski --- .../common/BaseNeuralSearchIT.java | 29 +++++++++++++++++++ .../CreateModelGroupRequestBody.json | 5 ++++ .../processor/UploadModelRequestBody.json | 1 + 3 files changed, 35 insertions(+) create mode 100644 src/test/resources/processor/CreateModelGroupRequestBody.json diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index db5dd1fa6..fdf2459df 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -8,6 +8,7 @@ import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; import java.io.IOException; +import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Collections; @@ -76,6 +77,7 @@ protected void updateClusterSettings() { updateClusterSettings("plugins.ml_commons.only_run_on_ml_node", false); // default threshold for native circuit breaker is 90, it may be not enough on test runner machine updateClusterSettings("plugins.ml_commons.native_memory_threshold", 100); + updateClusterSettings("plugins.ml_commons.allow_registering_model_via_url", true); } @SneakyThrows @@ -99,6 +101,10 @@ protected void updateClusterSettings(String settingKey, Object value) { } protected String uploadModel(String requestBody) throws Exception { + String modelGroupId = registerModelGroup(); + // model group id is dynamically generated, we need to update model update request body after group is registered + requestBody = requestBody.replace("", modelGroupId); + Response uploadResponse = makeRequest( client(), "POST", @@ -677,4 +683,27 @@ protected String getDeployedModelId() { assertEquals(1, modelIds.size()); return modelIds.iterator().next(); } + + @SneakyThrows + private String registerModelGroup() throws IOException, URISyntaxException { + String modelGroupRegisterRequestBody = Files.readString( + Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI()) + ); + Response modelGroupResponse = makeRequest( + client(), + "POST", + "/_plugins/_ml/model_groups/_register", + null, + toHttpEntity(modelGroupRegisterRequestBody), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map modelGroupResJson = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(modelGroupResponse.getEntity()), + false + ); + String modelGroupId = modelGroupResJson.get("model_group_id").toString(); + assertNotNull(modelGroupId); + return modelGroupId; + } } diff --git a/src/test/resources/processor/CreateModelGroupRequestBody.json b/src/test/resources/processor/CreateModelGroupRequestBody.json new file mode 100644 index 000000000..d6d398c76 --- /dev/null +++ b/src/test/resources/processor/CreateModelGroupRequestBody.json @@ -0,0 +1,5 @@ +{ + "name": "test_model_group_public", + "description": "This is a public model group", + "access_mode": "public" +} \ No newline at end of file diff --git a/src/test/resources/processor/UploadModelRequestBody.json b/src/test/resources/processor/UploadModelRequestBody.json index 9fc53f3b9..95f9c9cb5 100644 --- a/src/test/resources/processor/UploadModelRequestBody.json +++ b/src/test/resources/processor/UploadModelRequestBody.json @@ -4,6 +4,7 @@ "model_format": "TORCH_SCRIPT", "model_task_type": "text_embedding", "model_content_hash_value": "e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021", + "model_group_id": "", "model_config": { "model_type": "bert", "embedding_dimension": 768, From 75b59cd5f0f45c83f7e6d970f9ecc218d84cdcd7 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 29 Aug 2023 23:19:18 +0200 Subject: [PATCH 2/6] Changed format for hybrid query results to a single list of scores with delimiter (#259) * Changed approach for storing hybrid query results from compound top docs to signle list of scores with delimiter Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../processor/CompoundTopDocs.java | 120 +++++++++++++ .../processor/NormalizationProcessor.java | 38 ++-- .../NormalizationProcessorWorkflow.java | 98 +++++++++-- .../processor/combination/ScoreCombiner.java | 19 +- .../L2ScoreNormalizationTechnique.java | 12 +- .../MinMaxScoreNormalizationTechnique.java | 12 +- .../ScoreNormalizationTechnique.java | 2 +- .../normalization/ScoreNormalizer.java | 4 +- .../neuralsearch/search/CompoundTopDocs.java | 55 ------ .../query/HybridQueryPhaseSearcher.java | 71 +++++++- .../util/HybridSearchResultFormatUtil.java | 56 ++++++ .../opensearch/neuralsearch/TestUtils.java | 90 +++++++++- .../CompoundTopDocsTests.java | 30 ++-- .../processor/NormalizationProcessorIT.java | 30 +++- .../NormalizationProcessorTests.java | 98 +++++++++-- .../NormalizationProcessorWorkflowTests.java | 131 +++++++++++++- .../processor/ScoreCombinationIT.java | 8 +- .../ScoreCombinationTechniqueTests.java | 37 ++-- .../processor/ScoreNormalizationIT.java | 2 +- .../ScoreNormalizationTechniqueTests.java | 39 ++--- .../L2ScoreNormalizationTechniqueTests.java | 33 ++-- ...inMaxScoreNormalizationTechniqueTests.java | 33 ++-- .../neuralsearch/query/HybridQueryIT.java | 165 +++--------------- .../query/HybridQueryPhaseSearcherTests.java | 42 ++++- .../HybridSearchResultFormatUtilTests.java | 61 +++++++ 26 files changed, 902 insertions(+), 385 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java create mode 100644 src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java rename src/test/java/org/opensearch/neuralsearch/{search => processor}/CompoundTopDocsTests.java (74%) create mode 100644 src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 1824d209f..17f47690d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features * Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) ### Enhancements +* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java new file mode 100644 index 000000000..7ab497825 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/CompoundTopDocs.java @@ -0,0 +1,120 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +import lombok.AllArgsConstructor; +import lombok.Getter; +import lombok.Setter; +import lombok.ToString; +import lombok.extern.log4j.Log4j2; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; + +/** + * Class stores collection of TopDocs for each sub query from hybrid query. Collection of results is at shard level. We do store + * list of TopDocs and list of ScoreDoc as well as total hits for the shard. + */ +@AllArgsConstructor +@Getter +@ToString(includeFieldNames = true) +@Log4j2 +public class CompoundTopDocs { + + @Setter + private TotalHits totalHits; + private List topDocs; + @Setter + private List scoreDocs; + + public CompoundTopDocs(final TotalHits totalHits, final List topDocs) { + initialize(totalHits, topDocs); + } + + private void initialize(TotalHits totalHits, List topDocs) { + this.totalHits = totalHits; + this.topDocs = topDocs; + scoreDocs = cloneLargestScoreDocs(topDocs); + } + + /** + * Create new instance from TopDocs by parsing scores of sub-queries. Final format looks like: + * doc_id | magic_number_1 + * doc_id | magic_number_2 + * ... + * doc_id | magic_number_2 + * ... + * doc_id | magic_number_2 + * ... + * doc_id | magic_number_1 + * + * where doc_id is one of valid ids from result. For example, this is list with results for there sub-queries + * + * 0, 9549511920.4881596047 + * 0, 4422440593.9791198149 + * 0, 0.8 + * 2, 0.5 + * 0, 4422440593.9791198149 + * 0, 4422440593.9791198149 + * 2, 0.7 + * 5, 0.65 + * 6, 0.15 + * 0, 9549511920.4881596047 + */ + public CompoundTopDocs(final TopDocs topDocs) { + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + if (Objects.isNull(scoreDocs) || scoreDocs.length < 2) { + initialize(topDocs.totalHits, new ArrayList<>()); + return; + } + // skipping first two elements, it's a start-stop element and delimiter for first series + List topDocsList = new ArrayList<>(); + List scoreDocList = new ArrayList<>(); + for (int index = 2; index < scoreDocs.length; index++) { + // getting first element of score's series + ScoreDoc scoreDoc = scoreDocs[index]; + if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) { + ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]); + TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO); + TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores); + topDocsList.add(subQueryTopDocs); + scoreDocList.clear(); + } else { + scoreDocList.add(scoreDoc); + } + } + initialize(topDocs.totalHits, topDocsList); + } + + private List cloneLargestScoreDocs(final List docs) { + if (docs == null) { + return null; + } + ScoreDoc[] maxScoreDocs = new ScoreDoc[0]; + int maxLength = -1; + for (TopDocs topDoc : docs) { + if (topDoc == null || topDoc.scoreDocs == null) { + continue; + } + if (topDoc.scoreDocs.length > maxLength) { + maxLength = topDoc.scoreDocs.length; + maxScoreDocs = topDoc.scoreDocs; + } + } + // do deep copy + return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).collect(Collectors.toList()); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java index 570d3b9e1..997f08854 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessor.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.processor; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + import java.util.List; import java.util.Objects; import java.util.Optional; @@ -19,8 +21,8 @@ import org.opensearch.action.search.SearchPhaseResults; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.internal.SearchContext; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.query.QuerySearchResult; @@ -56,7 +58,8 @@ public void process( return; } List querySearchResults = getQueryPhaseSearchResults(searchPhaseResult); - normalizationWorkflow.execute(querySearchResults, normalizationTechnique, combinationTechnique); + Optional fetchSearchResult = getFetchSearchResults(searchPhaseResult); + normalizationWorkflow.execute(querySearchResults, fetchSearchResult, normalizationTechnique, combinationTechnique); } @Override @@ -95,19 +98,21 @@ private boolean shouldSkipProcessor(SearchPha } QueryPhaseResultConsumer queryPhaseResultConsumer = (QueryPhaseResultConsumer) searchPhaseResult; - Optional optionalSearchPhaseResult = queryPhaseResultConsumer.getAtomicArray() - .asList() - .stream() - .filter(Objects::nonNull) - .findFirst(); - return isNotHybridQuery(optionalSearchPhaseResult); + return queryPhaseResultConsumer.getAtomicArray().asList().stream().filter(Objects::nonNull).noneMatch(this::isHybridQuery); } - private boolean isNotHybridQuery(final Optional optionalSearchPhaseResult) { - return optionalSearchPhaseResult.isEmpty() - || Objects.isNull(optionalSearchPhaseResult.get().queryResult()) - || Objects.isNull(optionalSearchPhaseResult.get().queryResult().topDocs()) - || !(optionalSearchPhaseResult.get().queryResult().topDocs().topDocs instanceof CompoundTopDocs); + /** + * Return true if results are from hybrid query. + * @param searchPhaseResult + * @return true if results are from hybrid query + */ + private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + // check for delimiter at the end of the score docs. + return Objects.nonNull(searchPhaseResult.queryResult()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) + && Objects.nonNull(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs) + && searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs.length > 0 + && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); } private List getQueryPhaseSearchResults( @@ -119,4 +124,11 @@ private List getQueryPhase .map(result -> result == null ? null : result.queryResult()) .collect(Collectors.toList()); } + + private Optional getFetchSearchResults( + final SearchPhaseResults searchPhaseResults + ) { + Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); + return optionalFirstSearchPhaseResult.map(SearchPhaseResult::fetchResult); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index fda095773..23fbac002 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -5,19 +5,28 @@ package org.opensearch.neuralsearch.processor; +import java.util.Arrays; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Objects; +import java.util.Optional; +import java.util.function.Function; import java.util.stream.Collectors; import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; -import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.query.QuerySearchResult; /** @@ -39,6 +48,7 @@ public class NormalizationProcessorWorkflow { */ public void execute( final List querySearchResults, + final Optional fetchSearchResultOptional, final ScoreNormalizationTechnique normalizationTechnique, final ScoreCombinationTechnique combinationTechnique ) { @@ -57,6 +67,7 @@ public void execute( // post-process data log.debug("Post-process query results after score normalization and combination"); updateOriginalQueryResults(querySearchResults, queryTopDocs); + updateOriginalFetchResults(querySearchResults, fetchSearchResultOptional); } /** @@ -67,22 +78,87 @@ public void execute( private List getQueryTopDocs(final List querySearchResults) { List queryTopDocs = querySearchResults.stream() .filter(searchResult -> Objects.nonNull(searchResult.topDocs())) - .filter(searchResult -> searchResult.topDocs().topDocs instanceof CompoundTopDocs) - .map(searchResult -> (CompoundTopDocs) searchResult.topDocs().topDocs) + .map(querySearchResult -> querySearchResult.topDocs().topDocs) + .map(CompoundTopDocs::new) .collect(Collectors.toList()); + if (queryTopDocs.size() != querySearchResults.size()) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", + querySearchResults.size(), + queryTopDocs.size() + ) + ); + } return queryTopDocs; } private void updateOriginalQueryResults(final List querySearchResults, final List queryTopDocs) { - for (int i = 0; i < querySearchResults.size(); i++) { - QuerySearchResult querySearchResult = querySearchResults.get(i); - if (!(querySearchResult.topDocs().topDocs instanceof CompoundTopDocs) || Objects.isNull(queryTopDocs.get(i))) { - continue; - } - CompoundTopDocs updatedTopDocs = queryTopDocs.get(i); - float maxScore = updatedTopDocs.totalHits.value > 0 ? updatedTopDocs.scoreDocs[0].score : 0.0f; - TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(updatedTopDocs, maxScore); + if (querySearchResults.size() != queryTopDocs.size()) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + "query results were not formatted correctly by the hybrid query; sizes of querySearchResults [%d] and queryTopDocs [%d] must match", + querySearchResults.size(), + queryTopDocs.size() + ) + ); + } + for (int index = 0; index < querySearchResults.size(); index++) { + QuerySearchResult querySearchResult = querySearchResults.get(index); + CompoundTopDocs updatedTopDocs = queryTopDocs.get(index); + float maxScore = updatedTopDocs.getTotalHits().value > 0 ? updatedTopDocs.getScoreDocs().get(0).score : 0.0f; + + // create final version of top docs with all updated values + TopDocs topDocs = new TopDocs(updatedTopDocs.getTotalHits(), updatedTopDocs.getScoreDocs().toArray(new ScoreDoc[0])); + + TopDocsAndMaxScore updatedTopDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, maxScore); querySearchResult.topDocs(updatedTopDocsAndMaxScore, null); } } + + /** + * A workaround for a single shard case, fetch has happened, and we need to update both fetch and query results + */ + private void updateOriginalFetchResults( + final List querySearchResults, + final Optional fetchSearchResultOptional + ) { + if (fetchSearchResultOptional.isEmpty()) { + return; + } + // fetch results have list of document content, that includes start/stop and + // delimiter elements. list is in original order from query searcher. We need to: + // 1. filter out start/stop and delimiter elements + // 2. filter out duplicates from different sub-queries + // 3. update original scores to normalized and combined values + // 4. order scores based on normalized and combined values + FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); + SearchHits searchHits = fetchSearchResult.hits(); + + // create map of docId to index of search hits. This solves (2), duplicates are from + // delimiter and start/stop elements, they all have same valid doc_id. For this map + // we use doc_id as a key, and all those special elements are collapsed into a single + // key-value pair. + Map docIdToSearchHit = Arrays.stream(searchHits.getHits()) + .collect(Collectors.toMap(SearchHit::docId, Function.identity(), (a1, a2) -> a1)); + + QuerySearchResult querySearchResult = querySearchResults.get(0); + TopDocs topDocs = querySearchResult.topDocs().topDocs; + // iterate over the normalized/combined scores, that solves (1) and (3) + SearchHit[] updatedSearchHitArray = Arrays.stream(topDocs.scoreDocs).map(scoreDoc -> { + // get fetched hit content by doc_id + SearchHit searchHit = docIdToSearchHit.get(scoreDoc.doc); + // update score to normalized/combined value (3) + searchHit.score(scoreDoc.score); + return searchHit; + }).toArray(SearchHit[]::new); + SearchHits updatedSearchHits = new SearchHits( + updatedSearchHitArray, + querySearchResult.getTotalHits(), + querySearchResult.getMaxScore() + ); + fetchSearchResult.hits(updatedSearchHits); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 67e776d77..0293efae6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -18,7 +18,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; -import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; /** * Abstracts combination of scores in query search results. @@ -48,10 +48,10 @@ public void combineScores(final List queryTopDocs, final ScoreC } private void combineShardScores(final ScoreCombinationTechnique scoreCombinationTechnique, final CompoundTopDocs compoundQueryTopDocs) { - if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.totalHits.value == 0) { + if (Objects.isNull(compoundQueryTopDocs) || compoundQueryTopDocs.getTotalHits().value == 0) { return; } - List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); // - create map of normalized scores results returned from the single shard Map normalizedScoresPerDoc = getNormalizedScoresPerDocument(topDocsPerSubQuery); @@ -76,7 +76,7 @@ private List getSortedDocIds(final Map combinedNormaliz return sortedDocsIds; } - private ScoreDoc[] getCombinedScoreDocs( + private List getCombinedScoreDocs( final CompoundTopDocs compoundQueryTopDocs, final Map combinedNormalizedScoresByDocId, final List sortedScores, @@ -84,12 +84,12 @@ private ScoreDoc[] getCombinedScoreDocs( ) { ScoreDoc[] finalScoreDocs = new ScoreDoc[maxHits]; - int shardId = compoundQueryTopDocs.scoreDocs[0].shardIndex; + int shardId = compoundQueryTopDocs.getScoreDocs().get(0).shardIndex; for (int j = 0; j < maxHits && j < sortedScores.size(); j++) { int docId = sortedScores.get(j); finalScoreDocs[j] = new ScoreDoc(docId, combinedNormalizedScoresByDocId.get(docId), shardId); } - return finalScoreDocs; + return Arrays.stream(finalScoreDocs).collect(Collectors.toList()); } public Map getNormalizedScoresPerDocument(final List topDocsPerSubQuery) { @@ -100,7 +100,6 @@ public Map getNormalizedScoresPerDocument(final List normalizedScoresPerDoc.computeIfAbsent(scoreDoc.doc, key -> { float[] scores = new float[topDocsPerSubQuery.size()]; // we initialize with -1.0, as after normalization it's possible that score is 0.0 - Arrays.fill(scores, -1.0f); return scores; }); normalizedScoresPerDoc.get(scoreDoc.doc)[j] = scoreDoc.score; @@ -127,8 +126,10 @@ private void updateQueryTopDocsWithCombinedScores( // - count max number of hits among sub-queries int maxHits = getMaxHits(topDocsPerSubQuery); // - update query search results with normalized scores - compoundQueryTopDocs.scoreDocs = getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits); - compoundQueryTopDocs.totalHits = getTotalHits(topDocsPerSubQuery, maxHits); + compoundQueryTopDocs.setScoreDocs( + getCombinedScoreDocs(compoundQueryTopDocs, combinedNormalizedScoresByDocId, sortedScores, maxHits) + ); + compoundQueryTopDocs.setTotalHits(getTotalHits(topDocsPerSubQuery, maxHits)); } protected int getMaxHits(final List topDocsPerSubQuery) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java index 0e55e7231..a8c0a5953 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechnique.java @@ -13,7 +13,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; -import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; /** * Abstracts normalization of scores based on L2 method @@ -22,7 +22,7 @@ public class L2ScoreNormalizationTechnique implements ScoreNormalizationTechnique { @ToString.Include public static final String TECHNIQUE_NAME = "l2"; - private static final float MIN_SCORE = 0.001f; + private static final float MIN_SCORE = 0.0f; /** * L2 normalization method. @@ -41,7 +41,7 @@ public void normalize(final List queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; } - List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { @@ -57,17 +57,17 @@ private List getL2Norm(final List queryTopDocs) { // rest of sub-queries with zero total hits int numOfSubqueries = queryTopDocs.stream() .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .filter(topDocs -> topDocs.getTopDocs().size() > 0) .findAny() .get() - .getCompoundTopDocs() + .getTopDocs() .size(); float[] l2Norms = new float[numOfSubqueries]; for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; } - List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); int bound = topDocsPerSubQuery.size(); for (int index = 0; index < bound; index++) { for (ScoreDoc scoreDocs : topDocsPerSubQuery.get(index).scoreDocs) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index e32dbb033..6452e6dcd 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -13,7 +13,7 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; -import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; import com.google.common.primitives.Floats; @@ -38,10 +38,10 @@ public class MinMaxScoreNormalizationTechnique implements ScoreNormalizationTech public void normalize(final List queryTopDocs) { int numOfSubqueries = queryTopDocs.stream() .filter(Objects::nonNull) - .filter(topDocs -> topDocs.getCompoundTopDocs().size() > 0) + .filter(topDocs -> topDocs.getTopDocs().size() > 0) .findAny() .get() - .getCompoundTopDocs() + .getTopDocs() .size(); // get min scores for each sub query float[] minScoresPerSubquery = getMinScores(queryTopDocs, numOfSubqueries); @@ -54,7 +54,7 @@ public void normalize(final List queryTopDocs) { if (Objects.isNull(compoundQueryTopDocs)) { continue; } - List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { @@ -71,7 +71,7 @@ private float[] getMaxScores(final List queryTopDocs, final int if (Objects.isNull(compoundQueryTopDocs)) { continue; } - List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { maxScores[j] = Math.max( maxScores[j], @@ -92,7 +92,7 @@ private float[] getMinScores(final List queryTopDocs, final int if (Objects.isNull(compoundQueryTopDocs)) { continue; } - List topDocsPerSubQuery = compoundQueryTopDocs.getCompoundTopDocs(); + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); for (int j = 0; j < topDocsPerSubQuery.size(); j++) { minScores[j] = Math.min( minScores[j], diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java index fdaeb85d8..8dd124804 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationTechnique.java @@ -7,7 +7,7 @@ import java.util.List; -import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; /** * Abstracts normalization of scores in query search results. diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java index 5b8b7b1ca..e2e3385eb 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizer.java @@ -8,7 +8,7 @@ import java.util.List; import java.util.Objects; -import org.opensearch.neuralsearch.search.CompoundTopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; public class ScoreNormalizer { @@ -24,6 +24,6 @@ public void normalizeScores(final List queryTopDocs, final Scor } private boolean canQueryResultsBeNormalized(final List queryTopDocs) { - return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getCompoundTopDocs().size() > 0); + return queryTopDocs.stream().filter(Objects::nonNull).anyMatch(topDocs -> topDocs.getTopDocs().size() > 0); } } diff --git a/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java b/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java deleted file mode 100644 index fbc820d8b..000000000 --- a/src/main/java/org/opensearch/neuralsearch/search/CompoundTopDocs.java +++ /dev/null @@ -1,55 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.search; - -import java.util.Arrays; -import java.util.List; - -import lombok.Getter; -import lombok.ToString; - -import org.apache.lucene.search.ScoreDoc; -import org.apache.lucene.search.TopDocs; -import org.apache.lucene.search.TotalHits; - -/** - * Class stores collection of TodDocs for each sub query from hybrid query - */ -@ToString(includeFieldNames = true) -public class CompoundTopDocs extends TopDocs { - - @Getter - private List compoundTopDocs; - - public CompoundTopDocs(TotalHits totalHits, ScoreDoc[] scoreDocs) { - super(totalHits, scoreDocs); - } - - public CompoundTopDocs(TotalHits totalHits, List docs) { - // we pass clone of score docs from the sub-query that has most hits - super(totalHits, cloneLargestScoreDocs(docs)); - this.compoundTopDocs = docs; - } - - private static ScoreDoc[] cloneLargestScoreDocs(List docs) { - if (docs == null) { - return null; - } - ScoreDoc[] maxScoreDocs = new ScoreDoc[0]; - int maxLength = -1; - for (TopDocs topDoc : docs) { - if (topDoc == null || topDoc.scoreDocs == null) { - continue; - } - if (topDoc.scoreDocs.length > maxLength) { - maxLength = topDoc.scoreDocs.length; - maxScoreDocs = topDoc.scoreDocs; - } - } - // do deep copy - return Arrays.stream(maxScoreDocs).map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java index 81b6b7ebd..abf1b8813 100644 --- a/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java +++ b/src/main/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcher.java @@ -5,11 +5,16 @@ package org.opensearch.neuralsearch.search.query; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import static org.opensearch.search.query.TopDocsCollectorContext.createTopDocsCollectorContext; import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; import java.util.LinkedList; import java.util.List; +import java.util.Objects; import lombok.extern.log4j.Log4j2; @@ -21,7 +26,6 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.common.lucene.search.TopDocsAndMaxScore; import org.opensearch.neuralsearch.query.HybridQuery; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.neuralsearch.search.HitsThresholdChecker; import org.opensearch.neuralsearch.search.HybridTopScoreDocCollector; import org.opensearch.search.DocValueFormat; @@ -110,29 +114,84 @@ private void setTopDocsInQueryResult( ) { final List topDocs = collector.topDocs(); final float maxScore = getMaxScore(topDocs); - final TopDocs newTopDocs = new CompoundTopDocs(getTotalHits(searchContext, topDocs), topDocs); + final boolean isSingleShard = searchContext.numberOfShards() == 1; + final TopDocs newTopDocs = getNewTopDocs(getTotalHits(searchContext, topDocs, isSingleShard), topDocs); final TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(newTopDocs, maxScore); queryResult.topDocs(topDocsAndMaxScore, getSortValueFormats(searchContext.sort())); } - private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs) { + private TopDocs getNewTopDocs(final TotalHits totalHits, final List topDocs) { + ScoreDoc[] scoreDocs = new ScoreDoc[0]; + if (Objects.nonNull(topDocs)) { + // for a single shard case we need to do score processing at coordinator level. + // this is workaround for current core behaviour, for single shard fetch phase is executed + // right after query phase and processors are called after actual fetch is done + // find any valid doc Id, or set it to -1 if there is not a single match + int delimiterDocId = topDocs.stream() + .filter(Objects::nonNull) + .filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)) + .map(topDoc -> topDoc.scoreDocs) + .filter(scoreDoc -> scoreDoc.length > 0) + .map(scoreDoc -> scoreDoc[0].doc) + .findFirst() + .orElse(-1); + if (delimiterDocId == -1) { + return new TopDocs(totalHits, scoreDocs); + } + // format scores using following template: + // doc_id | magic_number_1 + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_2 + // ... + // doc_id | magic_number_1 + List result = new ArrayList<>(); + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + for (TopDocs topDoc : topDocs) { + if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) { + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + continue; + } + result.add(createDelimiterElementForHybridSearchResults(delimiterDocId)); + result.addAll(Arrays.asList(topDoc.scoreDocs)); + } + result.add(createStartStopElementForHybridSearchResults(delimiterDocId)); + scoreDocs = result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new); + } + return new TopDocs(totalHits, scoreDocs); + } + + private TotalHits getTotalHits(final SearchContext searchContext, final List topDocs, final boolean isSingleShard) { int trackTotalHitsUpTo = searchContext.trackTotalHitsUpTo(); final TotalHits.Relation relation = trackTotalHitsUpTo == SearchContext.TRACK_TOTAL_HITS_DISABLED ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO; - if (topDocs == null || topDocs.size() == 0) { + if (topDocs == null || topDocs.isEmpty()) { return new TotalHits(0, relation); } long maxTotalHits = topDocs.get(0).totalHits.value; + int totalSize = 0; for (TopDocs topDoc : topDocs) { maxTotalHits = Math.max(maxTotalHits, topDoc.totalHits.value); + if (isSingleShard) { + totalSize += topDoc.totalHits.value + 1; + } } + // add 1 qty per each sub-query and + 2 for start and stop delimiters + totalSize += 2; + if (isSingleShard) { + // for single shard we need to update total size as this is how many docs are fetched in Fetch phase + searchContext.size(totalSize); + } + return new TotalHits(maxTotalHits, relation); } private float getMaxScore(final List topDocs) { - if (topDocs.size() == 0) { - return Float.NaN; + if (topDocs.isEmpty()) { + return 0.0f; } else { return topDocs.stream() .map(docs -> docs.scoreDocs.length == 0 ? new ScoreDoc(-1, 0.0f) : docs.scoreDocs[0]) diff --git a/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java new file mode 100644 index 000000000..b345a105a --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtil.java @@ -0,0 +1,56 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.util; + +import java.util.Objects; + +import org.apache.lucene.search.ScoreDoc; + +/** + * Utility class for handling format of Hybrid Search query results + */ +public class HybridSearchResultFormatUtil { + // both magic numbers are randomly generated, there should be no collision as whole part of score is huge + // and OpenSearch convention is that scores are positive numbers + public static final Float MAGIC_NUMBER_START_STOP = -9549511920.4881596047f; + public static final Float MAGIC_NUMBER_DELIMITER = -4422440593.9791198149f; + + /** + * Create ScoreDoc object that is a start/stop element in case of hybrid search query results + * @param docId id of one of docs from actual result object, or -1 if there are no matches + * @return + */ + public static ScoreDoc createStartStopElementForHybridSearchResults(final int docId) { + return new ScoreDoc(docId, MAGIC_NUMBER_START_STOP); + } + + /** + * Create ScoreDoc object that is a delimiter element between sub-query results in hybrid search query results + * @param docId id of one of docs from actual result object, or -1 if there are no matches + * @return + */ + public static ScoreDoc createDelimiterElementForHybridSearchResults(final int docId) { + return new ScoreDoc(docId, MAGIC_NUMBER_DELIMITER); + } + + /** + * Checking if passed scoreDocs object is a start/stop element in the list of hybrid query result scores + * @param scoreDoc + * @return true if it is a start/stop element + */ + public static boolean isHybridQueryStartStopElement(final ScoreDoc scoreDoc) { + return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_START_STOP) == 0; + } + + /** + * Checking if passed scoreDocs object is a delimiter element in the list of hybrid query result scores + * @param scoreDoc + * @return true if it is a delimiter element + */ + public static boolean isHybridQueryDelimiterElement(final ScoreDoc scoreDoc) { + return Objects.nonNull(scoreDoc) && scoreDoc.doc >= 0 && Float.compare(scoreDoc.score, MAGIC_NUMBER_DELIMITER) == 0; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/TestUtils.java b/src/test/java/org/opensearch/neuralsearch/TestUtils.java index ff221bf20..3b131b886 100644 --- a/src/test/java/org/opensearch/neuralsearch/TestUtils.java +++ b/src/test/java/org/opensearch/neuralsearch/TestUtils.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.opensearch.test.OpenSearchTestCase.randomFloat; @@ -14,19 +15,25 @@ import java.util.Arrays; import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; import org.apache.commons.lang3.Range; +import org.apache.lucene.search.TotalHits; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesReference; import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; +import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.query.QuerySearchResult; public class TestUtils { - private final static String RELATION_EQUAL_TO = "eq"; + private static final String RELATION_EQUAL_TO = "eq"; + public static final float DELTA_FOR_SCORE_ASSERTION = 0.001f; /** * Convert an xContentBuilder to a map @@ -66,7 +73,7 @@ public static float[] createRandomVector(int dimension) { } /** - * Assert results of hyrdid query after score normalization and combination + * Assert results of hybrid query after score normalization and combination * @param querySearchResults collection of query search results after they processed by normalization processor */ public static void assertQueryResultScores(List querySearchResults) { @@ -75,12 +82,12 @@ public static void assertQueryResultScores(List querySearchRe .map(searchResult -> searchResult.topDocs().maxScore) .max(Float::compare) .orElse(Float.MAX_VALUE); - assertEquals(1.0f, maxScore, 0.0f); + assertEquals(1.0f, maxScore, DELTA_FOR_SCORE_ASSERTION); float totalMaxScore = querySearchResults.stream() .map(searchResult -> searchResult.getMaxScore()) .max(Float::compare) .orElse(Float.MAX_VALUE); - assertEquals(1.0f, totalMaxScore, 0.0f); + assertEquals(1.0f, totalMaxScore, DELTA_FOR_SCORE_ASSERTION); float maxScoreScoreFromScoreDocs = querySearchResults.stream() .map( searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) @@ -90,7 +97,7 @@ public static void assertQueryResultScores(List querySearchRe ) .max(Float::compare) .orElse(Float.MAX_VALUE); - assertEquals(1.0f, maxScoreScoreFromScoreDocs, 0.0f); + assertEquals(1.0f, maxScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); float minScoreScoreFromScoreDocs = querySearchResults.stream() .map( searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) @@ -100,7 +107,50 @@ public static void assertQueryResultScores(List querySearchRe ) .min(Float::compare) .orElse(Float.MAX_VALUE); - assertEquals(0.001f, minScoreScoreFromScoreDocs, 0.0f); + assertEquals(0.001f, minScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); + } + + /** + * Assert results of hybrid query after score normalization and combination + * @param querySearchResults collection of query search results after they processed by normalization processor + */ + public static void assertQueryResultScoresWithNoMatches(List querySearchResults) { + assertNotNull(querySearchResults); + float maxScore = querySearchResults.stream() + .map(searchResult -> searchResult.topDocs().maxScore) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(0.0f, maxScore, DELTA_FOR_SCORE_ASSERTION); + float totalMaxScore = querySearchResults.stream().map(QuerySearchResult::getMaxScore).max(Float::compare).orElse(Float.MAX_VALUE); + assertEquals(0.0f, totalMaxScore, DELTA_FOR_SCORE_ASSERTION); + float maxScoreScoreFromScoreDocs = querySearchResults.stream() + .map( + searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .max(Float::compare) + .orElse(0.0f) + ) + .max(Float::compare) + .orElse(0.0f); + assertEquals(0.0f, maxScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); + float minScoreScoreFromScoreDocs = querySearchResults.stream() + .map( + searchResult -> Arrays.stream(searchResult.topDocs().topDocs.scoreDocs) + .map(scoreDoc -> scoreDoc.score) + .min(Float::compare) + .orElse(0.0f) + ) + .min(Float::compare) + .orElse(0.0f); + assertEquals(0.001f, minScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); + + assertFalse( + querySearchResults.stream() + .map(searchResult -> searchResult.topDocs().topDocs.totalHits) + .filter(totalHits -> Objects.isNull(totalHits.relation)) + .filter(totalHits -> TotalHits.Relation.EQUAL_TO != totalHits.relation) + .anyMatch(totalHits -> 0 != totalHits.value) + ); } /** @@ -176,6 +226,34 @@ public static void assertHybridSearchResults( assertEquals(Set.copyOf(ids).size(), ids.size()); } + /** + * Assert results of a fetch phase for hybrid query + * @param fetchSearchResult results produced by fetch phase + * @param expectedNumberOfHits expected number of hits that should be in the fetch result object + */ + public static void assertFetchResultScores(FetchSearchResult fetchSearchResult, int expectedNumberOfHits) { + assertNotNull(fetchSearchResult); + assertNotNull(fetchSearchResult.hits()); + SearchHits searchHits = fetchSearchResult.hits(); + float maxScore = searchHits.getMaxScore(); + assertEquals(1.0f, maxScore, DELTA_FOR_SCORE_ASSERTION); + TotalHits totalHits = searchHits.getTotalHits(); + assertNotNull(totalHits); + assertEquals(expectedNumberOfHits, totalHits.value); + assertNotNull(searchHits.getHits()); + assertEquals(expectedNumberOfHits, searchHits.getHits().length); + float maxScoreScoreFromScoreDocs = Arrays.stream(searchHits.getHits()) + .map(SearchHit::getScore) + .max(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(1.0f, maxScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); + float minScoreScoreFromScoreDocs = Arrays.stream(searchHits.getHits()) + .map(SearchHit::getScore) + .min(Float::compare) + .orElse(Float.MAX_VALUE); + assertEquals(0.001f, minScoreScoreFromScoreDocs, DELTA_FOR_SCORE_ASSERTION); + } + private static List> getNestedHits(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return (List>) hitsMap.get("hits"); diff --git a/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java similarity index 74% rename from src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java rename to src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java index 0c79d7f73..a5bdda1e3 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/CompoundTopDocsTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/CompoundTopDocsTests.java @@ -3,7 +3,7 @@ * SPDX-License-Identifier: Apache-2.0 */ -package org.opensearch.neuralsearch.search; +package org.opensearch.neuralsearch.processor; import java.util.Arrays; import java.util.List; @@ -31,19 +31,25 @@ public void testBasics_whenCreateWithTopDocsArray_thenSuccessful() { List topDocs = List.of(topDocs1, topDocs2); CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(3, TotalHits.Relation.EQUAL_TO), topDocs); assertNotNull(compoundTopDocs); - assertEquals(topDocs, compoundTopDocs.getCompoundTopDocs()); + assertEquals(topDocs, compoundTopDocs.getTopDocs()); } public void testBasics_whenCreateWithoutTopDocs_thenTopDocsIsNull() { CompoundTopDocs hybridQueryScoreTopDocs = new CompoundTopDocs( new TotalHits(3, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { - new ScoreDoc(2, RandomUtils.nextFloat()), - new ScoreDoc(4, RandomUtils.nextFloat()), - new ScoreDoc(5, RandomUtils.nextFloat()) } + List.of( + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + new ScoreDoc(2, RandomUtils.nextFloat()), + new ScoreDoc(4, RandomUtils.nextFloat()), + new ScoreDoc(5, RandomUtils.nextFloat()) } + ) + ) ); assertNotNull(hybridQueryScoreTopDocs); - assertNull(hybridQueryScoreTopDocs.getCompoundTopDocs()); + assertNotNull(hybridQueryScoreTopDocs.getScoreDocs()); + assertNotNull(hybridQueryScoreTopDocs.getTopDocs()); } public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWithMostHits() { @@ -55,21 +61,21 @@ public void testBasics_whenMultipleTopDocsOfDifferentLength_thenReturnTopDocsWit List topDocs = List.of(topDocs1, topDocs2); CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), topDocs); assertNotNull(compoundTopDocs); - assertNotNull(compoundTopDocs.scoreDocs); - assertEquals(2, compoundTopDocs.scoreDocs.length); + assertNotNull(compoundTopDocs.getScoreDocs()); + assertEquals(2, compoundTopDocs.getScoreDocs().size()); } public void testBasics_whenMultipleTopDocsIsNull_thenScoreDocsIsNull() { CompoundTopDocs compoundTopDocs = new CompoundTopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), (List) null); assertNotNull(compoundTopDocs); - assertNull(compoundTopDocs.scoreDocs); + assertNull(compoundTopDocs.getScoreDocs()); CompoundTopDocs compoundTopDocsWithNullArray = new CompoundTopDocs( new TotalHits(0, TotalHits.Relation.EQUAL_TO), Arrays.asList(null, null) ); assertNotNull(compoundTopDocsWithNullArray); - assertNotNull(compoundTopDocsWithNullArray.scoreDocs); - assertEquals(0, compoundTopDocsWithNullArray.scoreDocs.length); + assertNotNull(compoundTopDocsWithNullArray.getScoreDocs()); + assertEquals(0, compoundTopDocsWithNullArray.getScoreDocs().size()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index a9b1fc9bf..3cd71e5a1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -42,8 +42,11 @@ public class NormalizationProcessorIT extends BaseNeuralSearchIT { private static final String TEST_DOC_TEXT3 = "We would like to welcome everyone"; private static final String TEST_DOC_TEXT4 = "Hello, I'm glad to you see you pal"; private static final String TEST_DOC_TEXT5 = "Say hello and enter my friend"; + private static final String TEST_DOC_TEXT6 = "This tale grew in the telling"; + private static final String TEST_DOC_TEXT7 = "They do not and did not understand or like machines"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + private static final String TEST_TEXT_FIELD_NAME_2 = "test-text-field-2"; private static final int TEST_DIMENSION = 768; private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; private static final String SEARCH_PIPELINE = "phase-results-pipeline"; @@ -172,7 +175,7 @@ public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful assertNotNull(total.get("relation")); assertEquals(RELATION_EQUAL_TO, total.get("relation")); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); - assertTrue(Range.between(.75f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); + assertTrue(Range.between(.5f, 1.0f).contains(getMaxScore(searchResponseAsMap).get())); List> hitsNestedList = getNestedHits(searchResponseAsMap); List ids = new ArrayList<>(); List scores = new ArrayList<>(); @@ -187,7 +190,7 @@ public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful // based on random vectors and return results for every doc. In some cases that may affect 1.0 score from term query and make it // lower. float highestScore = scores.stream().max(Double::compare).get().floatValue(); - assertTrue(Range.between(.75f, 1.0f).contains(highestScore)); + assertTrue(Range.between(.5f, 1.0f).contains(highestScore)); float lowestScore = scores.stream().min(Double::compare).get().floatValue(); assertTrue(Range.between(.0f, .5f).contains(lowestScore)); @@ -231,7 +234,7 @@ public void testResultProcessor_whenMultipleShardsAndPartialMatches_thenSuccessf 5, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertQueryResults(searchResponseAsMap, 4, true); + assertQueryResults(searchResponseAsMap, 4, true, Range.between(0.33f, 1.0f)); } private void initializeIndexIfNotExist(String indexName) throws IOException { @@ -293,8 +296,8 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { "1", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), Collections.singletonList(Floats.asList(testVector1).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT1) + List.of(TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2), + List.of(TEST_DOC_TEXT1, TEST_DOC_TEXT6) ); addKnnDoc( TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, @@ -307,8 +310,8 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { "3", Collections.singletonList(TEST_KNN_VECTOR_FIELD_NAME_1), Collections.singletonList(Floats.asList(testVector3).toArray()), - Collections.singletonList(TEST_TEXT_FIELD_NAME_1), - Collections.singletonList(TEST_DOC_TEXT2) + List.of(TEST_TEXT_FIELD_NAME_1, TEST_TEXT_FIELD_NAME_2), + List.of(TEST_DOC_TEXT2, TEST_DOC_TEXT7) ); addKnnDoc( TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, @@ -354,6 +357,15 @@ private Optional getMaxScore(Map searchResponseAsMap) { } private void assertQueryResults(Map searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) { + assertQueryResults(searchResponseAsMap, totalExpectedDocQty, assertMinScore, Range.between(0.5f, 1.0f)); + } + + private void assertQueryResults( + Map searchResponseAsMap, + int totalExpectedDocQty, + boolean assertMinScore, + Range maxScoreRange + ) { assertNotNull(searchResponseAsMap); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); @@ -362,7 +374,7 @@ private void assertQueryResults(Map searchResponseAsMap, int tot assertEquals(RELATION_EQUAL_TO, total.get("relation")); assertTrue(getMaxScore(searchResponseAsMap).isPresent()); if (totalExpectedDocQty > 0) { - assertEquals(1.0, getMaxScore(searchResponseAsMap).get(), 0.001f); + assertTrue(maxScoreRange.contains(getMaxScore(searchResponseAsMap).get())); } else { assertEquals(0.0, getMaxScore(searchResponseAsMap).get(), 0.001f); } @@ -378,7 +390,7 @@ private void assertQueryResults(Map searchResponseAsMap, int tot assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); // verify the scores are normalized if (totalExpectedDocQty > 0) { - assertEquals(1.0, (double) scores.stream().max(Double::compare).get(), 0.001); + assertTrue(maxScoreRange.contains(scores.stream().max(Double::compare).get().floatValue())); if (assertMinScore) { assertEquals(0.001, (double) scores.stream().min(Double::compare).get(), 0.001); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 397642007..41348ec49 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -10,6 +10,9 @@ import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import java.util.List; import java.util.concurrent.CountDownLatch; @@ -41,7 +44,6 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.search.DocValueFormat; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.aggregations.InternalAggregation; @@ -151,14 +153,18 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio OriginalIndices.NONE ); QuerySearchResult querySearchResult = new QuerySearchResult(); - CompoundTopDocs topDocs = new CompoundTopDocs( + TopDocs topDocs = new TopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), - List.of( - new TopDocs( - new TotalHits(4, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } - ) - ) + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(4), + createDelimiterElementForHybridSearchResults(4), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(4) } + ); querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); querySearchResult.setSearchShardTarget(searchShardTarget); @@ -179,6 +185,74 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio TestUtils.assertQueryResultScores(querySearchResults); } + public void testScoreCorrectness_whenCompoundDocs_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + NormalizationProcessor normalizationProcessor = new NormalizationProcessor( + PROCESSOR_TAG, + DESCRIPTION, + new ScoreNormalizationFactory().createNormalization(NORMALIZATION_METHOD), + new ScoreCombinationFactory().createCombination(COMBINATION_METHOD), + normalizationProcessorWorkflow + ); + + SearchRequest searchRequest = new SearchRequest(INDEX_NAME); + searchRequest.setBatchedReduceSize(4); + AtomicReference onPartialMergeFailure = new AtomicReference<>(); + QueryPhaseResultConsumer queryPhaseResultConsumer = new QueryPhaseResultConsumer( + searchRequest, + executor, + new NoopCircuitBreaker(CircuitBreaker.REQUEST), + searchPhaseController, + SearchProgressListener.NOOP, + writableRegistry(), + 10, + e -> onPartialMergeFailure.accumulateAndGet(e, (prev, curr) -> { + curr.addSuppressed(prev); + return curr; + }) + ); + CountDownLatch partialReduceLatch = new CountDownLatch(1); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget("node", new ShardId("index", "uuid", 0), null, OriginalIndices.NONE); + QuerySearchResult querySearchResult = new QuerySearchResult(); + TopDocs topDocs = new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(10), + createDelimiterElementForHybridSearchResults(10), + new ScoreDoc(2429, 0.028685084f), + new ScoreDoc(14, 0.025785536f), + new ScoreDoc(10, 0.024871103f), + createDelimiterElementForHybridSearchResults(10), + new ScoreDoc(2429, 25.438505f), + new ScoreDoc(10, 25.226639f), + new ScoreDoc(14, 24.935198f), + new ScoreDoc(2428, 21.614073f), + createStartStopElementForHybridSearchResults(10) } + + ); + querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 25.438505f), new DocValueFormat[0]); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + + queryPhaseResultConsumer.consumeResult(querySearchResult, partialReduceLatch::countDown); + + SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getNumShards()).thenReturn(1); + normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); + + List querySearchResults = queryPhaseResultConsumer.getAtomicArray() + .asList() + .stream() + .map(result -> result == null ? null : result.queryResult()) + .collect(Collectors.toList()); + + TestUtils.assertQueryResultScores(querySearchResults); + } + public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkflow() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) @@ -193,7 +267,7 @@ public void testEmptySearchResults_whenEmptySearchResults_thenDoNotExecuteWorkfl SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); normalizationProcessor.process(null, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResult_thenDoNotExecuteWorkflow() { @@ -225,7 +299,8 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul }) ); CountDownLatch partialReduceLatch = new CountDownLatch(5); - for (int shardId = 0; shardId < 4; shardId++) { + int numberOfShards = 4; + for (int shardId = 0; shardId < numberOfShards; shardId++) { SearchShardTarget searchShardTarget = new SearchShardTarget( "node", new ShardId("index", "uuid", shardId), @@ -245,8 +320,9 @@ public void testNotHybridSearchResult_whenResultsNotEmptyAndNotHybridSearchResul } SearchPhaseContext searchPhaseContext = mock(SearchPhaseContext.class); + when(searchPhaseContext.getNumShards()).thenReturn(numberOfShards); normalizationProcessor.process(queryPhaseResultConsumer, searchPhaseContext); - verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any()); + verify(normalizationProcessorWorkflow, never()).execute(any(), any(), any(), any()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index 453725a0d..a74fb53f2 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -6,9 +6,13 @@ package org.opensearch.neuralsearch.processor; import static org.mockito.Mockito.spy; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; import java.util.ArrayList; import java.util.List; +import java.util.Map; +import java.util.Optional; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; @@ -21,15 +25,17 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchHit; +import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; public class NormalizationProcessorWorkflowTests extends OpenSearchTestCase { - public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombination() { + public void testSearchResultTypes_whenResultsOfHybridSearch_thenDoNormalizationCombination() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -43,16 +49,23 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio OriginalIndices.NONE ); QuerySearchResult querySearchResult = new QuerySearchResult(); - CompoundTopDocs topDocs = new CompoundTopDocs( - new TotalHits(4, TotalHits.Relation.EQUAL_TO), - List.of( + querySearchResult.topDocs( + new TopDocsAndMaxScore( new TopDocs( new TotalHits(4, TotalHits.Relation.EQUAL_TO), - new ScoreDoc[] { new ScoreDoc(0, 0.5f), new ScoreDoc(2, 0.3f), new ScoreDoc(4, 0.25f), new ScoreDoc(10, 0.2f) } - ) - ) + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + new DocValueFormat[0] ); - querySearchResult.topDocs(new TopDocsAndMaxScore(topDocs, 0.5f), new DocValueFormat[0]); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); querySearchResults.add(querySearchResult); @@ -60,10 +73,110 @@ public void testSearchResultTypes_whenCompoundDocs_thenDoNormalizationCombinatio normalizationProcessorWorkflow.execute( querySearchResults, + Optional.empty(), ScoreNormalizationFactory.DEFAULT_METHOD, ScoreCombinationFactory.DEFAULT_METHOD ); TestUtils.assertQueryResultScores(querySearchResults); } + + public void testSearchResultTypes_whenNoMatches_thenReturnZeroResults() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + for (int shardId = 0; shardId < 4; shardId++) { + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(0, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(-1), + createDelimiterElementForHybridSearchResults(-1), + createStartStopElementForHybridSearchResults(-1) } + ), + 0.0f + ), + new DocValueFormat[0] + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + querySearchResults.add(querySearchResult); + } + + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.empty(), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScoresWithNoMatches(querySearchResults); + } + + public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNormalizationCombination() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + new DocValueFormat[0] + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + querySearchResults.add(querySearchResult); + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), + new SearchHit(2, "1", Map.of(), Map.of()), + new SearchHit(4, "2", Map.of(), Map.of()), + new SearchHit(10, "3", Map.of(), Map.of()), + new SearchHit(0, "10", Map.of(), Map.of()), }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + fetchSearchResult.hits(searchHits); + + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScores(querySearchResults); + TestUtils.assertFetchResultScores(fetchSearchResult, 4); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index e56532b52..df15f40fd 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -112,7 +112,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights1AsMap, 1.0, 1.0, 0.001); + assertWeightedScores(searchResponseWithWeights1AsMap, 0.375, 0.3125, 0.001); // delete existing pipeline and create a new one with another set of weights deleteSearchPipeline(SEARCH_PIPELINE); @@ -131,7 +131,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights2AsMap, 1.0, 1.0, 0.001); + assertWeightedScores(searchResponseWithWeights2AsMap, 0.606, 0.242, 0.001); // check case when number of weights is less than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -151,7 +151,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights3AsMap, 1.0, 1.0, 0.001); + assertWeightedScores(searchResponseWithWeights3AsMap, 0.357, 0.285, 0.001); // check case when number of weights is more than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -171,7 +171,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights4AsMap, 1.0, 1.0, 0.001); + assertWeightedScores(searchResponseWithWeights4AsMap, 0.375, 0.3125, 0.001); } /** diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java index 1a7f895cd..78c4e1139 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationTechniqueTests.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.processor; +import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; + import java.util.List; import org.apache.lucene.search.ScoreDoc; @@ -12,7 +14,6 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.test.OpenSearchTestCase; public class ScoreCombinationTechniqueTests extends OpenSearchTestCase { @@ -63,24 +64,24 @@ public void testCombination_whenMultipleSubqueriesResultsAndDefaultMethod_thenSc assertNotNull(queryTopDocs); assertEquals(3, queryTopDocs.size()); - assertEquals(3, queryTopDocs.get(0).scoreDocs.length); - assertEquals(1.0, queryTopDocs.get(0).scoreDocs[0].score, 0.001f); - assertEquals(1, queryTopDocs.get(0).scoreDocs[0].doc); - assertEquals(1.0, queryTopDocs.get(0).scoreDocs[1].score, 0.001f); - assertEquals(3, queryTopDocs.get(0).scoreDocs[1].doc); - assertEquals(0.25, queryTopDocs.get(0).scoreDocs[2].score, 0.001f); - assertEquals(2, queryTopDocs.get(0).scoreDocs[2].doc); + assertEquals(3, queryTopDocs.get(0).getScoreDocs().size()); + assertEquals(.5, queryTopDocs.get(0).getScoreDocs().get(0).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(1, queryTopDocs.get(0).getScoreDocs().get(0).doc); + assertEquals(.5, queryTopDocs.get(0).getScoreDocs().get(1).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(3, queryTopDocs.get(0).getScoreDocs().get(1).doc); + assertEquals(0.125, queryTopDocs.get(0).getScoreDocs().get(2).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(2, queryTopDocs.get(0).getScoreDocs().get(2).doc); - assertEquals(4, queryTopDocs.get(1).scoreDocs.length); - assertEquals(0.9, queryTopDocs.get(1).scoreDocs[0].score, 0.001f); - assertEquals(2, queryTopDocs.get(1).scoreDocs[0].doc); - assertEquals(0.6, queryTopDocs.get(1).scoreDocs[1].score, 0.001f); - assertEquals(4, queryTopDocs.get(1).scoreDocs[1].doc); - assertEquals(0.5, queryTopDocs.get(1).scoreDocs[2].score, 0.001f); - assertEquals(7, queryTopDocs.get(1).scoreDocs[2].doc); - assertEquals(0.01, queryTopDocs.get(1).scoreDocs[3].score, 0.001f); - assertEquals(9, queryTopDocs.get(1).scoreDocs[3].doc); + assertEquals(4, queryTopDocs.get(1).getScoreDocs().size()); + assertEquals(0.45, queryTopDocs.get(1).getScoreDocs().get(0).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(2, queryTopDocs.get(1).getScoreDocs().get(0).doc); + assertEquals(0.3, queryTopDocs.get(1).getScoreDocs().get(1).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(4, queryTopDocs.get(1).getScoreDocs().get(1).doc); + assertEquals(0.25, queryTopDocs.get(1).getScoreDocs().get(2).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(7, queryTopDocs.get(1).getScoreDocs().get(2).doc); + assertEquals(0.005, queryTopDocs.get(1).getScoreDocs().get(3).score, DELTA_FOR_SCORE_ASSERTION); + assertEquals(9, queryTopDocs.get(1).getScoreDocs().get(3).doc); - assertEquals(0, queryTopDocs.get(2).scoreDocs.length); + assertEquals(0, queryTopDocs.get(2).getScoreDocs().size()); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 64b3fe07f..36af1a712 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -195,7 +195,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { 5, Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 1.0f, 1.0f }); + assertHybridSearchResults(searchResponseAsMapArithmeticMean, 5, new float[] { 0.5f, 1.0f }); deleteSearchPipeline(SEARCH_PIPELINE); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java index 6188a7ef5..8e31d619c 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationTechniqueTests.java @@ -15,7 +15,6 @@ import org.apache.lucene.search.TotalHits; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.test.OpenSearchTestCase; public class ScoreNormalizationTechniqueTests extends OpenSearchTestCase { @@ -38,9 +37,9 @@ public void testNormalization_whenOneSubqueryAndOneShardAndDefaultMethod_thenSco assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); - assertNotNull(resultDoc.getCompoundTopDocs()); - assertEquals(1, resultDoc.getCompoundTopDocs().size()); - TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertNotNull(resultDoc.getTopDocs()); + assertEquals(1, resultDoc.getTopDocs().size()); + TopDocs topDoc = resultDoc.getTopDocs().get(0); assertEquals(1, topDoc.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); assertNotNull(topDoc.scoreDocs); @@ -68,9 +67,9 @@ public void testNormalization_whenOneSubqueryMultipleHitsAndOneShardAndDefaultMe assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); - assertNotNull(resultDoc.getCompoundTopDocs()); - assertEquals(1, resultDoc.getCompoundTopDocs().size()); - TopDocs topDoc = resultDoc.getCompoundTopDocs().get(0); + assertNotNull(resultDoc.getTopDocs()); + assertEquals(1, resultDoc.getTopDocs().size()); + TopDocs topDoc = resultDoc.getTopDocs().get(0); assertEquals(3, topDoc.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDoc.totalHits.relation); assertNotNull(topDoc.scoreDocs); @@ -105,10 +104,10 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe assertNotNull(queryTopDocs); assertEquals(1, queryTopDocs.size()); CompoundTopDocs resultDoc = queryTopDocs.get(0); - assertNotNull(resultDoc.getCompoundTopDocs()); - assertEquals(2, resultDoc.getCompoundTopDocs().size()); + assertNotNull(resultDoc.getTopDocs()); + assertEquals(2, resultDoc.getTopDocs().size()); // sub-query one - TopDocs topDocSubqueryOne = resultDoc.getCompoundTopDocs().get(0); + TopDocs topDocSubqueryOne = resultDoc.getTopDocs().get(0); assertEquals(3, topDocSubqueryOne.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); assertNotNull(topDocSubqueryOne.scoreDocs); @@ -120,7 +119,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsAndOneShardAndDe assertEquals(0.0, lowScoreDoc.score, 0.001f); assertEquals(4, lowScoreDoc.doc); // sub-query two - TopDocs topDocSubqueryTwo = resultDoc.getCompoundTopDocs().get(1); + TopDocs topDocSubqueryTwo = resultDoc.getTopDocs().get(1); assertEquals(2, topDocSubqueryTwo.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); assertNotNull(topDocSubqueryTwo.scoreDocs); @@ -170,9 +169,9 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn assertEquals(3, queryTopDocs.size()); // shard one CompoundTopDocs resultDocShardOne = queryTopDocs.get(0); - assertEquals(2, resultDocShardOne.getCompoundTopDocs().size()); + assertEquals(2, resultDocShardOne.getTopDocs().size()); // sub-query one - TopDocs topDocSubqueryOne = resultDocShardOne.getCompoundTopDocs().get(0); + TopDocs topDocSubqueryOne = resultDocShardOne.getTopDocs().get(0); assertEquals(3, topDocSubqueryOne.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryOne.totalHits.relation); assertNotNull(topDocSubqueryOne.scoreDocs); @@ -184,7 +183,7 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn assertEquals(0.0, lowScoreDoc.score, 0.001f); assertEquals(4, lowScoreDoc.doc); // sub-query two - TopDocs topDocSubqueryTwo = resultDocShardOne.getCompoundTopDocs().get(1); + TopDocs topDocSubqueryTwo = resultDocShardOne.getTopDocs().get(1); assertEquals(2, topDocSubqueryTwo.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocSubqueryTwo.totalHits.relation); assertNotNull(topDocSubqueryTwo.scoreDocs); @@ -196,15 +195,15 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn // shard two CompoundTopDocs resultDocShardTwo = queryTopDocs.get(1); - assertEquals(2, resultDocShardTwo.getCompoundTopDocs().size()); + assertEquals(2, resultDocShardTwo.getTopDocs().size()); // sub-query one - TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getCompoundTopDocs().get(0); + TopDocs topDocShardTwoSubqueryOne = resultDocShardTwo.getTopDocs().get(0); assertEquals(0, topDocShardTwoSubqueryOne.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryOne.totalHits.relation); assertNotNull(topDocShardTwoSubqueryOne.scoreDocs); assertEquals(0, topDocShardTwoSubqueryOne.scoreDocs.length); // sub-query two - TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getCompoundTopDocs().get(1); + TopDocs topDocShardTwoSubqueryTwo = resultDocShardTwo.getTopDocs().get(1); assertEquals(4, topDocShardTwoSubqueryTwo.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardTwoSubqueryTwo.totalHits.relation); assertNotNull(topDocShardTwoSubqueryTwo.scoreDocs); @@ -216,14 +215,14 @@ public void testNormalization_whenMultipleSubqueriesMultipleHitsMultipleShardsAn // shard three CompoundTopDocs resultDocShardThree = queryTopDocs.get(2); - assertEquals(2, resultDocShardThree.getCompoundTopDocs().size()); + assertEquals(2, resultDocShardThree.getTopDocs().size()); // sub-query one - TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getCompoundTopDocs().get(0); + TopDocs topDocShardThreeSubqueryOne = resultDocShardThree.getTopDocs().get(0); assertEquals(0, topDocShardThreeSubqueryOne.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryOne.totalHits.relation); assertEquals(0, topDocShardThreeSubqueryOne.scoreDocs.length); // sub-query two - TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getCompoundTopDocs().get(1); + TopDocs topDocShardThreeSubqueryTwo = resultDocShardThree.getTopDocs().get(1); assertEquals(0, topDocShardThreeSubqueryTwo.totalHits.value); assertEquals(TotalHits.Relation.EQUAL_TO, topDocShardThreeSubqueryTwo.totalHits.relation); assertEquals(0, topDocShardThreeSubqueryTwo.scoreDocs.length); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java index c5f8c4860..5732687ca 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/L2ScoreNormalizationTechniqueTests.java @@ -11,8 +11,8 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.neuralsearch.search.CompoundTopDocs; /** * Abstracts normalization of scores based on min-max method @@ -49,8 +49,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); - assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); - assertCompoundTopDocs(expectedCompoundDocs, compoundTopDocs.get(0).getCompoundTopDocs().get(0)); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); } public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { @@ -99,9 +102,9 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); - assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); - for (int i = 0; i < expectedCompoundDocs.getCompoundTopDocs().size(); i++) { - assertCompoundTopDocs(expectedCompoundDocs.getCompoundTopDocs().get(i), compoundTopDocs.get(0).getCompoundTopDocs().get(i)); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); } } @@ -192,19 +195,13 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the assertNotNull(compoundTopDocs); assertEquals(2, compoundTopDocs.size()); - assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); - for (int i = 0; i < expectedCompoundDocsShard1.getCompoundTopDocs().size(); i++) { - assertCompoundTopDocs( - expectedCompoundDocsShard1.getCompoundTopDocs().get(i), - compoundTopDocs.get(0).getCompoundTopDocs().get(i) - ); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); } - assertNotNull(compoundTopDocs.get(1).getCompoundTopDocs()); - for (int i = 0; i < expectedCompoundDocsShard2.getCompoundTopDocs().size(); i++) { - assertCompoundTopDocs( - expectedCompoundDocsShard2.getCompoundTopDocs().get(i), - compoundTopDocs.get(1).getCompoundTopDocs().get(i) - ); + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java index c7de1fdb5..1b38b7bb4 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechniqueTests.java @@ -10,8 +10,8 @@ import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.neuralsearch.search.CompoundTopDocs; /** * Abstracts normalization of scores based on min-max method @@ -45,8 +45,11 @@ public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); - assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); - assertCompoundTopDocs(expectedCompoundDocs, compoundTopDocs.get(0).getCompoundTopDocs().get(0)); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); } public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { @@ -85,9 +88,9 @@ public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSucce ); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); - assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); - for (int i = 0; i < expectedCompoundDocs.getCompoundTopDocs().size(); i++) { - assertCompoundTopDocs(expectedCompoundDocs.getCompoundTopDocs().get(i), compoundTopDocs.get(0).getCompoundTopDocs().get(i)); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); } } @@ -149,19 +152,13 @@ public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_the assertNotNull(compoundTopDocs); assertEquals(2, compoundTopDocs.size()); - assertNotNull(compoundTopDocs.get(0).getCompoundTopDocs()); - for (int i = 0; i < expectedCompoundDocsShard1.getCompoundTopDocs().size(); i++) { - assertCompoundTopDocs( - expectedCompoundDocsShard1.getCompoundTopDocs().get(i), - compoundTopDocs.get(0).getCompoundTopDocs().get(i) - ); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); } - assertNotNull(compoundTopDocs.get(1).getCompoundTopDocs()); - for (int i = 0; i < expectedCompoundDocsShard2.getCompoundTopDocs().size(); i++) { - assertCompoundTopDocs( - expectedCompoundDocsShard2.getCompoundTopDocs().get(i), - compoundTopDocs.get(1).getCompoundTopDocs().get(i) - ); + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 59c90f495..eec6955ff 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.query; +import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; import java.io.IOException; @@ -32,7 +33,6 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_BASIC_INDEX_NAME = "test-neural-basic-index"; private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; - private static final int MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX = 3; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -51,18 +51,21 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); private final static String RELATION_EQUAL_TO = "eq"; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; @Before public void setUp() throws Exception { super.setUp(); updateClusterSettings(); prepareModel(); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); } @After @SneakyThrows public void tearDown() { super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); /* this is required to minimize chance of model not being deployed due to open memory CB, * this happens in case we leave model from previous test case. We use new model for every test, and old model * can be undeployed and deleted to free resources after each test case execution. @@ -80,61 +83,6 @@ protected boolean preserveClusterUponCompletion() { return true; } - /** - * Tests basic query, example of query structure: - * { - * "query": { - * "hybrid": { - * "queries": [ - * { - * "neural": { - * "text_knn": { - * "query_text": "Hello world", - * "model_id": "dcsdcasd", - * "k": 1 - * } - * } - * } - * ] - * } - * } - * } - */ - @SneakyThrows - public void testBasicQuery_whenOneSubQuery_thenSuccessful() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); - String modelId = getDeployedModelId(); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); - - HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); - hybridQueryBuilder.add(neuralQueryBuilder); - - Map searchResponseAsMap1 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilder, 10); - - assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap1)); - - List> hits1NestedList = getNestedHits(searchResponseAsMap1); - List ids = new ArrayList<>(); - List scores = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { - ids.add((String) oneHit.get("_id")); - scores.add((Double) oneHit.get("_score")); - } - - // verify that scores are in desc order - assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); - // verify that all ids are unique - assertEquals(Set.copyOf(ids).size(), ids.size()); - - Map total = getTotalHits(searchResponseAsMap1); - assertNotNull(total.get("value")); - assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - assertTrue(getMaxScore(searchResponseAsMap1).isPresent()); - } - /** * Tests complex query with multiple nested sub-queries: * { @@ -181,7 +129,13 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder1); hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); - Map searchResponseAsMap1 = search(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); + Map searchResponseAsMap1 = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); assertEquals(3, getHitCount(searchResponseAsMap1)); @@ -205,92 +159,6 @@ public void testComplexQuery_whenMultipleSubqueries_thenSuccessful() { assertEquals(RELATION_EQUAL_TO, total.get("relation")); } - /** - * Using queries similar to below to test sub-queries order: - * { - * "query": { - * "hybrid": { - * "queries": [ - * { - * "neural": { - * "text_knn": { - * "query_text": "Hello world", - * "model_id": "dcsdcasd", - * "k": 1 - * } - * } - * }, - * { - * "term": { - * "text": "word" - * } - * } - * ] - * } - * } - * } - */ - @SneakyThrows - public void testSubQuery_whenSubqueriesInDifferentOrder_thenResultIsSame() { - initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); - String modelId = getDeployedModelId(); - - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); - TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT); - - HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); - hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); - hybridQueryBuilderNeuralThenTerm.add(termQueryBuilder); - - Map searchResponseAsMap1 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); - - assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap1)); - - List> hits1NestedList = getNestedHits(searchResponseAsMap1); - List ids1 = new ArrayList<>(); - List scores1 = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { - ids1.add((String) oneHit.get("_id")); - scores1.add((Double) oneHit.get("_score")); - } - - Map total = getTotalHits(searchResponseAsMap1); - assertNotNull(total.get("value")); - assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total.get("value")); - assertNotNull(total.get("relation")); - assertEquals(RELATION_EQUAL_TO, total.get("relation")); - - // verify that scores are in desc order - assertTrue(IntStream.range(0, scores1.size() - 1).noneMatch(idx -> scores1.get(idx) < scores1.get(idx + 1))); - // verify that all ids are unique - assertEquals(Set.copyOf(ids1).size(), ids1.size()); - - // check similar query when sub-queries are in reverse order, results must be same as in previous test case - HybridQueryBuilder hybridQueryBuilderTermThenNeural = new HybridQueryBuilder(); - hybridQueryBuilderTermThenNeural.add(termQueryBuilder); - hybridQueryBuilderTermThenNeural.add(neuralQueryBuilder); - - Map searchResponseAsMap2 = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderNeuralThenTerm, 10); - - assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, getHitCount(searchResponseAsMap2)); - - List ids2 = new ArrayList<>(); - List scores2 = new ArrayList<>(); - for (Map oneHit : hits1NestedList) { - ids2.add((String) oneHit.get("_id")); - scores2.add((Double) oneHit.get("_score")); - } - - Map total2 = getTotalHits(searchResponseAsMap2); - assertNotNull(total.get("value")); - assertEquals(MAX_NUMBER_OF_DOCS_IN_MULTI_DOC_INDEX, total2.get("value")); - assertNotNull(total2.get("relation")); - assertEquals(RELATION_EQUAL_TO, total2.get("relation")); - // doc ids must match same from the previous query, order of sub-queries doesn't change the result - assertEquals(ids1, ids2); - assertEquals(scores1, scores2); - } - @SneakyThrows public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult() { initializeIndexIfNotExist(TEST_MULTI_DOC_INDEX_NAME); @@ -301,10 +169,17 @@ public void testNoMatchResults_whenOnlyTermSubQueryWithoutMatch_thenEmptyResult( hybridQueryBuilderOnlyTerm.add(termQueryBuilder); hybridQueryBuilderOnlyTerm.add(termQuery2Builder); - Map searchResponseAsMap = search(TEST_MULTI_DOC_INDEX_NAME, hybridQueryBuilderOnlyTerm, 10); + Map searchResponseAsMap = search( + TEST_MULTI_DOC_INDEX_NAME, + hybridQueryBuilderOnlyTerm, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); assertEquals(0, getHitCount(searchResponseAsMap)); - assertTrue(getMaxScore(searchResponseAsMap).isEmpty()); + assertTrue(getMaxScore(searchResponseAsMap).isPresent()); + assertEquals(0.0f, getMaxScore(searchResponseAsMap).get(), DELTA_FOR_SCORE_ASSERTION); Map total = getTotalHits(searchResponseAsMap); assertNotNull(total.get("value")); diff --git a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java index f63df42c9..e9c55cc54 100644 --- a/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java +++ b/src/test/java/org/opensearch/neuralsearch/search/query/HybridQueryPhaseSearcherTests.java @@ -14,8 +14,11 @@ import static org.mockito.Mockito.spy; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; import java.io.IOException; +import java.util.ArrayList; import java.util.LinkedList; import java.util.List; @@ -48,7 +51,6 @@ import org.opensearch.knn.index.mapper.KNNVectorFieldMapper; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; -import org.opensearch.neuralsearch.search.CompoundTopDocs; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.internal.ContextIndexSearcher; import org.opensearch.search.internal.SearchContext; @@ -282,8 +284,12 @@ public void testQueryResult_whenOneSubQueryWithHits_thenHybridResultsAreSet() { TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; assertEquals(1, topDocs.totalHits.value); - assertTrue(topDocs instanceof CompoundTopDocs); - List compoundTopDocs = ((CompoundTopDocs) topDocs).getCompoundTopDocs(); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertEquals(4, scoreDocs.length); + assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); + assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); assertNotNull(compoundTopDocs); assertEquals(1, compoundTopDocs.size()); TopDocs subQueryTopDocs = compoundTopDocs.get(0); @@ -374,8 +380,12 @@ public void testQueryResult_whenMultipleTextSubQueriesWithSomeHits_thenHybridRes TopDocsAndMaxScore topDocsAndMaxScore = querySearchResult.topDocs(); TopDocs topDocs = topDocsAndMaxScore.topDocs; assertEquals(4, topDocs.totalHits.value); - assertTrue(topDocs instanceof CompoundTopDocs); - List compoundTopDocs = ((CompoundTopDocs) topDocs).getCompoundTopDocs(); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + assertNotNull(scoreDocs); + assertEquals(10, scoreDocs.length); + assertTrue(isHybridQueryStartStopElement(scoreDocs[0])); + assertTrue(isHybridQueryStartStopElement(scoreDocs[scoreDocs.length - 1])); + List compoundTopDocs = getSubQueryResultsForSingleShard(topDocs); assertNotNull(compoundTopDocs); assertEquals(3, compoundTopDocs.size()); @@ -415,4 +425,26 @@ private void releaseResources(Directory directory, IndexWriter w, IndexReader re reader.close(); directory.close(); } + + private List getSubQueryResultsForSingleShard(final TopDocs topDocs) { + assertNotNull(topDocs); + List topDocsList = new ArrayList<>(); + ScoreDoc[] scoreDocs = topDocs.scoreDocs; + // skipping 0 element, it's a start-stop element + List scoreDocList = new ArrayList<>(); + for (int index = 2; index < scoreDocs.length; index++) { + // getting first element of score's series + ScoreDoc scoreDoc = scoreDocs[index]; + if (isHybridQueryDelimiterElement(scoreDoc) || isHybridQueryStartStopElement(scoreDoc)) { + ScoreDoc[] subQueryScores = scoreDocList.toArray(new ScoreDoc[0]); + TotalHits totalHits = new TotalHits(subQueryScores.length, TotalHits.Relation.EQUAL_TO); + TopDocs subQueryTopDocs = new TopDocs(totalHits, subQueryScores); + topDocsList.add(subQueryTopDocs); + scoreDocList.clear(); + } else { + scoreDocList.add(scoreDoc); + } + } + return topDocsList; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java new file mode 100644 index 000000000..65971b6d6 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/search/util/HybridSearchResultFormatUtilTests.java @@ -0,0 +1,61 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.search.util; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_DELIMITER; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.MAGIC_NUMBER_START_STOP; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryDelimiterElement; +import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.isHybridQueryStartStopElement; + +import org.apache.lucene.search.ScoreDoc; +import org.opensearch.common.Randomness; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class HybridSearchResultFormatUtilTests extends OpenSearchQueryTestCase { + + public void testScoreDocsListElements_whenTestingListElements_thenCheckResultsAreCorrect() { + ScoreDoc validStartStopElement = new ScoreDoc(0, MAGIC_NUMBER_START_STOP); + assertTrue(isHybridQueryStartStopElement(validStartStopElement)); + + ScoreDoc validStartStopElement1 = new ScoreDoc(-1, MAGIC_NUMBER_START_STOP); + assertFalse(isHybridQueryStartStopElement(validStartStopElement1)); + + ScoreDoc validStartStopElement2 = new ScoreDoc(0, Randomness.get().nextFloat()); + assertFalse(isHybridQueryStartStopElement(validStartStopElement2)); + + assertFalse(isHybridQueryStartStopElement(null)); + + ScoreDoc validDelimiterElement = new ScoreDoc(0, MAGIC_NUMBER_DELIMITER); + assertTrue(isHybridQueryDelimiterElement(validDelimiterElement)); + + ScoreDoc validDelimiterElement1 = new ScoreDoc(-1, MAGIC_NUMBER_DELIMITER); + assertFalse(isHybridQueryDelimiterElement(validDelimiterElement1)); + + ScoreDoc validDelimiterElement2 = new ScoreDoc(0, Randomness.get().nextFloat()); + assertFalse(isHybridQueryDelimiterElement(validDelimiterElement2)); + + assertFalse(isHybridQueryDelimiterElement(null)); + } + + public void testCreateElements_whenCreateStartStopAndDelimiterElements_thenSuccessful() { + int docId = 1; + ScoreDoc startStopElement = createStartStopElementForHybridSearchResults(docId); + assertNotNull(startStopElement); + assertEquals(docId, startStopElement.doc); + assertEquals(MAGIC_NUMBER_START_STOP, startStopElement.score, 0.0f); + + ScoreDoc delimiterElement = createDelimiterElementForHybridSearchResults(docId); + assertNotNull(delimiterElement); + assertEquals(docId, delimiterElement.doc); + assertEquals(MAGIC_NUMBER_DELIMITER, delimiterElement.score, 0.0f); + } +} From 685d5d66d7070622420f8749dbccec0c3b9aa97f Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 30 Aug 2023 17:51:29 +0200 Subject: [PATCH 3/6] Added validations for score combination weights in Hybrid Search (#265) * Added strong check on number of weights equals number of sub-queries Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + ...ithmeticMeanScoreCombinationTechnique.java | 1 + ...eometricMeanScoreCombinationTechnique.java | 1 + ...HarmonicMeanScoreCombinationTechnique.java | 1 + .../combination/ScoreCombinationUtil.java | 64 ++++++++++++++++++- .../processor/ScoreCombinationIT.java | 61 ++++++++++-------- .../processor/ScoreNormalizationIT.java | 12 ++-- ...ticMeanScoreCombinationTechniqueTests.java | 16 ++--- ...ricMeanScoreCombinationTechniqueTests.java | 20 ++---- ...nicMeanScoreCombinationTechniqueTests.java | 18 ++---- .../ScoreCombinationUtilTests.java | 64 +++++++++++++++++++ .../NormalizationProcessorFactoryTests.java | 46 +++++++++++-- 12 files changed, 231 insertions(+), 74 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 17f47690d..7c504d813 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), * Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) ### Enhancements * Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) +* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265)) ### Bug Fixes ### Infrastructure ### Documentation diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java index cfafeb3e5..e656beca3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechnique.java @@ -38,6 +38,7 @@ public ArithmeticMeanScoreCombinationTechnique(final Map params, */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float combinedScore = 0.0f; float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java index 4e7a8ca9e..2a78d5ac6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechnique.java @@ -40,6 +40,7 @@ public GeometricMeanScoreCombinationTechnique(final Map params, */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float weightedLnSum = 0; float sumOfWeights = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java index 9f913b2ef..0b45fb616 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechnique.java @@ -38,6 +38,7 @@ public HarmonicMeanScoreCombinationTechnique(final Map params, f */ @Override public float combine(final float[] scores) { + scoreCombinationUtil.validateIfWeightsMatchScores(scores, weights); float sumOfWeights = 0; float sumOfHarmonics = 0; for (int indexOfSubQuery = 0; indexOfSubQuery < scores.length; indexOfSubQuery++) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java index 35e097f7f..ed82a62ea 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtil.java @@ -5,6 +5,7 @@ package org.opensearch.neuralsearch.processor.combination; +import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Map; @@ -13,11 +14,19 @@ import java.util.Set; import java.util.stream.Collectors; +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.Range; + +import com.google.common.math.DoubleMath; + /** * Collection of utility methods for score combination technique classes */ +@Log4j2 class ScoreCombinationUtil { private static final String PARAM_NAME_WEIGHTS = "weights"; + private static final float DELTA_FOR_SCORE_ASSERTION = 0.01f; /** * Get collection of weights based on user provided config @@ -29,9 +38,11 @@ public List getWeights(final Map params) { return List.of(); } // get weights, we don't need to check for instance as it's done during validation - return ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() + List weightsList = ((List) params.getOrDefault(PARAM_NAME_WEIGHTS, List.of())).stream() .map(Double::floatValue) .collect(Collectors.toUnmodifiableList()); + validateWeights(weightsList); + return weightsList; } /** @@ -77,4 +88,55 @@ public void validateParams(final Map actualParams, final Set weights, final int indexOfSubQuery) { return indexOfSubQuery < weights.size() ? weights.get(indexOfSubQuery) : 1.0f; } + + /** + * Check if number of weights matches number of queries. This does not apply for case when + * weights were not provided, as this is valid default value + * @param scores collection of scores from all sub-queries of a single hybrid search query + * @param weights score combination weights that are defined as part of search result processor + */ + protected void validateIfWeightsMatchScores(final float[] scores, final List weights) { + if (weights.isEmpty()) { + return; + } + if (scores.length != weights.size()) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "number of weights [%d] must match number of sub-queries [%d] in hybrid query", + weights.size(), + scores.length + ) + ); + } + } + + /** + * Check if provided weights are valid for combination. Following conditions are checked: + * - every weight is between 0.0 and 1.0 + * - sum of all weights must be equal 1.0 + * @param weightsList + */ + private void validateWeights(final List weightsList) { + boolean isOutOfRange = weightsList.stream().anyMatch(weight -> !Range.between(0.0f, 1.0f).contains(weight)); + if (isOutOfRange) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "all weights must be in range [0.0 ... 1.0], submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + } + float sumOfWeights = weightsList.stream().reduce(0.0f, Float::sum); + if (!DoubleMath.fuzzyEquals(1.0f, sumOfWeights, DELTA_FOR_SCORE_ASSERTION)) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "sum of weights for combination must be equal to 1.0, submitted weights: %s", + Arrays.toString(weightsList.toArray(new Float[0])) + ) + ); + } + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index df15f40fd..03b77549a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.processor; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.opensearch.neuralsearch.TestUtils.assertHybridSearchResults; import static org.opensearch.neuralsearch.TestUtils.assertWeightedScores; import static org.opensearch.neuralsearch.TestUtils.createRandomVector; @@ -18,6 +20,7 @@ import org.junit.After; import org.junit.Before; +import org.opensearch.client.ResponseException; import org.opensearch.index.query.QueryBuilders; import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; @@ -96,7 +99,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.4f, 0.3f, 0.3f })) ); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -112,7 +115,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights1AsMap, 0.375, 0.3125, 0.001); + assertWeightedScores(searchResponseWithWeights1AsMap, 0.4, 0.3, 0.001); // delete existing pipeline and create a new one with another set of weights deleteSearchPipeline(SEARCH_PIPELINE); @@ -120,7 +123,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 2.0f, 0.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.233f, 0.666f, 0.1f })) ); Map searchResponseWithWeights2AsMap = search( @@ -131,7 +134,7 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { Map.of("search_pipeline", SEARCH_PIPELINE) ); - assertWeightedScores(searchResponseWithWeights2AsMap, 0.606, 0.242, 0.001); + assertWeightedScores(searchResponseWithWeights2AsMap, 0.6666, 0.2332, 0.001); // check case when number of weights is less than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -140,18 +143,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 1.0f })) ); - Map searchResponseWithWeights3AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + ResponseException exception1 = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + org.hamcrest.MatcherAssert.assertThat( + exception1.getMessage(), + allOf( + containsString("number of weights"), + containsString("must match number of sub-queries"), + containsString("in hybrid query") + ) ); - - assertWeightedScores(searchResponseWithWeights3AsMap, 0.357, 0.285, 0.001); // check case when number of weights is more than number of sub-queries // delete existing pipeline and create a new one with another set of weights @@ -160,18 +166,21 @@ public void testArithmeticWeightedMean_whenWeightsPassed_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.6f, 0.5f, 0.5f, 1.5f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.3f, 0.25f, 0.25f, 0.2f })) ); - Map searchResponseWithWeights4AsMap = search( - TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, - hybridQueryBuilder, - null, - 5, - Map.of("search_pipeline", SEARCH_PIPELINE) + ResponseException exception2 = expectThrows( + ResponseException.class, + () -> search(TEST_MULTI_DOC_INDEX_THREE_SHARDS_NAME, hybridQueryBuilder, null, 5, Map.of("search_pipeline", SEARCH_PIPELINE)) + ); + org.hamcrest.MatcherAssert.assertThat( + exception2.getMessage(), + allOf( + containsString("number of weights"), + containsString("must match number of sub-queries"), + containsString("in hybrid query") + ) ); - - assertWeightedScores(searchResponseWithWeights4AsMap, 0.375, 0.3125, 0.001); } /** @@ -199,7 +208,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -223,7 +232,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); @@ -265,7 +274,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -289,7 +298,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 36af1a712..6f98e8d5e 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -91,7 +91,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -115,7 +115,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); @@ -138,7 +138,7 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, L2_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); @@ -180,7 +180,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, DEFAULT_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); String modelId = getDeployedModelId(); @@ -204,7 +204,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, HARMONIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); @@ -227,7 +227,7 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { SEARCH_PIPELINE, DEFAULT_NORMALIZATION_METHOD, GEOMETRIC_MEAN_COMBINATION_METHOD, - Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.8f, 0.7f })) + Map.of(PARAM_NAME_WEIGHTS, Arrays.toString(new float[] { 0.533f, 0.466f })) ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java index 842df736d..125930007 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ArithmeticMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class ArithmeticMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -33,9 +31,7 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { } public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil @@ -44,20 +40,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.825f; + float expectedScore = 0.69f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new ArithmeticMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java index fe0d962ca..3f70c229f 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/GeometricMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class GeometricMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -34,19 +32,17 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.5797f; + float expectedScore = 0.5567f; testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil @@ -55,20 +51,18 @@ public void testRandomValues_whenAllScoresAndWeightsPresent_thenCorrectScores() } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.7997f; + float expectedScore = 0.7863f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new GeometricMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java index 8187123a1..7b1b07f64 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/HarmonicMeanScoreCombinationTechniqueTests.java @@ -12,8 +12,6 @@ import java.util.stream.Collectors; import java.util.stream.IntStream; -import com.carrotsearch.randomizedtesting.RandomizedTest; - public class HarmonicMeanScoreCombinationTechniqueTests extends BaseScoreCombinationTechniqueTests { private ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); @@ -34,30 +32,28 @@ public void testLogic_whenNotAllScoresPresentAndNoWeights_thenCorrectScores() { public void testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores() { List scores = List.of(1.0f, 0.5f, 0.3f); - List weights = List.of(0.9, 0.2, 0.7); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expecteScore = 0.4954f; - testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expecteScore); + float expectedScore = 0.48f; + testLogic_whenAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List scores = List.of(1.0f, -1.0f, 0.6f); - List weights = List.of(0.9, 0.2, 0.7); + List scores = List.of(1.0f, 0.0f, 0.6f); + List weights = List.of(0.45, 0.15, 0.4); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil ); - float expectedScore = 0.7741f; + float expectedScore = 0.7611f; testLogic_whenNotAllScoresAndWeightsPresent_thenCorrectScores(technique, scores, expectedScore); } public void testRandomValues_whenNotAllScoresAndWeightsPresent_thenCorrectScores() { - List weights = IntStream.range(0, RANDOM_SCORES_SIZE) - .mapToObj(i -> RandomizedTest.randomDouble()) - .collect(Collectors.toList()); + List weights = IntStream.range(0, RANDOM_SCORES_SIZE).mapToObj(i -> 1.0 / RANDOM_SCORES_SIZE).collect(Collectors.toList()); ScoreCombinationTechnique technique = new HarmonicMeanScoreCombinationTechnique( Map.of(PARAM_NAME_WEIGHTS, weights), scoreCombinationUtil diff --git a/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java new file mode 100644 index 000000000..ca13cb2eb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/combination/ScoreCombinationUtilTests.java @@ -0,0 +1,64 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.combination; + +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; + +import java.util.List; +import java.util.Map; + +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ScoreCombinationUtilTests extends OpenSearchQueryTestCase { + + public void testCombinationWeights_whenEmptyInputPassed_thenCreateEmptyWeightCollection() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + List weights = scoreCombinationUtil.getWeights(Map.of()); + assertNotNull(weights); + assertTrue(weights.isEmpty()); + } + + public void testCombinationWeights_whenWeightsArePassed_thenSuccessful() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + List weights = scoreCombinationUtil.getWeights(Map.of("weights", List.of(0.4, 0.6))); + assertNotNull(weights); + assertEquals(2, weights.size()); + assertTrue(weights.containsAll(List.of(0.4f, 0.6f))); + } + + public void testCombinationWeights_whenInvalidWeightsArePassed_thenFail() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationUtil.getWeights(Map.of("weights", List.of(2.4))) + ); + assertTrue(exception1.getMessage().contains("all weights must be in range")); + + IllegalArgumentException exception2 = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationUtil.getWeights(Map.of("weights", List.of(0.4, 0.5, 0.6))) + ); + assertTrue(exception2.getMessage().contains("sum of weights for combination must be equal to 1.0")); + } + + public void testWeightsValidation_whenNumberOfScoresDifferentFromNumberOfWeights_thenFail() { + ScoreCombinationUtil scoreCombinationUtil = new ScoreCombinationUtil(); + IllegalArgumentException exception1 = expectThrows( + IllegalArgumentException.class, + () -> scoreCombinationUtil.validateIfWeightsMatchScores(new float[] { 0.6f, 0.5f }, List.of(0.4f, 0.2f, 0.4f)) + ); + org.hamcrest.MatcherAssert.assertThat( + exception1.getMessage(), + allOf( + containsString("number of weights"), + containsString("must match number of sub-queries"), + containsString("in hybrid query") + ) + ); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java index 83bb0e7bb..a1ddefe16 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/NormalizationProcessorFactoryTests.java @@ -132,17 +132,14 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( String tag = "tag"; String description = "description"; boolean ignoreFailure = false; + double weight1 = RandomizedTest.randomDouble(); + double weight2 = 1.0f - weight1; Map config = new HashMap<>(); config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); config.put( COMBINATION_CLAUSE, new HashMap<>( - Map.of( - TECHNIQUE, - "arithmetic_mean", - PARAMETERS, - new HashMap<>(Map.of("weights", Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble()))) - ) + Map.of(TECHNIQUE, "arithmetic_mean", PARAMETERS, new HashMap<>(Map.of("weights", Arrays.asList(weight1, weight2)))) ) ); Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); @@ -160,6 +157,43 @@ public void testNormalizationProcessor_whenWithCombinationParams_thenSuccessful( assertEquals("normalization-processor", normalizationProcessor.getType()); } + @SneakyThrows + public void testWeightsParams_whenInvalidValues_thenFail() { + NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), + new ScoreNormalizationFactory(), + new ScoreCombinationFactory() + ); + final Map> processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(NORMALIZATION_CLAUSE, new HashMap<>(Map.of("technique", "min_max"))); + config.put( + COMBINATION_CLAUSE, + new HashMap<>( + Map.of( + TECHNIQUE, + "arithmetic_mean", + PARAMETERS, + new HashMap<>( + Map.of( + "weights", + Arrays.asList(RandomizedTest.randomDouble(), RandomizedTest.randomDouble(), RandomizedTest.randomDouble()) + ) + ) + ) + ) + ); + Processor.PipelineContext pipelineContext = mock(Processor.PipelineContext.class); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> normalizationProcessorFactory.create(processorFactories, tag, description, ignoreFailure, config, pipelineContext) + ); + assertTrue(exception.getMessage().contains("sum of weights for combination must be equal to 1.0")); + } + public void testInputValidation_whenInvalidNormalizationClause_thenFail() { NormalizationProcessorFactory normalizationProcessorFactory = new NormalizationProcessorFactory( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()), From 2208c0569cebde536ad9f9bbc2e68a3c5ea3eb7d Mon Sep 17 00:00:00 2001 From: Navneet Verma Date: Tue, 5 Sep 2023 16:41:30 -0700 Subject: [PATCH 4/6] Added release notes for 2.10 release (#277) Signed-off-by: Navneet Verma --- CHANGELOG.md | 7 ++----- ...opensearch-neural-search.release-notes-2.10.0.0.md | 11 +++++++++++ 2 files changed, 13 insertions(+), 5 deletions(-) create mode 100644 release-notes/opensearch-neural-search.release-notes-2.10.0.0.md diff --git a/CHANGELOG.md b/CHANGELOG.md index 7c504d813..1b3ea389e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,14 +12,11 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Maintenance ### Refactoring -## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.7...2.x) +## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x) ### Features -* Added Score Normalization and Combination feature ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) ### Enhancements -* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) -* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265)) ### Bug Fixes ### Infrastructure ### Documentation ### Maintenance -### Refactoring \ No newline at end of file +### Refactoring diff --git a/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md b/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md new file mode 100644 index 000000000..5c86a24dd --- /dev/null +++ b/release-notes/opensearch-neural-search.release-notes-2.10.0.0.md @@ -0,0 +1,11 @@ +## Version 2.10.0.0 Release Notes + +Compatible with OpenSearch 2.10.0 + +### Features +* Improved Hybrid Search relevancy by Score Normalization and Combination ([#241](https://github.com/opensearch-project/neural-search/pull/241/)) + +### Enhancements +* Changed format for hybrid query results to a single list of scores with delimiter ([#259](https://github.com/opensearch-project/neural-search/pull/259)) +* Added validations for score combination weights in Hybrid Search ([#265](https://github.com/opensearch-project/neural-search/pull/265)) +* Made hybrid search active by default ([#274](https://github.com/opensearch-project/neural-search/pull/274)) From 5bf36ed78152ca6b0e2fab49904dc58404a295d9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 7 Sep 2023 16:49:53 -0700 Subject: [PATCH 5/6] Fixed compilation errors after recent changes in ml-commons (#285) * Fixed compilation errors after recent changes in ml-commons Signed-off-by: Martin Gaievski --- .../neuralsearch/common/BaseNeuralSearchIT.java | 8 ++++---- .../neuralsearch/ml/MLCommonsClientAccessorTests.java | 5 ++++- .../resources/processor/CreateModelGroupRequestBody.json | 5 ++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index fdf2459df..b144ade6c 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -8,7 +8,6 @@ import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; import java.io.IOException; -import java.net.URISyntaxException; import java.nio.file.Files; import java.nio.file.Path; import java.util.Collections; @@ -50,6 +49,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.neuralsearch.OpenSearchSecureRestTestCase; +import com.carrotsearch.randomizedtesting.RandomizedTest; import com.google.common.collect.ImmutableList; public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { @@ -137,7 +137,7 @@ protected void loadModel(String modelId) throws Exception { Response uploadResponse = makeRequest( client(), "POST", - String.format(LOCALE, "/_plugins/_ml/models/%s/_load", modelId), + String.format(LOCALE, "/_plugins/_ml/models/%s/_deploy", modelId), null, toHttpEntity(""), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) @@ -685,10 +685,10 @@ protected String getDeployedModelId() { } @SneakyThrows - private String registerModelGroup() throws IOException, URISyntaxException { + private String registerModelGroup() { String modelGroupRegisterRequestBody = Files.readString( Path.of(classLoader.getResource("processor/CreateModelGroupRequestBody.json").toURI()) - ); + ).replace("", "public_model_" + RandomizedTest.randomAsciiAlphanumOfLength(8)); Response modelGroupResponse = makeRequest( client(), "POST", diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 350394250..3ef5431b3 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -13,6 +13,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; +import java.util.Map; import org.junit.Before; import org.mockito.InjectMocks; @@ -168,7 +169,9 @@ private ModelTensorOutput createModelTensorOutput(final Float[] output) { output, new long[] { 1, 2 }, MLResultDataType.FLOAT64, - ByteBuffer.wrap(new byte[12]) + ByteBuffer.wrap(new byte[12]), + "someValue", + Map.of() ); mlModelTensorList.add(tensor); final ModelTensors modelTensors = new ModelTensors(mlModelTensorList); diff --git a/src/test/resources/processor/CreateModelGroupRequestBody.json b/src/test/resources/processor/CreateModelGroupRequestBody.json index d6d398c76..91f68e222 100644 --- a/src/test/resources/processor/CreateModelGroupRequestBody.json +++ b/src/test/resources/processor/CreateModelGroupRequestBody.json @@ -1,5 +1,4 @@ { - "name": "test_model_group_public", - "description": "This is a public model group", - "access_mode": "public" + "name": "", + "description": "This is a public model group" } \ No newline at end of file From 8484be9b319a09dbdd9a1c952ca1b8515dedf571 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 7 Sep 2023 17:08:59 -0700 Subject: [PATCH 6/6] Made hybrid search active by default by flipping feature flag default (#292) Signed-off-by: Martin Gaievski (cherry picked from commit 174f2c9579b9d374592142de190eec794ff33f2e) --- build.gradle | 4 ---- .../neuralsearch/plugin/NeuralSearch.java | 24 ++++++++++--------- .../settings/NeuralSearchSettings.java | 4 ++-- .../plugin/NeuralSearchTests.java | 14 +++++------ .../query/OpenSearchQueryTestCase.java | 4 ++-- 5 files changed, 24 insertions(+), 26 deletions(-) diff --git a/build.gradle b/build.gradle index c050f4705..1d8eca483 100644 --- a/build.gradle +++ b/build.gradle @@ -253,10 +253,6 @@ testClusters.integTest { // Increase heap size from default of 512mb to 1gb. When heap size is 512mb, our integ tests sporadically fail due // to ml-commons memory circuit breaker exception jvmArgs("-Xms1g", "-Xmx1g") - - // enable features for testing - // hybrid search - systemProperty('plugins.neural_search.hybrid_search_enabled', 'true') } // Remote Integration Tests diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index b46d2bc6d..e94a2957d 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -5,7 +5,7 @@ package org.opensearch.neuralsearch.plugin; -import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; import java.util.Arrays; import java.util.Collection; @@ -99,16 +99,18 @@ public Map getProcessors(Processor.Parameters paramet @Override public Optional getQueryPhaseSearcher() { - if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey())) { - log.info("Registering hybrid query phase searcher with feature flag [{}]", NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey()); - return Optional.of(new HybridQueryPhaseSearcher()); + // we're using "is_disabled" flag as there are no proper implementation of FeatureFlags.isDisabled(). Both + // cases when flag is not set or it is "false" are interpretted in the same way. In such case core is reading + // the actual value from settings. + if (FeatureFlags.isEnabled(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey())) { + log.info( + "Not registering hybrid query phase searcher because feature flag [{}] is disabled", + NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey() + ); + return Optional.empty(); } - log.info( - "Not registering hybrid query phase searcher because feature flag [{}] is disabled", - NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey() - ); - // we want feature be disabled by default due to risk of colliding and breaking concurrent search in core - return Optional.empty(); + log.info("Registering hybrid query phase searcher with feature flag [{}]", NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey()); + return Optional.of(new HybridQueryPhaseSearcher()); } @Override @@ -123,6 +125,6 @@ public Map> getSettings() { - return List.of(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED); + return List.of(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED); } } diff --git a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java index 995f0c0fa..54edf8745 100644 --- a/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java +++ b/src/main/java/org/opensearch/neuralsearch/settings/NeuralSearchSettings.java @@ -21,8 +21,8 @@ public final class NeuralSearchSettings { * Currently query phase searcher added with hybrid search will conflict with concurrent search in core. * Once that problem is resolved this feature flag can be removed. */ - public static final Setting NEURAL_SEARCH_HYBRID_SEARCH_ENABLED = Setting.boolSetting( - "plugins.neural_search.hybrid_search_enabled", + public static final Setting NEURAL_SEARCH_HYBRID_SEARCH_DISABLED = Setting.boolSetting( + "plugins.neural_search.hybrid_search_disabled", false, Setting.Property.NodeScope ); diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 7918126c5..8918e174c 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -38,18 +38,18 @@ public void testQuerySpecs() { public void testQueryPhaseSearcher() { NeuralSearch plugin = new NeuralSearch(); - Optional queryPhaseSearcher = plugin.getQueryPhaseSearcher(); - - assertNotNull(queryPhaseSearcher); - assertTrue(queryPhaseSearcher.isEmpty()); - - initFeatureFlags(); - Optional queryPhaseSearcherWithFeatureFlagDisabled = plugin.getQueryPhaseSearcher(); assertNotNull(queryPhaseSearcherWithFeatureFlagDisabled); assertFalse(queryPhaseSearcherWithFeatureFlagDisabled.isEmpty()); assertTrue(queryPhaseSearcherWithFeatureFlagDisabled.get() instanceof HybridQueryPhaseSearcher); + + initFeatureFlags(); + + Optional queryPhaseSearcher = plugin.getQueryPhaseSearcher(); + + assertNotNull(queryPhaseSearcher); + assertTrue(queryPhaseSearcher.isEmpty()); } public void testProcessors() { diff --git a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java index 94866acb8..26c183832 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java +++ b/src/test/java/org/opensearch/neuralsearch/query/OpenSearchQueryTestCase.java @@ -8,7 +8,7 @@ import static java.util.Collections.emptyMap; import static java.util.Collections.singletonMap; import static java.util.stream.Collectors.toList; -import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_ENABLED; +import static org.opensearch.neuralsearch.settings.NeuralSearchSettings.NEURAL_SEARCH_HYBRID_SEARCH_DISABLED; import java.io.IOException; import java.util.Arrays; @@ -227,6 +227,6 @@ public float getMaxScore(int upTo) { @SuppressForbidden(reason = "manipulates system properties for testing") protected static void initFeatureFlags() { - System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_ENABLED.getKey(), "true"); + System.setProperty(NEURAL_SEARCH_HYBRID_SEARCH_DISABLED.getKey(), "true"); } }