diff --git a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java index dc5b6e8f2..80fcf90f4 100644 --- a/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java +++ b/src/main/java/org/opensearch/neuralsearch/plugin/NeuralSearch.java @@ -118,7 +118,7 @@ public Map getProcessors(Processor.Parameters paramet new TextImageEmbeddingProcessorFactory(clientAccessor, parameters.env, parameters.ingestService.getClusterService()), DocumentChunkingProcessor.TYPE, new DocumentChunkingProcessor.Factory( - parameters.env.settings(), + parameters.env, parameters.ingestService.getClusterService(), parameters.indicesService, parameters.analysisRegistry diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java index 290756c12..957652638 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java @@ -9,17 +9,15 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.LinkedHashMap; -import java.util.stream.Collectors; -import java.util.stream.IntStream; -import com.google.common.annotations.VisibleForTesting; import lombok.extern.log4j.Log4j2; +import org.apache.commons.lang3.StringUtils; import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.env.Environment; import org.opensearch.index.IndexService; import org.opensearch.cluster.service.ClusterService; -import org.opensearch.common.settings.Settings; import org.opensearch.index.analysis.AnalysisRegistry; +import org.opensearch.index.mapper.MapperService; import org.opensearch.indices.IndicesService; import org.opensearch.index.IndexSettings; import org.opensearch.ingest.AbstractProcessor; @@ -45,8 +43,6 @@ public final class DocumentChunkingProcessor extends AbstractProcessor { private final Set supportedChunkers = ChunkerFactory.getAllChunkers(); - private final Settings settings; - private String chunkerType; private Map chunkerParameters; @@ -59,12 +55,14 @@ public final class DocumentChunkingProcessor extends AbstractProcessor { private final AnalysisRegistry analysisRegistry; + private final Environment environment; + public DocumentChunkingProcessor( String tag, String description, Map fieldMap, Map algorithmMap, - Settings settings, + Environment environment, ClusterService clusterService, IndicesService indicesService, AnalysisRegistry analysisRegistry @@ -72,7 +70,7 @@ public DocumentChunkingProcessor( super(tag, description); validateAndParseAlgorithmMap(algorithmMap); this.fieldMap = fieldMap; - this.settings = settings; + this.environment = environment; this.clusterService = clusterService; this.indicesService = indicesService; this.analysisRegistry = analysisRegistry; @@ -82,12 +80,6 @@ public String getType() { return TYPE; } - private List chunk(String content) { - // assume that content is either a map, list or string - IFieldChunker chunker = ChunkerFactory.create(chunkerType, analysisRegistry); - return chunker.chunk(content, chunkerParameters); - } - @SuppressWarnings("unchecked") private void validateAndParseAlgorithmMap(Map algorithmMap) { if (algorithmMap.size() != 1) { @@ -120,23 +112,56 @@ private void validateAndParseAlgorithmMap(Map algorithmMap) { } } - @Override - public IngestDocument execute(IngestDocument ingestDocument) { - Map processMap = buildMapWithProcessorKeyAndOriginalValue(ingestDocument); - List inferenceList = createInferenceList(processMap); - if (inferenceList.isEmpty()) { - return ingestDocument; - } else { - return doExecute(ingestDocument, processMap, inferenceList); + @SuppressWarnings("unchecked") + private boolean isListString(Object value) { + // an empty list is also List + if (!(value instanceof List)) { + return false; + } + for (Object element : (List) value) { + if (!(element instanceof String)) { + return false; + } + } + return true; + } + + private List chunkString(String content) { + // assume that content is either a map, list or string + IFieldChunker chunker = ChunkerFactory.create(chunkerType, analysisRegistry); + return chunker.chunk(content, chunkerParameters); + } + + private List chunkList(List contentList) { + // flatten the List> output to List + List result = new ArrayList<>(); + for (String content : contentList) { + result.addAll(chunkString(content)); + } + return result; + } + + @SuppressWarnings("unchecked") + private List chunkLeafType(Object value) { + // leaf type is either String or List + List chunkedResult = null; + if (value instanceof String) { + chunkedResult = chunkString(String.valueOf(value)); + } else if (isListString(value)) { + chunkedResult = chunkList((List) value); } + return chunkedResult; } - public IngestDocument doExecute(IngestDocument ingestDocument, Map ProcessMap, List inferenceList) { + @Override + public IngestDocument execute(IngestDocument ingestDocument) { + validateEmbeddingFieldsValue(ingestDocument); + if (Objects.equals(chunkerType, FIXED_LENGTH_ALGORITHM)) { // add maxTokenCount setting from index metadata to chunker parameters Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); - int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(settings); + int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(environment.settings()); IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName); if (indexMetadata != null) { // if the index exists, read maxTokenCount from the index setting @@ -146,182 +171,93 @@ public IngestDocument doExecute(IngestDocument ingestDocument, Map> chunkedResults = new ArrayList<>(); - for (String inferenceString : inferenceList) { - chunkedResults.add(chunk(inferenceString)); - } - setTargetFieldsToDocument(ingestDocument, ProcessMap, chunkedResults); + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + chunkMapType(sourceAndMetadataMap, fieldMap); + sourceAndMetadataMap.forEach(ingestDocument::setFieldValue); return ingestDocument; } - private List buildResultForListType(List sourceValue, List results, InferenceProcessor.IndexWrapper indexWrapper) { - Object peek = sourceValue.get(0); - if (peek instanceof String) { - List keyToResult = new ArrayList<>(); - IntStream.range(0, sourceValue.size()).forEachOrdered(x -> keyToResult.add(results.get(indexWrapper.index++))); - return keyToResult; - } else { - List> keyToResult = new ArrayList<>(); - for (Object nestedList : sourceValue) { - List nestedResult = new ArrayList<>(); - IntStream.range(0, ((List) nestedList).size()).forEachOrdered(x -> nestedResult.add(results.get(indexWrapper.index++))); - keyToResult.add(nestedResult); - } - return keyToResult; - } - } - - private Map buildMapWithProcessorKeyAndOriginalValue(IngestDocument ingestDocument) { + private void validateEmbeddingFieldsValue(IngestDocument ingestDocument) { Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map mapWithProcessorKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - Map treeRes = new LinkedHashMap<>(); - buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, treeRes); - mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); - } else { - mapWithProcessorKeys.put(String.valueOf(targetKey), sourceAndMetadataMap.get(originalKey)); - } - } - return mapWithProcessorKeys; - } - - private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map treeRes - ) { - if (processorKey == null || sourceAndMetadataMap == null) return; - if (processorKey instanceof Map) { - Map next = new LinkedHashMap<>(); - if (sourceAndMetadataMap.get(parentKey) instanceof Map) { - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - next - ); - } - } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - List> list = (List>) sourceAndMetadataMap.get(parentKey); - List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); - Map map = new LinkedHashMap<>(); - map.put(nestedFieldMapEntry.getKey(), listOfStrings); - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - map, - next - ); + for (Map.Entry embeddingFieldsEntry : fieldMap.entrySet()) { + Object sourceValue = sourceAndMetadataMap.get(embeddingFieldsEntry.getKey()); + if (sourceValue != null) { + String sourceKey = embeddingFieldsEntry.getKey(); + Class sourceValueClass = sourceValue.getClass(); + if (List.class.isAssignableFrom(sourceValueClass) || Map.class.isAssignableFrom(sourceValueClass)) { + validateNestedTypeValue(sourceKey, sourceValue, 1); + } else if (!String.class.isAssignableFrom(sourceValueClass)) { + throw new IllegalArgumentException("field [" + sourceKey + "] is neither string nor nested type, cannot process it"); + } else if (StringUtils.isBlank(sourceValue.toString())) { + throw new IllegalArgumentException("field [" + sourceKey + "] has empty string value, cannot process it"); } } - treeRes.put(parentKey, next); - } else { - String key = String.valueOf(processorKey); - treeRes.put(key, sourceAndMetadataMap.get(parentKey)); } } - @SuppressWarnings({ "unchecked" }) - private List createInferenceList(Map knnKeyMap) { - List texts = new ArrayList<>(); - knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof List) { - for (Object nestedValue : (List) sourceValue) { - if (nestedValue instanceof String) { - texts.add((String) nestedValue); - } else { - texts.addAll((List) nestedValue); - } - } - } else if (sourceValue instanceof Map) { - createInferenceListForMapTypeInput(sourceValue, texts); - } else { - texts.add(sourceValue.toString()); - } - }); - return texts; - } - - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Object sourceValue, List texts) { - if (sourceValue instanceof Map) { - ((Map) sourceValue).forEach((k, v) -> createInferenceListForMapTypeInput(v, texts)); - } else if (sourceValue instanceof List) { - texts.addAll(((List) sourceValue)); - } else { - if (sourceValue == null) return; - texts.add(sourceValue.toString()); + @SuppressWarnings({ "rawtypes", "unchecked" }) + private void validateNestedTypeValue(String sourceKey, Object sourceValue, int maxDepth) { + if (maxDepth > MapperService.INDEX_MAPPING_DEPTH_LIMIT_SETTING.get(environment.settings())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] reached max depth limit, cannot process it"); + } else if ((List.class.isAssignableFrom(sourceValue.getClass()))) { + validateListTypeValue(sourceKey, sourceValue, maxDepth); + } else if (Map.class.isAssignableFrom(sourceValue.getClass())) { + ((Map) sourceValue).values() + .stream() + .filter(Objects::nonNull) + .forEach(x -> validateNestedTypeValue(sourceKey, x, maxDepth + 1)); + } else if (!String.class.isAssignableFrom(sourceValue.getClass())) { + throw new IllegalArgumentException("map type field [" + sourceKey + "] has non-string type, cannot process it"); } } - private void setTargetFieldsToDocument(IngestDocument ingestDocument, Map processorMap, List results) { - Objects.requireNonNull(results, "embedding failed, inference returns null result!"); - log.debug("Model inference result fetched, starting build vector output!"); - Map result = buildResult(processorMap, results, ingestDocument.getSourceAndMetadata()); - result.forEach(ingestDocument::setFieldValue); - } - - @VisibleForTesting - Map buildResult(Map processorMap, List results, Map sourceAndMetadataMap) { - InferenceProcessor.IndexWrapper indexWrapper = new InferenceProcessor.IndexWrapper(0); - Map result = new LinkedHashMap<>(); - for (Map.Entry knnMapEntry : processorMap.entrySet()) { - String knnKey = knnMapEntry.getKey(); - Object sourceValue = knnMapEntry.getValue(); - if (sourceValue instanceof String) { - result.put(knnKey, results.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - result.put(knnKey, buildResultForListType((List) sourceValue, results, indexWrapper)); - } else if (sourceValue instanceof Map) { - putResultToSourceMapForMapType(knnKey, sourceValue, results, indexWrapper, sourceAndMetadataMap); + @SuppressWarnings({ "rawtypes" }) + private void validateListTypeValue(String sourceKey, Object sourceValue, int maxDepth) { + for (Object value : (List) sourceValue) { + if (value instanceof Map) { + validateNestedTypeValue(sourceKey, value, maxDepth + 1); + } else if (value == null) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has null, cannot process it"); + } else if (!(value instanceof String)) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has non string value, cannot process it"); + } else if (StringUtils.isBlank(value.toString())) { + throw new IllegalArgumentException("list type field [" + sourceKey + "] has empty string, cannot process it"); } } - return result; } - @SuppressWarnings({ "unchecked" }) - private void putResultToSourceMapForMapType( - String processorKey, - Object sourceValue, - List results, - InferenceProcessor.IndexWrapper indexWrapper, - Map sourceAndMetadataMap - ) { - if (processorKey == null || sourceAndMetadataMap == null || sourceValue == null) return; - if (sourceValue instanceof Map) { - for (Map.Entry inputNestedMapEntry : ((Map) sourceValue).entrySet()) { - if (sourceAndMetadataMap.get(processorKey) instanceof List) { - // build output for list of nested objects - for (Map nestedElement : (List>) sourceAndMetadataMap.get(processorKey)) { - nestedElement.put(inputNestedMapEntry.getKey(), results.get(indexWrapper.index++)); + @SuppressWarnings("unchecked") + private void chunkMapType(Map sourceAndMetadataMap, Map fieldMap) { + for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { + String originalKey = fieldMapEntry.getKey(); + Object targetKey = fieldMapEntry.getValue(); + if (targetKey instanceof Map) { + // call this method recursively when target key is a map + Object sourceObject = sourceAndMetadataMap.get(originalKey); + if (sourceObject instanceof List) { + List sourceObjectList = (List) sourceObject; + for (Object source : sourceObjectList) { + if (source instanceof Map) { + chunkMapType((Map) source, (Map) targetKey); + } } - } else { - putResultToSourceMapForMapType( - inputNestedMapEntry.getKey(), - inputNestedMapEntry.getValue(), - results, - indexWrapper, - (Map) sourceAndMetadataMap.get(processorKey) - ); + } else if (sourceObject instanceof Map) { + chunkMapType((Map) sourceObject, (Map) targetKey); + } + } else { + // chunk the object when target key is a string + Object chunkObject = sourceAndMetadataMap.get(originalKey); + List chunkedResult = chunkLeafType(chunkObject); + if (chunkedResult != null) { + sourceAndMetadataMap.put(String.valueOf(targetKey), chunkedResult); } } - } else if (sourceValue instanceof String) { - sourceAndMetadataMap.put(processorKey, results.get(indexWrapper.index++)); - } else if (sourceValue instanceof List) { - sourceAndMetadataMap.put(processorKey, buildResultForListType((List) sourceValue, results, indexWrapper)); } } public static class Factory implements Processor.Factory { - private final Settings settings; + private final Environment environment; private final ClusterService clusterService; @@ -329,8 +265,13 @@ public static class Factory implements Processor.Factory { private final AnalysisRegistry analysisRegistry; - public Factory(Settings settings, ClusterService clusterService, IndicesService indicesService, AnalysisRegistry analysisRegistry) { - this.settings = settings; + public Factory( + Environment environment, + ClusterService clusterService, + IndicesService indicesService, + AnalysisRegistry analysisRegistry + ) { + this.environment = environment; this.clusterService = clusterService; this.indicesService = indicesService; this.analysisRegistry = analysisRegistry; @@ -350,7 +291,7 @@ public DocumentChunkingProcessor create( description, fieldMap, algorithmMap, - settings, + environment, clusterService, indicesService, analysisRegistry