-
Notifications
You must be signed in to change notification settings - Fork 72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[FEATURE] Add z-score for the normalization processor #376 #468
Changes from 4 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -26,8 +26,6 @@ | |
@Log4j2 | ||
public class ScoreCombiner { | ||
|
||
private static final Float ZERO_SCORE = 0.0f; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. any reason why we are removing this? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It's not in use anywhere in the code |
||
|
||
/** | ||
* Performs score combination based on input combination technique. Mutates input object by updating combined scores | ||
* Main steps we're doing for combination: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.processor.normalization; | ||
|
||
import com.google.common.primitives.Floats; | ||
import lombok.ToString; | ||
import org.apache.lucene.search.ScoreDoc; | ||
import org.apache.lucene.search.TopDocs; | ||
import org.opensearch.neuralsearch.processor.CompoundTopDocs; | ||
|
||
import java.util.Arrays; | ||
import java.util.List; | ||
import java.util.Objects; | ||
|
||
/** | ||
* Implementation of z-score normalization technique for hybrid query | ||
* This is currently modeled based on the existing normalization techniques {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} | ||
* However, this class as well as the original ones require a significant work to improve style and ease of use, see TODO items below | ||
*/ | ||
/* | ||
TODO: Some todo items that apply here but also on the original normalization techniques on which it is modeled {@link L2ScoreNormalizationTechnique} and {@link MinMaxScoreNormalizationTechnique} | ||
1. Random access to abstract list object is a bad practice both stylistically and from performance perspective and should be removed | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please provide an alternative what should be used? As per my understanding, random access on the List is bad if List concrete implementation is LinkedList. But what I have seen generally is we use ArrayList which is backed by arrays, hence random access is done in constant time. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It should be fine if we know the exact implementation of List, as Navneet mentioned. But with list we can use functional style easier, without expensive conversion array -> stream, that was a reason why we switched to a List. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Usually it is highly discouraged to do There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm ok to switch from using general There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @martin-gaievski same here, I added the comment out of intention to propose as a separate refactoring PR. |
||
2. Identical sub queries and their distribution between shards is currently completely implicit based on ordering and should be explicit based on identifier | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is really a good thought, but problem is none of the query clauses in Opensearch supports identifiers. During the implementation this was discussed. The problem is the way after QueryPhase the results are returned. They are returned in a ScoreDocs array which doesn't support identifiers. We can go around that but it will require changes in interface of OpenSearch Core. Hence we decided against it to make sure that we are compatible with OpenSearch core. If there is an alternative supported in opensearch please let us know, may be we are missing something There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good @navneet1v I will give it some thought and will come up with suggestion. In any case not planning to do as part of this change. Can keep it for now and can suggest refactor or just remove if not achievable. |
||
3. Weird calculation of numOfSubQueries instead of having a more explicit indicator | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. same as above. |
||
*/ | ||
@ToString(onlyExplicitlyIncluded = true) | ||
public class ZScoreNormalizationTechnique implements ScoreNormalizationTechnique { | ||
@ToString.Include | ||
public static final String TECHNIQUE_NAME = "z_score"; | ||
private static final float SINGLE_RESULT_SCORE = 1.0f; | ||
@Override | ||
public void normalize(List<CompoundTopDocs> queryTopDocs) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please make all args of all public methods |
||
// why are we doing that? is List<CompoundTopDocs> the list of subqueries for a single shard? or a global list of all subqueries across shards? | ||
// If a subquery comes from each shard then when is it combined? that seems weird that combination will do combination of normalized results that each is normalized just based on shard level result | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lets talk about these on the github issue and not on the PR. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ack, I think this comment is no longer relevant I put it there and forgot to remove so feel free to ignore this one. |
||
int numOfSubQueries = queryTopDocs.stream() | ||
.filter(Objects::nonNull) | ||
.filter(topDocs -> topDocs.getTopDocs().size() > 0) | ||
.findAny() | ||
.get() | ||
.getTopDocs() | ||
.size(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can you please add more checks for nulls and empty objects. I think we're assuming a lot, e.g. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yeah agreed, I will add checks, currently it's modeled on existing normalization techniques (MinMax/L2) which use similar code. |
||
|
||
// to be done for each subquery | ||
float[] sumPerSubquery = findScoreSumPerSubQuery(queryTopDocs, numOfSubQueries); | ||
long[] elementsPerSubquery = findNumberOfElementsPerSubQuery(queryTopDocs, numOfSubQueries); | ||
float[] meanPerSubQuery = findMeanPerSubquery(sumPerSubquery, elementsPerSubquery); | ||
float[] stdPerSubquery = findStdPerSubquery(queryTopDocs, meanPerSubQuery, elementsPerSubquery, numOfSubQueries); | ||
|
||
// do normalization using actual score and z-scores for corresponding sub query | ||
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { | ||
if (Objects.isNull(compoundQueryTopDocs)) { | ||
continue; | ||
} | ||
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); | ||
for (int j = 0; j < topDocsPerSubQuery.size(); j++) { | ||
TopDocs subQueryTopDoc = topDocsPerSubQuery.get(j); | ||
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { | ||
scoreDoc.score = normalizeSingleScore(scoreDoc.score, stdPerSubquery[j], meanPerSubQuery[j]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
static private float[] findScoreSumPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nit. private would be better unless you have specific reason this to be static. Better way would be moving all these methods to another class to make it easier to write unit test. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. convention I was following is that if method is not dependent on any instance object it should be static. |
||
final float[] sumOfScorePerSubQuery = new float[numOfScores]; | ||
Arrays.fill(sumOfScorePerSubQuery, 0); | ||
//TODO: make this better, currently | ||
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) | ||
// which does a random search on an abstract list type. | ||
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { | ||
if (Objects.isNull(compoundQueryTopDocs)) { | ||
continue; | ||
} | ||
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); | ||
for (int j = 0; j < topDocsPerSubQuery.size(); j++) { | ||
sumOfScorePerSubQuery[j] += sumScoreDocsArray(topDocsPerSubQuery.get(j).scoreDocs); | ||
} | ||
} | ||
|
||
return sumOfScorePerSubQuery; | ||
} | ||
|
||
static private long[] findNumberOfElementsPerSubQuery(final List<CompoundTopDocs> queryTopDocs, final int numOfScores) { | ||
final long[] numberOfElementsPerSubQuery = new long[numOfScores]; | ||
Arrays.fill(numberOfElementsPerSubQuery, 0); | ||
//TODO: make this better, currently | ||
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) | ||
// which does a random search on an abstract list type. | ||
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { | ||
if (Objects.isNull(compoundQueryTopDocs)) { | ||
continue; | ||
} | ||
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); | ||
for (int j = 0; j < topDocsPerSubQuery.size(); j++) { | ||
numberOfElementsPerSubQuery[j] += topDocsPerSubQuery.get(j).totalHits.value; | ||
} | ||
} | ||
|
||
return numberOfElementsPerSubQuery; | ||
} | ||
|
||
static private float[] findMeanPerSubquery(final float[] sumPerSubquery, final long[] elementsPerSubquery) { | ||
final float[] meanPerSubQuery = new float[elementsPerSubquery.length]; | ||
for (int i = 0; i < elementsPerSubquery.length; i++) { | ||
if (elementsPerSubquery[i] == 0) { | ||
meanPerSubQuery[i] = 0; | ||
} else { | ||
meanPerSubQuery[i] = sumPerSubquery[i]/elementsPerSubquery[i]; | ||
} | ||
} | ||
|
||
return meanPerSubQuery; | ||
} | ||
|
||
static private float[] findStdPerSubquery(final List<CompoundTopDocs> queryTopDocs, final float[] meanPerSubQuery, final long[] elementsPerSubquery, final int numOfScores) { | ||
final double[] deltaSumPerSubquery = new double[numOfScores]; | ||
Arrays.fill(deltaSumPerSubquery, 0); | ||
|
||
|
||
//TODO: make this better, currently | ||
// this is a horrible implementation in particular when it comes to the topDocsPerSubQuery.get(j) | ||
// which does a random search on an abstract list type. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please provide reason why this is bad and how it can be improved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, let's avoid using subject word like 'horrible'. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @heemin32 my apologies, will avoid it in the future. |
||
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) { | ||
if (Objects.isNull(compoundQueryTopDocs)) { | ||
continue; | ||
} | ||
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs(); | ||
for (int j = 0; j < topDocsPerSubQuery.size(); j++) { | ||
for (ScoreDoc scoreDoc : topDocsPerSubQuery.get(j).scoreDocs) { | ||
deltaSumPerSubquery[j] += Math.pow(scoreDoc.score - meanPerSubQuery[j], 2); | ||
} | ||
} | ||
} | ||
|
||
final float[] stdPerSubQuery = new float[numOfScores]; | ||
for (int i = 0; i < deltaSumPerSubquery.length; i++) { | ||
if (elementsPerSubquery[i] == 0) { | ||
stdPerSubQuery[i] = 0; | ||
} else { | ||
stdPerSubQuery[i] = (float) Math.sqrt(deltaSumPerSubquery[i] / elementsPerSubquery[i]); | ||
} | ||
} | ||
|
||
return stdPerSubQuery; | ||
} | ||
|
||
static private float sumScoreDocsArray(ScoreDoc[] scoreDocs) { | ||
sam-herman marked this conversation as resolved.
Show resolved
Hide resolved
|
||
float sum = 0; | ||
for (ScoreDoc scoreDoc : scoreDocs) { | ||
sum += scoreDoc.score; | ||
} | ||
|
||
return sum; | ||
} | ||
|
||
private static float normalizeSingleScore(final float score, final float standardDeviation, final float mean) { | ||
// edge case when there is only one score and min and max scores are same | ||
if (Floats.compare(mean, score) == 0) { | ||
return SINGLE_RESULT_SCORE; | ||
} | ||
return (score - mean) / standardDeviation; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
*/ | ||
|
||
package org.opensearch.neuralsearch.processor.normalization; | ||
|
||
import org.apache.lucene.search.ScoreDoc; | ||
import org.apache.lucene.search.TopDocs; | ||
import org.apache.lucene.search.TotalHits; | ||
import org.opensearch.neuralsearch.processor.CompoundTopDocs; | ||
import org.opensearch.neuralsearch.query.OpenSearchQueryTestCase; | ||
|
||
import java.util.List; | ||
|
||
public class ZScoreNormalizationTechniqueTests extends OpenSearchQueryTestCase { | ||
private static final float DELTA_FOR_ASSERTION = 0.0001f; | ||
|
||
public void testNormalization_whenResultFromOneShardOneSubQuery_thenSuccessful() { | ||
ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); | ||
List<CompoundTopDocs> compoundTopDocs = List.of( | ||
new CompoundTopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } | ||
) | ||
) | ||
) | ||
); | ||
normalizationTechnique.normalize(compoundTopDocs); | ||
|
||
CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } | ||
) | ||
) | ||
); | ||
assertNotNull(compoundTopDocs); | ||
assertEquals(1, compoundTopDocs.size()); | ||
assertNotNull(compoundTopDocs.get(0).getTopDocs()); | ||
assertCompoundTopDocs( | ||
new TopDocs(expectedCompoundDocs.getTotalHits(), expectedCompoundDocs.getScoreDocs().toArray(new ScoreDoc[0])), | ||
compoundTopDocs.get(0).getTopDocs().get(0) | ||
); | ||
} | ||
|
||
public void testNormalization_whenResultFromOneShardMultipleSubQueries_thenSuccessful() { | ||
ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); | ||
List<CompoundTopDocs> compoundTopDocs = List.of( | ||
new CompoundTopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } | ||
), | ||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), | ||
new TopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } | ||
) | ||
) | ||
) | ||
); | ||
normalizationTechnique.normalize(compoundTopDocs); | ||
|
||
CompoundTopDocs expectedCompoundDocs = new CompoundTopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } | ||
), | ||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), | ||
new TopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you please add a function or description with a formula of how those expected scores are calculated? Also why we have a negative score value for one of ScoreDocs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @martin-gaievski will add documentation. Regarding negatives, z-scores can also be negative: |
||
) | ||
) | ||
); | ||
assertNotNull(compoundTopDocs); | ||
assertEquals(1, compoundTopDocs.size()); | ||
assertNotNull(compoundTopDocs.get(0).getTopDocs()); | ||
for (int i = 0; i < expectedCompoundDocs.getTopDocs().size(); i++) { | ||
assertCompoundTopDocs(expectedCompoundDocs.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); | ||
} | ||
} | ||
|
||
public void testNormalization_whenResultFromMultipleShardsMultipleSubQueries_thenSuccessful() { | ||
ZScoreNormalizationTechnique normalizationTechnique = new ZScoreNormalizationTechnique(); | ||
List<CompoundTopDocs> compoundTopDocs = List.of( | ||
new CompoundTopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(2, 0.5f), new ScoreDoc(4, 0.2f) } | ||
), | ||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), | ||
new TopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(3, 0.9f), new ScoreDoc(4, 0.7f), new ScoreDoc(2, 0.1f) } | ||
) | ||
) | ||
), | ||
new CompoundTopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(7, 2.9f), new ScoreDoc(9, 0.7f) } | ||
) | ||
) | ||
) | ||
); | ||
normalizationTechnique.normalize(compoundTopDocs); | ||
|
||
CompoundTopDocs expectedCompoundDocsShard1 = new CompoundTopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(2, 1.0f), new ScoreDoc(4, -1.0f) } | ||
), | ||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), | ||
new TopDocs( | ||
new TotalHits(3, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(3, 0.98058068f), new ScoreDoc(4, 0.39223227f), new ScoreDoc(2, -1.37281295f) } | ||
) | ||
) | ||
); | ||
|
||
CompoundTopDocs expectedCompoundDocsShard2 = new CompoundTopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
List.of( | ||
new TopDocs(new TotalHits(0, TotalHits.Relation.EQUAL_TO), new ScoreDoc[0]), | ||
new TopDocs( | ||
new TotalHits(2, TotalHits.Relation.EQUAL_TO), | ||
new ScoreDoc[] { new ScoreDoc(7, 1.0f), new ScoreDoc(9, -1.0f) } | ||
) | ||
) | ||
); | ||
|
||
assertNotNull(compoundTopDocs); | ||
assertEquals(2, compoundTopDocs.size()); | ||
assertNotNull(compoundTopDocs.get(0).getTopDocs()); | ||
for (int i = 0; i < expectedCompoundDocsShard1.getTopDocs().size(); i++) { | ||
assertCompoundTopDocs(expectedCompoundDocsShard1.getTopDocs().get(i), compoundTopDocs.get(0).getTopDocs().get(i)); | ||
} | ||
assertNotNull(compoundTopDocs.get(1).getTopDocs()); | ||
for (int i = 0; i < expectedCompoundDocsShard2.getTopDocs().size(); i++) { | ||
assertCompoundTopDocs(expectedCompoundDocsShard2.getTopDocs().get(i), compoundTopDocs.get(1).getTopDocs().get(i)); | ||
} | ||
} | ||
|
||
private void assertCompoundTopDocs(TopDocs expected, TopDocs actual) { | ||
assertEquals(expected.totalHits.value, actual.totalHits.value); | ||
assertEquals(expected.totalHits.relation, actual.totalHits.relation); | ||
assertEquals(expected.scoreDocs.length, actual.scoreDocs.length); | ||
for (int i = 0; i < expected.scoreDocs.length; i++) { | ||
assertEquals(expected.scoreDocs[i].score, actual.scoreDocs[i].score, DELTA_FOR_ASSERTION); | ||
assertEquals(expected.scoreDocs[i].doc, actual.scoreDocs[i].doc); | ||
assertEquals(expected.scoreDocs[i].shardIndex, actual.scoreDocs[i].shardIndex); | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can remove this.