From ee585c3d4ccff35e072cf1aa7b67556c6edafbc7 Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Mon, 11 Mar 2024 12:33:36 +0800 Subject: [PATCH] track chunkCount within function Signed-off-by: yuye-aws --- .../processor/DocumentChunkingProcessor.java | 188 ++++++++++-------- .../processor/chunker/Chunker.java | 1 - .../processor/chunker/ChunkerFactory.java | 2 - 3 files changed, 100 insertions(+), 91 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java index 63d04a505..c3432dfaa 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/DocumentChunkingProcessor.java @@ -4,7 +4,9 @@ */ package org.opensearch.neuralsearch.processor; +import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; import java.util.Set; import java.util.ArrayList; import java.util.List; @@ -49,12 +51,9 @@ public final class DocumentChunkingProcessor extends AbstractProcessor { private static final int DEFAULT_MAX_CHUNK_LIMIT = -1; - private int maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT; - - private String chunkerType; - - private Map chunkerParameters; + private int maxChunkLimit; + private Chunker chunker; private final Map fieldMap; private final ClusterService clusterService; @@ -65,19 +64,6 @@ public final class DocumentChunkingProcessor extends AbstractProcessor { private final Environment environment; - /** - * Users may specify parameter max_chunk_limit for a restriction on the number of strings from chunking results. - * Here the chunkCountWrapper is to store and increase the number of chunks across all output fields. - * chunkCount: the number of chunks of chunking result. - */ - static class ChunkCountWrapper { - private int chunkCount; - - protected ChunkCountWrapper(int chunkCount) { - this.chunkCount = chunkCount; - } - } - public DocumentChunkingProcessor( String tag, String description, @@ -108,41 +94,42 @@ private void validateAndParseAlgorithmMap(Map algorithmMap) { "Unable to create the processor as [" + ALGORITHM_FIELD + "] must contain and only contain 1 algorithm" ); } - - for (Map.Entry algorithmEntry : algorithmMap.entrySet()) { - String algorithmKey = algorithmEntry.getKey(); - Object algorithmValue = algorithmEntry.getValue(); - Set supportedChunkers = ChunkerFactory.getAllChunkers(); - if (!supportedChunkers.contains(algorithmKey)) { - throw new IllegalArgumentException( - "Unable to create the processor as chunker algorithm [" - + algorithmKey - + "] is not supported. Supported chunkers types are " - + supportedChunkers - ); - } - if (!(algorithmValue instanceof Map)) { + Entry algorithmEntry = algorithmMap.entrySet().iterator().next(); + String algorithmKey = algorithmEntry.getKey(); + Object algorithmValue = algorithmEntry.getValue(); + Set supportedChunkers = ChunkerFactory.getAllChunkers(); + if (!supportedChunkers.contains(algorithmKey)) { + throw new IllegalArgumentException( + "Unable to create the processor as chunker algorithm [" + + algorithmKey + + "] is not supported. Supported chunkers types are " + + supportedChunkers + ); + } + if (!(algorithmValue instanceof Map)) { + throw new IllegalArgumentException( + "Unable to create the processor as [" + algorithmKey + "] parameters cannot be cast to [" + Map.class.getName() + "]" + ); + } + Map chunkerParameters = (Map) algorithmValue; + if (Objects.equals(algorithmKey, FIXED_TOKEN_LENGTH_ALGORITHM)) { + chunkerParameters.put(FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD, analysisRegistry); + } + this.chunker = ChunkerFactory.create(algorithmKey, chunkerParameters); + if (chunkerParameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) { + String maxChunkLimitString = chunkerParameters.get(MAX_CHUNK_LIMIT_FIELD).toString(); + if (!(NumberUtils.isParsable(maxChunkLimitString))) { throw new IllegalArgumentException( - "Unable to create the processor as [" + algorithmKey + "] parameters cannot be cast to [" + Map.class.getName() + "]" + "Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]" ); } - Chunker chunker = ChunkerFactory.create(algorithmKey, analysisRegistry); - this.chunkerType = algorithmKey; - this.chunkerParameters = (Map) algorithmValue; - chunker.validateParameters(chunkerParameters); - if (chunkerParameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) { - String maxChunkLimitString = chunkerParameters.get(MAX_CHUNK_LIMIT_FIELD).toString(); - if (!(NumberUtils.isParsable(maxChunkLimitString))) { - throw new IllegalArgumentException( - "Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]" - ); - } - int maxChunkLimit = NumberUtils.createInteger(maxChunkLimitString); - if (maxChunkLimit <= 0 && maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT) { - throw new IllegalArgumentException("Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] must be a positive integer"); - } - this.maxChunkLimit = maxChunkLimit; + int maxChunkLimit = NumberUtils.createInteger(maxChunkLimitString); + if (maxChunkLimit <= 0 && maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT) { + throw new IllegalArgumentException("Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] must be a positive integer"); } + this.maxChunkLimit = maxChunkLimit; + } else { + this.maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT; } } @@ -160,41 +147,60 @@ private boolean isListOfString(Object value) { return true; } - private List chunkString(String content, ChunkCountWrapper chunkCountWrapper) { - Chunker chunker = ChunkerFactory.create(chunkerType, analysisRegistry); - List result = chunker.chunk(content, chunkerParameters); - chunkCountWrapper.chunkCount += result.size(); - if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && chunkCountWrapper.chunkCount > maxChunkLimit) { + private int chunkString(String content, List result, Map runTimeParameters, int chunkCount) { + // chunk the content, return the updated chunkCount and add chunk passages to result + List contentResult; + if (chunker instanceof FixedTokenLengthChunker) { + contentResult = chunker.chunk(content, runTimeParameters); + } else { + contentResult = chunker.chunk(content); + } + chunkCount += contentResult.size(); + if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && chunkCount > maxChunkLimit) { throw new IllegalArgumentException( "Unable to create the processor as the number of chunks [" - + chunkCountWrapper.chunkCount + + chunkCount + "] exceeds the maximum chunk limit [" + maxChunkLimit + "]" ); } - return result; + result.addAll(contentResult); + return chunkCount; } - private List chunkList(List contentList, ChunkCountWrapper chunkCountWrapper) { + private int chunkList(List contentList, List result, Map runTimeParameters, int chunkCount) { // flatten the List> output to List - List result = new ArrayList<>(); for (String content : contentList) { - result.addAll(chunkString(content, chunkCountWrapper)); + chunkCount = chunkString(content, result, runTimeParameters, chunkCount); } - return result; + return chunkCount; } @SuppressWarnings("unchecked") - private List chunkLeafType(Object value, ChunkCountWrapper chunkCountWrapper) { + private int chunkLeafType(Object value, List result, Map runTimeParameters, int chunkCount) { // leaf type is either String or List - List chunkedResult = null; + // the result should be an empty string if (value instanceof String) { - chunkedResult = chunkString(value.toString(), chunkCountWrapper); + chunkCount = chunkString(value.toString(), result, runTimeParameters, chunkCount); } else if (isListOfString(value)) { - chunkedResult = chunkList((List) value, chunkCountWrapper); + chunkCount = chunkList((List) value, result, runTimeParameters, chunkCount); + } + return chunkCount; + } + + private int getMaxTokenCount(Map sourceAndMetadataMap) { + String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); + IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName); + int maxTokenCount; + if (indexMetadata != null) { + // if the index exists, read maxTokenCount from the index setting + IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex()); + maxTokenCount = indexService.getIndexSettings().getMaxTokenCount(); + } else { + maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(environment.settings()); } - return chunkedResult; + return maxTokenCount; } /** @@ -204,23 +210,14 @@ private List chunkLeafType(Object value, ChunkCountWrapper chunkCountWra @Override public IngestDocument execute(IngestDocument ingestDocument) { validateFieldsValue(ingestDocument); - ChunkCountWrapper chunkCountWrapper = new ChunkCountWrapper(0); - if (Objects.equals(chunkerType, FIXED_TOKEN_LENGTH_ALGORITHM)) { - // add maxTokenCount setting from index metadata to chunker parameters - Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString(); - int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(environment.settings()); - IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName); - if (indexMetadata != null) { - // if the index exists, read maxTokenCount from the index setting - IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex()); - maxTokenCount = indexService.getIndexSettings().getMaxTokenCount(); - } - chunkerParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount); - } - + int chunkCount = 0; + Map runtimeParameters = new HashMap<>(); Map sourceAndMetadataMap = ingestDocument.getSourceAndMetadata(); - chunkMapType(sourceAndMetadataMap, fieldMap, chunkCountWrapper); + if (chunker instanceof FixedTokenLengthChunker) { + int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap); + runtimeParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount); + } + chunkMapType(sourceAndMetadataMap, fieldMap, runtimeParameters, chunkCount); return ingestDocument; } @@ -269,7 +266,12 @@ private void validateListTypeValue(String sourceKey, Object sourceValue, int max } @SuppressWarnings("unchecked") - private void chunkMapType(Map sourceAndMetadataMap, Map fieldMap, ChunkCountWrapper chunkCountWrapper) { + private int chunkMapType( + Map sourceAndMetadataMap, + Map fieldMap, + Map runtimeParameters, + int chunkCount + ) { for (Map.Entry fieldMapEntry : fieldMap.entrySet()) { String originalKey = fieldMapEntry.getKey(); Object targetKey = fieldMapEntry.getValue(); @@ -280,20 +282,30 @@ private void chunkMapType(Map sourceAndMetadataMap, Map sourceObjectList = (List) sourceObject; for (Object source : sourceObjectList) { if (source instanceof Map) { - chunkMapType((Map) source, (Map) targetKey, chunkCountWrapper); + chunkCount = chunkMapType( + (Map) source, + (Map) targetKey, + runtimeParameters, + chunkCount + ); } } } else if (sourceObject instanceof Map) { - chunkMapType((Map) sourceObject, (Map) targetKey, chunkCountWrapper); + chunkCount = chunkMapType( + (Map) sourceObject, + (Map) targetKey, + runtimeParameters, + chunkCount + ); } } else { // chunk the object when target key is a string Object chunkObject = sourceAndMetadataMap.get(originalKey); - List chunkedResult = chunkLeafType(chunkObject, chunkCountWrapper); - if (chunkedResult != null) { - sourceAndMetadataMap.put(String.valueOf(targetKey), chunkedResult); - } + List chunkedResult = new ArrayList<>(); + chunkCount = chunkLeafType(chunkObject, chunkedResult, runtimeParameters, chunkCount); + sourceAndMetadataMap.put(String.valueOf(targetKey), chunkedResult); } } + return chunkCount; } } diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java index 90f5a11ee..29a0539f2 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/Chunker.java @@ -6,7 +6,6 @@ import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.Map; import java.util.List; diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java index 2f72eab19..086cb0b71 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/ChunkerFactory.java @@ -7,8 +7,6 @@ import java.util.Map; import java.util.Set; -import org.opensearch.index.analysis.AnalysisRegistry; - /** * A factory to create different chunking algorithm classes and return all supported chunking algorithms. */