From c198181a3336c47200e34b3ddee3a35582fc64b3 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 22 Sep 2023 17:54:28 -0700 Subject: [PATCH] Adding inference processor and factory, register that in plugin class Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessor.java | 48 ++- .../neuralsearch/plugin/NeuralSearch.java | 10 +- .../processor/InferenceProcessor.java | 355 ++++++++++++++++++ .../factory/InferenceProcessorFactory.java | 36 ++ .../processor/InferenceProcessorTests.java | 141 +++++++ .../InferenceProcessorFactoryTests.java | 48 +++ 6 files changed, 634 insertions(+), 4 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 768584ec9..0a07aea42 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -7,7 +7,9 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.stream.Collectors; import lombok.NonNull; @@ -19,6 +21,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -103,6 +106,20 @@ public void inferenceSentences( inferenceSentencesWithRetry(targetResponseFilters, modelId, inputText, 0, listener); } + /** + * Call the ML predict API with multimodal input + * @param modelId + * @param inputObjects + * @param listener + */ + public void inferenceMultimodal( + @NonNull final String modelId, + @NonNull final Map> inputObjects, + @NonNull final ActionListener>> listener + ) { + inferenceMultimodalWithRetry(modelId, inputObjects, 0, listener); + } + private void inferenceSentencesWithRetry( final List targetResponseFilters, final String modelId, @@ -110,7 +127,7 @@ private void inferenceSentencesWithRetry( final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(targetResponseFilters, inputText); + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); log.debug("Inference Response for input sentence {} is : {} ", inputText, vector); @@ -125,7 +142,7 @@ private void inferenceSentencesWithRetry( })); } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); @@ -144,4 +161,31 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return vector; } + private void inferenceMultimodalWithRetry( + final String modelId, + final Map> inputObjects, + final int retryTime, + final ActionListener>> listener + ) { + MLInput mlInput = createMLMultimodalInput(inputObjects); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List> vector = buildVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + inferenceMultimodalWithRetry(modelId, inputObjects, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private MLInput createMLMultimodalInput(Map> input) { + Map remoteInferenceInput = new HashMap<>(); + input.forEach((key, value) -> remoteInferenceInput.put(value.get("model_input"), value.get("value"))); + final MLInputDataset inputDataset = new RemoteInferenceInputDataSet(remoteInferenceInput); + return new MLInput(FunctionName.REMOTE, null, inputDataset); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index e94a2957d..bcaca2785 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -9,7 +9,6 @@ import java.util.Arrays; import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Optional; @@ -29,11 +28,13 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; @@ -94,7 +95,12 @@ public List> getQueries() { @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); - return Collections.singletonMap(TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env)); + return Map.of( + TextEmbeddingProcessor.TYPE, + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + InferenceProcessor.TYPE, + new InferenceProcessorFactory(clientAccessor, parameters.env) + ); } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java new file mode 100644 index 000000000..08f1bb7e0 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -0,0 +1,355 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; + +/** + * This processor is used for getting embeddings for multimodal type of inference, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. + */ +@Log4j2 +public class InferenceProcessor extends AbstractProcessor { + + public static final String TYPE = "inference-processor"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + + private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + + @VisibleForTesting + private final String modelId; + + private final Map fieldMap; + + private final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + private static final int MAX_CONTENT_LENGTH_IN_BYTES = 10 * 1024 * 1024; // limit of 10Mb per field value + + public InferenceProcessor( + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateEmbeddingConfiguration(fieldMap); + + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.isEmpty() + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + )) { + throw new IllegalArgumentException("Unable to create the InferenceProcessor processor as field_map has invalid key or value"); + } + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + // When received a bulk indexing request, the pipeline will be executed in this method, (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). + // Before the pipeline execution, the pipeline will be marked as resolved (means executed), + // and then this overriding method will be invoked when executing the text embedding processor. + // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. + try { + validateEmbeddingFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map> inferenceMap = createInferenceMap(knnMap); + if (inferenceMap.isEmpty()) { + handler.accept(ingestDocument, null); + } else { + mlCommonsClientAccessor.inferenceMultimodal(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, knnMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + } catch (Exception e) { + handler.accept(null, e); + } + + } + + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { + Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); + textEmbeddingResult.forEach(ingestDocument::setFieldValue); + } + + private Map> createInferenceMap(Map knnKeyMap) { + Map> objects = new HashMap<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof Map) { + Map sourceValues = (Map) sourceValue; + if (sourceValues.entrySet().stream().anyMatch(entry -> entry.getKey().length() > MAX_CONTENT_LENGTH_IN_BYTES)) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) + ); + } + objects.put(knnMapEntry.getKey(), sourceValues); + } else { + throw new RuntimeException("Cannot build inference object"); + } + }); + return objects; + } + + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); + } + }); + return texts; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else { + if (sourceValue == null) return; + texts.add(sourceValue.toString()); + } + } + + @VisibleForTesting + Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithKnnKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + // Map treeRes = new LinkedHashMap<>(); + // buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + // mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); + Map knnMap = Map.of( + "value", + sourceAndMetadataMap.get(originalKey).toString(), + "model_input", + ((Map) targetKey).get("model_input").toString() + ); + mapWithKnnKeys.put(originalKey, knnMap); + } else { + mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + } + } + return mapWithKnnKeys; + } + + @SuppressWarnings({ "unchecked" }) + private void buildMapWithKnnKeyAndOriginalValueForMapType( + String parentKey, + Object knnKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (knnKey == null || sourceAndMetadataMap == null) return; + if (knnKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) knnKey).entrySet()) { + buildMapWithKnnKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); + } else { + String key = String.valueOf(knnKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + } + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildTextEmbeddingResult( + Map knnMap, + List> modelTensorList, + Map sourceAndMetadataMap + ) { + IndexWrapper indexWrapper = new IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : knnMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + List modelTensor = modelTensorList.get(indexWrapper.index++); + result.put(knnKey, modelTensor); + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putTextEmbeddingResultToSourceMapForMapType( + String knnKey, + Object sourceValue, + List> modelTensorList, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putTextEmbeddingResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + modelTensorList, + indexWrapper, + (Map) sourceAndMetadataMap.get(knnKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put( + knnKey, + buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) + ); + } + } + + private List>> buildTextEmbeddingResultForListType( + List sourceValue, + List> modelTensorList, + IndexWrapper indexWrapper + ) { + List>> numbers = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); + return numbers; + } + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private static void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + } + } + } + + @Override + public String getType() { + return TYPE; + } + + /** + * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * the index pointer during the recursive. + * index: the index pointer of the text embedding result. + */ + static class IndexWrapper { + private int index; + + protected IndexWrapper(int index) { + this.index = index; + } + } + +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java new file mode 100644 index 000000000..b2dd1fc28 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.ingest.Processor.Factory; + +import java.util.Map; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; + +public class InferenceProcessorFactory implements Factory { + + private final MLCommonsClientAccessor clientAccessor; + + private final Environment environment; + + public InferenceProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public InferenceProcessor create(Map registry, String processorTag, String description, Map config) + throws Exception { + String modelId = readStringProperty(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.MODEL_ID_FIELD); + Map filedMap = readMap(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.FIELD_MAP_FIELD); + return new InferenceProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java new file mode 100644 index 000000000..60d0dc143 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class InferenceProcessorTests extends OpenSearchTestCase { + + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private Environment env; + + @InjectMocks + private InferenceProcessorFactory inferenceProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + } + + @SneakyThrows + private InferenceProcessor createInstance(List> vector) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + "key1", + Map.of("model_input", "ModelInput1", "model_output", "ModelOutput1", "embedding", "key1Mapped"), + "key2", + Map.of("model_input", "ModelInput2", "model_output", "ModelOutput2", "embedding", "key2Mapped") + ) + ); + return inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + try { + inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } catch (OpenSearchParseException e) { + assertEquals("[field_map] required property is missing", e.getMessage()); + } + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + InferenceProcessor processor = createInstance(createMockVectorWithLength(2)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceMultimodal(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + private List> createMockVectorResult() { + List> modelTensorList = new ArrayList<>(); + List number1 = ImmutableList.of(1.234f, 2.354f); + List number2 = ImmutableList.of(3.234f, 4.354f); + List number3 = ImmutableList.of(5.234f, 6.354f); + List number4 = ImmutableList.of(7.234f, 8.354f); + List number5 = ImmutableList.of(9.234f, 10.354f); + List number6 = ImmutableList.of(11.234f, 12.354f); + List number7 = ImmutableList.of(13.234f, 14.354f); + modelTensorList.add(number1); + modelTensorList.add(number2); + modelTensorList.add(number3); + modelTensorList.add(number4); + modelTensorList.add(number5); + modelTensorList.add(number6); + modelTensorList.add(number7); + return modelTensorList; + } + + private List> createMockVectorWithLength(int size) { + float suffix = .234f; + List> result = new ArrayList<>(); + for (int i = 0; i < size * 2;) { + List number = new ArrayList<>(); + number.add(i++ + suffix); + number.add(i++ + suffix); + result.add(number); + } + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java new file mode 100644 index 000000000..55aba5443 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.InferenceProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.InferenceProcessor.MODEL_ID_FIELD; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class InferenceProcessorFactoryTests extends OpenSearchTestCase { + + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; + + @SneakyThrows + public void testNormalizationProcessor_whenNoParams_thenSuccessful() { + InferenceProcessorFactory inferenceProcessorFactory = new InferenceProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, "1234567678"); + config.put( + FIELD_MAP_FIELD, + Map.of("passage_text", Map.of("model_input", "TextInput1", "model_output", "TextEmbdedding1", "embedding", "passage_embedding")) + ); + InferenceProcessor inferenceProcessor = inferenceProcessorFactory.create(processorFactories, tag, description, config); + assertNotNull(inferenceProcessor); + assertEquals("inference-processor", inferenceProcessor.getType()); + } +}