Skip to content

Commit

Permalink
Optimize parameter parsing in text chunking processor (#733)
Browse files Browse the repository at this point in the history
* Optimize parameter parsing in text chunking processor

Signed-off-by: yuye-aws <[email protected]>

* add change log

Signed-off-by: yuye-aws <[email protected]>

* fix unit tests in delimiter chunker

Signed-off-by: yuye-aws <[email protected]>

* fix unit tests in fixed token length chunker

Signed-off-by: yuye-aws <[email protected]>

* remove redundant

Signed-off-by: yuye-aws <[email protected]>

* refactor chunker parameter parser

Signed-off-by: yuye-aws <[email protected]>

* unit tests for chunker parameter parser

Signed-off-by: yuye-aws <[email protected]>

* fix comment

Signed-off-by: yuye-aws <[email protected]>

* spotless apply

Signed-off-by: yuye-aws <[email protected]>

---------

Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws authored May 21, 2024
1 parent 7c54c86 commit 038b1ec
Show file tree
Hide file tree
Showing 9 changed files with 421 additions and 93 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
### Features
### Enhancements
- Pass empty doc collector instead of top docs collector to improve hybrid query latencies by 20% ([#731](https://github.com/opensearch-project/neural-search/pull/731))
- Optimize parameter parsing in text chunking processor ([#733](https://github.com/opensearch-project/neural-search/pull/733))
### Bug Fixes
### Infrastructure
### Documentation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DEFAULT_MAX_CHUNK_LIMIT;
import static org.opensearch.neuralsearch.processor.chunker.Chunker.DISABLED_MAX_CHUNK_LIMIT;
import static org.opensearch.neuralsearch.processor.chunker.Chunker.CHUNK_STRING_COUNT_FIELD;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerWithDefault;

/**
* This processor is used for text chunking.
Expand Down Expand Up @@ -115,8 +116,8 @@ private void parseAlgorithmMap(final Map<String, Object> algorithmMap) {
}
Map<String, Object> chunkerParameters = (Map<String, Object>) algorithmValue;
// parse processor level max chunk limit
this.maxChunkLimit = parseIntegerParameter(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
if (maxChunkLimit < 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) {
this.maxChunkLimit = parseIntegerWithDefault(chunkerParameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
if (maxChunkLimit <= 0 && maxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) {
throw new IllegalArgumentException(
String.format(
Locale.ROOT,
Expand Down Expand Up @@ -309,10 +310,10 @@ private List<String> chunkString(final String content, final Map<String, Object>
}
List<String> contentResult = chunker.chunk(content, runTimeParameters);
// update chunk_string_count for each string
int chunkStringCount = parseIntegerParameter(runTimeParameters, CHUNK_STRING_COUNT_FIELD, 1);
int chunkStringCount = parseInteger(runTimeParameters, CHUNK_STRING_COUNT_FIELD);
runTimeParameters.put(CHUNK_STRING_COUNT_FIELD, chunkStringCount - 1);
// update runtime max_chunk_limit if not disabled
int runtimeMaxChunkLimit = parseIntegerParameter(runTimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit);
int runtimeMaxChunkLimit = parseInteger(runTimeParameters, MAX_CHUNK_LIMIT_FIELD);
if (runtimeMaxChunkLimit != DISABLED_MAX_CHUNK_LIMIT) {
runTimeParameters.put(MAX_CHUNK_LIMIT_FIELD, runtimeMaxChunkLimit - contentResult.size());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,7 @@ private ChunkerParameterParser() {} // no instance of this util class
* Parse String type parameter.
* Throw IllegalArgumentException if parameter is not a string or an empty string.
*/
public static String parseStringParameter(final Map<String, Object> parameters, final String fieldName, final String defaultValue) {
if (!parameters.containsKey(fieldName)) {
// all string parameters are optional
return defaultValue;
}
public static String parseString(final Map<String, Object> parameters, final String fieldName) {
Object fieldValue = parameters.get(fieldName);
if (!(fieldValue instanceof String)) {
throw new IllegalArgumentException(
Expand All @@ -40,14 +36,23 @@ public static String parseStringParameter(final Map<String, Object> parameters,
}

/**
* Parse integer type parameter.
* Throw IllegalArgumentException if parameter is not an integer.
* Parse String type parameter.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if parameter is not a string or an empty string.
*/
public static int parseIntegerParameter(final Map<String, Object> parameters, final String fieldName, final int defaultValue) {
public static String parseStringWithDefault(final Map<String, Object> parameters, final String fieldName, final String defaultValue) {
if (!parameters.containsKey(fieldName)) {
// all integer parameters are optional
// all string parameters are optional
return defaultValue;
}
return parseString(parameters, fieldName);
}

/**
* Parse integer type parameter with default value.
* Throw IllegalArgumentException if the parameter is not an integer.
*/
public static int parseInteger(final Map<String, Object> parameters, final String fieldName) {
String fieldValueString = parameters.get(fieldName).toString();
try {
return NumberUtils.createInteger(fieldValueString);
Expand All @@ -58,27 +63,54 @@ public static int parseIntegerParameter(final Map<String, Object> parameters, fi
}
}

/**
* Parse integer type parameter with default value.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if the parameter is not an integer.
*/
public static int parseIntegerWithDefault(final Map<String, Object> parameters, final String fieldName, final int defaultValue) {
if (!parameters.containsKey(fieldName)) {
// return the default value when parameter is missing
return defaultValue;
}
return parseInteger(parameters, fieldName);
}

/**
* Parse integer type parameter with positive value.
* Throw IllegalArgumentException if parameter is not a positive integer.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if the parameter is not a positive integer.
*/
public static int parsePositiveIntegerParameter(final Map<String, Object> parameters, final String fieldName, final int defaultValue) {
int fieldValueInt = parseIntegerParameter(parameters, fieldName, defaultValue);
public static int parsePositiveInteger(final Map<String, Object> parameters, final String fieldName) {
int fieldValueInt = parseInteger(parameters, fieldName);
if (fieldValueInt <= 0) {
throw new IllegalArgumentException(String.format(Locale.ROOT, "Parameter [%s] must be positive.", fieldName));
}
return fieldValueInt;
}

/**
* Parse double type parameter.
* Throw IllegalArgumentException if parameter is not a double.
* Parse integer type parameter with positive value.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if the parameter is not a positive integer.
*/
public static double parseDoubleParameter(final Map<String, Object> parameters, final String fieldName, final double defaultValue) {
public static int parsePositiveIntegerWithDefault(
final Map<String, Object> parameters,
final String fieldName,
final Integer defaultValue
) {
if (!parameters.containsKey(fieldName)) {
// all double parameters are optional
return defaultValue;
}
return parsePositiveInteger(parameters, fieldName);
}

/**
* Parse double type parameter.
* Throw IllegalArgumentException if parameter is not a double.
*/
public static double parseDouble(final Map<String, Object> parameters, final String fieldName) {
String fieldValueString = parameters.get(fieldName).toString();
try {
return NumberUtils.createDouble(fieldValueString);
Expand All @@ -88,4 +120,17 @@ public static double parseDoubleParameter(final Map<String, Object> parameters,
);
}
}

/**
* Parse double type parameter.
* Return default value if the parameter is missing.
* Throw IllegalArgumentException if parameter is not a double.
*/
public static double parseDoubleWithDefault(final Map<String, Object> parameters, final String fieldName, final double defaultValue) {
if (!parameters.containsKey(fieldName)) {
// all double parameters are optional
return defaultValue;
}
return parseDouble(parameters, fieldName);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
import java.util.List;
import java.util.ArrayList;

import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringWithDefault;

/**
* The implementation {@link Chunker} for delimiter algorithm
Expand All @@ -23,7 +23,6 @@ public final class DelimiterChunker implements Chunker {
public static final String DEFAULT_DELIMITER = "\n\n";

private String delimiter;
private int maxChunkLimit;

public DelimiterChunker(final Map<String, Object> parameters) {
parseParameters(parameters);
Expand All @@ -39,8 +38,7 @@ public DelimiterChunker(final Map<String, Object> parameters) {
*/
@Override
public void parseParameters(Map<String, Object> parameters) {
this.delimiter = parseStringParameter(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER);
this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
this.delimiter = parseStringWithDefault(parameters, DELIMITER_FIELD, DEFAULT_DELIMITER);
}

/**
Expand All @@ -53,8 +51,8 @@ public void parseParameters(Map<String, Object> parameters) {
*/
@Override
public List<String> chunk(final String content, final Map<String, Object> runtimeParameters) {
int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, maxChunkLimit);
int chunkStringCount = parseIntegerParameter(runtimeParameters, CHUNK_STRING_COUNT_FIELD, 1);
int runtimeMaxChunkLimit = parseInteger(runtimeParameters, MAX_CHUNK_LIMIT_FIELD);
int chunkStringCount = parseInteger(runtimeParameters, CHUNK_STRING_COUNT_FIELD);

List<String> chunkResult = new ArrayList<>();
int start = 0, end;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@
import org.opensearch.action.admin.indices.analyze.AnalyzeAction;
import org.opensearch.action.admin.indices.analyze.AnalyzeAction.AnalyzeToken;
import static org.opensearch.action.admin.indices.analyze.TransportAnalyzeAction.analyze;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerParameter;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseInteger;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseStringWithDefault;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parseDoubleWithDefault;
import static org.opensearch.neuralsearch.processor.chunker.ChunkerParameterParser.parsePositiveIntegerWithDefault;

/**
* The implementation {@link Chunker} for fixed token length algorithm.
Expand All @@ -33,10 +33,9 @@ public final class FixedTokenLengthChunker implements Chunker {
public static final String MAX_TOKEN_COUNT_FIELD = "max_token_count";
public static final String TOKENIZER_FIELD = "tokenizer";

// default values for each parameter
// default values for each non-runtime parameter
private static final int DEFAULT_TOKEN_LIMIT = 384;
private static final double DEFAULT_OVERLAP_RATE = 0.0;
private static final int DEFAULT_MAX_TOKEN_COUNT = 10000;
private static final String DEFAULT_TOKENIZER = "standard";

// parameter restrictions
Expand All @@ -54,7 +53,6 @@ public final class FixedTokenLengthChunker implements Chunker {

// parameter value
private int tokenLimit;
private int maxChunkLimit;
private String tokenizer;
private double overlapRate;
private final AnalysisRegistry analysisRegistry;
Expand All @@ -81,10 +79,9 @@ public FixedTokenLengthChunker(final Map<String, Object> parameters) {
*/
@Override
public void parseParameters(Map<String, Object> parameters) {
this.tokenLimit = parsePositiveIntegerParameter(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT);
this.overlapRate = parseDoubleParameter(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE);
this.tokenizer = parseStringParameter(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER);
this.maxChunkLimit = parseIntegerParameter(parameters, MAX_CHUNK_LIMIT_FIELD, DEFAULT_MAX_CHUNK_LIMIT);
this.tokenLimit = parsePositiveIntegerWithDefault(parameters, TOKEN_LIMIT_FIELD, DEFAULT_TOKEN_LIMIT);
this.overlapRate = parseDoubleWithDefault(parameters, OVERLAP_RATE_FIELD, DEFAULT_OVERLAP_RATE);
this.tokenizer = parseStringWithDefault(parameters, TOKENIZER_FIELD, DEFAULT_TOKENIZER);
if (overlapRate < OVERLAP_RATE_LOWER_BOUND || overlapRate > OVERLAP_RATE_UPPER_BOUND) {
throw new IllegalArgumentException(
String.format(
Expand Down Expand Up @@ -121,9 +118,9 @@ public void parseParameters(Map<String, Object> parameters) {
*/
@Override
public List<String> chunk(final String content, final Map<String, Object> runtimeParameters) {
int maxTokenCount = parsePositiveIntegerParameter(runtimeParameters, MAX_TOKEN_COUNT_FIELD, DEFAULT_MAX_TOKEN_COUNT);
int runtimeMaxChunkLimit = parseIntegerParameter(runtimeParameters, MAX_CHUNK_LIMIT_FIELD, this.maxChunkLimit);
int chunkStringCount = parseIntegerParameter(runtimeParameters, CHUNK_STRING_COUNT_FIELD, 1);
int maxTokenCount = parseInteger(runtimeParameters, MAX_TOKEN_COUNT_FIELD);
int runtimeMaxChunkLimit = parseInteger(runtimeParameters, MAX_CHUNK_LIMIT_FIELD);
int chunkStringCount = parseInteger(runtimeParameters, CHUNK_STRING_COUNT_FIELD);

List<AnalyzeToken> tokens = tokenize(content, tokenizer, maxTokenCount);
List<String> chunkResult = new ArrayList<>();
Expand Down
Loading

0 comments on commit 038b1ec

Please sign in to comment.