Skip to content

Commit

Permalink
Adding Reciprocal Rank Fusion (RRF) in hybrid query (opensearch-proje…
Browse files Browse the repository at this point in the history
…ct#1086)

* Reciprocal Rank Fusion (RRF) normalization technique in hybrid query (opensearch-project#874)

---------
Signed-off-by: Isaac Johnson <[email protected]>
Signed-off-by: Ryan Bogan <[email protected]>
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
ryanbogan committed Jan 14, 2025
1 parent 96fa384 commit 0b0e376
Show file tree
Hide file tree
Showing 46 changed files with 2,245 additions and 260 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.18...2.x)
### Features
- Pagination in Hybrid query ([#1048](https://github.com/opensearch-project/neural-search/pull/1048))
- Implement Reciprocal Rank Fusion score normalization/combination technique in hybrid query ([#874](https://github.com/opensearch-project/neural-search/pull/874))
### Enhancements
- Explainability in hybrid query ([#970](https://github.com/opensearch-project/neural-search/pull/970))
- Implement pruning for neural sparse ingestion pipeline and two phase search processor ([#988](https://github.com/opensearch-project/neural-search/pull/988))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ private void validateIndexQuery(final String modelId) {
.modelId(modelId)
.maxDistance(100000f)
.build();

Map<String, Object> responseWithMaxDistanceQuery = search(getIndexNameForTest(), neuralQueryBuilderWithMaxDistanceQuery, 1);
assertNotNull(responseWithMaxDistanceQuery);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,24 @@
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor;
import org.opensearch.neuralsearch.processor.NeuralSparseTwoPhaseProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.ExplanationResponseProcessor;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.TextChunkingProcessor;
import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.RRFProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.ExplanationResponseProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextChunkingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RRFProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
Expand Down Expand Up @@ -157,7 +159,9 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchPhaseR
) {
return Map.of(
NormalizationProcessor.TYPE,
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
new NormalizationProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory),
RRFProcessor.TYPE,
new RRFProcessorFactory(normalizationProcessorWorkflow, scoreNormalizationFactory, scoreCombinationFactory)
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import java.util.Optional;

/**
* Base class for all score hybridization processors. This class is responsible for executing the score hybridization process.
* It is a pipeline processor that is executed after the query phase and before the fetch phase.
*/
public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor {
/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult
* @param searchPhaseContext
* @param requestContextOptional
* @param <Result>
*/
abstract <Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ public SearchResponse processResponse(
);
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
finalScore,
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ Map<String, Object> buildMapWithTargetKeys(IngestDocument ingestDocument) {
buildNestedMap(originalKey, targetKey, sourceAndMetadataMap, treeRes);
mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey));
} else {
mapWithProcessorKeys.put(String.valueOf(targetKey), normalizeSourceValue(sourceAndMetadataMap.get(originalKey)));
mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey));
}
}
return mapWithProcessorKeys;
Expand Down Expand Up @@ -357,9 +357,8 @@ void buildNestedMap(String parentKey, Object processorKey, Map<String, Object> s
}
treeRes.merge(parentKey, next, REMAPPING_FUNCTION);
} else {
Object parentValue = sourceAndMetadataMap.get(parentKey);
String key = String.valueOf(processorKey);
treeRes.put(key, normalizeSourceValue(parentValue));
treeRes.put(key, sourceAndMetadataMap.get(parentKey));
}
}

Expand Down Expand Up @@ -404,7 +403,7 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) {
indexName,
clusterService,
environment,
true
false
);
}

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.query.QuerySearchResult;

import java.util.List;
import java.util.Optional;

/**
* DTO object to hold data in NormalizationProcessorWorkflow class
* in NormalizationProcessorWorkflow.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizationExecuteDTO {
@NonNull
private List<QuerySearchResult> querySearchResults;
@NonNull
private Optional<FetchSearchResult> fetchSearchResultOptional;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
@NonNull
private ScoreCombinationTechnique combinationTechnique;
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

import lombok.AllArgsConstructor;
Expand All @@ -33,7 +31,7 @@
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public class NormalizationProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
Expand All @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,12 @@
import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.neuralsearch.processor.combination.CombineScoresDto;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.explain.CombinedExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationPayload;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
Expand Down Expand Up @@ -57,44 +55,30 @@ public class NormalizationProcessorWorkflow {

/**
* Start execution of this workflow
* @param querySearchResults input data with QuerySearchResult from multiple shards
* @param normalizationTechnique technique for score normalization
* @param combinationTechnique technique for score combination
* @param request contains querySearchResults input data with QuerySearchResult
* from multiple shards, fetchSearchResultOptional, normalizationTechnique technique for score normalization
* combinationTechnique technique for score combination, and nullable rankConstant only used in RRF technique
*/
public void execute(
final List<QuerySearchResult> querySearchResults,
final Optional<FetchSearchResult> fetchSearchResultOptional,
final ScoreNormalizationTechnique normalizationTechnique,
final ScoreCombinationTechnique combinationTechnique,
final SearchPhaseContext searchPhaseContext
) {
NormalizationProcessorWorkflowExecuteRequest request = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResultOptional)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.searchPhaseContext(searchPhaseContext)
.build();
execute(request);
}

public void execute(final NormalizationProcessorWorkflowExecuteRequest request) {
List<QuerySearchResult> querySearchResults = request.getQuerySearchResults();
Optional<FetchSearchResult> fetchSearchResultOptional = request.getFetchSearchResultOptional();

// save original state
List<Integer> unprocessedDocIds = unprocessedDocIds(querySearchResults);
List<Integer> unprocessedDocIds = unprocessedDocIds(request.getQuerySearchResults());

// pre-process data
log.debug("Pre-process query results");
List<CompoundTopDocs> queryTopDocs = getQueryTopDocs(querySearchResults);

explain(request, queryTopDocs);

// Data transfer object for score normalization used to pass nullable rankConstant which is only used in RRF
NormalizeScoresDTO normalizeScoresDTO = NormalizeScoresDTO.builder()
.queryTopDocs(queryTopDocs)
.normalizationTechnique(request.getNormalizationTechnique())
.build();

// normalize
log.debug("Do score normalization");
scoreNormalizer.normalizeScores(queryTopDocs, request.getNormalizationTechnique());
scoreNormalizer.normalizeScores(normalizeScoresDTO);

CombineScoresDto combineScoresDTO = CombineScoresDto.builder()
.queryTopDocs(queryTopDocs)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import lombok.AllArgsConstructor;
import lombok.Builder;
import lombok.Getter;
import lombok.NonNull;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;

import java.util.List;

/**
* DTO object to hold data required for score normalization.
*/
@AllArgsConstructor
@Builder
@Getter
public class NormalizeScoresDTO {
@NonNull
private List<CompoundTopDocs> queryTopDocs;
@NonNull
private ScoreNormalizationTechnique normalizationTechnique;
}
Loading

0 comments on commit 0b0e376

Please sign in to comment.