Skip to content

Commit

Permalink
add z-score and logging for tests
Browse files Browse the repository at this point in the history
Signed-off-by: Samuel Herman <[email protected]>
  • Loading branch information
sam-herman committed Oct 17, 2023
1 parent 0f73cc6 commit e6d2130
Show file tree
Hide file tree
Showing 5 changed files with 360 additions and 0 deletions.
6 changes: 6 additions & 0 deletions build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ integTest {
if (System.getProperty("test.debug") != null) {
jvmArgs '-agentlib:jdwp=transport=dt_socket,server=y,suspend=y,address=*:5005'
}

systemProperty 'log4j2.configurationFile', "${projectDir}/src/test/resources/log4j2-test.xml"

// Set this to true this if you want to see the logs in the terminal test output.
// note: if left false the log output will still show in your IDE
testLogging.showStandardStreams = true
}

testClusters.integTest {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ public void execute(
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique
) {
log.info("Entering normalization processor workflow");
// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.normalization;

import com.google.common.primitives.Floats;
import lombok.ToString;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import java.util.Arrays;
import java.util.List;
import java.util.Objects;

/**
* Implementation of z-score normalization technique for hybrid query
* This is currently modeled based on the existing normalization techniques {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique}
* However, this class as well as the original ones require a significant work to improve style and ease of use, see TODO items below
*/
/*
TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique}
1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed
2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier
3. Weird calculation of numOfSubQueries instead of having a more explicit indicator
*/
@ToString(onlyExplicitlyIncluded = true)
public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "z_score";
private static final float MIN_SCORE = 0.001f;
private static final float SINGLE_RESULT_SCORE = 1.0f;
@Override
public void normalize(List<CompoundTopDocs> queryTopDocs) {
// why are we doing that? is List<CompoundTopDocs> the list of subqueries for a single shard? or a global list of all subqueries across shards?
// If a subquery comes from each shard then when is it combined? that seems weird that combination will do combination of normalized results that each is normalized just based on shard level result
int numOfSubQueries = queryTopDocs.stream()
.filter(Objects::nonNull)
.filter(topDocs -> topDocs.getTopDocs().size() > 0)
.findAny()
.get()
.getTopDocs()
.size();

// to be done for each subquery
float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries);
long[] elementsPerSubquery = findNumberOfElementsPerSubQuery(queryTopDocs, numOfSubQueries);
float[] meanPerSubQuery = findMeanPerSubquery(sumPerSubquery, elementsPerSubquery);
float[] stdPerSubquery = findStdPerSubquery(queryTopDocs, meanPerSubQuery, elementsPerSubquery, numOfSubQueries);

// do normalization using actual score and z-scores for corresponding sub query
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j);
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
scoreDoc.score = normalizeSingleScore(scoreDoc.score, stdPerSubquery[j], meanPerSubQuery[j]);
}
}
}
}

static private float[] findScoreSumPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
final float[] sumOfScorePerSubQuery = new float[numOfScores];
Arrays.fill(sumOfScorePerSubQuery, 0);
//TODO: make this better, currently
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j)
// which does a random search on an abstract list type.
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
sumOfScorePerSubQuery[j] += sumScoreDocsArray(topDocsPerSubQuery.get(j).scoreDocs);
}
}

return sumOfScorePerSubQuery;
}

static private long[] findNumberOfElementsPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) {
final long[] numberOfElementsPerSubQuery = new long[numOfScores];
Arrays.fill(numberOfElementsPerSubQuery, 0);
//TODO: make this better, currently
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j)
// which does a random search on an abstract list type.
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
numberOfElementsPerSubQuery[j] += topDocsPerSubQuery.get(j).totalHits.value;
}
}

return numberOfElementsPerSubQuery;
}

static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final long[] elementsPerSubquery) {
final float[] meanPerSubQuery = new float[elementsPerSubquery.length];
for (int i = 0; i < elementsPerSubquery.length; i++) {
if (elementsPerSubquery[i] == 0) {
meanPerSubQuery[i] = 0;
} else {
meanPerSubQuery[i] = sumPerSubquery[i]/elementsPerSubquery[i];
}
}

return meanPerSubQuery;
}

static private float[] findStdPerSubquery(final List<CompoundTopDocs> queryTopDocs, final float[] meanPerSubQuery, final long[] elementsPerSubquery, final int numOfScores) {
final double[] deltaSumPerSubquery = new double[numOfScores];
Arrays.fill(deltaSumPerSubquery, 0);


//TODO: make this better, currently
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j)
// which does a random search on an abstract list type.
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (int j = 0; j < topDocsPerSubQuery.size(); j++) {
for (ScoreDoc scoreDoc : topDocsPerSubQuery.get(j).scoreDocs) {
deltaSumPerSubquery[j] += Math.pow(scoreDoc.score - meanPerSubQuery[j], 2);
}
}
}

final float[] stdPerSubQuery = new float[numOfScores];
for (int i = 0; i < deltaSumPerSubquery.length; i++) {
if (elementsPerSubquery[i] == 0) {
stdPerSubQuery[i] = 0;
} else {
stdPerSubQuery[i] = (float) Math.sqrt(deltaSumPerSubquery[i] / elementsPerSubquery[i]);
}
}

return stdPerSubQuery;
}

static private float sumScoreDocsArray(ScoreDoc[] scoreDocs) {
float sum = 0;
for (ScoreDoc scoreDoc : scoreDocs) {
sum += scoreDoc.score;
}

return sum;
}

private static float normalizeSingleScore(final float score, final float standardDeviation, final float mean) {
// edge case when there is only one score and min and max scores are same
if (Floats.compare(mean, score) == 0) {
return SINGLE_RESULT_SCORE;
}
float normalizedScore = (score - mean) / standardDeviation;
return normalizedScore == 0.0f ? MIN_SCORE : normalizedScore;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.neuralsearch.processor.normalization;

import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase;

import java.util.List;

public class ZScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase {
private static final float DELTA_FOR_ASSERTION = 0.0001f;

public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() {
ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique();
List<CompoundTopDocs> compoundTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) }
)
)
)
);
normalizationTechnique.normalize(compoundTopDocs);

CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) }
)
)
);
assertNotNull(compoundTopDocs);
assertEquals(1, compoundTopDocs.size());
assertNotNull(compoundTopDocs.get(0).getTopDocs());
assertCompoundTopDocs(
new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])),
compoundTopDocs.get(0).getTopDocs().get(0)
);
}

public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() {
ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique();
List<CompoundTopDocs> compoundTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) }
),
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) }
)
)
)
);
normalizationTechnique.normalize(compoundTopDocs);

CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) }
),
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) }
)
)
);
assertNotNull(compoundTopDocs);
assertEquals(1, compoundTopDocs.size());
assertNotNull(compoundTopDocs.get(0).getTopDocs());
for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) {
assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i));
}
}

public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() {
ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique();
List<CompoundTopDocs> compoundTopDocs = List.of(
new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) }
),
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) }
)
)
),
new CompoundTopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) }
)
)
)
);
normalizationTechnique.normalize(compoundTopDocs);

CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) }
),
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(
new TotalHits(3, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) }
)
)
);

CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
List.of(
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]),
new TopDocs(
new TotalHits(2, TotalHits.Relation.EQUAL_TO),
new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, -1.0f) }
)
)
);

assertNotNull(compoundTopDocs);
assertEquals(2, compoundTopDocs.size());
assertNotNull(compoundTopDocs.get(0).getTopDocs());
for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) {
assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i));
}
assertNotNull(compoundTopDocs.get(1).getTopDocs());
for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) {
assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i));
}
}

private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) {
assertEquals(expected.totalHits.value, actual.totalHits.value);
assertEquals(expected.totalHits.relation, actual.totalHits.relation);
assertEquals(expected.scoreDocs.length, actual.scoreDocs.length);
for (int i = 0; i < expected.scoreDocs.length; i++) {
assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION);
assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc);
assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex);
}
}
}
13 changes: 13 additions & 0 deletions src/test/resources/log4j2-test.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
<?xml version="1.0" encoding="UTF-8"?>
<Configuration status="WARN">
<Appenders>
<Console name="Console" target="SYSTEM_OUT">
<PatternLayout pattern="%d{HH:mm:ss.SSS} [%t] %-5level %logger{36} - %msg%n"/>
</Console>
</Appenders>
<Loggers>
<Root level="debug">
<AppenderRef ref="Console"/>
</Root>
</Loggers>
</Configuration>

0 comments on commit e6d2130

Please sign in to comment.