diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index a1ea84a5a..fe201abae 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -10,11 +10,13 @@ import java.util.Map; import java.util.Objects; import java.util.function.BiConsumer; +import java.util.function.Supplier; import java.util.stream.Collectors; import java.util.stream.IntStream; import org.apache.commons.lang3.StringUtils; import org.opensearch.env.Environment; +import org.opensearch.index.mapper.MapperService; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.IngestDocument; import org.opensearch.neuralsearch.ml.MLCommonsClientAccessor; @@ -43,11 +45,11 @@ public abstract class InferenceProcessor extends AbstractProcessor { protected final String modelId; - protected final Map fieldMap; + private final Map fieldMap; protected final MLCommonsClientAccessor mlCommonsClientAccessor; - protected final Environment environment; + private final Environment environment; public InferenceProcessor( String tag, @@ -61,8 +63,9 @@ public InferenceProcessor( ) { super(tag, description); this.type = type; - validateEmbeddingConfiguration(fieldMap); if (StringUtils.isBlank(modelId)) throw new IllegalArgumentException("model_id is null or empty, cannot process it"); + validateEmbeddingConfiguration(fieldMap); + this.listTypeNestedMapKey = listTypeNestedMapKey; this.modelId = modelId; this.fieldMap = fieldMap; @@ -103,13 +106,13 @@ public IngestDocument execute(IngestDocument ingestDocument) throws Exception { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { - processorInputValidator.validateFieldsValue(fieldMap, environment, ingestDocument, false); - Map processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(processMap); - if (inferenceList.isEmpty()) { + validateEmbeddingFieldsValue(ingestDocument); + Map ProcessMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); + List inferenceList = createInferenceList(ProcessMap); + if (inferenceList.size() == 0) { handler.accept(ingestDocument, null); } else { - doExecute(ingestDocument, processMap, inferenceList, handler); + doExecute(ingestDocument, ProcessMap, inferenceList, handler); } } catch (Exception e) { handler.accept(null, e); @@ -117,18 +120,12 @@ public void execute(IngestDocument ingestDocument, BiConsumer createInferenceList(Map knnKeyMap) { + private List createInferenceList(Map knnKeyMap) { List texts = new ArrayList<>(); knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { Object sourceValue = knnMapEntry.getValue(); if (sourceValue instanceof List) { - for (Object nestedValue : (List) sourceValue) { - if (nestedValue instanceof String) { - texts.add((String) nestedValue); - } else { - texts.addAll((List) nestedValue); - } - } + texts.addAll(((List) sourceValue)); } else if (sourceValue instanceof Map) { createInferenceListForMapTypeInput(sourceValue, texts); } else { @@ -207,16 +204,68 @@ private void buildMapWithProcessorKeyAndOriginalValueForMapType( } } - protected void setTargetFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, () -> 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); + } + } + } + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + int maxDepth = maxDepthSupplier.get(); + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue, maxDepthSupplier); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, () -> maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue(String sourceKey, Object sourceValue, Supplier maxDepthSupplier) { + for (Object value : (List) sourceValue) { + if (value instanceof Map) { + validateNestedTypeValue(sourceKey, value, () -> maxDepthSupplier.get() + 1); + } else if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); + } + } + } + + protected void setVectorFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { Objects.requireNonNull(results, "embedding failed, inference returns null result!"); log.debug("Model inference result fetched, starting build vector output!"); - Map result = buildResult(processorMap, results, ingestDocument.getSourceAndMetadata()); - result.forEach(ingestDocument::setFieldValue); + Map nlpResult = buildNLPResult(processorMap, results, ingestDocument.getSourceAndMetadata()); + nlpResult.forEach(ingestDocument::setFieldValue); } @SuppressWarnings({ "unchecked" }) @VisibleForTesting - Map buildResult(Map processorMap, List results, Map sourceAndMetadataMap) { + Map buildNLPResult(Map processorMap, List results, Map sourceAndMetadataMap) { IndexWrapper indexWrapper = new IndexWrapper(0); Map result = new LinkedHashMap<>(); for (Map.Entry knnMapEntry : processorMap.entrySet()) { @@ -225,16 +274,16 @@ Map buildResult(Map processorMap, List result if (sourceValue instanceof String) { result.put(knnKey, results.get(indexWrapper.index++)); } else if (sourceValue instanceof List) { - result.put(knnKey, buildResultForListType((List) sourceValue, results, indexWrapper)); + result.put(knnKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); } else if (sourceValue instanceof Map) { - putResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); + putNLPResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); } } return result; } @SuppressWarnings({ "unchecked" }) - private void putResultToSourceMapForMapType( + private void putNLPResultToSourceMapForMapType( String processorKey, Object sourceValue, List results, @@ -245,12 +294,12 @@ private void putResultToSourceMapForMapType( if (sourceValue instanceof Map) { for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { if (sourceAndMetadataMap.get(processorKey) instanceof List) { - // build output for list of nested objects + // build nlp output for list of nested objects for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); } } else { - putResultToSourceMapForMapType( + putNLPResultToSourceMapForMapType( inputNestedMapEntry.getKey(), inputNestedMapEntry.getValue(), results, @@ -262,27 +311,15 @@ private void putResultToSourceMapForMapType( } else if (sourceValue instanceof String) { sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put(processorKey, buildResultForListType((List) sourceValue, results, indexWrapper)); + sourceAndMetadataMap.put(processorKey, buildNLPResultForListType((List) sourceValue, results, indexWrapper)); } } - protected List buildResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { - Object peek = sourceValue.get(0); - if (peek instanceof String) { - List> keyToResult = new ArrayList<>(); - IntStream.range(0, sourceValue.size()) - .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); - return keyToResult; - } else { - List>> keyToResult = new ArrayList<>(); - for (Object nestedList : sourceValue) { - List> nestedResult = new ArrayList<>(); - IntStream.range(0, ((List) nestedList).size()) - .forEachOrdered(x -> nestedResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); - keyToResult.add(nestedResult); - } - return keyToResult; - } + private List> buildNLPResultForListType(List sourceValue, List results, IndexWrapper indexWrapper) { + List> keyToResult = new ArrayList<>(); + IntStream.range(0, sourceValue.size()) + .forEachOrdered(x -> keyToResult.add(ImmutableMap.of(listTypeNestedMapKey, results.get(indexWrapper.index++)))); + return keyToResult; } @Override @@ -299,7 +336,7 @@ public String getType() { * index: the index pointer of the text embedding result. */ static class IndexWrapper { - protected int index; + private int index; protected IndexWrapper(int index) { this.index = index;