From f17f85e75e1fd87bf45de2aa90300c59890103b7 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 11 Oct 2023 21:33:16 +0800 Subject: [PATCH 1/4] add serialization and deserialization Signed-off-by: zhichao-aws --- .../neuralsearch/query/NeuralSparseQueryBuilder.java | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index fd15b431b..3a842a4ab 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -89,6 +89,10 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException { this.queryText = in.readString(); this.modelId = in.readString(); this.maxTokenScore = in.readOptionalFloat(); + if (in.readBoolean()) { + Map queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat); + this.queryTokensSupplier = () -> queryTokens; + } } @Override @@ -97,6 +101,12 @@ protected void doWriteTo(StreamOutput out) throws IOException { out.writeString(queryText); out.writeString(modelId); out.writeOptionalFloat(maxTokenScore); + if (queryTokensSupplier != null && queryTokensSupplier.get() != null) { + out.writeBoolean(true); + out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat); + } else { + out.writeBoolean(false); + } } @Override From ed6f26314602c1b2f5f9bec0eafd4cd6a5ddaa9b Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 11 Oct 2023 22:01:50 +0800 Subject: [PATCH 2/4] hash, equals. + UT Signed-off-by: zhichao-aws --- .../query/NeuralSparseQueryBuilder.java | 12 +++++++- .../query/NeuralSparseQueryBuilderTests.java | 30 +++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index 3a842a4ab..f9085b3dd 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; +import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; @@ -286,16 +287,25 @@ private static void validateQueryTokens(Map queryTokens) { protected boolean doEquals(NeuralSparseQueryBuilder obj) { if (this == obj) return true; if (obj == null || getClass() != obj.getClass()) return false; + if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false; + if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false; EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName) .append(queryText, obj.queryText) .append(modelId, obj.modelId) .append(maxTokenScore, obj.maxTokenScore); + if (queryTokensSupplier != null) { + equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get()); + } return equalsBuilder.isEquals(); } @Override protected int doHashCode() { - return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore).toHashCode(); + HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore); + if (queryTokensSupplier != null) { + builder.append(queryTokensSupplier.get()); + } + return builder.toHashCode(); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index 34850dcb7..fbb752a25 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -27,6 +27,7 @@ import lombok.SneakyThrows; import org.opensearch.client.Client; +import org.opensearch.common.SetOnce; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; @@ -281,6 +282,9 @@ public void testStreams() { original.modelId(MODEL_ID); original.boost(BOOST); original.queryName(QUERY_NAME); + SetOnce> queryTokensSetOnce = new SetOnce<>(); + queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); + original.queryTokensSupplier(queryTokensSetOnce::get); BytesStreamOutput streamOutput = new BytesStreamOutput(); original.writeTo(streamOutput); @@ -309,6 +313,8 @@ public void testHashAndEquals() { float boost2 = 3.8f; String queryName1 = "query-1"; String queryName2 = "query-2"; + Map queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f); + Map queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f); NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1) .queryText(queryText1) @@ -379,6 +385,24 @@ public void testHashAndEquals() { .boost(boost1) .queryName(queryName1); + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(()->queryTokens1); + + // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier + NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(()->queryTokens2); + assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); @@ -405,6 +429,12 @@ public void testHashAndEquals() { assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore); assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode()); + + assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens); + assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode()); } @SneakyThrows From 1f1382f21cd3b304522b076b1ffaedcdd678d4ae Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 11 Oct 2023 22:03:00 +0800 Subject: [PATCH 3/4] tidy Signed-off-by: zhichao-aws --- .../query/NeuralSparseQueryBuilder.java | 1 - .../query/NeuralSparseQueryBuilderTests.java | 24 +++++++++---------- 2 files changed, 12 insertions(+), 13 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java index f9085b3dd..d883af23d 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java @@ -6,7 +6,6 @@ package org.opensearch.neuralsearch.query; import java.io.IOException; -import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index fbb752a25..d40d70c09 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -387,21 +387,21 @@ public void testHashAndEquals() { // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) - .queryText(queryText1) - .modelId(modelId1) - .maxTokenScore(maxTokenScore1) - .boost(boost1) - .queryName(queryName1) - .queryTokensSupplier(()->queryTokens1); + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(() -> queryTokens1); // Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1) - .queryText(queryText1) - .modelId(modelId1) - .maxTokenScore(maxTokenScore1) - .boost(boost1) - .queryName(queryName1) - .queryTokensSupplier(()->queryTokens2); + .queryText(queryText1) + .modelId(modelId1) + .maxTokenScore(maxTokenScore1) + .boost(boost1) + .queryName(queryName1) + .queryTokensSupplier(() -> queryTokens2); assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline); assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode()); From 462b61451c806db7ef23a0d2a99ca49308371ef1 Mon Sep 17 00:00:00 2001 From: zhichao-aws Date: Wed, 11 Oct 2023 22:23:07 +0800 Subject: [PATCH 4/4] add test Signed-off-by: zhichao-aws --- .../query/NeuralSparseQueryBuilderTests.java | 20 ++++++++++++++++--- 1 file changed, 17 insertions(+), 3 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java index d40d70c09..a50ab4fb8 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java @@ -282,9 +282,6 @@ public void testStreams() { original.modelId(MODEL_ID); original.boost(BOOST); original.queryName(QUERY_NAME); - SetOnce> queryTokensSetOnce = new SetOnce<>(); - queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); - original.queryTokensSupplier(queryTokensSetOnce::get); BytesStreamOutput streamOutput = new BytesStreamOutput(); original.writeTo(streamOutput); @@ -298,6 +295,23 @@ public void testStreams() { NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput); assertEquals(original, copy); + + SetOnce> queryTokensSetOnce = new SetOnce<>(); + queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f)); + original.queryTokensSupplier(queryTokensSetOnce::get); + + streamOutput = new BytesStreamOutput(); + original.writeTo(streamOutput); + + filterStreamInput = new NamedWriteableAwareStreamInput( + streamOutput.bytes().streamInput(), + new NamedWriteableRegistry( + List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new)) + ) + ); + + copy = new NeuralSparseQueryBuilder(filterStreamInput); + assertEquals(original, copy); } public void testHashAndEquals() {