Skip to content

Commit

Permalink
[BUG FIX] Fix bwc failure in neural sparse search (#696)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhichao-aws authored Apr 20, 2024
1 parent dd3b30c commit 7b0229d
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
- Fix async actions are left in neural_sparse query ([#438](https://github.com/opensearch-project/neural-search/pull/438))
- Fix typo for sparse encoding processor factory([#578](https://github.com/opensearch-project/neural-search/pull/578))
- Add non-null check for queryBuilder in NeuralQueryEnricherProcessor ([#615](https://github.com/opensearch-project/neural-search/pull/615))
- Add max_token_score field placeholder in NeuralSparseQueryBuilder to fix the rolling-upgrade from 2.x nodes bwc tests. ([#696](https://github.com/opensearch-project/neural-search/pull/696))
### Infrastructure
- Adding integration tests for scenario of hybrid query with aggregations ([#632](https://github.com/opensearch-project/neural-search/pull/632))
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
*/
package org.opensearch.neuralsearch.bwc;

import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import org.junit.Before;
import org.opensearch.common.settings.Settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,9 @@
*/
package org.opensearch.neuralsearch.bwc;

import java.net.URL;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.Locale;
import java.util.Objects;
import java.util.Optional;
import org.junit.Before;
import org.opensearch.common.settings.Settings;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@
import lombok.extern.log4j.Log4j2;

/**
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML SPARSE_ENCODING model
* SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
* or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed
* to Lucene FeatureQuery wrapped by Lucene BooleanQuery.
*/
Expand All @@ -63,6 +63,11 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
// We use max_token_score field to help WAND scorer prune query clause in lucene 9.7. But in lucene 9.8 the inner
// logics change, this field is not needed any more.
@VisibleForTesting
@Deprecated
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score").withAllDeprecated();

private static MLCommonsClientAccessor ML_CLIENT;

Expand All @@ -73,6 +78,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
private String fieldName;
private String queryText;
private String modelId;
private Float maxTokenScore;
private Supplier<Map<String, Float>> queryTokensSupplier;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

Expand All @@ -91,6 +97,7 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
} else {
this.modelId = in.readString();
}
this.maxTokenScore = in.readOptionalFloat();
if (in.readBoolean()) {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
Expand All @@ -106,6 +113,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
} else {
out.writeString(this.modelId);
}
out.writeOptionalFloat(maxTokenScore);
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
Expand All @@ -122,6 +130,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand All @@ -131,7 +140,8 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* The expected parsing form looks like:
* "SAMPLE_FIELD": {
* "query_text": "string",
* "model_id": "string"
* "model_id": "string",
* "max_token_score": float (optional)
* }
*
* @param parser XContentParser
Expand Down Expand Up @@ -189,6 +199,8 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui
sparseEncodingQueryBuilder.queryText(parser.text());
} else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
sparseEncodingQueryBuilder.modelId(parser.text());
} else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -227,6 +239,7 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
return new NeuralSparseQueryBuilder().fieldName(fieldName)
.queryText(queryText)
.modelId(modelId)
.maxTokenScore(maxTokenScore)
.queryTokensSupplier(queryTokensSetOnce::get);
}

Expand Down Expand Up @@ -280,22 +293,23 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && !Objects.isNull(obj.queryTokensSupplier)) return false;
if (!Objects.isNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
if (obj == null || getClass() != obj.getClass()) return false;
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId);
if (!Objects.isNull(queryTokensSupplier)) {
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore);
if (queryTokensSupplier != null) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
}

@Override
protected int doHashCode() {
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId);
if (!Objects.isNull(queryTokensSupplier)) {
HashCodeBuilder builder = new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore);
if (queryTokensSupplier != null) {
builder.append(queryTokensSupplier.get());
}
return builder.toHashCode();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MAX_TOKEN_SCORE_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;
Expand All @@ -22,6 +23,9 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.junit.Before;
import org.opensearch.Version;
import org.opensearch.client.Client;
Expand All @@ -37,9 +41,11 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.util.NeuralSearchClusterTestUtils;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
Expand All @@ -54,6 +60,7 @@ public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
private static final float BOOST = 1.8f;
private static final String QUERY_NAME = "queryName";
private static final Float MAX_TOKEN_SCORE = 123f;
private static final Supplier<Map<String, Float>> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f);

@Before
Expand Down Expand Up @@ -121,6 +128,32 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning() {
/*
{
"VECTOR_FIELD": {
"query_text": "string",
"model_id": "string",
"max_token_score": 123.0
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
assertWarnings("Deprecated field [max_token_score] used, this field is unused and will be removed entirely");
}

@SneakyThrows
public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
/*
Expand Down Expand Up @@ -248,7 +281,8 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
public void testToXContent() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT);
.queryText(QUERY_TEXT)
.maxTokenScore(MAX_TOKEN_SCORE);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand All @@ -273,6 +307,7 @@ public void testToXContent() {

assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0);
}

public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() {
Expand All @@ -285,6 +320,7 @@ public void testStreams() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
original.maxTokenScore(MAX_TOKEN_SCORE);
original.modelId(MODEL_ID);
original.boost(BOOST);
original.queryName(QUERY_NAME);
Expand All @@ -306,11 +342,11 @@ public void testStreams() {
queryTokensSetOnce.set(Map.of("hello", 1.0f, "world", 2.0f));
original.queryTokensSupplier(queryTokensSetOnce::get);

BytesStreamOutput streamOutput2 = new BytesStreamOutput();
original.writeTo(streamOutput2);
streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);

filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput2.bytes().streamInput(),
streamOutput.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
Expand All @@ -327,6 +363,8 @@ public void testHashAndEquals() {
String queryText2 = "query text 2";
String modelId1 = "model-1";
String modelId2 = "model-2";
float maxTokenScore1 = 1.1f;
float maxTokenScore2 = 2.2f;
float boost1 = 1.8f;
float boost2 = 3.8f;
String queryName1 = "query-1";
Expand All @@ -337,60 +375,77 @@ public void testHashAndEquals() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1);
.modelId(modelId1)
.maxTokenScore(maxTokenScore1);

// Identical to sparseEncodingQueryBuilder_baseline except diff field name
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query text
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText2)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId2)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff boost
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost2)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except diff query name
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName2);

// Identical to sparseEncodingQueryBuilder_baseline except diff max token score
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore2)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except non-null query tokens supplier
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nonNullQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens1);
Expand All @@ -399,6 +454,7 @@ public void testHashAndEquals() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryTokens = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
.maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);
Expand Down Expand Up @@ -427,6 +483,9 @@ public void testHashAndEquals() {
assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nonNullQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode());

Expand Down Expand Up @@ -486,4 +545,23 @@ private void setUpClusterService(Version version) {
ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version);
NeuralSearchClusterUtil.instance().initialize(clusterService);
}

@SneakyThrows
public void testDoToQuery_successfulDoToQuery() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.maxTokenScore(MAX_TOKEN_SCORE)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
QueryShardContext mockedQueryShardContext = mock(QueryShardContext.class);
MappedFieldType mockedMappedFieldType = mock(MappedFieldType.class);
doAnswer(invocation -> "rank_features").when(mockedMappedFieldType).typeName();
doAnswer(invocation -> mockedMappedFieldType).when(mockedQueryShardContext).fieldMapper(any());

BooleanQuery.Builder targetQueryBuilder = new BooleanQuery.Builder();
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f), BooleanClause.Occur.SHOULD);
targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f), BooleanClause.Occur.SHOULD);

assertEquals(sparseEncodingQueryBuilder.doToQuery(mockedQueryShardContext), targetQueryBuilder.build());
}
}

0 comments on commit 7b0229d

Please sign in to comment.