diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index 98c32b189..5bb54cc2c 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -119,8 +119,7 @@ public void inferenceSentencesWithMapResult( /** * Abstraction to call predict function of api of MLClient with provided targetResponse filters. It uses the * custom model provided as modelId and run the {@link FunctionName#TEXT_EMBEDDING}. The return will be sent - * using the actionListener which will have a {@link List} of {@link List} of {@link Float} in the order of - * inputText. + * using the actionListener which will have a list of floats in the order of inputText. * * @param modelId {@link String} * @param inputObjects {@link Map} of {@link String}, {@link String} on which inference needs to happen @@ -212,7 +211,7 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { return resultMaps; } - private List buildSingleVectorFromResponse(MLOutput mlOutput) { + private List buildSingleVectorFromResponse(final MLOutput mlOutput) { final List> vector = buildVectorFromResponse(mlOutput); return vector.isEmpty() ? new ArrayList<>() : vector.get(0); } @@ -239,7 +238,7 @@ private void retryableInferenceSentencesWithSingleVectorResult( })); } - private MLInput createMLMultimodalInput(final List targetResponseFilters, Map input) { + private MLInput createMLMultimodalInput(final List targetResponseFilters, final Map input) { List inputText = new ArrayList<>(); inputText.add(input.get(INPUT_TEXT)); if (input.containsKey(INPUT_IMAGE)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index c5ff97822..f77d72157 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -8,6 +8,7 @@ import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; +import java.util.Locale; import java.util.Map; import java.util.Objects; import java.util.Set; @@ -41,8 +42,8 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { public static final String IMAGE_FIELD_NAME = "image"; public static final String INPUT_TEXT = "inputText"; public static final String INPUT_IMAGE = "inputImage"; + private static final Set VALID_FIELD_NAMES = Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME); - @VisibleForTesting private final String modelId; private final String embedding; private final Map fieldMap; @@ -51,13 +52,13 @@ public class TextImageEmbeddingProcessor extends AbstractProcessor { private final Environment environment; public TextImageEmbeddingProcessor( - String tag, - String description, - String modelId, - String embedding, - Map fieldMap, - MLCommonsClientAccessor clientAccessor, - Environment environment + final String tag, + final String description, + final String modelId, + final String embedding, + final Map fieldMap, + final MLCommonsClientAccessor clientAccessor, + final Environment environment ) { super(tag, description); if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, can not process it"); @@ -70,17 +71,21 @@ public TextImageEmbeddingProcessor( this.environment = environment; } - private void validateEmbeddingConfiguration(Map fieldMap) { + private void validateEmbeddingConfiguration(final Map fieldMap) { if (fieldMap == null || fieldMap.isEmpty() - || fieldMap.entrySet() - .stream() - .anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue()))) { + || fieldMap.entrySet().stream().anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()))) { throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); } - if (fieldMap.entrySet().stream().anyMatch(entry -> !Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME).contains(entry.getKey()))) { - throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + if (fieldMap.entrySet().stream().anyMatch(entry -> !VALID_FIELD_NAMES.contains(entry.getKey()))) { + throw new IllegalArgumentException( + String.format( + Locale.ROOT, + "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [%s]", + String.join(",", VALID_FIELD_NAMES) + ) + ); } } @@ -115,7 +120,7 @@ public void execute(final IngestDocument ingestDocument, final BiConsumer vectors) { + private void setVectorFieldsToDocument(final IngestDocument ingestDocument, final List vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); @@ -123,7 +128,7 @@ private void setVectorFieldsToDocument(IngestDocument ingestDocument, List createInferences(Map knnKeyMap) { + private Map createInferences(final Map knnKeyMap) { Map texts = new HashMap<>(); if (fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(TEXT_FIELD_NAME))) { texts.put(INPUT_TEXT, knnKeyMap.get(fieldMap.get(TEXT_FIELD_NAME))); @@ -135,7 +140,7 @@ private Map createInferences(Map knnKeyMap) { } @VisibleForTesting - Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { + Map buildMapWithKnnKeyAndOriginalValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); Map mapWithKnnKeys = new LinkedHashMap<>(); for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { @@ -155,37 +160,39 @@ Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocu @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildTextEmbeddingResult(String knnKey, List modelTensorList) { + Map buildTextEmbeddingResult(final String knnKey, List modelTensorList) { Map result = new LinkedHashMap<>(); result.put(knnKey, modelTensorList); return result; } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + private void validateEmbeddingFieldsValue(final IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); - } + if (Objects.isNull(sourceValue)) { + continue; } + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, can not process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, can not process it"); + } + } } @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + private void validateNestedTypeValue(final String sourceKey, final Object sourceValue, final Supplier maxDepthSupplier) { int maxDepth = maxDepthSupplier.get(); if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, can not process it"); } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue); + validateListTypeValue(sourceKey, (List) sourceValue); } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { ((Map) sourceValue).values() .stream() @@ -199,8 +206,8 @@ private void validateNestedTypeValue(String sourceKey, Object sourceValue, Suppl } @SuppressWarnings({ "rawtypes" }) - private static void validateListTypeValue(String sourceKey, Object sourceValue) { - for (Object value : (List) sourceValue) { + private static void validateListTypeValue(final String sourceKey, final List sourceValue) { + for (Object value : sourceValue) { if (value == null) { throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, can not process it"); } else if (!(value instanceof String)) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 0c9a6fa2c..adf6f6d21 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -25,17 +25,17 @@ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final Environment environment; - public TextEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + public TextEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { this.clientAccessor = clientAccessor; this.environment = environment; } @Override public TextEmbeddingProcessor create( - Map registry, - String processorTag, - String description, - Map config + final Map registry, + final String processorTag, + final String description, + final Map config ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java index 768a67161..39d1f14de 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -5,6 +5,8 @@ package org.opensearch.neuralsearch.processor.factory; +import static org.hamcrest.Matchers.allOf; +import static org.hamcrest.Matchers.containsString; import static org.mockito.Mockito.mock; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; @@ -101,7 +103,16 @@ public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { IllegalArgumentException.class, () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configMixOfFields) ); - assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + org.hamcrest.MatcherAssert.assertThat( + exception.getMessage(), + allOf( + containsString( + "Unable to create the TextImageEmbedding processor with provided field name(s). Following names are supported [" + ), + containsString("image"), + containsString("text") + ) + ); Map configNoFields = new HashMap<>(); configNoFields.put(MODEL_ID_FIELD, "1234567678"); configNoFields.put(EMBEDDING_FIELD, "embedding_field");