Skip to content

Commit

Permalink
hash, equals. + UT
Browse files Browse the repository at this point in the history
Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws committed Oct 11, 2023
1 parent f17f85e commit ed6f263
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -286,16 +287,25 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore);
if (queryTokensSupplier != null) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore).toHashCode();
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore);
if (queryTokensSupplier != null) {
builder.append(queryTokensSupplier.get());
}
return builder.toHashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import lombok.SneakyThrows;

import org.opensearch.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -281,6 +282,9 @@ public void testStreams() {
original.modelId(MODEL_ID);
original.boost(BOOST);
original.queryName(QUERY_NAME);
SetOnce<Map<String, Float>> queryTokensSetOnce = new SetOnce<>();
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
original.queryTokensSupplier(queryTokensSetOnce::get);

BytesStreamOutput streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);
Expand Down Expand Up @@ -309,6 +313,8 @@ public void testHashAndEquals() {
float boost2 = 3.8f;
String queryName1 = "query-1";
String queryName2 = "query-2";
Map<String, Float> queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f);
Map<String, Float> queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f);

NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
Expand Down Expand Up @@ -379,6 +385,24 @@ public void testHashAndEquals() {
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(()->queryTokens1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(()->queryTokens2);

assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());

Expand All @@ -405,6 +429,12 @@ public void testHashAndEquals() {

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode());
}

@SneakyThrows
Expand Down

0 comments on commit ed6f263

Please sign in to comment.