From e6d2130599ad5f8ddff040ae6264a0e3b16e1946 Mon Sep 17 00:00:00 2001 From: Samuel Herman Date: Sat, 7 Oct 2023 09:49:29 -0700 Subject: [PATCH] add z-score and logging for tests Signed-off-by: Samuel Herman --- build.gradle | 6 + .../NormalizationProcessorWorkflow.java | 1 + .../ZScoreNormalizationTechnique.java | 168 +++++++++++++++++ .../ZScoreNormalizationTechniqueTests.java | 172 ++++++++++++++++++ src/test/resources/log4j2-test.xml | 13 ++ 5 files changed, 360 insertions(+) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java create mode 100644 src/test/resources/log4j2-test.xml diff --git a/build.gradle b/build.gradle index 853aa85e7..90731feb5 100644 --- a/build.gradle +++ b/build.gradle @@ -218,6 +218,12 @@ integTest { if (System.getProperty("test.debug") != null) { jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005' } + + systemProperty 'log4j2.configurationFile', "${projectDir}/src/test/resources/log4j2-test.xml" + + // Set this to true this if you want to see the logs in the terminal test output. + // note: if left false the log output will still show in your IDE + testLogging.showStandardStreams = true } testClusters.integTest { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index 9e0069b21..d5e898185 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -52,6 +52,7 @@ public void execute( final ScoreNormalizationTechnique normalizationTechnique, final ScoreCombinationTechnique combinationTechnique ) { + log.info("Entering normalization processor workflow"); // save original state List unprocessedDocIds = unprocessedDocIds(querySearchResults); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java new file mode 100644 index 000000000..6d6fadf2b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechnique.java @@ -0,0 +1,168 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import com.google.common.primitives.Floats; +import lombok.ToString; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; + +import java.util.Arrays; +import java.util.List; +import java.util.Objects; + +/** + * Implementation of z-score normalization technique for hybrid query + * This is currently modeled based on the existing normalization techniques {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} + * However, this class as well as the original ones require a significant work to improve style and ease of use, see TODO items below + */ +/* +TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} +1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed +2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier +3. Weird calculation of numOfSubQueries instead of having a more explicit indicator + */ +@ToString(onlyExplicitlyIncluded = true) +public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique { + @ToString.Include + public static final String TECHNIQUE_NAME = "z_score"; + private static final float MIN_SCORE = 0.001f; + private static final float SINGLE_RESULT_SCORE = 1.0f; + @Override + public void normalize(List queryTopDocs) { + // why are we doing that? is List the list of subqueries for a single shard? or a global list of all subqueries across shards? + // If a subquery comes from each shard then when is it combined? that seems weird that combination will do combination of normalized results that each is normalized just based on shard level result + int numOfSubQueries = queryTopDocs.stream() + .filter(Objects::nonNull) + .filter(topDocs -> topDocs.getTopDocs().size() > 0) + .findAny() + .get() + .getTopDocs() + .size(); + + // to be done for each subquery + float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries); + long[] elementsPerSubquery = findNumberOfElementsPerSubQuery(queryTopDocs, numOfSubQueries); + float[] meanPerSubQuery = findMeanPerSubquery(sumPerSubquery, elementsPerSubquery); + float[] stdPerSubquery = findStdPerSubquery(queryTopDocs, meanPerSubQuery, elementsPerSubquery, numOfSubQueries); + + // do normalization using actual score and z-scores for corresponding sub query + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); + for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { + scoreDoc.score = normalizeSingleScore(scoreDoc.score, stdPerSubquery[j], meanPerSubQuery[j]); + } + } + } + } + + static private float[] findScoreSumPerSubQuery(final List queryTopDocs, final int numOfScores) { + final float[] sumOfScorePerSubQuery = new float[numOfScores]; + Arrays.fill(sumOfScorePerSubQuery, 0); + //TODO: make this better, currently + // this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) + // which does a random search on an abstract list type. + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + sumOfScorePerSubQuery[j] += sumScoreDocsArray(topDocsPerSubQuery.get(j).scoreDocs); + } + } + + return sumOfScorePerSubQuery; + } + + static private long[] findNumberOfElementsPerSubQuery(final List queryTopDocs, final int numOfScores) { + final long[] numberOfElementsPerSubQuery = new long[numOfScores]; + Arrays.fill(numberOfElementsPerSubQuery, 0); + //TODO: make this better, currently + // this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) + // which does a random search on an abstract list type. + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + numberOfElementsPerSubQuery[j] += topDocsPerSubQuery.get(j).totalHits.value; + } + } + + return numberOfElementsPerSubQuery; + } + + static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final long[] elementsPerSubquery) { + final float[] meanPerSubQuery = new float[elementsPerSubquery.length]; + for (int i = 0; i < elementsPerSubquery.length; i++) { + if (elementsPerSubquery[i] == 0) { + meanPerSubQuery[i] = 0; + } else { + meanPerSubQuery[i] = sumPerSubquery[i]/elementsPerSubquery[i]; + } + } + + return meanPerSubQuery; + } + + static private float[] findStdPerSubquery(final List queryTopDocs, final float[] meanPerSubQuery, final long[] elementsPerSubquery, final int numOfScores) { + final double[] deltaSumPerSubquery = new double[numOfScores]; + Arrays.fill(deltaSumPerSubquery, 0); + + + //TODO: make this better, currently + // this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) + // which does a random search on an abstract list type. + for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { + if (Objects.isNull(compoundQueryTopDocs)) { + continue; + } + List topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); + for (int j = 0; j < topDocsPerSubQuery.size(); j++) { + for (ScoreDoc scoreDoc : topDocsPerSubQuery.get(j).scoreDocs) { + deltaSumPerSubquery[j] += Math.pow(scoreDoc.score - meanPerSubQuery[j], 2); + } + } + } + + final float[] stdPerSubQuery = new float[numOfScores]; + for (int i = 0; i < deltaSumPerSubquery.length; i++) { + if (elementsPerSubquery[i] == 0) { + stdPerSubQuery[i] = 0; + } else { + stdPerSubQuery[i] = (float) Math.sqrt(deltaSumPerSubquery[i] / elementsPerSubquery[i]); + } + } + + return stdPerSubQuery; + } + + static private float sumScoreDocsArray(ScoreDoc[] scoreDocs) { + float sum = 0; + for (ScoreDoc scoreDoc : scoreDocs) { + sum += scoreDoc.score; + } + + return sum; + } + + private static float normalizeSingleScore(final float score, final float standardDeviation, final float mean) { + // edge case when there is only one score and min and max scores are same + if (Floats.compare(mean, score) == 0) { + return SINGLE_RESULT_SCORE; + } + float normalizedScore = (score - mean) / standardDeviation; + return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java new file mode 100644 index 000000000..1d0c61373 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/normalization/ZScoreNormalizationTechniqueTests.java @@ -0,0 +1,172 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.normalization; + +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.opensearch.neuralsearch.processor.CompoundTopDocs; +import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; + +import java.util.List; + +public class ZScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { + private static final float DELTA_FOR_ASSERTION = 0.0001f; + + public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { + ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + assertCompoundTopDocs( + new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), + compoundTopDocs.get(0).getTopDocs().get(0) + ); + } + + public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { + ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } + ) + ) + ); + assertNotNull(compoundTopDocs); + assertEquals(1, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + } + + public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { + ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); + List compoundTopDocs = List.of( + new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } + ) + ) + ), + new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } + ) + ) + ) + ); + normalizationTechnique.normalize(compoundTopDocs); + + CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } + ), + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(3, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } + ) + ) + ); + + CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + List.of( + new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), + new TopDocs( + new TotalHits(2, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, -1.0f) } + ) + ) + ); + + assertNotNull(compoundTopDocs); + assertEquals(2, compoundTopDocs.size()); + assertNotNull(compoundTopDocs.get(0).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); + } + assertNotNull(compoundTopDocs.get(1).getTopDocs()); + for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { + assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); + } + } + + private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { + assertEquals(expected.totalHits.value, actual.totalHits.value); + assertEquals(expected.totalHits.relation, actual.totalHits.relation); + assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); + for (int i = 0; i < expected.scoreDocs.length; i++) { + assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); + assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); + assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); + } + } +} diff --git a/src/test/resources/log4j2-test.xml b/src/test/resources/log4j2-test.xml new file mode 100644 index 000000000..32c8f6bc7 --- /dev/null +++ b/src/test/resources/log4j2-test.xml @@ -0,0 +1,13 @@ + + + + + + + + + + + + + \ No newline at end of file