Skip to content

Commit

Permalink
Adding inference processor and factory, register that in plugin class
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 25, 2023
1 parent adecea5 commit 4a8d596
Show file tree
Hide file tree
Showing 6 changed files with 634 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,9 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import lombok.NonNull;
Expand All @@ -19,6 +21,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -103,14 +106,28 @@ public void inferenceSentences(
inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener);
}

/**
* Call the ML predict API with multimodal input
* @param modelId
* @param inputObjects
* @param listener
*/
public void inferenceMultimodal(
@NonNull final String modelId,
@NonNull final Map<String, Map<String, String>> inputObjects,
@NonNull final ActionListener<List<List<Float>>> listener
) {
inferenceMultimodalWithRetry(modelId, inputObjects, 0, listener);
}

private void inferenceSentencesWithRetry(
final List<String> targetResponseFilters,
final String modelId,
final List<String> inputText,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputText, vector);
Expand All @@ -125,7 +142,7 @@ private void inferenceSentencesWithRetry(
}));
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
Expand All @@ -144,4 +161,31 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
return vector;
}

private void inferenceMultimodalWithRetry(
final String modelId,
final Map<String, Map<String, String>> inputObjects,
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLMultimodalInput(inputObjects);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector);
listener.onResponse(vector);
}, e -> {
if (RetryUtil.shouldRetry(e, retryTime)) {
final int retryTimeAdd = retryTime + 1;
inferenceMultimodalWithRetry(modelId, inputObjects, retryTimeAdd, listener);
} else {
listener.onFailure(e);
}
}));
}

private MLInput createMLMultimodalInput(Map<String, Map<String, String>> input) {
Map<String, String> remoteInferenceInput = new HashMap<>();
input.forEach((key, value) -> remoteInferenceInput.put(value.get("model_input"), value.get("value")));
final MLInputDataset inputDataset = new RemoteInferenceInputDataSet(remoteInferenceInput);
return new MLInput(FunctionName.REMOTE, null, inputDataset);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Optional;
Expand All @@ -29,11 +28,13 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory;
Expand Down Expand Up @@ -94,7 +95,12 @@ public List<QuerySpec<?>> getQueries() {
@Override
public Map<String, Processor.Factory> getProcessors(Processor.Parameters parameters) {
clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client));
return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env));
return Map.of(
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
InferenceProcessor.TYPE,
new InferenceProcessorFactory(clientAccessor, parameters.env)
);
}

@Override
Expand Down
Loading

0 comments on commit 4a8d596

Please sign in to comment.