diff --git a/CHANGELOG.md b/CHANGELOG.md
index da2ae9ec9..15532572f 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333))
### Enhancements
+Add `max_token_score` parameter to improve the execution efficiency for `neural_sparse` query clause ([#348](https://github.com/opensearch-project/neural-search/pull/348))
### Bug Fixes
### Infrastructure
### Documentation
diff --git a/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java b/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java
new file mode 100644
index 000000000..617662363
--- /dev/null
+++ b/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java
@@ -0,0 +1,235 @@
+/*
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ * The OpenSearch Contributors require contributions made to
+ * this file be licensed under the Apache-2.0 license or a
+ * compatible open source license.
+ */
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/*
+ * Modifications Copyright OpenSearch Contributors. See
+ * GitHub history for details.
+ */
+
+/*
+ * This class is built based on lucene FeatureQuery. We use LinearFuntion to
+ * build the query and add an upperbound to it.
+ */
+
+package org.apache.lucene;
+
+import java.io.IOException;
+import java.util.Objects;
+
+import org.apache.lucene.index.ImpactsEnum;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.PostingsEnum;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.index.Terms;
+import org.apache.lucene.index.TermsEnum;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.ImpactsDISI;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.QueryVisitor;
+import org.apache.lucene.search.ScoreMode;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.search.similarities.Similarity.SimScorer;
+import org.apache.lucene.util.BytesRef;
+
+/**
+ * The feature queries of input tokens are wrapped by lucene BooleanQuery, which use WAND algorithm
+ * to accelerate the execution. The WAND algorithm leverage the score upper bound of sub-queries to
+ * skip non-competitive tokens. However, origin lucene FeatureQuery use Float.MAX_VALUE as the score
+ * upper bound, and this invalidates WAND.
+ *
+ * To mitigate this issue, we rewrite the FeatureQuery to BoundedLinearFeatureQuery. The caller can
+ * set the token score upperbound of this query. And according to our use case, we use LinearFunction
+ * as the score function.
+ *
+ * This class combines both FeatureQuery
+ * and FeatureField together
+ * and will be deprecated after OpenSearch upgraded lucene to version 9.8.
+ */
+
+public final class BoundedLinearFeatureQuery extends Query {
+
+ private final String fieldName;
+ private final String featureName;
+ private final Float scoreUpperBound;
+
+ public BoundedLinearFeatureQuery(String fieldName, String featureName, Float scoreUpperBound) {
+ this.fieldName = Objects.requireNonNull(fieldName);
+ this.featureName = Objects.requireNonNull(featureName);
+ this.scoreUpperBound = Objects.requireNonNull(scoreUpperBound);
+ }
+
+ @Override
+ public Query rewrite(IndexSearcher indexSearcher) throws IOException {
+ // LinearFunction return same object for rewrite
+ return super.rewrite(indexSearcher);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || getClass() != obj.getClass()) {
+ return false;
+ }
+ BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj;
+ return Objects.equals(fieldName, that.fieldName)
+ && Objects.equals(featureName, that.featureName)
+ && Objects.equals(scoreUpperBound, that.scoreUpperBound);
+ }
+
+ @Override
+ public int hashCode() {
+ int h = getClass().hashCode();
+ h = 31 * h + fieldName.hashCode();
+ h = 31 * h + featureName.hashCode();
+ h = 31 * h + scoreUpperBound.hashCode();
+ return h;
+ }
+
+ @Override
+ public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
+ if (!scoreMode.needsScores()) {
+ // We don't need scores (e.g. for faceting), and since features are stored as terms,
+ // allow TermQuery to optimize in this case
+ TermQuery tq = new TermQuery(new Term(fieldName, featureName));
+ return searcher.rewrite(tq).createWeight(searcher, scoreMode, boost);
+ }
+
+ return new Weight(this) {
+
+ @Override
+ public boolean isCacheable(LeafReaderContext ctx) {
+ return false;
+ }
+
+ @Override
+ public Explanation explain(LeafReaderContext context, int doc) throws IOException {
+ String desc = "weight(" + getQuery() + " in " + doc + ") [\" BoundedLinearFeatureQuery \"]";
+
+ Terms terms = context.reader().terms(fieldName);
+ if (terms == null) {
+ return Explanation.noMatch(desc + ". Field " + fieldName + " doesn't exist.");
+ }
+ TermsEnum termsEnum = terms.iterator();
+ if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
+ return Explanation.noMatch(desc + ". Feature " + featureName + " doesn't exist.");
+ }
+
+ PostingsEnum postings = termsEnum.postings(null, PostingsEnum.FREQS);
+ if (postings.advance(doc) != doc) {
+ return Explanation.noMatch(desc + ". Feature " + featureName + " isn't set.");
+ }
+
+ int freq = postings.freq();
+ float featureValue = decodeFeatureValue(freq);
+ float score = boost * featureValue;
+ return Explanation.match(
+ score,
+ "Linear function on the " + fieldName + " field for the " + featureName + " feature, computed as w * S from:",
+ Explanation.match(boost, "w, weight of this function"),
+ Explanation.match(featureValue, "S, feature value")
+ );
+ }
+
+ @Override
+ public Scorer scorer(LeafReaderContext context) throws IOException {
+ Terms terms = Terms.getTerms(context.reader(), fieldName);
+ TermsEnum termsEnum = terms.iterator();
+ if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
+ return null;
+ }
+
+ final SimScorer scorer = new SimScorer() {
+ @Override
+ public float score(float freq, long norm) {
+ return boost * decodeFeatureValue(freq);
+ }
+ };
+ final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS);
+ final ImpactsDISI impactsDisi = new ImpactsDISI(impacts, impacts, scorer);
+
+ return new Scorer(this) {
+
+ @Override
+ public int docID() {
+ return impacts.docID();
+ }
+
+ @Override
+ public float score() throws IOException {
+ return scorer.score(impacts.freq(), 1L);
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return impactsDisi;
+ }
+
+ @Override
+ public int advanceShallow(int target) throws IOException {
+ return impactsDisi.advanceShallow(target);
+ }
+
+ @Override
+ public float getMaxScore(int upTo) throws IOException {
+ return impactsDisi.getMaxScore(upTo);
+ }
+
+ @Override
+ public void setMinCompetitiveScore(float minScore) {
+ impactsDisi.setMinCompetitiveScore(minScore);
+ }
+ };
+ }
+ };
+ }
+
+ @Override
+ public void visit(QueryVisitor visitor) {
+ if (visitor.acceptField(fieldName)) {
+ visitor.visitLeaf(this);
+ }
+ }
+
+ @Override
+ public String toString(String field) {
+ return "BoundedLinearFeatureQuery(field=" + fieldName + ", feature=" + featureName + ", scoreUpperBound=" + scoreUpperBound + ")";
+ }
+
+ // the field and decodeFeatureValue are modified from FeatureField.decodeFeatureValue
+ static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15;
+
+ // Rewriting this function to make scoreUpperBound work.
+ private float decodeFeatureValue(float freq) {
+ if (freq > MAX_FREQ) {
+ return scoreUpperBound;
+ }
+ int tf = (int) freq; // lossless
+ int featureBits = tf << 15;
+ return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound);
+ }
+}
diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
index 2ac8853e4..3ce9582f4 100644
--- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
+++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
@@ -41,7 +41,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
-import org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder;
+import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.plugins.ActionPlugin;
import org.opensearch.plugins.ExtensiblePlugin;
@@ -81,7 +81,7 @@ public Collection createComponents(
final Supplier repositoriesServiceSupplier
) {
NeuralQueryBuilder.initialize(clientAccessor);
- SparseEncodingQueryBuilder.initialize(clientAccessor);
+ NeuralSparseQueryBuilder.initialize(clientAccessor);
normalizationProcessorWorkflow = new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner());
return List.of(clientAccessor);
}
@@ -91,7 +91,7 @@ public List> getQueries() {
return Arrays.asList(
new QuerySpec<>(NeuralQueryBuilder.NAME, NeuralQueryBuilder::new, NeuralQueryBuilder::fromXContent),
new QuerySpec<>(HybridQueryBuilder.NAME, HybridQueryBuilder::new, HybridQueryBuilder::fromXContent),
- new QuerySpec<>(SparseEncodingQueryBuilder.NAME, SparseEncodingQueryBuilder::new, SparseEncodingQueryBuilder::fromXContent)
+ new QuerySpec<>(NeuralSparseQueryBuilder.NAME, NeuralSparseQueryBuilder::new, NeuralSparseQueryBuilder::fromXContent)
);
}
diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java
similarity index 97%
rename from src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java
rename to src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java
index 4ac63d419..acf0eb32b 100644
--- a/src/main/java/org/opensearch/neuralsearch/processor/NLPProcessor.java
+++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java
@@ -32,7 +32,7 @@
* and set the target fields according to the field name map.
*/
@Log4j2
-public abstract class NLPProcessor extends AbstractProcessor {
+public abstract class InferenceProcessor extends AbstractProcessor {
public static final String MODEL_ID_FIELD = "model_id";
public static final String FIELD_MAP_FIELD = "field_map";
@@ -51,7 +51,7 @@ public abstract class NLPProcessor extends AbstractProcessor {
private final Environment environment;
- public NLPProcessor(
+ public InferenceProcessor(
String tag,
String description,
String type,
@@ -249,7 +249,7 @@ protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map buildNLPResult(Map processorMap, List> results, Map sourceAndMetadataMap) {
- NLPProcessor.IndexWrapper indexWrapper = new NLPProcessor.IndexWrapper(0);
+ InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0);
Map result = new LinkedHashMap<>();
for (Map.Entry knnMapEntry : processorMap.entrySet()) {
String knnKey = knnMapEntry.getKey();
@@ -270,7 +270,7 @@ private void putNLPResultToSourceMapForMapType(
String processorKey,
Object sourceValue,
List> results,
- NLPProcessor.IndexWrapper indexWrapper,
+ InferenceProcessor.IndexWrapper indexWrapper,
Map sourceAndMetadataMap
) {
if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return;
@@ -294,7 +294,7 @@ private void putNLPResultToSourceMapForMapType(
private List> buildNLPResultForListType(
List sourceValue,
List> results,
- NLPProcessor.IndexWrapper indexWrapper
+ InferenceProcessor.IndexWrapper indexWrapper
) {
List> keyToResult = new ArrayList<>();
IntStream.range(0, sourceValue.size())
diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java
index 275117809..b5bb85aac 100644
--- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java
+++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java
@@ -22,7 +22,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the sparse encoding results.
*/
@Log4j2
-public final class SparseEncodingProcessor extends NLPProcessor {
+public final class SparseEncodingProcessor extends InferenceProcessor {
public static final String TYPE = "sparse_encoding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "sparse_encoding";
diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java
index 1df60baea..c30d14caf 100644
--- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java
+++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java
@@ -21,7 +21,7 @@
* and field_map can be used to indicate which fields needs text embedding and the corresponding keys for the text embedding results.
*/
@Log4j2
-public final class TextEmbeddingProcessor extends NLPProcessor {
+public final class TextEmbeddingProcessor extends InferenceProcessor {
public static final String TYPE = "text_embedding";
public static final String LIST_TYPE_NESTED_MAP_KEY = "knn";
diff --git a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java
similarity index 81%
rename from src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java
rename to src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java
index 4b8b6f0d4..fd15b431b 100644
--- a/src/main/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilder.java
+++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java
@@ -21,9 +21,10 @@
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
-import org.apache.lucene.document.FeatureField;
+import org.apache.lucene.BoundedLinearFeatureQuery;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
@@ -44,7 +45,7 @@
import com.google.common.annotations.VisibleForTesting;
/**
- * SparseEncodingQueryBuilder is responsible for handling "sparse_encoding" query types. It uses an ML SPARSE_ENCODING model
+ * SparseEncodingQueryBuilder is responsible for handling "neural_sparse" query types. It uses an ML NEURAL_SPARSE model
* or SPARSE_TOKENIZE model to produce a Map with String keys and Float values for input text. Then it will be transformed
* to Lucene FeatureQuery wrapped by Lucene BooleanQuery.
*/
@@ -55,22 +56,25 @@
@Accessors(chain = true, fluent = true)
@NoArgsConstructor
@AllArgsConstructor
-public class SparseEncodingQueryBuilder extends AbstractQueryBuilder {
- public static final String NAME = "sparse_encoding";
+public class NeuralSparseQueryBuilder extends AbstractQueryBuilder {
+ public static final String NAME = "neural_sparse";
@VisibleForTesting
static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text");
@VisibleForTesting
static final ParseField MODEL_ID_FIELD = new ParseField("model_id");
+ @VisibleForTesting
+ static final ParseField MAX_TOKEN_SCORE_FIELD = new ParseField("max_token_score");
private static MLCommonsClientAccessor ML_CLIENT;
public static void initialize(MLCommonsClientAccessor mlClient) {
- SparseEncodingQueryBuilder.ML_CLIENT = mlClient;
+ NeuralSparseQueryBuilder.ML_CLIENT = mlClient;
}
private String fieldName;
private String queryText;
private String modelId;
+ private Float maxTokenScore;
private Supplier> queryTokensSupplier;
/**
@@ -79,11 +83,12 @@ public static void initialize(MLCommonsClientAccessor mlClient) {
* @param in StreamInput to initialize object from
* @throws IOException thrown if unable to read from input stream
*/
- public SparseEncodingQueryBuilder(StreamInput in) throws IOException {
+ public NeuralSparseQueryBuilder(StreamInput in) throws IOException {
super(in);
this.fieldName = in.readString();
this.queryText = in.readString();
this.modelId = in.readString();
+ this.maxTokenScore = in.readOptionalFloat();
}
@Override
@@ -91,6 +96,7 @@ protected void doWriteTo(StreamOutput out) throws IOException {
out.writeString(fieldName);
out.writeString(queryText);
out.writeString(modelId);
+ out.writeOptionalFloat(maxTokenScore);
}
@Override
@@ -99,6 +105,7 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
xContentBuilder.startObject(fieldName);
xContentBuilder.field(QUERY_TEXT_FIELD.getPreferredName(), queryText);
xContentBuilder.field(MODEL_ID_FIELD.getPreferredName(), modelId);
+ if (maxTokenScore != null) xContentBuilder.field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), maxTokenScore);
printBoostAndQueryName(xContentBuilder);
xContentBuilder.endObject();
xContentBuilder.endObject();
@@ -108,15 +115,16 @@ protected void doXContent(XContentBuilder xContentBuilder, Params params) throws
* The expected parsing form looks like:
* "SAMPLE_FIELD": {
* "query_text": "string",
- * "model_id": "string"
+ * "model_id": "string",
+ * "token_score_upper_bound": float (optional)
* }
*
* @param parser XContentParser
* @return NeuralQueryBuilder
* @throws IOException can be thrown by parser
*/
- public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) throws IOException {
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder();
+ public static NeuralSparseQueryBuilder fromXContent(XContentParser parser) throws IOException {
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder();
if (parser.currentToken() != XContentParser.Token.START_OBJECT) {
throw new ParsingException(parser.getTokenLocation(), "First token of " + NAME + "query must be START_OBJECT");
}
@@ -146,11 +154,14 @@ public static SparseEncodingQueryBuilder fromXContent(XContentParser parser) thr
sparseEncodingQueryBuilder.modelId(),
String.format(Locale.ROOT, "%s field must be provided for [%s] query", MODEL_ID_FIELD.getPreferredName(), NAME)
);
+ if (sparseEncodingQueryBuilder.maxTokenScore != null && sparseEncodingQueryBuilder.maxTokenScore <= 0) {
+ throw new IllegalArgumentException(MAX_TOKEN_SCORE_FIELD.getPreferredName() + " must be larger than 0.");
+ }
return sparseEncodingQueryBuilder;
}
- private static void parseQueryParams(XContentParser parser, SparseEncodingQueryBuilder sparseEncodingQueryBuilder) throws IOException {
+ private static void parseQueryParams(XContentParser parser, NeuralSparseQueryBuilder sparseEncodingQueryBuilder) throws IOException {
XContentParser.Token token;
String currentFieldName = "";
while ((token = parser.nextToken()) != XContentParser.Token.END_OBJECT) {
@@ -165,6 +176,8 @@ private static void parseQueryParams(XContentParser parser, SparseEncodingQueryB
sparseEncodingQueryBuilder.queryText(parser.text());
} else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
sparseEncodingQueryBuilder.modelId(parser.text());
+ } else if (MAX_TOKEN_SCORE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) {
+ sparseEncodingQueryBuilder.maxTokenScore(parser.floatValue());
} else {
throw new ParsingException(
parser.getTokenLocation(),
@@ -200,9 +213,10 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) throws
}, actionListener::onFailure)
))
);
- return new SparseEncodingQueryBuilder().fieldName(fieldName)
+ return new NeuralSparseQueryBuilder().fieldName(fieldName)
.queryText(queryText)
.modelId(modelId)
+ .maxTokenScore(maxTokenScore)
.queryTokensSupplier(queryTokensSetOnce::get);
}
@@ -214,9 +228,14 @@ protected Query doToQuery(QueryShardContext context) throws IOException {
Map queryTokens = queryTokensSupplier.get();
validateQueryTokens(queryTokens);
+ final Float scoreUpperBound = maxTokenScore != null ? maxTokenScore : Float.MAX_VALUE;
+
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (Map.Entry entry : queryTokens.entrySet()) {
- builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD);
+ builder.add(
+ new BoostQuery(new BoundedLinearFeatureQuery(fieldName, entry.getKey(), scoreUpperBound), entry.getValue()),
+ BooleanClause.Occur.SHOULD
+ );
}
return builder.build();
}
@@ -254,18 +273,19 @@ private static void validateQueryTokens(Map queryTokens) {
}
@Override
- protected boolean doEquals(SparseEncodingQueryBuilder obj) {
+ protected boolean doEquals(NeuralSparseQueryBuilder obj) {
if (this == obj) return true;
if (obj == null || getClass() != obj.getClass()) return false;
EqualsBuilder equalsBuilder = new EqualsBuilder().append(fieldName, obj.fieldName)
.append(queryText, obj.queryText)
- .append(modelId, obj.modelId);
+ .append(modelId, obj.modelId)
+ .append(maxTokenScore, obj.maxTokenScore);
return equalsBuilder.isEquals();
}
@Override
protected int doHashCode() {
- return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).toHashCode();
+ return new HashCodeBuilder().append(fieldName).append(queryText).append(modelId).append(maxTokenScore).toHashCode();
}
@Override
diff --git a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java
index 76ce0fa16..853fc743d 100644
--- a/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java
+++ b/src/main/java/org/opensearch/neuralsearch/util/TokenWeightUtil.java
@@ -12,7 +12,7 @@
import java.util.stream.Collectors;
/**
- * Utility class for working with sparse_encoding queries and ingest processor.
+ * Utility class for working with neural_sparse queries and ingest processor.
* Used to fetch the (token, weight) Map from the response returned by {@link org.opensearch.neuralsearch.ml.MLCommonsClientAccessor}
*
*/
diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java
similarity index 73%
rename from src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java
rename to src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java
index 6cb122c4f..34850dcb7 100644
--- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryBuilderTests.java
+++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java
@@ -11,9 +11,10 @@
import static org.opensearch.index.query.AbstractQueryBuilder.BOOST_FIELD;
import static org.opensearch.index.query.AbstractQueryBuilder.NAME_FIELD;
import static org.opensearch.neuralsearch.TestUtils.xContentBuilderToMap;
-import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.MODEL_ID_FIELD;
-import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.NAME;
-import static org.opensearch.neuralsearch.query.SparseEncodingQueryBuilder.QUERY_TEXT_FIELD;
+import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MAX_TOKEN_SCORE_FIELD;
+import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.MODEL_ID_FIELD;
+import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.NAME;
+import static org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder.QUERY_TEXT_FIELD;
import java.io.IOException;
import java.util.List;
@@ -42,13 +43,14 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.test.OpenSearchTestCase;
-public class SparseEncodingQueryBuilderTests extends OpenSearchTestCase {
+public class NeuralSparseQueryBuilderTests extends OpenSearchTestCase {
private static final String FIELD_NAME = "testField";
private static final String QUERY_TEXT = "Hello world!";
private static final String MODEL_ID = "mfgfgdsfgfdgsde";
private static final float BOOST = 1.8f;
private static final String QUERY_NAME = "queryName";
+ private static final Float MAX_TOKEN_SCORE = 123f;
private static final Supplier> QUERY_TOKENS_SUPPLIER = () -> Map.of("hello", 1.f, "world", 2.f);
@SneakyThrows
@@ -71,7 +73,7 @@ public void testFromXContent_whenBuiltWithQueryText_thenBuildSuccessfully() {
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
@@ -85,6 +87,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
"VECTOR_FIELD": {
"query_text": "string",
"model_id": "string",
+ "max_token_score": 123.0,
"boost": 10.0,
"_name": "something",
}
@@ -95,6 +98,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
+ .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
.field(BOOST_FIELD.getPreferredName(), BOOST)
.field(NAME_FIELD.getPreferredName(), QUERY_NAME)
.endObject()
@@ -102,11 +106,12 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = SparseEncodingQueryBuilder.fromXContent(contentParser);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId());
+ assertEquals(MAX_TOKEN_SCORE, sparseEncodingQueryBuilder.maxTokenScore(), 0.0);
assertEquals(BOOST, sparseEncodingQueryBuilder.boost(), 0.0);
assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName());
}
@@ -137,7 +142,7 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- expectThrows(ParsingException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
+ expectThrows(ParsingException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}
@SneakyThrows
@@ -158,7 +163,31 @@ public void testFromXContent_whenBuildWithMissingQuery_thenFail() {
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
+ expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
+ }
+
+ @SneakyThrows
+ public void testFromXContent_whenBuildWithNegativeMaxTokenScore_thenFail() {
+ /*
+ {
+ "VECTOR_FIELD": {
+ "query_text": "string",
+ "model_id": "string",
+ "max_token_score": -1
+ }
+ }
+ */
+ XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
+ .startObject()
+ .startObject(FIELD_NAME)
+ .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
+ .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), -1f)
+ .endObject()
+ .endObject();
+
+ XContentParser contentParser = createParser(xContentBuilder);
+ contentParser.nextToken();
+ expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}
@SneakyThrows
@@ -179,7 +208,7 @@ public void testFromXContent_whenBuildWithMissingModelId_thenFail() {
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- expectThrows(IllegalArgumentException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
+ expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}
@SneakyThrows
@@ -206,15 +235,16 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() {
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- expectThrows(IOException.class, () -> SparseEncodingQueryBuilder.fromXContent(contentParser));
+ expectThrows(IOException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}
@SuppressWarnings("unchecked")
@SneakyThrows
public void testToXContent() {
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.modelId(MODEL_ID)
- .queryText(QUERY_TEXT);
+ .queryText(QUERY_TEXT)
+ .maxTokenScore(MAX_TOKEN_SCORE);
XContentBuilder builder = XContentFactory.jsonBuilder();
builder = sparseEncodingQueryBuilder.toXContent(builder, ToXContent.EMPTY_PARAMS);
@@ -239,13 +269,15 @@ public void testToXContent() {
assertEquals(MODEL_ID, secondInnerMap.get(MODEL_ID_FIELD.getPreferredName()));
assertEquals(QUERY_TEXT, secondInnerMap.get(QUERY_TEXT_FIELD.getPreferredName()));
+ assertEquals(MAX_TOKEN_SCORE, (Double) secondInnerMap.get(MAX_TOKEN_SCORE_FIELD.getPreferredName()), 0.0);
}
@SneakyThrows
public void testStreams() {
- SparseEncodingQueryBuilder original = new SparseEncodingQueryBuilder();
+ NeuralSparseQueryBuilder original = new NeuralSparseQueryBuilder();
original.fieldName(FIELD_NAME);
original.queryText(QUERY_TEXT);
+ original.maxTokenScore(MAX_TOKEN_SCORE);
original.modelId(MODEL_ID);
original.boost(BOOST);
original.queryName(QUERY_NAME);
@@ -260,7 +292,7 @@ public void testStreams() {
)
);
- SparseEncodingQueryBuilder copy = new SparseEncodingQueryBuilder(filterStreamInput);
+ NeuralSparseQueryBuilder copy = new NeuralSparseQueryBuilder(filterStreamInput);
assertEquals(original, copy);
}
@@ -271,64 +303,82 @@ public void testHashAndEquals() {
String queryText2 = "query text 2";
String modelId1 = "model-1";
String modelId2 = "model-2";
+ float maxTokenScore1 = 1.1f;
+ float maxTokenScore2 = 2.2f;
float boost1 = 1.8f;
float boost2 = 3.8f;
String queryName1 = "query-1";
String queryName2 = "query-2";
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baseline = new SparseEncodingQueryBuilder().fieldName(fieldName1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baseline = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
+ .maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);
// Identical to sparseEncodingQueryBuilder_baseline
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new SparseEncodingQueryBuilder().fieldName(fieldName1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_baselineCopy = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
+ .maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);
// Identical to sparseEncodingQueryBuilder_baseline except default boost and query name
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new SparseEncodingQueryBuilder().fieldName(
- fieldName1
- ).queryText(queryText1).modelId(modelId1);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_defaultBoostAndQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
+ .queryText(queryText1)
+ .modelId(modelId1)
+ .maxTokenScore(maxTokenScore1);
// Identical to sparseEncodingQueryBuilder_baseline except diff field name
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new SparseEncodingQueryBuilder().fieldName(fieldName2)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffFieldName = new NeuralSparseQueryBuilder().fieldName(fieldName2)
.queryText(queryText1)
.modelId(modelId1)
+ .maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);
// Identical to sparseEncodingQueryBuilder_baseline except diff query text
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new SparseEncodingQueryBuilder().fieldName(fieldName1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryText = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText2)
.modelId(modelId1)
+ .maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);
// Identical to sparseEncodingQueryBuilder_baseline except diff model ID
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffModelId = new SparseEncodingQueryBuilder().fieldName(fieldName1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffModelId = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId2)
+ .maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName1);
// Identical to sparseEncodingQueryBuilder_baseline except diff boost
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffBoost = new SparseEncodingQueryBuilder().fieldName(fieldName1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffBoost = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
+ .maxTokenScore(maxTokenScore1)
.boost(boost2)
.queryName(queryName1);
// Identical to sparseEncodingQueryBuilder_baseline except diff query name
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new SparseEncodingQueryBuilder().fieldName(fieldName1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffQueryName = new NeuralSparseQueryBuilder().fieldName(fieldName1)
.queryText(queryText1)
.modelId(modelId1)
+ .maxTokenScore(maxTokenScore1)
.boost(boost1)
.queryName(queryName2);
+ // Identical to sparseEncodingQueryBuilder_baseline except diff max token score
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder_diffMaxTokenScore = new NeuralSparseQueryBuilder().fieldName(fieldName1)
+ .queryText(queryText1)
+ .modelId(modelId1)
+ .maxTokenScore(maxTokenScore2)
+ .boost(boost1)
+ .queryName(queryName1);
+
assertEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_baseline);
assertEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_baseline.hashCode());
@@ -352,11 +402,14 @@ public void testHashAndEquals() {
assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffQueryName);
assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffQueryName.hashCode());
+
+ assertNotEquals(sparseEncodingQueryBuilder_baseline, sparseEncodingQueryBuilder_diffMaxTokenScore);
+ assertNotEquals(sparseEncodingQueryBuilder_baseline.hashCode(), sparseEncodingQueryBuilder_diffMaxTokenScore.hashCode());
}
@SneakyThrows
public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier() {
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID);
Map expectedMap = Map.of("1", 1f, "2", 2f);
@@ -366,7 +419,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
listener.onResponse(List.of(Map.of("response", List.of(expectedMap))));
return null;
}).when(mlCommonsClientAccessor).inferenceSentencesWithMapResult(any(), any(), any());
- SparseEncodingQueryBuilder.initialize(mlCommonsClientAccessor);
+ NeuralSparseQueryBuilder.initialize(mlCommonsClientAccessor);
final CountDownLatch inProgressLatch = new CountDownLatch(1);
QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class);
@@ -382,7 +435,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
return null;
}).when(queryRewriteContext).registerAsyncAction(any());
- SparseEncodingQueryBuilder queryBuilder = (SparseEncodingQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
+ NeuralSparseQueryBuilder queryBuilder = (NeuralSparseQueryBuilder) sparseEncodingQueryBuilder.doRewrite(queryRewriteContext);
assertNotNull(queryBuilder.queryTokensSupplier());
assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS));
assertEquals(expectedMap, queryBuilder.queryTokensSupplier().get());
@@ -390,7 +443,7 @@ public void testRewrite_whenQueryTokensSupplierNull_thenSetQueryTokensSupplier()
@SneakyThrows
public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(FIELD_NAME)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
.queryText(QUERY_TEXT)
.modelId(MODEL_ID)
.queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
diff --git a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java
similarity index 61%
rename from src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java
rename to src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java
index 54991d7e2..672ab2940 100644
--- a/src/test/java/org/opensearch/neuralsearch/query/SparseEncodingQueryIT.java
+++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java
@@ -21,16 +21,16 @@
import org.opensearch.neuralsearch.TestUtils;
import org.opensearch.neuralsearch.common.BaseSparseEncodingIT;
-public class SparseEncodingQueryIT extends BaseSparseEncodingIT {
+public class NeuralSparseQueryIT extends BaseSparseEncodingIT {
private static final String TEST_BASIC_INDEX_NAME = "test-sparse-basic-index";
- private static final String TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-multi-field-index";
- private static final String TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME = "test-sparse-text-and-field-index";
+ private static final String TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME = "test-sparse-multi-field-index";
+ private static final String TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME = "test-sparse-text-and-field-index";
private static final String TEST_NESTED_INDEX_NAME = "test-sparse-nested-index";
private static final String TEST_QUERY_TEXT = "Hello world a b";
- private static final String TEST_SPARSE_ENCODING_FIELD_NAME_1 = "test-sparse-encoding-1";
- private static final String TEST_SPARSE_ENCODING_FIELD_NAME_2 = "test-sparse-encoding-2";
+ private static final String TEST_NEURAL_SPARSE_FIELD_NAME_1 = "test-sparse-encoding-1";
+ private static final String TEST_NEURAL_SPARSE_FIELD_NAME_2 = "test-sparse-encoding-2";
private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field";
- private static final String TEST_SPARSE_ENCODING_FIELD_NAME_NESTED = "nested.sparse_encoding.field";
+ private static final String TEST_NEURAL_SPARSE_FIELD_NAME_NESTED = "nested.neural_sparse.field";
private static final List TEST_TOKENS = List.of("hello", "world", "a", "b", "c");
@@ -55,7 +55,7 @@ public void tearDown() {
* Tests basic query:
* {
* "query": {
- * "sparse_encoding": {
+ * "neural_sparse": {
* "text_sparse": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd"
@@ -68,9 +68,9 @@ public void tearDown() {
public void testBasicQueryUsingQueryText() {
initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME);
String modelId = getDeployedModelId();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(
- TEST_SPARSE_ENCODING_FIELD_NAME_1
- ).queryText(TEST_QUERY_TEXT).modelId(modelId);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId);
Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1);
Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
@@ -83,7 +83,47 @@ public void testBasicQueryUsingQueryText() {
* Tests basic query:
* {
* "query": {
- * "sparse_encoding": {
+ * "neural_sparse": {
+ * "text_sparse": {
+ * "query_text": "Hello world a b",
+ * "model_id": "dcsdcasd",
+ * "max_token_score": float
+ * }
+ * }
+ * }
+ * }
+ */
+ @SneakyThrows
+ public void testBasicQueryWithMaxTokenScore() {
+ float maxTokenScore = 0.00001f;
+ initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME);
+ String modelId = getDeployedModelId();
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId)
+ .maxTokenScore(maxTokenScore);
+ Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1);
+ Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
+
+ assertEquals("1", firstInnerHit.get("_id"));
+ Map queryTokens = runSparseModelInference(modelId, TEST_QUERY_TEXT);
+ float expectedScore = 0f;
+ for (Map.Entry entry : queryTokens.entrySet()) {
+ if (testRankFeaturesDoc.containsKey(entry.getKey())) {
+ expectedScore += entry.getValue() * Math.min(
+ getFeatureFieldCompressedNumber(testRankFeaturesDoc.get(entry.getKey())),
+ maxTokenScore
+ );
+ }
+ }
+ assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA);
+ }
+
+ /**
+ * Tests basic query:
+ * {
+ * "query": {
+ * "neural_sparse": {
* "text_sparse": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd",
@@ -97,9 +137,10 @@ public void testBasicQueryUsingQueryText() {
public void testBoostQuery() {
initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME);
String modelId = getDeployedModelId();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(
- TEST_SPARSE_ENCODING_FIELD_NAME_1
- ).queryText(TEST_QUERY_TEXT).modelId(modelId).boost(2.0f);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId)
+ .boost(2.0f);
Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, sparseEncodingQueryBuilder, 1);
Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
@@ -117,7 +158,7 @@ public void testBoostQuery() {
* "rescore": {
* "query": {
* "rescore_query": {
- * "sparse_encoding": {
+ * "neural_sparse": {
* "text_sparse": {
* * "query_text": "Hello world a b",
* * "model_id": "dcsdcasd"
@@ -133,9 +174,9 @@ public void testRescoreQuery() {
initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME);
String modelId = getDeployedModelId();
MatchAllQueryBuilder matchAllQueryBuilder = new MatchAllQueryBuilder();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(
- TEST_SPARSE_ENCODING_FIELD_NAME_1
- ).queryText(TEST_QUERY_TEXT).modelId(modelId);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId);
Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, matchAllQueryBuilder, sparseEncodingQueryBuilder, 1);
Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
@@ -150,13 +191,13 @@ public void testRescoreQuery() {
* "query": {
* "bool" : {
* "should": [
- * "sparse_encoding": {
+ * "neural_sparse": {
* "field1": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd"
* }
* },
- * "sparse_encoding": {
+ * "neural_sparse": {
* "field2": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd"
@@ -169,20 +210,20 @@ public void testRescoreQuery() {
*/
@SneakyThrows
public void testBooleanQuery_withMultipleSparseEncodingQueries() {
- initializeIndexIfNotExist(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME);
+ initializeIndexIfNotExist(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME);
String modelId = getDeployedModelId();
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder1 = new SparseEncodingQueryBuilder().fieldName(
- TEST_SPARSE_ENCODING_FIELD_NAME_1
- ).queryText(TEST_QUERY_TEXT).modelId(modelId);
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder2 = new SparseEncodingQueryBuilder().fieldName(
- TEST_SPARSE_ENCODING_FIELD_NAME_2
- ).queryText(TEST_QUERY_TEXT).modelId(modelId);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder1 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder2 = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_2)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId);
boolQueryBuilder.should(sparseEncodingQueryBuilder1).should(sparseEncodingQueryBuilder2);
- Map searchResponseAsMap = search(TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1);
+ Map searchResponseAsMap = search(TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME, boolQueryBuilder, 1);
Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
assertEquals("1", firstInnerHit.get("_id"));
@@ -196,13 +237,13 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() {
* "query": {
* "bool" : {
* "should": [
- * "sparse_encoding": {
+ * "neural_sparse": {
* "field1": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd"
* }
* },
- * "sparse_encoding": {
+ * "neural_sparse": {
* "field2": {
* "query_text": "Hello world a b",
* "model_id": "dcsdcasd"
@@ -215,17 +256,17 @@ public void testBooleanQuery_withMultipleSparseEncodingQueries() {
*/
@SneakyThrows
public void testBooleanQuery_withSparseEncodingAndBM25Queries() {
- initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME);
+ initializeIndexIfNotExist(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME);
String modelId = getDeployedModelId();
BoolQueryBuilder boolQueryBuilder = new BoolQueryBuilder();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(
- TEST_SPARSE_ENCODING_FIELD_NAME_1
- ).queryText(TEST_QUERY_TEXT).modelId(modelId);
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_NEURAL_SPARSE_FIELD_NAME_1)
+ .queryText(TEST_QUERY_TEXT)
+ .modelId(modelId);
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT);
boolQueryBuilder.should(sparseEncodingQueryBuilder).should(matchQueryBuilder);
- Map searchResponseAsMap = search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, boolQueryBuilder, 1);
+ Map searchResponseAsMap = search(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME, boolQueryBuilder, 1);
Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
assertEquals("1", firstInnerHit.get("_id"));
@@ -235,41 +276,41 @@ public void testBooleanQuery_withSparseEncodingAndBM25Queries() {
@SneakyThrows
public void testBasicQueryUsingQueryText_whenQueryWrongFieldType_thenFail() {
- initializeIndexIfNotExist(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME);
+ initializeIndexIfNotExist(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME);
String modelId = getDeployedModelId();
- SparseEncodingQueryBuilder sparseEncodingQueryBuilder = new SparseEncodingQueryBuilder().fieldName(TEST_TEXT_FIELD_NAME_1)
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(TEST_TEXT_FIELD_NAME_1)
.queryText(TEST_QUERY_TEXT)
.modelId(modelId);
- expectThrows(ResponseException.class, () -> search(TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME, sparseEncodingQueryBuilder, 1));
+ expectThrows(ResponseException.class, () -> search(TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME, sparseEncodingQueryBuilder, 1));
}
@SneakyThrows
protected void initializeIndexIfNotExist(String indexName) {
if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
- prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1));
- addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1), List.of(testRankFeaturesDoc));
+ prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1));
+ addSparseEncodingDoc(indexName, "1", List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1), List.of(testRankFeaturesDoc));
assertEquals(1, getDocCount(indexName));
}
- if (TEST_MULTI_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
- prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2));
+ if (TEST_MULTI_NEURAL_SPARSE_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
+ prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1, TEST_NEURAL_SPARSE_FIELD_NAME_2));
addSparseEncodingDoc(
indexName,
"1",
- List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1, TEST_SPARSE_ENCODING_FIELD_NAME_2),
+ List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1, TEST_NEURAL_SPARSE_FIELD_NAME_2),
List.of(testRankFeaturesDoc, testRankFeaturesDoc)
);
assertEquals(1, getDocCount(indexName));
}
- if (TEST_TEXT_AND_SPARSE_ENCODING_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
- prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1));
+ if (TEST_TEXT_AND_NEURAL_SPARSE_FIELD_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
+ prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1));
addSparseEncodingDoc(
indexName,
"1",
- List.of(TEST_SPARSE_ENCODING_FIELD_NAME_1),
+ List.of(TEST_NEURAL_SPARSE_FIELD_NAME_1),
List.of(testRankFeaturesDoc),
List.of(TEST_TEXT_FIELD_NAME_1),
List.of(TEST_QUERY_TEXT)
@@ -278,8 +319,8 @@ protected void initializeIndexIfNotExist(String indexName) {
}
if (TEST_NESTED_INDEX_NAME.equals(indexName) && !indexExists(indexName)) {
- prepareSparseEncodingIndex(indexName, List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED));
- addSparseEncodingDoc(indexName, "1", List.of(TEST_SPARSE_ENCODING_FIELD_NAME_NESTED), List.of(testRankFeaturesDoc));
+ prepareSparseEncodingIndex(indexName, List.of(TEST_NEURAL_SPARSE_FIELD_NAME_NESTED));
+ addSparseEncodingDoc(indexName, "1", List.of(TEST_NEURAL_SPARSE_FIELD_NAME_NESTED), List.of(testRankFeaturesDoc));
assertEquals(1, getDocCount(TEST_NESTED_INDEX_NAME));
}
}
diff --git a/src/test/resources/processor/SparseEncodingIndexMappings.json b/src/test/resources/processor/SparseEncodingIndexMappings.json
index 87dee278e..9748e8f3d 100644
--- a/src/test/resources/processor/SparseEncodingIndexMappings.json
+++ b/src/test/resources/processor/SparseEncodingIndexMappings.json
@@ -23,4 +23,4 @@
}
}
}
-}
\ No newline at end of file
+}
diff --git a/src/test/resources/processor/SparseEncodingPipelineConfiguration.json b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json
index 82d13c8fe..04a4baf80 100644
--- a/src/test/resources/processor/SparseEncodingPipelineConfiguration.json
+++ b/src/test/resources/processor/SparseEncodingPipelineConfiguration.json
@@ -15,4 +15,4 @@
}
}
]
-}
\ No newline at end of file
+}
diff --git a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json
index c45334bae..50b4b8a9b 100644
--- a/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json
+++ b/src/test/resources/processor/UploadSparseEncodingModelRequestBody.json
@@ -7,4 +7,4 @@
"model_group_id": "",
"model_content_hash_value": "b345e9e943b62c405a8dd227ef2c46c84c5ff0a0b71b584be9132b37bce91a9a",
"url": "https://github.com/opensearch-project/ml-commons/raw/main/ml-algorithms/src/test/resources/org/opensearch/ml/engine/algorithms/sparse_encoding/sparse_demo.zip"
-}
\ No newline at end of file
+}