Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding support for generic re-ranker interface and opensearch ml re-ranker for improving search relavancy. #494

Merged
merged 27 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
bdf5d9b
Add rerank processor interfaces
HenryL27 Nov 14, 2023
17cff65
add cross-encoder specific logic and factory
HenryL27 Nov 14, 2023
8d476db
add unittests
HenryL27 Nov 16, 2023
4efa463
add integration test
HenryL27 Nov 18, 2023
de96761
use string.format() instead of concatenation
HenryL27 Dec 1, 2023
6f85824
rename generateScoringContext to generateRerankingContext
HenryL27 Dec 1, 2023
a30180c
add name change in test too. whoops
HenryL27 Dec 1, 2023
b8820ec
start refactoring with contextSaourceFetchers
HenryL27 Dec 4, 2023
5e1c00b
refactor to use contextSourceFetchers to get context
HenryL27 Dec 5, 2023
2976807
rename CrossEncoder to TextSimilarity
HenryL27 Dec 5, 2023
5332fee
add query_context layer to search ext
HenryL27 Dec 5, 2023
aa1d524
add javadocs
HenryL27 Dec 5, 2023
77301d9
update to new asyncProcessResponse api
HenryL27 Dec 11, 2023
e8de412
rename reranktype to ML_OPENSEARCH
HenryL27 Dec 18, 2023
a7090b2
improve error messages for bad rerank type config
HenryL27 Dec 18, 2023
797eaf6
simplify configuration/factory logic
HenryL27 Dec 18, 2023
ddf2866
improve handling for non-flat-string context fields
HenryL27 Dec 18, 2023
14c8f89
rename TextSimilarity files to MLOpenSearch files
HenryL27 Dec 18, 2023
577f855
apply spotless after rebase
HenryL27 Dec 19, 2023
e3cf218
update changelog
HenryL27 Dec 21, 2023
7a6595f
after rebase
HenryL27 Jan 9, 2024
708fb66
Address pr comments and fix XContent in search ext
HenryL27 Jan 10, 2024
2d04075
move contextSourceFetchers to their own subdirectory
HenryL27 Jan 10, 2024
a39428b
Apply suggestions from code review
HenryL27 Jan 11, 2024
f462965
CR changes
HenryL27 Jan 11, 2024
db8bec1
finish CR comments and fix broken unittest
HenryL27 Jan 11, 2024
7962ffa
fix unittest names
HenryL27 Jan 11, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),

## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.11...2.x)
### Features
- Add rerank processor interface and ml-commons reranker ([#494](https://github.com/opensearch-project/neural-search/pull/494))
### Enhancements
### Bug Fixes
- Fixing multiple issues reported in #497 ([#524](https://github.com/opensearch-project/neural-search/pull/524))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -132,6 +133,25 @@ public void inferenceSentences(
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
*
* @param modelId {@link String} ML-Commons Model Id
* @param queryText {@link String} The query to compare all the inputText to
* @param inputText {@link List} of {@link String} The texts to compare to the query
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(
@NonNull final String modelId,
@NonNull final String queryText,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
Expand Down Expand Up @@ -173,12 +193,37 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLTextPairsInput(final String query, final List<String> inputText) {
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(query, inputText);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,17 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand All @@ -54,6 +57,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -150,4 +154,22 @@
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));

Check warning on line 162 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L162

Added line #L162 was not covered by tests
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List.of(

Check warning on line 167 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L167

Added line #L167 was not covered by tests
new SearchExtSpec<>(
RerankSearchExtBuilder.PARAM_FIELD_NAME,
in -> new RerankSearchExtBuilder(in),
parser -> RerankSearchExtBuilder.parse(parser)

Check warning on line 171 in src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java#L170-L171

Added lines #L170 - L171 were not covered by tests
)
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
HenryL27 marked this conversation as resolved.
Show resolved Hide resolved
package org.opensearch.neuralsearch.processor.factory;

import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.StringJoiner;

import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.MLOpenSearchRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

import com.google.common.collect.Sets;

import lombok.AllArgsConstructor;

/**
* Factory for rerank processors. Must:
* - Instantiate the right kind of rerank processor
* - Instantiate the appropriate context source fetchers
*/
@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

public static final String RERANK_PROCESSOR_TYPE = "rerank";
public static final String CONTEXT_CONFIG_FIELD = "context";

private final MLCommonsClientAccessor clientAccessor;

@Override
public SearchResponseProcessor create(
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
final String tag,
final String description,
final boolean ignoreFailure,
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) {
RerankType type = findRerankType(config);
boolean includeQueryContextFetcher = ContextFetcherFactory.shouldIncludeQueryContextFetcher(type);
List<ContextSourceFetcher> contextFetchers = ContextFetcherFactory.createFetchers(config, includeQueryContextFetcher, tag);
switch (type) {
case ML_OPENSEARCH:
Map<String, Object> rerankerConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, type.getLabel());
String modelId = ConfigurationUtils.readStringProperty(
RERANK_PROCESSOR_TYPE,
tag,
rerankerConfig,
MLOpenSearchRerankProcessor.MODEL_ID_FIELD
);
return new MLOpenSearchRerankProcessor(description, tag, ignoreFailure, modelId, contextFetchers, clientAccessor);
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "Cannot build reranker type %s", type.getLabel()));

Check warning on line 64 in src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java#L64

Added line #L64 was not covered by tests
}
}

private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
// Set of rerank type labels in the config
Set<String> rerankTypes = Sets.intersection(config.keySet(), RerankType.labelMap().keySet());
// A rerank type must be provided
if (rerankTypes.size() == 0) {
StringJoiner msgBuilder = new StringJoiner(", ", "No rerank type found. Possible rerank types are: [", "]");
for (RerankType t : RerankType.values()) {
msgBuilder.add(t.getLabel());
}
throw new IllegalArgumentException(msgBuilder.toString());
}
// Only one rerank type may be provided
if (rerankTypes.size() > 1) {
StringJoiner msgBuilder = new StringJoiner(", ", "Multiple rerank types found: [", "]. Only one is permitted.");
rerankTypes.forEach(rt -> msgBuilder.add(rt));
throw new IllegalArgumentException(msgBuilder.toString());

Check warning on line 83 in src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java

View check run for this annotation

Codecov / codecov/patch

src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java#L81-L83

Added lines #L81 - L83 were not covered by tests
}
return RerankType.from(rerankTypes.iterator().next());
}

/**
* Factory class for context fetchers. Constructs a list of context fetchers
* specified in the pipeline config (and maybe the query context fetcher)
*/
private static class ContextFetcherFactory {

/**
* Map rerank types to whether they should include the query context source fetcher
* @param type the constructing RerankType
* @return does this RerankType depend on the QueryContextSourceFetcher?
*/
public static boolean shouldIncludeQueryContextFetcher(RerankType type) {
return type == RerankType.ML_OPENSEARCH;
}

/**
* Create necessary queryContextFetchers for this processor
* @param config processor config object. Look for "context" field to find fetchers
* @param includeQueryContextFetcher should I include the queryContextFetcher?
* @return list of contextFetchers for the processor to use
*/
public static List<ContextSourceFetcher> createFetchers(
Map<String, Object> config,
boolean includeQueryContextFetcher,
String tag
) {
Map<String, Object> contextConfig = ConfigurationUtils.readMap(RERANK_PROCESSOR_TYPE, tag, config, CONTEXT_CONFIG_FIELD);
List<ContextSourceFetcher> fetchers = new ArrayList<>();
for (String key : contextConfig.keySet()) {
Object cfg = contextConfig.get(key);
switch (key) {
case DocumentContextSourceFetcher.NAME:
fetchers.add(DocumentContextSourceFetcher.create(cfg));
break;
default:
throw new IllegalArgumentException(String.format(Locale.ROOT, "unrecognized context field: %s", key));
}
}
if (includeQueryContextFetcher) {
fetchers.add(new QueryContextSourceFetcher());
}
navneet1v marked this conversation as resolved.
Show resolved Hide resolved
return fetchers;
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.neuralsearch.processor.rerank;

import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.stream.Collectors;

import org.opensearch.action.search.SearchResponse;
import org.opensearch.core.action.ActionListener;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.rerank.context.ContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.DocumentContextSourceFetcher;
import org.opensearch.neuralsearch.processor.rerank.context.QueryContextSourceFetcher;

/**
* Rescoring Rerank Processor that uses a TextSimilarity model in ml-commons to rescore
*/
public class MLOpenSearchRerankProcessor extends RescoringRerankProcessor {

public static final String MODEL_ID_FIELD = "model_id";

protected final String modelId;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;
Comment on lines +27 to +29
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why protected?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in case of future subclasses that want it. Maybe such a thing will never occur? idk.


/**
* Constructor
* @param description
* @param tag
* @param ignoreFailure
* @param modelId id of TEXT_SIMILARITY model
* @param contextSourceFetchers
* @param mlCommonsClientAccessor
*/
public MLOpenSearchRerankProcessor(
final String description,
final String tag,
final boolean ignoreFailure,
final String modelId,
final List<ContextSourceFetcher> contextSourceFetchers,
final MLCommonsClientAccessor mlCommonsClientAccessor
) {
super(RerankType.ML_OPENSEARCH, description, tag, ignoreFailure, contextSourceFetchers);
this.modelId = modelId;
this.mlCommonsClientAccessor = mlCommonsClientAccessor;
}

@Override
public void rescoreSearchResponse(
final SearchResponse response,
final Map<String, Object> rerankingContext,
final ActionListener<List<Float>> listener
) {
Object ctxObj = rerankingContext.get(DocumentContextSourceFetcher.DOCUMENT_CONTEXT_LIST_FIELD);
if (!(ctxObj instanceof List<?>)) {
listener.onFailure(
new IllegalStateException(
String.format(
Locale.ROOT,
"No document context found! Perhaps \"%s.%s\" is missing from the pipeline definition?",
RerankProcessorFactory.CONTEXT_CONFIG_FIELD,
DocumentContextSourceFetcher.NAME
)
)
);
return;
}
List<?> ctxList = (List<?>) ctxObj;
List<String> contexts = ctxList.stream().map(str -> (String) str).collect(Collectors.toList());
martin-gaievski marked this conversation as resolved.
Show resolved Hide resolved
mlCommonsClientAccessor.inferenceSimilarity(
modelId,
(String) rerankingContext.get(QueryContextSourceFetcher.QUERY_TEXT_FIELD),
contexts,
listener
);
}

}
Loading
Loading