Skip to content

Commit

Permalink
Adding inference processor and factory, register that in plugin class
Browse files Browse the repository at this point in the history
Signed-off-by: Martin Gaievski <[email protected]>
  • Loading branch information
martin-gaievski committed Sep 29, 2023
1 parent 848ec21 commit c61db6a
Show file tree
Hide file tree
Showing 6 changed files with 599 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
Expand All @@ -21,6 +22,7 @@
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.output.model.ModelResultFilter;
Expand Down Expand Up @@ -140,7 +142,7 @@ private void retryableInferenceSentencesWithVectorResult(
final int retryTime,
final ActionListener<List<List<Float>>> listener
) {
MLInput mlInput = createMLInput(targetResponseFilters, inputText);
MLInput mlInput = createMLTextInput(targetResponseFilters, inputText);
mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> {
final List<List<Float>> vector = buildVectorFromResponse(mlOutput);
listener.onResponse(vector);
Expand All @@ -154,6 +156,12 @@ private void retryableInferenceSentencesWithVectorResult(
}));
}

private MLInput createMLTextInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset);
}

private MLInput createMLInput(final List<String> targetResponseFilters, List<String> inputText) {
final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null);
final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter);
Expand Down Expand Up @@ -190,5 +198,4 @@ private List<List<Float>> buildVectorFromResponse(MLOutput mlOutput) {
}
return resultMaps;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,14 @@
import org.opensearch.ingest.Processor;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor;
import org.opensearch.neuralsearch.processor.InferenceProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessor;
import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow;
import org.opensearch.neuralsearch.processor.SparseEncodingProcessor;
import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor;
import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory;
import org.opensearch.neuralsearch.processor.combination.ScoreCombiner;
import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory;
import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory;
Expand Down Expand Up @@ -102,7 +104,9 @@ public Map<String, Processor.Factory> getProcessors(Processor.Parameters paramet
TextEmbeddingProcessor.TYPE,
new TextEmbeddingProcessorFactory(clientAccessor, parameters.env),
SparseEncodingProcessor.TYPE,
new SparseEncodingProcessorFactory(clientAccessor, parameters.env)
new SparseEncodingProcessorFactory(clientAccessor, parameters.env),
InferenceProcessor.TYPE,
new InferenceProcessorFactory(clientAccessor, parameters.env)
);
}

Expand Down
Loading

0 comments on commit c61db6a

Please sign in to comment.