From 84fb0d5ab34ec6172be12fddb119a8a90e07e4d8 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 22 Sep 2023 17:54:28 -0700 Subject: [PATCH] Adding inference processor and factory, register that in plugin class Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessor.java | 45 ++- .../neuralsearch/plugin/NeuralSearch.java | 7 +- .../processor/InferenceProcessor.java | 343 ++++++++++++++++++ .../InferenceGeneratorProcessorFactory.java | 41 +++ 4 files changed, 433 insertions(+), 3 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceGeneratorProcessorFactory.java diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 768584ec9..63d8779b0 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -8,6 +8,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import lombok.NonNull; @@ -19,6 +20,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -103,6 +105,20 @@ public void inferenceSentences( inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); } + /** + * Call the ML predict API with multimodal input + * @param modelId + * @param inputObjects + * @param listener + */ + public void inferenceMultimodal( + @NonNull final String modelId, + @NonNull final Map inputObjects, + @NonNull final ActionListener>> listener + ) { + inferenceMultimodalWithRetry(modelId, inputObjects, 0, listener); + } + private void inferenceSentencesWithRetry( final List targetResponseFilters, final String modelId, @@ -110,7 +126,7 @@ private void inferenceSentencesWithRetry( final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(targetResponseFilters, inputText); + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); @@ -125,7 +141,7 @@ private void inferenceSentencesWithRetry( })); } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); @@ -144,4 +160,29 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } + private void inferenceMultimodalWithRetry( + final String modelId, + final Map inputObjects, + final int retryTime, + final ActionListener>> listener + ) { + MLInput mlInput = createMLMultimodalInput(inputObjects); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> vector = buildVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + inferenceMultimodalWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private MLInput createMLMultimodalInput(Map input) { + final MLInputDataset inputDataset = new RemoteInferenceInputDataSet(input); + return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index e94a2957d..75383c0e1 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -29,11 +29,13 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.InferenceGeneratorProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; @@ -94,7 +96,10 @@ public List> getQueries() { @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); - return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); + return Map.of(TextEmbeddingProcessor.TYPE, + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + InferenceProcessor.TYPE, + new InferenceGeneratorProcessorFactory(clientAccessor, parameters.env)); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java new file mode 100644 index 000000000..7c39d22a8 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -0,0 +1,343 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; +import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +/** + * This processor is used for getting embeddings for multimodal type of inference, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. + */ +@Log4j2 +public class InferenceProcessor extends AbstractProcessor { + + public static final String TYPE = "inference_processor"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + + private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + + @VisibleForTesting + private final String modelId; + + private final Map fieldMap; + + private final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + public InferenceProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateEmbeddingConfiguration(fieldMap); + + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.isEmpty() + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + )) { + throw new IllegalArgumentException("Unable to create the InferenceProcessor processor as field_map has invalid key or value"); + } + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + // When received a bulk indexing request, the pipeline will be executed in this method, (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). + // Before the pipeline execution, the pipeline will be marked as resolved (means executed), + // and then this overriding method will be invoked when executing the text embedding processor. + // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. + try { + validateEmbeddingFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map inferenceMap = createInferenceMap(knnMap); + if (inferenceMap.isEmpty()) { + handler.accept(ingestDocument, null); + } else { + mlCommonsClientAccessor.inferenceMultimodal(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, knnMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + } catch (Exception e) { + handler.accept(null, e); + } + + } + + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { + Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); + textEmbeddingResult.forEach(ingestDocument::setFieldValue); + } + + private Map createInferenceMap(Map knnKeyMap) { + Map objects = new HashMap<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof Map) { + objects.putAll((Map) sourceValue); + } else { + throw new RuntimeException("Cannot build inference object"); + } + }); + return objects; + } + + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); + } + }); + return texts; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else { + if (sourceValue == null) return; + texts.add(sourceValue.toString()); + } + } + + @VisibleForTesting + Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithKnnKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + Map treeRes = new LinkedHashMap<>(); + buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); + } else { + mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + } + } + return mapWithKnnKeys; + } + + @SuppressWarnings({ "unchecked" }) + private void buildMapWithKnnKeyAndOriginalValueForMapType( + String parentKey, + Object knnKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (knnKey == null || sourceAndMetadataMap == null) return; + if (knnKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) knnKey).entrySet()) { + buildMapWithKnnKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); + } else { + String key = String.valueOf(knnKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + } + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildTextEmbeddingResult( + Map knnMap, + List> modelTensorList, + Map sourceAndMetadataMap + ) { + IndexWrapper indexWrapper = new IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : knnMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof String) { + List modelTensor = modelTensorList.get(indexWrapper.index++); + result.put(knnKey, modelTensor); + } else if (sourceValue instanceof List) { + result.put(knnKey, buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper)); + } else if (sourceValue instanceof Map) { + putTextEmbeddingResultToSourceMapForMapType(knnKey, sourceValue, modelTensorList, indexWrapper, sourceAndMetadataMap); + } + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putTextEmbeddingResultToSourceMapForMapType( + String knnKey, + Object sourceValue, + List> modelTensorList, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putTextEmbeddingResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + modelTensorList, + indexWrapper, + (Map) sourceAndMetadataMap.get(knnKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put( + knnKey, + buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) + ); + } + } + + private List>> buildTextEmbeddingResultForListType( + List sourceValue, + List> modelTensorList, + IndexWrapper indexWrapper + ) { + List>> numbers = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); + return numbers; + } + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private static void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + } + } + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * the index pointer during the recursive. + * index: the index pointer of the text embedding result. + */ + static class IndexWrapper { + private int index; + + protected IndexWrapper(int index) { + this.index = index; + } + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceGeneratorProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceGeneratorProcessorFactory.java new file mode 100644 index 000000000..5cb346b70 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceGeneratorProcessorFactory.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; +import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; + +import java.util.Map; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory; + +public class InferenceGeneratorProcessorFactory implements Factory { + + private final MLCommonsClientAccessor clientAccessor; + + private final Environment environment; + + public InferenceGeneratorProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public TextEmbeddingProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + String modelId = readStringProperty(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.MODEL_ID_FIELD); + Map filedMap = readMap(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.FIELD_MAP_FIELD); + return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + } +}