From 627fcb4f5082b0f29d0d3f89364b05c79a7ae4f6 Mon Sep 17 00:00:00 2001 From: Ryan Bogan Date: Tue, 3 Dec 2024 08:59:26 -0800 Subject: [PATCH] Add integration and unit tests for missing RRF coverage (#997) * Initial unit test implementation Signed-off-by: Ryan Bogan --------- Signed-off-by: Ryan Bogan Signed-off-by: Martin Gaievski --- .../neuralsearch/processor/RRFProcessor.java | 14 +- .../MinMaxScoreNormalizationTechnique.java | 16 +- .../processor/RRFProcessorIT.java | 93 +++++++ .../processor/RRFProcessorTests.java | 226 ++++++++++++++++++ .../neuralsearch/BaseNeuralSearchIT.java | 28 +++ 5 files changed, 366 insertions(+), 11 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java index c8f78691a..ca67f2d1c 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/RRFProcessor.java @@ -13,6 +13,7 @@ import java.util.Optional; import lombok.Getter; +import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; import org.opensearch.search.fetch.FetchSearchResult; @@ -99,7 +100,8 @@ public boolean isIgnoreFailure() { return false; } - private boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { + @VisibleForTesting + boolean shouldSkipProcessor(SearchPhaseResults searchPhaseResult) { if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) { return true; } @@ -112,7 +114,8 @@ private boolean shouldSkipProcessor(SearchPha * @param searchPhaseResult * @return true if results are from hybrid query */ - private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { + @VisibleForTesting + boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { // check for delimiter at the end of the score docs. return Objects.nonNull(searchPhaseResult.queryResult()) && Objects.nonNull(searchPhaseResult.queryResult().topDocs()) @@ -121,9 +124,7 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) { && isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]); } - private List getQueryPhaseSearchResults( - final SearchPhaseResults results - ) { + List getQueryPhaseSearchResults(final SearchPhaseResults results) { return results.getAtomicArray() .asList() .stream() @@ -131,7 +132,8 @@ private List getQueryPhase .collect(Collectors.toList()); } - private Optional getFetchSearchResults( + @VisibleForTesting + Optional getFetchSearchResults( final SearchPhaseResults searchPhaseResults ) { Optional optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst(); diff --git a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java index da16d6c96..7da4c4330 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/normalization/MinMaxScoreNormalizationTechnique.java @@ -12,6 +12,8 @@ import java.util.Map; import java.util.Objects; +import lombok.AllArgsConstructor; +import lombok.Getter; import org.apache.lucene.search.ScoreDoc; import org.apache.lucene.search.TopDocs; import org.opensearch.neuralsearch.processor.CompoundTopDocs; @@ -58,8 +60,8 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) { for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) { scoreDoc.score = normalizeSingleScore( scoreDoc.score, - minMaxScores.minScoresPerSubquery()[j], - minMaxScores.maxScoresPerSubquery()[j] + minMaxScores.getMinScoresPerSubquery()[j], + minMaxScores.getMaxScoresPerSubquery()[j] ); } } @@ -96,8 +98,8 @@ public Map explain(final List new ArrayList<>()).add(normalizedScore); scoreDoc.score = normalizedScore; @@ -171,6 +173,10 @@ private float normalizeSingleScore(final float score, final float minScore, fina /** * Result class to hold min and max scores for each sub query */ - private record MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) { + @AllArgsConstructor + @Getter + private class MinMaxScores { + float[] minScoresPerSubquery; + float[] maxScoresPerSubquery; } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java new file mode 100644 index 000000000..fccabab5c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorIT.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.opensearch.index.query.MatchQueryBuilder; +import org.opensearch.knn.index.query.KNNQueryBuilder; +import org.opensearch.neuralsearch.BaseNeuralSearchIT; +import org.opensearch.neuralsearch.query.HybridQueryBuilder; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION; +import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE; + +public class RRFProcessorIT extends BaseNeuralSearchIT { + + private int currentDoc = 1; + private static final String RRF_INDEX_NAME = "rrf-index"; + private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; + private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline"; + + private static final int RRF_DIMENSION = 5; + + @SneakyThrows + public void testRRF_whenValidInput_thenSucceed() { + try { + createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING); + prepareKnnIndex( + RRF_INDEX_NAME, + Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE)) + ); + addDocuments(); + createDefaultRRFSearchPipeline(); + + HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder(); + + Map results = search( + RRF_INDEX_NAME, + hybridQueryBuilder, + null, + 5, + Map.of("search_pipeline", RRF_SEARCH_PIPELINE) + ); + Map hits = (Map) results.get("hits"); + ArrayList> hitsList = (ArrayList>) hits.get("hits"); + assertEquals(3, hitsList.size()); + assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION); + assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION); + } finally { + wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE); + } + } + + private HybridQueryBuilder getHybridQueryBuilder() { + MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco"); + KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding") + .k(5) + .vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f }) + .build(); + + HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); + hybridQueryBuilder.add(matchQueryBuilder); + hybridQueryBuilder.add(knnQueryBuilder); + return hybridQueryBuilder; + } + + @SneakyThrows + private void addDocuments() { + addDocument( + "A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .", + "4319130149.jpg" + ); + addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg"); + addDocument( + "People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .", + "2664027527.jpg" + ); + addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg"); + addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg"); + } + + @SneakyThrows + private void addDocument(String description, String imageText) { + addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java new file mode 100644 index 000000000..b7764128f --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/RRFProcessorTests.java @@ -0,0 +1,226 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import lombok.SneakyThrows; +import org.apache.lucene.search.ScoreDoc; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.TotalHits; +import org.junit.Before; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.action.OriginalIndices; +import org.opensearch.action.search.QueryPhaseResultConsumer; +import org.opensearch.action.search.SearchPhaseContext; +import org.opensearch.action.search.SearchPhaseName; +import org.opensearch.action.search.SearchPhaseResults; +import org.opensearch.action.search.SearchRequest; +import org.opensearch.action.support.IndicesOptions; +import org.opensearch.common.lucene.search.TopDocsAndMaxScore; +import org.opensearch.common.util.concurrent.AtomicArray; +import org.opensearch.core.common.Strings; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique; +import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique; +import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil; +import org.opensearch.search.DocValueFormat; +import org.opensearch.search.SearchPhaseResult; +import org.opensearch.search.SearchShardTarget; +import org.opensearch.search.builder.SearchSourceBuilder; +import org.opensearch.search.fetch.FetchSearchResult; +import org.opensearch.search.internal.AliasFilter; +import org.opensearch.search.internal.ShardSearchContextId; +import org.opensearch.search.internal.ShardSearchRequest; +import org.opensearch.search.query.QuerySearchResult; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.List; +import java.util.Optional; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class RRFProcessorTests extends OpenSearchTestCase { + + @Mock + private ScoreNormalizationTechnique mockNormalizationTechnique; + @Mock + private ScoreCombinationTechnique mockCombinationTechnique; + @Mock + private NormalizationProcessorWorkflow mockNormalizationWorkflow; + @Mock + private SearchPhaseResults mockSearchPhaseResults; + @Mock + private SearchPhaseContext mockSearchPhaseContext; + @Mock + private QueryPhaseResultConsumer mockQueryPhaseResultConsumer; + + private RRFProcessor rrfProcessor; + private static final String TAG = "tag"; + private static final String DESCRIPTION = "description"; + + @Before + @SneakyThrows + public void setUp() { + super.setUp(); + MockitoAnnotations.openMocks(this); + rrfProcessor = new RRFProcessor(TAG, DESCRIPTION, mockNormalizationTechnique, mockCombinationTechnique, mockNormalizationWorkflow); + } + + @SneakyThrows + public void testGetType() { + assertEquals(RRFProcessor.TYPE, rrfProcessor.getType()); + } + + @SneakyThrows + public void testGetBeforePhase() { + assertEquals(SearchPhaseName.QUERY, rrfProcessor.getBeforePhase()); + } + + @SneakyThrows + public void testGetAfterPhase() { + assertEquals(SearchPhaseName.FETCH, rrfProcessor.getAfterPhase()); + } + + @SneakyThrows + public void testIsIgnoreFailure() { + assertFalse(rrfProcessor.isIgnoreFailure()); + } + + @SneakyThrows + public void testProcess_whenNullSearchPhaseResult_thenSkipWorkflow() { + rrfProcessor.process(null, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenNonQueryPhaseResultConsumer_thenSkipWorkflow() { + rrfProcessor.process(mockSearchPhaseResults, mockSearchPhaseContext); + verify(mockNormalizationWorkflow, never()).execute(any()); + } + + @SneakyThrows + public void testProcess_whenValidHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(true); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); + } + + @SneakyThrows + public void testProcess_whenValidNonHybridInput_thenSucceed() { + QuerySearchResult result = createQuerySearchResult(false); + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, result); + + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext); + + verify(mockNormalizationWorkflow, never()).execute(any(NormalizationProcessorWorkflowExecuteRequest.class)); + } + + @SneakyThrows + public void testGetTag() { + assertEquals(TAG, rrfProcessor.getTag()); + } + + @SneakyThrows + public void testGetDescription() { + assertEquals(DESCRIPTION, rrfProcessor.getDescription()); + } + + @SneakyThrows + public void testShouldSkipProcessor() { + assertTrue(rrfProcessor.shouldSkipProcessor(null)); + assertTrue(rrfProcessor.shouldSkipProcessor(mockSearchPhaseResults)); + + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + assertTrue(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + + atomicArray.set(0, createQuerySearchResult(true)); + assertFalse(rrfProcessor.shouldSkipProcessor(mockQueryPhaseResultConsumer)); + } + + @SneakyThrows + public void testGetQueryPhaseSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(2); + atomicArray.set(0, createQuerySearchResult(true)); + atomicArray.set(1, createQuerySearchResult(false)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + List results = rrfProcessor.getQueryPhaseSearchResults(mockQueryPhaseResultConsumer); + assertEquals(2, results.size()); + assertNotNull(results.get(0)); + assertNotNull(results.get(1)); + } + + @SneakyThrows + public void testGetFetchSearchResults() { + AtomicArray atomicArray = new AtomicArray<>(1); + atomicArray.set(0, createQuerySearchResult(true)); + when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray); + + Optional result = rrfProcessor.getFetchSearchResults(mockQueryPhaseResultConsumer); + assertFalse(result.isPresent()); + } + + private QuerySearchResult createQuerySearchResult(boolean isHybrid) { + ShardId shardId = new ShardId("index", "uuid", 0); + OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed()); + SearchRequest searchRequest = new SearchRequest("index"); + searchRequest.source(new SearchSourceBuilder()); + searchRequest.allowPartialSearchResults(true); + + int numberOfShards = 1; + AliasFilter aliasFilter = new AliasFilter(null, Strings.EMPTY_ARRAY); + float indexBoost = 1.0f; + long nowInMillis = System.currentTimeMillis(); + String clusterAlias = null; + String[] indexRoutings = Strings.EMPTY_ARRAY; + + ShardSearchRequest shardSearchRequest = new ShardSearchRequest( + originalIndices, + searchRequest, + shardId, + numberOfShards, + aliasFilter, + indexBoost, + nowInMillis, + clusterAlias, + indexRoutings + ); + + QuerySearchResult result = new QuerySearchResult( + new ShardSearchContextId("test", 1), + new SearchShardTarget("node1", shardId, clusterAlias, originalIndices), + shardSearchRequest + ); + result.from(0).size(10); + + ScoreDoc[] scoreDocs; + if (isHybrid) { + scoreDocs = new ScoreDoc[] { HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(0) }; + } else { + scoreDocs = new ScoreDoc[] { new ScoreDoc(0, 1.0f) }; + } + + TopDocs topDocs = new TopDocs(new TotalHits(1, TotalHits.Relation.EQUAL_TO), scoreDocs); + TopDocsAndMaxScore topDocsAndMaxScore = new TopDocsAndMaxScore(topDocs, 1.0f); + result.topDocs(topDocsAndMaxScore, new DocValueFormat[0]); + + return result; + } +} diff --git a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java index 4f154e78b..5107296c3 100644 --- a/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java +++ b/src/testFixtures/java/org/opensearch/neuralsearch/BaseNeuralSearchIT.java @@ -91,6 +91,7 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ); private static final Set SUCCESS_STATUSES = Set.of(RestStatus.CREATED, RestStatus.OK); protected static final String CONCURRENT_SEGMENT_SEARCH_ENABLED = "search.concurrent_segment_search.enabled"; + protected static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); @@ -1468,4 +1469,31 @@ protected enum ProcessorType { TEXT_IMAGE_EMBEDDING, SPARSE_ENCODING } + + @SneakyThrows + protected void createDefaultRRFSearchPipeline() { + String requestBody = XContentFactory.jsonBuilder() + .startObject() + .field("description", "Post processor for hybrid search") + .startArray("phase_results_processors") + .startObject() + .startObject("score-ranker-processor") + .startObject("combination") + .field("technique", "rrf") + .endObject() + .endObject() + .endObject() + .endArray() + .endObject() + .toString(); + + makeRequest( + client(), + "PUT", + String.format(LOCALE, "/_search/pipeline/%s", RRF_SEARCH_PIPELINE), + null, + toHttpEntity(String.format(LOCALE, requestBody)), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, DEFAULT_USER_AGENT)) + ); + } }