Skip to content

Commit

Permalink
Add hidden support for chunking in InferenceProcessor
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Nov 23, 2023
1 parent 5f3cd33 commit aea1b0b
Show file tree
Hide file tree
Showing 7 changed files with 160 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public final class ScriptProcessor extends AbstractProcessor {
* @param precompiledIngestScriptFactory The {@link Script} precompiled script
* @param scriptService The {@link ScriptService} used to execute the script.
*/
ScriptProcessor(
public ScriptProcessor(
String tag,
String description,
Script script,
Expand Down
15 changes: 15 additions & 0 deletions multi-cluster-run.gradle
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
rootProject {
if (project.name == 'elasticsearch') {
afterEvaluate {
testClusters.configureEach {
numberOfNodes = 2
}
def cluster = testClusters.named("runTask").get()
cluster.getNodes().each { node ->
node.setting('cluster.initial_master_nodes', cluster.getLastNode().getName())
node.setting('node.roles', '[master,data_hot,data_content]')
}
cluster.getFirstNode().setting('node.roles', '[]')
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
import org.elasticsearch.ingest.IngestDocument;
import org.elasticsearch.xcontent.ToXContentFragment;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;

Expand Down Expand Up @@ -55,6 +58,28 @@ static void writeResultToField(
}
}

static void writeChunkResultsToField(
List<InferenceResults> results,
IngestDocument ingestDocument,
@Nullable String basePath,
String outputField) {
Objects.requireNonNull(results, "results");
Objects.requireNonNull(ingestDocument, "ingestDocument");
Objects.requireNonNull(outputField, "outputField");
@SuppressWarnings("unchecked")
List<Map<String, Object>> inputValues = ingestDocument.getFieldValue(basePath + "." + outputField, List.class);
List<Map<String, Object>> outputValues = new ArrayList<>();
int currentResult = 0;
for (InferenceResults result : results) {
Map<String, Object> outputMap = new HashMap<>();
outputMap.put("inference", result.asMap(outputField).get(outputField));
outputMap.putAll(inputValues.get(currentResult));
outputValues.add(outputMap);
}

ingestDocument.setFieldValue(basePath + "." + outputField, outputValues);
}

private static void setOrAppendValue(String path, Object value, IngestDocument ingestDocument) {
if (ingestDocument.hasField(path)) {
ingestDocument.appendFieldValue(path, value);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@
import org.elasticsearch.index.analysis.CharFilterFactory;
import org.elasticsearch.index.analysis.TokenizerFactory;
import org.elasticsearch.index.mapper.Mapper;
import org.elasticsearch.index.mapper.MetadataFieldMapper;
import org.elasticsearch.index.query.QueryBuilder;
import org.elasticsearch.indices.AssociatedIndexDescriptor;
import org.elasticsearch.indices.SystemIndexDescriptor;
Expand Down Expand Up @@ -365,7 +364,6 @@
import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor;
import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor;
import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor;
import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;
Expand Down Expand Up @@ -2286,13 +2284,13 @@ public Map<String, Mapper.TypeParser> getMappers() {
);
}

@Override
public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
return Map.of(
SemanticTextInferenceResultFieldMapper.CONTENT_TYPE,
SemanticTextInferenceResultFieldMapper.PARSER
);
}
// @Override
// public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
// return Map.of(
// SemanticTextInferenceResultFieldMapper.CONTENT_TYPE,
// SemanticTextInferenceResultFieldMapper.PARSER
// );
// }

@Override
public Optional<Pipeline> getIngestPipeline(IndexMetadata indexMetadata, Processor.Parameters parameters) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -109,9 +109,23 @@ public static InferenceProcessor fromInputFieldConfiguration(
String modelId,
InferenceConfigUpdate inferenceConfig,
List<Factory.InputConfig> inputs,
boolean ignoreMissing
boolean ignoreMissing,
boolean supportChunking
) {
return new InferenceProcessor(client, auditor, tag, description, null, modelId, inferenceConfig, null, inputs, true, ignoreMissing);
return new InferenceProcessor(
client,
auditor,
tag,
description,
null,
modelId,
inferenceConfig,
null,
inputs,
true,
ignoreMissing,
supportChunking
);
}

public static InferenceProcessor fromTargetFieldConfiguration(
Expand All @@ -136,6 +150,7 @@ public static InferenceProcessor fromTargetFieldConfiguration(
fieldMap,
null,
false,
false,
false
);
}
Expand All @@ -151,6 +166,7 @@ public static InferenceProcessor fromTargetFieldConfiguration(
private final List<Factory.InputConfig> inputs;
private final boolean configuredWithInputsFields;
private final boolean ignoreMissing;
private final boolean supportChunking;

private InferenceProcessor(
Client client,
Expand All @@ -163,7 +179,8 @@ private InferenceProcessor(
Map<String, String> fieldMap,
List<Factory.InputConfig> inputs,
boolean configuredWithInputsFields,
boolean ignoreMissing
boolean ignoreMissing,
boolean supportChunking
) {
super(tag, description);
this.configuredWithInputsFields = configuredWithInputsFields;
Expand All @@ -172,6 +189,7 @@ private InferenceProcessor(
this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID);
this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG);
this.ignoreMissing = ignoreMissing;
this.supportChunking = supportChunking;

if (configuredWithInputsFields) {
this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT_OUTPUT);
Expand Down Expand Up @@ -229,12 +247,23 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) {
List<String> requestInputs = new ArrayList<>();
for (var inputFields : inputs) {
try {
var inputText = ingestDocument.getFieldValue(inputFields.inputField, String.class, ignoreMissing);
// field is missing and ignoreMissing == true then a null value is returned.
if (inputText == null) {
inputText = ""; // need to send a non-null request to the same number of results back
if (supportChunking) {
@SuppressWarnings("unchecked")
List<Map<String, Object>> inputChunks = ingestDocument.getFieldValue(inputFields.inputField, List.class, ignoreMissing);
// field is missing and ignoreMissing == true then a null value is returned.
if (inputChunks == null) {
requestInputs.add(""); // need to send a non-null request to the same number of results back
} else {
requestInputs.addAll(inputChunks.stream().map(m -> m.get("text").toString()).toList());
}
} else {
var inputText = ingestDocument.getFieldValue(inputFields.inputField, String.class, ignoreMissing);
// field is missing and ignoreMissing == true then a null value is returned.
if (inputText == null) {
inputText = ""; // need to send a non-null request to the same number of results back
}
requestInputs.add(inputText);
}
requestInputs.add(inputText);
} catch (IllegalArgumentException e) {
if (ingestDocument.hasField(inputFields.inputField())) {
// field is present but of the wrong type, translate to a more meaningful message
Expand Down Expand Up @@ -297,24 +326,43 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc
// String modelIdField = tag == null ? MODEL_ID_RESULTS_FIELD : MODEL_ID_RESULTS_FIELD + "." + tag;

if (configuredWithInputsFields) {
if (response.getInferenceResults().size() != inputs.size()) {
throw new ElasticsearchStatusException(
"number of results [{}] does not match the number of inputs [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
response.getInferenceResults().size(),
inputs.size()
);
}
if (supportChunking) {
int currentResult = 0;
for (Factory.InputConfig input : inputs) {
int inputSize = ingestDocument.getFieldValue(input.inputField, List.class).size();
List<InferenceResults> inputResults = response.getInferenceResults()
.subList(currentResult, currentResult + inputSize);
InferenceResults.writeChunkResultsToField(inputResults, ingestDocument, input.outputBasePath, input.outputField);
currentResult += inputSize;
}
if (currentResult != response.getInferenceResults().size()) {
throw new ElasticsearchStatusException(
"number of results [{}] does not match the number of inputs [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
response.getInferenceResults().size(),
currentResult
);
}
} else {
if (response.getInferenceResults().size() != inputs.size()) {
throw new ElasticsearchStatusException(
"number of results [{}] does not match the number of inputs [{}]",
RestStatus.INTERNAL_SERVER_ERROR,
response.getInferenceResults().size(),
inputs.size()
);
}

for (int i = 0; i < inputs.size(); i++) {
InferenceResults.writeResultToField(
response.getInferenceResults().get(i),
ingestDocument,
inputs.get(i).outputBasePath(),
inputs.get(i).outputField,
response.getId() != null ? response.getId() : modelId,
i == 0
);
for (int i = 0; i < inputs.size(); i++) {
InferenceResults.writeResultToField(
response.getInferenceResults().get(i),
ingestDocument,
inputs.get(i).outputBasePath(),
inputs.get(i).outputField,
response.getId() != null ? response.getId() : modelId,
i == 0
);
}
}
} else {
assert response.getInferenceResults().size() == 1;
Expand Down Expand Up @@ -472,7 +520,8 @@ public InferenceProcessor create(
modelId,
inferenceConfigUpdate,
parsedInputs,
ignoreMissing
ignoreMissing,
false
);
} else {
// old style configuration with target field
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.ml.notifications.InferenceAuditor;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Set;
Expand All @@ -26,8 +27,10 @@ public class SemanticTextInferenceProcessor extends AbstractProcessor implements

public static final String TYPE = "semanticTextInference";
public static final String TAG = "semantic_text";
public static final String TEXT_SUFFIX = ".text";
public static final String INFERENCE_SUFFIX = ".inference";

private final Map<String, Set<String>> fieldsForModels;
private final Map<String, Set<String>> modelForFields;

private final Processor wrappedProcessor;

Expand All @@ -38,18 +41,18 @@ public SemanticTextInferenceProcessor(
Client client,
InferenceAuditor inferenceAuditor,
String description,
Map<String, Set<String>> fieldsForModels
Map<String, Set<String>> modelForFields
) {
super(TAG, description);
this.client = client;
this.inferenceAuditor = inferenceAuditor;

this.fieldsForModels = fieldsForModels;
this.modelForFields = modelForFields;
this.wrappedProcessor = createWrappedProcessor();
}

private Processor createWrappedProcessor() {
InferenceProcessor[] inferenceProcessors = fieldsForModels.entrySet()
InferenceProcessor[] inferenceProcessors = modelForFields.entrySet()
.stream()
.map(e -> createInferenceProcessor(e.getKey(), e.getValue()))
.toArray(InferenceProcessor[]::new);
Expand All @@ -58,7 +61,11 @@ private Processor createWrappedProcessor() {

private InferenceProcessor createInferenceProcessor(String modelId, Set<String> fields) {
List<InferenceProcessor.Factory.InputConfig> inputConfigs = fields.stream()
.map(f -> new InferenceProcessor.Factory.InputConfig(f, SemanticTextInferenceResultFieldMapper.NAME, f, Map.of()))
.map(field -> new InferenceProcessor.Factory.InputConfig(
SemanticTextInferenceResultFieldMapper.NAME + "." + field,
SemanticTextInferenceResultFieldMapper.NAME,
field,
Map.of()))
.toList();

return InferenceProcessor.fromInputFieldConfiguration(
Expand All @@ -69,18 +76,28 @@ private InferenceProcessor createInferenceProcessor(String modelId, Set<String>
modelId,
TextExpansionConfigUpdate.EMPTY_UPDATE,
inputConfigs,
false
false,
true
);
}

@Override
public void execute(IngestDocument ingestDocument, BiConsumer<IngestDocument, Exception> handler) {
modelForFields.forEach((modelId, fields) -> chunkText(ingestDocument, modelId, fields));
getInnerProcessor().execute(ingestDocument, handler);
}

@Override
public IngestDocument execute(IngestDocument ingestDocument) throws Exception {
return getInnerProcessor().execute(ingestDocument);
private static void chunkText(IngestDocument ingestDocument, String modelId, Set<String> fields) {
for (String field : fields) {
String value = ingestDocument.getFieldValue(field, String.class);
if (value != null) {
String[] chunks = value.split("\\.");
ingestDocument.setFieldValue(SemanticTextInferenceResultFieldMapper.NAME + "." + field, new ArrayList<>());
for (String chunk : chunks) {
ingestDocument.appendFieldValue(SemanticTextInferenceResultFieldMapper.NAME + "." + field, Map.of("text", chunk));
}
}
}
}

@Override
Expand Down
Loading

0 comments on commit aea1b0b

Please sign in to comment.