From cd323ab2defc88f7575c0b8ad00574f9e84ee33b Mon Sep 17 00:00:00 2001 From: zane-neo Date: Fri, 1 Mar 2024 14:40:14 +0800 Subject: [PATCH] Support list> type in embedding and extract validation logic to common class Signed-off-by: zane-neo --- .../neuralsearch/plugin/NeuralSearch.java | 6 +- .../processor/InferenceProcessor.java | 118 +++++++----------- .../processor/ProcessorInputValidator.java | 93 ++++++++++++++ .../processor/SparseEncodingProcessor.java | 7 +- .../processor/TextEmbeddingProcessor.java | 7 +- .../SparseEncodingProcessorFactory.java | 19 ++- .../TextEmbeddingProcessorFactory.java | 20 ++- .../TextEmbeddingProcessorTests.java | 14 +-- 8 files changed, 190 insertions(+), 94 deletions(-) create mode 100644 src/main/java/org/opensearch/neuralsearch/processor/ProcessorInputValidator.java diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index dc5b6e8f2..13f554621 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -29,6 +29,7 @@ import org.opensearch.neuralsearch.processor.NeuralQueryEnricherProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessor; import org.opensearch.neuralsearch.processor.NormalizationProcessorWorkflow; +import org.opensearch.neuralsearch.processor.ProcessorInputValidator; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; import org.opensearch.neuralsearch.processor.DocumentChunkingProcessor; @@ -109,11 +110,12 @@ public List> getQueries() { @Override public Map getProcessors(Processor.Parameters parameters) { clientAccessor = new MLCommonsClientAccessor(new MachineLearningNodeClient(parameters.client)); + ProcessorInputValidator processorInputValidator = new ProcessorInputValidator(); return Map.of( TextEmbeddingProcessor.TYPE, - new TextEmbeddingProcessorFactory(clientAccessor, parameters.env), + new TextEmbeddingProcessorFactory(clientAccessor, parameters.env, processorInputValidator), SparseEncodingProcessor.TYPE, - new SparseEncodingProcessorFactory(clientAccessor, parameters.env), + new SparseEncodingProcessorFactory(clientAccessor, parameters.env, processorInputValidator), TextImageEmbeddingProcessor.TYPE, new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), DocumentChunkingProcessor.TYPE, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index fe201abae..4762c699d 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -10,13 +10,11 @@ import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; -import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.lang3.StringUtils; import org.opensearch.env.Environment; -import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -51,6 +49,8 @@ public abstract class InferenceProcessor extends AbstractProcessor { private final Environment environment; + private final ProcessorInputValidator processorInputValidator; + public InferenceProcessor( String tag, String description, @@ -59,7 +59,8 @@ public InferenceProcessor( String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, - Environment environment + Environment environment, + ProcessorInputValidator processorInputValidator ) { super(tag, description); this.type = type; @@ -71,6 +72,7 @@ public InferenceProcessor( this.fieldMap = fieldMap; this.mlCommonsClientAccessor = clientAccessor; this.environment = environment; + this.processorInputValidator = processorInputValidator; } private void validateEmbeddingConfiguration(Map fieldMap) { @@ -106,13 +108,13 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { - validateEmbeddingFieldsValue(ingestDocument); - Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(ProcessMap); + processorInputValidator.validateFieldsValue(fieldMap, environment, ingestDocument, false); + Map processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(processMap); if (inferenceList.size() == 0) { handler.accept(ingestDocument, null); } else { - doExecute(ingestDocument, ProcessMap, inferenceList, handler); + doExecute(ingestDocument, processMap, inferenceList, handler); } } catch (Exception e) { handler.accept(null, e); @@ -125,7 +127,13 @@ private List createInferenceList(Map knnKeyMap) { knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { Object sourceValue = knnMapEntry.getValue(); if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); + for (Object nestedValue : (List) sourceValue) { + if (nestedValue instanceof String) { + texts.add((String) nestedValue); + } else { + texts.addAll((List) nestedValue); + } + } } else if (sourceValue instanceof Map) { createInferenceListForMapTypeInput(sourceValue, texts); } else { @@ -204,68 +212,16 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( } } - private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { - Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); - if (sourceValue != null) { - String sourceKey = embeddingFieldsEntry.getKey(); - Class sourceValueClass = sourceValue.getClass(); - if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { - validateNestedTypeValue(sourceKey, sourceValue, () -> 1); - } else if (!String.class.isAssignableFrom(sourceValueClass)) { - throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); - } - } - } - } - - @SuppressWarnings({ "rawtypes", "unchecked" }) - private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - int maxDepth = maxDepthSupplier.get(); - if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); - } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { - validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier); - } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { - ((Map) sourceValue).values() - .stream() - .filter(Objects::nonNull) - .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); - } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); - } else if (StringUtils.isBlank(sourceValue.toString())) { - throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); - } - } - - @SuppressWarnings({ "rawtypes" }) - private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { - for (Object value : (List) sourceValue) { - if (value instanceof Map) { - validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1); - } else if (value == null) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); - } else if (!(value instanceof String)) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); - } else if (StringUtils.isBlank(value.toString())) { - throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); - } - } - } - - protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { + protected void setTargetFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { Objects.requireNonNull(results, "embedding failed, inference returns null result!"); log.debug("Model inference result fetched, starting build vector output!"); - Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); - nlpResult.forEach(ingestDocument::setFieldValue); + Map result = buildResult(processorMap, results, ingestDocument.getSourceAndMetadata()); + result.forEach(ingestDocument::setFieldValue); } @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { + Map buildResult(Map processorMap, List results, Map sourceAndMetadataMap) { IndexWrapper indexWrapper = new IndexWrapper(0); Map result = new LinkedHashMap<>(); for (Map.Entry knnMapEntry : processorMap.entrySet()) { @@ -274,16 +230,16 @@ Map buildNLPResult(Map processorMap, List res if (sourceValue instanceof String) { result.put(knnKey, results.get(indexWrapper.index++)); } else if (sourceValue instanceof List) { - result.put(knnKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + result.put(knnKey, buildResultForListType((List) sourceValue, results, indexWrapper)); } else if (sourceValue instanceof Map) { - putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); + putResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); } } return result; } @SuppressWarnings({ "unchecked" }) - private void putNLPResultToSourceMapForMapType( + private void putResultToSourceMapForMapType( String processorKey, Object sourceValue, List results, @@ -294,12 +250,12 @@ private void putNLPResultToSourceMapForMapType( if (sourceValue instanceof Map) { for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { if (sourceAndMetadataMap.get(processorKey) instanceof List) { - // build nlp output for list of nested objects + // build output for list of nested objects for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); } } else { - putNLPResultToSourceMapForMapType( + putResultToSourceMapForMapType( inputNestedMapEntry.getKey(), inputNestedMapEntry.getValue(), results, @@ -311,15 +267,27 @@ private void putNLPResultToSourceMapForMapType( } else if (sourceValue instanceof String) { sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); + sourceAndMetadataMap.put(processorKey, buildResultForListType((List) sourceValue, results, indexWrapper)); } } - private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { - List> keyToResult = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); - return keyToResult; + protected List buildResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { + Object peek = sourceValue.get(0); + if (peek instanceof String) { + List> keyToResult = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + return keyToResult; + } else { + List>> keyToResult = new ArrayList<>(); + for (Object nestedList : sourceValue) { + List> nestedResult = new ArrayList<>(); + IntStream.range(0, ((List) nestedList).size()) + .forEachOrdered(x -> nestedResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + keyToResult.add(nestedResult); + } + return keyToResult; + } } @Override diff --git a/src/main/java/org/opensearch/neuralsearch/processor/ProcessorInputValidator.java b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorInputValidator.java new file mode 100644 index 000000000..57766c911 --- /dev/null +++ b/src/main/java/org/opensearch/neuralsearch/processor/ProcessorInputValidator.java @@ -0,0 +1,93 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.neuralsearch.processor; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; +import org.opensearch.ingest.IngestDocument; + +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.function.Supplier; + +public class ProcessorInputValidator { + + public void validateFieldsValue( + Map fieldMap, + Environment environment, + IngestDocument ingestDocument, + boolean allowEmpty + ) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, environment, allowEmpty, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + } else if (!allowEmpty && StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue( + String sourceKey, + Object sourceValue, + Environment environment, + boolean allowEmpty, + Supplier maxDepthSupplier + ) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue, environment, allowEmpty, maxDepthSupplier); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, environment, allowEmpty, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + } else if (!allowEmpty && StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue( + String sourceKey, + Object sourceValue, + Environment environment, + boolean allowEmpty, + Supplier maxDepthSupplier + ) { + for (Object value : (List) sourceValue) { + if (value instanceof Map) { + validateNestedTypeValue(sourceKey, value, environment, allowEmpty, () -> maxDepthSupplier.get() + 1); + } else if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + } else if (value instanceof List) { + for (Object nestedValue : (List) sourceValue) { + if (!(nestedValue instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } + } + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } else if (!allowEmpty && StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + } +} diff --git a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java index 8acf95bf7..1087651c2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/SparseEncodingProcessor.java @@ -32,9 +32,10 @@ public SparseEncodingProcessor( String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, - Environment environment + Environment environment, + ProcessorInputValidator processorInputValidator ) { - super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, processorInputValidator); } @Override @@ -45,7 +46,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentencesWithMapResult(this.modelId, inferenceList, ActionListener.wrap(resultMaps -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); + setTargetFieldsToDocument(ingestDocument, ProcessMap, TokenWeightUtil.fetchListOfTokenWeightMap(resultMaps)); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java index c1b8f92a6..ce024cd87 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessor.java @@ -31,9 +31,10 @@ public TextEmbeddingProcessor( String modelId, Map fieldMap, MLCommonsClientAccessor clientAccessor, - Environment environment + Environment environment, + ProcessorInputValidator processorInputValidator ) { - super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment); + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, modelId, fieldMap, clientAccessor, environment, processorInputValidator); } @Override @@ -44,7 +45,7 @@ public void doExecute( BiConsumer handler ) { mlCommonsClientAccessor.inferenceSentences(this.modelId, inferenceList, ActionListener.wrap(vectors -> { - setVectorFieldsToDocument(ingestDocument, ProcessMap, vectors); + setTargetFieldsToDocument(ingestDocument, ProcessMap, vectors); handler.accept(ingestDocument, null); }, e -> { handler.accept(null, e); })); } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java index 95b2803a0..d5b90c406 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/SparseEncodingProcessorFactory.java @@ -15,6 +15,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.ProcessorInputValidator; import org.opensearch.neuralsearch.processor.SparseEncodingProcessor; import lombok.extern.log4j.Log4j2; @@ -26,10 +27,16 @@ public class SparseEncodingProcessorFactory implements Processor.Factory { private final MLCommonsClientAccessor clientAccessor; private final Environment environment; + private ProcessorInputValidator processorInputValidator; - public SparseEncodingProcessorFactory(MLCommonsClientAccessor clientAccessor, Environment environment) { + public SparseEncodingProcessorFactory( + MLCommonsClientAccessor clientAccessor, + Environment environment, + ProcessorInputValidator processorInputValidator + ) { this.clientAccessor = clientAccessor; this.environment = environment; + this.processorInputValidator = processorInputValidator; } @Override @@ -42,6 +49,14 @@ public SparseEncodingProcessor create( String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new SparseEncodingProcessor(processorTag, description, modelId, fieldMap, clientAccessor, environment); + return new SparseEncodingProcessor( + processorTag, + description, + modelId, + fieldMap, + clientAccessor, + environment, + processorInputValidator + ); } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java index 7802cb1f6..061ac3474 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/factory/TextEmbeddingProcessorFactory.java @@ -15,6 +15,7 @@ import org.opensearch.env.Environment; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; +import org.opensearch.neuralsearch.processor.ProcessorInputValidator; import org.opensearch.neuralsearch.processor.TextEmbeddingProcessor; /** @@ -26,9 +27,16 @@ public class TextEmbeddingProcessorFactory implements Processor.Factory { private final Environment environment; - public TextEmbeddingProcessorFactory(final MLCommonsClientAccessor clientAccessor, final Environment environment) { + private ProcessorInputValidator processorInputValidator; + + public TextEmbeddingProcessorFactory( + final MLCommonsClientAccessor clientAccessor, + final Environment environment, + ProcessorInputValidator processorInputValidator + ) { this.clientAccessor = clientAccessor; this.environment = environment; + this.processorInputValidator = processorInputValidator; } @Override @@ -40,6 +48,14 @@ public TextEmbeddingProcessor create( ) throws Exception { String modelId = readStringProperty(TYPE, processorTag, config, MODEL_ID_FIELD); Map filedMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); - return new TextEmbeddingProcessor(processorTag, description, modelId, filedMap, clientAccessor, environment); + return new TextEmbeddingProcessor( + processorTag, + description, + modelId, + filedMap, + clientAccessor, + environment, + processorInputValidator + ); } } diff --git a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java index 60408d820..25d41c345 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/TextEmbeddingProcessorTests.java @@ -357,7 +357,7 @@ public void testProcessResponse_successful() throws Exception { Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + processor.setTargetFieldsToDocument(ingestDocument, knnMap, modelTensorList); assertEquals(12, ingestDocument.getSourceAndMetadata().size()); } @@ -378,7 +378,7 @@ public void testBuildVectorOutput_withPlainStringValue_successful() { assertEquals(knnKeyList.get(lastIndex), configValueList.get(lastIndex).toString()); List> modelTensorList = createMockVectorResult(); - Map result = processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + Map result = processor.buildResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); assertTrue(result.containsKey("oriKey1_knn")); assertTrue(result.containsKey("oriKey2_knn")); assertTrue(result.containsKey("oriKey3_knn")); @@ -395,7 +395,7 @@ public void testBuildVectorOutput_withNestedMap_successful() { TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + processor.buildResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map favoritesMap = (Map) ingestDocument.getSourceAndMetadata().get("favorites"); assertNotNull(favoritesMap); Map favoriteGames = (Map) favoritesMap.get("favorite.games"); @@ -411,7 +411,7 @@ public void testBuildVectorOutput_withNestedList_successful() { TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + textEmbeddingProcessor.buildResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); List> nestedObj = (List>) ingestDocument.getSourceAndMetadata().get("nestedField"); assertTrue(nestedObj.get(0).containsKey("vectorField")); assertTrue(nestedObj.get(1).containsKey("vectorField")); @@ -425,7 +425,7 @@ public void testBuildVectorOutput_withNestedList_Level2_successful() { TextEmbeddingProcessor textEmbeddingProcessor = createInstanceWithNestedMapConfiguration(config); Map knnMap = textEmbeddingProcessor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - textEmbeddingProcessor.buildNLPResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); + textEmbeddingProcessor.buildResult(knnMap, modelTensorList, ingestDocument.getSourceAndMetadata()); Map nestedLevel1 = (Map) ingestDocument.getSourceAndMetadata().get("nestedField"); List> nestedObj = (List>) nestedLevel1.get("nestedField"); assertTrue(nestedObj.get(0).containsKey("vectorField")); @@ -440,10 +440,10 @@ public void test_updateDocument_appendVectorFieldsToDocument_successful() { TextEmbeddingProcessor processor = createInstanceWithNestedMapConfiguration(config); Map knnMap = processor.buildMapWithProcessorKeyAndOriginalValue(ingestDocument); List> modelTensorList = createMockVectorResult(); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList); + processor.setTargetFieldsToDocument(ingestDocument, knnMap, modelTensorList); List> modelTensorList1 = createMockVectorResult(); - processor.setVectorFieldsToDocument(ingestDocument, knnMap, modelTensorList1); + processor.setTargetFieldsToDocument(ingestDocument, knnMap, modelTensorList1); assertEquals(12, ingestDocument.getSourceAndMetadata().size()); assertEquals(2, ((List) ingestDocument.getSourceAndMetadata().get("oriKey6_knn")).size()); }