Skip to content

Commit

Permalink
Add integration and unit tests for missing RRF coverage (#997)
Browse files Browse the repository at this point in the history
* Initial unit test implementation

Signed-off-by: Ryan Bogan <[email protected]>

---------
Signed-off-by: Ryan Bogan <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
ryanbogan authored and martin-gaievski committed Dec 17, 2024
1 parent 582e882 commit 627fcb4
Show file tree
Hide file tree
Showing 5 changed files with 366 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import java.util.Optional;

import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
Expand Down Expand Up @@ -99,7 +100,8 @@ public boolean isIgnoreFailure() {
return false;
}

private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
@VisibleForTesting
<Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPhaseResults<Result> searchPhaseResult) {
if (Objects.isNull(searchPhaseResult) || !(searchPhaseResult instanceof QueryPhaseResultConsumer queryPhaseResultConsumer)) {
return true;
}
Expand All @@ -112,7 +114,8 @@ private <Result extends SearchPhaseResult> boolean shouldSkipProcessor(SearchPha
* @param searchPhaseResult
* @return true if results are from hybrid query
*/
private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
@VisibleForTesting
boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
// check for delimiter at the end of the score docs.
return Objects.nonNull(searchPhaseResult.queryResult())
&& Objects.nonNull(searchPhaseResult.queryResult().topDocs())
Expand All @@ -121,17 +124,16 @@ private boolean isHybridQuery(final SearchPhaseResult searchPhaseResult) {
&& isHybridQueryStartStopElement(searchPhaseResult.queryResult().topDocs().topDocs.scoreDocs[0]);
}

private <Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(
final SearchPhaseResults<Result> results
) {
<Result extends SearchPhaseResult> List<QuerySearchResult> getQueryPhaseSearchResults(final SearchPhaseResults<Result> results) {
return results.getAtomicArray()
.asList()
.stream()
.map(result -> result == null ? null : result.queryResult())
.collect(Collectors.toList());
}

private <Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
@VisibleForTesting
<Result extends SearchPhaseResult> Optional<FetchSearchResult> getFetchSearchResults(
final SearchPhaseResults<Result> searchPhaseResults
) {
Optional<Result> optionalFirstSearchPhaseResult = searchPhaseResults.getAtomicArray().asList().stream().findFirst();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
import java.util.Map;
import java.util.Objects;

import lombok.AllArgsConstructor;
import lombok.Getter;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;
Expand Down Expand Up @@ -58,8 +60,8 @@ public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
for (ScoreDoc scoreDoc : subQueryTopDoc.scoreDocs) {
scoreDoc.score = normalizeSingleScore(
scoreDoc.score,
minMaxScores.minScoresPerSubquery()[j],
minMaxScores.maxScoresPerSubquery()[j]
minMaxScores.getMinScoresPerSubquery()[j],
minMaxScores.getMaxScoresPerSubquery()[j]
);
}
}
Expand Down Expand Up @@ -96,8 +98,8 @@ public Map<DocIdAtSearchShard, ExplanationDetails> explain(final List<CompoundTo
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(scoreDoc.doc, compoundQueryTopDocs.getSearchShard());
float normalizedScore = normalizeSingleScore(
scoreDoc.score,
minMaxScores.minScoresPerSubquery()[j],
minMaxScores.maxScoresPerSubquery()[j]
minMaxScores.getMinScoresPerSubquery()[j],
minMaxScores.getMaxScoresPerSubquery()[j]
);
normalizedScores.computeIfAbsent(docIdAtSearchShard, k -> new ArrayList<>()).add(normalizedScore);
scoreDoc.score = normalizedScore;
Expand Down Expand Up @@ -171,6 +173,10 @@ private float normalizeSingleScore(final float score, final float minScore, fina
/**
* Result class to hold min and max scores for each sub query
*/
private record MinMaxScores(float[] minScoresPerSubquery, float[] maxScoresPerSubquery) {
@AllArgsConstructor
@Getter
private class MinMaxScores {
float[] minScoresPerSubquery;
float[] maxScoresPerSubquery;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.SneakyThrows;
import org.opensearch.index.query.MatchQueryBuilder;
import org.opensearch.knn.index.query.KNNQueryBuilder;
import org.opensearch.neuralsearch.BaseNeuralSearchIT;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;

import static org.opensearch.neuralsearch.util.TestUtils.DELTA_FOR_SCORE_ASSERTION;
import static org.opensearch.neuralsearch.util.TestUtils.TEST_SPACE_TYPE;

public class RRFProcessorIT extends BaseNeuralSearchIT {

private int currentDoc = 1;
private static final String RRF_INDEX_NAME = "rrf-index";
private static final String RRF_SEARCH_PIPELINE = "rrf-search-pipeline";
private static final String RRF_INGEST_PIPELINE = "rrf-ingest-pipeline";

private static final int RRF_DIMENSION = 5;

@SneakyThrows
public void testRRF_whenValidInput_thenSucceed() {
try {
createPipelineProcessor(null, RRF_INGEST_PIPELINE, ProcessorType.TEXT_EMBEDDING);
prepareKnnIndex(
RRF_INDEX_NAME,
Collections.singletonList(new KNNFieldConfig("passage_embedding", RRF_DIMENSION, TEST_SPACE_TYPE))
);
addDocuments();
createDefaultRRFSearchPipeline();

HybridQueryBuilder hybridQueryBuilder = getHybridQueryBuilder();

Map<String, Object> results = search(
RRF_INDEX_NAME,
hybridQueryBuilder,
null,
5,
Map.of("search_pipeline", RRF_SEARCH_PIPELINE)
);
Map<String, Object> hits = (Map<String, Object>) results.get("hits");
ArrayList<HashMap<String, Object>> hitsList = (ArrayList<HashMap<String, Object>>) hits.get("hits");
assertEquals(3, hitsList.size());
assertEquals(0.016393442, (Double) hitsList.getFirst().get("_score"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.016129032, (Double) hitsList.get(1).get("_score"), DELTA_FOR_SCORE_ASSERTION);
assertEquals(0.015873017, (Double) hitsList.getLast().get("_score"), DELTA_FOR_SCORE_ASSERTION);
} finally {
wipeOfTestResources(RRF_INDEX_NAME, RRF_INGEST_PIPELINE, null, RRF_SEARCH_PIPELINE);
}
}

private HybridQueryBuilder getHybridQueryBuilder() {
MatchQueryBuilder matchQueryBuilder = new MatchQueryBuilder("text", "cowboy rodeo bronco");
KNNQueryBuilder knnQueryBuilder = new KNNQueryBuilder.Builder().fieldName("passage_embedding")
.k(5)
.vector(new float[] { 0.1f, 1.2f, 2.3f, 3.4f, 4.5f })
.build();

HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder();
hybridQueryBuilder.add(matchQueryBuilder);
hybridQueryBuilder.add(knnQueryBuilder);
return hybridQueryBuilder;
}

@SneakyThrows
private void addDocuments() {
addDocument(
"A West Virginia university women 's basketball team , officials , and a small gathering of fans are in a West Virginia arena .",
"4319130149.jpg"
);
addDocument("A wild animal races across an uncut field with a minimal amount of trees .", "1775029934.jpg");
addDocument(
"People line the stands which advertise Freemont 's orthopedics , a cowboy rides a light brown bucking bronco .",
"2664027527.jpg"
);
addDocument("A man who is riding a wild horse in the rodeo is very near to falling off .", "4427058951.jpg");
addDocument("A rodeo cowboy , wearing a cowboy hat , is being thrown off of a wild white horse .", "2691147709.jpg");
}

@SneakyThrows
private void addDocument(String description, String imageText) {
addDocument(RRF_INDEX_NAME, String.valueOf(currentDoc++), "text", description, "image_text", imageText);
}
}
Loading

0 comments on commit 627fcb4

Please sign in to comment.