From 3a416496a6b5bf3256bdbfb56e66dbd97cc0c8ec Mon Sep 17 00:00:00 2001 From: yuye-aws Date: Fri, 1 Mar 2024 17:29:58 +0800 Subject: [PATCH] add max chunk limit into fixed token length algorithm Signed-off-by: yuye-aws --- .../processor/chunker/DelimiterChunker.java | 3 +- .../chunker/FixedTokenLengthChunker.java | 46 +++++++++++++------ 2 files changed, 33 insertions(+), 16 deletions(-) diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java index c6b5c3ae9..fc5f41d24 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java @@ -22,7 +22,7 @@ public void validateParameters(Map parameters) { Object delimiter = parameters.get(DELIMITER_FIELD); if (!(delimiter instanceof String)) { throw new IllegalArgumentException("delimiter parameters: " + delimiter + " must be string."); - } else if (((String) delimiter).length() == 0) { + } else if (((String) delimiter).isEmpty()) { throw new IllegalArgumentException("delimiter parameters should not be empty."); } } else { @@ -50,7 +50,6 @@ public List chunk(String content, Map parameters) { addChunkResult(chunkResult, maxChunkingNumber, content.substring(start, end + delimiter.length())); start = end + delimiter.length(); end = content.indexOf(delimiter, start); - } if (start < content.length()) { diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java index 3079fcf8e..b57a8fdca 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/FixedTokenLengthChunker.java @@ -21,15 +21,15 @@ public class FixedTokenLengthChunker implements IFieldChunker { public static final String TOKEN_LIMIT_FIELD = "token_limit"; public static final String OVERLAP_RATE_FIELD = "overlap_rate"; - public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count"; - + public static String MAX_CHUNK_LIMIT_FIELD = "max_chunk_limit"; public static final String TOKENIZER_FIELD = "tokenizer"; // default values for each parameter private static final int DEFAULT_TOKEN_LIMIT = 500; private static final double DEFAULT_OVERLAP_RATE = 0.2; private static final int DEFAULT_MAX_TOKEN_COUNT = 10000; + private static final int DEFAULT_MAX_CHUNK_LIMIT = -1; private static final String DEFAULT_TOKENIZER = "standard"; private final AnalysisRegistry analysisRegistry; @@ -62,6 +62,8 @@ public List chunk(String content, Map parameters) { int tokenLimit = DEFAULT_TOKEN_LIMIT; double overlapRate = DEFAULT_OVERLAP_RATE; int maxTokenCount = DEFAULT_MAX_TOKEN_COUNT; + int maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT; + String tokenizer = DEFAULT_TOKENIZER; if (parameters.containsKey(TOKEN_LIMIT_FIELD)) { @@ -76,6 +78,9 @@ public List chunk(String content, Map parameters) { if (parameters.containsKey(TOKENIZER_FIELD)) { tokenizer = (String) parameters.get(TOKENIZER_FIELD); } + if (parameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) { + maxChunkLimit = ((Number) parameters.get(MAX_CHUNK_LIMIT_FIELD)).intValue(); + } List tokens = tokenize(content, tokenizer, maxTokenCount); List passages = new ArrayList<>(); @@ -90,29 +95,42 @@ public List chunk(String content, Map parameters) { if (startToken + tokenLimit >= tokens.size()) { // break the loop when already cover the last token passage = String.join(" ", tokens.subList(startToken, tokens.size())); - passages.add(passage); + addPassageToList(passages, passage, maxChunkLimit); break; } else { passage = String.join(" ", tokens.subList(startToken, startToken + tokenLimit)); - passages.add(passage); + addPassageToList(passages, passage, maxChunkLimit); } startToken += tokenLimit - overlapTokenNumber; } return passages; } + private void addPassageToList(List passages, String passage, int maxChunkLimit) { + if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && passages.size() + 1 >= maxChunkLimit) { + throw new IllegalArgumentException("Exceed max chunk number: " + maxChunkLimit); + } + passages.add(passage); + } + + private void validatePositiveIntegerParameter(Map parameters, String fieldName) { + // this method validate that parameter is a positive integer + // this method accepts positive float or double number + if (!(parameters.get(fieldName) instanceof Number)) { + throw new IllegalArgumentException( + "fixed length parameter [" + fieldName + "] cannot be cast to [" + Number.class.getName() + "]" + ); + } + if (((Number) parameters.get(fieldName)).intValue() <= 0) { + throw new IllegalArgumentException("fixed length parameter [" + fieldName + "] must be positive"); + } + } + @Override public void validateParameters(Map parameters) { - if (parameters.containsKey(TOKEN_LIMIT_FIELD)) { - if (!(parameters.get(TOKEN_LIMIT_FIELD) instanceof Number)) { - throw new IllegalArgumentException( - "fixed length parameter [" + TOKEN_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]" - ); - } - if (((Number) parameters.get(TOKEN_LIMIT_FIELD)).intValue() <= 0) { - throw new IllegalArgumentException("fixed length parameter [" + TOKEN_LIMIT_FIELD + "] must be positive"); - } - } + validatePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD); + validatePositiveIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD); + validatePositiveIntegerParameter(parameters, MAX_TOKEN_COUNT_FIELD); if (parameters.containsKey(OVERLAP_RATE_FIELD)) { if (!(parameters.get(OVERLAP_RATE_FIELD) instanceof Number)) {