From 1081bb98a6a0df8ded57dbe96e927c3444583b57 Mon Sep 17 00:00:00 2001 From: Michael Froh <froh@amazon.com> Date: Fri, 10 May 2024 21:57:12 +0000 Subject: [PATCH] Do not pass negative scores into function_score or script_score queries In theory, Lucene scores should never go negative. To stop users from writing `function_score` and `script_score` queries that return negative values, we explicitly check their outputs and throw an exception when negative. Unfortunately, due to a subtle, more complicated bug in multi_match queries, sometimes those might (incorrectly) return negative scores. While that problem is also worth solving, we should protect function and script scoring from throwing an exception just for passing through a negative value that they had no hand in computing. Signed-off-by: Michael Froh <froh@amazon.com> --- CHANGELOG.md | 2 +- .../rest-api-spec/test/painless/30_search.yml | 76 +++++++++++++++++++ .../search/function/FunctionScoreQuery.java | 6 +- .../search/function/ScriptScoreFunction.java | 14 +++- .../org/opensearch/script/ScoreScript.java | 4 +- .../index/query/NegativeBoostQuery.java | 72 ++++++++++++++++++ .../functionscore/FunctionScoreTests.java | 19 +++++ 7 files changed, 185 insertions(+), 8 deletions(-) create mode 100644 server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java diff --git a/CHANGELOG.md b/CHANGELOG.md index f5d0ab4f7af38..9396907a107f0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,8 +24,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Remove handling of index.mapper.dynamic in AutoCreateIndex([#13067](https://github.com/opensearch-project/OpenSearch/pull/13067)) ### Fixed -- Fix negative RequestStats metric issue ([#13553](https://github.com/opensearch-project/OpenSearch/pull/13553)) - Fix get field mapping API returns 404 error in mixed cluster with multiple versions ([#13624](https://github.com/opensearch-project/OpenSearch/pull/13624)) +- Replace negative input scores to function/script score queries with zero to avoid downstream exception ([#13627](https://github.com/opensearch-project/OpenSearch/pull/13627)) ### Security diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml index a006fde630716..e48add37f44b8 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/30_search.yml @@ -482,3 +482,79 @@ }] - match: { error.root_cause.0.type: "illegal_argument_exception" } - match: { error.root_cause.0.reason: "script score function must not produce negative scores, but got: [-9.0]"} + +--- +"Do not throw exception if input score is negative": + - do: + index: + index: test + id: 1 + body: { "color" : "orange red yellow" } + - do: + index: + index: test + id: 2 + body: { "color": "orange red purple", "shape": "red square" } + - do: + index: + index: test + id: 3 + body: { "color" : "orange red yellow purple" } + - do: + indices.refresh: { } + - do: + search: + index: test + body: + query: + function_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + functions: [{ + "script_score": { + "script": { + "lang": "painless", + "source": "_score" + } + } + }] + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } + - do: + search: + index: test + body: + query: + function_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + weight: 1 + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } + - do: + search: + index: test + body: + query: + script_score: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + script: + source: "_score" + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.2._score: 0.0 } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java index cb93e80288a98..512dec9f1f355 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/FunctionScoreQuery.java @@ -533,8 +533,10 @@ public float score() throws IOException { int docId = docID(); // Even if the weight is created with needsScores=false, it might // be costly to call score(), so we explicitly check if scores - // are needed - float subQueryScore = needsScores ? super.score() : 0f; + // are needed. + // While the function scorer should never turn a score negative, we + // must guard against the input score being negative. + float subQueryScore = needsScores ? Math.max(0f, super.score()) : 0f; if (leafFunctions.length == 0) { return subQueryScore; } diff --git a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java index 38c356a8be4b0..146dfd4440e2f 100644 --- a/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java +++ b/server/src/main/java/org/opensearch/common/lucene/search/function/ScriptScoreFunction.java @@ -52,8 +52,14 @@ public class ScriptScoreFunction extends ScoreFunction { static final class CannedScorer extends Scorable { - protected int docid; - protected float score; + private int docid; + private float score; + + public void score(float subScore) { + // We check to make sure the script score function never makes a score negative, but we need to make + // sure the script score function does not receive negative input. + this.score = Math.max(0.0f, subScore); + } @Override public int docID() { @@ -105,7 +111,7 @@ public LeafScoreFunction getLeafScoreFunction(LeafReaderContext ctx) throws IOEx public double score(int docId, float subQueryScore) throws IOException { leafScript.setDocument(docId); scorer.docid = docId; - scorer.score = subQueryScore; + scorer.score(subQueryScore); double result = leafScript.execute(null); if (result < 0f) { throw new IllegalArgumentException("script score function must not produce negative scores, but got: [" + result + "]"); @@ -119,7 +125,7 @@ public Explanation explainScore(int docId, Explanation subQueryScore) throws IOE if (leafScript instanceof ExplainableScoreScript) { leafScript.setDocument(docId); scorer.docid = docId; - scorer.score = subQueryScore.getValue().floatValue(); + scorer.score(subQueryScore.getValue().floatValue()); exp = ((ExplainableScoreScript) leafScript).explain(subQueryScore, functionName); } else { double score = score(docId, subQueryScore.getValue().floatValue()); diff --git a/server/src/main/java/org/opensearch/script/ScoreScript.java b/server/src/main/java/org/opensearch/script/ScoreScript.java index 70de636a655f2..63bf9f43d133f 100644 --- a/server/src/main/java/org/opensearch/script/ScoreScript.java +++ b/server/src/main/java/org/opensearch/script/ScoreScript.java @@ -165,7 +165,9 @@ public void setDocument(int docid) { public void setScorer(Scorable scorer) { this.scoreSupplier = () -> { try { - return scorer.score(); + // The ScoreScript is forbidden from returning a negative value. + // We should guard against receiving negative input. + return Math.max(0f, scorer.score()); } catch (IOException e) { throw new UncheckedIOException(e); } diff --git a/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java b/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java new file mode 100644 index 0000000000000..91697cddcd133 --- /dev/null +++ b/server/src/test/java/org/opensearch/index/query/NegativeBoostQuery.java @@ -0,0 +1,72 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.index.query; + +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Weight; + +import java.io.IOException; + +/** + * Similar to Lucene's BoostQuery, but will accept negative boost values (which is normally wrong, since scores + * should not be negative). Useful for testing that other query types guard against negative input scores. + */ +public class NegativeBoostQuery extends Query { + private final Query query; + private final float boost; + + public NegativeBoostQuery(Query query, float boost) { + if (boost >= 0) { + throw new IllegalArgumentException("Expected negative boost. Use BoostQuery if boost is non-negative."); + } + this.boost = boost; + this.query = query; + } + + @Override + public String toString(String field) { + StringBuilder builder = new StringBuilder(); + builder.append("("); + builder.append(query.toString(field)); + builder.append(")"); + builder.append("^"); + builder.append(boost); + return builder.toString(); + } + + @Override + public void visit(QueryVisitor visitor) { + query.visit(visitor); + } + + @Override + public boolean equals(Object other) { + return sameClassAs(other) && equalsTo(getClass().cast(other)); + } + + private boolean equalsTo(NegativeBoostQuery other) { + return query.equals(other.query) && Float.floatToIntBits(boost) == Float.floatToIntBits(other.boost); + } + + @Override + public int hashCode() { + int h = classHash(); + h = 31 * h + query.hashCode(); + h = 31 * h + Float.floatToIntBits(boost); + return h; + } + + @Override + public Weight createWeight(IndexSearcher searcher, ScoreMode scoreMode, float boost) throws IOException { + return query.createWeight(searcher, scoreMode, boost * this.boost); + } +} diff --git a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java index 0ea91efc568d0..02ffd97835ef0 100644 --- a/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java +++ b/server/src/test/java/org/opensearch/index/query/functionscore/FunctionScoreTests.java @@ -74,6 +74,7 @@ import org.opensearch.index.fielddata.ScriptDocValues; import org.opensearch.index.fielddata.SortedBinaryDocValues; import org.opensearch.index.fielddata.SortedNumericDoubleValues; +import org.opensearch.index.query.NegativeBoostQuery; import org.opensearch.search.DocValueFormat; import org.opensearch.search.MultiValueMode; import org.opensearch.search.aggregations.support.ValuesSourceType; @@ -1095,6 +1096,24 @@ public void testExceptionOnNegativeScores() { assertThat(exc.getMessage(), not(containsString("consider using log1p or log2p instead of log to avoid negative scores"))); } + public void testNoExceptionOnNegativeScoreInput() throws IOException { + IndexSearcher localSearcher = new IndexSearcher(reader); + TermQuery termQuery = new TermQuery(new Term(FIELD, "out")); + + // test that field_value_factor function throws an exception on negative scores + FieldValueFactorFunction.Modifier modifier = FieldValueFactorFunction.Modifier.NONE; + + final ScoreFunction fvfFunction = new FieldValueFactorFunction(FIELD, 1, modifier, 1.0, new IndexNumericFieldDataStub()); + FunctionScoreQuery fsQuery1 = new FunctionScoreQuery( + new NegativeBoostQuery(termQuery, -10f), + fvfFunction, + CombineFunction.MULTIPLY, + null, + Float.POSITIVE_INFINITY + ); + localSearcher.search(fsQuery1, 1); + } + public void testExceptionOnLnNegativeScores() { IndexSearcher localSearcher = new IndexSearcher(reader); TermQuery termQuery = new TermQuery(new Term(FIELD, "out"));