From f5344d248cee53c0a8994fbe38fab9c69f5a50b3 Mon Sep 17 00:00:00 2001 From: Michael Froh Date: Fri, 31 May 2024 18:34:22 +0000 Subject: [PATCH] Fix negative scores returned from `multi_match` query with `cross_fields` (#13829) Under specific circumstances, when using `cross_fields` scoring on a `multi_match` query, we can end up with negative scores from the inverse document frequency calculation in the BM25 formula. Specifically, the IDF is calculated as: ``` log(1 + (N - n + 0.5) / (n + 0.5)) ``` where `N` is the number of documents containing the field and `n` is the number of documents containing the given term in the field. Obviously, `n` should always be less than or equal to `N`. Unfortunately, `cross_fields` makes up a new value for `n` and tries to use it across all fields. This change finds the (nonzero) value of `N` for each field and uses that as an upper bound for the new value of `n`. Signed-off-by: Michael Froh --------- Signed-off-by: Michael Froh --- CHANGELOG.md | 1 + .../test/search/50_multi_match.yml | 35 ++++++++++++++++++ .../lucene/queries/BlendedTermQuery.java | 8 +++- .../test/rest/yaml/section/Assertion.java | 37 +++++++++++++++++++ .../yaml/section/GreaterThanAssertion.java | 1 + .../section/GreaterThanEqualToAssertion.java | 1 + .../rest/yaml/section/LessThanAssertion.java | 1 + .../section/LessThanOrEqualToAssertion.java | 1 + 8 files changed, 84 insertions(+), 1 deletion(-) create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search/50_multi_match.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index a98d0be56a658..f43e8e40c7338 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -46,6 +46,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Fix get field mapping API returns 404 error in mixed cluster with multiple versions ([#13624](https://github.com/opensearch-project/OpenSearch/pull/13624)) - Allow clearing `remote_store.compatibility_mode` setting ([#13646](https://github.com/opensearch-project/OpenSearch/pull/13646)) - Fix ReplicaShardBatchAllocator to batch shards without duplicates ([#13710](https://github.com/opensearch-project/OpenSearch/pull/13710)) +- Don't return negative scores from `multi_match` query with `cross_fields` type ([#13829](https://github.com/opensearch-project/OpenSearch/pull/13829)) - Pass parent filter to inner hit query ([#13903](https://github.com/opensearch-project/OpenSearch/pull/13903)) ### Security diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/50_multi_match.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/50_multi_match.yml new file mode 100644 index 0000000000000..34acb5985b555 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/50_multi_match.yml @@ -0,0 +1,35 @@ +"Cross fields do not return negative scores": + - skip: + version: " - 2.99.99" + reason: "This fix is in 2.15. Until we do the BWC dance, we need to skip all pre-3.0, though." + - do: + index: + index: test + id: 1 + body: { "color" : "orange red yellow" } + - do: + index: + index: test + id: 2 + body: { "color": "orange red purple", "shape": "red square" } + - do: + index: + index: test + id: 3 + body: { "color" : "orange red yellow purple" } + - do: + indices.refresh: { } + - do: + search: + index: test + body: + query: + multi_match: + query: "red" + type: "cross_fields" + fields: [ "color", "shape^100"] + tie_breaker: 0.1 + explain: true + - match: { hits.total.value: 3 } + - match: { hits.hits.0._id: "2" } + - gt: { hits.hits.2._score: 0.0 } diff --git a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java index b47b974b96fed..34e1e210d7137 100644 --- a/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java +++ b/server/src/main/java/org/apache/lucene/queries/BlendedTermQuery.java @@ -120,6 +120,7 @@ protected void blend(final TermStates[] contexts, int maxDoc, IndexReader reader } int max = 0; long minSumTTF = Long.MAX_VALUE; + int[] docCounts = new int[contexts.length]; for (int i = 0; i < contexts.length; i++) { TermStates ctx = contexts[i]; int df = ctx.docFreq(); @@ -133,6 +134,7 @@ protected void blend(final TermStates[] contexts, int maxDoc, IndexReader reader // we need to find out the minimum sumTTF to adjust the statistics // otherwise the statistics don't match minSumTTF = Math.min(minSumTTF, reader.getSumTotalTermFreq(terms[i].field())); + docCounts[i] = reader.getDocCount(terms[i].field()); } } if (maxDoc > minSumTTF) { @@ -175,7 +177,11 @@ protected int compare(int i, int j) { if (prev > current) { actualDf++; } - contexts[i] = ctx = adjustDF(reader.getContext(), ctx, Math.min(maxDoc, actualDf)); + // Per field, we want to guarantee that the adjusted df does not exceed the number of docs with the field. + // That is, in the IDF formula (log(1 + (N - n + 0.5) / (n + 0.5))), we need to make sure that n (the + // adjusted df) is never bigger than N (the number of docs with the field). + int fieldMaxDoc = Math.min(maxDoc, docCounts[i]); + contexts[i] = ctx = adjustDF(reader.getContext(), ctx, Math.min(fieldMaxDoc, actualDf)); prev = current; sumTTF += ctx.totalTermFreq(); } diff --git a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/Assertion.java b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/Assertion.java index b9cbaacdf8873..732d4291ae670 100644 --- a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/Assertion.java +++ b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/Assertion.java @@ -37,6 +37,8 @@ import java.io.IOException; import java.util.Map; +import static org.junit.Assert.fail; + /** * Base class for executable sections that hold assertions */ @@ -79,6 +81,41 @@ protected final Object getActualValue(ClientYamlTestExecutionContext executionCo return executionContext.response(field); } + static Object convertActualValue(Object actualValue, Object expectedValue) { + if (actualValue == null || expectedValue.getClass().isAssignableFrom(actualValue.getClass())) { + return actualValue; + } + if (actualValue instanceof Number && expectedValue instanceof Number) { + if (expectedValue instanceof Float) { + return Float.parseFloat(actualValue.toString()); + } else if (expectedValue instanceof Double) { + return Double.parseDouble(actualValue.toString()); + } else if (expectedValue instanceof Integer) { + return Integer.parseInt(actualValue.toString()); + } else if (expectedValue instanceof Long) { + return Long.parseLong(actualValue.toString()); + } + } + // Force a class cast exception here, so developers can flesh out the above logic as needed. + try { + expectedValue.getClass().cast(actualValue); + } catch (ClassCastException e) { + fail( + "Type mismatch: Expected value (" + + expectedValue + + ") has type " + + expectedValue.getClass() + + ". " + + "Actual value (" + + actualValue + + ") has type " + + actualValue.getClass() + + "." + ); + } + return actualValue; + } + @Override public XContentLocation getLocation() { return location; diff --git a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanAssertion.java b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanAssertion.java index 4c2e70f37a33c..0d20dc7c326b0 100644 --- a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanAssertion.java +++ b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanAssertion.java @@ -71,6 +71,7 @@ public GreaterThanAssertion(XContentLocation location, String field, Object expe @Override protected void doAssert(Object actualValue, Object expectedValue) { logger.trace("assert that [{}] is greater than [{}] (field: [{}])", actualValue, expectedValue, getField()); + actualValue = convertActualValue(actualValue, expectedValue); assertThat( "value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])", actualValue, diff --git a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java index 8e929eff44348..a6435c1303489 100644 --- a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java +++ b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/GreaterThanEqualToAssertion.java @@ -72,6 +72,7 @@ public GreaterThanEqualToAssertion(XContentLocation location, String field, Obje @Override protected void doAssert(Object actualValue, Object expectedValue) { logger.trace("assert that [{}] is greater than or equal to [{}] (field: [{}])", actualValue, expectedValue, getField()); + actualValue = convertActualValue(actualValue, expectedValue); assertThat( "value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])", actualValue, diff --git a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanAssertion.java b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanAssertion.java index d6e2ae1e23996..acffe03d34439 100644 --- a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanAssertion.java +++ b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanAssertion.java @@ -72,6 +72,7 @@ public LessThanAssertion(XContentLocation location, String field, Object expecte @Override protected void doAssert(Object actualValue, Object expectedValue) { logger.trace("assert that [{}] is less than [{}] (field: [{}])", actualValue, expectedValue, getField()); + actualValue = convertActualValue(actualValue, expectedValue); assertThat( "value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])", actualValue, diff --git a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java index ee46c04496f32..d685d3e46a543 100644 --- a/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java +++ b/test/framework/src/main/java/org/opensearch/test/rest/yaml/section/LessThanOrEqualToAssertion.java @@ -72,6 +72,7 @@ public LessThanOrEqualToAssertion(XContentLocation location, String field, Objec @Override protected void doAssert(Object actualValue, Object expectedValue) { logger.trace("assert that [{}] is less than or equal to [{}] (field: [{}])", actualValue, expectedValue, getField()); + actualValue = convertActualValue(actualValue, expectedValue); assertThat( "value of [" + getField() + "] is not comparable (got [" + safeClass(actualValue) + "])", actualValue,