Skip to content

Commit

Permalink
track chunkCount within function
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 11, 2024
1 parent c2dbc85 commit ee585c3
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
*/
package org.opensearch.neuralsearch.processor;

import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.ArrayList;
import java.util.List;
Expand Down Expand Up @@ -49,12 +51,9 @@ public final class DocumentChunkingProcessor extends AbstractProcessor {

private static final int DEFAULT_MAX_CHUNK_LIMIT = -1;

private int maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT;

private String chunkerType;

private Map<String, Object> chunkerParameters;
private int maxChunkLimit;

private Chunker chunker;
private final Map<String, Object> fieldMap;

private final ClusterService clusterService;
Expand All @@ -65,19 +64,6 @@ public final class DocumentChunkingProcessor extends AbstractProcessor {

private final Environment environment;

/**
* Users may specify parameter max_chunk_limit for a restriction on the number of strings from chunking results.
* Here the chunkCountWrapper is to store and increase the number of chunks across all output fields.
* chunkCount: the number of chunks of chunking result.
*/
static class ChunkCountWrapper {
private int chunkCount;

protected ChunkCountWrapper(int chunkCount) {
this.chunkCount = chunkCount;
}
}

public DocumentChunkingProcessor(
String tag,
String description,
Expand Down Expand Up @@ -108,41 +94,42 @@ private void validateAndParseAlgorithmMap(Map<String, Object> algorithmMap) {
"Unable to create the processor as [" + ALGORITHM_FIELD + "] must contain and only contain 1 algorithm"
);
}

for (Map.Entry<String, Object> algorithmEntry : algorithmMap.entrySet()) {
String algorithmKey = algorithmEntry.getKey();
Object algorithmValue = algorithmEntry.getValue();
Set<String> supportedChunkers = ChunkerFactory.getAllChunkers();
if (!supportedChunkers.contains(algorithmKey)) {
throw new IllegalArgumentException(
"Unable to create the processor as chunker algorithm ["
+ algorithmKey
+ "] is not supported. Supported chunkers types are "
+ supportedChunkers
);
}
if (!(algorithmValue instanceof Map)) {
Entry<String, Object> algorithmEntry = algorithmMap.entrySet().iterator().next();
String algorithmKey = algorithmEntry.getKey();
Object algorithmValue = algorithmEntry.getValue();
Set<String> supportedChunkers = ChunkerFactory.getAllChunkers();
if (!supportedChunkers.contains(algorithmKey)) {
throw new IllegalArgumentException(
"Unable to create the processor as chunker algorithm ["
+ algorithmKey
+ "] is not supported. Supported chunkers types are "
+ supportedChunkers
);
}
if (!(algorithmValue instanceof Map)) {
throw new IllegalArgumentException(
"Unable to create the processor as [" + algorithmKey + "] parameters cannot be cast to [" + Map.class.getName() + "]"
);
}
Map<String, Object> chunkerParameters = (Map<String, Object>) algorithmValue;
if (Objects.equals(algorithmKey, FIXED_TOKEN_LENGTH_ALGORITHM)) {
chunkerParameters.put(FixedTokenLengthChunker.ANALYSIS_REGISTRY_FIELD, analysisRegistry);
}
this.chunker = ChunkerFactory.create(algorithmKey, chunkerParameters);
if (chunkerParameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) {
String maxChunkLimitString = chunkerParameters.get(MAX_CHUNK_LIMIT_FIELD).toString();
if (!(NumberUtils.isParsable(maxChunkLimitString))) {
throw new IllegalArgumentException(
"Unable to create the processor as [" + algorithmKey + "] parameters cannot be cast to [" + Map.class.getName() + "]"
"Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
Chunker chunker = ChunkerFactory.create(algorithmKey, analysisRegistry);
this.chunkerType = algorithmKey;
this.chunkerParameters = (Map<String, Object>) algorithmValue;
chunker.validateParameters(chunkerParameters);
if (chunkerParameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) {
String maxChunkLimitString = chunkerParameters.get(MAX_CHUNK_LIMIT_FIELD).toString();
if (!(NumberUtils.isParsable(maxChunkLimitString))) {
throw new IllegalArgumentException(
"Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
int maxChunkLimit = NumberUtils.createInteger(maxChunkLimitString);
if (maxChunkLimit <= 0 && maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT) {
throw new IllegalArgumentException("Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] must be a positive integer");
}
this.maxChunkLimit = maxChunkLimit;
int maxChunkLimit = NumberUtils.createInteger(maxChunkLimitString);
if (maxChunkLimit <= 0 && maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT) {
throw new IllegalArgumentException("Parameter [" + MAX_CHUNK_LIMIT_FIELD + "] must be a positive integer");
}
this.maxChunkLimit = maxChunkLimit;
} else {
this.maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT;
}
}

Expand All @@ -160,41 +147,60 @@ private boolean isListOfString(Object value) {
return true;
}

private List<String> chunkString(String content, ChunkCountWrapper chunkCountWrapper) {
Chunker chunker = ChunkerFactory.create(chunkerType, analysisRegistry);
List<String> result = chunker.chunk(content, chunkerParameters);
chunkCountWrapper.chunkCount += result.size();
if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && chunkCountWrapper.chunkCount > maxChunkLimit) {
private int chunkString(String content, List<String> result, Map<String, Object> runTimeParameters, int chunkCount) {
// chunk the content, return the updated chunkCount and add chunk passages to result
List<String> contentResult;
if (chunker instanceof FixedTokenLengthChunker) {
contentResult = chunker.chunk(content, runTimeParameters);
} else {
contentResult = chunker.chunk(content);
}
chunkCount += contentResult.size();
if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && chunkCount > maxChunkLimit) {
throw new IllegalArgumentException(
"Unable to create the processor as the number of chunks ["
+ chunkCountWrapper.chunkCount
+ chunkCount
+ "] exceeds the maximum chunk limit ["
+ maxChunkLimit
+ "]"
);
}
return result;
result.addAll(contentResult);
return chunkCount;
}

private List<String> chunkList(List<String> contentList, ChunkCountWrapper chunkCountWrapper) {
private int chunkList(List<String> contentList, List<String> result, Map<String, Object> runTimeParameters, int chunkCount) {
// flatten the List<List<String>> output to List<String>
List<String> result = new ArrayList<>();
for (String content : contentList) {
result.addAll(chunkString(content, chunkCountWrapper));
chunkCount = chunkString(content, result, runTimeParameters, chunkCount);
}
return result;
return chunkCount;
}

@SuppressWarnings("unchecked")
private List<String> chunkLeafType(Object value, ChunkCountWrapper chunkCountWrapper) {
private int chunkLeafType(Object value, List<String> result, Map<String, Object> runTimeParameters, int chunkCount) {
// leaf type is either String or List<String>
List<String> chunkedResult = null;
// the result should be an empty string
if (value instanceof String) {
chunkedResult = chunkString(value.toString(), chunkCountWrapper);
chunkCount = chunkString(value.toString(), result, runTimeParameters, chunkCount);
} else if (isListOfString(value)) {
chunkedResult = chunkList((List<String>) value, chunkCountWrapper);
chunkCount = chunkList((List<String>) value, result, runTimeParameters, chunkCount);
}
return chunkCount;
}

private int getMaxTokenCount(Map<String, Object> sourceAndMetadataMap) {
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName);
int maxTokenCount;
if (indexMetadata != null) {
// if the index exists, read maxTokenCount from the index setting
IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex());
maxTokenCount = indexService.getIndexSettings().getMaxTokenCount();
} else {
maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(environment.settings());
}
return chunkedResult;
return maxTokenCount;
}

/**
Expand All @@ -204,23 +210,14 @@ private List<String> chunkLeafType(Object value, ChunkCountWrapper chunkCountWra
@Override
public IngestDocument execute(IngestDocument ingestDocument) {
validateFieldsValue(ingestDocument);
ChunkCountWrapper chunkCountWrapper = new ChunkCountWrapper(0);
if (Objects.equals(chunkerType, FIXED_TOKEN_LENGTH_ALGORITHM)) {
// add maxTokenCount setting from index metadata to chunker parameters
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
String indexName = sourceAndMetadataMap.get(IndexFieldMapper.NAME).toString();
int maxTokenCount = IndexSettings.MAX_TOKEN_COUNT_SETTING.get(environment.settings());
IndexMetadata indexMetadata = clusterService.state().metadata().index(indexName);
if (indexMetadata != null) {
// if the index exists, read maxTokenCount from the index setting
IndexService indexService = indicesService.indexServiceSafe(indexMetadata.getIndex());
maxTokenCount = indexService.getIndexSettings().getMaxTokenCount();
}
chunkerParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount);
}

int chunkCount = 0;
Map<String, Object> runtimeParameters = new HashMap<>();
Map<String, Object> sourceAndMetadataMap = ingestDocument.getSourceAndMetadata();
chunkMapType(sourceAndMetadataMap, fieldMap, chunkCountWrapper);
if (chunker instanceof FixedTokenLengthChunker) {
int maxTokenCount = getMaxTokenCount(sourceAndMetadataMap);
runtimeParameters.put(FixedTokenLengthChunker.MAX_TOKEN_COUNT_FIELD, maxTokenCount);
}
chunkMapType(sourceAndMetadataMap, fieldMap, runtimeParameters, chunkCount);
return ingestDocument;
}

Expand Down Expand Up @@ -269,7 +266,12 @@ private void validateListTypeValue(String sourceKey, Object sourceValue, int max
}

@SuppressWarnings("unchecked")
private void chunkMapType(Map<String, Object> sourceAndMetadataMap, Map<String, Object> fieldMap, ChunkCountWrapper chunkCountWrapper) {
private int chunkMapType(
Map<String, Object> sourceAndMetadataMap,
Map<String, Object> fieldMap,
Map<String, Object> runtimeParameters,
int chunkCount
) {
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
String originalKey = fieldMapEntry.getKey();
Object targetKey = fieldMapEntry.getValue();
Expand All @@ -280,20 +282,30 @@ private void chunkMapType(Map<String, Object> sourceAndMetadataMap, Map<String,
List<Object> sourceObjectList = (List<Object>) sourceObject;
for (Object source : sourceObjectList) {
if (source instanceof Map) {
chunkMapType((Map<String, Object>) source, (Map<String, Object>) targetKey, chunkCountWrapper);
chunkCount = chunkMapType(
(Map<String, Object>) source,
(Map<String, Object>) targetKey,
runtimeParameters,
chunkCount
);
}
}
} else if (sourceObject instanceof Map) {
chunkMapType((Map<String, Object>) sourceObject, (Map<String, Object>) targetKey, chunkCountWrapper);
chunkCount = chunkMapType(
(Map<String, Object>) sourceObject,
(Map<String, Object>) targetKey,
runtimeParameters,
chunkCount
);
}
} else {
// chunk the object when target key is a string
Object chunkObject = sourceAndMetadataMap.get(originalKey);
List<String> chunkedResult = chunkLeafType(chunkObject, chunkCountWrapper);
if (chunkedResult != null) {
sourceAndMetadataMap.put(String.valueOf(targetKey), chunkedResult);
}
List<String> chunkedResult = new ArrayList<>();
chunkCount = chunkLeafType(chunkObject, chunkedResult, runtimeParameters, chunkCount);
sourceAndMetadataMap.put(String.valueOf(targetKey), chunkedResult);
}
}
return chunkCount;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import com.google.common.collect.ImmutableList;

import java.util.ArrayList;
import java.util.Map;
import java.util.List;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
import java.util.Map;
import java.util.Set;

import org.opensearch.index.analysis.AnalysisRegistry;

/**
* A factory to create different chunking algorithm classes and return all supported chunking algorithms.
*/
Expand Down

0 comments on commit ee585c3

Please sign in to comment.