diff --git a/CHANGELOG.md b/CHANGELOG.md
index 10063185c..438c4d65a 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -21,3 +21,4 @@ Fixed exception for case when Hybrid query being wrapped into bool query ([#490]
### Documentation
### Maintenance
### Refactoring
+Deprecate the `max_token_score` field in `neural_sparse` query clause ([#478](https://github.com/opensearch-project/neural-search/pull/478))
diff --git a/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java b/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java
deleted file mode 100644
index a914f3156..000000000
--- a/src/main/java/org/apache/lucene/BoundedLinearFeatureQuery.java
+++ /dev/null
@@ -1,237 +0,0 @@
-/*
- * SPDX-License-Identifier: Apache-2.0
- *
- * The OpenSearch Contributors require contributions made to
- * this file be licensed under the Apache-2.0 license or a
- * compatible open source license.
- */
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-/*
- * Modifications Copyright OpenSearch Contributors. See
- * GitHub history for details.
- */
-
-/*
- * This class is built based on lucene FeatureQuery. We use LinearFuntion to
- * build the query and add an upperbound to it.
- */
-
-package org.apache.lucene;
-
-import java.io.IOException;
-import java.util.Objects;
-
-import org.apache.lucene.index.ImpactsEnum;
-import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.PostingsEnum;
-import org.apache.lucene.index.Term;
-import org.apache.lucene.index.Terms;
-import org.apache.lucene.index.TermsEnum;
-import org.apache.lucene.search.DocIdSetIterator;
-import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.ImpactsDISI;
-import org.apache.lucene.search.IndexSearcher;
-import org.apache.lucene.search.MaxScoreCache;
-import org.apache.lucene.search.Query;
-import org.apache.lucene.search.QueryVisitor;
-import org.apache.lucene.search.ScoreMode;
-import org.apache.lucene.search.Scorer;
-import org.apache.lucene.search.TermQuery;
-import org.apache.lucene.search.Weight;
-import org.apache.lucene.search.similarities.Similarity.SimScorer;
-import org.apache.lucene.util.BytesRef;
-
-/**
- * The feature queries of input tokens are wrapped by lucene BooleanQuery, which use WAND algorithm
- * to accelerate the execution. The WAND algorithm leverage the score upper bound of sub-queries to
- * skip non-competitive tokens. However, origin lucene FeatureQuery use Float.MAX_VALUE as the score
- * upper bound, and this invalidates WAND.
- *
- * To mitigate this issue, we rewrite the FeatureQuery to BoundedLinearFeatureQuery. The caller can
- * set the token score upperbound of this query. And according to our use case, we use LinearFunction
- * as the score function.
- *
- * This class combines both FeatureQuery
- * and FeatureField together
- * and will be deprecated after OpenSearch upgraded lucene to version 9.8.
- */
-
-public final class BoundedLinearFeatureQuery extends Query {
-
- private final String fieldName;
- private final String featureName;
- private final Float scoreUpperBound;
-
- public BoundedLinearFeatureQuery(String fieldName, String featureName, Float scoreUpperBound) {
- this.fieldName = Objects.requireNonNull(fieldName);
- this.featureName = Objects.requireNonNull(featureName);
- this.scoreUpperBound = Objects.requireNonNull(scoreUpperBound);
- }
-
- @Override
- public Query rewrite(IndexSearcher indexSearcher) throws IOException {
- // LinearFunction return same object for rewrite
- return super.rewrite(indexSearcher);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || getClass() != obj.getClass()) {
- return false;
- }
- BoundedLinearFeatureQuery that = (BoundedLinearFeatureQuery) obj;
- return Objects.equals(fieldName, that.fieldName)
- && Objects.equals(featureName, that.featureName)
- && Objects.equals(scoreUpperBound, that.scoreUpperBound);
- }
-
- @Override
- public int hashCode() {
- int h = getClass().hashCode();
- h = 31 * h + fieldName.hashCode();
- h = 31 * h + featureName.hashCode();
- h = 31 * h + scoreUpperBound.hashCode();
- return h;
- }
-
- @Override
- public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException {
- if (!scoreMode.needsScores()) {
- // We don't need scores (e.g. for faceting), and since features are stored as terms,
- // allow TermQuery to optimize in this case
- TermQuery tq = new TermQuery(new Term(fieldName, featureName));
- return searcher.rewrite(tq).createWeight(searcher, scoreMode, boost);
- }
-
- return new Weight(this) {
-
- @Override
- public boolean isCacheable(LeafReaderContext ctx) {
- return false;
- }
-
- @Override
- public Explanation explain(LeafReaderContext context, int doc) throws IOException {
- String desc = "weight(" + getQuery() + " in " + doc + ") [\" BoundedLinearFeatureQuery \"]";
-
- Terms terms = context.reader().terms(fieldName);
- if (terms == null) {
- return Explanation.noMatch(desc + ". Field " + fieldName + " doesn't exist.");
- }
- TermsEnum termsEnum = terms.iterator();
- if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
- return Explanation.noMatch(desc + ". Feature " + featureName + " doesn't exist.");
- }
-
- PostingsEnum postings = termsEnum.postings(null, PostingsEnum.FREQS);
- if (postings.advance(doc) != doc) {
- return Explanation.noMatch(desc + ". Feature " + featureName + " isn't set.");
- }
-
- int freq = postings.freq();
- float featureValue = decodeFeatureValue(freq);
- float score = boost * featureValue;
- return Explanation.match(
- score,
- "Linear function on the " + fieldName + " field for the " + featureName + " feature, computed as w * S from:",
- Explanation.match(boost, "w, weight of this function"),
- Explanation.match(featureValue, "S, feature value")
- );
- }
-
- @Override
- public Scorer scorer(LeafReaderContext context) throws IOException {
- Terms terms = Terms.getTerms(context.reader(), fieldName);
- TermsEnum termsEnum = terms.iterator();
- if (termsEnum.seekExact(new BytesRef(featureName)) == false) {
- return null;
- }
-
- final SimScorer scorer = new SimScorer() {
- @Override
- public float score(float freq, long norm) {
- return boost * decodeFeatureValue(freq);
- }
- };
- final ImpactsEnum impacts = termsEnum.impacts(PostingsEnum.FREQS);
- MaxScoreCache maxScoreCache = new MaxScoreCache(impacts, scorer);
- final ImpactsDISI impactsDisi = new ImpactsDISI(impacts, maxScoreCache);
-
- return new Scorer(this) {
-
- @Override
- public int docID() {
- return impacts.docID();
- }
-
- @Override
- public float score() throws IOException {
- return scorer.score(impacts.freq(), 1L);
- }
-
- @Override
- public DocIdSetIterator iterator() {
- return impactsDisi;
- }
-
- @Override
- public int advanceShallow(int target) throws IOException {
- return impactsDisi.getMaxScoreCache().advanceShallow(target);
- }
-
- @Override
- public float getMaxScore(int upTo) throws IOException {
- return impactsDisi.getMaxScoreCache().getMaxScore(upTo);
- }
-
- @Override
- public void setMinCompetitiveScore(float minScore) {
- impactsDisi.setMinCompetitiveScore(minScore);
- }
- };
- }
- };
- }
-
- @Override
- public void visit(QueryVisitor visitor) {
- if (visitor.acceptField(fieldName)) {
- visitor.visitLeaf(this);
- }
- }
-
- @Override
- public String toString(String field) {
- return "BoundedLinearFeatureQuery(field=" + fieldName + ", feature=" + featureName + ", scoreUpperBound=" + scoreUpperBound + ")";
- }
-
- // the field and decodeFeatureValue are modified from FeatureField.decodeFeatureValue
- static final int MAX_FREQ = Float.floatToIntBits(Float.MAX_VALUE) >>> 15;
-
- // Rewriting this function to make scoreUpperBound work.
- private float decodeFeatureValue(float freq) {
- if (freq > MAX_FREQ) {
- return scoreUpperBound;
- }
- int tf = (int) freq; // lossless
- int featureBits = tf << 15;
- return Math.min(Float.intBitsToFloat(featureBits), scoreUpperBound);
- }
-}
diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java
index d883af23d..20eeb2e11 100644
--- a/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java
+++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilder.java
@@ -21,10 +21,9 @@
import org.apache.commons.lang.StringUtils;
import org.apache.commons.lang.builder.EqualsBuilder;
import org.apache.commons.lang.builder.HashCodeBuilder;
-import org.apache.lucene.BoundedLinearFeatureQuery;
+import org.apache.lucene.document.FeatureField;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
-import org.apache.lucene.search.BoostQuery;
import org.apache.lucene.search.Query;
import org.opensearch.common.SetOnce;
import org.opensearch.core.ParseField;
@@ -62,8 +61,11 @@ public class NeuralSparseQueryBuilder extends AbstractQueryBuilder queryTokens = queryTokensSupplier.get();
validateQueryTokens(queryTokens);
- final Float scoreUpperBound = maxTokenScore != null ? maxTokenScore : Float.MAX_VALUE;
-
BooleanQuery.Builder builder = new BooleanQuery.Builder();
for (Map.Entry entry : queryTokens.entrySet()) {
- builder.add(
- new BoostQuery(new BoundedLinearFeatureQuery(fieldName, entry.getKey(), scoreUpperBound), entry.getValue()),
- BooleanClause.Occur.SHOULD
- );
+ builder.add(FeatureField.newLinearQuery(fieldName, entry.getKey(), entry.getValue()), BooleanClause.Occur.SHOULD);
}
return builder.build();
}
diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java
index a50ab4fb8..9d1a1627b 100644
--- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java
+++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryBuilderTests.java
@@ -26,6 +26,9 @@
import lombok.SneakyThrows;
+import org.apache.lucene.document.FeatureField;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
import org.opensearch.client.Client;
import org.opensearch.common.SetOnce;
import org.opensearch.common.io.stream.BytesStreamOutput;
@@ -38,9 +41,11 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
+import org.opensearch.index.mapper.MappedFieldType;
import org.opensearch.index.query.MatchAllQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryRewriteContext;
+import org.opensearch.index.query.QueryShardContext;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.test.OpenSearchTestCase;
@@ -88,7 +93,6 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
"VECTOR_FIELD": {
"query_text": "string",
"model_id": "string",
- "max_token_score": 123.0,
"boost": 10.0,
"_name": "something",
}
@@ -99,7 +103,6 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
- .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
.field(BOOST_FIELD.getPreferredName(), BOOST)
.field(NAME_FIELD.getPreferredName(), QUERY_NAME)
.endObject()
@@ -112,22 +115,19 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() {
assertEquals(FIELD_NAME, sparseEncodingQueryBuilder.fieldName());
assertEquals(QUERY_TEXT, sparseEncodingQueryBuilder.queryText());
assertEquals(MODEL_ID, sparseEncodingQueryBuilder.modelId());
- assertEquals(MAX_TOKEN_SCORE, sparseEncodingQueryBuilder.maxTokenScore(), 0.0);
assertEquals(BOOST, sparseEncodingQueryBuilder.boost(), 0.0);
assertEquals(QUERY_NAME, sparseEncodingQueryBuilder.queryName());
}
@SneakyThrows
- public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
+ public void testFromXContent_whenBuiltWithMaxTokenScore_thenThrowWarning() {
/*
{
"VECTOR_FIELD": {
"query_text": "string",
"model_id": "string",
- "boost": 10.0,
- "_name": "something",
- },
- "invalid": 10
+ "max_token_score": 123.0
+ }
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
@@ -135,46 +135,51 @@ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
.startObject(FIELD_NAME)
.field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
- .field(BOOST_FIELD.getPreferredName(), BOOST)
- .field(NAME_FIELD.getPreferredName(), QUERY_NAME)
+ .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), MAX_TOKEN_SCORE)
.endObject()
- .field("invalid", 10)
.endObject();
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- expectThrows(ParsingException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = NeuralSparseQueryBuilder.fromXContent(contentParser);
+ assertWarnings("Deprecated field [max_token_score] used, this field is unused and will be removed entirely");
}
@SneakyThrows
- public void testFromXContent_whenBuildWithMissingQuery_thenFail() {
+ public void testFromXContent_whenBuildWithMultipleRootFields_thenFail() {
/*
{
"VECTOR_FIELD": {
- "model_id": "string"
- }
+ "query_text": "string",
+ "model_id": "string",
+ "boost": 10.0,
+ "_name": "something",
+ },
+ "invalid": 10
}
*/
XContentBuilder xContentBuilder = XContentFactory.jsonBuilder()
.startObject()
.startObject(FIELD_NAME)
+ .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
+ .field(BOOST_FIELD.getPreferredName(), BOOST)
+ .field(NAME_FIELD.getPreferredName(), QUERY_NAME)
.endObject()
+ .field("invalid", 10)
.endObject();
XContentParser contentParser = createParser(xContentBuilder);
contentParser.nextToken();
- expectThrows(IllegalArgumentException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
+ expectThrows(ParsingException.class, () -> NeuralSparseQueryBuilder.fromXContent(contentParser));
}
@SneakyThrows
- public void testFromXContent_whenBuildWithNegativeMaxTokenScore_thenFail() {
+ public void testFromXContent_whenBuildWithMissingQuery_thenFail() {
/*
{
"VECTOR_FIELD": {
- "query_text": "string",
- "model_id": "string",
- "max_token_score": -1
+ "model_id": "string"
}
}
*/
@@ -182,7 +187,6 @@ public void testFromXContent_whenBuildWithNegativeMaxTokenScore_thenFail() {
.startObject()
.startObject(FIELD_NAME)
.field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID)
- .field(MAX_TOKEN_SCORE_FIELD.getPreferredName(), -1f)
.endObject()
.endObject();
@@ -498,4 +502,23 @@ public void testRewrite_whenQueryTokensSupplierSet_thenReturnSelf() {
queryBuilder = sparseEncodingQueryBuilder.doRewrite(null);
assertTrue(queryBuilder == sparseEncodingQueryBuilder);
}
+
+ @SneakyThrows
+ public void testDoToQuery_successfulDoToQuery() {
+ NeuralSparseQueryBuilder sparseEncodingQueryBuilder = new NeuralSparseQueryBuilder().fieldName(FIELD_NAME)
+ .maxTokenScore(MAX_TOKEN_SCORE)
+ .queryText(QUERY_TEXT)
+ .modelId(MODEL_ID)
+ .queryTokensSupplier(QUERY_TOKENS_SUPPLIER);
+ QueryShardContext mockedQueryShardContext = mock(QueryShardContext.class);
+ MappedFieldType mockedMappedFieldType = mock(MappedFieldType.class);
+ doAnswer(invocation -> "rank_features").when(mockedMappedFieldType).typeName();
+ doAnswer(invocation -> mockedMappedFieldType).when(mockedQueryShardContext).fieldMapper(any());
+
+ BooleanQuery.Builder targetQueryBuilder = new BooleanQuery.Builder();
+ targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "hello", 1.f), BooleanClause.Occur.SHOULD);
+ targetQueryBuilder.add(FeatureField.newLinearQuery(FIELD_NAME, "world", 2.f), BooleanClause.Occur.SHOULD);
+
+ assertEquals(sparseEncodingQueryBuilder.doToQuery(mockedQueryShardContext), targetQueryBuilder.build());
+ }
}
diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java
index 672ab2940..12bd1c1cb 100644
--- a/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java
+++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralSparseQueryIT.java
@@ -106,16 +106,7 @@ public void testBasicQueryWithMaxTokenScore() {
Map firstInnerHit = getFirstInnerHit(searchResponseAsMap);
assertEquals("1", firstInnerHit.get("_id"));
- Map queryTokens = runSparseModelInference(modelId, TEST_QUERY_TEXT);
- float expectedScore = 0f;
- for (Map.Entry entry : queryTokens.entrySet()) {
- if (testRankFeaturesDoc.containsKey(entry.getKey())) {
- expectedScore += entry.getValue() * Math.min(
- getFeatureFieldCompressedNumber(testRankFeaturesDoc.get(entry.getKey())),
- maxTokenScore
- );
- }
- }
+ float expectedScore = computeExpectedScore(modelId, testRankFeaturesDoc, TEST_QUERY_TEXT);
assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), DELTA);
}