Skip to content

Commit

Permalink
[Backport main manually][bug fix] Fix async actions are left in neura…
Browse files Browse the repository at this point in the history
…l_sparse query (#438) (#479)

* [bug fix] Fix async actions are left in neural_sparse query (#438)

* add serialization and deserialization

Signed-off-by: zhichao-aws <[email protected]>

* hash, equals. + UT

Signed-off-by: zhichao-aws <[email protected]>

* tidy

Signed-off-by: zhichao-aws <[email protected]>

* add test

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
(cherry picked from commit 51e6c00)

* rm max_token_score

Signed-off-by: zhichao-aws <[email protected]>

* add changelog

Signed-off-by: zhichao-aws <[email protected]>

* tidy

Signed-off-by: zhichao-aws <[email protected]>

---------

Signed-off-by: zhichao-aws <[email protected]>
  • Loading branch information
zhichao-aws authored Nov 17, 2023
1 parent 00c5589 commit a72df66
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 2 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
### Bug Fixes
Fix async actions are left in neural_sparse query ([438](https://github.com/opensearch-project/neural-search/pull/438))
### Infrastructure
### Documentation
### Maintenance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;

import lombok.AllArgsConstructor;
Expand Down Expand Up @@ -84,13 +85,23 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
if (in.readBoolean()) {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeString(queryText);
out.writeString(modelId);
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
} else {
out.writeBoolean(false);
}
}

@Override
Expand Down Expand Up @@ -256,16 +267,25 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false;
if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId);
if (!Objects.isNull(queryTokensSupplier)) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode();
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId);
if (!Objects.isNull(queryTokensSupplier)) {
builder.append(queryTokensSupplier.get());
}
return builder.toHashCode();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
import lombok.SneakyThrows;

import org.opensearch.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
Expand Down Expand Up @@ -262,6 +263,23 @@ public void testStreams() {

NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);

SetOnce<Map<String, Float>> queryTokensSetOnce = new SetOnce<>();
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
original.queryTokensSupplier(queryTokensSetOnce::get);

BytesStreamOutput streamOutput2 = new BytesStreamOutput();
original.writeTo(streamOutput2);

filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput2.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
);

copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

public void testHashAndEquals() {
Expand All @@ -275,6 +293,8 @@ public void testHashAndEquals() {
float boost2 = 3.8f;
String queryName1 = "query-1";
String queryName2 = "query-2";
Map<String, Float> queryTokens1 = Map.of("hello", 1.0f, "world", 2.0f);
Map<String, Float> queryTokens2 = Map.of("hello", 1.0f, "world", 2.2f);

NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
Expand Down Expand Up @@ -329,6 +349,22 @@ public void testHashAndEquals() {
.boost(boost1)
.queryName(queryName2);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);

assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());

Expand All @@ -352,6 +388,12 @@ public void testHashAndEquals() {

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode());
}

@SneakyThrows
Expand Down

0 comments on commit a72df66

Please sign in to comment.