diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 1c09f5996..7cb188035 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -18,12 +18,14 @@ import lombok.RequiredArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -133,6 +135,25 @@ public void inferenceSentences( retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); } + /** + * Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the + * {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing + * the similarity scores of the texts w.r.t. the query text, in the order of the input texts. + * + * @param modelId {@link String} ML-Commons Model Id + * @param queryText {@link String} The query to compare all the inputText to + * @param inputText {@link List} of {@link String} The texts to compare to the query + * @param listener {@link ActionListener} receives the result of the inference + */ + public void inferenceSimilarity( + @NonNull final String modelId, + @NonNull final String queryText, + @NonNull final List inputText, + @NonNull final ActionListener> listener + ) { + retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener); + } + private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, @@ -174,12 +195,42 @@ private void retryableInferenceSentencesWithVectorResult( })); } + private void retryableInferenceSimilarityWithVectorResult( + final String modelId, + final String queryText, + final List inputText, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLTextPairsInput(queryText, inputText); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList()); + listener.onResponse(scores); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener); + } else { + listener.onFailure(e); + } + })); + } + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } + private MLInput createMLTextPairsInput(final List> pairs) { + final MLInputDataset inputDataset = new TextSimilarityInputDataSet(pairs); + return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset); + } + + private MLInput createMLTextPairsInput(final String query, final List inputText) { + List> pairs = inputText.stream().map(text -> Pair.of(query, text)).collect(Collectors.toList()); + return createMLTextPairsInput(pairs); + } + private List> buildVectorFromResponse(MLOutput mlOutput) { final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 8672c6142..8c14ce494 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -37,14 +37,17 @@ import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; +import org.opensearch.neuralsearch.processor.rerank.RerankProcessor; import org.opensearch.neuralsearch.query.HybridQueryBuilder; import org.opensearch.neuralsearch.query.NeuralQueryBuilder; import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder; +import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher; import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil; import org.opensearch.plugins.ActionPlugin; @@ -57,6 +60,7 @@ import org.opensearch.script.ScriptService; import org.opensearch.search.pipeline.SearchPhaseResultsProcessor; import org.opensearch.search.pipeline.SearchRequestProcessor; +import org.opensearch.search.pipeline.SearchResponseProcessor; import org.opensearch.search.query.QueryPhaseSearcher; import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; @@ -151,4 +155,22 @@ public Map> getResponseProcessors( + Parameters parameters + ) { + return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor)); + } + + @Override + public List> getSearchExts() { + return List.of( + new SearchExtSpec<>( + RerankSearchExtBuilder.PARAM_FIELD_NAME, + in -> new RerankSearchExtBuilder(in), + parser -> RerankSearchExtBuilder.parse(parser) + ) + ); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java index 7449743ff..ed1d56b4b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/RerankProcessorFactory.java @@ -19,11 +19,21 @@ import java.util.Map; +import lombok.AllArgsConstructor; + +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor; +import org.opensearch.neuralsearch.processor.rerank.RerankType; import org.opensearch.search.pipeline.Processor; import org.opensearch.search.pipeline.SearchResponseProcessor; +@AllArgsConstructor public class RerankProcessorFactory implements Processor.Factory { + public static final String RERANK_PROCESSOR_TYPE = "rerank"; + + private final MLCommonsClientAccessor clientAccessor; + @Override public SearchResponseProcessor create( final Map> processorFactories, @@ -33,6 +43,29 @@ public SearchResponseProcessor create( final Map config, final Processor.PipelineContext pipelineContext ) { - return null; + RerankType type = findRerankType(config); + switch (type) { + case CROSS_ENCODER: + @SuppressWarnings("unchecked") + Map rerankerConfig = (Map) config.get(type.getLabel()); + String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD); + String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD); + return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor); + default: + throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel()); + } + } + + private RerankType findRerankType(final Map config) throws IllegalArgumentException { + for (String key : config.keySet()) { + try { + RerankType attempt = RerankType.from(key); + return attempt; + } catch (IllegalArgumentException e) { + // Assume it's just a different field in the config, so don't do anything. + // If we get to the end and there were no valid RerankTypes, then we can panic. + } + } + throw new IllegalArgumentException("no rerank type found"); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java index ea2152378..61193ea36 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/CrossEncoderRerankProcessor.java @@ -19,6 +19,7 @@ import java.io.PipedInputStream; import java.io.PipedOutputStream; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -32,10 +33,10 @@ import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder; import org.opensearch.search.SearchExtBuilder; +import org.opensearch.search.SearchHit; public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { @@ -45,26 +46,22 @@ public class CrossEncoderRerankProcessor extends RescoringRerankProcessor { public static final String RERANK_CONTEXT_FIELD = "rerank_context_field"; protected final String modelId; - protected final String rerank_context; + protected final String rerankContext; protected final MLCommonsClientAccessor mlCommonsClientAccessor; - private final Environment environment; - public CrossEncoderRerankProcessor( String description, String tag, boolean ignoreFailure, String modelId, - String rerank_context, - MLCommonsClientAccessor mlCommonsClientAccessor, - Environment environment + String rerankContext, + MLCommonsClientAccessor mlCommonsClientAccessor ) { super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure); this.modelId = modelId; - this.rerank_context = rerank_context; + this.rerankContext = rerankContext; this.mlCommonsClientAccessor = mlCommonsClientAccessor; - this.environment = environment; } @Override @@ -108,8 +105,21 @@ public void generateScoringContext( @Override public void rescoreSearchResponse(SearchResponse response, Map scoringContext, ActionListener> listener) { - // TODO Auto-generated method stub - throw new UnsupportedOperationException("Unimplemented method 'rescoreSearchResponse'"); + List contexts = new ArrayList<>(); + for (SearchHit hit : response.getHits()) { + contexts.add(contextFromSearchHit(hit)); + } + mlCommonsClientAccessor.inferenceSimilarity(modelId, (String) scoringContext.get(QUERY_TEXT_FIELD), contexts, listener); + } + + private String contextFromSearchHit(final SearchHit hit) { + if (hit.getFields().containsKey(this.rerankContext)) { + return (String) hit.field(this.rerankContext).getValue(); + } else if (hit.getSourceAsMap().containsKey(this.rerankContext)) { + return (String) hit.getSourceAsMap().get(this.rerankContext); + } else { + return null; + } } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java index 62ab61da4..d458c0ca2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RerankProcessor.java @@ -26,6 +26,8 @@ public interface RerankProcessor extends SearchResponseProcessor { + public static final String TYPE = "rerank"; + /** * Generate the information that this processor needs in order to rerank. * That could be as simple as grabbing a field from the search request or diff --git a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java index f88c02d0d..907c26c5d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/rerank/RescoringRerankProcessor.java @@ -48,7 +48,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp @Override public String getType() { - return "rerank-" + type.getLabel(); + return TYPE; } @Override