Skip to content

Commit

Permalink
implement unit tests for unit tests with max_chunk_limit in fixed tok…
Browse files Browse the repository at this point in the history
…en length

Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 1, 2024
1 parent 6065b6f commit 3a90f39
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ public List<String> chunk(String content, Map<String, Object> parameters) {

private void addChunkResult(List<String> chunkResult, int maxChunkingNumber, String candidate) {
if (chunkResult.size() >= maxChunkingNumber && maxChunkingNumber > 0) {
throw new IllegalArgumentException("Exceed max chunk number: " + maxChunkingNumber);
throw new IllegalStateException("Exceed max chunk number: " + maxChunkingNumber);
}
chunkResult.add(candidate);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -107,18 +107,22 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
}

private void addPassageToList(List<String> passages, String passage, int maxChunkLimit) {
if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && passages.size() + 1 >= maxChunkLimit) {
throw new IllegalArgumentException("Exceed max chunk number: " + maxChunkLimit);
if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && passages.size() >= maxChunkLimit) {
throw new IllegalStateException("Exceed max chunk number: " + maxChunkLimit);
}
passages.add(passage);
}

private void validatePositiveIntegerParameter(Map<String, Object> parameters, String fieldName) {
// this method validate that parameter is a positive integer
// this method accepts positive float or double number
if (!parameters.containsKey(fieldName)) {
// all parameters are optional
return;
}
if (!(parameters.get(fieldName) instanceof Number)) {
throw new IllegalArgumentException(
"fixed length parameter [" + fieldName + "] cannot be cast to [" + Number.class.getName() + "]"
"fixed length parameter [" + fieldName + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
if (((Number) parameters.get(fieldName)).intValue() <= 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,10 @@

import static java.util.Collections.singletonList;
import static java.util.Collections.singletonMap;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKENIZER_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKEN_LIMIT_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.OVERLAP_RATE_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.TOKENIZER_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.FixedTokenLengthChunker.MAX_CHUNK_LIMIT_FIELD;

public class FixedTokenLengthChunkerTests extends OpenSearchTestCase {

Expand Down Expand Up @@ -66,7 +67,10 @@ public void testValidateParameters_whenIllegalTokenLimitType_thenFail() {
IllegalArgumentException.class,
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals("fixed length parameter [token_limit] cannot be cast to [java.lang.Number]", illegalArgumentException.getMessage());
assertEquals(
"fixed length parameter [" + TOKEN_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalTokenLimitValue_thenFail() {
Expand All @@ -76,7 +80,7 @@ public void testValidateParameters_whenIllegalTokenLimitValue_thenFail() {
IllegalArgumentException.class,
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals("fixed length parameter [token_limit] must be positive", illegalArgumentException.getMessage());
assertEquals("fixed length parameter [" + TOKEN_LIMIT_FIELD + "] must be positive", illegalArgumentException.getMessage());
}

public void testValidateParameters_whenIllegalOverlapRateType_thenFail() {
Expand All @@ -86,7 +90,10 @@ public void testValidateParameters_whenIllegalOverlapRateType_thenFail() {
IllegalArgumentException.class,
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals("fixed length parameter [overlap_rate] cannot be cast to [java.lang.Number]", illegalArgumentException.getMessage());
assertEquals(
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] cannot be cast to [" + Number.class.getName() + "]",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalOverlapRateValue_thenFail() {
Expand All @@ -97,7 +104,7 @@ public void testValidateParameters_whenIllegalOverlapRateValue_thenFail() {
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals(
"fixed length parameter [overlap_rate] must be between 0 and 1, 1 is not included.",
"fixed length parameter [" + OVERLAP_RATE_FIELD + "] must be between 0 and 1, 1 is not included.",
illegalArgumentException.getMessage()
);
}
Expand All @@ -109,7 +116,33 @@ public void testValidateParameters_whenIllegalTokenizerType_thenFail() {
IllegalArgumentException.class,
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals("fixed length parameter [tokenizer] cannot be cast to [java.lang.String]", illegalArgumentException.getMessage());
assertEquals(
"fixed length parameter [" + TOKENIZER_FIELD + "] cannot be cast to [" + String.class.getName() + "]",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalChunkLimitType_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(MAX_CHUNK_LIMIT_FIELD, "invalid chunk limit");
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals(
"fixed length parameter [" + MAX_CHUNK_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]",
illegalArgumentException.getMessage()
);
}

public void testValidateParameters_whenIllegalChunkLimitValue_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(MAX_CHUNK_LIMIT_FIELD, -1);
IllegalArgumentException illegalArgumentException = assertThrows(
IllegalArgumentException.class,
() -> FixedTokenLengthChunker.validateParameters(parameters)
);
assertEquals("fixed length parameter [" + MAX_CHUNK_LIMIT_FIELD + "] must be positive", illegalArgumentException.getMessage());
}

public void testChunk_withTokenLimit_10() {
Expand Down Expand Up @@ -153,4 +186,31 @@ public void testChunk_withOverlapRate_half() {
expectedPassages.add("sentences and 24 tokens by standard tokenizer in OpenSearch");
assertEquals(expectedPassages, passages);
}

public void testChunk_withMaxChunkLimitOne_thenFail() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(TOKEN_LIMIT_FIELD, 10);
parameters.put(MAX_CHUNK_LIMIT_FIELD, 1);
String content =
"This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch.";
IllegalStateException illegalStateException = assertThrows(
IllegalStateException.class,
() -> FixedTokenLengthChunker.chunk(content, parameters)
);
assertEquals("Exceed max chunk number: 1", illegalStateException.getMessage());
}

public void testChunk_withMaxChunkLimitTen_thenSuccess() {
Map<String, Object> parameters = new HashMap<>();
parameters.put(TOKEN_LIMIT_FIELD, 10);
parameters.put(MAX_CHUNK_LIMIT_FIELD, 10);
String content =
"This is an example document to be chunked. The document contains a single paragraph, two sentences and 24 tokens by standard tokenizer in OpenSearch.";
List<String> passages = FixedTokenLengthChunker.chunk(content, parameters);
List<String> expectedPassages = new ArrayList<>();
expectedPassages.add("This is an example document to be chunked The document");
expectedPassages.add("The document contains a single paragraph two sentences and 24");
expectedPassages.add("and 24 tokens by standard tokenizer in OpenSearch");
assertEquals(expectedPassages, passages);
}
}

0 comments on commit 3a90f39

Please sign in to comment.