diff --git a/build.gradle b/build.gradle index 853aa85e7..90731feb5 100644 --- a/build.gradle +++ b/build.gradle @@ -218,6 +218,12 @@ integTest { if (System.getProperty("test.debug") != null) { jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' } + + systemProperty 'log4j2.configurationFile', "${projectDir}/src/test/resources/log4j2-test.xml" + + // Set this to true this if you want to see the logs in the terminal test output. + // note: if left false the log output will still show in your IDE + testLogging.showStandardStreams = true } testClusters.integTest { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java index 0293efae6..5b9cd4378 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/combination/ScoreCombiner.java @@ -26,8 +26,6 @@ @Log4j2 public class ScoreCombiner { - private static final Float ZERO_SCORE = 0.0f; - /** * Performs score combination based on input combination technique. Mutates input object by updating combined scores * Main steps we're doing for combination: diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java index 667c237c7..c42df96fe 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ScoreNormalizationFactory.java @@ -19,7 +19,9 @@ public class ScoreNormalizationFactory { MinMaxScoreNormalizationTechnique.TECHNIQUE_NAME, new MinMaxScoreNormalizationTechnique(), L2ScoreNormalizationTechnique.TECHNIQUE_NAME, - new L2ScoreNormalizationTechnique() + new L2ScoreNormalizationTechnique(), + ZScoreNormalizationTechnique.TECHNIQUE_NAME, + new ZScoreNormalizationTechnique() ); /** diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java new file mode 100644 index 000000000..fc97e8a4b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.Optional; + +import lombok.ToString; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import com.google.common.primitives.Floats; + +/** + * Implementation of z-score normalization technique for hybrid query + * This is currently modeled based on the existing normalization techniques {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} + * However, this class as well as the original ones require a significant work to improve style and ease of use, see TODO items below + */ +/* +TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} +1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed +2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier +3. Implicit calculation of numOfSubQueries instead of having a more explicit upstream indicator/metadata regarding it + */ +@ToString(onlyExplicitlyIncluded = true) +public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "z_score"; + private static final float SINGLE_RESULT_SCORE = 1.0f; + + @Override + public void normalize(final List queryTopDocs) { + /* + TODO: There is an implicit assumption in this calculation that probably need to be made clearer by passing some metadata with the results. + Currently assuming that finding a single non empty shard result will contain all sub query results with 0 hits. + */ + final Optional maybeCompoundTopDocs = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getTopDocs().size() > 0) + .findAny(); + + final int numOfSubQueries = maybeCompoundTopDocs.map(compoundTopDocs -> compoundTopDocs.getTopDocs().size()).orElse(0); + + // to be done for each subquery + float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries); + long[] elementsPerSubquery = findNumberOfElementsPerSubQuery(queryTopDocs, numOfSubQueries); + float[] meanPerSubQuery = findMeanPerSubquery(sumPerSubquery, elementsPerSubquery); + float[] stdPerSubquery = findStdPerSubquery(queryTopDocs, meanPerSubQuery, elementsPerSubquery, numOfSubQueries); + + // do normalization using actual score and z-scores for corresponding sub query + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore(scoreDoc.score, stdPerSubquery[j], meanPerSubQuery[j]); + } + } + } + } + + static private float[] findScoreSumPerSubQuery(final List queryTopDocs, final int numOfScores) { + final float[] sumOfScorePerSubQuery = new float[numOfScores]; + Arrays.fill(sumOfScorePerSubQuery, 0); + // TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + sumOfScorePerSubQuery[j] += sumScoreDocsArray(topDocsPerSubQuery.get(j).scoreDocs); + } + } + + return sumOfScorePerSubQuery; + } + + static private long[] findNumberOfElementsPerSubQuery(final List queryTopDocs, final int numOfScores) { + final long[] numberOfElementsPerSubQuery = new long[numOfScores]; + Arrays.fill(numberOfElementsPerSubQuery, 0); + // TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + numberOfElementsPerSubQuery[j] += topDocsPerSubQuery.get(j).totalHits.value; + } + } + + return numberOfElementsPerSubQuery; + } + + static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final long[] elementsPerSubquery) { + final float[] meanPerSubQuery = new float[elementsPerSubquery.length]; + for (int i = 0; i < elementsPerSubquery.length; i++) { + if (elementsPerSubquery[i] == 0) { + meanPerSubQuery[i] = 0; + } else { + meanPerSubQuery[i] = sumPerSubquery[i] / elementsPerSubquery[i]; + } + } + + return meanPerSubQuery; + } + + static private float[] findStdPerSubquery( + final List queryTopDocs, + final float[] meanPerSubQuery, + final long[] elementsPerSubquery, + final int numOfScores + ) { + final double[] deltaSumPerSubquery = new double[numOfScores]; + Arrays.fill(deltaSumPerSubquery, 0); + // TODO: make this syntactically clearer regarding performance by avoiding List.get(j) with an abstract List type + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + for (ScoreDoc scoreDoc : topDocsPerSubQuery.get(j).scoreDocs) { + deltaSumPerSubquery[j] += Math.pow(scoreDoc.score - meanPerSubQuery[j], 2); + } + } + } + + final float[] stdPerSubQuery = new float[numOfScores]; + for (int i = 0; i < deltaSumPerSubquery.length; i++) { + if (elementsPerSubquery[i] == 0) { + stdPerSubQuery[i] = 0; + } else { + stdPerSubQuery[i] = (float) Math.sqrt(deltaSumPerSubquery[i] / elementsPerSubquery[i]); + } + } + + return stdPerSubQuery; + } + + static private float sumScoreDocsArray(final ScoreDoc[] scoreDocs) { + float sum = 0; + for (ScoreDoc scoreDoc : scoreDocs) { + sum += scoreDoc.score; + } + + return sum; + } + + private static float normalizeSingleScore(final float score, final float standardDeviation, final float mean) { + // edge case when there is only one score and min and max scores are same + if (Floats.compare(mean, score) == 0) { + return SINGLE_RESULT_SCORE; + } + return (score - mean) / standardDeviation; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 9c24e81fd..e6265724f 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -760,4 +760,19 @@ private String registerModelGroup() { assertNotNull(modelGroupId); return modelGroupId; } + + protected List> getNestedHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (List>) hitsMap.get("hits"); + } + + protected Map getTotalHits(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return (Map) hitsMap.get("total"); + } + + protected Optional getMaxScore(Map searchResponseAsMap) { + Map hitsMap = (Map) searchResponseAsMap.get("hits"); + return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/HybridQueryZScoreIT.java b/src/test/java/org/opensearch/neuralsearch/processor/HybridQueryZScoreIT.java new file mode 100644 index 000000000..23db97fe2 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/HybridQueryZScoreIT.java @@ -0,0 +1,205 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.opensearch.neuralsearch.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.TestUtils.createRandomVector; + +import java.io.IOException; +import java.util.*; +import java.util.stream.IntStream; + +import lombok.SneakyThrows; + +import org.junit.After; +import org.junit.Before; +import org.opensearch.index.query.BoolQueryBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermQueryBuilder; +import org.opensearch.knn.index.SpaceType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.processor.normalization.ZScoreNormalizationTechnique; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; +import org.opensearch.neuralsearch.query.NeuralQueryBuilder; + +import com.google.common.primitives.Floats; + +public class HybridQueryZScoreIT extends BaseNeuralSearchIT { + private static final String TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME = "test-neural-vector-doc-field-index"; + private static final String TEST_QUERY_TEXT = "greetings"; + private static final String TEST_QUERY_TEXT4 = "place"; + private static final String TEST_QUERY_TEXT5 = "welcome"; + private static final String TEST_DOC_TEXT1 = "Hello world"; + private static final String TEST_DOC_TEXT2 = "Hi to this place"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; + private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; + private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field-1"; + + private static final int TEST_DIMENSION = 768; + private static final SpaceType TEST_SPACE_TYPE = SpaceType.L2; + private final float[] testVector1 = createRandomVector(TEST_DIMENSION); + private final float[] testVector2 = createRandomVector(TEST_DIMENSION); + private final static String RELATION_EQUAL_TO = "eq"; + private static final String SEARCH_PIPELINE = "phase-results-pipeline"; + + @Before + public void setUp() throws Exception { + super.setUp(); + updateClusterSettings(); + prepareModel(); + createSearchPipeline( + SEARCH_PIPELINE, + ZScoreNormalizationTechnique.TECHNIQUE_NAME, + DEFAULT_COMBINATION_METHOD, + Map.of(PARAM_NAME_WEIGHTS, "[0.5,0.5]") + ); + } + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + deleteSearchPipeline(SEARCH_PIPELINE); + /* this is required to minimize chance of model not being deployed due to open memory CB, + * this happens in case we leave model from previous test case. We use new model for every test, and old model + * can be undeployed and deleted to free resources after each test case execution. + */ + findDeployedModels().forEach(this::deleteModel); + } + + @Override + public boolean isUpdateClusterSettings() { + return false; + } + + @Override + protected boolean preserveClusterUponCompletion() { + return true; + } + + /** + * Tests complex query with multiple nested sub-queries: + * { + * "query": { + * "hybrid": { + * "queries": [ + * { + * "bool": { + * "should": [ + * { + * "term": { + * "text": "word1" + * } + * }, + * { + * "term": { + * "text": "word2" + * } + * } + * ] + * } + * }, + * { + * "term": { + * "text": "word3" + * } + * } + * ] + * } + * } + * } + */ + @SneakyThrows + public void testComplexQuery_withZScoreNormalization() { + initializeIndexIfNotExist(); + + TermQueryBuilder termQueryBuilder2 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT4); + TermQueryBuilder termQueryBuilder3 = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT5); + BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder(); + boolQueryBuilder.should(termQueryBuilder2).should(termQueryBuilder3); + + String modelId = getDeployedModelId(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + modelId, + 5, + null, + null + ); + + HybridQueryBuilder hybridQueryBuilderNeuralThenTerm = new HybridQueryBuilder(); + hybridQueryBuilderNeuralThenTerm.add(neuralQueryBuilder); + hybridQueryBuilderNeuralThenTerm.add(boolQueryBuilder); + + final Map searchResponseAsMap = search( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + hybridQueryBuilderNeuralThenTerm, + null, + 5, + Map.of("search_pipeline", SEARCH_PIPELINE) + ); + + assertEquals(2, getHitCount(searchResponseAsMap)); + + List> hits1NestedList = getNestedHits(searchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hits1NestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + assertEquals(2, scores.size()); + // by design when there are only two results with z score since it's z-score normalized we would expect 1 , -1 to be the + // corresponding score, + // furthermore the combination logic with weights should make it doc1Score: (1 * w1 + 0.98 * w2)/(w1 + w2), doc2Score: -1 ~ 0 + assertEquals(0.9999, scores.get(0).floatValue(), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0, scores.get(1).floatValue(), DELTA_FOR_SCORE_ASSERTION); + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(searchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(2, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + } + + private void initializeIndexIfNotExist() throws IOException { + if (!indexExists(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)) { + prepareKnnIndex( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + List.of( + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DIMENSION, TEST_SPACE_TYPE), + new KNNFieldConfig(TEST_KNN_VECTOR_FIELD_NAME_2, TEST_DIMENSION, TEST_SPACE_TYPE) + ), + 1 + ); + + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "1", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector1).toArray(), Floats.asList(testVector1).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT1) + ); + addKnnDoc( + TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME, + "2", + List.of(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_KNN_VECTOR_FIELD_NAME_2), + List.of(Floats.asList(testVector2).toArray(), Floats.asList(testVector2).toArray()), + Collections.singletonList(TEST_TEXT_FIELD_NAME_1), + Collections.singletonList(TEST_DOC_TEXT2) + ); + assertEquals(2, getDocCount(TEST_BASIC_VECTOR_DOC_FIELD_INDEX_NAME)); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 3cd71e5a1..79db226e1 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -12,7 +12,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; @@ -341,21 +340,6 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { } } - private List> getNestedHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (List>) hitsMap.get("hits"); - } - - private Map getTotalHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (Map) hitsMap.get("total"); - } - - private Optional getMaxScore(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); - } - private void assertQueryResults(Map searchResponseAsMap, int totalExpectedDocQty, boolean assertMinScore) { assertQueryResults(searchResponseAsMap, totalExpectedDocQty, assertMinScore, Range.between(0.5f, 1.0f)); } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..45e350dbb --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java @@ -0,0 +1,174 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import java.util.List; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +public class ZScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + /** + * Z score will check the relative distance from the center of distribution and hence can also be negative. + * When only two values are available their z-score numbers will be 1 and -1 correspondingly. + * For more information regarding z-score you can check this link + * https://www.z-table.com/ + * + */ + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + // since we only have two scores of 0.5 and 0.2 their z-score numbers will be 1 and -1 + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) }) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } + ) + ) + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs(new TotalHits(2, TotalHits.Relation.EQUAL_TO), new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, -1.0f) }) + ) + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); + } + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index eec6955ff..229374730 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -13,7 +13,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Optional; import java.util.Set; import java.util.stream.IntStream; @@ -267,19 +266,4 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { assertEquals(3, getDocCount(TEST_MULTI_DOC_INDEX_NAME)); } } - - private List> getNestedHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (List>) hitsMap.get("hits"); - } - - private Map getTotalHits(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return (Map) hitsMap.get("total"); - } - - private Optional getMaxScore(Map searchResponseAsMap) { - Map hitsMap = (Map) searchResponseAsMap.get("hits"); - return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); - } } diff --git a/src/test/resources/log4j2-test.xml b/src/test/resources/log4j2-test.xml new file mode 100644 index 000000000..32c8f6bc7 --- /dev/null +++ b/src/test/resources/log4j2-test.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file