Skip to content

Commit

Permalink
Integrate explainability for hybrid query into RRF processor
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Dec 18, 2024
1 parent 627fcb4 commit 8632f9c
Show file tree
Hide file tree
Showing 8 changed files with 339 additions and 82 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor;

import org.opensearch.action.search.SearchPhaseContext;
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import java.util.Optional;

/**
* Base class for all score hybridization processors. This class is responsible for executing the score hybridization process.
* It is a pipeline processor that is executed after the query phase and before the fetch phase.
*/
public abstract class AbstractScoreHybridizationProcessor implements SearchPhaseResultsProcessor {
/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
hybridizeScores(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult
* @param searchPhaseContext
* @param requestContextOptional
* @param <Result>
*/
abstract <Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
);
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.query.QuerySearchResult;

import lombok.AllArgsConstructor;
Expand All @@ -33,7 +31,7 @@
*/
@Log4j2
@AllArgsConstructor
public class NormalizationProcessor implements SearchPhaseResultsProcessor {
public class NormalizationProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "normalization-processor";

private final String tag;
Expand All @@ -42,38 +40,8 @@ public class NormalizationProcessor implements SearchPhaseResultsProcessor {
private final ScoreCombinationTechnique combinationTechnique;
private final NormalizationProcessorWorkflow normalizationWorkflow;

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor. This method is called when there is no pipeline context
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.empty());
}

/**
* Method abstracts functional aspect of score normalization and score combination. Exact methods for each processing stage
* are set as part of class constructor
* @param searchPhaseResult {@link SearchPhaseResults} DTO that has query search results. Results will be mutated as part of this method execution
* @param searchPhaseContext {@link SearchContext}
* @param requestContext {@link PipelineProcessingContext} processing context of search pipeline
* @param <Result>
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext,
final PipelineProcessingContext requestContext
) {
prepareAndExecuteNormalizationWorkflow(searchPhaseResult, searchPhaseContext, Optional.ofNullable(requestContext));
}

private <Result extends SearchPhaseResult> void prepareAndExecuteNormalizationWorkflow(
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
import java.util.Objects;
import java.util.Optional;

import com.google.common.annotations.VisibleForTesting;
import lombok.Getter;
import org.opensearch.ml.repackage.com.google.common.annotations.VisibleForTesting;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationTechnique;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationTechnique;
import org.opensearch.search.fetch.FetchSearchResult;
import org.opensearch.search.pipeline.PipelineProcessingContext;
import org.opensearch.search.query.QuerySearchResult;

import org.opensearch.action.search.QueryPhaseResultConsumer;
Expand All @@ -25,7 +26,6 @@
import org.opensearch.action.search.SearchPhaseResults;
import org.opensearch.search.SearchPhaseResult;
import org.opensearch.search.internal.SearchContext;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;

import lombok.AllArgsConstructor;
import lombok.extern.log4j.Log4j2;
Expand All @@ -39,7 +39,7 @@
*/
@Log4j2
@AllArgsConstructor
public class RRFProcessor implements SearchPhaseResultsProcessor {
public class RRFProcessor extends AbstractScoreHybridizationProcessor {
public static final String TYPE = "score-ranker-processor";

@Getter
Expand All @@ -57,25 +57,28 @@ public class RRFProcessor implements SearchPhaseResultsProcessor {
* @param searchPhaseContext {@link SearchContext}
*/
@Override
public <Result extends SearchPhaseResult> void process(
final SearchPhaseResults<Result> searchPhaseResult,
final SearchPhaseContext searchPhaseContext
<Result extends SearchPhaseResult> void hybridizeScores(
SearchPhaseResults<Result> searchPhaseResult,
SearchPhaseContext searchPhaseContext,
Optional<PipelineProcessingContext> requestContextOptional
) {
if (shouldSkipProcessor(searchPhaseResult)) {
log.debug("Query results are not compatible with RRF processor");
return;
}
List<QuerySearchResult> querySearchResults = getQueryPhaseSearchResults(searchPhaseResult);
Optional<FetchSearchResult> fetchSearchResult = getFetchSearchResults(searchPhaseResult);

boolean explain = Objects.nonNull(searchPhaseContext.getRequest().source().explain())
&& searchPhaseContext.getRequest().source().explain();
// make data transfer object to pass in, execute will get object with 4 or 5 fields, depending
// on coming from NormalizationProcessor or RRFProcessor
NormalizationProcessorWorkflowExecuteRequest normalizationExecuteDTO = NormalizationProcessorWorkflowExecuteRequest.builder()
.querySearchResults(querySearchResults)
.fetchSearchResultOptional(fetchSearchResult)
.normalizationTechnique(normalizationTechnique)
.combinationTechnique(combinationTechnique)
.explain(false)
.explain(explain)
.pipelineProcessingContext(requestContextOptional.orElse(null))
.build();
normalizationWorkflow.execute(normalizationExecuteDTO);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,19 @@

import lombok.ToString;
import lombok.extern.log4j.Log4j2;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;

import java.util.List;
import java.util.Map;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.describeCombinationTechnique;

@Log4j2
/**
* Abstracts combination of scores based on reciprocal rank fusion algorithm
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique {
public class RRFScoreCombinationTechnique implements ScoreCombinationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";

Expand All @@ -29,4 +33,9 @@ public float combine(final float[] scores) {
}
return sumScores;
}

@Override
public String describe() {
return describeCombinationTechnique(TECHNIQUE_NAME, List.of());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,34 @@

import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Locale;
import java.util.Set;
import java.util.function.BiConsumer;
import java.util.stream.IntStream;

import org.apache.commons.lang3.Range;
import org.apache.commons.lang3.math.NumberUtils;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.TopDocs;
import org.opensearch.neuralsearch.processor.CompoundTopDocs;

import lombok.ToString;
import org.opensearch.neuralsearch.processor.NormalizeScoresDTO;
import org.opensearch.neuralsearch.processor.explain.DocIdAtSearchShard;
import org.opensearch.neuralsearch.processor.explain.ExplainableTechnique;
import org.opensearch.neuralsearch.processor.explain.ExplanationDetails;

import static org.opensearch.neuralsearch.processor.explain.ExplanationUtils.getDocIdAtQueryForNormalization;

/**
* Abstracts calculation of rank scores for each document returned as part of
* reciprocal rank fusion. Rank scores are summed across subqueries in combination classes.
*/
@ToString(onlyExplicitlyIncluded = true)
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique {
public class RRFNormalizationTechnique implements ScoreNormalizationTechnique, ExplainableTechnique {
@ToString.Include
public static final String TECHNIQUE_NAME = "rrf";
public static final int DEFAULT_RANK_CONSTANT = 60;
Expand Down Expand Up @@ -58,21 +65,49 @@ public RRFNormalizationTechnique(final Map<String, Object> params, final ScoreNo
public void normalize(final NormalizeScoresDTO normalizeScoresDTO) {
final List<CompoundTopDocs> queryTopDocs = normalizeScoresDTO.getQueryTopDocs();
for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
if (Objects.isNull(compoundQueryTopDocs)) {
continue;
}
List<TopDocs> topDocsPerSubQuery = compoundQueryTopDocs.getTopDocs();
for (TopDocs topDocs : topDocsPerSubQuery) {
int docsCountPerSubQuery = topDocs.scoreDocs.length;
ScoreDoc[] scoreDocs = topDocs.scoreDocs;
for (int j = 0; j < docsCountPerSubQuery; j++) {
// using big decimal approach to minimize error caused by floating point ops
// score = 1.f / (float) (rankConstant + j + 1))
scoreDocs[j].score = BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + j + 1), 10, RoundingMode.HALF_UP)
.floatValue();
}
}
processTopDocs(compoundQueryTopDocs, (docId, score) -> {});
}
}

@Override
public String describe() {
return String.format(Locale.ROOT, "%s, rank_constant %s", TECHNIQUE_NAME, rankConstant);
}

@Override
public Map<DocIdAtSearchShard, ExplanationDetails> explain(List<CompoundTopDocs> queryTopDocs) {
Map<DocIdAtSearchShard, List<Float>> normalizedScores = new HashMap<>();

for (CompoundTopDocs compoundQueryTopDocs : queryTopDocs) {
processTopDocs(
compoundQueryTopDocs,
(docId, score) -> normalizedScores.computeIfAbsent(docId, k -> new ArrayList<>()).add(score)
);
}

return getDocIdAtQueryForNormalization(normalizedScores, this);
}

private void processTopDocs(CompoundTopDocs compoundQueryTopDocs, BiConsumer<DocIdAtSearchShard, Float> scoreProcessor) {
if (Objects.isNull(compoundQueryTopDocs)) {
return;
}

compoundQueryTopDocs.getTopDocs().forEach(topDocs -> {
IntStream.range(0, topDocs.scoreDocs.length).forEach(position -> {
float normalizedScore = calculateNormalizedScore(position);
DocIdAtSearchShard docIdAtSearchShard = new DocIdAtSearchShard(
topDocs.scoreDocs[position].doc,
compoundQueryTopDocs.getSearchShard()
);
scoreProcessor.accept(docIdAtSearchShard, normalizedScore);
topDocs.scoreDocs[position].score = normalizedScore;
});
});
}

private float calculateNormalizedScore(int position) {
return BigDecimal.ONE.divide(BigDecimal.valueOf(rankConstant + position + 1), 10, RoundingMode.HALF_UP).floatValue();
}

private int getRankConstant(final Map<String, Object> params) {
Expand All @@ -96,7 +131,7 @@ private void validateRankConstant(final int rankConstant) {
}
}

public static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
private static int getParamAsInteger(final Map<String, Object> parameters, final String fieldName) {
try {
return NumberUtils.createInteger(String.valueOf(parameters.get(fieldName)));
} catch (NumberFormatException e) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TotalHits;
import org.junit.Before;
import org.mockito.ArgumentCaptor;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.opensearch.action.OriginalIndices;
Expand Down Expand Up @@ -111,6 +112,10 @@ public void testProcess_whenValidHybridInput_thenSucceed() {

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

SearchRequest searchRequest = new SearchRequest();
searchRequest.source(new SearchSourceBuilder());
when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

verify(mockNormalizationWorkflow).execute(any(NormalizationProcessorWorkflowExecuteRequest.class));
Expand Down Expand Up @@ -177,6 +182,34 @@ public void testGetFetchSearchResults() {
assertFalse(result.isPresent());
}

@SneakyThrows
public void testProcess_whenExplainIsTrue_thenExplanationIsAdded() {
QuerySearchResult result = createQuerySearchResult(true);
AtomicArray<SearchPhaseResult> atomicArray = new AtomicArray<>(1);
atomicArray.set(0, result);

when(mockQueryPhaseResultConsumer.getAtomicArray()).thenReturn(atomicArray);

SearchRequest searchRequest = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
sourceBuilder.explain(true);
searchRequest.source(sourceBuilder);
when(mockSearchPhaseContext.getRequest()).thenReturn(searchRequest);

rrfProcessor.process(mockQueryPhaseResultConsumer, mockSearchPhaseContext);

// Capture the actual request
ArgumentCaptor<NormalizationProcessorWorkflowExecuteRequest> requestCaptor = ArgumentCaptor.forClass(
NormalizationProcessorWorkflowExecuteRequest.class
);
verify(mockNormalizationWorkflow).execute(requestCaptor.capture());

// Verify the captured request
NormalizationProcessorWorkflowExecuteRequest capturedRequest = requestCaptor.getValue();
assertTrue(capturedRequest.isExplain());
assertNull(capturedRequest.getPipelineProcessingContext());
}

private QuerySearchResult createQuerySearchResult(boolean isHybrid) {
ShardId shardId = new ShardId("index", "uuid", 0);
OriginalIndices originalIndices = new OriginalIndices(new String[] { "index" }, IndicesOptions.strictExpandOpenAndForbidClosed());
Expand Down
Loading

0 comments on commit 8632f9c

Please sign in to comment.