From cc6a6b2087a19d4216f1731b68f50defabe63f48 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Wed, 3 Apr 2024 16:09:04 -0700 Subject: [PATCH] Add support for local cache in hybrid query (#663) Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + .../NormalizationProcessorWorkflow.java | 14 +- .../NormalizationProcessorTests.java | 7 + .../NormalizationProcessorWorkflowTests.java | 87 ++++- .../neuralsearch/query/HybridQueryIT.java | 308 ++++++++++++++++++ 5 files changed, 405 insertions(+), 12 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f20523654..060acdb29 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -19,6 +19,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ### Features ### Enhancements ### Bug Fixes +- Add support for request_cache flag in hybrid query ([#663](https://github.com/opensearch-project/neural-search/pull/663)) ### Infrastructure ### Documentation ### Maintenance diff --git a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java index a20b52517..f317a9e12 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflow.java @@ -138,7 +138,13 @@ private void updateOriginalFetchResults( // 3. update original scores to normalized and combined values // 4. order scores based on normalized and combined values FetchSearchResult fetchSearchResult = fetchSearchResultOptional.get(); - SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult); + // checking case when results are cached + boolean requestCache = Objects.nonNull(querySearchResults) + && !querySearchResults.isEmpty() + && Objects.nonNull(querySearchResults.get(0).getShardSearchRequest().requestCache()) + && querySearchResults.get(0).getShardSearchRequest().requestCache(); + + SearchHit[] searchHitArray = getSearchHits(docIds, fetchSearchResult, requestCache); // create map of docId to index of search hits. This solves (2), duplicates are from // delimiter and start/stop elements, they all have same valid doc_id. For this map @@ -168,7 +174,7 @@ private void updateOriginalFetchResults( fetchSearchResult.hits(updatedSearchHits); } - private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult) { + private SearchHit[] getSearchHits(final List docIds, final FetchSearchResult fetchSearchResult, final boolean requestCache) { SearchHits searchHits = fetchSearchResult.hits(); SearchHit[] searchHitArray = searchHits.getHits(); // validate the both collections are of the same size @@ -177,7 +183,9 @@ private SearchHit[] getSearchHits(final List docIds, final FetchSearchR "score normalization processor cannot produce final query result, fetch query phase returns empty results" ); } - if (searchHitArray.length != docIds.size()) { + // in case of cached request results of fetch and query may be different, only restriction is + // that number of query results size is greater or equal size of fetch results + if ((!requestCache && searchHitArray.length != docIds.size()) || requestCache && docIds.size() < searchHitArray.length) { throw new IllegalStateException( String.format( Locale.ROOT, diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java index 57698cd7e..dd185e227 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorTests.java @@ -53,6 +53,7 @@ import org.opensearch.search.aggregations.pipeline.PipelineAggregator; import org.opensearch.search.fetch.FetchSearchResult; import org.opensearch.search.fetch.QueryFetchSearchResult; +import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.TestThreadPool; @@ -401,6 +402,9 @@ public void testResultTypes_whenQueryAndFetchPresentAndSizeSame_thenCallNormaliz QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); queryFetchSearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); @@ -485,6 +489,9 @@ public void testResultTypes_whenQueryAndFetchPresentButSizeDifferent_thenFail() QueryFetchSearchResult queryFetchSearchResult = new QueryFetchSearchResult(querySearchResult, fetchSearchResult); queryFetchSearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE); + querySearchResult.setShardSearchRequest(shardSearchRequest); queryPhaseResultConsumer.consumeResult(queryFetchSearchResult, partialReduceLatch::countDown); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java index f34f8f59b..2f880ce74 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorWorkflowTests.java @@ -4,7 +4,9 @@ */ package org.opensearch.neuralsearch.processor; +import static org.mockito.Mockito.mock; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.when; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults; import static org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults; @@ -29,6 +31,7 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.SearchShardTarget; import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.ShardSearchRequest; import org.opensearch.search.query.QuerySearchResult; import org.opensearch.test.OpenSearchTestCase; @@ -156,6 +159,9 @@ public void testFetchResults_whenOneShardAndQueryAndFetchResultsPresent_thenDoNo ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); querySearchResults.add(querySearchResult); SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(0, "10", Map.of(), Map.of()), @@ -213,6 +219,9 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); querySearchResults.add(querySearchResult); SearchHit[] searchHitArray = new SearchHit[] { new SearchHit(-1, "10", Map.of(), Map.of()), @@ -236,7 +245,7 @@ public void testFetchResults_whenOneShardAndMultipleNodes_thenDoNormalizationCom TestUtils.assertFetchResultScores(fetchSearchResult, 4); } - public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { + public void testFetchResultsAndNoCache_whenOneShardAndMultipleNodesAndMismatchResults_thenFail() { NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) ); @@ -270,15 +279,11 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then ); querySearchResult.setSearchShardTarget(searchShardTarget); querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.FALSE); + querySearchResult.setShardSearchRequest(shardSearchRequest); querySearchResults.add(querySearchResult); - SearchHit[] searchHitArray = new SearchHit[] { - new SearchHit(-1, "10", Map.of(), Map.of()), - new SearchHit(-1, "10", Map.of(), Map.of()), - new SearchHit(-1, "10", Map.of(), Map.of()), - new SearchHit(-1, "1", Map.of(), Map.of()), - new SearchHit(-1, "2", Map.of(), Map.of()), - new SearchHit(-1, "3", Map.of(), Map.of()) }; - SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + SearchHits searchHits = getSearchHits(); fetchSearchResult.hits(searchHits); expectThrows( @@ -291,4 +296,68 @@ public void testFetchResults_whenOneShardAndMultipleNodesAndMismatchResults_then ) ); } + + public void testFetchResultsAndCache_whenOneShardAndMultipleNodesAndMismatchResults_thenSuccessful() { + NormalizationProcessorWorkflow normalizationProcessorWorkflow = spy( + new NormalizationProcessorWorkflow(new ScoreNormalizer(), new ScoreCombiner()) + ); + + List querySearchResults = new ArrayList<>(); + FetchSearchResult fetchSearchResult = new FetchSearchResult(); + int shardId = 0; + SearchShardTarget searchShardTarget = new SearchShardTarget( + "node", + new ShardId("index", "uuid", shardId), + null, + OriginalIndices.NONE + ); + QuerySearchResult querySearchResult = new QuerySearchResult(); + querySearchResult.topDocs( + new TopDocsAndMaxScore( + new TopDocs( + new TotalHits(4, TotalHits.Relation.EQUAL_TO), + new ScoreDoc[] { + createStartStopElementForHybridSearchResults(0), + createDelimiterElementForHybridSearchResults(0), + new ScoreDoc(0, 0.5f), + new ScoreDoc(2, 0.3f), + new ScoreDoc(4, 0.25f), + new ScoreDoc(10, 0.2f), + createStartStopElementForHybridSearchResults(0) } + ), + 0.5f + ), + new DocValueFormat[0] + ); + querySearchResult.setSearchShardTarget(searchShardTarget); + querySearchResult.setShardIndex(shardId); + ShardSearchRequest shardSearchRequest = mock(ShardSearchRequest.class); + when(shardSearchRequest.requestCache()).thenReturn(Boolean.TRUE); + querySearchResult.setShardSearchRequest(shardSearchRequest); + querySearchResults.add(querySearchResult); + SearchHits searchHits = getSearchHits(); + fetchSearchResult.hits(searchHits); + + normalizationProcessorWorkflow.execute( + querySearchResults, + Optional.of(fetchSearchResult), + ScoreNormalizationFactory.DEFAULT_METHOD, + ScoreCombinationFactory.DEFAULT_METHOD + ); + + TestUtils.assertQueryResultScores(querySearchResults); + TestUtils.assertFetchResultScores(fetchSearchResult, 4); + } + + private static SearchHits getSearchHits() { + SearchHit[] searchHitArray = new SearchHit[] { + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "10", Map.of(), Map.of()), + new SearchHit(-1, "1", Map.of(), Map.of()), + new SearchHit(-1, "2", Map.of(), Map.of()), + new SearchHit(-1, "3", Map.of(), Map.of()) }; + SearchHits searchHits = new SearchHits(searchHitArray, new TotalHits(7, TotalHits.Relation.EQUAL_TO), 10); + return searchHits; + } } diff --git a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java index 487a378df..38aa69075 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/HybridQueryIT.java @@ -29,6 +29,7 @@ import org.opensearch.index.query.MatchQueryBuilder; import org.opensearch.index.query.NestedQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.RangeQueryBuilder; import org.opensearch.index.query.TermQueryBuilder; import org.opensearch.neuralsearch.BaseNeuralSearchIT; @@ -43,6 +44,8 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String TEST_MULTI_DOC_INDEX_NAME_ONE_SHARD = "test-hybrid-multi-doc-single-shard-index"; private static final String TEST_MULTI_DOC_INDEX_WITH_NESTED_TYPE_NAME_ONE_SHARD = "test-hybrid-multi-doc-nested-type-single-shard-index"; + private static final String TEST_INDEX_WITH_KEYWORDS_ONE_SHARD = "test-hybrid-keywords-single-shard-index"; + private static final String TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS = "test-hybrid-keywords-three-shards-index"; private static final String TEST_QUERY_TEXT = "greetings"; private static final String TEST_QUERY_TEXT2 = "salute"; private static final String TEST_QUERY_TEXT3 = "hello"; @@ -59,6 +62,18 @@ public class HybridQueryIT extends BaseNeuralSearchIT { private static final String NESTED_FIELD_2 = "lastname"; private static final String NESTED_FIELD_1_VALUE = "john"; private static final String NESTED_FIELD_2_VALUE = "black"; + private static final String KEYWORD_FIELD_1 = "doc_keyword"; + private static final String KEYWORD_FIELD_1_VALUE = "workable"; + private static final String KEYWORD_FIELD_2_VALUE = "angry"; + private static final String KEYWORD_FIELD_3_VALUE = "likeable"; + private static final String KEYWORD_FIELD_4_VALUE = "entire"; + private static final String INTEGER_FIELD_PRICE = "doc_price"; + private static final int INTEGER_FIELD_PRICE_1_VALUE = 130; + private static final int INTEGER_FIELD_PRICE_2_VALUE = 100; + private static final int INTEGER_FIELD_PRICE_3_VALUE = 200; + private static final int INTEGER_FIELD_PRICE_4_VALUE = 25; + private static final int INTEGER_FIELD_PRICE_5_VALUE = 30; + private static final int INTEGER_FIELD_PRICE_6_VALUE = 350; private final float[] testVector1 = createRandomVector(TEST_DIMENSION); private final float[] testVector2 = createRandomVector(TEST_DIMENSION); private final float[] testVector3 = createRandomVector(TEST_DIMENSION); @@ -399,6 +414,170 @@ public void testIndexWithNestedFields_whenHybridQueryIncludesNested_thenSuccess( } } + @SneakyThrows + public void testRequestCache_whenOneShardAndQueryReturnResults_thenSuccessful() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD); + modelId = prepareModel(); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(KEYWORD_FIELD_1, KEYWORD_FIELD_2_VALUE); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_PRICE).gte(10).lte(1000); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(rangeQueryBuilder); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_WITH_KEYWORDS_ONE_SHARD, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "request_cache", Boolean.TRUE.toString()) + ); + + int firstQueryHitCount = getHitCount(firstSearchResponseAsMap); + assertTrue(firstQueryHitCount > 0); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(firstQueryHitCount, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // second query is served from the cache + Map secondSearchResponseAsMap = search( + TEST_INDEX_WITH_KEYWORDS_ONE_SHARD, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "request_cache", Boolean.TRUE.toString()) + ); + + assertEquals(firstQueryHitCount, getHitCount(secondSearchResponseAsMap)); + + List> hitsNestedListSecondQuery = getNestedHits(secondSearchResponseAsMap); + List idsSecondQuery = new ArrayList<>(); + List scoresSecondQuery = new ArrayList<>(); + for (Map oneHit : hitsNestedListSecondQuery) { + idsSecondQuery.add((String) oneHit.get("_id")); + scoresSecondQuery.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue( + IntStream.range(0, scoresSecondQuery.size() - 1) + .noneMatch(idx -> scoresSecondQuery.get(idx) < scoresSecondQuery.get(idx + 1)) + ); + // verify that all ids are unique + assertEquals(Set.copyOf(idsSecondQuery).size(), idsSecondQuery.size()); + + Map totalSecondQuery = getTotalHits(secondSearchResponseAsMap); + assertNotNull(totalSecondQuery.get("value")); + assertEquals(firstQueryHitCount, totalSecondQuery.get("value")); + assertNotNull(totalSecondQuery.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalSecondQuery.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD, null, modelId, SEARCH_PIPELINE); + } + } + + @SneakyThrows + public void testRequestCache_whenMultipleShardsQueryReturnResults_thenSuccessful() { + String modelId = null; + try { + initializeIndexIfNotExist(TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS); + modelId = prepareModel(); + createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); + MatchQueryBuilder matchQueryBuilder = QueryBuilders.matchQuery(KEYWORD_FIELD_1, KEYWORD_FIELD_2_VALUE); + RangeQueryBuilder rangeQueryBuilder = QueryBuilders.rangeQuery(INTEGER_FIELD_PRICE).gte(10).lte(1000); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(rangeQueryBuilder); + + // first query with cache flag executed normally by reading documents from index + Map firstSearchResponseAsMap = search( + TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "request_cache", Boolean.TRUE.toString()) + ); + + int firstQueryHitCount = getHitCount(firstSearchResponseAsMap); + assertTrue(firstQueryHitCount > 0); + + List> hitsNestedList = getNestedHits(firstSearchResponseAsMap); + List ids = new ArrayList<>(); + List scores = new ArrayList<>(); + for (Map oneHit : hitsNestedList) { + ids.add((String) oneHit.get("_id")); + scores.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue(IntStream.range(0, scores.size() - 1).noneMatch(idx -> scores.get(idx) < scores.get(idx + 1))); + // verify that all ids are unique + assertEquals(Set.copyOf(ids).size(), ids.size()); + + Map total = getTotalHits(firstSearchResponseAsMap); + assertNotNull(total.get("value")); + assertEquals(firstQueryHitCount, total.get("value")); + assertNotNull(total.get("relation")); + assertEquals(RELATION_EQUAL_TO, total.get("relation")); + + // second query is served from the cache + Map secondSearchResponseAsMap = search( + TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS, + hybridQueryBuilder, + null, + 10, + Map.of("search_pipeline", SEARCH_PIPELINE, "request_cache", Boolean.TRUE.toString()) + ); + + assertEquals(firstQueryHitCount, getHitCount(secondSearchResponseAsMap)); + + List> hitsNestedListSecondQuery = getNestedHits(secondSearchResponseAsMap); + List idsSecondQuery = new ArrayList<>(); + List scoresSecondQuery = new ArrayList<>(); + for (Map oneHit : hitsNestedListSecondQuery) { + idsSecondQuery.add((String) oneHit.get("_id")); + scoresSecondQuery.add((Double) oneHit.get("_score")); + } + + // verify that scores are in desc order + assertTrue( + IntStream.range(0, scoresSecondQuery.size() - 1) + .noneMatch(idx -> scoresSecondQuery.get(idx) < scoresSecondQuery.get(idx + 1)) + ); + // verify that all ids are unique + assertEquals(Set.copyOf(idsSecondQuery).size(), idsSecondQuery.size()); + + Map totalSecondQuery = getTotalHits(secondSearchResponseAsMap); + assertNotNull(totalSecondQuery.get("value")); + assertEquals(firstQueryHitCount, totalSecondQuery.get("value")); + assertNotNull(totalSecondQuery.get("relation")); + assertEquals(RELATION_EQUAL_TO, totalSecondQuery.get("relation")); + } finally { + wipeOfTestResources(TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS, null, modelId, SEARCH_PIPELINE); + } + } + @SneakyThrows private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { @@ -490,6 +669,104 @@ private void initializeIndexIfNotExist(String indexName) throws IOException { List.of(Map.of(NESTED_FIELD_1, NESTED_FIELD_1_VALUE, NESTED_FIELD_2, NESTED_FIELD_2_VALUE)) ); } + + if (TEST_INDEX_WITH_KEYWORDS_ONE_SHARD.equals(indexName) && !indexExists(TEST_INDEX_WITH_KEYWORDS_ONE_SHARD)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_PRICE), List.of(KEYWORD_FIELD_1), List.of(), 1), + "" + ); + addDocWithKeywordsAndIntFields( + indexName, + "1", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_1_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_1_VALUE + ); + addDocWithKeywordsAndIntFields(indexName, "2", INTEGER_FIELD_PRICE, INTEGER_FIELD_PRICE_2_VALUE, null, null); + addDocWithKeywordsAndIntFields( + indexName, + "3", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_3_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_2_VALUE + ); + addDocWithKeywordsAndIntFields( + indexName, + "4", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_4_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_3_VALUE + ); + addDocWithKeywordsAndIntFields( + indexName, + "5", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_5_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_4_VALUE + ); + addDocWithKeywordsAndIntFields( + indexName, + "6", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_6_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_4_VALUE + ); + } + + if (TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS.equals(indexName) && !indexExists(TEST_INDEX_WITH_KEYWORDS_THREE_SHARDS)) { + createIndexWithConfiguration( + indexName, + buildIndexConfiguration(List.of(), List.of(), List.of(INTEGER_FIELD_PRICE), List.of(KEYWORD_FIELD_1), List.of(), 3), + "" + ); + addDocWithKeywordsAndIntFields( + indexName, + "1", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_1_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_1_VALUE + ); + addDocWithKeywordsAndIntFields(indexName, "2", INTEGER_FIELD_PRICE, INTEGER_FIELD_PRICE_2_VALUE, null, null); + addDocWithKeywordsAndIntFields( + indexName, + "3", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_3_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_2_VALUE + ); + addDocWithKeywordsAndIntFields( + indexName, + "4", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_4_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_3_VALUE + ); + addDocWithKeywordsAndIntFields( + indexName, + "5", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_5_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_4_VALUE + ); + addDocWithKeywordsAndIntFields( + indexName, + "6", + INTEGER_FIELD_PRICE, + INTEGER_FIELD_PRICE_6_VALUE, + KEYWORD_FIELD_1, + KEYWORD_FIELD_4_VALUE + ); + } } private void addDocsToIndex(final String testMultiDocIndexName) { @@ -532,4 +809,35 @@ private Optional getMaxScore(Map searchResponseAsMap) { Map hitsMap = (Map) searchResponseAsMap.get("hits"); return hitsMap.get("max_score") == null ? Optional.empty() : Optional.of(((Double) hitsMap.get("max_score")).floatValue()); } + + private void addDocWithKeywordsAndIntFields( + final String indexName, + final String docId, + final String integerField, + final Integer integerFieldValue, + final String keywordField, + final String keywordFieldValue + ) { + List intFields = integerField == null ? List.of() : List.of(integerField); + List intValues = integerFieldValue == null ? List.of() : List.of(integerFieldValue); + List keywordFields = keywordField == null ? List.of() : List.of(keywordField); + List keywordValues = keywordFieldValue == null ? List.of() : List.of(keywordFieldValue); + + addKnnDoc( + indexName, + docId, + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + List.of(), + intFields, + intValues, + keywordFields, + keywordValues, + List.of(), + List.of() + ); + } }