Skip to content

Commit

Permalink
add cross-encoder specific logic and factory
Browse files Browse the repository at this point in the history
Signed-off-by: HenryL27 <[email protected]>
  • Loading branch information
HenryL27 committed Dec 1, 2023
1 parent f7a7944 commit aceb846
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
import lombok.RequiredArgsConstructor;
import lombok.extern.log4j.Log4j2;

import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.TextSimilarityInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -133,6 +135,25 @@ public void inferenceSentences(
retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener);
}

/**
* Abstraction to call predict function of api of MLClient. It uses the custom model provided as modelId and the
* {@link FunctionName#TEXT_SIMILARITY}. The return will be sent via actionListener as a list of floats representing
* the similarity scores of the texts w.r.t. the query text, in the order of the input texts.
*
* @param modelId {@link String} ML-Commons Model Id
* @param queryText {@link String} The query to compare all the inputText to
* @param inputText {@link List} of {@link String} The texts to compare to the query
* @param listener {@link ActionListener} receives the result of the inference
*/
public void inferenceSimilarity(
@NonNull final String modelId,
@NonNull final String queryText,
@NonNull final List<String> inputText,
@NonNull final ActionListener<List<Float>> listener
) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, 0, listener);
}

private void retryableInferenceSentencesWithMapResult(
final String modelId,
final List<String> inputText,
Expand Down Expand Up @@ -174,12 +195,42 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private void retryableInferenceSimilarityWithVectorResult(
final String modelId,
final String queryText,
final List<String> inputText,
final int retryTime,
final ActionListener<List<Float>> listener
) {
MLInput mlInput = createMLTextPairsInput(queryText, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<Float> scores = buildVectorFromResponse(mlOutput).stream().map(v -> v.get(0)).collect(Collectors.toList());
listener.onResponse(scores);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
retryableInferenceSimilarityWithVectorResult(modelId, queryText, inputText, retryTime + 1, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLTextPairsInput(final List<Pair<String, String>> pairs) {
final MLInputDataset inputDataset = new TextSimilarityInputDataSet(pairs);
return new MLInput(FunctionName.TEXT_SIMILARITY, null, inputDataset);
}

private MLInput createMLTextPairsInput(final String query, final List<String> inputText) {
List<Pair<String, String>> pairs = inputText.stream().map(text -> Pair.of(query, text)).collect(Collectors.toList());
return createMLTextPairsInput(pairs);
}

private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
final List<List<Float>> vector = new ArrayList<>();
final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput;
Expand Down
22 changes: 22 additions & 0 deletions src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,17 @@
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.RerankProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer;
import org.opensearch.neuralsearch.processor.rerank.RerankProcessor;
import org.opensearch.neuralsearch.query.HybridQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralQueryBuilder;
import org.opensearch.neuralsearch.query.NeuralSparseQueryBuilder;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.neuralsearch.search.query.HybridQueryPhaseSearcher;
import org.opensearch.neuralsearch.util.NeuralSearchClusterUtil;
import org.opensearch.plugins.ActionPlugin;
Expand All @@ -57,6 +60,7 @@
import org.opensearch.script.ScriptService;
import org.opensearch.search.pipeline.SearchPhaseResultsProcessor;
import org.opensearch.search.pipeline.SearchRequestProcessor;
import org.opensearch.search.pipeline.SearchResponseProcessor;
import org.opensearch.search.query.QueryPhaseSearcher;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;
Expand Down Expand Up @@ -151,4 +155,22 @@ public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchReques
) {
return Map.of(NeuralQueryEnricherProcessor.TYPE, new NeuralQueryEnricherProcessor.Factory());
}

@Override
public Map<String, org.opensearch.search.pipeline.Processor.Factory<SearchResponseProcessor>> getResponseProcessors(
Parameters parameters
) {
return Map.of(RerankProcessor.TYPE, new RerankProcessorFactory(clientAccessor));
}

@Override
public List<SearchPlugin.SearchExtSpec<?>> getSearchExts() {
return List.of(
new SearchExtSpec<>(
RerankSearchExtBuilder.PARAM_FIELD_NAME,
in -> new RerankSearchExtBuilder(in),
parser -> RerankSearchExtBuilder.parse(parser)
)
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,21 @@

import java.util.Map;

import lombok.AllArgsConstructor;

import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.rerank.CrossEncoderRerankProcessor;
import org.opensearch.neuralsearch.processor.rerank.RerankType;
import org.opensearch.search.pipeline.Processor;
import org.opensearch.search.pipeline.SearchResponseProcessor;

@AllArgsConstructor
public class RerankProcessorFactory implements Processor.Factory<SearchResponseProcessor> {

public static final String RERANK_PROCESSOR_TYPE = "rerank";

private final MLCommonsClientAccessor clientAccessor;

@Override
public SearchResponseProcessor create(
final Map<String, Processor.Factory<SearchResponseProcessor>> processorFactories,
Expand All @@ -33,6 +43,29 @@ public SearchResponseProcessor create(
final Map<String, Object> config,
final Processor.PipelineContext pipelineContext
) {
return null;
RerankType type = findRerankType(config);
switch (type) {
case CROSS_ENCODER:
@SuppressWarnings("unchecked")
Map<String, String> rerankerConfig = (Map<String, String>) config.get(type.getLabel());
String modelId = rerankerConfig.get(CrossEncoderRerankProcessor.MODEL_ID_FIELD);
String rerankContext = rerankerConfig.get(CrossEncoderRerankProcessor.RERANK_CONTEXT_FIELD);
return new CrossEncoderRerankProcessor(description, tag, ignoreFailure, modelId, rerankContext, clientAccessor);
default:
throw new IllegalArgumentException("could not find constructor for reranker type " + type.getLabel());
}
}

private RerankType findRerankType(final Map<String, Object> config) throws IllegalArgumentException {
for (String key : config.keySet()) {
try {
RerankType attempt = RerankType.from(key);
return attempt;
} catch (IllegalArgumentException e) {
// Assume it's just a different field in the config, so don't do anything.
// If we get to the end and there were no valid RerankTypes, then we can panic.
}
}
throw new IllegalArgumentException("no rerank type found");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
Expand All @@ -32,10 +33,10 @@
import org.opensearch.core.xcontent.ToXContent;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.env.Environment;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.query.ext.RerankSearchExtBuilder;
import org.opensearch.search.SearchExtBuilder;
import org.opensearch.search.SearchHit;

public class CrossEncoderRerankProcessor extends RescoringRerankProcessor {

Expand All @@ -45,26 +46,22 @@ public class CrossEncoderRerankProcessor extends RescoringRerankProcessor {
public static final String RERANK_CONTEXT_FIELD = "rerank_context_field";

protected final String modelId;
protected final String rerank_context;
protected final String rerankContext;

protected final MLCommonsClientAccessor mlCommonsClientAccessor;

private final Environment environment;

public CrossEncoderRerankProcessor(
String description,
String tag,
boolean ignoreFailure,
String modelId,
String rerank_context,
MLCommonsClientAccessor mlCommonsClientAccessor,
Environment environment
String rerankContext,
MLCommonsClientAccessor mlCommonsClientAccessor
) {
super(RerankType.CROSS_ENCODER, description, tag, ignoreFailure);
this.modelId = modelId;
this.rerank_context = rerank_context;
this.rerankContext = rerankContext;
this.mlCommonsClientAccessor = mlCommonsClientAccessor;
this.environment = environment;
}

@Override
Expand Down Expand Up @@ -108,8 +105,21 @@ public void generateScoringContext(

@Override
public void rescoreSearchResponse(SearchResponse response, Map<String, Object> scoringContext, ActionListener<List<Float>> listener) {
// TODO Auto-generated method stub
throw new UnsupportedOperationException("Unimplemented method 'rescoreSearchResponse'");
List<String> contexts = new ArrayList<>();
for (SearchHit hit : response.getHits()) {
contexts.add(contextFromSearchHit(hit));
}
mlCommonsClientAccessor.inferenceSimilarity(modelId, (String) scoringContext.get(QUERY_TEXT_FIELD), contexts, listener);
}

private String contextFromSearchHit(final SearchHit hit) {
if (hit.getFields().containsKey(this.rerankContext)) {
return (String) hit.field(this.rerankContext).getValue();
} else if (hit.getSourceAsMap().containsKey(this.rerankContext)) {
return (String) hit.getSourceAsMap().get(this.rerankContext);
} else {
return null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@

public interface RerankProcessor extends SearchResponseProcessor {

public static final String TYPE = "rerank";

/**
* Generate the information that this processor needs in order to rerank.
* That could be as simple as grabbing a field from the search request or
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public SearchResponse processResponse(SearchRequest request, SearchResponse resp

@Override
public String getType() {
return "rerank-" + type.getLabel();
return TYPE;
}

@Override
Expand Down

0 comments on commit aceb846

Please sign in to comment.