Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport 2.x] enhancements: support neural_sparse query by tokens #701

Merged
merged 1 commit into from
Apr 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Allowing execution of hybrid query on index alias with filters ([#670](https://github.com/opensearch-project/neural-search/pull/670))
- Allowing query by raw tokens in neural_sparse query ([#693](https://github.com/opensearch-project/neural-search/pull/693))
### Bug Fixes
- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663))
### Infrastructure
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package org.opensearch.neuralsearch.query;

import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
Expand Down Expand Up @@ -62,26 +63,26 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder<NeuralSparseQ
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField QUERY_TOKENS_FIELD = new ParseField("query_tokens");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
// We use max_token_score field to help WAND scorer prune query clause in lucene 9.7. But in lucene 9.8 the inner
// logics change, this field is not needed any more.
@VisibleForTesting
@Deprecated
static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score").withAllDeprecated();

private static MLCommonsClientAccessor ML_CLIENT;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}

private String fieldName;
private String queryText;
private String modelId;
private Float maxTokenScore;
private Supplier<Map<String, Float>> queryTokensSupplier;
private static final Version MINIMAL_SUPPORTED_VERSION_DEFAULT_MODEL_ID = Version.V_2_13_0;

public static void initialize(MLCommonsClientAccessor mlClient) {
NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}

/**
* Constructor from stream input
*
Expand All @@ -102,21 +103,31 @@ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
Map<String, Float> queryTokens = in.readMap(StreamInput::readString, StreamInput::readFloat);
this.queryTokensSupplier = () -> queryTokens;
}
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
if (StringUtils.EMPTY.equals(this.queryText)) {
this.queryText = null;
}
if (StringUtils.EMPTY.equals(this.modelId)) {
this.modelId = null;
}
}

@Override
protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeString(queryText);
out.writeString(this.fieldName);
// to be backward compatible with previous version, we need to use writeString/readString API instead of optionalString API
// after supporting query by tokens, queryText and modelId can be null. here we write an empty String instead
out.writeString(StringUtils.defaultString(this.queryText, StringUtils.EMPTY));
if (isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
out.writeOptionalString(this.modelId);
} else {
out.writeString(this.modelId);
out.writeString(StringUtils.defaultString(this.modelId, StringUtils.EMPTY));
}
out.writeOptionalFloat(maxTokenScore);
if (!Objects.isNull(queryTokensSupplier) && !Objects.isNull(queryTokensSupplier.get())) {
if (!Objects.isNull(this.queryTokensSupplier) && !Objects.isNull(this.queryTokensSupplier.get())) {
out.writeBoolean(true);
out.writeMap(queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
out.writeMap(this.queryTokensSupplier.get(), StreamOutput::writeString, StreamOutput::writeFloat);
} else {
out.writeBoolean(false);
}
Expand All @@ -126,11 +137,16 @@ protected void doWriteTo(StreamOutput out) throws IOException {
protected void doXContent(XContentBuilder xContentBuilder, Params params) throws IOException {
xContentBuilder.startObject(NAME);
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
if (Objects.nonNull(queryText)) {
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
}
if (Objects.nonNull(modelId)) {
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
}
if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
if (Objects.nonNull(maxTokenScore)) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
if (Objects.nonNull(queryTokensSupplier) && Objects.nonNull(queryTokensSupplier.get())) {
xContentBuilder.field(QUERY_TOKENS_FIELD.getPreferredName(), queryTokensSupplier.get());
}
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
Expand All @@ -144,6 +160,16 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* "token_score_upper_bound": float (optional)
* }
*
* or
* "SAMPLE_FIELD": {
* "query_tokens": {
* "token_a": float,
* "token_b": float,
* ...
* }
* }
*
*
* @param parser XContentParser
* @return NeuralQueryBuilder
* @throws IOException can be thrown by parser
Expand Down Expand Up @@ -171,16 +197,40 @@ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throw
}

requireValue(sparseEncodingQueryBuilder.fieldName(), "Field name must be provided for " + NAME + " query");
requireValue(
sparseEncodingQueryBuilder.queryText(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", QUERY_TEXT_FIELD.getPreferredName(), NAME)
);
if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
if (Objects.isNull(sparseEncodingQueryBuilder.queryTokensSupplier())) {
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
sparseEncodingQueryBuilder.queryText(),
String.format(
Locale.ROOT,
"either %s field or %s field must be provided for [%s] query",
QUERY_TEXT_FIELD.getPreferredName(),
QUERY_TOKENS_FIELD.getPreferredName(),
NAME
)
);
if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) {
requireValue(
sparseEncodingQueryBuilder.modelId(),
String.format(
Locale.ROOT,
"using %s, %s field must be provided for [%s] query",
QUERY_TEXT_FIELD.getPreferredName(),
MODEL_ID_FIELD.getPreferredName(),
NAME
)
);
}
}

if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.queryText())) {
throw new IllegalArgumentException(
String.format(Locale.ROOT, "%s field can not be empty", QUERY_TEXT_FIELD.getPreferredName())
);
}
if (StringUtils.EMPTY.equals(sparseEncodingQueryBuilder.modelId())) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "%s field can not be empty", MODEL_ID_FIELD.getPreferredName()));
}

return sparseEncodingQueryBuilder;
}

Expand All @@ -207,6 +257,9 @@ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBui
String.format(Locale.ROOT, "[%s] query does not support [%s] field", NAME, currentFieldName)
);
}
} else if (QUERY_TOKENS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
Map<String, Float> queryTokens = parser.map(HashMap::new, XContentParser::floatValue);
sparseEncodingQueryBuilder.queryTokensSupplier(() -> queryTokens);
} else {
throw new ParsingException(
parser.getTokenLocation(),
Expand Down Expand Up @@ -293,14 +346,14 @@ private static void validateQueryTokens(Map<String, Float> queryTokens) {
@Override
protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
if (queryTokensSupplier == null && obj.queryTokensSupplier != null) return false;
if (queryTokensSupplier != null && obj.queryTokensSupplier == null) return false;
if (Objects.isNull(obj) || getClass() != obj.getClass()) return false;
if (Objects.isNull(queryTokensSupplier) && Objects.nonNull(obj.queryTokensSupplier)) return false;
if (Objects.nonNull(queryTokensSupplier) && Objects.isNull(obj.queryTokensSupplier)) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
.append(modelId, obj.modelId)
.append(maxTokenScore, obj.maxTokenScore);
if (queryTokensSupplier != null) {
if (Objects.nonNull(queryTokensSupplier)) {
equalsBuilder.append(queryTokensSupplier.get(), obj.queryTokensSupplier.get());
}
return equalsBuilder.isEquals();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;
import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TOKENS_FIELD;

import java.io.IOException;
import java.util.List;
Expand All @@ -23,6 +24,7 @@
import java.util.function.BiConsumer;
import java.util.function.Supplier;

import org.apache.commons.lang.StringUtils;
import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
Expand Down Expand Up @@ -95,6 +97,32 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() {
assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithQueryTokens_thenBuildSuccessfully() {
/*
{
"VECTOR_FIELD": {
"query_tokens": {
"token_a": float_score_a,
"token_b": float_score_b
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TOKENS_FIELD.getPreferredName(), QUERY_TOKENS_SUPPLIER.get())
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);

assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TOKENS_SUPPLIER.get(), sparseEncodingQueryBuilder.queryTokensSupplier().get());
}

@SneakyThrows
public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
/*
Expand Down Expand Up @@ -276,13 +304,56 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
public void testFromXContent_whenBuildWithEmptyQuery_thenFail() {
/*
{
"VECTOR_FIELD": {
"query_text": ""
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), StringUtils.EMPTY)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SneakyThrows
public void testFromXContent_whenBuildWithEmptyModelId_thenFail() {
/*
{
"VECTOR_FIELD": {
"model_id": ""
}
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
.field(MODEL_ID_FIELD.getPreferredName(), StringUtils.EMPTY)
.endObject()
.endObject();

XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}

@SuppressWarnings("unchecked")
@SneakyThrows
public void testToXContent() {
NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
.queryText(QUERY_TEXT)
.maxTokenScore(MAX_TOKEN_SCORE);
.maxTokenScore(MAX_TOKEN_SCORE)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);

XContentBuilder builder = XContentFactory.jsonBuilder();
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
Expand All @@ -308,15 +379,27 @@ public void testToXContent() {
assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0);
Map<String, Double> parsedQueryTokens = (Map<String, Double>) secondInnerMap.get(QUERY_TOKENS_FIELD.getPreferredName());
assertEquals(QUERY_TOKENS_SUPPLIER.get().keySet(), parsedQueryTokens.keySet());
for (Map.Entry<String, Float> entry : QUERY_TOKENS_SUPPLIER.get().entrySet()) {
assertEquals(entry.getValue(), parsedQueryTokens.get(entry.getKey()).floatValue(), 0);
}
}

public void testStreams_whenCurrentVersion_thenSuccess() {
setUpClusterService(Version.CURRENT);
testStreams();
testStreamsWithQueryTokensOnly();
}

public void testStreams_whenMinVersionIsBeforeDefaultModelId_thenSuccess() {
setUpClusterService(Version.V_2_12_0);
testStreams();
testStreamsWithQueryTokensOnly();
}

@SneakyThrows
public void testStreams() {
private void testStreams() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
Expand Down Expand Up @@ -356,6 +439,26 @@ public void testStreams() {
assertEquals(original, copy);
}

@SneakyThrows
private void testStreamsWithQueryTokensOnly() {
NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);

BytesStreamOutput streamOutput = new BytesStreamOutput();
original.writeTo(streamOutput);

FilterStreamInput filterStreamInput = new NamedWriteableAwareStreamInput(
streamOutput.bytes().streamInput(),
new NamedWriteableRegistry(
List.of(new NamedWriteableRegistry.Entry(QueryBuilder.class, MatchAllQueryBuilder.NAME, MatchAllQueryBuilder::new))
)
);

NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}

public void testHashAndEquals() {
String fieldName1 = "field 1";
String fieldName2 = "field 2";
Expand Down Expand Up @@ -459,6 +562,18 @@ public void testHashAndEquals() {
.queryName(queryName1)
.queryTokensSupplier(() -> queryTokens2);

// Identical to sparseEncodingQueryBuilder_baseline except null query text
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.modelId(modelId1)
.boost(boost1)
.queryName(queryName1);

// Identical to sparseEncodingQueryBuilder_baseline except null model id
NeuralSparseQueryBuilder sparseEncodingQueryBuilder_nullModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.boost(boost1)
.queryName(queryName1);

assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());

Expand Down Expand Up @@ -491,6 +606,12 @@ public void testHashAndEquals() {

assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens, sparseEncodingQueryBuilder_diffQueryTokens);
assertNotEquals(sparseEncodingQueryBuilder_nonNullQueryTokens.hashCode(), sparseEncodingQueryBuilder_diffQueryTokens.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullQueryText);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullQueryText.hashCode());

assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_nullModelId);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_nullModelId.hashCode());
}

@SneakyThrows
Expand Down
Loading
Loading