Skip to content

Commit

Permalink
Integrate explainability for hybrid query into RRF processor (#1037)
Browse files Browse the repository at this point in the history
* Integrate explainability for hybrid query into RRF processor

Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski authored Dec 23, 2024
1 parent 627fcb4 commit 1d93192
Show file tree
Hide file tree
Showing 14 changed files with 708 additions and 91 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import java.util.Optional;

/**
* Base class for all score hybridization processors. This class is responsible for executing the score hybridization process.
* It is a pipeline processor that is executed after the query phase and before the fetch phase.
*/
public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor {
/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult
* @param searchPhaseContext
* @param requestContextOptional
* @param <Result>
*/
abstract <Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -111,8 +111,9 @@ public SearchResponse processResponse(
);
}
// Create and set final explanation combining all components
Float finalScore = Float.isNaN(searchHit.getScore()) ? 0.0f : searchHit.getScore();
Explanation finalExplanation = Explanation.match(
searchHit.getScore(),
finalScore,
// combination level explanation is always a single detail
combinationExplanation.getScoreDetails().get(0).getValue(),
normalizedExplanation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

import lombok.AllArgsConstructor;
Expand All @@ -33,7 +31,7 @@
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public class NormalizationProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
Expand All @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import java.util.Objects;
import java.util.Optional;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.query.QuerySearchResult;

import org.opensearch.action.search.QueryPhaseResultConsumer;
Expand All @@ -25,7 +26,6 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
Expand All @@ -39,7 +39,7 @@
*/
@Log4j2
@AllArgsConstructor
public class RRFProcessor implements SearchPhaseResultsProcessor {
public class RRFProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "score-ranker-processor";

@Getter
Expand All @@ -57,25 +57,28 @@ public class RRFProcessor implements SearchPhaseResultsProcessor {
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with RRF processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);

boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
&& searchPhaseContext.getRequest().source().explain();
// make data transfer object to pass in, execute will get object with 4 or 5 fields, depending
// on coming from NormalizationProcessor or RRFProcessor
NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,39 @@

import lombok.ToString;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import java.util.Map;
import java.util.List;
import java.util.Objects;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

@Log4j2
/**
* Abstracts combination of scores based on reciprocal rank fusion algorithm
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique {
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";

// Not currently using weights for RRF, no need to modify or verify these params
public RRFScoreCombinationTechnique(final Map<String, Object> params, final ScoreCombinationUtil combinationUtil) {}
public RRFScoreCombinationTechnique() {}

@Override
public float combine(final float[] scores) {
if (Objects.isNull(scores)) {
throw new IllegalArgumentException("scores array cannot be null");
}
float sumScores = 0.0f;
for (float score : scores) {
sumScores += score;
}
return sumScores;
}

@Override
public String describe() {
return describeCombinationTechnique(TECHNIQUE_NAME, List.of());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ public class ScoreCombinationFactory {
GeometricMeanScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new GeometricMeanScoreCombinationTechnique(params, scoreCombinationUtil),
RRFScoreCombinationTechnique.TECHNIQUE_NAME,
params -> new RRFScoreCombinationTechnique(params, scoreCombinationUtil)
params -> new RRFScoreCombinationTechnique()
);

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,34 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.ToString;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;

/**
* Abstracts calculation of rank scores for each document returned as part of
* reciprocal rank fusion. Rank scores are summed across subqueries in combination classes.
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique {
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";
public static final int DEFAULT_RANK_CONSTANT = 60;
Expand Down Expand Up @@ -58,21 +65,49 @@ public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNo
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (TopDocs topDocs : topDocsPerSubQuery) {
int docsCountPerSubQuery = topDocs.scoreDocs.length;
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
for (int j = 0; j < docsCountPerSubQuery; j++) {
// using big decimal approach to minimize error caused by floating point ops
// score = 1.f / (float) (rankConstant + j + 1))
scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP)
.floatValue();
}
}
processTopDocs(compoundQueryTopDocs, (docId, score) -> {});
}
}

@Override
public String describe() {
return String.format(Locale.ROOT, "%s, rank_constant [%s]", TECHNIQUE_NAME, rankConstant);
}

@Override
public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
processTopDocs(
compoundQueryTopDocs,
(docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score)
);
}

return getDocIdAtQueryForNormalization(normalizedScores, this);
}

private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer<DocIdAtSearchShard, Float> scoreProcessor) {
if (Objects.isNull(compoundQueryTopDocs)) {
return;
}

compoundQueryTopDocs.getTopDocs().forEach(topDocs -> {
IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> {
float normalizedScore = calculateNormalizedScore(position);
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(
topDocs.scoreDocs[position].doc,
compoundQueryTopDocs.getSearchShard()
);
scoreProcessor.accept(docIdAtSearchShard, normalizedScore);
topDocs.scoreDocs[position].score = normalizedScore;
});
});
}

private float calculateNormalizedScore(int position) {
return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue();
}

private int getRankConstant(final Map<String, Object> params) {
Expand All @@ -96,7 +131,7 @@ private void validateRankConstant(final int rankConstant) {
}
}

public static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
private static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
try {
return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName)));
} catch (NumberFormatException e) {
Expand Down
Loading

0 comments on commit 1d93192

Please sign in to comment.