diff --git a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java index 31577604f..2670d3882 100644 --- a/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java +++ b/src/main/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunker.java @@ -14,6 +14,8 @@ public DelimiterChunker() {} public static String DELIMITER_FIELD = "delimiter"; + public static String MAX_CHUNK_LIMIT_FIELD = "max_chunk_limit"; + @Override public void validateParameters(Map parameters) { if (parameters.containsKey(DELIMITER_FIELD)) { @@ -26,11 +28,26 @@ public void validateParameters(Map parameters) { } else { throw new IllegalArgumentException("You must contain field:" + DELIMITER_FIELD + " in your parameter."); } + if (parameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) { + Object maxChunkLimit = parameters.get(MAX_CHUNK_LIMIT_FIELD); + if (!(maxChunkLimit instanceof String)) { + throw new IllegalArgumentException( + "Parameter max_chunk_limit:" + maxChunkLimit.toString() + " cannot be converted to integer." + ); + } else { + try { + int maxChunkingNumber = Integer.valueOf((String) maxChunkLimit); + } catch (Exception exception) { + throw new IllegalArgumentException("Parameter max_chunk_limit:" + maxChunkLimit + " cannot be converted to integer."); + } + } + } } @Override public List chunk(String content, Map parameters) { String delimiter = (String) parameters.get(DELIMITER_FIELD); + int maxChunkingNumber = Integer.valueOf((String) parameters.getOrDefault(MAX_CHUNK_LIMIT_FIELD, "0")); List chunkResult = new ArrayList<>(); int start = 0; int end = content.indexOf(delimiter); @@ -39,10 +56,16 @@ public List chunk(String content, Map parameters) { chunkResult.add(content.substring(start, end + delimiter.length())); start = end + delimiter.length(); end = content.indexOf(delimiter, start); + if (chunkResult.size() > maxChunkingNumber && maxChunkingNumber > 0) { + throw new IllegalArgumentException("Exceed max chunk number: " + maxChunkingNumber); + } } if (start < content.length()) { chunkResult.add(content.substring(start)); + if (chunkResult.size() > maxChunkingNumber && maxChunkingNumber > 0) { + throw new IllegalArgumentException("Exceed max chunk number: " + maxChunkingNumber); + } } return chunkResult; diff --git a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java index d201ab574..147f91588 100644 --- a/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java +++ b/src/test/java/org/opensearch/neuralsearch/processor/chunker/DelimiterChunkerTests.java @@ -13,6 +13,7 @@ import static junit.framework.TestCase.assertEquals; import static org.junit.Assert.assertThrows; import static org.opensearch.neuralsearch.processor.chunker.DelimiterChunker.DELIMITER_FIELD; +import static org.opensearch.neuralsearch.processor.chunker.DelimiterChunker.MAX_CHUNK_LIMIT_FIELD; public class DelimiterChunkerTests extends OpenSearchTestCase { @@ -24,6 +25,22 @@ public void testChunkerWithNoDelimiterField() { Assert.assertEquals("You must contain field:" + DELIMITER_FIELD + " in your parameter.", exception.getMessage()); } + public void testChunkerWithWrongLimitFieldList() { + DelimiterChunker chunker = new DelimiterChunker(); + String content = "a\nb\nc\nd"; + Map inputParameters = Map.of(MAX_CHUNK_LIMIT_FIELD, List.of("-1"), DELIMITER_FIELD, "\n"); + Exception exception = assertThrows(IllegalArgumentException.class, () -> chunker.validateParameters(inputParameters)); + Assert.assertEquals("Parameter max_chunk_limit:" + List.of("-1") + " cannot be converted to integer.", exception.getMessage()); + } + + public void testChunkerWithWrongLimitField() { + DelimiterChunker chunker = new DelimiterChunker(); + String content = "a\nb\nc\nd"; + Map inputParameters = Map.of(MAX_CHUNK_LIMIT_FIELD, "1000\n", DELIMITER_FIELD, "\n"); + Exception exception = assertThrows(IllegalArgumentException.class, () -> chunker.validateParameters(inputParameters)); + Assert.assertEquals("Parameter max_chunk_limit:1000\n cannot be converted to integer.", exception.getMessage()); + } + public void testChunkerWithDelimiterFieldNotString() { DelimiterChunker chunker = new DelimiterChunker(); String content = "a\nb\nc\nd"; @@ -40,6 +57,14 @@ public void testChunkerWithDelimiterFieldNoString() { Assert.assertEquals("delimiter parameters should not be empty.", exception.getMessage()); } + public void testChunkerWithLimitNumber() { + DelimiterChunker chunker = new DelimiterChunker(); + String content = "a\nb\nc\nd"; + Map inputParameters = Map.of(DELIMITER_FIELD, "\n", MAX_CHUNK_LIMIT_FIELD, "1"); + Exception exception = assertThrows(IllegalArgumentException.class, () -> chunker.chunk(content, inputParameters)); + Assert.assertEquals("Exceed max chunk number: 1", exception.getMessage()); + } + public void testChunker() { DelimiterChunker chunker = new DelimiterChunker(); String content = "a\nb\nc\nd";