Skip to content

Commit

Permalink
Adding support for generic re-ranker interface and opensearch ml re-r…
Browse files Browse the repository at this point in the history
…anker for improving search relavancy. (opensearch-project#494)

* Add rerank processor interfaces

Signed-off-by: HenryL27 <[email protected]>

* add cross-encoder specific logic and factory

Signed-off-by: HenryL27 <[email protected]>

* add unittests

Signed-off-by: HenryL27 <[email protected]>

* add integration test

Signed-off-by: HenryL27 <[email protected]>

* use string.format() instead of concatenation

Signed-off-by: HenryL27 <[email protected]>

* rename generateScoringContext to generateRerankingContext

Signed-off-by: HenryL27 <[email protected]>

* add name change in test too. whoops

Signed-off-by: HenryL27 <[email protected]>

* start refactoring with contextSaourceFetchers

Signed-off-by: HenryL27 <[email protected]>

* refactor to use contextSourceFetchers to get context

Signed-off-by: HenryL27 <[email protected]>

* rename CrossEncoder to TextSimilarity

Signed-off-by: HenryL27 <[email protected]>

* add query_context layer to search ext

Signed-off-by: HenryL27 <[email protected]>

* add javadocs

Signed-off-by: HenryL27 <[email protected]>

* update to new asyncProcessResponse api

Signed-off-by: HenryL27 <[email protected]>

* rename reranktype to ML_OPENSEARCH

Signed-off-by: HenryL27 <[email protected]>

* improve error messages for bad rerank type config

Signed-off-by: HenryL27 <[email protected]>

* simplify configuration/factory logic

Signed-off-by: HenryL27 <[email protected]>

* improve handling for non-flat-string context fields

Signed-off-by: HenryL27 <[email protected]>

* rename TextSimilarity files to MLOpenSearch files

Signed-off-by: HenryL27 <[email protected]>

* apply spotless after rebase

Signed-off-by: HenryL27 <[email protected]>

* update changelog

Signed-off-by: HenryL27 <[email protected]>

* after rebase

Signed-off-by: HenryL27 <[email protected]>

* Address pr comments and fix XContent in search ext

Signed-off-by: HenryL27 <[email protected]>

* move contextSourceFetchers to their own subdirectory

Signed-off-by: HenryL27 <[email protected]>

* Apply suggestions from code review

Co-authored-by: Martin Gaievski <[email protected]>
Signed-off-by: HenryL27 <[email protected]>

* CR changes

Signed-off-by: HenryL27 <[email protected]>

* finish CR comments and fix broken unittest

Signed-off-by: HenryL27 <[email protected]>

* fix unittest names

Signed-off-by: HenryL27 <[email protected]>

---------

Signed-off-by: HenryL27 <[email protected]>
Co-authored-by: Martin Gaievski <[email protected]>
  • Loading branch information
2 people authored and ylwu-amzn committed Jan 26, 2024
1 parent 26699ea commit 8c2de82
Show file tree
Hide file tree
Showing 20 changed files with 1,811 additions and 37 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.11...2.x)
### Features
- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494))
### Enhancements
### Bug Fixes
- Fixed exception for case when Hybrid query being wrapped into bool query ([#490](https://github.com/opensearch-project/neural-search/pull/490))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -137,6 +138,25 @@ public void inferenceSentences(
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
*
* @param modelId {@link String} ML-Commons Model Id
* @param queryText {@link String} The query to compare all the inputText to
* @param inputText {@link List} of {@link String} The texts to compare to the query
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(
@NonNull final String modelId,
@NonNull final String queryText,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
Expand Down Expand Up @@ -178,12 +198,37 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLTextPairsInput(final String query, final List<String> inputText) {
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Expand Down
22 changes: 22 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand All @@ -54,6 +57,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -150,4 +154,22 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchReques
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List.of(
new SearchExtSpec<>(
RerankSearchExtBuilder.PARAM_FIELD_NAME,
in -> new RerankSearchExtBuilder(in),
parser -> RerankSearchExtBuilder.parse(parser)
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.factory;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.collect.Sets;

import lombok.AllArgsConstructor;

/**
* Factory for rerank processors. Must:
* - Instantiate the right kind of rerank processor
* - Instantiate the appropriate context source fetchers
*/
@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

public static final String RERANK_PROCESSOR_TYPE = "rerank";
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;

@Override
public SearchResponseProcessor create(
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
final String tag,
final String description,
final boolean ignoreFailure,
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));
}
}

private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
// Set of rerank type labels in the config
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet());
// A rerank type must be provided
if (rerankTypes.size() == 0) {
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]");
for (RerankType t : RerankType.values()) {
msgBuilder.add(t.getLabel());
}
throw new IllegalArgumentException(msgBuilder.toString());
}
// Only one rerank type may be provided
if (rerankTypes.size() > 1) {
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted.");
rerankTypes.forEach(rt -> msgBuilder.add(rt));
throw new IllegalArgumentException(msgBuilder.toString());
}
return RerankType.from(rerankTypes.iterator().next());
}

/**
* Factory class for context fetchers. Constructs a list of context fetchers
* specified in the pipeline config (and maybe the query context fetcher)
*/
private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
* @param type the constructing RerankType
* @return does this RerankType depend on the QueryContextSourceFetcher?
*/
public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
return type == RerankType.ML_OPENSEARCH;
}

/**
* Create necessary queryContextFetchers for this processor
* @param config processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher());
}
return fetchers;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;

/**
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore
*/
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";

protected final String modelId;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;

/**
* Constructor
* @param description
* @param tag
* @param ignoreFailure
* @param modelId id of TEXT_SIMILARITY model
* @param contextSourceFetchers
* @param mlCommonsClientAccessor
*/
public MLOpenSearchRerankProcessor(
final String description,
final String tag,
final boolean ignoreFailure,
final String modelId,
final List<ContextSourceFetcher> contextSourceFetchers,
final MLCommonsClientAccessor mlCommonsClientAccessor
) {
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers);
this.modelId = modelId;
this.mlCommonsClientAccessor = mlCommonsClientAccessor;
}

@Override
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD);
if (!(ctxObj instanceof List<?>)) {
listener.onFailure(
new IllegalStateException(
String.format(
Locale.ROOT,
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?",
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
DocumentContextSourceFetcher.NAME
)
)
);
return;
}
List<?> ctxList = (List<?>) ctxObj;
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
mlCommonsClientAccessor.inferenceSimilarity(
modelId,
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
contexts,
listener
);
}

}
Loading

0 comments on commit 8c2de82

Please sign in to comment.