From e458e9d65f3693f6874634174a95fc62d63b82f9 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 22 Sep 2023 17:54:28 -0700 Subject: [PATCH 01/14] Adding inference processor and factory, register that in plugin class Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessor.java | 11 +- .../neuralsearch/plugin/NeuralSearch.java | 6 +- .../processor/InferenceProcessor.java | 307 ++++++++++-------- .../factory/InferenceProcessorFactory.java | 36 ++ .../processor/InferenceProcessorTests.java | 141 ++++++++ .../InferenceProcessorFactoryTests.java | 48 +++ 6 files changed, 410 insertions(+), 139 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 6f8b790bb..e9dc342ec 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -7,6 +7,7 @@ import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -21,6 +22,7 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; +import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -140,7 +142,7 @@ private void retryableInferenceSentencesWithVectorResult( final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(targetResponseFilters, inputText); + MLInput mlInput = createMLTextInput(targetResponseFilters, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> vector = buildVectorFromResponse(mlOutput); listener.onResponse(vector); @@ -154,6 +156,12 @@ private void retryableInferenceSentencesWithVectorResult( })); } + private MLInput createMLTextInput(final List targetResponseFilters, List inputText) { + final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); + final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); + return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + } + private MLInput createMLInput(final List targetResponseFilters, List inputText) { final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); @@ -190,5 +198,4 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { } return resultMaps; } - } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index d69203f92..7bcafd781 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -29,12 +29,14 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; +import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; @@ -106,7 +108,9 @@ public Map getProcessors(Processor.Parameters paramet TextEmbeddingProcessor.TYPE, new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), SparseEncodingProcessor.TYPE, - new SparseEncodingProcessorFactory(clientAccessor, parameters.env) + new SparseEncodingProcessorFactory(clientAccessor, parameters.env), + InferenceProcessor.TYPE, + new InferenceProcessorFactory(clientAccessor, parameters.env) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index acf0eb32b..cc482fa2f 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -6,8 +6,10 @@ package org.opensearch.neuralsearch.processor; import java.util.ArrayList; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; @@ -17,6 +19,7 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; @@ -27,46 +30,41 @@ import com.google.common.collect.ImmutableMap; /** - * The abstract class for text processing use cases. Users provide a field name map and a model id. - * During ingestion, the processor will use the corresponding model to inference the input texts, - * and set the target fields according to the field name map. + * This processor is used for getting embeddings for multimodal type of inference, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. */ @Log4j2 -public abstract class InferenceProcessor extends AbstractProcessor { +public class InferenceProcessor extends AbstractProcessor { + public static final String TYPE = "inference-processor"; public static final String MODEL_ID_FIELD = "model_id"; public static final String FIELD_MAP_FIELD = "field_map"; - private final String type; + private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - // This field is used for nested knn_vector/rank_features field. The value of the field will be used as the - // default key for the nested object. - private final String listTypeNestedMapKey; - - protected final String modelId; + @VisibleForTesting + private final String modelId; private final Map fieldMap; - protected final MLCommonsClientAccessor mlCommonsClientAccessor; + private final MLCommonsClientAccessor mlCommonsClientAccessor; private final Environment environment; + private static final int MAX_CONTENT_LENGTH_IN_BYTES = 10 * 1024 * 1024; // limit of 10Mb per field value + public InferenceProcessor( - String tag, - String description, - String type, - String listTypeNestedMapKey, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment + String tag, + String description, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment ) { super(tag, description); - this.type = type; - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); validateEmbeddingConfiguration(fieldMap); - this.listTypeNestedMapKey = listTypeNestedMapKey; this.modelId = modelId; this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; @@ -75,25 +73,18 @@ public InferenceProcessor( private void validateEmbeddingConfiguration(Map fieldMap) { if (fieldMap == null - || fieldMap.size() == 0 - || fieldMap.entrySet() + || fieldMap.isEmpty() + || fieldMap.entrySet() .stream() .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) )) { - throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); + throw new IllegalArgumentException("Unable to create the InferenceProcessor processor as field_map has invalid key or value"); } } - public abstract void doExecute( - IngestDocument ingestDocument, - Map ProcessMap, - List inferenceList, - BiConsumer handler - ); - @Override - public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + public IngestDocument execute(IngestDocument ingestDocument) { return ingestDocument; } @@ -105,18 +96,58 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { */ @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { + // When received a bulk indexing request, the pipeline will be executed in this method, (see + // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). + // Before the pipeline execution, the pipeline will be marked as resolved (means executed), + // and then this overriding method will be invoked when executing the text embedding processor. + // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. try { validateEmbeddingFieldsValue(ingestDocument); - Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(ProcessMap); - if (inferenceList.size() == 0) { + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map> inferenceMap = createInferenceMap(knnMap); + if (inferenceMap.isEmpty()) { handler.accept(ingestDocument, null); } else { - doExecute(ingestDocument, ProcessMap, inferenceList, handler); + mlCommonsClientAccessor.inferenceMultimodal(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, knnMap, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); } } catch (Exception e) { handler.accept(null, e); } + + } + + void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { + Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); + textEmbeddingResult.forEach(ingestDocument::setFieldValue); + } + + private Map> createInferenceMap(Map knnKeyMap) { + Map> objects = new HashMap<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof Map) { + Map sourceValues = (Map) sourceValue; + if (sourceValues.entrySet() + .stream() + .anyMatch( + entry -> entry.getKey().length() > MAX_CONTENT_LENGTH_IN_BYTES + || entry.getValue().length() > MAX_CONTENT_LENGTH_IN_BYTES + )) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) + ); + } + objects.put(knnMapEntry.getKey(), sourceValues); + } else { + throw new RuntimeException("Cannot build inference object"); + } + }); + return objects; } @SuppressWarnings({ "unchecked" }) @@ -148,47 +179,113 @@ private void createInferenceListForMapTypeInput(Object sourceValue, List } @VisibleForTesting - Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithProcessorKeys = new LinkedHashMap<>(); + Map mapWithKnnKeys = new LinkedHashMap<>(); for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { String originalKey = fieldMapEntry.getKey(); Object targetKey = fieldMapEntry.getValue(); if (targetKey instanceof Map) { - Map treeRes = new LinkedHashMap<>(); - buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); + // Map treeRes = new LinkedHashMap<>(); + // buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + // mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); + Map knnMap = Map.of( + "value", + sourceAndMetadataMap.get(originalKey).toString(), + "model_input", + ((Map) targetKey).get("model_input").toString() + ); + mapWithKnnKeys.put(originalKey, knnMap); } else { - mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); } } - return mapWithProcessorKeys; + return mapWithKnnKeys; } - private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes + @SuppressWarnings({ "unchecked" }) + private void buildMapWithKnnKeyAndOriginalValueForMapType( + String parentKey, + Object knnKey, + Map sourceAndMetadataMap, + Map treeRes ) { - if (processorKey == null || sourceAndMetadataMap == null) return; - if (processorKey instanceof Map) { + if (knnKey == null || sourceAndMetadataMap == null) return; + if (knnKey instanceof Map) { Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next + for (Map.Entry nestedFieldMapEntry : ((Map) knnKey).entrySet()) { + buildMapWithKnnKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next ); } treeRes.put(parentKey, next); } else { - String key = String.valueOf(processorKey); + String key = String.valueOf(knnKey); treeRes.put(key, sourceAndMetadataMap.get(parentKey)); } } + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildTextEmbeddingResult( + Map knnMap, + List> modelTensorList, + Map sourceAndMetadataMap + ) { + IndexWrapper indexWrapper = new IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : knnMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + List modelTensor = modelTensorList.get(indexWrapper.index++); + result.put(knnKey, modelTensor); + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putTextEmbeddingResultToSourceMapForMapType( + String knnKey, + Object sourceValue, + List> modelTensorList, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putTextEmbeddingResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + modelTensorList, + indexWrapper, + (Map) sourceAndMetadataMap.get(knnKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put( + knnKey, + buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) + ); + } + } + + private List>> buildTextEmbeddingResultForListType( + List sourceValue, + List> modelTensorList, + IndexWrapper indexWrapper + ) { + List>> numbers = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); + return numbers; + } + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { @@ -199,9 +296,9 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { validateNestedTypeValue(sourceKey, sourceValue, () -> 1); } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); } } } @@ -211,100 +308,37 @@ private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { validateListTypeValue(sourceKey, sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); } } @SuppressWarnings({ "rawtypes" }) - private void validateListTypeValue(String sourceKey, Object sourceValue) { + private static void validateListTypeValue(String sourceKey, Object sourceValue) { for (Object value : (List) sourceValue) { if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); - } - } - } - - protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { - Objects.requireNonNull(results, "embedding failed, inference returns null result!"); - log.debug("Model inference result fetched, starting build vector output!"); - Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); - nlpResult.forEach(ingestDocument::setFieldValue); - } - - @SuppressWarnings({ "unchecked" }) - @VisibleForTesting - Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { - InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0); - Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : processorMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof String) { - result.put(knnKey, results.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - result.put(knnKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); - } else if (sourceValue instanceof Map) { - putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); - } - } - return result; - } - - @SuppressWarnings({ "unchecked" }) - private void putNLPResultToSourceMapForMapType( - String processorKey, - Object sourceValue, - List results, - InferenceProcessor.IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putNLPResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - results, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) - ); + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); } } - private List> buildNLPResultForListType( - List sourceValue, - List results, - InferenceProcessor.IndexWrapper indexWrapper - ) { - List> keyToResult = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); - return keyToResult; - } - @Override public String getType() { - return type; + return TYPE; } /** @@ -322,4 +356,5 @@ protected IndexWrapper(int index) { this.index = index; } } -} + +} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java new file mode 100644 index 000000000..b2dd1fc28 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java @@ -0,0 +1,36 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.ingest.Processor.Factory; + +import java.util.Map; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; + +public class InferenceProcessorFactory implements Factory { + + private final MLCommonsClientAccessor clientAccessor; + + private final Environment environment; + + public InferenceProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public InferenceProcessor create(Map registry, String processorTag, String description, Map config) + throws Exception { + String modelId = readStringProperty(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.MODEL_ID_FIELD); + Map filedMap = readMap(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.FIELD_MAP_FIELD); + return new InferenceProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java new file mode 100644 index 000000000..60d0dc143 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java @@ -0,0 +1,141 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class InferenceProcessorTests extends OpenSearchTestCase { + + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private Environment env; + + @InjectMocks + private InferenceProcessorFactory inferenceProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + } + + @SneakyThrows + private InferenceProcessor createInstance(List> vector) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put( + TextEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of( + "key1", + Map.of("model_input", "ModelInput1", "model_output", "ModelOutput1", "embedding", "key1Mapped"), + "key2", + Map.of("model_input", "ModelInput2", "model_output", "ModelOutput2", "embedding", "key2Mapped") + ) + ); + return inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + try { + inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } catch (OpenSearchParseException e) { + assertEquals("[field_map] required property is missing", e.getMessage()); + } + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + InferenceProcessor processor = createInstance(createMockVectorWithLength(2)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceMultimodal(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + private List> createMockVectorResult() { + List> modelTensorList = new ArrayList<>(); + List number1 = ImmutableList.of(1.234f, 2.354f); + List number2 = ImmutableList.of(3.234f, 4.354f); + List number3 = ImmutableList.of(5.234f, 6.354f); + List number4 = ImmutableList.of(7.234f, 8.354f); + List number5 = ImmutableList.of(9.234f, 10.354f); + List number6 = ImmutableList.of(11.234f, 12.354f); + List number7 = ImmutableList.of(13.234f, 14.354f); + modelTensorList.add(number1); + modelTensorList.add(number2); + modelTensorList.add(number3); + modelTensorList.add(number4); + modelTensorList.add(number5); + modelTensorList.add(number6); + modelTensorList.add(number7); + return modelTensorList; + } + + private List> createMockVectorWithLength(int size) { + float suffix = .234f; + List> result = new ArrayList<>(); + for (int i = 0; i < size * 2;) { + List number = new ArrayList<>(); + number.add(i++ + suffix); + number.add(i++ + suffix); + result.add(number); + } + return result; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java new file mode 100644 index 000000000..55aba5443 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.InferenceProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.InferenceProcessor.MODEL_ID_FIELD; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.InferenceProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class InferenceProcessorFactoryTests extends OpenSearchTestCase { + + private static final String NORMALIZATION_METHOD = "min_max"; + private static final String COMBINATION_METHOD = "arithmetic_mean"; + + @SneakyThrows + public void testNormalizationProcessor_whenNoParams_thenSuccessful() { + InferenceProcessorFactory inferenceProcessorFactory = new InferenceProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, "1234567678"); + config.put( + FIELD_MAP_FIELD, + Map.of("passage_text", Map.of("model_input", "TextInput1", "model_output", "TextEmbdedding1", "embedding", "passage_embedding")) + ); + InferenceProcessor inferenceProcessor = inferenceProcessorFactory.create(processorFactories, tag, description, config); + assertNotNull(inferenceProcessor); + assertEquals("inference-processor", inferenceProcessor.getType()); + } +} From 9f6b3f4c35a59fc7d4072e1beb24b31aebe15dad Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 28 Sep 2023 10:15:58 -0700 Subject: [PATCH 02/14] Changed approach to a hardcoded fields for image and text Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + build.gradle | 1 + .../ml/MLCommonsClientAccessor.java | 61 ++- .../neuralsearch/plugin/NeuralSearch.java | 8 +- .../processor/InferenceProcessor.java | 360 ------------------ .../TextImageEmbeddingProcessor.java | 227 +++++++++++ .../factory/InferenceProcessorFactory.java | 36 -- .../TextImageEmbeddingProcessorFactory.java | 48 +++ .../query/NeuralQueryBuilder.java | 26 +- .../constants/TestCommonConstants.java | 2 + .../ml/MLCommonsClientAccessorTests.java | 31 ++ .../processor/InferenceProcessorTests.java | 141 ------- .../processor/NormalizationProcessorIT.java | 30 +- .../processor/ScoreCombinationIT.java | 8 +- .../processor/ScoreNormalizationIT.java | 24 +- .../TextImageEmbeddingProcessorTests.java | 272 +++++++++++++ .../InferenceProcessorFactoryTests.java | 48 --- ...xtImageEmbeddingProcessorFactoryTests.java | 115 ++++++ .../query/NeuralQueryBuilderTests.java | 40 +- .../neuralsearch/query/NeuralQueryIT.java | 45 +++ 20 files changed, 916 insertions(+), 608 deletions(-) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java delete mode 100644 src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java delete mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index eba702c00..4fe6cc1b4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x) ### Features Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) +Added Multimodal semantic search feature ([#359](https://github.com/opensearch-project/neural-search/pull/359)) ### Enhancements Add `max_token_score` parameter to improve the execution efficiency for `neural_sparse` query clause ([#348](https://github.com/opensearch-project/neural-search/pull/348)) ### Bug Fixes diff --git a/build.gradle b/build.gradle index 8fac39682..3f5da7997 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" + runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' runtimeOnly group: 'org.json', name: 'json', version: '20230227' } diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index e9dc342ec..2117b220b 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -5,9 +5,11 @@ package org.opensearch.neuralsearch.ml; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; + import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -22,7 +24,6 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -115,6 +116,24 @@ public void inferenceSentencesWithMapResult( retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); } + /** + * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the + * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent + * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of + * inputText. + * + * @param modelId {@link String} + * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen + * @param listener {@link ActionListener} which will be called when prediction is completed or errored out. + */ + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final Map inputObjects, + @NonNull final ActionListener> listener + ) { + inferenceSentencesWithRetry(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + } + private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, @@ -198,4 +217,42 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { } return resultMaps; } + + private List buildSingleVectorFromResponse(MLOutput mlOutput) { + final List> vector = buildVectorFromResponse(mlOutput); + return vector.isEmpty() ? new ArrayList<>() : vector.get(0); + } + + private void inferenceSentencesWithRetry( + @NonNull final List targetResponseFilters, + final String modelId, + final Map inputObjects, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List vector = buildSingleVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + inferenceSentencesWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private MLInput createMLMultimodalInput(final List targetResponseFilters, Map input) { + List inputText = new ArrayList<>(); + inputText.add(input.get(INPUT_TEXT)); + if (input.containsKey(INPUT_IMAGE)) { + inputText.add(input.get(INPUT_IMAGE)); + } + final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); + final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); + return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index 7bcafd781..cf1a2f9bd 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -29,17 +29,17 @@ import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; -import org.opensearch.neuralsearch.processor.InferenceProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; -import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; @@ -109,8 +109,8 @@ public Map getProcessors(Processor.Parameters paramet new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env), - InferenceProcessor.TYPE, - new InferenceProcessorFactory(clientAccessor, parameters.env) + TextImageEmbeddingProcessor.TYPE, + new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java deleted file mode 100644 index cc482fa2f..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ /dev/null @@ -1,360 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Locale; -import java.util.Map; -import java.util.Objects; -import java.util.function.BiConsumer; -import java.util.function.Supplier; -import java.util.stream.IntStream; - -import lombok.extern.log4j.Log4j2; - -import org.apache.commons.lang3.StringUtils; -import org.opensearch.core.action.ActionListener; -import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; -import org.opensearch.ingest.AbstractProcessor; -import org.opensearch.ingest.IngestDocument; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; - -import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; - -/** - * This processor is used for getting embeddings for multimodal type of inference, model_id can be used to indicate which model user use, - * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. - */ -@Log4j2 -public class InferenceProcessor extends AbstractProcessor { - - public static final String TYPE = "inference-processor"; - public static final String MODEL_ID_FIELD = "model_id"; - public static final String FIELD_MAP_FIELD = "field_map"; - - private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; - - @VisibleForTesting - private final String modelId; - - private final Map fieldMap; - - private final MLCommonsClientAccessor mlCommonsClientAccessor; - - private final Environment environment; - - private static final int MAX_CONTENT_LENGTH_IN_BYTES = 10 * 1024 * 1024; // limit of 10Mb per field value - - public InferenceProcessor( - String tag, - String description, - String modelId, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment - ) { - super(tag, description); - if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); - validateEmbeddingConfiguration(fieldMap); - - this.modelId = modelId; - this.fieldMap = fieldMap; - this.mlCommonsClientAccessor = clientAccessor; - this.environment = environment; - } - - private void validateEmbeddingConfiguration(Map fieldMap) { - if (fieldMap == null - || fieldMap.isEmpty() - || fieldMap.entrySet() - .stream() - .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) - )) { - throw new IllegalArgumentException("Unable to create the InferenceProcessor processor as field_map has invalid key or value"); - } - } - - @Override - public IngestDocument execute(IngestDocument ingestDocument) { - return ingestDocument; - } - - /** - * This method will be invoked by PipelineService to make async inference and then delegate the handler to - * process the inference response or failure. - * @param ingestDocument {@link IngestDocument} which is the document passed to processor. - * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. - */ - @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { - // When received a bulk indexing request, the pipeline will be executed in this method, (see - // https://github.com/opensearch-project/OpenSearch/blob/main/server/src/main/java/org/opensearch/action/bulk/TransportBulkAction.java#L226). - // Before the pipeline execution, the pipeline will be marked as resolved (means executed), - // and then this overriding method will be invoked when executing the text embedding processor. - // After the inference completes, the handler will invoke the doInternalExecute method again to run actual write operation. - try { - validateEmbeddingFieldsValue(ingestDocument); - Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); - Map> inferenceMap = createInferenceMap(knnMap); - if (inferenceMap.isEmpty()) { - handler.accept(ingestDocument, null); - } else { - mlCommonsClientAccessor.inferenceMultimodal(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, knnMap, vectors); - handler.accept(ingestDocument, null); - }, e -> { handler.accept(null, e); })); - } - } catch (Exception e) { - handler.accept(null, e); - } - - } - - void setVectorFieldsToDocument(IngestDocument ingestDocument, Map knnMap, List> vectors) { - Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); - log.debug("Text embedding result fetched, starting build vector output!"); - Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); - textEmbeddingResult.forEach(ingestDocument::setFieldValue); - } - - private Map> createInferenceMap(Map knnKeyMap) { - Map> objects = new HashMap<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof Map) { - Map sourceValues = (Map) sourceValue; - if (sourceValues.entrySet() - .stream() - .anyMatch( - entry -> entry.getKey().length() > MAX_CONTENT_LENGTH_IN_BYTES - || entry.getValue().length() > MAX_CONTENT_LENGTH_IN_BYTES - )) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) - ); - } - objects.put(knnMapEntry.getKey(), sourceValues); - } else { - throw new RuntimeException("Cannot build inference object"); - } - }); - return objects; - } - - @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); - } - }); - return texts; - } - - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); - } - } - - @VisibleForTesting - Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithKnnKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - // Map treeRes = new LinkedHashMap<>(); - // buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - // mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); - Map knnMap = Map.of( - "value", - sourceAndMetadataMap.get(originalKey).toString(), - "model_input", - ((Map) targetKey).get("model_input").toString() - ); - mapWithKnnKeys.put(originalKey, knnMap); - } else { - mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); - } - } - return mapWithKnnKeys; - } - - @SuppressWarnings({ "unchecked" }) - private void buildMapWithKnnKeyAndOriginalValueForMapType( - String parentKey, - Object knnKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (knnKey == null || sourceAndMetadataMap == null) return; - if (knnKey instanceof Map) { - Map next = new LinkedHashMap<>(); - for (Map.Entry nestedFieldMapEntry : ((Map) knnKey).entrySet()) { - buildMapWithKnnKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); - } - treeRes.put(parentKey, next); - } else { - String key = String.valueOf(knnKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); - } - } - - @SuppressWarnings({ "unchecked" }) - @VisibleForTesting - Map buildTextEmbeddingResult( - Map knnMap, - List> modelTensorList, - Map sourceAndMetadataMap - ) { - IndexWrapper indexWrapper = new IndexWrapper(0); - Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : knnMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - List modelTensor = modelTensorList.get(indexWrapper.index++); - result.put(knnKey, modelTensor); - } - return result; - } - - @SuppressWarnings({ "unchecked" }) - private void putTextEmbeddingResultToSourceMapForMapType( - String knnKey, - Object sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putTextEmbeddingResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - modelTensorList, - indexWrapper, - (Map) sourceAndMetadataMap.get(knnKey) - ); - } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put( - knnKey, - buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) - ); - } - } - - private List>> buildTextEmbeddingResultForListType( - List sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper - ) { - List>> numbers = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); - return numbers; - } - - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { - if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); - } - } - } - - @Override - public String getType() { - return TYPE; - } - - /** - * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, - * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order - * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the - * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase - * the index pointer during the recursive. - * index: the index pointer of the text embedding result. - */ - static class IndexWrapper { - private int index; - - protected IndexWrapper(int index) { - this.index = index; - } - } - -} \ No newline at end of file diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java new file mode 100644 index 000000000..29664dcf2 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -0,0 +1,227 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import com.google.common.annotations.VisibleForTesting; + +/** + * This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use, + * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. + */ +@Log4j2 +public class TextImageEmbeddingProcessor extends AbstractProcessor { + + public static final String TYPE = "text_image_embedding"; + public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding"; + public static final String FIELD_MAP_FIELD = "field_map"; + public static final String TEXT_FIELD_NAME = "text"; + public static final String IMAGE_FIELD_NAME = "image"; + public static final String INPUT_TEXT = "inputText"; + public static final String INPUT_IMAGE = "inputImage"; + + @VisibleForTesting + private final String modelId; + private final String embedding; + private final Map fieldMap; + + private final MLCommonsClientAccessor mlCommonsClientAccessor; + private final Environment environment; + // limit of 16Mb per field value. This is from current bedrock model, calculated as 2048*2048 pixels (24 bit), + // image to base64 encoding assumed to have 4/3 ratio, assuming UTF-8 encoding average of 1 byte per character + private static final int MAX_CONTENT_LENGTH_IN_BYTES = 16 * 1024 * 1024; + + public TextImageEmbeddingProcessor( + String tag, + String description, + String modelId, + String embedding, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); + validateEmbeddingConfiguration(fieldMap); + + this.modelId = modelId; + this.embedding = embedding; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.isEmpty() + || fieldMap.entrySet() + .stream() + .anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue()))) { + throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); + } + + if (fieldMap.entrySet().stream().anyMatch(entry -> !Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME).contains(entry.getKey()))) { + throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + } + } + + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(final IngestDocument ingestDocument, final BiConsumer handler) { + try { + validateEmbeddingFieldsValue(ingestDocument); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); + Map inferenceMap = createInferences(knnMap); + if (inferenceMap.isEmpty()) { + handler.accept(ingestDocument, null); + } else { + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); + handler.accept(ingestDocument, null); + }, e -> { handler.accept(null, e); })); + } + } catch (Exception e) { + handler.accept(null, e); + } + + } + + private void setVectorFieldsToDocument(IngestDocument ingestDocument, List vectors) { + Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); + log.debug("Text embedding result fetched, starting build vector output!"); + Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); + textEmbeddingResult.forEach(ingestDocument::setFieldValue); + } + + @SuppressWarnings({ "unchecked" }) + private Map createInferences(Map knnKeyMap) { + Map texts = new HashMap<>(); + if (fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(TEXT_FIELD_NAME))) { + texts.put(INPUT_TEXT, knnKeyMap.get(fieldMap.get(TEXT_FIELD_NAME))); + } + if (fieldMap.containsKey(IMAGE_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(IMAGE_FIELD_NAME))) { + texts.put(INPUT_IMAGE, knnKeyMap.get(fieldMap.get(IMAGE_FIELD_NAME))); + } + return texts; + } + + @VisibleForTesting + Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithKnnKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getValue(); // field from ingest document that we need to sent as model input, part of + // processor definition + + if (!sourceAndMetadataMap.containsKey(originalKey)) { + continue; + } + if (!(sourceAndMetadataMap.get(originalKey) instanceof String)) { + throw new IllegalArgumentException("Unsupported format of the field in the document, value must be a string"); + } + if (((String) sourceAndMetadataMap.get(originalKey)).length() > MAX_CONTENT_LENGTH_IN_BYTES) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) + ); + } + mapWithKnnKeys.put(originalKey, (String) sourceAndMetadataMap.get(originalKey)); + } + return mapWithKnnKeys; + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildTextEmbeddingResult(String knnKey, List modelTensorList) { + Map result = new LinkedHashMap<>(); + result.put(knnKey, modelTensorList); + return result; + } + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, can not process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private static void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, can not process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, can not process it"); + } + } + } + + @Override + public String getType() { + return TYPE; + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java deleted file mode 100644 index b2dd1fc28..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor.factory; - -import static org.opensearch.ingest.ConfigurationUtils.readMap; -import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.ingest.Processor.Factory; - -import java.util.Map; - -import org.opensearch.env.Environment; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.InferenceProcessor; - -public class InferenceProcessorFactory implements Factory { - - private final MLCommonsClientAccessor clientAccessor; - - private final Environment environment; - - public InferenceProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { - this.clientAccessor = clientAccessor; - this.environment = environment; - } - - @Override - public InferenceProcessor create(Map registry, String processorTag, String description, Map config) - throws Exception { - String modelId = readStringProperty(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.MODEL_ID_FIELD); - Map filedMap = readMap(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.FIELD_MAP_FIELD); - return new InferenceProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java new file mode 100644 index 000000000..a7ae347e0 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -0,0 +1,48 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TYPE; + +import java.util.Map; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; + +/** + * Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. + */ +public class TextImageEmbeddingProcessorFactory implements Factory { + + private final MLCommonsClientAccessor clientAccessor; + + private final Environment environment; + + public TextImageEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public TextImageEmbeddingProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); + String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); + Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + return new TextImageEmbeddingProcessor(processorTag, description, modelId, embedding, filedMap, clientAccessor, environment); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index 7b78be269..edb9aace0 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -7,8 +7,12 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Supplier; import lombok.AccessLevel; @@ -19,6 +23,7 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; @@ -61,6 +66,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); + @VisibleForTesting + static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image"); + @VisibleForTesting static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); @@ -77,6 +85,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String fieldName; private String queryText; + private String queryImage; private String modelId; private int k = DEFAULT_K; @VisibleForTesting @@ -177,7 +186,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx + "]" ); } - requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query"); + if (StringUtils.isBlank(neuralQueryBuilder.queryText()) && StringUtils.isBlank(neuralQueryBuilder.queryImage())) { + throw new IllegalArgumentException("Either query text or image text must be provided for neural query"); + } requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query"); if (!isClusterOnOrAfterMinReqVersionForDefaultModelIdSupport()) { requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); @@ -194,6 +205,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } else if (token.isValue()) { if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.queryText(parser.text()); + } else if (QUERY_IMAGE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.queryImage(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.modelId(parser.text()); } else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -237,13 +250,20 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { } SetOnce vectorSetOnce = new SetOnce<>(); + Map inferenceInput = new HashMap<>(); + if (StringUtils.isNotBlank(queryText())) { + inferenceInput.put(INPUT_TEXT, queryText()); + } + if (StringUtils.isNotBlank(queryImage())) { + inferenceInput.put(INPUT_IMAGE, queryImage()); + } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { + ((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { vectorSetOnce.set(vectorAsListToArray(floatList)); actionListener.onResponse(null); }, actionListener::onFailure))) ); - return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter()); + return new NeuralQueryBuilder(fieldName(), queryText(), queryImage(), modelId(), k(), vectorSetOnce::get, filter()); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java b/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java index 2776a53e6..185934b07 100644 --- a/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java +++ b/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.constants; import java.util.List; +import java.util.Map; import lombok.AccessLevel; import lombok.NoArgsConstructor; @@ -16,4 +17,5 @@ public class TestCommonConstants { public static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); public static final Float[] PREDICT_VECTOR_ARRAY = new Float[] { 2.0f, 3.0f }; public static final List SENTENCES_LIST = List.of("TEXT"); + public static final Map SENTENCES_MAP = Map.of("inputText", "Text query", "inputImage", "base641234567890"); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 1b10966dc..b972d474b 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -280,6 +280,37 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceMultimodal_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { + final RuntimeException exception = new RuntimeException(); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(exception); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java deleted file mode 100644 index 60d0dc143..000000000 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor; - -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.isNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; - -import lombok.SneakyThrows; - -import org.junit.Before; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchParseException; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.env.Environment; -import org.opensearch.ingest.IngestDocument; -import org.opensearch.ingest.Processor; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; -import org.opensearch.test.OpenSearchTestCase; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; - -public class InferenceProcessorTests extends OpenSearchTestCase { - - @Mock - private MLCommonsClientAccessor mlCommonsClientAccessor; - - @Mock - private Environment env; - - @InjectMocks - private InferenceProcessorFactory inferenceProcessorFactory; - private static final String PROCESSOR_TAG = "mockTag"; - private static final String DESCRIPTION = "mockDescription"; - - @Before - public void setup() { - MockitoAnnotations.openMocks(this); - Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); - when(env.settings()).thenReturn(settings); - } - - @SneakyThrows - private InferenceProcessor createInstance(List> vector) { - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of( - "key1", - Map.of("model_input", "ModelInput1", "model_output", "ModelOutput1", "embedding", "key1Mapped"), - "key2", - Map.of("model_input", "ModelInput2", "model_output", "ModelOutput2", "embedding", "key2Mapped") - ) - ); - return inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - } - - @SneakyThrows - public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - try { - inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - } catch (OpenSearchParseException e) { - assertEquals("[field_map] required property is missing", e.getMessage()); - } - } - - public void testExecute_successful() { - Map sourceAndMetadata = new HashMap<>(); - sourceAndMetadata.put("key1", "value1"); - sourceAndMetadata.put("key2", "value2"); - IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - InferenceProcessor processor = createInstance(createMockVectorWithLength(2)); - - List> modelTensorList = createMockVectorResult(); - doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); - listener.onResponse(modelTensorList); - return null; - }).when(mlCommonsClientAccessor).inferenceMultimodal(anyString(), anyMap(), isA(ActionListener.class)); - - BiConsumer handler = mock(BiConsumer.class); - processor.execute(ingestDocument, handler); - verify(handler).accept(any(IngestDocument.class), isNull()); - } - - private List> createMockVectorResult() { - List> modelTensorList = new ArrayList<>(); - List number1 = ImmutableList.of(1.234f, 2.354f); - List number2 = ImmutableList.of(3.234f, 4.354f); - List number3 = ImmutableList.of(5.234f, 6.354f); - List number4 = ImmutableList.of(7.234f, 8.354f); - List number5 = ImmutableList.of(9.234f, 10.354f); - List number6 = ImmutableList.of(11.234f, 12.354f); - List number7 = ImmutableList.of(13.234f, 14.354f); - modelTensorList.add(number1); - modelTensorList.add(number2); - modelTensorList.add(number3); - modelTensorList.add(number4); - modelTensorList.add(number5); - modelTensorList.add(number6); - modelTensorList.add(number7); - return modelTensorList; - } - - private List> createMockVectorWithLength(int size) { - float suffix = .234f; - List> result = new ArrayList<>(); - for (int i = 0; i < size * 2;) { - List number = new ArrayList<>(); - number.add(i++ + suffix); - number.add(i++ + suffix); - result.add(number); - } - return result; - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 3cd71e5a1..86e75f736 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -94,7 +94,15 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 5, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -129,7 +137,15 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 5, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -153,7 +169,15 @@ public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful String modelId = getDeployedModelId(); int totalExpectedDocQty = 6; - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 6, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 6, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index 03b77549a..4993df7fb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -213,7 +213,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); - hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapDefaultNorm = search( @@ -236,7 +236,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( @@ -279,7 +279,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); - hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapDefaultNorm = search( @@ -302,7 +302,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 6f98e8d5e..aa133c44d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -96,7 +96,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); - hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapArithmeticMean = search( @@ -119,7 +121,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); - hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapHarmonicMean = search( @@ -142,7 +146,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); - hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapGeometricMean = search( @@ -185,7 +191,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); - hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapArithmeticMean = search( @@ -208,7 +216,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); - hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapHarmonicMean = search( @@ -231,7 +241,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); - hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapGeometricMean = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java new file mode 100644 index 000000000..6acd5901d --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -0,0 +1,272 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { + + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private Environment env; + + @InjectMocks + private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + } + + @SneakyThrows + private TextImageEmbeddingProcessor createInstance(List> vector) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + return textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + try { + textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } catch (OpenSearchParseException e) { + assertEquals("[embedding] required property is missing", e.getMessage()); + } + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, ""); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) + ); + assertEquals("model_id is null or empty, can not process it", exception.getMessage()); + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("my_text_field", "value2"); + sourceAndMetadata.put("key3", "value3"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + @SneakyThrows + public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map registry = new HashMap<>(); + MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(accessor, env); + + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(RuntimeException.class)); + } + + public void testExecute_withListTypeInput_successful() { + List list1 = ImmutableList.of("test1", "test2", "test3"); + List list2 = ImmutableList.of("test4", "test5", "test6"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { + Map ret = createMaxDepthLimitExceedMap(() -> 1); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "hello world"); + sourceAndMetadata.put("my_text_field", ret); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test3", 209.3D); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("my_text_field", map2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { + Map map1 = ImmutableMap.of("test1", "test2"); + Map map2 = ImmutableMap.of("test3", " "); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", map1); + sourceAndMetadata.put("my_text_field", map2); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + private List> createMockVectorResult() { + List> modelTensorList = new ArrayList<>(); + List number1 = ImmutableList.of(1.234f, 2.354f); + List number2 = ImmutableList.of(3.234f, 4.354f); + List number3 = ImmutableList.of(5.234f, 6.354f); + List number4 = ImmutableList.of(7.234f, 8.354f); + List number5 = ImmutableList.of(9.234f, 10.354f); + List number6 = ImmutableList.of(11.234f, 12.354f); + List number7 = ImmutableList.of(13.234f, 14.354f); + modelTensorList.add(number1); + modelTensorList.add(number2); + modelTensorList.add(number3); + modelTensorList.add(number4); + modelTensorList.add(number5); + modelTensorList.add(number6); + modelTensorList.add(number7); + return modelTensorList; + } + + private List> createMockVectorWithLength(int size) { + float suffix = .234f; + List> result = new ArrayList<>(); + for (int i = 0; i < size * 2;) { + List number = new ArrayList<>(); + number.add(i++ + suffix); + number.add(i++ + suffix); + result.add(number); + } + return result; + } + + private Map createMaxDepthLimitExceedMap(Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > 21) { + return null; + } + Map innerMap = new HashMap<>(); + Map ret = createMaxDepthLimitExceedMap(() -> maxDepth + 1); + if (ret == null) return innerMap; + innerMap.put("hello", ret); + return innerMap; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java deleted file mode 100644 index 55aba5443..000000000 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor.factory; - -import static org.mockito.Mockito.mock; -import static org.opensearch.neuralsearch.processor.InferenceProcessor.FIELD_MAP_FIELD; -import static org.opensearch.neuralsearch.processor.InferenceProcessor.MODEL_ID_FIELD; - -import java.util.HashMap; -import java.util.Map; - -import lombok.SneakyThrows; - -import org.opensearch.env.Environment; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.InferenceProcessor; -import org.opensearch.test.OpenSearchTestCase; - -public class InferenceProcessorFactoryTests extends OpenSearchTestCase { - - private static final String NORMALIZATION_METHOD = "min_max"; - private static final String COMBINATION_METHOD = "arithmetic_mean"; - - @SneakyThrows - public void testNormalizationProcessor_whenNoParams_thenSuccessful() { - InferenceProcessorFactory inferenceProcessorFactory = new InferenceProcessorFactory( - mock(MLCommonsClientAccessor.class), - mock(Environment.class) - ); - - final Map processorFactories = new HashMap<>(); - String tag = "tag"; - String description = "description"; - boolean ignoreFailure = false; - Map config = new HashMap<>(); - config.put(MODEL_ID_FIELD, "1234567678"); - config.put( - FIELD_MAP_FIELD, - Map.of("passage_text", Map.of("model_input", "TextInput1", "model_output", "TextEmbdedding1", "embedding", "passage_embedding")) - ); - InferenceProcessor inferenceProcessor = inferenceProcessorFactory.create(processorFactories, tag, description, config); - assertNotNull(inferenceProcessor); - assertEquals("inference-processor", inferenceProcessor.getType()); - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..768a67161 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, "1234567678"); + config.put(EMBEDDING_FIELD, "embedding_field"); + config.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "my_image_field")); + TextImageEmbeddingProcessor inferenceProcessor = textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + config + ); + assertNotNull(inferenceProcessor); + assertEquals("text_image_embedding", inferenceProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map configOnlyTextField = new HashMap<>(); + configOnlyTextField.put(MODEL_ID_FIELD, "1234567678"); + configOnlyTextField.put(EMBEDDING_FIELD, "embedding_field"); + configOnlyTextField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); + TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + configOnlyTextField + ); + assertNotNull(processor); + assertEquals("text_image_embedding", processor.getType()); + + Map configOnlyImageField = new HashMap<>(); + configOnlyImageField.put(MODEL_ID_FIELD, "1234567678"); + configOnlyImageField.put(EMBEDDING_FIELD, "embedding_field"); + configOnlyImageField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); + processor = textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configOnlyImageField); + assertNotNull(processor); + assertEquals("text_image_embedding", processor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map configMixOfFields = new HashMap<>(); + configMixOfFields.put(MODEL_ID_FIELD, "1234567678"); + configMixOfFields.put(EMBEDDING_FIELD, "embedding_field"); + configMixOfFields.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", "random_field_name", "random_field")); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configMixOfFields) + ); + assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + Map configNoFields = new HashMap<>(); + configNoFields.put(MODEL_ID_FIELD, "1234567678"); + configNoFields.put(EMBEDDING_FIELD, "embedding_field"); + configNoFields.put(FIELD_MAP_FIELD, Map.of()); + exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configNoFields) + ); + assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 681c1247d..968ed0233 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -16,6 +17,7 @@ import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; import java.io.IOException; @@ -60,6 +62,7 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final String FIELD_NAME = "testField"; private static final String QUERY_TEXT = "Hello world!"; + private static final String IMAGE_TEXT = "base641234567890"; private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final int K = 10; private static final float BOOST = 1.8f; @@ -74,6 +77,7 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { { "VECTOR_FIELD": { "query_text": "string", + "query_image": "string", "model_id": "string", "k": int } @@ -117,6 +121,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { .startObject() .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(QUERY_IMAGE_FIELD.getPreferredName(), IMAGE_TEXT) .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) .field(K_FIELD.getPreferredName(), K) .field(BOOST_FIELD.getPreferredName(), BOOST) @@ -130,6 +135,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(IMAGE_TEXT, neuralQueryBuilder.queryImage()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(K, neuralQueryBuilder.k()); assertEquals(BOOST, neuralQueryBuilder.boost(), 0.0); @@ -269,6 +275,33 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } + @SneakyThrows + public void testFromXContent_whenNoQueryField_thenFail() { + /* + { + "VECTOR_FIELD": { + "model_id": "string", + "model_id": "string", + "k": int, + "k": int + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + @SneakyThrows public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { /* @@ -352,6 +385,7 @@ private void testStreams() { NeuralQueryBuilder original = new NeuralQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); + original.queryImage(IMAGE_TEXT); original.modelId(MODEL_ID); original.k(K); original.boost(BOOST); @@ -377,6 +411,7 @@ public void testHashAndEquals() { String fieldName2 = "field 2"; String queryText1 = "query text 1"; String queryText2 = "query text 2"; + String imageText1 = "query image 1"; String modelId1 = "model-1"; String modelId2 = "model-2"; float boost1 = 1.8f; @@ -391,6 +426,7 @@ public void testHashAndEquals() { NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) + .queryImage(imageText1) .modelId(modelId1) .k(k1) .boost(boost1) @@ -527,7 +563,7 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { ActionListener> listener = invocation.getArgument(2); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentence(any(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -554,6 +590,7 @@ public void testRewrite_whenVectorNull_thenReturnCopy() { Supplier nullSupplier = () -> null; NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .vectorSupplier(nullSupplier); @@ -564,6 +601,7 @@ public void testRewrite_whenVectorNull_thenReturnCopy() { public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .vectorSupplier(TEST_VECTOR_SUPPLIER); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index c55da9e7f..6a8e48465 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -32,6 +32,7 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final String TEST_NESTED_INDEX_NAME = "test-neural-nested-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world"; + private static final String TEST_IMAGE_TEXT = "/9j/4AAQSkZJRgABAQAASABIAAD"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field"; @@ -80,6 +81,7 @@ public void testBasicQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -115,6 +117,7 @@ public void testBoostQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -159,6 +162,7 @@ public void testRescoreQuery() { NeuralQueryBuilder rescoreNeuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -207,6 +211,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -215,6 +220,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_2, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -263,6 +269,7 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -307,6 +314,7 @@ public void testNestedQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_NESTED, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -349,6 +357,7 @@ public void testFilterQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -362,6 +371,42 @@ public void testFilterQuery() { assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } + /** + * Tests basic query for multimodal: + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "query_image": "base64_1234567890", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * } + * } + */ + @SneakyThrows + public void testMultimodalQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + TEST_IMAGE_TEXT, + modelId, + 1, + null, + null + ); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex( From 24e243dcd44184bc4623b7745564f677eb981c02 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Fri, 29 Sep 2023 14:20:44 -0700 Subject: [PATCH 03/14] Adding integ test for new processor Signed-off-by: Martin Gaievski --- .../common/BaseNeuralSearchIT.java | 34 +++++- .../processor/SparseEncodingProcessIT.java | 8 +- .../TextImageEmbeddingProcessorIT.java | 110 ++++++++++++++++++ ...tImageEmbeddingProcessorConfiguration.json | 15 +++ 4 files changed, 154 insertions(+), 13 deletions(-) create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java create mode 100644 src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index f5cda0535..159e0aa53 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -73,17 +73,21 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { protected static final String DEFAULT_COMBINATION_METHOD = "arithmetic_mean"; protected static final String PARAM_NAME_WEIGHTS = "weights"; - protected String PIPELINE_CONFIGURATION_NAME = "processor/PipelineConfiguration.json"; + protected static final Map PIPELINE_CONFIGS_BY_TYPE = Map.of( + ProcessorType.TEXT_EMBEDDING, + "processor/PipelineConfiguration.json", + ProcessorType.SPARSE_ENCODING, + "processor/SparseEncodingPipelineConfiguration.json", + ProcessorType.TEXT_IMAGE_EMBEDDING, + "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" + ); + // protected String PIPELINE_CONFIGURATION_NAME = "processor/PipelineConfiguration.json"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); protected ThreadPool threadPool; protected ClusterService clusterService; - protected void setPipelineConfigurationName(String pipelineConfigurationName) { - this.PIPELINE_CONFIGURATION_NAME = pipelineConfigurationName; - } - @Before public void setupSettings() { threadPool = setUpThreadPool(); @@ -263,13 +267,21 @@ protected void createIndexWithConfiguration(String indexName, String indexConfig } protected void createPipelineProcessor(String modelId, String pipelineName) throws Exception { + createPipelineProcessor(modelId, pipelineName, ProcessorType.TEXT_EMBEDDING); + } + + protected void createPipelineProcessor(String modelId, String pipelineName, ProcessorType processorType) throws Exception { Response pipelineCreateResponse = makeRequest( client(), "PUT", "/_ingest/pipeline/" + pipelineName, null, toHttpEntity( - String.format(LOCALE, Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGURATION_NAME).toURI())), modelId) + String.format( + LOCALE, + Files.readString(Path.of(classLoader.getResource(PIPELINE_CONFIGS_BY_TYPE.get(processorType)).toURI())), + modelId + ) ), ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) ); @@ -760,4 +772,14 @@ protected String getDeployedModelId() { assertEquals(1, modelIds.size()); return modelIds.iterator().next(); } + + /** + * Enumeration for types of pipeline processors, used to lookup resources like create + * processor request as those are type specific + */ + protected enum ProcessorType { + TEXT_EMBEDDING, + TEXT_IMAGE_EMBEDDING, + SPARSE_ENCODING + } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java index 51bdf3acc..da2373a3a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessIT.java @@ -15,7 +15,6 @@ import org.apache.http.message.BasicHeader; import org.apache.http.util.EntityUtils; import org.junit.After; -import org.junit.Before; import org.opensearch.client.Response; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.common.xcontent.XContentType; @@ -40,14 +39,9 @@ public void tearDown() { findDeployedModels().forEach(this::deleteModel); } - @Before - public void setPipelineName() { - this.setPipelineConfigurationName("processor/SparseEncodingPipelineConfiguration.json"); - } - public void testSparseEncodingProcessor() throws Exception { String modelId = prepareModel(); - createPipelineProcessor(modelId, PIPELINE_NAME); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.SPARSE_ENCODING); createSparseEncodingIndex(); ingestDocument(); assertEquals(1, getDocCount(INDEX_NAME)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java new file mode 100644 index 000000000..59d964b9c --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorIT.java @@ -0,0 +1,110 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.apache.http.HttpHeaders; +import org.apache.http.message.BasicHeader; +import org.apache.http.util.EntityUtils; +import org.junit.After; +import org.opensearch.client.Response; +import org.opensearch.common.xcontent.XContentHelper; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.neuralsearch.common.BaseNeuralSearchIT; + +import com.google.common.collect.ImmutableList; + +/** + * Testing text_and_image_embedding ingest processor. We can only test text in integ tests, none of pre-built models + * supports both text and image. + */ +public class TextImageEmbeddingProcessorIT extends BaseNeuralSearchIT { + + private static final String INDEX_NAME = "text_image_embedding_index"; + private static final String PIPELINE_NAME = "ingest-pipeline"; + + @After + @SneakyThrows + public void tearDown() { + super.tearDown(); + findDeployedModels().forEach(this::deleteModel); + } + + public void testEmbeddingProcessor_whenIngestingDocumentWithSourceMatchingTextMapping_thenSuccessful() throws Exception { + String modelId = uploadModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING); + createTextImageEmbeddingIndex(); + ingestDocumentWithTextMappedToEmbeddingField(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + public void testEmbeddingProcessor_whenIngestingDocumentWithSourceWithoutMatchingInMapping_thenSuccessful() throws Exception { + String modelId = uploadModel(); + loadModel(modelId); + createPipelineProcessor(modelId, PIPELINE_NAME, ProcessorType.TEXT_IMAGE_EMBEDDING); + createTextImageEmbeddingIndex(); + ingestDocumentWithoutMappedFields(); + assertEquals(1, getDocCount(INDEX_NAME)); + } + + private String uploadModel() throws Exception { + String requestBody = Files.readString(Path.of(classLoader.getResource("processor/UploadModelRequestBody.json").toURI())); + return uploadModel(requestBody); + } + + private void createTextImageEmbeddingIndex() throws Exception { + createIndexWithConfiguration( + INDEX_NAME, + Files.readString(Path.of(classLoader.getResource("processor/IndexMappings.json").toURI())), + PIPELINE_NAME + ); + } + + private void ingestDocumentWithTextMappedToEmbeddingField() throws Exception { + String ingestDocumentBody = "{\n" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"passage_text\": \"A very nice day today\",\n" + + " \"favorites\": {\n" + + " \"game\": \"overwatch\",\n" + + " \"movie\": null\n" + + " }\n" + + "}\n"; + ingestDocument(ingestDocumentBody); + } + + private void ingestDocumentWithoutMappedFields() throws Exception { + String ingestDocumentBody = "{\n" + + " \"title\": \"This is a good day\",\n" + + " \"description\": \"daily logging\",\n" + + " \"some_random_field\": \"Today is a sunny weather\"\n" + + "}\n"; + ingestDocument(ingestDocumentBody); + } + + private void ingestDocument(final String ingestDocument) throws Exception { + Response response = makeRequest( + client(), + "POST", + INDEX_NAME + "/_doc?refresh", + null, + toHttpEntity(ingestDocument), + ImmutableList.of(new BasicHeader(HttpHeaders.USER_AGENT, "Kibana")) + ); + Map map = XContentHelper.convertToMap( + XContentType.JSON.xContent(), + EntityUtils.toString(response.getEntity()), + false + ); + assertEquals("created", map.get("result")); + } +} diff --git a/src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json b/src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json new file mode 100644 index 000000000..60d5dc051 --- /dev/null +++ b/src/test/resources/processor/PipelineForTextImageEmbeddingProcessorConfiguration.json @@ -0,0 +1,15 @@ +{ + "description": "text image embedding pipeline", + "processors": [ + { + "text_image_embedding": { + "model_id": "%s", + "embedding": "passage_embedding", + "field_map": { + "text": "passage_text", + "image": "passage_image" + } + } + } + ] +} From e0e27d2b58aa732c79c53b017e99eecda3aef74e Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 2 Oct 2023 09:35:27 -0700 Subject: [PATCH 04/14] Remove max field value length limit Signed-off-by: Martin Gaievski --- .../processor/TextImageEmbeddingProcessor.java | 9 --------- 1 file changed, 9 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index 29664dcf2..c5ff97822 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -8,7 +8,6 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; -import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -50,9 +49,6 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { private final MLCommonsClientAccessor mlCommonsClientAccessor; private final Environment environment; - // limit of 16Mb per field value. This is from current bedrock model, calculated as 2048*2048 pixels (24 bit), - // image to base64 encoding assumed to have 4/3 ratio, assuming UTF-8 encoding average of 1 byte per character - private static final int MAX_CONTENT_LENGTH_IN_BYTES = 16 * 1024 * 1024; public TextImageEmbeddingProcessor( String tag, @@ -152,11 +148,6 @@ Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocu if (!(sourceAndMetadataMap.get(originalKey) instanceof String)) { throw new IllegalArgumentException("Unsupported format of the field in the document, value must be a string"); } - if (((String) sourceAndMetadataMap.get(originalKey)).length() > MAX_CONTENT_LENGTH_IN_BYTES) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) - ); - } mapWithKnnKeys.put(originalKey, (String) sourceAndMetadataMap.get(originalKey)); } return mapWithKnnKeys; From dba368fea2c52d54a346bd94ecf1884a0da063e0 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 2 Oct 2023 10:09:29 -0700 Subject: [PATCH 05/14] Removed redundant non-null for argument, renamed some private methods Signed-off-by: Martin Gaievski --- .../neuralsearch/ml/MLCommonsClientAccessor.java | 16 +++++----------- 1 file changed, 5 insertions(+), 11 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 2117b220b..98c32b189 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -131,7 +131,7 @@ public void inferenceSentences( @NonNull final Map inputObjects, @NonNull final ActionListener> listener ) { - inferenceSentencesWithRetry(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + retryableInferenceSentencesWithSingleVectorResult(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); } private void retryableInferenceSentencesWithMapResult( @@ -140,7 +140,7 @@ private void retryableInferenceSentencesWithMapResult( final int retryTime, final ActionListener>> listener ) { - MLInput mlInput = createMLInput(null, inputText); + MLInput mlInput = createMLTextInput(null, inputText); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List> result = buildMapResultFromResponse(mlOutput); listener.onResponse(result); @@ -181,12 +181,6 @@ private MLInput createMLTextInput(final List targetResponseFilters, List return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); } - private MLInput createMLInput(final List targetResponseFilters, List inputText) { - final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); - final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); - return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); - } - private List> buildVectorFromResponse(MLOutput mlOutput) { final List> vector = new ArrayList<>(); final ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlOutput; @@ -223,8 +217,8 @@ private List buildSingleVectorFromResponse(MLOutput mlOutput) { return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } - private void inferenceSentencesWithRetry( - @NonNull final List targetResponseFilters, + private void retryableInferenceSentencesWithSingleVectorResult( + final List targetResponseFilters, final String modelId, final Map inputObjects, final int retryTime, @@ -238,7 +232,7 @@ private void inferenceSentencesWithRetry( }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { final int retryTimeAdd = retryTime + 1; - inferenceSentencesWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); + retryableInferenceSentencesWithSingleVectorResult(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); } else { listener.onFailure(e); } From 79a2e4b99554437f9012dae3ef1d49ab7e531af6 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 2 Oct 2023 13:20:24 -0700 Subject: [PATCH 06/14] Adding tests to improve test coverage Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessorTests.java | 17 +++++++++++++ .../TextImageEmbeddingProcessorTests.java | 25 +++++++++++++------ 2 files changed, 35 insertions(+), 7 deletions(-) diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index b972d474b..ce2773f2f 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -311,6 +311,23 @@ public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { Mockito.verifyNoMoreInteractions(singleSentenceResultListener); } + public void testInferenceSentencesMultimodal_whenNodeNotConnectedException_thenRetryThreeTimes() { + final NodeNotConnectedException nodeNodeConnectedException = new NodeNotConnectedException( + mock(DiscoveryNode.class), + "Node not connected" + ); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(nodeNodeConnectedException); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client, times(4)) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(nodeNodeConnectedException); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 6acd5901d..97597691d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -65,7 +65,7 @@ public void setup() { } @SneakyThrows - private TextImageEmbeddingProcessor createInstance(List> vector) { + private TextImageEmbeddingProcessor createInstance() { Map registry = new HashMap<>(); Map config = new HashMap<>(); config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -112,7 +112,7 @@ public void testExecute_successful() { sourceAndMetadata.put("my_text_field", "value2"); sourceAndMetadata.put("key3", "value3"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextImageEmbeddingProcessor processor = createInstance(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -157,7 +157,7 @@ public void testExecute_withListTypeInput_successful() { sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6)); + TextImageEmbeddingProcessor processor = createInstance(); List> modelTensorList = createMockVectorResult(); doAnswer(invocation -> { @@ -177,7 +177,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { sourceAndMetadata.put("key1", "hello world"); sourceAndMetadata.put("my_text_field", ret); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -188,7 +188,7 @@ public void testExecute_MLClientAccessorThrowFail_handlerFailure() { sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("key2", "value2"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextImageEmbeddingProcessor processor = createInstance(); doAnswer(invocation -> { ActionListener>> listener = invocation.getArgument(2); @@ -208,7 +208,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("my_text_field", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); @@ -221,12 +221,23 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("my_text_field", map2); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); processor.execute(ingestDocument, handler); verify(handler).accept(isNull(), any(IllegalArgumentException.class)); } + public void testExecute_hybridTypeInput_successful() throws Exception { + List list1 = ImmutableList.of("test1", "test2"); + Map> map1 = ImmutableMap.of("test3", list1); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key2", map1); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + IngestDocument document = processor.execute(ingestDocument); + assert document.getSourceAndMetadata().containsKey("key2"); + } + private List> createMockVectorResult() { List> modelTensorList = new ArrayList<>(); List number1 = ImmutableList.of(1.234f, 2.354f); From 62a7b1035d338070909df1ed2500175c586e872b Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Mon, 2 Oct 2023 17:57:04 -0700 Subject: [PATCH 07/14] Address code review comments Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessor.java | 7 +- .../TextImageEmbeddingProcessor.java | 73 ++++++++++--------- .../TextEmbeddingProcessorFactory.java | 10 +-- ...xtImageEmbeddingProcessorFactoryTests.java | 13 +++- 4 files changed, 60 insertions(+), 43 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 98c32b189..5bb54cc2c 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -119,8 +119,7 @@ public void inferenceSentencesWithMapResult( /** * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. + * using the actionListener which will have a list of floats in the order of inputText. * * @param modelId {@link String} * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen @@ -212,7 +211,7 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return resultMaps; } - private List buildSingleVectorFromResponse(MLOutput mlOutput) { + private List buildSingleVectorFromResponse(final MLOutput mlOutput) { final List> vector = buildVectorFromResponse(mlOutput); return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } @@ -239,7 +238,7 @@ private void retryableInferenceSentencesWithSingleVectorResult( })); } - private MLInput createMLMultimodalInput(final List targetResponseFilters, Map input) { + private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { List inputText = new ArrayList<>(); inputText.add(input.get(INPUT_TEXT)); if (input.containsKey(INPUT_IMAGE)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index c5ff97822..f77d72157 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -8,6 +8,7 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -41,8 +42,8 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { public static final String IMAGE_FIELD_NAME = "image"; public static final String INPUT_TEXT = "inputText"; public static final String INPUT_IMAGE = "inputImage"; + private static final Set VALID_FIELD_NAMES = Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME); - @VisibleForTesting private final String modelId; private final String embedding; private final Map fieldMap; @@ -51,13 +52,13 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { private final Environment environment; public TextImageEmbeddingProcessor( - String tag, - String description, - String modelId, - String embedding, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment + final String tag, + final String description, + final String modelId, + final String embedding, + final Map fieldMap, + final MLCommonsClientAccessor clientAccessor, + final Environment environment ) { super(tag, description); if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); @@ -70,17 +71,21 @@ public TextImageEmbeddingProcessor( this.environment = environment; } - private void validateEmbeddingConfiguration(Map fieldMap) { + private void validateEmbeddingConfiguration(final Map fieldMap) { if (fieldMap == null || fieldMap.isEmpty() - || fieldMap.entrySet() - .stream() - .anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue()))) { + || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()))) { throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); } - if (fieldMap.entrySet().stream().anyMatch(entry -> !Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME).contains(entry.getKey()))) { - throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + if (fieldMap.entrySet().stream().anyMatch(entry -> !VALID_FIELD_NAMES.contains(entry.getKey()))) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [%s]", + String.join(",", VALID_FIELD_NAMES) + ) + ); } } @@ -115,7 +120,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer vectors) { + private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); @@ -123,7 +128,7 @@ private void setVectorFieldsToDocument(IngestDocument ingestDocument, List createInferences(Map knnKeyMap) { + private Map createInferences(final Map knnKeyMap) { Map texts = new HashMap<>(); if (fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(TEXT_FIELD_NAME))) { texts.put(INPUT_TEXT, knnKeyMap.get(fieldMap.get(TEXT_FIELD_NAME))); @@ -135,7 +140,7 @@ private Map createInferences(Map knnKeyMap) { } @VisibleForTesting - Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { + Map buildMapWithKnnKeyAndOriginalValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); Map mapWithKnnKeys = new LinkedHashMap<>(); for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { @@ -155,37 +160,39 @@ Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocu @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildTextEmbeddingResult(String knnKey, List modelTensorList) { + Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { Map result = new LinkedHashMap<>(); result.put(knnKey, modelTensorList); return result; } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); - } + if (Objects.isNull(sourceValue)) { + continue; } + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } } @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); + validateListTypeValue(sourceKey, (List) sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() .stream() @@ -199,8 +206,8 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl } @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { + private static void validateListTypeValue(final String sourceKey, final List sourceValue) { + for (Object value : sourceValue) { if (value == null) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); } else if (!(value instanceof String)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 0c9a6fa2c..adf6f6d21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -25,17 +25,17 @@ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final Environment environment; - public TextEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + public TextEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { this.clientAccessor = clientAccessor; this.environment = environment; } @Override public TextEmbeddingProcessor create( - Map registry, - String processorTag, - String description, - Map config + final Map registry, + final String processorTag, + final String description, + final Map config ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java index 768a67161..39d1f14de 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.processor.factory; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; @@ -101,7 +103,16 @@ public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { IllegalArgumentException.class, () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configMixOfFields) ); - assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + allOf( + containsString( + "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [" + ), + containsString("image"), + containsString("text") + ) + ); Map configNoFields = new HashMap<>(); configNoFields.put(MODEL_ID_FIELD, "1234567678"); configNoFields.put(EMBEDDING_FIELD, "embedding_field"); From e820584f59b87decab66a14dd4e8253281608c54 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 08:22:33 -0700 Subject: [PATCH 08/14] Remove commented code from test Signed-off-by: Martin Gaievski --- .../org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java | 1 - 1 file changed, 1 deletion(-) diff --git a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java index 159e0aa53..ef79df153 100644 --- a/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java +++ b/src/test/java/org/opensearch/neuralsearch/common/BaseNeuralSearchIT.java @@ -81,7 +81,6 @@ public abstract class BaseNeuralSearchIT extends OpenSearchSecureRestTestCase { ProcessorType.TEXT_IMAGE_EMBEDDING, "processor/PipelineForTextImageEmbeddingProcessorConfiguration.json" ); - // protected String PIPELINE_CONFIGURATION_NAME = "processor/PipelineConfiguration.json"; protected final ClassLoader classLoader = this.getClass().getClassLoader(); From c1d6c0ac86c4371f3429e68f16f12394ea9ea899 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 09:31:30 -0700 Subject: [PATCH 09/14] Fixed bug when mapped field name retrieved incorrectly for getting inferences Signed-off-by: Martin Gaievski --- .../TextImageEmbeddingProcessor.java | 10 +-- .../TextImageEmbeddingProcessorTests.java | 70 ++++++++++++++++++- 2 files changed, 73 insertions(+), 7 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index f77d72157..70ddc0d60 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -169,18 +169,18 @@ Map buildTextEmbeddingResult(final String knnKey, List mo private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + String mappedSourceKey = embeddingFieldsEntry.getValue(); + Object sourceValue = sourceAndMetadataMap.get(mappedSourceKey); if (Objects.isNull(sourceValue)) { continue; } - String sourceKey = embeddingFieldsEntry.getKey(); Class sourceValueClass = sourceValue.getClass(); if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1); } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + throw new IllegalArgumentException("field [" + mappedSourceKey + "] has empty string value, can not process it"); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index 97597691d..c0cab4422 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -89,6 +89,54 @@ public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalA } } + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_throwIllegalArgumentException() { + boolean ignoreFailure = false; + String modelId = "mockModelId"; + String embeddingField = "my_embedding_field"; + + // create with null type mapping + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor(PROCESSOR_TAG, DESCRIPTION, modelId, embeddingField, null, mlCommonsClientAccessor, env) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + + // type mapping has empty key + exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + Map.of("", "my_field"), + mlCommonsClientAccessor, + env + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + + // type mapping has empty value + // use vanila java syntax because it allows null values + Map typeMapping = new HashMap<>(); + typeMapping.put("my_field", null); + + exception = expectThrows( + IllegalArgumentException.class, + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + typeMapping, + mlCommonsClientAccessor, + env + ) + ); + assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); + } + @SneakyThrows public void testTextEmbeddingProcessConstructor_whenEmptyModelId_throwIllegalArgumentException() { Map registry = new HashMap<>(); @@ -111,6 +159,7 @@ public void testExecute_successful() { sourceAndMetadata.put("key1", "value1"); sourceAndMetadata.put("my_text_field", "value2"); sourceAndMetadata.put("key3", "value3"); + sourceAndMetadata.put("image_field", "base64_of_image_1234567890"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); @@ -151,8 +200,6 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep } public void testExecute_withListTypeInput_successful() { - List list1 = ImmutableList.of("test1", "test2", "test3"); - List list2 = ImmutableList.of("test4", "test5", "test6"); Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("my_text_field", "value1"); sourceAndMetadata.put("another_text_field", "value2"); @@ -238,6 +285,25 @@ public void testExecute_hybridTypeInput_successful() throws Exception { assert document.getSourceAndMetadata().containsKey("key2"); } + public void testExecute_whenInferencesAreEmpty_thenSuccessful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + private List> createMockVectorResult() { List> modelTensorList = new ArrayList<>(); List number1 = ImmutableList.of(1.234f, 2.354f); From f3fd17518c9ab190659f1355f9d01c877e8f0d3f Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 09:52:09 -0700 Subject: [PATCH 10/14] Added unit tests Signed-off-by: Martin Gaievski --- .../query/NeuralQueryBuilderTests.java | 52 +++++++++++++++++++ 1 file changed, 52 insertions(+) diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 968ed0233..5462acc76 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -51,6 +51,7 @@ import org.opensearch.index.query.MatchNoneQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryRewriteContext; +import org.opensearch.index.query.QueryShardContext; import org.opensearch.knn.index.query.KNNQueryBuilder; import org.opensearch.neuralsearch.common.VectorUtil; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -586,6 +587,42 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); } + @SneakyThrows + public void testRewrite_whenVectorSupplierNullAndQueryTextAndImageTextSet_thenSetVectorSupplier() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) + .modelId(MODEL_ID) + .k(K); + List expectedVector = Arrays.asList(1.0f, 2.0f, 3.0f, 4.0f, 5.0f); + MLCommonsClientAccessor mlCommonsClientAccessor = mock(MLCommonsClientAccessor.class); + doAnswer(invocation -> { + ActionListener> listener = invocation.getArgument(2); + listener.onResponse(expectedVector); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); + NeuralQueryBuilder.initialize(mlCommonsClientAccessor); + + final CountDownLatch inProgressLatch = new CountDownLatch(1); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + doAnswer(invocation -> { + BiConsumer> biConsumer = invocation.getArgument(0); + biConsumer.accept( + null, + ActionListener.wrap( + response -> inProgressLatch.countDown(), + err -> fail("Failed to set vector supplier: " + err.getMessage()) + ) + ); + return null; + }).when(queryRewriteContext).registerAsyncAction(any()); + + NeuralQueryBuilder queryBuilder = (NeuralQueryBuilder) neuralQueryBuilder.doRewrite(queryRewriteContext); + assertNotNull(queryBuilder.vectorSupplier()); + assertTrue(inProgressLatch.await(5, TimeUnit.SECONDS)); + assertArrayEquals(VectorUtil.vectorAsListToArray(expectedVector), queryBuilder.vectorSupplier().get(), 0.0f); + } + public void testRewrite_whenVectorNull_thenReturnCopy() { Supplier nullSupplier = () -> null; NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) @@ -626,6 +663,21 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { assertEquals(neuralQueryBuilder.filter(), knnQueryBuilder.getFilter()); } + public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() { + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); + QueryShardContext queryShardContext = mock(QueryShardContext.class); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> neuralQueryBuilder.doToQuery(queryShardContext) + ); + assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage()); + } + private void setUpClusterService(Version version) { ClusterService clusterService = NeuralSearchClusterTestUtils.mockClusterService(version); NeuralSearchClusterUtil.instance().initialize(clusterService); From 21f356be9bdc39047cdea8e318a09fba2d0e5f18 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 10:49:31 -0700 Subject: [PATCH 11/14] Rebased on recent changes in 2.x Signed-off-by: Martin Gaievski --- .../processor/InferenceProcessor.java | 321 ++++++++++++++++++ .../query/NeuralQueryBuilderTests.java | 14 +- 2 files changed, 328 insertions(+), 7 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java new file mode 100644 index 000000000..19944c11b --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -0,0 +1,321 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.BiConsumer; +import java.util.function.Supplier; +import java.util.stream.IntStream; + +import lombok.extern.log4j.Log4j2; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.AbstractProcessor; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableMap; + +/** + * The abstract class for text processing use cases. Users provide a field name map and a model id. + * During ingestion, the processor will use the corresponding model to inference the input texts, + * and set the target fields according to the field name map. + */ +@Log4j2 +public abstract class InferenceProcessor extends AbstractProcessor { + + public static final String MODEL_ID_FIELD = "model_id"; + public static final String FIELD_MAP_FIELD = "field_map"; + + private final String type; + + // This field is used for nested knn_vector/rank_features field. The value of the field will be used as the + // default key for the nested object. + private final String listTypeNestedMapKey; + + protected final String modelId; + + private final Map fieldMap; + + protected final MLCommonsClientAccessor mlCommonsClientAccessor; + + private final Environment environment; + + public InferenceProcessor( + String tag, + String description, + String type, + String listTypeNestedMapKey, + String modelId, + Map fieldMap, + MLCommonsClientAccessor clientAccessor, + Environment environment + ) { + super(tag, description); + this.type = type; + if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); + validateEmbeddingConfiguration(fieldMap); + + this.listTypeNestedMapKey = listTypeNestedMapKey; + this.modelId = modelId; + this.fieldMap = fieldMap; + this.mlCommonsClientAccessor = clientAccessor; + this.environment = environment; + } + + private void validateEmbeddingConfiguration(Map fieldMap) { + if (fieldMap == null + || fieldMap.size() == 0 + || fieldMap.entrySet() + .stream() + .anyMatch( + x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) + )) { + throw new IllegalArgumentException("Unable to create the processor as field_map has invalid key or value"); + } + } + + public abstract void doExecute( + IngestDocument ingestDocument, + Map ProcessMap, + List inferenceList, + BiConsumer handler + ); + + @Override + public IngestDocument execute(IngestDocument ingestDocument) throws Exception { + return ingestDocument; + } + + /** + * This method will be invoked by PipelineService to make async inference and then delegate the handler to + * process the inference response or failure. + * @param ingestDocument {@link IngestDocument} which is the document passed to processor. + * @param handler {@link BiConsumer} which is the handler which can be used after the inference task is done. + */ + @Override + public void execute(IngestDocument ingestDocument, BiConsumer handler) { + try { + validateEmbeddingFieldsValue(ingestDocument); + Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(ProcessMap); + if (inferenceList.size() == 0) { + handler.accept(ingestDocument, null); + } else { + doExecute(ingestDocument, ProcessMap, inferenceList, handler); + } + } catch (Exception e) { + handler.accept(null, e); + } + } + + @SuppressWarnings({ "unchecked" }) + private List createInferenceList(Map knnKeyMap) { + List texts = new ArrayList<>(); + knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else if (sourceValue instanceof Map) { + createInferenceListForMapTypeInput(sourceValue, texts); + } else { + texts.add(sourceValue.toString()); + } + }); + return texts; + } + + @SuppressWarnings("unchecked") + private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { + if (sourceValue instanceof Map) { + ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); + } else if (sourceValue instanceof List) { + texts.addAll(((List) sourceValue)); + } else { + if (sourceValue == null) return; + texts.add(sourceValue.toString()); + } + } + + @VisibleForTesting + Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + Map mapWithProcessorKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + Map treeRes = new LinkedHashMap<>(); + buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); + mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); + } else { + mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); + } + } + return mapWithProcessorKeys; + } + + private void buildMapWithProcessorKeyAndOriginalValueForMapType( + String parentKey, + Object processorKey, + Map sourceAndMetadataMap, + Map treeRes + ) { + if (processorKey == null || sourceAndMetadataMap == null) return; + if (processorKey instanceof Map) { + Map next = new LinkedHashMap<>(); + for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { + buildMapWithProcessorKeyAndOriginalValueForMapType( + nestedFieldMapEntry.getKey(), + nestedFieldMapEntry.getValue(), + (Map) sourceAndMetadataMap.get(parentKey), + next + ); + } + treeRes.put(parentKey, next); + } else { + String key = String.valueOf(processorKey); + treeRes.put(key, sourceAndMetadataMap.get(parentKey)); + } + } + + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue(String sourceKey, Object sourceValue) { + for (Object value : (List) sourceValue) { + if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + } + + protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { + Objects.requireNonNull(results, "embedding failed, inference returns null result!"); + log.debug("Model inference result fetched, starting build vector output!"); + Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); + nlpResult.forEach(ingestDocument::setFieldValue); + } + + @SuppressWarnings({ "unchecked" }) + @VisibleForTesting + Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { + IndexWrapper indexWrapper = new IndexWrapper(0); + Map result = new LinkedHashMap<>(); + for (Map.Entry knnMapEntry : processorMap.entrySet()) { + String knnKey = knnMapEntry.getKey(); + Object sourceValue = knnMapEntry.getValue(); + if (sourceValue instanceof String) { + result.put(knnKey, results.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + result.put(knnKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + } else if (sourceValue instanceof Map) { + putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); + } + } + return result; + } + + @SuppressWarnings({ "unchecked" }) + private void putNLPResultToSourceMapForMapType( + String processorKey, + Object sourceValue, + List results, + IndexWrapper indexWrapper, + Map sourceAndMetadataMap + ) { + if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; + if (sourceValue instanceof Map) { + for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { + putNLPResultToSourceMapForMapType( + inputNestedMapEntry.getKey(), + inputNestedMapEntry.getValue(), + results, + indexWrapper, + (Map) sourceAndMetadataMap.get(processorKey) + ); + } + } else if (sourceValue instanceof String) { + sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); + } else if (sourceValue instanceof List) { + sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + } + } + + private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { + List> keyToResult = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + return keyToResult; + } + + @Override + public String getType() { + return type; + } + + /** + * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, + * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order + * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the + * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase + * the index pointer during the recursive. + * index: the index pointer of the text embedding result. + */ + static class IndexWrapper { + private int index; + + protected IndexWrapper(int index) { + this.index = index; + } + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index 5462acc76..94665d879 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -665,15 +665,15 @@ public void testRewrite_whenFilterSet_thenKNNQueryBuilderFilterSet() { public void testQueryCreation_whenCreateQueryWithDoToQuery_thenFail() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) - .queryText(QUERY_TEXT) - .modelId(MODEL_ID) - .k(K) - .vectorSupplier(TEST_VECTOR_SUPPLIER) - .filter(TEST_FILTER); + .queryText(QUERY_TEXT) + .modelId(MODEL_ID) + .k(K) + .vectorSupplier(TEST_VECTOR_SUPPLIER) + .filter(TEST_FILTER); QueryShardContext queryShardContext = mock(QueryShardContext.class); UnsupportedOperationException exception = expectThrows( - UnsupportedOperationException.class, - () -> neuralQueryBuilder.doToQuery(queryShardContext) + UnsupportedOperationException.class, + () -> neuralQueryBuilder.doToQuery(queryShardContext) ); assertEquals("Query cannot be created by NeuralQueryBuilder directly", exception.getMessage()); } From 98341e7bd4b85d3f203ef02f52dcc11048640d33 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 12:16:32 -0700 Subject: [PATCH 12/14] Fixed integ test, indexes must use lucene as knn engine Signed-off-by: Martin Gaievski --- src/test/resources/processor/IndexMappings.json | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/src/test/resources/processor/IndexMappings.json b/src/test/resources/processor/IndexMappings.json index 5464a9311..02de6c6af 100644 --- a/src/test/resources/processor/IndexMappings.json +++ b/src/test/resources/processor/IndexMappings.json @@ -70,7 +70,16 @@ }, "passage_embedding": { "type": "knn_vector", - "dimension": 768 + "dimension": 768, + "method": { + "name": "hnsw", + "space_type": "l2", + "engine": "lucene", + "parameters": { + "ef_construction": 128, + "m": 24 + } + } }, "passage_text": { "type": "text" From 69813cee9a855e6711c792c7298dab301e0be0fe Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 12:38:12 -0700 Subject: [PATCH 13/14] Add final keyword to some method arguments Signed-off-by: Martin Gaievski --- .../factory/TextImageEmbeddingProcessorFactory.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java index a7ae347e0..df13c523b 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -28,17 +28,17 @@ public class TextImageEmbeddingProcessorFactory implements Factory { private final Environment environment; - public TextImageEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + public TextImageEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { this.clientAccessor = clientAccessor; this.environment = environment; } @Override public TextImageEmbeddingProcessor create( - Map registry, - String processorTag, - String description, - Map config + final Map registry, + final String processorTag, + final String description, + final Map config ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); From fd8b1fdcd1fdb7078b31652532e811c49def3648 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Tue, 3 Oct 2023 15:32:11 -0700 Subject: [PATCH 14/14] Allow updates on dynamic index setting max_depth_limit Signed-off-by: Martin Gaievski --- .../ml/MLCommonsClientAccessor.java | 2 +- .../neuralsearch/plugin/NeuralSearch.java | 2 +- .../TextImageEmbeddingProcessor.java | 23 +++++++--- .../TextImageEmbeddingProcessorFactory.java | 22 +++++++--- .../plugin/NeuralSearchTests.java | 13 +++++- .../TextImageEmbeddingProcessorTests.java | 44 ++++++++++++++++--- ...xtImageEmbeddingProcessorFactoryTests.java | 10 +++-- 7 files changed, 93 insertions(+), 23 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 5bb54cc2c..1c09f5996 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -226,7 +226,7 @@ private void retryableInferenceSentencesWithSingleVectorResult( MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { final List vector = buildSingleVectorFromResponse(mlOutput); - log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector); + log.debug("Inference Response for input sentence is : {} ", vector); listener.onResponse(vector); }, e -> { if (RetryUtil.shouldRetry(e, retryTime)) { diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index cf1a2f9bd..8672c6142 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -110,7 +110,7 @@ public Map getProcessors(Processor.Parameters paramet SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env), TextImageEmbeddingProcessor.TYPE, - new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env) + new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index 70ddc0d60..a0d9606e9 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -18,8 +18,11 @@ import lombok.extern.log4j.Log4j2; import org.apache.commons.lang3.StringUtils; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; @@ -50,6 +53,7 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { private final MLCommonsClientAccessor mlCommonsClientAccessor; private final Environment environment; + private final ClusterService clusterService; public TextImageEmbeddingProcessor( final String tag, @@ -58,7 +62,8 @@ public TextImageEmbeddingProcessor( final String embedding, final Map fieldMap, final MLCommonsClientAccessor clientAccessor, - final Environment environment + final Environment environment, + final ClusterService clusterService ) { super(tag, description); if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); @@ -69,6 +74,7 @@ public TextImageEmbeddingProcessor( this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; + this.clusterService = clusterService; } private void validateEmbeddingConfiguration(final Map fieldMap) { @@ -176,7 +182,8 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { } Class sourceValueClass = sourceValue.getClass(); if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1); + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + validateNestedTypeValue(mappedSourceKey, sourceValue, () -> 1, indexName); } else if (!String.class.isAssignableFrom(sourceValueClass)) { throw new IllegalArgumentException("field [" + mappedSourceKey + "] is neither string nor nested type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { @@ -187,9 +194,15 @@ private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { } @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final Supplier maxDepthSupplier) { + private void validateNestedTypeValue( + final String sourceKey, + final Object sourceValue, + final Supplier maxDepthSupplier, + final String indexName + ) { int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + Settings indexSettings = clusterService.state().metadata().index(indexName).getSettings(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(indexSettings)) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { validateListTypeValue(sourceKey, (List) sourceValue); @@ -197,7 +210,7 @@ private void validateNestedTypeValue(final String sourceKey, final Object source ((Map) sourceValue).values() .stream() .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1, indexName)); } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, can not process it"); } else if (StringUtils.isBlank(sourceValue.toString())) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java index df13c523b..c18ec6fb3 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -15,6 +15,9 @@ import java.util.Map; +import lombok.AllArgsConstructor; + +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; @@ -22,16 +25,12 @@ /** * Factory for text_image embedding ingest processor for ingestion pipeline. Instantiates processor based on user provided input. */ +@AllArgsConstructor public class TextImageEmbeddingProcessorFactory implements Factory { private final MLCommonsClientAccessor clientAccessor; - private final Environment environment; - - public TextImageEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { - this.clientAccessor = clientAccessor; - this.environment = environment; - } + private final ClusterService clusterService; @Override public TextImageEmbeddingProcessor create( @@ -43,6 +42,15 @@ public TextImageEmbeddingProcessor create( String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new TextImageEmbeddingProcessor(processorTag, description, modelId, embedding, filedMap, clientAccessor, environment); + return new TextImageEmbeddingProcessor( + processorTag, + description, + modelId, + embedding, + filedMap, + clientAccessor, + environment, + clusterService + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java index 8cae15678..69791681e 100644 --- a/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java +++ b/src/test/java/org/opensearch/neuralsearch/plugin/NeuralSearchTests.java @@ -11,6 +11,7 @@ import java.util.Map; import java.util.Optional; +import org.opensearch.ingest.IngestService; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; @@ -56,7 +57,17 @@ public void testQueryPhaseSearcher() { public void testProcessors() { NeuralSearch plugin = new NeuralSearch(); - Processor.Parameters processorParams = mock(Processor.Parameters.class); + Processor.Parameters processorParams = new Processor.Parameters( + null, + null, + null, + null, + null, + null, + mock(IngestService.class), + null, + null + ); Map processors = plugin.getProcessors(processorParams); assertNotNull(processors); assertNotNull(processors.get(TextEmbeddingProcessor.TYPE)); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java index c0cab4422..bae336d4a 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -32,9 +32,14 @@ import org.mockito.Mock; import org.mockito.MockitoAnnotations; import org.opensearch.OpenSearchParseException; +import org.opensearch.cluster.ClusterState; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.cluster.metadata.Metadata; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.core.action.ActionListener; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.IndexFieldMapper; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -48,9 +53,16 @@ public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { @Mock private MLCommonsClientAccessor mlCommonsClientAccessor; - @Mock private Environment env; + @Mock + private ClusterService clusterService; + @Mock + private ClusterState clusterState; + @Mock + private Metadata metadata; + @Mock + private IndexMetadata indexMetadata; @InjectMocks private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory; @@ -62,6 +74,10 @@ public void setup() { MockitoAnnotations.openMocks(this); Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); when(env.settings()).thenReturn(settings); + when(clusterService.state()).thenReturn(clusterState); + when(clusterState.metadata()).thenReturn(metadata); + when(metadata.index(anyString())).thenReturn(indexMetadata); + when(indexMetadata.getSettings()).thenReturn(settings); } @SneakyThrows @@ -98,7 +114,16 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t // create with null type mapping IllegalArgumentException exception = expectThrows( IllegalArgumentException.class, - () -> new TextImageEmbeddingProcessor(PROCESSOR_TAG, DESCRIPTION, modelId, embeddingField, null, mlCommonsClientAccessor, env) + () -> new TextImageEmbeddingProcessor( + PROCESSOR_TAG, + DESCRIPTION, + modelId, + embeddingField, + null, + mlCommonsClientAccessor, + env, + clusterService + ) ); assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); @@ -112,7 +137,8 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t embeddingField, Map.of("", "my_field"), mlCommonsClientAccessor, - env + env, + clusterService ) ); assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); @@ -131,7 +157,8 @@ public void testTextEmbeddingProcessConstructor_whenTypeMappingIsNullOrInvalid_t embeddingField, typeMapping, mlCommonsClientAccessor, - env + env, + clusterService ) ); assertEquals("Unable to create the TextImageEmbedding processor as field_map has invalid key or value", exception.getMessage()); @@ -183,7 +210,11 @@ public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeExcep IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); Map registry = new HashMap<>(); MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); - TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(accessor, env); + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + accessor, + env, + clusterService + ); Map config = new HashMap<>(); config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); @@ -223,6 +254,7 @@ public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", "hello world"); sourceAndMetadata.put("my_text_field", ret); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); @@ -254,6 +286,7 @@ public void testExecute_mapHasNonStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("my_text_field", map2); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); @@ -267,6 +300,7 @@ public void testExecute_mapHasEmptyStringValue_throwIllegalArgumentException() { Map sourceAndMetadata = new HashMap<>(); sourceAndMetadata.put("key1", map1); sourceAndMetadata.put("my_text_field", map2); + sourceAndMetadata.put(IndexFieldMapper.NAME, "my_index"); IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); TextImageEmbeddingProcessor processor = createInstance(); BiConsumer handler = mock(BiConsumer.class); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java index 39d1f14de..cbf53b8fc 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -19,6 +19,7 @@ import lombok.SneakyThrows; +import org.opensearch.cluster.service.ClusterService; import org.opensearch.env.Environment; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; @@ -30,7 +31,8 @@ public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( mock(MLCommonsClientAccessor.class), - mock(Environment.class) + mock(Environment.class), + mock(ClusterService.class) ); final Map processorFactories = new HashMap<>(); @@ -55,7 +57,8 @@ public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( mock(MLCommonsClientAccessor.class), - mock(Environment.class) + mock(Environment.class), + mock(ClusterService.class) ); final Map processorFactories = new HashMap<>(); @@ -88,7 +91,8 @@ public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( mock(MLCommonsClientAccessor.class), - mock(Environment.class) + mock(Environment.class), + mock(ClusterService.class) ); final Map processorFactories = new HashMap<>();