From 378f8bdbb44676577a4d78180bc4ac426014d63d Mon Sep 17 00:00:00 2001 From: Dharin Shah Date: Tue, 6 Feb 2024 19:01:54 +0100 Subject: [PATCH] add support for scored named queries (#11626) Opensearch already support labelling the queries, that returns as a list in the returned results, of which query it matched. However one of the use case while doing hybrid search with query text and dense vector is to determine individual scores for each query type. This is very useful in further analysis and building offline model to generate better weights for ranking score. Hence adding this feature that sends the client to add the score for each matched query. --------- Signed-off-by: Dharin Shah <8616130+Dharin-shah@users.noreply.github.com> Signed-off-by: Dharin Shah Co-authored-by: Dharin Shah <8616130+Dharin-shah@users.noreply.github.com> (cherry picked from commit 52b27f47bca5b3ab52cab237542f32c307d203b4) --- CHANGELOG.md | 4 + .../resources/rest-api-spec/api/search.json | 5 + .../test/search/350_matched_queries.yml | 103 +++++++++++ .../recovery/IndexPrimaryRelocationIT.java | 21 +-- .../fetch/subphase/MatchedQueriesIT.java | 105 +++++++---- .../search/functionscore/QueryRescorerIT.java | 8 +- .../action/search/SearchRequestBuilder.java | 9 + .../rest/action/search/RestSearchAction.java | 10 +- .../search/DefaultSearchContext.java | 13 ++ .../java/org/opensearch/search/SearchHit.java | 163 +++++++++++------- .../org/opensearch/search/SearchService.java | 1 + .../search/builder/SearchSourceBuilder.java | 34 ++++ .../opensearch/search/fetch/FetchContext.java | 4 + .../opensearch/search/fetch/FetchPhase.java | 2 +- .../fetch/subphase/MatchedQueriesPhase.java | 76 ++++++-- .../internal/FilteredSearchContext.java | 8 + .../search/internal/SearchContext.java | 23 +++ .../search/internal/SubSearchContext.java | 13 ++ .../org/opensearch/search/SearchHitTests.java | 161 ++++++++++++++++- .../opensearch/test/TestSearchContext.java | 12 ++ .../test/hamcrest/OpenSearchAssertions.java | 4 + .../test/hamcrest/OpenSearchMatchers.java | 30 ++++ 22 files changed, 675 insertions(+), 134 deletions(-) create mode 100644 rest-api-spec/src/main/resources/rest-api-spec/test/search/350_matched_queries.yml diff --git a/CHANGELOG.md b/CHANGELOG.md index b6ec71503a7c8..81a74e3caf05e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -53,8 +53,12 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - Remove concurrent segment search feature flag for GA launch ([#12074](https://github.com/opensearch-project/OpenSearch/pull/12074)) - Enable Fuzzy codec for doc id fields using a bloom filter ([#11022](https://github.com/opensearch-project/OpenSearch/pull/11022)) - [Metrics Framework] Adds support for Histogram metric ([#12062](https://github.com/opensearch-project/OpenSearch/pull/12062)) +<<<<<<< HEAD - [AdmissionControl] Added changes for AdmissionControl Interceptor and AdmissionControlService for RateLimiting ([#9286](https://github.com/opensearch-project/OpenSearch/pull/9286)) - [Admission Control] Integrate CPU AC with ResourceUsageCollector and add CPU AC stats to nodes/stats ([#10887](https://github.com/opensearch-project/OpenSearch/pull/10887)) +======= +- Support for returning scores in matched queries ([#11626](https://github.com/opensearch-project/OpenSearch/pull/11626)) +>>>>>>> 52b27f47bca (add support for scored named queries (#11626)) ### Dependencies - Bumps jetty version to 9.4.52.v20230823 to fix GMS-2023-1857 ([#9822](https://github.com/opensearch-project/OpenSearch/pull/9822)) diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/search.json b/rest-api-spec/src/main/resources/rest-api-spec/api/search.json index e0fbeeb83ffc4..e78d49a67a98a 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/search.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/search.json @@ -229,6 +229,11 @@ "search_pipeline": { "type": "string", "description": "The search pipeline to use to execute this request" + }, + "include_named_queries_score":{ + "type": "boolean", + "description":"Indicates whether hit.matched_queries should be rendered as a map that includes the name of the matched query associated with its score (true) or as an array containing the name of the matched queries (false)", + "default":false } }, "body":{ diff --git a/rest-api-spec/src/main/resources/rest-api-spec/test/search/350_matched_queries.yml b/rest-api-spec/src/main/resources/rest-api-spec/test/search/350_matched_queries.yml new file mode 100644 index 0000000000000..25de51a316bd4 --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/test/search/350_matched_queries.yml @@ -0,0 +1,103 @@ +setup: + - skip: + version: " - 2.12.0" + reason: "implemented for versions post 2.12.0" + +--- +"matched queries": + - do: + indices.create: + index: test + + - do: + bulk: + refresh: true + body: + - '{ "index" : { "_index" : "test_1", "_id" : "1" } }' + - '{"field" : 1 }' + - '{ "index" : { "_index" : "test_1", "_id" : "2" } }' + - '{"field" : [1, 2] }' + + - do: + search: + index: test_1 + body: + query: + bool: { + should: [ + { + match: { + field: { + query: 1, + _name: match_field_1 + } + } + }, + { + match: { + field: { + query: 2, + _name: match_field_2, + boost: 10 + } + } + } + ] + } + + - match: {hits.total.value: 2} + - length: {hits.hits.0.matched_queries: 2} + - match: {hits.hits.0.matched_queries: [ "match_field_1", "match_field_2" ]} + - length: {hits.hits.1.matched_queries: 1} + - match: {hits.hits.1.matched_queries: [ "match_field_1" ]} + +--- + +"matched queries with scores": + - do: + indices.create: + index: test + + - do: + bulk: + refresh: true + body: + - '{ "index" : { "_index" : "test_1", "_id" : "1" } }' + - '{"field" : 1 }' + - '{ "index" : { "_index" : "test_1", "_id" : "2" } }' + - '{"field" : [1, 2] }' + + - do: + search: + include_named_queries_score: true + index: test_1 + body: + query: + bool: { + should: [ + { + match: { + field: { + query: 1, + _name: match_field_1 + } + } + }, + { + match: { + field: { + query: 2, + _name: match_field_2, + boost: 10 + } + } + } + ] + } + + - match: { hits.total.value: 2 } + - length: { hits.hits.0.matched_queries: 2 } + - match: { hits.hits.0.matched_queries.match_field_1: 1 } + - match: { hits.hits.0.matched_queries.match_field_2: 10 } + - length: { hits.hits.1.matched_queries: 1 } + - match: { hits.hits.1.matched_queries.match_field_1: 1 } diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexPrimaryRelocationIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexPrimaryRelocationIT.java index c049c8ed2d4a6..9decd17d95eab 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexPrimaryRelocationIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexPrimaryRelocationIT.java @@ -66,19 +66,16 @@ public void testPrimaryRelocationWhileIndexing() throws Exception { ensureGreen("test"); AtomicInteger numAutoGenDocs = new AtomicInteger(); final AtomicBoolean finished = new AtomicBoolean(false); - Thread indexingThread = new Thread() { - @Override - public void run() { - while (finished.get() == false && numAutoGenDocs.get() < 10_000) { - IndexResponse indexResponse = client().prepareIndex("test").setId("id").setSource("field", "value").get(); - assertEquals(DocWriteResponse.Result.CREATED, indexResponse.getResult()); - DeleteResponse deleteResponse = client().prepareDelete("test", "id").get(); - assertEquals(DocWriteResponse.Result.DELETED, deleteResponse.getResult()); - client().prepareIndex("test").setSource("auto", true).get(); - numAutoGenDocs.incrementAndGet(); - } + Thread indexingThread = new Thread(() -> { + while (finished.get() == false && numAutoGenDocs.get() < 10_000) { + IndexResponse indexResponse = client().prepareIndex("test").setId("id").setSource("field", "value").get(); + assertEquals(DocWriteResponse.Result.CREATED, indexResponse.getResult()); + DeleteResponse deleteResponse = client().prepareDelete("test", "id").get(); + assertEquals(DocWriteResponse.Result.DELETED, deleteResponse.getResult()); + client().prepareIndex("test").setSource("auto", true).get(); + numAutoGenDocs.incrementAndGet(); } - }; + }); indexingThread.start(); ClusterState initialState = client().admin().cluster().prepareState().get().getState(); diff --git a/server/src/internalClusterTest/java/org/opensearch/search/fetch/subphase/MatchedQueriesIT.java b/server/src/internalClusterTest/java/org/opensearch/search/fetch/subphase/MatchedQueriesIT.java index 7a828c06c5cd7..a1adc6f99b92a 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/fetch/subphase/MatchedQueriesIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/fetch/subphase/MatchedQueriesIT.java @@ -61,7 +61,9 @@ import static org.opensearch.search.SearchService.CLUSTER_CONCURRENT_SEGMENT_SEARCH_SETTING; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertHitCount; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.hasItemInArray; +import static org.hamcrest.Matchers.hasKey; public class MatchedQueriesIT extends ParameterizedStaticSettingsOpenSearchIntegTestCase { @@ -95,15 +97,18 @@ public void testSimpleMatchedQueryFromFilteredQuery() throws Exception { .should(rangeQuery("number").gte(2).queryName("test2")) ) ) + .setIncludeNamedQueriesScore(true) .get(); assertHitCount(searchResponse, 3L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("3") || hit.getId().equals("2")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("test2")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("test2")); + assertThat(hit.getMatchedQueryScore("test2"), equalTo(1f)); } else if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("test1")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("test1")); + assertThat(hit.getMatchedQueryScore("test1"), equalTo(1f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -113,15 +118,18 @@ public void testSimpleMatchedQueryFromFilteredQuery() throws Exception { .setQuery( boolQuery().should(rangeQuery("number").lte(2).queryName("test1")).should(rangeQuery("number").gt(2).queryName("test2")) ) + .setIncludeNamedQueriesScore(true) .get(); assertHitCount(searchResponse, 3L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1") || hit.getId().equals("2")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("test1")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("test1")); + assertThat(hit.getMatchedQueryScore("test1"), equalTo(1f)); } else if (hit.getId().equals("3")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("test2")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("test2")); + assertThat(hit.getMatchedQueryScore("test2"), equalTo(1f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -147,12 +155,15 @@ public void testSimpleMatchedQueryFromTopLevelFilter() throws Exception { assertHitCount(searchResponse, 3L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(2)); - assertThat(hit.getMatchedQueries(), hasItemInArray("name")); - assertThat(hit.getMatchedQueries(), hasItemInArray("title")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(2)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("name")); + assertThat(hit.getMatchedQueryScore("name"), greaterThan(0f)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("title")); + assertThat(hit.getMatchedQueryScore("title"), greaterThan(0f)); } else if (hit.getId().equals("2") || hit.getId().equals("3")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("name")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("name")); + assertThat(hit.getMatchedQueryScore("name"), greaterThan(0f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -168,12 +179,15 @@ public void testSimpleMatchedQueryFromTopLevelFilter() throws Exception { assertHitCount(searchResponse, 3L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(2)); - assertThat(hit.getMatchedQueries(), hasItemInArray("name")); - assertThat(hit.getMatchedQueries(), hasItemInArray("title")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(2)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("name")); + assertThat(hit.getMatchedQueryScore("name"), greaterThan(0f)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("title")); + assertThat(hit.getMatchedQueryScore("title"), greaterThan(0f)); } else if (hit.getId().equals("2") || hit.getId().equals("3")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("name")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("name")); + assertThat(hit.getMatchedQueryScore("name"), greaterThan(0f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -197,9 +211,11 @@ public void testSimpleMatchedQueryFromTopLevelFilterAndFilteredQuery() throws Ex assertHitCount(searchResponse, 3L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1") || hit.getId().equals("2") || hit.getId().equals("3")) { - assertThat(hit.getMatchedQueries().length, equalTo(2)); - assertThat(hit.getMatchedQueries(), hasItemInArray("name")); - assertThat(hit.getMatchedQueries(), hasItemInArray("title")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(2)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("name")); + assertThat(hit.getMatchedQueryScore("name"), greaterThan(0f)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("title")); + assertThat(hit.getMatchedQueryScore("title"), greaterThan(0f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -231,13 +247,15 @@ public void testRegExpQuerySupportsName() throws InterruptedException { SearchResponse searchResponse = client().prepareSearch() .setQuery(QueryBuilders.regexpQuery("title", "title1").queryName("regex")) + .setIncludeNamedQueriesScore(true) .get(); assertHitCount(searchResponse, 1L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("regex")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("regex")); + assertThat(hit.getMatchedQueryScore("regex"), equalTo(1f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -252,15 +270,17 @@ public void testPrefixQuerySupportsName() throws InterruptedException { refresh(); indexRandomForConcurrentSearch("test1"); - SearchResponse searchResponse = client().prepareSearch() + var query = client().prepareSearch() .setQuery(QueryBuilders.prefixQuery("title", "title").queryName("prefix")) - .get(); + .setIncludeNamedQueriesScore(true); + var searchResponse = query.get(); assertHitCount(searchResponse, 1L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("prefix")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("prefix")); + assertThat(hit.getMatchedQueryScore("prefix"), equalTo(1f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -282,8 +302,9 @@ public void testFuzzyQuerySupportsName() throws InterruptedException { for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("fuzzy")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("fuzzy")); + assertThat(hit.getMatchedQueryScore("fuzzy"), greaterThan(0f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -300,13 +321,15 @@ public void testWildcardQuerySupportsName() throws InterruptedException { SearchResponse searchResponse = client().prepareSearch() .setQuery(QueryBuilders.wildcardQuery("title", "titl*").queryName("wildcard")) + .setIncludeNamedQueriesScore(true) .get(); assertHitCount(searchResponse, 1L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("wildcard")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("wildcard")); + assertThat(hit.getMatchedQueryScore("wildcard"), equalTo(1f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -328,8 +351,9 @@ public void testSpanFirstQuerySupportsName() throws InterruptedException { for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("span")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("span")); + assertThat(hit.getMatchedQueryScore("span"), greaterThan(0f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -363,11 +387,13 @@ public void testMatchedWithShould() throws Exception { assertHitCount(searchResponse, 2L); for (SearchHit hit : searchResponse.getHits()) { if (hit.getId().equals("1")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("dolor")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("dolor")); + assertThat(hit.getMatchedQueryScore("dolor"), greaterThan(0f)); } else if (hit.getId().equals("2")) { - assertThat(hit.getMatchedQueries().length, equalTo(1)); - assertThat(hit.getMatchedQueries(), hasItemInArray("elit")); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("elit")); + assertThat(hit.getMatchedQueryScore("elit"), greaterThan(0f)); } else { fail("Unexpected document returned with id " + hit.getId()); } @@ -391,7 +417,10 @@ public void testMatchedWithWrapperQuery() throws Exception { for (QueryBuilder query : queries) { SearchResponse searchResponse = client().prepareSearch().setQuery(query).get(); assertHitCount(searchResponse, 1L); - assertThat(searchResponse.getHits().getAt(0).getMatchedQueries()[0], equalTo("abc")); + SearchHit hit = searchResponse.getHits().getAt(0); + assertThat(hit.getMatchedQueriesAndScores().size(), equalTo(1)); + assertThat(hit.getMatchedQueriesAndScores(), hasKey("abc")); + assertThat(hit.getMatchedQueryScore("abc"), greaterThan(0f)); } } } diff --git a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/QueryRescorerIT.java b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/QueryRescorerIT.java index 6c4ea0cdeb1f1..5121d5023fd95 100644 --- a/server/src/internalClusterTest/java/org/opensearch/search/functionscore/QueryRescorerIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/search/functionscore/QueryRescorerIT.java @@ -83,6 +83,7 @@ import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertSecondHit; import static org.opensearch.test.hamcrest.OpenSearchAssertions.assertThirdHit; import static org.opensearch.test.hamcrest.OpenSearchAssertions.hasId; +import static org.opensearch.test.hamcrest.OpenSearchAssertions.hasMatchedQueries; import static org.opensearch.test.hamcrest.OpenSearchAssertions.hasScore; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -594,7 +595,7 @@ public void testExplain() throws Exception { SearchResponse searchResponse = client().prepareSearch() .setSearchType(SearchType.DFS_QUERY_THEN_FETCH) - .setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(Operator.OR)) + .setQuery(QueryBuilders.matchQuery("field1", "the quick brown").operator(Operator.OR).queryName("hello-world")) .setRescorer(innerRescoreQuery, 5) .setExplain(true) .get(); @@ -602,7 +603,10 @@ public void testExplain() throws Exception { assertFirstHit(searchResponse, hasId("1")); assertSecondHit(searchResponse, hasId("2")); assertThirdHit(searchResponse, hasId("3")); - + final String[] matchedQueries = { "hello-world" }; + assertFirstHit(searchResponse, hasMatchedQueries(matchedQueries)); + assertSecondHit(searchResponse, hasMatchedQueries(matchedQueries)); + assertThirdHit(searchResponse, hasMatchedQueries(matchedQueries)); for (int j = 0; j < 3; j++) { assertThat(searchResponse.getHits().getAt(j).getExplanation().getDescription(), equalTo(descriptionModes[innerMode])); } diff --git a/server/src/main/java/org/opensearch/action/search/SearchRequestBuilder.java b/server/src/main/java/org/opensearch/action/search/SearchRequestBuilder.java index e949c5e0bea29..9dac827e7d518 100644 --- a/server/src/main/java/org/opensearch/action/search/SearchRequestBuilder.java +++ b/server/src/main/java/org/opensearch/action/search/SearchRequestBuilder.java @@ -406,6 +406,15 @@ public SearchRequestBuilder setTrackScores(boolean trackScores) { return this; } + /** + * Applies when fetching scores with named queries, and controls if scores will be tracked as well. + * Defaults to {@code false}. + */ + public SearchRequestBuilder setIncludeNamedQueriesScore(boolean includeNamedQueriesScore) { + sourceBuilder().includeNamedQueriesScores(includeNamedQueriesScore); + return this; + } + /** * Indicates if the total hit count for the query should be tracked. Requests will count total hit count accurately * up to 10,000 by default, see {@link #setTrackTotalHitsUpTo(int)} to change this value or set to true/false to always/never diff --git a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java index 080366e536da1..80dc34c4d5d68 100644 --- a/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java +++ b/server/src/main/java/org/opensearch/rest/action/search/RestSearchAction.java @@ -86,10 +86,13 @@ public class RestSearchAction extends BaseRestHandler { */ public static final String TOTAL_HITS_AS_INT_PARAM = "rest_total_hits_as_int"; public static final String TYPED_KEYS_PARAM = "typed_keys"; + public static final String INCLUDE_NAMED_QUERIES_SCORE_PARAM = "include_named_queries_score"; private static final Set RESPONSE_PARAMS; static { - final Set responseParams = new HashSet<>(Arrays.asList(TYPED_KEYS_PARAM, TOTAL_HITS_AS_INT_PARAM)); + final Set responseParams = new HashSet<>( + Arrays.asList(TYPED_KEYS_PARAM, TOTAL_HITS_AS_INT_PARAM, INCLUDE_NAMED_QUERIES_SCORE_PARAM) + ); RESPONSE_PARAMS = Collections.unmodifiableSet(responseParams); } @@ -209,6 +212,7 @@ public static void parseSearchRequest( searchRequest.pipeline(request.param("search_pipeline")); checkRestTotalHits(request, searchRequest); + request.paramAsBoolean(INCLUDE_NAMED_QUERIES_SCORE_PARAM, false); if (searchRequest.pointInTimeBuilder() != null) { preparePointInTime(searchRequest, request, namedWriteableRegistry); @@ -286,6 +290,10 @@ private static void parseSearchSource(final SearchSourceBuilder searchSourceBuil searchSourceBuilder.trackScores(request.paramAsBoolean("track_scores", false)); } + if (request.hasParam("include_named_queries_score")) { + searchSourceBuilder.includeNamedQueriesScores(request.paramAsBoolean("include_named_queries_score", false)); + } + if (request.hasParam("track_total_hits")) { if (Booleans.isBoolean(request.param("track_total_hits"))) { searchSourceBuilder.trackTotalHits(request.paramAsBoolean("track_total_hits", true)); diff --git a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java index 498d763c00db5..4195cc67f7af1 100644 --- a/server/src/main/java/org/opensearch/search/DefaultSearchContext.java +++ b/server/src/main/java/org/opensearch/search/DefaultSearchContext.java @@ -148,6 +148,8 @@ final class DefaultSearchContext extends SearchContext { private SortAndFormats sort; private Float minimumScore; private boolean trackScores = false; // when sorting, track scores as well... + + private boolean includeNamedQueriesScore = false; private int trackTotalHitsUpTo = SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO; private FieldDoc searchAfter; private CollapseContext collapse; @@ -635,6 +637,17 @@ public boolean trackScores() { return this.trackScores; } + @Override + public SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore) { + this.includeNamedQueriesScore = includeNamedQueriesScore; + return this; + } + + @Override + public boolean includeNamedQueriesScore() { + return includeNamedQueriesScore; + } + @Override public SearchContext trackTotalHitsUpTo(int trackTotalHitsUpTo) { this.trackTotalHitsUpTo = trackTotalHitsUpTo; diff --git a/server/src/main/java/org/opensearch/search/SearchHit.java b/server/src/main/java/org/opensearch/search/SearchHit.java index ff5457ce3ad77..c8ddd3584c110 100644 --- a/server/src/main/java/org/opensearch/search/SearchHit.java +++ b/server/src/main/java/org/opensearch/search/SearchHit.java @@ -65,19 +65,21 @@ import org.opensearch.index.mapper.MapperService; import org.opensearch.index.mapper.SourceFieldMapper; import org.opensearch.index.seqno.SequenceNumbers; +import org.opensearch.rest.action.search.RestSearchAction; import org.opensearch.search.fetch.subphase.highlight.HighlightField; import org.opensearch.search.lookup.SourceLookup; import org.opensearch.transport.RemoteClusterAware; import java.io.IOException; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.HashMap; import java.util.Iterator; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.stream.Collectors; import static java.util.Collections.emptyMap; import static java.util.Collections.singletonMap; @@ -121,7 +123,7 @@ public final class SearchHit implements Writeable, ToXContentObject, Iterable matchedQueries = new HashMap<>(); private Explanation explanation; @@ -216,10 +218,20 @@ public SearchHit(StreamInput in) throws IOException { sortValues = new SearchSortValues(in); size = in.readVInt(); - if (size > 0) { - matchedQueries = new String[size]; + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { + if (size > 0) { + Map tempMap = in.readMap(StreamInput::readString, StreamInput::readFloat); + matchedQueries = tempMap.entrySet() + .stream() + .sorted(Map.Entry.comparingByKey()) + .collect( + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue, (oldValue, newValue) -> oldValue, LinkedHashMap::new) + ); + } + } else { + matchedQueries = new LinkedHashMap<>(size); for (int i = 0; i < size; i++) { - matchedQueries[i] = in.readString(); + matchedQueries.put(in.readString(), Float.NaN); } } // we call the setter here because that also sets the local index parameter @@ -237,36 +249,6 @@ public SearchHit(StreamInput in) throws IOException { } } - private Map readFields(StreamInput in) throws IOException { - Map fields; - int size = in.readVInt(); - if (size == 0) { - fields = emptyMap(); - } else if (size == 1) { - DocumentField hitField = new DocumentField(in); - fields = singletonMap(hitField.getName(), hitField); - } else { - fields = new HashMap<>(size); - for (int i = 0; i < size; i++) { - DocumentField field = new DocumentField(in); - fields.put(field.getName(), field); - } - fields = unmodifiableMap(fields); - } - return fields; - } - - private void writeFields(StreamOutput out, Map fields) throws IOException { - if (fields == null) { - out.writeVInt(0); - } else { - out.writeVInt(fields.size()); - for (DocumentField field : fields.values()) { - field.writeTo(out); - } - } - } - private static final Text SINGLE_MAPPING_TYPE = new Text(MapperService.SINGLE_MAPPING_NAME); @Override @@ -303,11 +285,13 @@ public void writeTo(StreamOutput out) throws IOException { } sortValues.writeTo(out); - if (matchedQueries.length == 0) { - out.writeVInt(0); + out.writeVInt(matchedQueries.size()); + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + if (!matchedQueries.isEmpty()) { + out.writeMap(matchedQueries, StreamOutput::writeString, StreamOutput::writeFloat); + } } else { - out.writeVInt(matchedQueries.length); - for (String matchedFilter : matchedQueries) { + for (String matchedFilter : matchedQueries.keySet()) { out.writeString(matchedFilter); } } @@ -475,11 +459,11 @@ public DocumentField field(String fieldName) { } /* - * Adds a new DocumentField to the map in case both parameters are not null. - * */ + * Adds a new DocumentField to the map in case both parameters are not null. + * */ public void setDocumentField(String fieldName, DocumentField field) { if (fieldName == null || field == null) return; - if (documentFields.size() == 0) this.documentFields = new HashMap<>(); + if (documentFields.isEmpty()) this.documentFields = new HashMap<>(); this.documentFields.put(fieldName, field); } @@ -492,7 +476,7 @@ public DocumentField removeDocumentField(String fieldName) { * were required to be loaded. */ public Map getFields() { - if (metaFields.size() > 0 || documentFields.size() > 0) { + if (!metaFields.isEmpty() || !documentFields.isEmpty()) { final Map fields = new HashMap<>(); fields.putAll(metaFields); fields.putAll(documentFields); @@ -577,14 +561,45 @@ public String getClusterAlias() { } public void matchedQueries(String[] matchedQueries) { - this.matchedQueries = matchedQueries; + if (matchedQueries != null) { + for (String query : matchedQueries) { + this.matchedQueries.put(query, Float.NaN); + } + } + } + + public void matchedQueriesWithScores(Map matchedQueries) { + if (matchedQueries != null) { + this.matchedQueries = matchedQueries; + } } /** * The set of query and filter names the query matched with. Mainly makes sense for compound filters and queries. */ public String[] getMatchedQueries() { - return this.matchedQueries; + return matchedQueries == null ? new String[0] : matchedQueries.keySet().toArray(new String[0]); + } + + /** + * Returns the score of the provided named query if it matches. + *

+ * If the 'include_named_queries_score' is not set, this method will return {@link Float#NaN} + * for each named query instead of a numerical score. + *

+ * + * @param name The name of the query to retrieve the score for. + * @return The score of the named query, or {@link Float#NaN} if 'include_named_queries_score' is not set. + */ + public Float getMatchedQueryScore(String name) { + return getMatchedQueriesAndScores().get(name); + } + + /** + * @return The map of the named queries that matched and their associated score. + */ + public Map getMatchedQueriesAndScores() { + return matchedQueries == null ? Collections.emptyMap() : matchedQueries; } /** @@ -671,7 +686,7 @@ public XContentBuilder toInnerXContent(XContentBuilder builder, Params params) t for (DocumentField field : metaFields.values()) { // ignore empty metadata fields - if (field.getValues().size() == 0) { + if (field.getValues().isEmpty()) { continue; } // _ignored is the only multi-valued meta field @@ -687,10 +702,10 @@ public XContentBuilder toInnerXContent(XContentBuilder builder, Params params) t } if (documentFields.isEmpty() == false && // ignore fields all together if they are all empty - documentFields.values().stream().anyMatch(df -> df.getValues().size() > 0)) { + documentFields.values().stream().anyMatch(df -> !df.getValues().isEmpty())) { builder.startObject(Fields.FIELDS); for (DocumentField field : documentFields.values()) { - if (field.getValues().size() > 0) { + if (!field.getValues().isEmpty()) { field.toXContent(builder, params); } } @@ -704,12 +719,21 @@ public XContentBuilder toInnerXContent(XContentBuilder builder, Params params) t builder.endObject(); } sortValues.toXContent(builder, params); - if (matchedQueries.length > 0) { - builder.startArray(Fields.MATCHED_QUERIES); - for (String matchedFilter : matchedQueries) { - builder.value(matchedFilter); + if (!matchedQueries.isEmpty()) { + boolean includeMatchedQueriesScore = params.paramAsBoolean(RestSearchAction.INCLUDE_NAMED_QUERIES_SCORE_PARAM, false); + if (includeMatchedQueriesScore) { + builder.startObject(Fields.MATCHED_QUERIES); + for (Map.Entry entry : matchedQueries.entrySet()) { + builder.field(entry.getKey(), entry.getValue()); + } + builder.endObject(); + } else { + builder.startArray(Fields.MATCHED_QUERIES); + for (String matchedFilter : matchedQueries.keySet()) { + builder.value(matchedFilter); + } + builder.endArray(); } - builder.endArray(); } if (getExplanation() != null) { builder.field(Fields._EXPLANATION); @@ -814,7 +838,27 @@ public static void declareInnerHitsParseFields(ObjectParser, (p, c) -> parseInnerHits(p), new ParseField(Fields.INNER_HITS) ); - parser.declareStringArray((map, list) -> map.put(Fields.MATCHED_QUERIES, list), new ParseField(Fields.MATCHED_QUERIES)); + parser.declareField((p, map, context) -> { + XContentParser.Token token = p.currentToken(); + Map matchedQueries = new LinkedHashMap<>(); + if (token == XContentParser.Token.START_OBJECT) { + String fieldName = null; + while ((token = p.nextToken()) != XContentParser.Token.END_OBJECT) { + if (token == XContentParser.Token.FIELD_NAME) { + fieldName = p.currentName(); + } else if (token.isValue()) { + matchedQueries.put(fieldName, p.floatValue()); + } + } + } else if (token == XContentParser.Token.START_ARRAY) { + while (p.nextToken() != XContentParser.Token.END_ARRAY) { + matchedQueries.put(p.text(), Float.NaN); + } + } else { + throw new IllegalStateException("expected object or array but got [" + token + "]"); + } + map.put(Fields.MATCHED_QUERIES, matchedQueries); + }, new ParseField(Fields.MATCHED_QUERIES), ObjectParser.ValueType.OBJECT_ARRAY); parser.declareField( (map, list) -> map.put(Fields.SORT, list), SearchSortValues::fromXContent, @@ -845,7 +889,7 @@ public static SearchHit createFromMap(Map values) { assert shardId.getIndexName().equals(index); searchHit.shard(new SearchShardTarget(nodeId, shardId, clusterAlias, OriginalIndices.NONE)); } else { - // these fields get set anyways when setting the shard target, + // these fields get set anyway when setting the shard target, // but we set them explicitly when we don't have enough info to rebuild the shard target searchHit.index = index; searchHit.clusterAlias = clusterAlias; @@ -859,10 +903,7 @@ public static SearchHit createFromMap(Map values) { searchHit.sourceRef(get(SourceFieldMapper.NAME, values, null)); searchHit.explanation(get(Fields._EXPLANATION, values, null)); searchHit.setInnerHits(get(Fields.INNER_HITS, values, null)); - List matchedQueries = get(Fields.MATCHED_QUERIES, values, null); - if (matchedQueries != null) { - searchHit.matchedQueries(matchedQueries.toArray(new String[0])); - } + searchHit.matchedQueriesWithScores(get(Fields.MATCHED_QUERIES, values, null)); return searchHit; } @@ -982,7 +1023,7 @@ public boolean equals(Object obj) { && Objects.equals(documentFields, other.documentFields) && Objects.equals(metaFields, other.metaFields) && Objects.equals(getHighlightFields(), other.getHighlightFields()) - && Arrays.equals(matchedQueries, other.matchedQueries) + && Objects.equals(getMatchedQueriesAndScores(), other.getMatchedQueriesAndScores()) && Objects.equals(explanation, other.explanation) && Objects.equals(shard, other.shard) && Objects.equals(innerHits, other.innerHits) @@ -1002,7 +1043,7 @@ public int hashCode() { documentFields, metaFields, getHighlightFields(), - Arrays.hashCode(matchedQueries), + getMatchedQueriesAndScores(), explanation, shard, innerHits, diff --git a/server/src/main/java/org/opensearch/search/SearchService.java b/server/src/main/java/org/opensearch/search/SearchService.java index 572ef2034ba38..d48f6f6522ca5 100644 --- a/server/src/main/java/org/opensearch/search/SearchService.java +++ b/server/src/main/java/org/opensearch/search/SearchService.java @@ -1275,6 +1275,7 @@ private void parseSource(DefaultSearchContext context, SearchSourceBuilder sourc } } context.trackScores(source.trackScores()); + context.includeNamedQueriesScore(source.includeNamedQueriesScore()); if (source.trackTotalHitsUpTo() != null && source.trackTotalHitsUpTo() != SearchContext.TRACK_TOTAL_HITS_ACCURATE && context.scrollContext() != null) { diff --git a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java index 9661e0fe707b3..72e0c6be5416b 100644 --- a/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java +++ b/server/src/main/java/org/opensearch/search/builder/SearchSourceBuilder.java @@ -118,6 +118,7 @@ public final class SearchSourceBuilder implements Writeable, ToXContentObject, R public static final ParseField IGNORE_FAILURE_FIELD = new ParseField("ignore_failure"); public static final ParseField SORT_FIELD = new ParseField("sort"); public static final ParseField TRACK_SCORES_FIELD = new ParseField("track_scores"); + public static final ParseField INCLUDE_NAMED_QUERIES_SCORE = new ParseField("include_named_queries_score"); public static final ParseField TRACK_TOTAL_HITS_FIELD = new ParseField("track_total_hits"); public static final ParseField INDICES_BOOST_FIELD = new ParseField("indices_boost"); public static final ParseField AGGREGATIONS_FIELD = new ParseField("aggregations"); @@ -176,6 +177,8 @@ public static HighlightBuilder highlight() { private boolean trackScores = false; + private Boolean includeNamedQueriesScore; + private Integer trackTotalHitsUpTo; private SearchAfterBuilder searchAfterBuilder; @@ -285,6 +288,9 @@ public SearchSourceBuilder(StreamInput in) throws IOException { searchPipelineSource = in.readMap(); } } + if (in.getVersion().onOrAfter(Version.V_3_0_0)) { + includeNamedQueriesScore = in.readOptionalBoolean(); + } } @Override @@ -358,6 +364,9 @@ public void writeTo(StreamOutput out) throws IOException { out.writeMap(searchPipelineSource); } } + if (out.getVersion().onOrAfter(Version.V_3_0_0)) { + out.writeOptionalBoolean(includeNamedQueriesScore); + } } /** @@ -585,6 +594,22 @@ public SearchSourceBuilder trackScores(boolean trackScores) { return this; } + /** + * Applies when there are named queries, to return the scores along as well + * Defaults to {@code false}. + */ + public SearchSourceBuilder includeNamedQueriesScores(boolean includeNamedQueriesScore) { + this.includeNamedQueriesScore = includeNamedQueriesScore; + return this; + } + + /** + * Indicates whether scores will be returned as part of every search matched query.s + */ + public boolean includeNamedQueriesScore() { + return includeNamedQueriesScore != null && includeNamedQueriesScore; + } + /** * Indicates whether scores will be tracked for this request. */ @@ -1120,6 +1145,7 @@ private SearchSourceBuilder shallowCopy( rewrittenBuilder.terminateAfter = terminateAfter; rewrittenBuilder.timeout = timeout; rewrittenBuilder.trackScores = trackScores; + rewrittenBuilder.includeNamedQueriesScore = includeNamedQueriesScore; rewrittenBuilder.trackTotalHitsUpTo = trackTotalHitsUpTo; rewrittenBuilder.version = version; rewrittenBuilder.seqNoAndPrimaryTerm = seqNoAndPrimaryTerm; @@ -1172,6 +1198,8 @@ public void parseXContent(XContentParser parser, boolean checkTrailingTokens) th explain = parser.booleanValue(); } else if (TRACK_SCORES_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { trackScores = parser.booleanValue(); + } else if (INCLUDE_NAMED_QUERIES_SCORE.match(currentFieldName, parser.getDeprecationHandler())) { + includeNamedQueriesScore = parser.booleanValue(); } else if (TRACK_TOTAL_HITS_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { if (token == XContentParser.Token.VALUE_BOOLEAN || (token == XContentParser.Token.VALUE_STRING && Booleans.isBoolean(parser.text()))) { @@ -1435,6 +1463,10 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params) t builder.field(TRACK_SCORES_FIELD.getPreferredName(), true); } + if (includeNamedQueriesScore != null) { + builder.field(INCLUDE_NAMED_QUERIES_SCORE.getPreferredName(), includeNamedQueriesScore); + } + if (trackTotalHitsUpTo != null) { builder.field(TRACK_TOTAL_HITS_FIELD.getPreferredName(), trackTotalHitsUpTo); } @@ -1766,6 +1798,7 @@ public int hashCode() { terminateAfter, timeout, trackScores, + includeNamedQueriesScore, version, seqNoAndPrimaryTerm, profile, @@ -1808,6 +1841,7 @@ public boolean equals(Object obj) { && Objects.equals(terminateAfter, other.terminateAfter) && Objects.equals(timeout, other.timeout) && Objects.equals(trackScores, other.trackScores) + && Objects.equals(includeNamedQueriesScore, other.includeNamedQueriesScore) && Objects.equals(version, other.version) && Objects.equals(seqNoAndPrimaryTerm, other.seqNoAndPrimaryTerm) && Objects.equals(profile, other.profile) diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchContext.java b/server/src/main/java/org/opensearch/search/fetch/FetchContext.java index 7e36ace9e2112..5be3733106655 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchContext.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchContext.java @@ -188,6 +188,10 @@ public boolean fetchScores() { return searchContext.sort() != null && searchContext.trackScores(); } + public boolean includeNamedQueriesScore() { + return searchContext.includeNamedQueriesScore(); + } + /** * Configuration for returning inner hits */ diff --git a/server/src/main/java/org/opensearch/search/fetch/FetchPhase.java b/server/src/main/java/org/opensearch/search/fetch/FetchPhase.java index a842c0f1adc6e..1698f41caaf2b 100644 --- a/server/src/main/java/org/opensearch/search/fetch/FetchPhase.java +++ b/server/src/main/java/org/opensearch/search/fetch/FetchPhase.java @@ -91,7 +91,7 @@ /** * Fetch phase of a search request, used to fetch the actual top matching documents to be returned to the client, identified - * after reducing all of the matches returned by the query phase + * after reducing all the matches returned by the query phase * * @opensearch.api */ diff --git a/server/src/main/java/org/opensearch/search/fetch/subphase/MatchedQueriesPhase.java b/server/src/main/java/org/opensearch/search/fetch/subphase/MatchedQueriesPhase.java index 6c589438d6b4c..406d9c8b4bc03 100644 --- a/server/src/main/java/org/opensearch/search/fetch/subphase/MatchedQueriesPhase.java +++ b/server/src/main/java/org/opensearch/search/fetch/subphase/MatchedQueriesPhase.java @@ -28,12 +28,12 @@ * Modifications Copyright OpenSearch Contributors. See * GitHub history for details. */ - package org.opensearch.search.fetch.subphase; import org.apache.lucene.index.LeafReaderContext; import org.apache.lucene.search.Query; import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; import org.apache.lucene.search.ScorerSupplier; import org.apache.lucene.search.Weight; import org.apache.lucene.util.Bits; @@ -45,6 +45,7 @@ import java.io.IOException; import java.util.ArrayList; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -67,25 +68,69 @@ public FetchSubPhaseProcessor getProcessor(FetchContext context) throws IOExcept if (namedQueries.isEmpty()) { return null; } + + Map weights = prepareWeights(context, namedQueries); + + return context.includeNamedQueriesScore() ? createScoringProcessor(weights) : createNonScoringProcessor(weights); + } + + private Map prepareWeights(FetchContext context, Map namedQueries) throws IOException { Map weights = new HashMap<>(); + ScoreMode scoreMode = context.includeNamedQueriesScore() ? ScoreMode.COMPLETE : ScoreMode.COMPLETE_NO_SCORES; for (Map.Entry entry : namedQueries.entrySet()) { - weights.put( - entry.getKey(), - context.searcher().createWeight(context.searcher().rewrite(entry.getValue()), ScoreMode.COMPLETE_NO_SCORES, 1) - ); + weights.put(entry.getKey(), context.searcher().createWeight(context.searcher().rewrite(entry.getValue()), scoreMode, 1)); } + return weights; + } + + private FetchSubPhaseProcessor createScoringProcessor(Map weights) { return new FetchSubPhaseProcessor() { + final Map matchingScorers = new HashMap<>(); + + @Override + public void setNextReader(LeafReaderContext readerContext) throws IOException { + matchingScorers.clear(); + for (Map.Entry entry : weights.entrySet()) { + ScorerSupplier scorerSupplier = entry.getValue().scorerSupplier(readerContext); + if (scorerSupplier != null) { + Scorer scorer = scorerSupplier.get(0L); + if (scorer != null) { + matchingScorers.put(entry.getKey(), scorer); + } + } + } + } + + @Override + public void process(HitContext hitContext) throws IOException { + Map matches = new LinkedHashMap<>(); + int docId = hitContext.docId(); + for (Map.Entry entry : matchingScorers.entrySet()) { + Scorer scorer = entry.getValue(); + if (scorer.iterator().docID() < docId) { + scorer.iterator().advance(docId); + } + if (scorer.iterator().docID() == docId) { + matches.put(entry.getKey(), scorer.score()); + } + } + hitContext.hit().matchedQueriesWithScores(matches); + } + }; + } - final Map matchingIterators = new HashMap<>(); + private FetchSubPhaseProcessor createNonScoringProcessor(Map weights) { + return new FetchSubPhaseProcessor() { + final Map matchingBits = new HashMap<>(); @Override public void setNextReader(LeafReaderContext readerContext) throws IOException { - matchingIterators.clear(); + matchingBits.clear(); for (Map.Entry entry : weights.entrySet()) { - ScorerSupplier ss = entry.getValue().scorerSupplier(readerContext); - if (ss != null) { - Bits matchingBits = Lucene.asSequentialAccessBits(readerContext.reader().maxDoc(), ss); - matchingIterators.put(entry.getKey(), matchingBits); + ScorerSupplier scorerSupplier = entry.getValue().scorerSupplier(readerContext); + if (scorerSupplier != null) { + Bits bits = Lucene.asSequentialAccessBits(readerContext.reader().maxDoc(), scorerSupplier); + matchingBits.put(entry.getKey(), bits); } } } @@ -93,15 +138,14 @@ public void setNextReader(LeafReaderContext readerContext) throws IOException { @Override public void process(HitContext hitContext) { List matches = new ArrayList<>(); - int doc = hitContext.docId(); - for (Map.Entry iterator : matchingIterators.entrySet()) { - if (iterator.getValue().get(doc)) { - matches.add(iterator.getKey()); + int docId = hitContext.docId(); + for (Map.Entry entry : matchingBits.entrySet()) { + if (entry.getValue().get(docId)) { + matches.add(entry.getKey()); } } hitContext.hit().matchedQueries(matches.toArray(new String[0])); } }; } - } diff --git a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java index 327552cbfccdb..6a0fb376fa9bf 100644 --- a/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/FilteredSearchContext.java @@ -340,6 +340,14 @@ public FieldDoc searchAfter() { return in.searchAfter(); } + public SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore) { + return in.includeNamedQueriesScore(includeNamedQueriesScore); + } + + public boolean includeNamedQueriesScore() { + return in.includeNamedQueriesScore(); + } + @Override public SearchContext parsedPostFilter(ParsedQuery postFilter) { return in.parsedPostFilter(postFilter); diff --git a/server/src/main/java/org/opensearch/search/internal/SearchContext.java b/server/src/main/java/org/opensearch/search/internal/SearchContext.java index 57f5dc955a5da..7c4f13b946367 100644 --- a/server/src/main/java/org/opensearch/search/internal/SearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SearchContext.java @@ -305,6 +305,29 @@ public final void assignRescoreDocIds(RescoreDocIds rescoreDocIds) { public abstract boolean trackScores(); + /** + * Determines whether named queries' scores should be included in the search results. + * By default, this is set to return false, indicating that scores from named queries are not included. + * + * @param includeNamedQueriesScore true to include scores from named queries, false otherwise. + */ + public SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore) { + // Default implementation does nothing and returns this for chaining. + // Implementations of SearchContext should override this method to actually store the value. + return this; + } + + /** + * Checks if scores from named queries are included in the search results. + * + * @return true if scores from named queries are included, false otherwise. + */ + public boolean includeNamedQueriesScore() { + // Default implementation returns false. + // Implementations of SearchContext should override this method to return the actual value. + return false; + } + public abstract SearchContext trackTotalHitsUpTo(int trackTotalHits); /** diff --git a/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java b/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java index 55315013ea8c9..b2c97baf78d91 100644 --- a/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java +++ b/server/src/main/java/org/opensearch/search/internal/SubSearchContext.java @@ -82,6 +82,8 @@ public class SubSearchContext extends FilteredSearchContext { private boolean explain; private boolean trackScores; + + private boolean includeNamedQueriesScore; private boolean version; private boolean seqNoAndPrimaryTerm; @@ -234,6 +236,17 @@ public boolean trackScores() { return trackScores; } + @Override + public SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore) { + this.includeNamedQueriesScore = includeNamedQueriesScore; + return this; + } + + @Override + public boolean includeNamedQueriesScore() { + return includeNamedQueriesScore; + } + @Override public SearchContext parsedPostFilter(ParsedQuery postFilter) { throw new UnsupportedOperationException("Not supported"); diff --git a/server/src/test/java/org/opensearch/search/SearchHitTests.java b/server/src/test/java/org/opensearch/search/SearchHitTests.java index 88d5fb38a6cb1..13b4d9f976ed5 100644 --- a/server/src/test/java/org/opensearch/search/SearchHitTests.java +++ b/server/src/test/java/org/opensearch/search/SearchHitTests.java @@ -56,11 +56,13 @@ import org.opensearch.test.AbstractWireSerializingTestCase; import org.opensearch.test.RandomObjects; import org.opensearch.test.VersionUtils; +import org.junit.Assert; import java.io.IOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.function.Predicate; @@ -76,6 +78,25 @@ import static org.hamcrest.Matchers.nullValue; public class SearchHitTests extends AbstractWireSerializingTestCase { + + private Map getSampleMatchedQueries() { + Map matchedQueries = new LinkedHashMap<>(); + matchedQueries.put("query1", 1.0f); + matchedQueries.put("query2", 0.5f); + return matchedQueries; + } + + public static SearchHit createTestItemWithMatchedQueriesScores(boolean withOptionalInnerHits, boolean withShardTarget) { + var searchHit = createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget); + int size = randomIntBetween(1, 5); // Ensure at least one matched query + Map matchedQueries = new LinkedHashMap<>(size); + for (int i = 0; i < size; i++) { + matchedQueries.put(randomAlphaOfLength(5), randomFloat()); + } + searchHit.matchedQueriesWithScores(matchedQueries); + return searchHit; + } + public static SearchHit createTestItem(boolean withOptionalInnerHits, boolean withShardTarget) { return createTestItem(randomFrom(XContentType.values()), withOptionalInnerHits, withShardTarget); } @@ -129,11 +150,11 @@ public static SearchHit createTestItem(final MediaType mediaType, boolean withOp } if (randomBoolean()) { int size = randomIntBetween(0, 5); - String[] matchedQueries = new String[size]; + Map matchedQueries = new LinkedHashMap<>(size); for (int i = 0; i < size; i++) { - matchedQueries[i] = randomAlphaOfLength(5); + matchedQueries.put(randomAlphaOfLength(5), Float.NaN); } - hit.matchedQueries(matchedQueries); + hit.matchedQueriesWithScores(matchedQueries); } if (randomBoolean()) { hit.explanation(createExplanation(randomIntBetween(0, 5))); @@ -219,6 +240,21 @@ public void testFromXContentLenientParsing() throws IOException { assertToXContentEquivalent(originalBytes, toXContent(parsed, xContentType, true), xContentType); } + public void testSerializationDeserializationWithMatchedQueriesScores() throws IOException { + SearchHit searchHit = createTestItemWithMatchedQueriesScores(true, true); + SearchHit deserializedSearchHit = copyWriteable(searchHit, getNamedWriteableRegistry(), SearchHit::new, Version.V_3_0_0); + assertEquals(searchHit, deserializedSearchHit); + assertEquals(searchHit.getMatchedQueriesAndScores(), deserializedSearchHit.getMatchedQueriesAndScores()); + } + + public void testSerializationDeserializationWithMatchedQueriesList() throws IOException { + SearchHit searchHit = createTestItem(true, true); + SearchHit deserializedSearchHit = copyWriteable(searchHit, getNamedWriteableRegistry(), SearchHit::new, Version.V_2_12_0); + assertEquals(searchHit, deserializedSearchHit); + assertEquals(searchHit.getMatchedQueriesAndScores(), deserializedSearchHit.getMatchedQueriesAndScores()); + Assert.assertArrayEquals(searchHit.getMatchedQueries(), deserializedSearchHit.getMatchedQueries()); + } + /** * When e.g. with "stored_fields": "_none_", only "_index" and "_score" are returned. */ @@ -244,6 +280,125 @@ public void testToXContent() throws IOException { assertEquals("{\"_id\":\"id1\",\"_score\":1.5}", builder.toString()); } + public void testSerializeShardTargetWithNewVersion() throws Exception { + String clusterAlias = randomBoolean() ? null : "cluster_alias"; + SearchShardTarget target = new SearchShardTarget( + "_node_id", + new ShardId(new Index("_index", "_na_"), 0), + clusterAlias, + OriginalIndices.NONE + ); + + Map innerHits = new HashMap<>(); + SearchHit innerHit1 = new SearchHit(0, "_id", null, null); + innerHit1.shard(target); + SearchHit innerInnerHit2 = new SearchHit(0, "_id", null, null); + innerInnerHit2.shard(target); + innerHits.put("1", new SearchHits(new SearchHit[] { innerInnerHit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f)); + innerHit1.setInnerHits(innerHits); + SearchHit innerHit2 = new SearchHit(0, "_id", null, null); + innerHit2.shard(target); + SearchHit innerHit3 = new SearchHit(0, "_id", null, null); + innerHit3.shard(target); + + innerHits = new HashMap<>(); + SearchHit hit1 = new SearchHit(0, "_id", null, null); + innerHits.put("1", new SearchHits(new SearchHit[] { innerHit1, innerHit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f)); + innerHits.put("2", new SearchHits(new SearchHit[] { innerHit3 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f)); + hit1.shard(target); + hit1.setInnerHits(innerHits); + + SearchHit hit2 = new SearchHit(0, "_id", null, null); + hit2.shard(target); + + SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 1f); + + SearchHits results = copyWriteable(hits, getNamedWriteableRegistry(), SearchHits::new, Version.V_3_0_0); + SearchShardTarget deserializedTarget = results.getAt(0).getShard(); + assertThat(deserializedTarget, equalTo(target)); + assertThat(results.getAt(0).getInnerHits().get("1").getAt(0).getShard(), notNullValue()); + assertThat(results.getAt(0).getInnerHits().get("1").getAt(0).getInnerHits().get("1").getAt(0).getShard(), notNullValue()); + assertThat(results.getAt(0).getInnerHits().get("1").getAt(1).getShard(), notNullValue()); + assertThat(results.getAt(0).getInnerHits().get("2").getAt(0).getShard(), notNullValue()); + for (SearchHit hit : results) { + assertEquals(clusterAlias, hit.getClusterAlias()); + if (hit.getInnerHits() != null) { + for (SearchHits innerhits : hit.getInnerHits().values()) { + for (SearchHit innerHit : innerhits) { + assertEquals(clusterAlias, innerHit.getClusterAlias()); + } + } + } + } + assertThat(results.getAt(1).getShard(), equalTo(target)); + } + + public void testSerializeShardTargetWithNewVersionAndMatchedQueries() throws Exception { + String clusterAlias = randomBoolean() ? null : "cluster_alias"; + SearchShardTarget target = new SearchShardTarget( + "_node_id", + new ShardId(new Index("_index", "_na_"), 0), + clusterAlias, + OriginalIndices.NONE + ); + + Map innerHits = new HashMap<>(); + SearchHit innerHit1 = new SearchHit(0, "_id", null, null); + innerHit1.shard(target); + innerHit1.matchedQueriesWithScores(getSampleMatchedQueries()); + SearchHit innerInnerHit2 = new SearchHit(0, "_id", null, null); + innerInnerHit2.shard(target); + innerHits.put("1", new SearchHits(new SearchHit[] { innerInnerHit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f)); + innerHit1.setInnerHits(innerHits); + SearchHit innerHit2 = new SearchHit(0, "_id", null, null); + innerHit2.shard(target); + innerHit2.matchedQueriesWithScores(getSampleMatchedQueries()); + SearchHit innerHit3 = new SearchHit(0, "_id", null, null); + innerHit3.shard(target); + innerHit3.matchedQueriesWithScores(getSampleMatchedQueries()); + + innerHits = new HashMap<>(); + SearchHit hit1 = new SearchHit(0, "_id", null, null); + innerHits.put("1", new SearchHits(new SearchHit[] { innerHit1, innerHit2 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f)); + innerHits.put("2", new SearchHits(new SearchHit[] { innerHit3 }, new TotalHits(1, TotalHits.Relation.EQUAL_TO), 1f)); + hit1.shard(target); + hit1.setInnerHits(innerHits); + + SearchHit hit2 = new SearchHit(0, "_id", null, null); + hit2.shard(target); + + SearchHits hits = new SearchHits(new SearchHit[] { hit1, hit2 }, new TotalHits(2, TotalHits.Relation.EQUAL_TO), 1f); + + SearchHits results = copyWriteable(hits, getNamedWriteableRegistry(), SearchHits::new, Version.V_3_0_0); + SearchShardTarget deserializedTarget = results.getAt(0).getShard(); + assertThat(deserializedTarget, equalTo(target)); + assertThat(results.getAt(0).getInnerHits().get("1").getAt(0).getShard(), notNullValue()); + assertThat(results.getAt(0).getInnerHits().get("1").getAt(0).getInnerHits().get("1").getAt(0).getShard(), notNullValue()); + assertThat(results.getAt(0).getInnerHits().get("1").getAt(1).getShard(), notNullValue()); + assertThat(results.getAt(0).getInnerHits().get("2").getAt(0).getShard(), notNullValue()); + String[] expectedMatchedQueries = new String[] { "query1", "query2" }; + String[] actualMatchedQueries = results.getAt(0).getInnerHits().get("1").getAt(0).getMatchedQueries(); + assertArrayEquals(expectedMatchedQueries, actualMatchedQueries); + + Map expectedMatchedQueriesAndScores = new LinkedHashMap<>(); + expectedMatchedQueriesAndScores.put("query1", 1.0f); + expectedMatchedQueriesAndScores.put("query2", 0.5f); + + Map actualMatchedQueriesAndScores = results.getAt(0).getInnerHits().get("1").getAt(0).getMatchedQueriesAndScores(); + assertEquals(expectedMatchedQueriesAndScores, actualMatchedQueriesAndScores); + for (SearchHit hit : results) { + assertEquals(clusterAlias, hit.getClusterAlias()); + if (hit.getInnerHits() != null) { + for (SearchHits innerhits : hit.getInnerHits().values()) { + for (SearchHit innerHit : innerhits) { + assertEquals(clusterAlias, innerHit.getClusterAlias()); + } + } + } + } + assertThat(results.getAt(1).getShard(), equalTo(target)); + } + public void testSerializeShardTarget() throws Exception { String clusterAlias = randomBoolean() ? null : "cluster_alias"; SearchShardTarget target = new SearchShardTarget( diff --git a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java index bbb3c4a070800..cbfeea4633ad6 100644 --- a/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java +++ b/test/framework/src/main/java/org/opensearch/test/TestSearchContext.java @@ -107,6 +107,7 @@ public class TestSearchContext extends SearchContext { SearchShardTask task; SortAndFormats sort; boolean trackScores = false; + boolean includeNamedQueriesScore = false; int trackTotalHitsUpTo = SearchContext.DEFAULT_TRACK_TOTAL_HITS_UP_TO; ContextIndexSearcher searcher; @@ -415,6 +416,17 @@ public boolean trackScores() { return trackScores; } + @Override + public SearchContext includeNamedQueriesScore(boolean includeNamedQueriesScore) { + this.includeNamedQueriesScore = includeNamedQueriesScore; + return this; + } + + @Override + public boolean includeNamedQueriesScore() { + return includeNamedQueriesScore; + } + @Override public SearchContext trackTotalHitsUpTo(int trackTotalHitsUpTo) { this.trackTotalHitsUpTo = trackTotalHitsUpTo; diff --git a/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchAssertions.java b/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchAssertions.java index 3a39fd30f17dd..0a0cb1ed3a263 100644 --- a/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchAssertions.java +++ b/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchAssertions.java @@ -531,6 +531,10 @@ public static Matcher hasScore(final float score) { return new OpenSearchMatchers.SearchHitHasScoreMatcher(score); } + public static Matcher hasMatchedQueries(final String[] matchedQueries) { + return new OpenSearchMatchers.SearchHitMatchedQueriesMatcher(matchedQueries); + } + public static CombinableMatcher hasProperty(Function property, Matcher valueMatcher) { return OpenSearchMatchers.HasPropertyLambdaMatcher.hasProperty(property, valueMatcher); } diff --git a/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchMatchers.java b/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchMatchers.java index 5889b7e269ed2..2be94bd53e3c1 100644 --- a/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchMatchers.java +++ b/test/framework/src/main/java/org/opensearch/test/hamcrest/OpenSearchMatchers.java @@ -38,6 +38,7 @@ import org.hamcrest.TypeSafeMatcher; import org.hamcrest.core.CombinableMatcher; +import java.util.Arrays; import java.util.function.Function; public class OpenSearchMatchers { @@ -111,6 +112,35 @@ public void describeTo(final Description description) { } } + public static class SearchHitMatchedQueriesMatcher extends TypeSafeMatcher { + private String[] matchedQueries; + + public SearchHitMatchedQueriesMatcher(String[] matchedQueries) { + this.matchedQueries = matchedQueries; + } + + @Override + protected boolean matchesSafely(SearchHit searchHit) { + String[] searchHitQueries = searchHit.getMatchedQueries(); + if (matchedQueries == null) { + return false; + } + Arrays.sort(searchHitQueries); + Arrays.sort(matchedQueries); + return Arrays.equals(searchHitQueries, matchedQueries); + } + + @Override + public void describeMismatchSafely(final SearchHit searchHit, final Description mismatchDescription) { + mismatchDescription.appendText(" matched queries were ").appendValue(Arrays.toString(searchHit.getMatchedQueries())); + } + + @Override + public void describeTo(final Description description) { + description.appendText("searchHit matched queries should be ").appendValue(Arrays.toString(matchedQueries)); + } + } + public static class HasPropertyLambdaMatcher extends FeatureMatcher { private final Function property;