Skip to content

Commit

Permalink
chunker factory create with analysis registry
Browse files Browse the repository at this point in the history
Signed-off-by: yuye-aws <[email protected]>
  • Loading branch information
yuye-aws committed Feb 22, 2024
1 parent ab5c4eb commit d3f3d58
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,8 @@
import java.util.Set;
import java.util.ArrayList;
import java.util.List;
import lombok.extern.log4j.Log4j2;

import org.opensearch.client.Client;
import org.opensearch.client.node.NodeClient;
import org.opensearch.index.analysis.AnalysisRegistry;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.AbstractProcessor;
Expand All @@ -21,32 +19,29 @@
import static org.opensearch.ingest.ConfigurationUtils.readMap;
import static org.opensearch.neuralsearch.processor.InferenceProcessor.FIELD_MAP_FIELD;

@Log4j2
public final class DocumentChunkingProcessor extends AbstractProcessor {

public static final String TYPE = "chunking";
public static final String OUTPUT_FIELD = "output_field";

private final Map<String, Object> fieldMap;

private static NodeClient nodeClient;

private final Set<String> supportedChunkers = ChunkerFactory.getChunkers();

public DocumentChunkingProcessor(String tag, String description, Map<String, Object> fieldMap) {
private final AnalysisRegistry analysisRegistry;

public DocumentChunkingProcessor(String tag, String description, Map<String, Object> fieldMap, AnalysisRegistry analysisRegistry) {
super(tag, description);
validateDocumentChunkingFieldMap(fieldMap);
this.fieldMap = fieldMap;
}

public static void initialize(Client nodeClient) {
DocumentChunkingProcessor.nodeClient = (NodeClient) nodeClient;
this.analysisRegistry = analysisRegistry;
}

public String getType() {
return TYPE;
}

@SuppressWarnings("unchecked")
private void validateDocumentChunkingFieldMap(Map<String, Object> fieldMap) {
if (fieldMap == null || fieldMap.isEmpty()) {
throw new IllegalArgumentException("Unable to create the processor as field_map is null or empty");
Expand All @@ -66,8 +61,7 @@ private void validateDocumentChunkingFieldMap(Map<String, Object> fieldMap) {
);
}

// Casting parameters to a map
Map<?, ?> parameterMap = (Map<?, ?>) parameters;
Map<String, Object> parameterMap = (Map<String, Object>) parameters;

// output field must be string
if (!(parameterMap.containsKey(OUTPUT_FIELD))) {
Expand All @@ -93,7 +87,7 @@ private void validateDocumentChunkingFieldMap(Map<String, Object> fieldMap) {
if (supportedChunkers.contains(parameterKey)) {
chunkingAlgorithmCount += 1;
chunkerParameters = (Map<String, Object>) parameterEntry.getValue();
IFieldChunker chunker = ChunkerFactory.create(parameterKey, nodeClient);
IFieldChunker chunker = ChunkerFactory.create(parameterKey, analysisRegistry);
chunker.validateParameters(chunkerParameters);
}
}
Expand All @@ -109,7 +103,7 @@ private void validateDocumentChunkingFieldMap(Map<String, Object> fieldMap) {
}

@Override
public final IngestDocument execute(IngestDocument document) {
public IngestDocument execute(IngestDocument document) {
for (Map.Entry<String, Object> fieldMapEntry : fieldMap.entrySet()) {
String inputField = fieldMapEntry.getKey();
Object content = document.getFieldValue(inputField, Object.class);
Expand All @@ -118,7 +112,22 @@ public final IngestDocument execute(IngestDocument document) {
throw new IllegalArgumentException("input field in document [" + inputField + "] is null, cannot process it.");
}

if (!(content instanceof String)) {
if (content instanceof List<?>) {
List<?> contentList = (List<?>) content;
for (Object contentElement : contentList) {
if (!(contentElement instanceof String)) {
throw new IllegalArgumentException(
"element in input field list ["
+ inputField
+ "] of type ["
+ contentElement.getClass().getName()
+ "] cannot be cast to ["
+ String.class.getName()
+ "]"
);
}
}
} else if (!(content instanceof String)) {
throw new IllegalArgumentException(
"input field ["
+ inputField
Expand All @@ -134,13 +143,21 @@ public final IngestDocument execute(IngestDocument document) {
String outputField = (String) parameters.get(OUTPUT_FIELD);
List<String> chunkedPassages = new ArrayList<>();

// parameter has been checked that there is only one algorithm
// we have validated that there is one chunking algorithm
// and that chunkerParameters is of type Map<String, Object>
for (Map.Entry<?, ?> parameterEntry : parameters.entrySet()) {
String parameterKey = (String) parameterEntry.getKey();
if (supportedChunkers.contains(parameterKey)) {
Map<?, ?> chunkerParameters = (Map<?, ?>) parameterEntry.getValue();
IFieldChunker chunker = ChunkerFactory.create(parameterKey, nodeClient);
chunkedPassages = chunker.chunk((String) content, (Map<String, Object>) chunkerParameters);
@SuppressWarnings("unchecked")
Map<String, Object> chunkerParameters = (Map<String, Object>) parameterEntry.getValue();
IFieldChunker chunker = ChunkerFactory.create(parameterKey, analysisRegistry);
if (content instanceof String) {
chunkedPassages = chunker.chunk((String) content, chunkerParameters);
} else {
for (Object contentElement : (List<?>) content) {
chunkedPassages.addAll(chunker.chunk((String) contentElement, chunkerParameters));
}
}
}
}
document.setFieldValue(outputField, chunkedPassages);
Expand All @@ -149,7 +166,12 @@ public final IngestDocument execute(IngestDocument document) {
}

public static class Factory implements Processor.Factory {
public Factory() {}

private final AnalysisRegistry analysisRegistry;

public Factory(AnalysisRegistry analysisRegistry) {
this.analysisRegistry = analysisRegistry;
}

@Override
public DocumentChunkingProcessor create(
Expand All @@ -159,7 +181,7 @@ public DocumentChunkingProcessor create(
Map<String, Object> config
) throws Exception {
Map<String, Object> fieldMap = readMap(TYPE, processorTag, config, FIELD_MAP_FIELD);
return new DocumentChunkingProcessor(processorTag, description, fieldMap);
return new DocumentChunkingProcessor(processorTag, description, fieldMap, analysisRegistry);
}

}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
*/
package org.opensearch.neuralsearch.processor.chunker;

import org.opensearch.client.node.NodeClient;
import org.opensearch.index.analysis.AnalysisRegistry;

import java.util.Set;

Expand All @@ -13,10 +13,10 @@ public class ChunkerFactory {
public static final String FIXED_LENGTH_ALGORITHM = "fix_length";
public static final String DELIMITER_ALGORITHM = "delimiter";

public static IFieldChunker create(String type, NodeClient nodeClient) {
public static IFieldChunker create(String type, AnalysisRegistry analysisRegistry) {
switch (type) {
case FIXED_LENGTH_ALGORITHM:
return new FixedTokenLengthChunker(nodeClient);
return new FixedTokenLengthChunker(analysisRegistry);
case DELIMITER_ALGORITHM:
return new DelimiterChunker();
default:
Expand Down

0 comments on commit d3f3d58

Please sign in to comment.