Skip to content

Commit

Permalink
add max chunk limit into fixed token length algorithm
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Mar 1, 2024
1 parent ad2ecb0 commit 3a41649
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ public void validateParameters(Map<String, Object> parameters) {
Object delimiter = parameters.get(DELIMITER_FIELD);
if (!(delimiter instanceof String)) {
throw new IllegalArgumentException("delimiter parameters: " + delimiter + " must be string.");
} else if (((String) delimiter).length() == 0) {
} else if (((String) delimiter).isEmpty()) {
throw new IllegalArgumentException("delimiter parameters should not be empty.");
}
} else {
Expand Down Expand Up @@ -50,7 +50,6 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
addChunkResult(chunkResult, maxChunkingNumber, content.substring(start, end + delimiter.length()));
start = end + delimiter.length();
end = content.indexOf(delimiter, start);

}

if (start < content.length()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@ public class FixedTokenLengthChunker implements IFieldChunker {

public static final String TOKEN_LIMIT_FIELD = "token_limit";
public static final String OVERLAP_RATE_FIELD = "overlap_rate";

public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count";

public static String MAX_CHUNK_LIMIT_FIELD = "max_chunk_limit";
public static final String TOKENIZER_FIELD = "tokenizer";

// default values for each parameter
private static final int DEFAULT_TOKEN_LIMIT = 500;
private static final double DEFAULT_OVERLAP_RATE = 0.2;
private static final int DEFAULT_MAX_TOKEN_COUNT = 10000;
private static final int DEFAULT_MAX_CHUNK_LIMIT = -1;
private static final String DEFAULT_TOKENIZER = "standard";

private final AnalysisRegistry analysisRegistry;
Expand Down Expand Up @@ -62,6 +62,8 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
int tokenLimit = DEFAULT_TOKEN_LIMIT;
double overlapRate = DEFAULT_OVERLAP_RATE;
int maxTokenCount = DEFAULT_MAX_TOKEN_COUNT;
int maxChunkLimit = DEFAULT_MAX_CHUNK_LIMIT;

String tokenizer = DEFAULT_TOKENIZER;

if (parameters.containsKey(TOKEN_LIMIT_FIELD)) {
Expand All @@ -76,6 +78,9 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
if (parameters.containsKey(TOKENIZER_FIELD)) {
tokenizer = (String) parameters.get(TOKENIZER_FIELD);
}
if (parameters.containsKey(MAX_CHUNK_LIMIT_FIELD)) {
maxChunkLimit = ((Number) parameters.get(MAX_CHUNK_LIMIT_FIELD)).intValue();
}

List<String> tokens = tokenize(content, tokenizer, maxTokenCount);
List<String> passages = new ArrayList<>();
Expand All @@ -90,29 +95,42 @@ public List<String> chunk(String content, Map<String, Object> parameters) {
if (startToken + tokenLimit >= tokens.size()) {
// break the loop when already cover the last token
passage = String.join(" ", tokens.subList(startToken, tokens.size()));
passages.add(passage);
addPassageToList(passages, passage, maxChunkLimit);
break;
} else {
passage = String.join(" ", tokens.subList(startToken, startToken + tokenLimit));
passages.add(passage);
addPassageToList(passages, passage, maxChunkLimit);
}
startToken += tokenLimit - overlapTokenNumber;
}
return passages;
}

private void addPassageToList(List<String> passages, String passage, int maxChunkLimit) {
if (maxChunkLimit != DEFAULT_MAX_CHUNK_LIMIT && passages.size() + 1 >= maxChunkLimit) {
throw new IllegalArgumentException("Exceed max chunk number: " + maxChunkLimit);
}
passages.add(passage);
}

private void validatePositiveIntegerParameter(Map<String, Object> parameters, String fieldName) {
// this method validate that parameter is a positive integer
// this method accepts positive float or double number
if (!(parameters.get(fieldName) instanceof Number)) {
throw new IllegalArgumentException(
"fixed length parameter [" + fieldName + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
if (((Number) parameters.get(fieldName)).intValue() <= 0) {
throw new IllegalArgumentException("fixed length parameter [" + fieldName + "] must be positive");
}
}

@Override
public void validateParameters(Map<String, Object> parameters) {
if (parameters.containsKey(TOKEN_LIMIT_FIELD)) {
if (!(parameters.get(TOKEN_LIMIT_FIELD) instanceof Number)) {
throw new IllegalArgumentException(
"fixed length parameter [" + TOKEN_LIMIT_FIELD + "] cannot be cast to [" + Number.class.getName() + "]"
);
}
if (((Number) parameters.get(TOKEN_LIMIT_FIELD)).intValue() <= 0) {
throw new IllegalArgumentException("fixed length parameter [" + TOKEN_LIMIT_FIELD + "] must be positive");
}
}
validatePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD);
validatePositiveIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD);
validatePositiveIntegerParameter(parameters, MAX_TOKEN_COUNT_FIELD);

if (parameters.containsKey(OVERLAP_RATE_FIELD)) {
if (!(parameters.get(OVERLAP_RATE_FIELD) instanceof Number)) {
Expand Down

0 comments on commit 3a41649

Please sign in to comment.