From 2192bb52e6f8122462b88e1dd6c67cef9e0828b8 Mon Sep 17 00:00:00 2001 From: Martin Gaievski Date: Thu, 28 Sep 2023 10:15:58 -0700 Subject: [PATCH] Changed approach to a hardcoded fields for image and text Signed-off-by: Martin Gaievski --- CHANGELOG.md | 1 + build.gradle | 1 + .../ml/MLCommonsClientAccessor.java | 51 +++- .../neuralsearch/plugin/NeuralSearch.java | 8 +- ....java => TextImageEmbeddingProcessor.java} | 220 +++++------------ .../factory/InferenceProcessorFactory.java | 36 --- .../TextImageEmbeddingProcessorFactory.java | 45 ++++ .../query/NeuralQueryBuilder.java | 26 +- .../constants/TestCommonConstants.java | 2 + .../ml/MLCommonsClientAccessorTests.java | 31 +++ .../processor/InferenceProcessorTests.java | 141 ----------- .../processor/NormalizationProcessorIT.java | 30 ++- .../processor/ScoreCombinationIT.java | 8 +- .../processor/ScoreNormalizationIT.java | 24 +- .../TextImageEmbeddingProcessorTests.java | 229 ++++++++++++++++++ .../InferenceProcessorFactoryTests.java | 48 ---- ...xtImageEmbeddingProcessorFactoryTests.java | 115 +++++++++ .../query/NeuralQueryBuilderTests.java | 40 ++- .../neuralsearch/query/NeuralQueryIT.java | 45 ++++ 19 files changed, 693 insertions(+), 408 deletions(-) rename src/main/java/org/opensearch/neuralsearch/processor/{InferenceProcessor.java => TextImageEmbeddingProcessor.java} (51%) delete mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java delete mode 100644 src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java delete mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java create mode 100644 src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index da2ae9ec9..ac7c783c4 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), ## [Unreleased 2.x](https://github.com/opensearch-project/neural-search/compare/2.10...2.x) ### Features Support sparse semantic retrieval by introducing `sparse_encoding` ingest processor and query builder ([#333](https://github.com/opensearch-project/neural-search/pull/333)) +Added Multimodal semantic search feature ([#359](https://github.com/opensearch-project/neural-search/pull/359)) ### Enhancements ### Bug Fixes ### Infrastructure diff --git a/build.gradle b/build.gradle index 613a75088..48d3184d9 100644 --- a/build.gradle +++ b/build.gradle @@ -151,6 +151,7 @@ dependencies { runtimeOnly group: 'org.reflections', name: 'reflections', version: '0.9.12' runtimeOnly group: 'org.javassist', name: 'javassist', version: '3.29.2-GA' runtimeOnly group: 'org.opensearch', name: 'common-utils', version: "${opensearch_build}" + runtimeOnly group: 'org.apache.commons', name: 'commons-text', version: '1.10.0' runtimeOnly group: 'com.google.code.gson', name: 'gson', version: '2.10.1' runtimeOnly group: 'org.json', name: 'json', version: '20230227' } diff --git a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java index e9dc342ec..c0d4992c6 100644 --- a/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java +++ b/src/main/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessor.java @@ -5,9 +5,11 @@ package org.opensearch.neuralsearch.ml; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; + import java.util.ArrayList; import java.util.Arrays; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -22,7 +24,6 @@ import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.dataset.TextDocsInputDataSet; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelResultFilter; @@ -115,6 +116,14 @@ public void inferenceSentencesWithMapResult( retryableInferenceSentencesWithMapResult(modelId, inputText, 0, listener); } + public void inferenceSentences( + @NonNull final String modelId, + @NonNull final Map inputObjects, + @NonNull final ActionListener> listener + ) { + inferenceSentencesWithRetry(TARGET_RESPONSE_FILTERS, modelId, inputObjects, 0, listener); + } + private void retryableInferenceSentencesWithMapResult( final String modelId, final List inputText, @@ -198,4 +207,42 @@ private List> buildVectorFromResponse(MLOutput mlOutput) { } return resultMaps; } + + private List buildSingleVectorFromResponse(MLOutput mlOutput) { + final List> vector = buildVectorFromResponse(mlOutput); + return vector.isEmpty() ? new ArrayList<>() : vector.get(0); + } + + private void inferenceSentencesWithRetry( + @NonNull final List targetResponseFilters, + final String modelId, + final Map inputObjects, + final int retryTime, + final ActionListener> listener + ) { + MLInput mlInput = createMLMultimodalInput(targetResponseFilters, inputObjects); + mlClient.predict(modelId, mlInput, ActionListener.wrap(mlOutput -> { + final List vector = buildSingleVectorFromResponse(mlOutput); + log.debug("Inference Response for input sentence {} is : {} ", inputObjects, vector); + listener.onResponse(vector); + }, e -> { + if (RetryUtil.shouldRetry(e, retryTime)) { + final int retryTimeAdd = retryTime + 1; + inferenceSentencesWithRetry(targetResponseFilters, modelId, inputObjects, retryTimeAdd, listener); + } else { + listener.onFailure(e); + } + })); + } + + private MLInput createMLMultimodalInput(final List targetResponseFilters, Map input) { + List inputText = new ArrayList<>(); + inputText.add(input.get(INPUT_TEXT)); + if (input.containsKey(INPUT_IMAGE)) { + inputText.add(input.get(INPUT_IMAGE)); + } + final ModelResultFilter modelResultFilter = new ModelResultFilter(false, true, targetResponseFilters, null); + final MLInputDataset inputDataset = new TextDocsInputDataSet(inputText, modelResultFilter); + return new MLInput(FunctionName.TEXT_EMBEDDING, null, inputDataset); + } } diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index afc44fa8c..ceccb5332 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -28,17 +28,17 @@ import org.opensearch.ingest.Processor; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.InferenceProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; import org.opensearch.neuralsearch.processor.combination.ScoreCombinationFactory; import org.opensearch.neuralsearch.processor.combination.ScoreCombiner; -import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; import org.opensearch.neuralsearch.processor.factory.NormalizationProcessorFactory; import org.opensearch.neuralsearch.processor.factory.SparseEncodingProcessorFactory; import org.opensearch.neuralsearch.processor.factory.TextEmbeddingProcessorFactory; +import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizationFactory; import org.opensearch.neuralsearch.processor.normalization.ScoreNormalizer; import org.opensearch.neuralsearch.query.HybridQueryBuilder; @@ -105,8 +105,8 @@ public Map getProcessors(Processor.Parameters paramet new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), SparseEncodingProcessor.TYPE, new SparseEncodingProcessorFactory(clientAccessor, parameters.env), - InferenceProcessor.TYPE, - new InferenceProcessorFactory(clientAccessor, parameters.env) + TextImageEmbeddingProcessor.TYPE, + new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env) ); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java similarity index 51% rename from src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java rename to src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java index a630a7a1f..4659a43c6 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessor.java @@ -5,16 +5,15 @@ package org.opensearch.neuralsearch.processor; -import java.util.ArrayList; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.function.BiConsumer; import java.util.function.Supplier; -import java.util.stream.IntStream; import lombok.extern.log4j.Log4j2; @@ -27,37 +26,42 @@ import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; import com.google.common.annotations.VisibleForTesting; -import com.google.common.collect.ImmutableMap; /** - * This processor is used for getting embeddings for multimodal type of inference, model_id can be used to indicate which model user use, + * This processor is used for user input data text and image embedding processing, model_id can be used to indicate which model user use, * and field_map can be used to indicate which fields needs embedding and the corresponding keys for the embedding results. */ @Log4j2 -public class InferenceProcessor extends AbstractProcessor { +public class TextImageEmbeddingProcessor extends AbstractProcessor { - public static final String TYPE = "inference-processor"; + public static final String TYPE = "text_image_embedding"; public static final String MODEL_ID_FIELD = "model_id"; + public static final String EMBEDDING_FIELD = "embedding"; public static final String FIELD_MAP_FIELD = "field_map"; - - private static final String LIST_TYPE_NESTED_MAP_KEY = "knn"; + public static final String TEXT_FIELD_NAME = "text"; + public static final String IMAGE_FIELD_NAME = "image"; + public static final String INPUT_TEXT = "inputText"; + public static final String INPUT_IMAGE = "inputImage"; @VisibleForTesting private final String modelId; - - private final Map fieldMap; + private final String embedding; + private final Map fieldMap; private final MLCommonsClientAccessor mlCommonsClientAccessor; private final Environment environment; - private static final int MAX_CONTENT_LENGTH_IN_BYTES = 10 * 1024 * 1024; // limit of 10Mb per field value + // limit of 16Mb per field value. This is from current bedrock model, calculated as 2048*2048 pixels (24 bit), + // image to base64 encoding assumed to have 4/3 ratio, assuming UTF-8 encoding average of 1 byte per character + private static final int MAX_CONTENT_LENGTH_IN_BYTES = 16 * 1024 * 1024; - public InferenceProcessor( + public TextImageEmbeddingProcessor( String tag, String description, String modelId, - Map fieldMap, + String embedding, + Map fieldMap, MLCommonsClientAccessor clientAccessor, Environment environment ) { @@ -66,20 +70,23 @@ public InferenceProcessor( validateEmbeddingConfiguration(fieldMap); this.modelId = modelId; + this.embedding = embedding; this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; } - private void validateEmbeddingConfiguration(Map fieldMap) { + private void validateEmbeddingConfiguration(Map fieldMap) { if (fieldMap == null || fieldMap.isEmpty() || fieldMap.entrySet() .stream() - .anyMatch( - x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue().toString()) - )) { - throw new IllegalArgumentException("Unable to create the InferenceProcessor processor as field_map has invalid key or value"); + .anyMatch(x -> StringUtils.isBlank(x.getKey()) || Objects.isNull(x.getValue()) || StringUtils.isBlank(x.getValue()))) { + throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); + } + + if (fieldMap.entrySet().stream().anyMatch(entry -> !Set.of(TEXT_FIELD_NAME, IMAGE_FIELD_NAME).contains(entry.getKey()))) { + throw new IllegalArgumentException("Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); } } @@ -103,13 +110,15 @@ public void execute(IngestDocument ingestDocument, BiConsumer knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); - Map> inferenceMap = createInferenceMap(knnMap); + Map knnMap = buildMapWithKnnKeyAndOriginalValue(ingestDocument); // ["passage_embedding", ["my_text_123123123", + // "image_base64_123456780"]] + Map inferenceMap = createInferences(knnMap); // ["text_field" : "my_text_123123123", "image_field" : + // "image_base64_123456780"] if (inferenceMap.isEmpty()) { handler.accept(ingestDocument, null); } else { - mlCommonsClientAccessor.inferenceMultimodal(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, knnMap, vectors); + mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceMap, ActionListener.wrap(vectors -> { + setVectorFieldsToDocument(ingestDocument, vectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } @@ -119,91 +128,49 @@ public void execute(IngestDocument ingestDocument, BiConsumer knnMap, List> vectors) { + void setVectorFieldsToDocument(IngestDocument ingestDocument, List vectors) { Objects.requireNonNull(vectors, "embedding failed, inference returns null result!"); log.debug("Text embedding result fetched, starting build vector output!"); - Map textEmbeddingResult = buildTextEmbeddingResult(knnMap, vectors, ingestDocument.getSourceAndMetadata()); + Map textEmbeddingResult = buildTextEmbeddingResult(this.embedding, vectors); textEmbeddingResult.forEach(ingestDocument::setFieldValue); } - private Map> createInferenceMap(Map knnKeyMap) { - Map> objects = new HashMap<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof Map) { - Map sourceValues = (Map) sourceValue; - if (sourceValues.entrySet() - .stream() - .anyMatch( - entry -> entry.getKey().length() > MAX_CONTENT_LENGTH_IN_BYTES - || entry.getValue().length() > MAX_CONTENT_LENGTH_IN_BYTES - )) { - throw new IllegalArgumentException( - String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) - ); - } - objects.put(knnMapEntry.getKey(), sourceValues); - } else { - throw new RuntimeException("Cannot build inference object"); - } - }); - return objects; - } - @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); - } - }); - return texts; - } - - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); + private Map createInferences(Map knnKeyMap) { + Map texts = new HashMap<>(); + if (fieldMap.containsKey(TEXT_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(TEXT_FIELD_NAME))) { + texts.put(INPUT_TEXT, knnKeyMap.get(fieldMap.get(TEXT_FIELD_NAME))); + } + if (fieldMap.containsKey(IMAGE_FIELD_NAME) && knnKeyMap.containsKey(fieldMap.get(IMAGE_FIELD_NAME))) { + texts.put(INPUT_IMAGE, knnKeyMap.get(fieldMap.get(IMAGE_FIELD_NAME))); } + return texts; } @VisibleForTesting - Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { + Map buildMapWithKnnKeyAndOriginalValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithKnnKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - // Map treeRes = new LinkedHashMap<>(); - // buildMapWithKnnKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - // mapWithKnnKeys.put(originalKey, treeRes.get(originalKey)); - Map knnMap = Map.of( - "value", - sourceAndMetadataMap.get(originalKey).toString(), - "model_input", - ((Map) targetKey).get("model_input").toString() + Map mapWithKnnKeys = new LinkedHashMap<>(); + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getValue(); // field from ingest document that we need to sent as model input, part of + // processor definition + + if (!sourceAndMetadataMap.containsKey(originalKey)) { + continue; + } + if (!(sourceAndMetadataMap.get(originalKey) instanceof String)) { + throw new IllegalArgumentException("Unsupported format of the field in the document, value must be a string"); + } + if (((String) sourceAndMetadataMap.get(originalKey)).length() > MAX_CONTENT_LENGTH_IN_BYTES) { + throw new IllegalArgumentException( + String.format(Locale.ROOT, "content cannot be longer than a %d bytes", MAX_CONTENT_LENGTH_IN_BYTES) ); - mapWithKnnKeys.put(originalKey, knnMap); - } else { - mapWithKnnKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); } + mapWithKnnKeys.put(originalKey, (String) sourceAndMetadataMap.get(originalKey)); } return mapWithKnnKeys; } - @SuppressWarnings({ "unchecked" }) private void buildMapWithKnnKeyAndOriginalValueForMapType( String parentKey, Object knnKey, @@ -230,65 +197,15 @@ private void buildMapWithKnnKeyAndOriginalValueForMapType( @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildTextEmbeddingResult( - Map knnMap, - List> modelTensorList, - Map sourceAndMetadataMap - ) { - IndexWrapper indexWrapper = new IndexWrapper(0); + Map buildTextEmbeddingResult(String knnKey, List modelTensorList) { Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : knnMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - List modelTensor = modelTensorList.get(indexWrapper.index++); - result.put(knnKey, modelTensor); - } + result.put(knnKey, modelTensorList); return result; } - @SuppressWarnings({ "unchecked" }) - private void putTextEmbeddingResultToSourceMapForMapType( - String knnKey, - Object sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (knnKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - putTextEmbeddingResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - modelTensorList, - indexWrapper, - (Map) sourceAndMetadataMap.get(knnKey) - ); - } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(knnKey, modelTensorList.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put( - knnKey, - buildTextEmbeddingResultForListType((List) sourceValue, modelTensorList, indexWrapper) - ); - } - } - - private List>> buildTextEmbeddingResultForListType( - List sourceValue, - List> modelTensorList, - IndexWrapper indexWrapper - ) { - List>> numbers = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> numbers.add(ImmutableMap.of(LIST_TYPE_NESTED_MAP_KEY, modelTensorList.get(indexWrapper.index++)))); - return numbers; - } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); if (sourceValue != null) { String sourceKey = embeddingFieldsEntry.getKey(); @@ -340,21 +257,4 @@ private static void validateListTypeValue(String sourceKey, Object sourceValue) public String getType() { return TYPE; } - - /** - * Since we need to build a {@link List} as the input for text embedding, and the result type is {@link List} of {@link List}, - * we need to map the result back to the input one by one with exactly order. For nested map type input, we're performing a pre-order - * traversal to extract the input strings, so when mapping back to the nested map, we still need a pre-order traversal to ensure the - * order. And we also need to ensure the index pointer goes forward in the recursive, so here the IndexWrapper is to store and increase - * the index pointer during the recursive. - * index: the index pointer of the text embedding result. - */ - static class IndexWrapper { - private int index; - - protected IndexWrapper(int index) { - this.index = index; - } - } - } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java deleted file mode 100644 index b2dd1fc28..000000000 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactory.java +++ /dev/null @@ -1,36 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor.factory; - -import static org.opensearch.ingest.ConfigurationUtils.readMap; -import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; -import static org.opensearch.ingest.Processor.Factory; - -import java.util.Map; - -import org.opensearch.env.Environment; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.InferenceProcessor; - -public class InferenceProcessorFactory implements Factory { - - private final MLCommonsClientAccessor clientAccessor; - - private final Environment environment; - - public InferenceProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { - this.clientAccessor = clientAccessor; - this.environment = environment; - } - - @Override - public InferenceProcessor create(Map registry, String processorTag, String description, Map config) - throws Exception { - String modelId = readStringProperty(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.MODEL_ID_FIELD); - Map filedMap = readMap(InferenceProcessor.TYPE, processorTag, config, InferenceProcessor.FIELD_MAP_FIELD); - return new InferenceProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); - } -} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java new file mode 100644 index 000000000..6069fa40c --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactory.java @@ -0,0 +1,45 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.ingest.ConfigurationUtils.readStringProperty; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.Factory; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TYPE; + +import java.util.Map; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; + +public class TextImageEmbeddingProcessorFactory implements Factory { + + private final MLCommonsClientAccessor clientAccessor; + + private final Environment environment; + + public TextImageEmbeddingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + this.clientAccessor = clientAccessor; + this.environment = environment; + } + + @Override + public TextImageEmbeddingProcessor create( + Map registry, + String processorTag, + String description, + Map config + ) throws Exception { + String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); + String embedding = readStringProperty(TYPE, processorTag, config, EMBEDDING_FIELD); + Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + return new TextImageEmbeddingProcessor(processorTag, description, modelId, embedding, filedMap, clientAccessor, environment); + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java index ebcd9a88b..5cfb561df 100644 --- a/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java +++ b/src/main/java/org/opensearch/neuralsearch/query/NeuralQueryBuilder.java @@ -7,8 +7,12 @@ import static org.opensearch.knn.index.query.KNNQueryBuilder.FILTER_FIELD; import static org.opensearch.neuralsearch.common.VectorUtil.vectorAsListToArray; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_IMAGE; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.INPUT_TEXT; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; import java.util.function.Supplier; import lombok.AccessLevel; @@ -19,6 +23,7 @@ import lombok.experimental.Accessors; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang.StringUtils; import org.apache.commons.lang.builder.EqualsBuilder; import org.apache.commons.lang.builder.HashCodeBuilder; import org.apache.lucene.search.Query; @@ -59,6 +64,9 @@ public class NeuralQueryBuilder extends AbstractQueryBuilder @VisibleForTesting static final ParseField QUERY_TEXT_FIELD = new ParseField("query_text"); + @VisibleForTesting + static final ParseField QUERY_IMAGE_FIELD = new ParseField("query_image"); + @VisibleForTesting static final ParseField MODEL_ID_FIELD = new ParseField("model_id"); @@ -75,6 +83,7 @@ public static void initialize(MLCommonsClientAccessor mlClient) { private String fieldName; private String queryText; + private String queryImage; private String modelId; private int k = DEFAULT_K; @VisibleForTesting @@ -162,7 +171,9 @@ public static NeuralQueryBuilder fromXContent(XContentParser parser) throws IOEx + "]" ); } - requireValue(neuralQueryBuilder.queryText(), "Query text must be provided for neural query"); + if (StringUtils.isBlank(neuralQueryBuilder.queryText()) && StringUtils.isBlank(neuralQueryBuilder.queryImage())) { + throw new IllegalArgumentException("Either query text or image text must be provided for neural query"); + } requireValue(neuralQueryBuilder.fieldName(), "Field name must be provided for neural query"); requireValue(neuralQueryBuilder.modelId(), "Model ID must be provided for neural query"); @@ -178,6 +189,8 @@ private static void parseQueryParams(XContentParser parser, NeuralQueryBuilder n } else if (token.isValue()) { if (QUERY_TEXT_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.queryText(parser.text()); + } else if (QUERY_IMAGE_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { + neuralQueryBuilder.queryImage(parser.text()); } else if (MODEL_ID_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { neuralQueryBuilder.modelId(parser.text()); } else if (K_FIELD.match(currentFieldName, parser.getDeprecationHandler())) { @@ -221,13 +234,20 @@ protected QueryBuilder doRewrite(QueryRewriteContext queryRewriteContext) { } SetOnce vectorSetOnce = new SetOnce<>(); + Map inferenceInput = new HashMap<>(); + if (StringUtils.isNotBlank(queryText())) { + inferenceInput.put(INPUT_TEXT, queryText()); + } + if (StringUtils.isNotBlank(queryImage())) { + inferenceInput.put(INPUT_IMAGE, queryImage()); + } queryRewriteContext.registerAsyncAction( - ((client, actionListener) -> ML_CLIENT.inferenceSentence(modelId(), queryText(), ActionListener.wrap(floatList -> { + ((client, actionListener) -> ML_CLIENT.inferenceSentences(modelId(), inferenceInput, ActionListener.wrap(floatList -> { vectorSetOnce.set(vectorAsListToArray(floatList)); actionListener.onResponse(null); }, actionListener::onFailure))) ); - return new NeuralQueryBuilder(fieldName(), queryText(), modelId(), k(), vectorSetOnce::get, filter()); + return new NeuralQueryBuilder(fieldName(), queryText(), queryImage(), modelId(), k(), vectorSetOnce::get, filter()); } @Override diff --git a/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java b/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java index 2776a53e6..185934b07 100644 --- a/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java +++ b/src/test/java/org/opensearch/neuralsearch/constants/TestCommonConstants.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.constants; import java.util.List; +import java.util.Map; import lombok.AccessLevel; import lombok.NoArgsConstructor; @@ -16,4 +17,5 @@ public class TestCommonConstants { public static final List TARGET_RESPONSE_FILTERS = List.of("sentence_embedding"); public static final Float[] PREDICT_VECTOR_ARRAY = new Float[] { 2.0f, 3.0f }; public static final List SENTENCES_LIST = List.of("TEXT"); + public static final Map SENTENCES_MAP = Map.of("inputText", "Text query", "inputImage", "base641234567890"); } diff --git a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java index 1b10966dc..b972d474b 100644 --- a/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/ml/MLCommonsClientAccessorTests.java @@ -280,6 +280,37 @@ public void testInferenceSentencesWithMapResult_whenNotRetryableException_thenFa Mockito.verify(resultListener).onFailure(illegalStateException); } + public void testInferenceMultimodal_whenValidInput_thenSuccess() { + final List vector = new ArrayList<>(List.of(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(createModelTensorOutput(TestCommonConstants.PREDICT_VECTOR_ARRAY)); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onResponse(vector); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + + public void testInferenceMultimodal_whenExceptionFromMLClient_thenFailure() { + final RuntimeException exception = new RuntimeException(); + Mockito.doAnswer(invocation -> { + final ActionListener actionListener = invocation.getArgument(2); + actionListener.onFailure(exception); + return null; + }).when(client).predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + accessor.inferenceSentences(TestCommonConstants.MODEL_ID, TestCommonConstants.SENTENCES_MAP, singleSentenceResultListener); + + Mockito.verify(client) + .predict(Mockito.eq(TestCommonConstants.MODEL_ID), Mockito.isA(MLInput.class), Mockito.isA(ActionListener.class)); + Mockito.verify(singleSentenceResultListener).onFailure(exception); + Mockito.verifyNoMoreInteractions(singleSentenceResultListener); + } + private ModelTensorOutput createModelTensorOutput(final Float[] output) { final List tensorsList = new ArrayList<>(); final List mlModelTensorList = new ArrayList<>(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java deleted file mode 100644 index 60d0dc143..000000000 --- a/src/test/java/org/opensearch/neuralsearch/processor/InferenceProcessorTests.java +++ /dev/null @@ -1,141 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor; - -import static org.mockito.ArgumentMatchers.anyMap; -import static org.mockito.ArgumentMatchers.anyString; -import static org.mockito.Mockito.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.isA; -import static org.mockito.Mockito.isNull; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; - -import java.util.ArrayList; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.function.BiConsumer; - -import lombok.SneakyThrows; - -import org.junit.Before; -import org.mockito.InjectMocks; -import org.mockito.Mock; -import org.mockito.MockitoAnnotations; -import org.opensearch.OpenSearchParseException; -import org.opensearch.common.settings.Settings; -import org.opensearch.core.action.ActionListener; -import org.opensearch.env.Environment; -import org.opensearch.ingest.IngestDocument; -import org.opensearch.ingest.Processor; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.factory.InferenceProcessorFactory; -import org.opensearch.test.OpenSearchTestCase; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; - -public class InferenceProcessorTests extends OpenSearchTestCase { - - @Mock - private MLCommonsClientAccessor mlCommonsClientAccessor; - - @Mock - private Environment env; - - @InjectMocks - private InferenceProcessorFactory inferenceProcessorFactory; - private static final String PROCESSOR_TAG = "mockTag"; - private static final String DESCRIPTION = "mockDescription"; - - @Before - public void setup() { - MockitoAnnotations.openMocks(this); - Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); - when(env.settings()).thenReturn(settings); - } - - @SneakyThrows - private InferenceProcessor createInstance(List> vector) { - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - config.put( - TextEmbeddingProcessor.FIELD_MAP_FIELD, - ImmutableMap.of( - "key1", - Map.of("model_input", "ModelInput1", "model_output", "ModelOutput1", "embedding", "key1Mapped"), - "key2", - Map.of("model_input", "ModelInput2", "model_output", "ModelOutput2", "embedding", "key2Mapped") - ) - ); - return inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - } - - @SneakyThrows - public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { - Map registry = new HashMap<>(); - Map config = new HashMap<>(); - config.put(TextEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); - try { - inferenceProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); - } catch (OpenSearchParseException e) { - assertEquals("[field_map] required property is missing", e.getMessage()); - } - } - - public void testExecute_successful() { - Map sourceAndMetadata = new HashMap<>(); - sourceAndMetadata.put("key1", "value1"); - sourceAndMetadata.put("key2", "value2"); - IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - InferenceProcessor processor = createInstance(createMockVectorWithLength(2)); - - List> modelTensorList = createMockVectorResult(); - doAnswer(invocation -> { - ActionListener>> listener = invocation.getArgument(2); - listener.onResponse(modelTensorList); - return null; - }).when(mlCommonsClientAccessor).inferenceMultimodal(anyString(), anyMap(), isA(ActionListener.class)); - - BiConsumer handler = mock(BiConsumer.class); - processor.execute(ingestDocument, handler); - verify(handler).accept(any(IngestDocument.class), isNull()); - } - - private List> createMockVectorResult() { - List> modelTensorList = new ArrayList<>(); - List number1 = ImmutableList.of(1.234f, 2.354f); - List number2 = ImmutableList.of(3.234f, 4.354f); - List number3 = ImmutableList.of(5.234f, 6.354f); - List number4 = ImmutableList.of(7.234f, 8.354f); - List number5 = ImmutableList.of(9.234f, 10.354f); - List number6 = ImmutableList.of(11.234f, 12.354f); - List number7 = ImmutableList.of(13.234f, 14.354f); - modelTensorList.add(number1); - modelTensorList.add(number2); - modelTensorList.add(number3); - modelTensorList.add(number4); - modelTensorList.add(number5); - modelTensorList.add(number6); - modelTensorList.add(number7); - return modelTensorList; - } - - private List> createMockVectorWithLength(int size) { - float suffix = .234f; - List> result = new ArrayList<>(); - for (int i = 0; i < size * 2;) { - List number = new ArrayList<>(); - number.add(i++ + suffix); - number.add(i++ + suffix); - result.add(number); - } - return result; - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java index 3cd71e5a1..86e75f736 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/NormalizationProcessorIT.java @@ -94,7 +94,15 @@ public void testResultProcessor_whenOneShardAndQueryMatches_thenSuccessful() { createSearchPipelineWithResultsPostProcessor(SEARCH_PIPELINE); String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 5, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -129,7 +137,15 @@ public void testResultProcessor_whenDefaultProcessorConfigAndQueryMatches_thenSu createSearchPipelineWithDefaultResultsPostProcessor(SEARCH_PIPELINE); String modelId = getDeployedModelId(); - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 5, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); @@ -153,7 +169,15 @@ public void testResultProcessor_whenMultipleShardsAndQueryMatches_thenSuccessful String modelId = getDeployedModelId(); int totalExpectedDocQty = 6; - NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 6, null, null); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_DOC_TEXT1, + "", + modelId, + 6, + null, + null + ); TermQueryBuilder termQueryBuilder = QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3); HybridQueryBuilder hybridQueryBuilder = new HybridQueryBuilder(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java index 03b77549a..4993df7fb 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreCombinationIT.java @@ -213,7 +213,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); - hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapDefaultNorm = search( @@ -236,7 +236,7 @@ public void testHarmonicMeanCombination_whenOneShardAndQueryMatches_thenSuccessf ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( @@ -279,7 +279,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderDefaultNorm = new HybridQueryBuilder(); - hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderDefaultNorm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderDefaultNorm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapDefaultNorm = search( @@ -302,7 +302,7 @@ public void testGeometricMeanCombination_whenOneShardAndQueryMatches_thenSuccess ); HybridQueryBuilder hybridQueryBuilderL2Norm = new HybridQueryBuilder(); - hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderL2Norm.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null)); hybridQueryBuilderL2Norm.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapL2Norm = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java index 6f98e8d5e..aa133c44d 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/ScoreNormalizationIT.java @@ -96,7 +96,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); - hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapArithmeticMean = search( @@ -119,7 +121,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); - hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapHarmonicMean = search( @@ -142,7 +146,9 @@ public void testL2Norm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); - hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapGeometricMean = search( @@ -185,7 +191,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { String modelId = getDeployedModelId(); HybridQueryBuilder hybridQueryBuilderArithmeticMean = new HybridQueryBuilder(); - hybridQueryBuilderArithmeticMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderArithmeticMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderArithmeticMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapArithmeticMean = search( @@ -208,7 +216,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderHarmonicMean = new HybridQueryBuilder(); - hybridQueryBuilderHarmonicMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderHarmonicMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderHarmonicMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapHarmonicMean = search( @@ -231,7 +241,9 @@ public void testMinMaxNorm_whenOneShardAndQueryMatches_thenSuccessful() { ); HybridQueryBuilder hybridQueryBuilderGeometricMean = new HybridQueryBuilder(); - hybridQueryBuilderGeometricMean.add(new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, "", modelId, 5, null, null)); + hybridQueryBuilderGeometricMean.add( + new NeuralQueryBuilder(TEST_KNN_VECTOR_FIELD_NAME_1, TEST_DOC_TEXT1, "", modelId, 5, null, null) + ); hybridQueryBuilderGeometricMean.add(QueryBuilders.termQuery(TEST_TEXT_FIELD_NAME_1, TEST_QUERY_TEXT3)); Map searchResponseAsMapGeometricMean = search( diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java new file mode 100644 index 000000000..dfe974d35 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextImageEmbeddingProcessorTests.java @@ -0,0 +1,229 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor; + +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.Supplier; + +import lombok.SneakyThrows; + +import org.junit.Before; +import org.mockito.InjectMocks; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opensearch.OpenSearchParseException; +import org.opensearch.common.settings.Settings; +import org.opensearch.core.action.ActionListener; +import org.opensearch.env.Environment; +import org.opensearch.ingest.IngestDocument; +import org.opensearch.ingest.Processor; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.factory.TextImageEmbeddingProcessorFactory; +import org.opensearch.test.OpenSearchTestCase; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; + +public class TextImageEmbeddingProcessorTests extends OpenSearchTestCase { + + @Mock + private MLCommonsClientAccessor mlCommonsClientAccessor; + + @Mock + private Environment env; + + @InjectMocks + private TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory; + private static final String PROCESSOR_TAG = "mockTag"; + private static final String DESCRIPTION = "mockDescription"; + + @Before + public void setup() { + MockitoAnnotations.openMocks(this); + Settings settings = Settings.builder().put("index.mapping.depth.limit", 20).build(); + when(env.settings()).thenReturn(settings); + } + + @SneakyThrows + private TextImageEmbeddingProcessor createInstance(List> vector) { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + return textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } + + @SneakyThrows + public void testTextEmbeddingProcessConstructor_whenConfigMapEmpty_throwIllegalArgumentException() { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + try { + textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + } catch (OpenSearchParseException e) { + assertEquals("[embedding] required property is missing", e.getMessage()); + } + } + + public void testExecute_successful() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "value1"); + sourceAndMetadata.put("my_text_field", "value2"); + sourceAndMetadata.put("key3", "value3"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + @SneakyThrows + public void testExecute_whenInferenceThrowInterruptedException_throwRuntimeException() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + Map registry = new HashMap<>(); + MLCommonsClientAccessor accessor = mock(MLCommonsClientAccessor.class); + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory(accessor, env); + + Map config = new HashMap<>(); + config.put(TextImageEmbeddingProcessor.MODEL_ID_FIELD, "mockModelId"); + config.put(TextImageEmbeddingProcessor.EMBEDDING_FIELD, "my_embedding_field"); + config.put( + TextImageEmbeddingProcessor.FIELD_MAP_FIELD, + ImmutableMap.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "image_field") + ); + TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); + doThrow(new RuntimeException()).when(accessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(RuntimeException.class)); + } + + public void testExecute_withListTypeInput_successful() { + List list1 = ImmutableList.of("test1", "test2", "test3"); + List list2 = ImmutableList.of("test4", "test5", "test6"); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("another_text_field", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(6)); + + List> modelTensorList = createMockVectorResult(); + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onResponse(modelTensorList); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(any(IngestDocument.class), isNull()); + } + + public void testExecute_mapDepthReachLimit_throwIllegalArgumentException() { + Map ret = createMaxDepthLimitExceedMap(() -> 1); + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", "hello world"); + sourceAndMetadata.put("my_text_field", ret); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + public void testExecute_MLClientAccessorThrowFail_handlerFailure() { + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("my_text_field", "value1"); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + TextImageEmbeddingProcessor processor = createInstance(createMockVectorWithLength(2)); + + doAnswer(invocation -> { + ActionListener>> listener = invocation.getArgument(2); + listener.onFailure(new IllegalArgumentException("illegal argument")); + return null; + }).when(mlCommonsClientAccessor).inferenceSentences(anyString(), anyMap(), isA(ActionListener.class)); + + BiConsumer handler = mock(BiConsumer.class); + processor.execute(ingestDocument, handler); + verify(handler).accept(isNull(), any(IllegalArgumentException.class)); + } + + private List> createMockVectorResult() { + List> modelTensorList = new ArrayList<>(); + List number1 = ImmutableList.of(1.234f, 2.354f); + List number2 = ImmutableList.of(3.234f, 4.354f); + List number3 = ImmutableList.of(5.234f, 6.354f); + List number4 = ImmutableList.of(7.234f, 8.354f); + List number5 = ImmutableList.of(9.234f, 10.354f); + List number6 = ImmutableList.of(11.234f, 12.354f); + List number7 = ImmutableList.of(13.234f, 14.354f); + modelTensorList.add(number1); + modelTensorList.add(number2); + modelTensorList.add(number3); + modelTensorList.add(number4); + modelTensorList.add(number5); + modelTensorList.add(number6); + modelTensorList.add(number7); + return modelTensorList; + } + + private List> createMockVectorWithLength(int size) { + float suffix = .234f; + List> result = new ArrayList<>(); + for (int i = 0; i < size * 2;) { + List number = new ArrayList<>(); + number.add(i++ + suffix); + number.add(i++ + suffix); + result.add(number); + } + return result; + } + + private Map createMaxDepthLimitExceedMap(Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > 21) { + return null; + } + Map innerMap = new HashMap<>(); + Map ret = createMaxDepthLimitExceedMap(() -> maxDepth + 1); + if (ret == null) return innerMap; + innerMap.put("hello", ret); + return innerMap; + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java deleted file mode 100644 index 55aba5443..000000000 --- a/src/test/java/org/opensearch/neuralsearch/processor/factory/InferenceProcessorFactoryTests.java +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.neuralsearch.processor.factory; - -import static org.mockito.Mockito.mock; -import static org.opensearch.neuralsearch.processor.InferenceProcessor.FIELD_MAP_FIELD; -import static org.opensearch.neuralsearch.processor.InferenceProcessor.MODEL_ID_FIELD; - -import java.util.HashMap; -import java.util.Map; - -import lombok.SneakyThrows; - -import org.opensearch.env.Environment; -import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; -import org.opensearch.neuralsearch.processor.InferenceProcessor; -import org.opensearch.test.OpenSearchTestCase; - -public class InferenceProcessorFactoryTests extends OpenSearchTestCase { - - private static final String NORMALIZATION_METHOD = "min_max"; - private static final String COMBINATION_METHOD = "arithmetic_mean"; - - @SneakyThrows - public void testNormalizationProcessor_whenNoParams_thenSuccessful() { - InferenceProcessorFactory inferenceProcessorFactory = new InferenceProcessorFactory( - mock(MLCommonsClientAccessor.class), - mock(Environment.class) - ); - - final Map processorFactories = new HashMap<>(); - String tag = "tag"; - String description = "description"; - boolean ignoreFailure = false; - Map config = new HashMap<>(); - config.put(MODEL_ID_FIELD, "1234567678"); - config.put( - FIELD_MAP_FIELD, - Map.of("passage_text", Map.of("model_input", "TextInput1", "model_output", "TextEmbdedding1", "embedding", "passage_embedding")) - ); - InferenceProcessor inferenceProcessor = inferenceProcessorFactory.create(processorFactories, tag, description, config); - assertNotNull(inferenceProcessor); - assertEquals("inference-processor", inferenceProcessor.getType()); - } -} diff --git a/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java new file mode 100644 index 000000000..768a67161 --- /dev/null +++ b/src/test/java/org/opensearch/neuralsearch/processor/factory/TextImageEmbeddingProcessorFactoryTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.neuralsearch.processor.factory; + +import static org.mockito.Mockito.mock; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.FIELD_MAP_FIELD; +import static org.opensearch.neuralsearch.processor.TextEmbeddingProcessor.MODEL_ID_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.EMBEDDING_FIELD; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.IMAGE_FIELD_NAME; +import static org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor.TEXT_FIELD_NAME; + +import java.util.HashMap; +import java.util.Map; + +import lombok.SneakyThrows; + +import org.opensearch.env.Environment; +import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.TextImageEmbeddingProcessor; +import org.opensearch.test.OpenSearchTestCase; + +public class TextImageEmbeddingProcessorFactoryTests extends OpenSearchTestCase { + + @SneakyThrows + public void testNormalizationProcessor_whenAllParamsPassed_thenSuccessful() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map config = new HashMap<>(); + config.put(MODEL_ID_FIELD, "1234567678"); + config.put(EMBEDDING_FIELD, "embedding_field"); + config.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", IMAGE_FIELD_NAME, "my_image_field")); + TextImageEmbeddingProcessor inferenceProcessor = textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + config + ); + assertNotNull(inferenceProcessor); + assertEquals("text_image_embedding", inferenceProcessor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenOnlyOneParamSet_thenSuccessful() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map configOnlyTextField = new HashMap<>(); + configOnlyTextField.put(MODEL_ID_FIELD, "1234567678"); + configOnlyTextField.put(EMBEDDING_FIELD, "embedding_field"); + configOnlyTextField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); + TextImageEmbeddingProcessor processor = textImageEmbeddingProcessorFactory.create( + processorFactories, + tag, + description, + configOnlyTextField + ); + assertNotNull(processor); + assertEquals("text_image_embedding", processor.getType()); + + Map configOnlyImageField = new HashMap<>(); + configOnlyImageField.put(MODEL_ID_FIELD, "1234567678"); + configOnlyImageField.put(EMBEDDING_FIELD, "embedding_field"); + configOnlyImageField.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field")); + processor = textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configOnlyImageField); + assertNotNull(processor); + assertEquals("text_image_embedding", processor.getType()); + } + + @SneakyThrows + public void testNormalizationProcessor_whenMixOfParamsOrEmptyParams_thenFail() { + TextImageEmbeddingProcessorFactory textImageEmbeddingProcessorFactory = new TextImageEmbeddingProcessorFactory( + mock(MLCommonsClientAccessor.class), + mock(Environment.class) + ); + + final Map processorFactories = new HashMap<>(); + String tag = "tag"; + String description = "description"; + boolean ignoreFailure = false; + Map configMixOfFields = new HashMap<>(); + configMixOfFields.put(MODEL_ID_FIELD, "1234567678"); + configMixOfFields.put(EMBEDDING_FIELD, "embedding_field"); + configMixOfFields.put(FIELD_MAP_FIELD, Map.of(TEXT_FIELD_NAME, "my_text_field", "random_field_name", "random_field")); + IllegalArgumentException exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configMixOfFields) + ); + assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has unsupported field name"); + Map configNoFields = new HashMap<>(); + configNoFields.put(MODEL_ID_FIELD, "1234567678"); + configNoFields.put(EMBEDDING_FIELD, "embedding_field"); + configNoFields.put(FIELD_MAP_FIELD, Map.of()); + exception = expectThrows( + IllegalArgumentException.class, + () -> textImageEmbeddingProcessorFactory.create(processorFactories, tag, description, configNoFields) + ); + assertEquals(exception.getMessage(), "Unable to create the TextImageEmbedding processor as field_map has invalid key or value"); + } +} diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java index f389dfd22..b7a9145de 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryBuilderTests.java @@ -6,6 +6,7 @@ package org.opensearch.neuralsearch.query; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.opensearch.core.xcontent.ToXContent.EMPTY_PARAMS; @@ -16,6 +17,7 @@ import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.K_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.MODEL_ID_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.NAME; +import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_IMAGE_FIELD; import static org.opensearch.neuralsearch.query.NeuralQueryBuilder.QUERY_TEXT_FIELD; import java.io.IOException; @@ -56,6 +58,7 @@ public class NeuralQueryBuilderTests extends OpenSearchTestCase { private static final String FIELD_NAME = "testField"; private static final String QUERY_TEXT = "Hello world!"; + private static final String IMAGE_TEXT = "base641234567890"; private static final String MODEL_ID = "mfgfgdsfgfdgsde"; private static final int K = 10; private static final float BOOST = 1.8f; @@ -70,6 +73,7 @@ public void testFromXContent_whenBuiltWithDefaults_thenBuildSuccessfully() { { "VECTOR_FIELD": { "query_text": "string", + "query_image": "string", "model_id": "string", "k": int } @@ -111,6 +115,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { .startObject() .startObject(FIELD_NAME) .field(QUERY_TEXT_FIELD.getPreferredName(), QUERY_TEXT) + .field(QUERY_IMAGE_FIELD.getPreferredName(), IMAGE_TEXT) .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) .field(K_FIELD.getPreferredName(), K) .field(BOOST_FIELD.getPreferredName(), BOOST) @@ -124,6 +129,7 @@ public void testFromXContent_whenBuiltWithOptionals_thenBuildSuccessfully() { assertEquals(FIELD_NAME, neuralQueryBuilder.fieldName()); assertEquals(QUERY_TEXT, neuralQueryBuilder.queryText()); + assertEquals(IMAGE_TEXT, neuralQueryBuilder.queryImage()); assertEquals(MODEL_ID, neuralQueryBuilder.modelId()); assertEquals(K, neuralQueryBuilder.k()); assertEquals(BOOST, neuralQueryBuilder.boost(), 0.0); @@ -262,6 +268,33 @@ public void testFromXContent_whenBuildWithDuplicateParameters_thenFail() { expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); } + @SneakyThrows + public void testFromXContent_whenNoQueryField_thenFail() { + /* + { + "VECTOR_FIELD": { + "model_id": "string", + "model_id": "string", + "k": int, + "k": int + } + } + */ + XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() + .startObject() + .startObject(FIELD_NAME) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .field(MODEL_ID_FIELD.getPreferredName(), MODEL_ID) + .field(K_FIELD.getPreferredName(), K) + .endObject() + .endObject(); + + XContentParser contentParser = createParser(xContentBuilder); + contentParser.nextToken(); + expectThrows(IOException.class, () -> NeuralQueryBuilder.fromXContent(contentParser)); + } + @SneakyThrows public void testFromXContent_whenBuiltWithInvalidFilter_thenFail() { /* @@ -337,6 +370,7 @@ public void testStreams() { NeuralQueryBuilder original = new NeuralQueryBuilder(); original.fieldName(FIELD_NAME); original.queryText(QUERY_TEXT); + original.queryImage(IMAGE_TEXT); original.modelId(MODEL_ID); original.k(K); original.boost(BOOST); @@ -362,6 +396,7 @@ public void testHashAndEquals() { String fieldName2 = "field 2"; String queryText1 = "query text 1"; String queryText2 = "query text 2"; + String imageText1 = "query image 1"; String modelId1 = "model-1"; String modelId2 = "model-2"; float boost1 = 1.8f; @@ -376,6 +411,7 @@ public void testHashAndEquals() { NeuralQueryBuilder neuralQueryBuilder_baseline = new NeuralQueryBuilder().fieldName(fieldName1) .queryText(queryText1) + .queryImage(imageText1) .modelId(modelId1) .k(k1) .boost(boost1) @@ -512,7 +548,7 @@ public void testRewrite_whenVectorSupplierNull_thenSetVectorSupplier() { ActionListener> listener = invocation.getArgument(2); listener.onResponse(expectedVector); return null; - }).when(mlCommonsClientAccessor).inferenceSentence(any(), any(), any()); + }).when(mlCommonsClientAccessor).inferenceSentences(any(), anyMap(), any()); NeuralQueryBuilder.initialize(mlCommonsClientAccessor); final CountDownLatch inProgressLatch = new CountDownLatch(1); @@ -539,6 +575,7 @@ public void testRewrite_whenVectorNull_thenReturnCopy() { Supplier nullSupplier = () -> null; NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .vectorSupplier(nullSupplier); @@ -549,6 +586,7 @@ public void testRewrite_whenVectorNull_thenReturnCopy() { public void testRewrite_whenVectorSupplierAndVectorSet_thenReturnKNNQueryBuilder() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder().fieldName(FIELD_NAME) .queryText(QUERY_TEXT) + .queryImage(IMAGE_TEXT) .modelId(MODEL_ID) .k(K) .vectorSupplier(TEST_VECTOR_SUPPLIER); diff --git a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java index c55da9e7f..6a8e48465 100644 --- a/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java +++ b/src/test/java/org/opensearch/neuralsearch/query/NeuralQueryIT.java @@ -32,6 +32,7 @@ public class NeuralQueryIT extends BaseNeuralSearchIT { private static final String TEST_NESTED_INDEX_NAME = "test-neural-nested-index"; private static final String TEST_MULTI_DOC_INDEX_NAME = "test-neural-multi-doc-index"; private static final String TEST_QUERY_TEXT = "Hello world"; + private static final String TEST_IMAGE_TEXT = "/9j/4AAQSkZJRgABAQAASABIAAD"; private static final String TEST_KNN_VECTOR_FIELD_NAME_1 = "test-knn-vector-1"; private static final String TEST_KNN_VECTOR_FIELD_NAME_2 = "test-knn-vector-2"; private static final String TEST_TEXT_FIELD_NAME_1 = "test-text-field"; @@ -80,6 +81,7 @@ public void testBasicQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -115,6 +117,7 @@ public void testBoostQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -159,6 +162,7 @@ public void testRescoreQuery() { NeuralQueryBuilder rescoreNeuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -207,6 +211,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder1 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -215,6 +220,7 @@ public void testBooleanQuery_withMultipleNeuralQueries() { NeuralQueryBuilder neuralQueryBuilder2 = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_2, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -263,6 +269,7 @@ public void testBooleanQuery_withNeuralAndBM25Queries() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -307,6 +314,7 @@ public void testNestedQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_NESTED, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -349,6 +357,7 @@ public void testFilterQuery() { NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( TEST_KNN_VECTOR_FIELD_NAME_1, TEST_QUERY_TEXT, + "", modelId, 1, null, @@ -362,6 +371,42 @@ public void testFilterQuery() { assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); } + /** + * Tests basic query for multimodal: + * { + * "query": { + * "neural": { + * "text_knn": { + * "query_text": "Hello world", + * "query_image": "base64_1234567890", + * "model_id": "dcsdcasd", + * "k": 1 + * } + * } + * } + * } + */ + @SneakyThrows + public void testMultimodalQuery() { + initializeIndexIfNotExist(TEST_BASIC_INDEX_NAME); + String modelId = getDeployedModelId(); + NeuralQueryBuilder neuralQueryBuilder = new NeuralQueryBuilder( + TEST_KNN_VECTOR_FIELD_NAME_1, + TEST_QUERY_TEXT, + TEST_IMAGE_TEXT, + modelId, + 1, + null, + null + ); + Map searchResponseAsMap = search(TEST_BASIC_INDEX_NAME, neuralQueryBuilder, 1); + Map firstInnerHit = getFirstInnerHit(searchResponseAsMap); + + assertEquals("1", firstInnerHit.get("_id")); + float expectedScore = computeExpectedScore(modelId, testVector, TEST_SPACE_TYPE, TEST_QUERY_TEXT); + assertEquals(expectedScore, objectToFloat(firstInnerHit.get("_score")), 0.0); + } + private void initializeIndexIfNotExist(String indexName) throws IOException { if (TEST_BASIC_INDEX_NAME.equals(indexName) && !indexExists(TEST_BASIC_INDEX_NAME)) { prepareKnnIndex(