Skip to content

Commit

Permalink
Adding inference processor and factory, register that in plugin class
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 23, 2023
1 parent adecea5 commit 84fb0d5
Show file tree
Hide file tree
Showing 4 changed files with 433 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
Expand All @@ -19,6 +20,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -103,14 +105,28 @@ public void inferenceSentences(
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
}

/**
* Call the ML predict API with multimodal input
* @param modelId
* @param inputObjects
* @param listener
*/
public void inferenceMultimodal(
@NonNull final String modelId,
@NonNull final Map<String, String> inputObjects,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceMultimodalWithRetry(modelId, inputObjects, 0, listener);
}

private void inferenceSentencesWithRetry(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
Expand All @@ -125,7 +141,7 @@ private void inferenceSentencesWithRetry(
}));
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
Expand All @@ -144,4 +160,29 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private void inferenceMultimodalWithRetry(
final String modelId,
final Map<String, String> inputObjects,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLMultimodalInput(inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceMultimodalWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLMultimodalInput(Map<String, String> input) {
final MLInputDataset inputDataset = new RemoteInferenceInputDataSet(input);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,13 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.InferenceGeneratorProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
Expand Down Expand Up @@ -94,7 +96,10 @@ public List<QuerySpec<?>> getQueries() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
return Map.of(TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
InferenceProcessor.TYPE,
new InferenceGeneratorProcessorFactory(clientAccessor, parameters.env));
}

@Override
Expand Down
Loading

0 comments on commit 84fb0d5

Please sign in to comment.