From 59f8eac36177288b9b0a8d913fd0ee8b7d0a2aee Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Tue, 5 Mar 2024 16:13:19 +0800 Subject: [PATCH] api refactor for document chunking processor Signed-off-by: yuye-aws --- .../processor/DocumentChunkingProcessor.java | 302 ++++-------------- .../processor/InferenceProcessor.java | 2 +- .../DocumentChunkingProcessorTests.java | 56 +--- 3 files changed, 75 insertions(+), 285 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java index a5312c04d..34a02eb35 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java @@ -4,42 +4,50 @@ */ package org.opensearch.neuralsearch.processor; -import java.util.HashMap; import java.util.Map; import java.util.Set; import java.util.ArrayList; import java.util.List; -import java.util.LinkedHashMap; +import java.util.Objects; import java.util.function.BiConsumer; -import java.util.stream.Collectors; +import org.opensearch.cluster.metadata.IndexMetadata; +import org.opensearch.index.IndexService; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; import org.opensearch.env.Environment; import org.opensearch.index.analysis.AnalysisRegistry; import org.opensearch.indices.IndicesService; +import org.opensearch.index.IndexSettings; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.neuralsearch.processor.chunker.ChunkerFactory; import org.opensearch.neuralsearch.processor.chunker.IFieldChunker; +import org.opensearch.index.mapper.IndexFieldMapper; +import org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker; import static org.opensearch.ingest.ConfigurationUtils.readMap; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerFactory.DELIMITER_ALGORITHM; +import static org.opensearch.neuralsearch.processor.chunker.ChunkerFactory.FIXED_LENGTH_ALGORITHM; public final class DocumentChunkingProcessor extends InferenceProcessor { public static final String TYPE = "chunking"; - public static final String OUTPUT_FIELD = "output_field"; public static final String FIELD_MAP_FIELD = "field_map"; - public static final String LIST_TYPE_NESTED_MAP_KEY = "chunking"; + public static final String ALGORITHM_FIELD = "algorithm"; - private final Map chunkingFieldMap; + public static final String LIST_TYPE_NESTED_MAP_KEY = "chunking"; private final Set supportedChunkers = ChunkerFactory.getAllChunkers(); private final Settings settings; + private String chunkerType; + + private Map chunkerParameters; + private final ClusterService clusterService; private final IndicesService indicesService; @@ -50,6 +58,7 @@ public DocumentChunkingProcessor( String tag, String description, Map fieldMap, + Map algorithmMap, Settings settings, ClusterService clusterService, IndicesService indicesService, @@ -57,20 +66,9 @@ public DocumentChunkingProcessor( Environment environment, ProcessorInputValidator processorInputValidator ) { - super( - tag, - description, - TYPE, - LIST_TYPE_NESTED_MAP_KEY, - null, - transformFieldMap(fieldMap), - null, - environment, - processorInputValidator - ); - validateFieldMap(fieldMap, ""); + super(tag, description, TYPE, LIST_TYPE_NESTED_MAP_KEY, null, fieldMap, null, environment, processorInputValidator); + validateAndParseAlgorithmMap(algorithmMap); this.settings = settings; - this.chunkingFieldMap = fieldMap; this.clusterService = clusterService; this.indicesService = indicesService; this.analysisRegistry = analysisRegistry; @@ -80,115 +78,44 @@ public String getType() { return TYPE; } - private Object chunk(IFieldChunker chunker, Object content, Map chunkerParameters) { + private List chunk(String content) { // assume that content is either a map, list or string - if (content instanceof Map) { - Map chunkedPassageMap = new HashMap<>(); - Map contentMap = (Map) content; - for (Map.Entry contentEntry : contentMap.entrySet()) { - String contentKey = contentEntry.getKey(); - Object contentValue = contentEntry.getValue(); - // contentValue can also be a map, list or string - chunkedPassageMap.put(contentKey, chunk(chunker, contentValue, chunkerParameters)); - } - return chunkedPassageMap; - } else if (content instanceof List) { - List chunkedPassageList = new ArrayList<>(); - List contentList = (List) content; - for (Object contentElement : contentList) { - chunkedPassageList.addAll(chunker.chunk((String) contentElement, chunkerParameters)); - } - return chunkedPassageList; - } else { - return chunker.chunk((String) content, chunkerParameters); - } + IFieldChunker chunker = ChunkerFactory.create(chunkerType, analysisRegistry); + return chunker.chunk(content, chunkerParameters); } @SuppressWarnings("unchecked") - private void validateFieldMap(Map fieldMap, String inputPrefix) { - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String inputField = fieldMapEntry.getKey(); - if (fieldMapEntry.getValue() instanceof Map) { - Map insideFieldMap = (Map) fieldMapEntry.getValue(); - if (insideFieldMap.containsKey(OUTPUT_FIELD)) { - validateChunkingFieldMap(insideFieldMap, inputPrefix + "." + inputField); - } else { - validateFieldMap(insideFieldMap, inputPrefix + "." + inputField); - } - } - } - } - - @SuppressWarnings("unchecked") - private void validateChunkingFieldMap(Map fieldMap, String inputField) { - // this function validates the parameters for chunking processors with: - // 1. the output field is a string - // 2. the chunker parameters must include and only include 1 type of chunker - // 3. the chunker parameters must be validated by each algorithm - Object outputField = fieldMap.get(OUTPUT_FIELD); - - if (!(outputField instanceof String)) { + private void validateAndParseAlgorithmMap(Map algorithmMap) { + if (algorithmMap.size() != 1) { throw new IllegalArgumentException( - "parameters for output field [" + OUTPUT_FIELD + "] cannot be cast to [" + String.class.getName() + "]" + "Unable to create the processor as [" + ALGORITHM_FIELD + "] must contain and only contain 1 algorithm" ); } - // check non string parameter key - // validate each algorithm - int chunkingAlgorithmCount = 0; - Map chunkerParameters; - for (Map.Entry parameterEntry : fieldMap.entrySet()) { - if (!(parameterEntry.getKey() instanceof String)) { - throw new IllegalArgumentException("found parameter entry with non-string key"); + for (Map.Entry algorithmEntry : algorithmMap.entrySet()) { + String algorithmKey = algorithmEntry.getKey(); + Object algorithmValue = algorithmEntry.getValue(); + if (!supportedChunkers.contains(algorithmKey)) { + throw new IllegalArgumentException( + "Unable to create the processor as chunker algorithm [" + + algorithmKey + + "] is not supported. Supported chunkers types are [" + + FIXED_LENGTH_ALGORITHM + + ", " + + DELIMITER_ALGORITHM + + "]" + ); } - String parameterKey = (String) parameterEntry.getKey(); - if (supportedChunkers.contains(parameterKey)) { - chunkingAlgorithmCount += 1; - chunkerParameters = (Map) parameterEntry.getValue(); - IFieldChunker chunker = ChunkerFactory.create(parameterKey, analysisRegistry); - chunker.validateParameters(chunkerParameters); + if (!(algorithmValue instanceof Map)) { + throw new IllegalArgumentException( + "Unable to create the processor as [" + ALGORITHM_FIELD + "] cannot be cast to [" + Map.class.getName() + "]" + ); } - } - - // should only define one algorithm - if (chunkingAlgorithmCount != 1) { - throw new IllegalArgumentException("input field [" + inputField + "] should has and only has 1 chunking algorithm"); + this.chunkerType = algorithmKey; + this.chunkerParameters = (Map) algorithmValue; } } - private static Map transformFieldMap(Map fieldMap) { - // transform the into field map for inference processor - Map transformedFieldMap = new HashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String inputField = fieldMapEntry.getKey(); - if (fieldMapEntry.getValue() instanceof Map) { - Map insideFieldMap = (Map) fieldMapEntry.getValue(); - if (insideFieldMap.containsKey(OUTPUT_FIELD)) { - Object outputField = insideFieldMap.get(OUTPUT_FIELD); - transformedFieldMap.put(inputField, outputField); - } else { - transformedFieldMap.put(inputField, transformFieldMap(insideFieldMap)); - } - } - } - return transformedFieldMap; - } - - private List> chunk(List contents, Map> parameter) { - // parameter only contains 1 key defining chunker type - // its value should be chunking parameters - List> chunkedContents = new ArrayList<>(); - for (Map.Entry> parameterEntry : parameter.entrySet()) { - String type = parameterEntry.getKey(); - Map chunkerParameters = parameterEntry.getValue(); - IFieldChunker chunker = ChunkerFactory.create(type, analysisRegistry); - for (String content : contents) { - chunkedContents.add(chunker.chunk(content, chunkerParameters)); - } - } - return chunkedContents; - } - @Override public void doExecute( IngestDocument ingestDocument, @@ -196,114 +123,30 @@ public void doExecute( List inferenceList, BiConsumer handler ) { - throw new RuntimeException("method doExecute() not implemented in document chunking processor"); - } - - @Override - public void execute(IngestDocument ingestDocument, BiConsumer handler) { try { processorInputValidator.validateFieldsValue(fieldMap, environment, ingestDocument, false); - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - Map processMap = buildMapWithProcessorKeyAndOriginalValue(sourceAndMetadataMap, chunkingFieldMap); - // List inferenceList = createInferenceList(processMap); - // List> results = chunk(processMap); - // setTargetFieldsToDocument(ingestDocument, processMap, results); - handler.accept(ingestDocument, null); - /* - if (inferenceList.isEmpty()) { - handler.accept(ingestDocument, null); - } else { - // perform chunking - List> results = chunk(inferenceList, processMap); - setTargetFieldsToDocument(ingestDocument, processMap, results); - doExecute(ingestDocument, processMap, inferenceList, handler); - handler.accept(ingestDocument, null); - } - */ - } catch (Exception e) { - handler.accept(null, e); - } - } - - Map buildMapWithProcessorKeyAndOriginalValue( - Map sourceAndMetadataMap, - Map chunkingFieldMap - ) { - // the leaf map for processMap contains two key value pairs - // parameters: the chunker parameters, Map - // inferenceList: a list of strings to be chunked, List - Map mapWithProcessorKeys = new LinkedHashMap<>(); - for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { - String originalKey = fieldMapEntry.getKey(); - Object targetKey = fieldMapEntry.getValue(); - if (targetKey instanceof Map) { - Map treeRes = new LinkedHashMap<>(); - buildMapWithProcessorKeyAndOriginalValueForMapType(originalKey, targetKey, sourceAndMetadataMap, chunkingFieldMap, treeRes); - mapWithProcessorKeys.put(originalKey, treeRes.get(originalKey)); - } else { - Map leafMap = new HashMap<>(); - leafMap.put("parameters", chunkingFieldMap.get(originalKey)); - Object inferenceObject = sourceAndMetadataMap.get(originalKey); - // inferenceObject is either a string or a list of strings - if (inferenceObject instanceof List) { - leafMap.put("inferenceList", inferenceObject); - } else { - leafMap.put("inferenceList", stringToList((String) inferenceObject)); + if (Objects.equals(chunkerType, FIXED_LENGTH_ALGORITHM)) { + // add maxTokenCount setting from index metadata to chunker parameters + Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(settings); + IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName); + if (indexMetadata != null) { + // if the index exists, read maxTokenCount from the index setting + IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex()); + maxTokenCount = indexService.getIndexSettings().getMaxTokenCount(); } - mapWithProcessorKeys.put(String.valueOf(targetKey), leafMap); + chunkerParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount); } - } - return mapWithProcessorKeys; - } - @SuppressWarnings("unchecked") - private void buildMapWithProcessorKeyAndOriginalValueForMapType( - String parentKey, - Object processorKey, - Map sourceAndMetadataMap, - Map chunkingFieldMap, - Map treeRes - ) { - if (processorKey == null || sourceAndMetadataMap == null || chunkingFieldMap == null) return; - if (processorKey instanceof Map) { - Map next = new LinkedHashMap<>(); - if (sourceAndMetadataMap.get(parentKey) instanceof Map) { - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - (Map) sourceAndMetadataMap.get(parentKey), - (Map) chunkingFieldMap.get(parentKey), - next - ); - } - } else if (sourceAndMetadataMap.get(parentKey) instanceof List) { - for (Map.Entry nestedFieldMapEntry : ((Map) processorKey).entrySet()) { - List> list = (List>) sourceAndMetadataMap.get(parentKey); - List listOfStrings = list.stream().map(x -> x.get(nestedFieldMapEntry.getKey())).collect(Collectors.toList()); - Map map = new LinkedHashMap<>(); - map.put(nestedFieldMapEntry.getKey(), listOfStrings); - buildMapWithProcessorKeyAndOriginalValueForMapType( - nestedFieldMapEntry.getKey(), - nestedFieldMapEntry.getValue(), - map, - (Map) chunkingFieldMap.get(nestedFieldMapEntry.getKey()), - next - ); - } - } - treeRes.put(parentKey, next); - } else { - Map leafMap = new HashMap<>(); - leafMap.put("parameters", chunkingFieldMap.get(parentKey)); - Object inferenceObject = sourceAndMetadataMap.get(parentKey); - // inferenceObject is either a string or a list of strings - if (inferenceObject instanceof List) { - leafMap.put("inferenceList", inferenceObject); - } else { - leafMap.put("inferenceList", stringToList((String) inferenceObject)); + List> chunkedResults = new ArrayList<>(); + for (String inferenceString : inferenceList) { + chunkedResults.add(chunk(inferenceString)); } - treeRes.put(parentKey, leafMap); + setTargetFieldsToDocument(ingestDocument, ProcessMap, chunkedResults); + handler.accept(ingestDocument, null); + } catch (Exception e) { + handler.accept(null, e); } } @@ -313,29 +156,6 @@ private static List stringToList(String string) { return list; } - @SuppressWarnings("unchecked") - private List> createInferenceList(Map processMap) { - List> texts = new ArrayList<>(); - processMap.entrySet().stream().filter(processMapEntry -> processMapEntry.getValue() != null).forEach(processMapEntry -> { - Map sourceValue = (Map) processMapEntry.getValue(); - // get "inferenceList" key - createInferenceListForMapTypeInput(sourceValue, texts); - }); - return texts; - } - - @SuppressWarnings("unchecked") - private void createInferenceListForMapTypeInput(Map mapInput, List> texts) { - if (mapInput.containsKey("inferenceList")) { - texts.add((List) mapInput.get("inferenceList")); - return; - } - for (Map.Entry nestedFieldMapEntry : mapInput.entrySet()) { - Map nestedMapInput = (Map) nestedFieldMapEntry.getValue(); - createInferenceListForMapTypeInput(nestedMapInput, texts); - } - } - public static class Factory implements Processor.Factory { private final Settings settings; @@ -374,10 +194,12 @@ public DocumentChunkingProcessor create( Map config ) throws Exception { Map fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD); + Map algorithmMap = readMap(TYPE, processorTag, config, ALGORITHM_FIELD); return new DocumentChunkingProcessor( processorTag, description, fieldMap, + algorithmMap, settings, clusterService, indicesService, diff --git a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java index bf73be795..b88bee0e1 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/InferenceProcessor.java @@ -121,7 +121,7 @@ public void execute(IngestDocument ingestDocument, BiConsumer createInferenceList(Map knnKeyMap) { + protected List createInferenceList(Map knnKeyMap) { List texts = new ArrayList<>(); knnKeyMap.entrySet().stream().filter(knnMapEntry -> knnMapEntry.getValue() != null).forEach(knnMapEntry -> { Object sourceValue = knnMapEntry.getValue(); diff --git a/src/test/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessorTests.java b/src/test/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessorTests.java index 61590c5ed..c70d64d32 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessorTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessorTests.java @@ -110,12 +110,12 @@ private Map createDelimiterParameters() { @SneakyThrows private DocumentChunkingProcessor createFixedTokenLengthInstance() { Map config = new HashMap<>(); - Map fieldParameters = new HashMap<>(); - Map chunkerParameters = new HashMap<>(); - chunkerParameters.put(ChunkerFactory.FIXED_LENGTH_ALGORITHM, createFixedTokenLengthParameters()); - chunkerParameters.put(DocumentChunkingProcessor.OUTPUT_FIELD, OUTPUT_FIELD); - fieldParameters.put(INPUT_FIELD, chunkerParameters); - config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldParameters); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + algorithmMap.put(ChunkerFactory.FIXED_LENGTH_ALGORITHM, createFixedTokenLengthParameters()); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldMap); + config.put(DocumentChunkingProcessor.ALGORITHM_FIELD, algorithmMap); Map registry = new HashMap<>(); return factory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @@ -123,12 +123,12 @@ private DocumentChunkingProcessor createFixedTokenLengthInstance() { @SneakyThrows private DocumentChunkingProcessor createDelimiterInstance() { Map config = new HashMap<>(); - Map fieldParameters = new HashMap<>(); - Map chunkerParameters = new HashMap<>(); - chunkerParameters.put(ChunkerFactory.DELIMITER_ALGORITHM, createDelimiterParameters()); - chunkerParameters.put(DocumentChunkingProcessor.OUTPUT_FIELD, OUTPUT_FIELD); - fieldParameters.put(INPUT_FIELD, chunkerParameters); - config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldParameters); + Map fieldMap = new HashMap<>(); + Map algorithmMap = new HashMap<>(); + algorithmMap.put(ChunkerFactory.DELIMITER_ALGORITHM, createDelimiterParameters()); + fieldMap.put(INPUT_FIELD, OUTPUT_FIELD); + config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldMap); + config.put(DocumentChunkingProcessor.ALGORITHM_FIELD, algorithmMap); Map registry = new HashMap<>(); return factory.create(registry, PROCESSOR_TAG, DESCRIPTION, config); } @@ -171,38 +171,6 @@ public void testCreate_whenFieldMapWithIllegalParameterType_failure() { assertEquals("parameters for input field [key] cannot be cast to [java.util.Map]", illegalArgumentException.getMessage()); } - public void testCreate_whenFieldMapWithEmptyOutputField_failure() { - Map config = new HashMap<>(); - Map fieldMap = new HashMap<>(); - fieldMap.put(INPUT_FIELD, ImmutableMap.of()); - config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldMap); - Map registry = new HashMap<>(); - IllegalArgumentException illegalArgumentException = assertThrows( - IllegalArgumentException.class, - () -> factory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) - ); - assertEquals( - "parameters for input field [" + INPUT_FIELD + "] misses [" + DocumentChunkingProcessor.OUTPUT_FIELD + "], cannot process it.", - illegalArgumentException.getMessage() - ); - } - - public void testCreate_whenFieldMapWithIllegalOutputField_failure() { - Map config = new HashMap<>(); - Map fieldMap = new HashMap<>(); - fieldMap.put(INPUT_FIELD, ImmutableMap.of(DocumentChunkingProcessor.OUTPUT_FIELD, 1)); - config.put(DocumentChunkingProcessor.FIELD_MAP_FIELD, fieldMap); - Map registry = new HashMap<>(); - IllegalArgumentException illegalArgumentException = assertThrows( - IllegalArgumentException.class, - () -> factory.create(registry, PROCESSOR_TAG, DESCRIPTION, config) - ); - assertEquals( - "parameters for output field [output_field] cannot be cast to [java.lang.String]", - illegalArgumentException.getMessage() - ); - } - public void testCreate_whenFieldMapWithIllegalKey_failure() { Map config = new HashMap<>(); Map fieldMap = new HashMap<>();