diff --git a/CHANGELOG.md b/CHANGELOG.md index f5d0ab4f7af38..9396907a107f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,8 +24,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Remove handling of index.mapper.dynamic in AutoCreateIndex([#13067](https://github.com/opensearch-project/OpenSearch/pull/13067)) ### Fixed -- Fix negative RequestStats metric issue ([#13553](https://github.com/opensearch-project/OpenSearch/pull/13553)) - Fix get field mapping API returns 404 error in mixed cluster with multiple versions ([#13624](https://github.com/opensearch-project/OpenSearch/pull/13624)) +- Replace negative input scores to function/script score queries with zero to avoid downstream exception ([#13627](https://github.com/opensearch-project/OpenSearch/pull/13627)) ### Security diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml index a006fde630716..e48add37f44b8 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml @@ -482,3 +482,79 @@ }] - match: { error.root_cause.0.type: "illegal_argument_exception" } - match: { error.root_cause.0.reason: "script score function must not produce negative scores, but got: [-9.0]"} + +--- +"Do not throw exception if input score is negative": + - do: + index: + index: test + id: 1 + body: { "color" : "orange red yellow" } + - do: + index: + index: test + id: 2 + body: { "color": "orange red purple", "shape": "red square" } + - do: + index: + index: test + id: 3 + body: { "color" : "orange red yellow purple" } + - do: + indices.refresh: { } + - do: + search: + index: test + body: + query: + function_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + functions: [{ + "script_score": { + "script": { + "lang": "painless", + "source": "_score" + } + } + }] + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } + - do: + search: + index: test + body: + query: + function_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + weight: 1 + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } + - do: + search: + index: test + body: + query: + script_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + script: + source: "_score" + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java index cb93e80288a98..512dec9f1f355 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java @@ -533,8 +533,10 @@ public float score() throws IOException { int docId = docID(); // Even if the weight is created with needsScores=false, it might // be costly to call score(), so we explicitly check if scores - // are needed - float subQueryScore = needsScores ? super.score() : 0f; + // are needed. + // While the function scorer should never turn a score negative, we + // must guard against the input score being negative. + float subQueryScore = needsScores ? Math.max(0f, super.score()) : 0f; if (leafFunctions.length == 0) { return subQueryScore; } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java index 38c356a8be4b0..146dfd4440e2f 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java @@ -52,8 +52,14 @@ public class ScriptScoreFunction extends ScoreFunction { static final class CannedScorer extends Scorable { - protected int docid; - protected float score; + private int docid; + private float score; + + public void score(float subScore) { + // We check to make sure the script score function never makes a score negative, but we need to make + // sure the script score function does not receive negative input. + this.score = Math.max(0.0f, subScore); + } @Override public int docID() { @@ -105,7 +111,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx public double score(int docId, float subQueryScore) throws IOException { leafScript.setDocument(docId); scorer.docid = docId; - scorer.score = subQueryScore; + scorer.score(subQueryScore); double result = leafScript.execute(null); if (result < 0f) { throw new IllegalArgumentException("script score function must not produce negative scores, but got: [" + result + "]"); @@ -119,7 +125,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE if (leafScript instanceof ExplainableScoreScript) { leafScript.setDocument(docId); scorer.docid = docId; - scorer.score = subQueryScore.getValue().floatValue(); + scorer.score(subQueryScore.getValue().floatValue()); exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName); } else { double score = score(docId, subQueryScore.getValue().floatValue()); diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 70de636a655f2..63bf9f43d133f 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -165,7 +165,9 @@ public void setDocument(int docid) { public void setScorer(Scorable scorer) { this.scoreSupplier = () -> { try { - return scorer.score(); + // The ScoreScript is forbidden from returning a negative value. + // We should guard against receiving negative input. + return Math.max(0f, scorer.score()); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java b/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java new file mode 100644 index 0000000000000..06b29cd1c1303 --- /dev/null +++ b/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java @@ -0,0 +1,114 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.Explanation; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +/** + * Similar to Lucene's BoostQuery, but will accept negative boost values (which is normally wrong, since scores + * should not be negative). Useful for testing that other query types guard against negative input scores. + */ +public class NegativeBoostQuery extends Query { + private final Query query; + private final float boost; + + public NegativeBoostQuery(Query query, float boost) { + if (boost >= 0) { + throw new IllegalArgumentException("Expected negative boost. Use BoostQuery if boost is non-negative."); + } + this.boost = boost; + this.query = query; + } + + @Override + public String toString(String field) { + StringBuilder builder = new StringBuilder(); + builder.append("("); + builder.append(query.toString(field)); + builder.append(")"); + builder.append("^"); + builder.append(boost); + return builder.toString(); + } + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor); + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(NegativeBoostQuery other) { + return query.equals(other.query) && Float.floatToIntBits(boost) == Float.floatToIntBits(other.boost); + } + + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + query.hashCode(); + h = 31 * h + Float.floatToIntBits(boost); + return h; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + final float negativeBoost = this.boost; + Weight delegate = query.createWeight(searcher, scoreMode, boost); + return new Weight(this) { + @Override + public Explanation explain(LeafReaderContext context, int doc) throws IOException { + return delegate.explain(context, doc); + } + + @Override + public Scorer scorer(LeafReaderContext context) throws IOException { + Scorer delegateScorer = delegate.scorer(context); + return new Scorer(this) { + @Override + public DocIdSetIterator iterator() { + return delegateScorer.iterator(); + } + + @Override + public float getMaxScore(int upTo) throws IOException { + return delegateScorer.getMaxScore(upTo); + } + + @Override + public float score() throws IOException { + return delegateScorer.score() * negativeBoost; + } + + @Override + public int docID() { + return delegateScorer.docID(); + } + }; + } + + @Override + public boolean isCacheable(LeafReaderContext ctx) { + return delegate.isCacheable(ctx); + } + }; + } +} diff --git a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java index 0ea91efc568d0..02ffd97835ef0 100644 --- a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java +++ b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java @@ -74,6 +74,7 @@ import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; +import org.opensearch.index.query.NegativeBoostQuery; import org.opensearch.search.DocValueFormat; import org.opensearch.search.MultiValueMode; import org.opensearch.search.aggregations.support.ValuesSourceType; @@ -1095,6 +1096,24 @@ public void testExceptionOnNegativeScores() { assertThat(exc.getMessage(), not(containsString("consider using log1p or log2p instead of log to avoid negative scores"))); } + public void testNoExceptionOnNegativeScoreInput() throws IOException { + IndexSearcher localSearcher = new IndexSearcher(reader); + TermQuery termQuery = new TermQuery(new Term(FIELD, "out")); + + // test that field_value_factor function throws an exception on negative scores + FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE; + + final ScoreFunction fvfFunction = new FieldValueFactorFunction(FIELD, 1, modifier, 1.0, new IndexNumericFieldDataStub()); + FunctionScoreQuery fsQuery1 = new FunctionScoreQuery( + new NegativeBoostQuery(termQuery, -10f), + fvfFunction, + CombineFunction.MULTIPLY, + null, + Float.POSITIVE_INFINITY + ); + localSearcher.search(fsQuery1, 1); + } + public void testExceptionOnLnNegativeScores() { IndexSearcher localSearcher = new IndexSearcher(reader); TermQuery termQuery = new TermQuery(new Term(FIELD, "out")); diff --git a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java index 55c50b8cf854d..6d3fdb8f06a0b 100644 --- a/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java +++ b/server/src/test/java/org/opensearch/search/query/ScriptScoreQueryTests.java @@ -39,6 +39,7 @@ import org.apache.lucene.index.DirectoryReader; import org.apache.lucene.index.IndexWriter; import org.apache.lucene.index.LeafReaderContext; +import org.apache.lucene.index.Term; import org.apache.lucene.search.DocIdSetIterator; import org.apache.lucene.search.Explanation; import org.apache.lucene.search.IndexSearcher; @@ -46,12 +47,15 @@ import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.TermQuery; +import org.apache.lucene.search.TopDocs; import org.apache.lucene.search.TwoPhaseIterator; import org.apache.lucene.search.Weight; import org.apache.lucene.store.Directory; import org.opensearch.Version; import org.opensearch.common.lucene.search.Queries; import org.opensearch.common.lucene.search.function.ScriptScoreQuery; +import org.opensearch.index.query.NegativeBoostQuery; import org.opensearch.script.ScoreScript; import org.opensearch.script.Script; import org.opensearch.script.ScriptType; @@ -64,6 +68,7 @@ import java.io.IOException; import java.util.HashMap; import java.util.Map; +import java.util.function.BiFunction; import java.util.function.Function; import static org.hamcrest.CoreMatchers.containsString; @@ -185,6 +190,15 @@ public void testScriptScoreErrorOnNegativeScore() { assertTrue(e.getMessage().contains("Must be a non-negative score!")); } + public void testNoExceptionOnNegativeInputScore() throws IOException { + Script script = new Script("script that returns _score"); + ScoreScript.LeafFactory factory = newFactory(script, true, (s, e) -> s.get_score()); + NegativeBoostQuery negativeBoostQuery = new NegativeBoostQuery(new TermQuery(new Term("field", "text")), -10.0f); + ScriptScoreQuery query = new ScriptScoreQuery(negativeBoostQuery, script, factory, -1f, "index", 0, Version.CURRENT); + TopDocs topDocs = searcher.search(query, 1); + assertEquals(0.0f, topDocs.scoreDocs[0].score, 0.0001); + } + public void testTwoPhaseIteratorDelegation() throws IOException { Map params = new HashMap<>(); String scriptSource = "doc['field'].value != null ? 2.0 : 0.0"; // Adjust based on actual field and logic @@ -220,6 +234,14 @@ private ScoreScript.LeafFactory newFactory( Script script, boolean needsScore, Function function + ) { + return newFactory(script, needsScore, (s, e) -> function.apply(e)); + } + + private ScoreScript.LeafFactory newFactory( + Script script, + boolean needsScore, + BiFunction function ) { SearchLookup lookup = mock(SearchLookup.class); LeafSearchLookup leafLookup = mock(LeafSearchLookup.class); @@ -236,7 +258,7 @@ public ScoreScript newInstance(LeafReaderContext ctx) throws IOException { return new ScoreScript(script.getParams(), lookup, indexSearcher, leafReaderContext) { @Override public double execute(ExplanationHolder explanation) { - return function.apply(explanation); + return function.apply(this, explanation); } }; }