From 13e22645337165756a837ec6c2daab3743593e52 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 3 Oct 2023 16:23:05 +0100 Subject: [PATCH] Write results to output field --- .../inference/InferenceResults.java | 49 ++++++- .../ClassificationInferenceResults.java | 14 +- .../results/ErrorInferenceResults.java | 6 + .../ml/inference/results/FillMaskResults.java | 7 + .../results/NlpInferenceResults.java | 13 +- .../results/PyTorchPassThroughResults.java | 7 + .../QuestionAnsweringInferenceResults.java | 12 ++ .../results/RawInferenceResults.java | 5 + .../results/RegressionInferenceResults.java | 14 +- .../results/TextEmbeddingResults.java | 7 + .../results/TextExpansionResults.java | 7 + .../TextSimilarityInferenceResults.java | 7 + .../results/WarningInferenceResults.java | 6 + .../InferenceIngestInputConfigIT.java | 133 ++++++++++++++++++ .../inference/ingest/InferenceProcessor.java | 65 ++++++--- .../InferenceProcessorFactoryTests.java | 79 +++++++++-- .../ingest/InferenceProcessorTests.java | 73 ++++++++++ 17 files changed, 468 insertions(+), 36 deletions(-) create mode 100644 x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestInputConfigIT.java diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java index 76dfcae2ba530..43d5c37563cc0 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java @@ -9,6 +9,7 @@ package org.elasticsearch.inference; import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.core.Nullable; import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xcontent.ToXContentFragment; @@ -24,16 +25,58 @@ static void writeResult(InferenceResults results, IngestDocument ingestDocument, Objects.requireNonNull(resultField, "resultField"); Map resultMap = results.asMap(); resultMap.put(MODEL_ID_RESULTS_FIELD, modelId); - if (ingestDocument.hasField(resultField)) { - ingestDocument.appendFieldValue(resultField, resultMap); + setOrAppendValue(resultField, resultMap, ingestDocument); + } + + static void writeResultToField( + InferenceResults results, + IngestDocument ingestDocument, + @Nullable String basePath, + String outputField, + String modelId, + boolean includeModelId + ) { + Objects.requireNonNull(results, "results"); + Objects.requireNonNull(ingestDocument, "ingestDocument"); + Objects.requireNonNull(outputField, "outputField"); + Map resultMap = results.asMap(outputField); + if (includeModelId) { + resultMap.put(MODEL_ID_RESULTS_FIELD, modelId); + } + if (basePath == null) { + // insert the results into the root of the document + for (var entry : resultMap.entrySet()) { + setOrAppendValue(entry.getKey(), entry.getValue(), ingestDocument); + } } else { - ingestDocument.setFieldValue(resultField, resultMap); + for (var entry : resultMap.entrySet()) { + setOrAppendValue(basePath + "." + entry.getKey(), entry.getValue(), ingestDocument); + } + } + } + + private static void setOrAppendValue(String path, Object value, IngestDocument ingestDocument) { + if (ingestDocument.hasField(path)) { + ingestDocument.appendFieldValue(path, value); + } else { + ingestDocument.setFieldValue(path, value); } } String getResultsField(); + /** + * Convert to a map + * @return Map representation of the InferenceResult + */ Map asMap(); + /** + * Convert to a map placing the inference result in {@code outputField} + * @param outputField Write the inference result to this field + * @return Map representation of the InferenceResult + */ + Map asMap(String outputField); + Object predictedValue(); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java index ebf946aa34716..12921166b1489 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResults.java @@ -220,6 +220,19 @@ public String getResultsField() { public Map asMap() { Map map = new LinkedHashMap<>(); map.put(resultsField, predictionFieldType.transformPredictedValue(value(), valueAsString())); + addSupportingFieldsToMap(map); + return map; + } + + @Override + public Map asMap(String outputField) { + Map map = new LinkedHashMap<>(); + map.put(outputField, predictionFieldType.transformPredictedValue(value(), valueAsString())); + addSupportingFieldsToMap(map); + return map; + } + + private void addSupportingFieldsToMap(Map map) { if (topClasses.isEmpty() == false) { map.put(topNumClassesField, topClasses.stream().map(TopClassEntry::asValueMap).collect(Collectors.toList())); } @@ -235,7 +248,6 @@ public Map asMap() { featureImportance.stream().map(ClassificationFeatureImportance::toMap).collect(Collectors.toList()) ); } - return map; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java index daf1175bbc547..c3b3a8f7d88f2 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResults.java @@ -74,6 +74,12 @@ public Map asMap() { return asMap; } + @Override + public Map asMap(String outputField) { + // errors do not have a result + return asMap(); + } + @Override public String toString() { return Strings.toString(this); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java index 863efad3e3b0e..4fad9b535e4e1 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/FillMaskResults.java @@ -55,6 +55,13 @@ void addMapFields(Map map) { map.put(resultsField + "_sequence", predictedSequence); } + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField + "_sequence", predictedSequence); + return map; + } + @Override public String getWriteableName() { return NAME; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java index ab1939013e976..4efb719137c65 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/NlpInferenceResults.java @@ -59,10 +59,21 @@ public final XContentBuilder toXContent(XContentBuilder builder, Params params) public final Map asMap() { Map map = new LinkedHashMap<>(); addMapFields(map); + addSupportingFieldsToMap(map); + return map; + } + + @Override + public Map asMap(String outputField) { + Map map = new LinkedHashMap<>(); + addSupportingFieldsToMap(map); + return map; + } + + private void addSupportingFieldsToMap(Map map) { if (isTruncated) { map.put("is_truncated", isTruncated); } - return map; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java index 668ba6d773c26..de49fb2252ad0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResults.java @@ -65,6 +65,13 @@ void addMapFields(Map map) { map.put(resultsField, inference); } + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField, inference); + return map; + } + @Override public Object predictedValue() { throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java index dec5343fe8ddf..e9e41ce963bec 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResults.java @@ -121,6 +121,18 @@ public String predictedValue() { @Override void addMapFields(Map map) { map.put(resultsField, answer); + addSupportingFieldsToMap(map); + } + + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField, answer); + addSupportingFieldsToMap(map); + return map; + } + + private void addSupportingFieldsToMap(Map map) { map.put(START_OFFSET.getPreferredName(), startOffset); map.put(END_OFFSET.getPreferredName(), endOffset); if (topClasses.isEmpty() == false) { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java index b7744a6c54800..6f1e3e423b240 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RawInferenceResults.java @@ -69,6 +69,11 @@ public Map asMap() { throw new UnsupportedOperationException("[raw] does not support map conversion"); } + @Override + public Map asMap(String outputField) { + throw new UnsupportedOperationException("[raw] does not support map conversion"); + } + @Override public Object predictedValue() { return null; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java index e0af785f827e3..8f1a884c3b1f0 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResults.java @@ -121,11 +121,23 @@ public String getResultsField() { @Override public Map asMap() { Map map = new LinkedHashMap<>(); + addSupportingFieldsToMap(map); map.put(resultsField, value()); + return map; + } + + @Override + public Map asMap(String outputField) { + Map map = new LinkedHashMap<>(); + addSupportingFieldsToMap(map); + map.put(outputField, value()); + return map; + } + + private void addSupportingFieldsToMap(Map map) { if (featureImportance.isEmpty() == false) { map.put(FEATURE_IMPORTANCE, featureImportance.stream().map(RegressionFeatureImportance::toMap).collect(Collectors.toList())); } - return map; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java index bd42439ecb35a..526c2ec7b7aaa 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResults.java @@ -72,6 +72,13 @@ void addMapFields(Map map) { map.put(resultsField, inference); } + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField, inference); + return map; + } + @Override public Object predictedValue() { throw new UnsupportedOperationException("[" + NAME + "] does not support a single predicted value"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java index 0dfc0dccbb8f6..45aa4d51e0ad6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResults.java @@ -119,4 +119,11 @@ void doWriteTo(StreamOutput out) throws IOException { void addMapFields(Map map) { map.put(resultsField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); } + + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField, weightedTokens.stream().collect(Collectors.toMap(WeightedToken::token, WeightedToken::weight))); + return map; + } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java index 6848d47f187bd..b8b75e2bf7eb4 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResults.java @@ -68,6 +68,13 @@ void addMapFields(Map map) { map.put(resultsField, score); } + @Override + public Map asMap(String outputField) { + var map = super.asMap(outputField); + map.put(outputField, score); + return map; + } + @Override public String getWriteableName() { return NAME; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java index a956c91007934..254955f7da9d6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResults.java @@ -75,6 +75,12 @@ public Map asMap() { return asMap; } + @Override + public Map asMap(String outputField) { + // warnings do not have a result + return asMap(); + } + @Override public Object predictedValue() { return null; diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestInputConfigIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestInputConfigIT.java new file mode 100644 index 0000000000000..a3e5c3993398e --- /dev/null +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/InferenceIngestInputConfigIT.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.integration; + +import org.elasticsearch.client.Request; +import org.elasticsearch.client.Response; +import org.elasticsearch.core.Strings; +import org.elasticsearch.xpack.core.ml.utils.MapHelper; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.hasSize; + +public class InferenceIngestInputConfigIT extends PyTorchModelRestTestCase { + + @SuppressWarnings("unchecked") + public void testIngestWithInputFields() throws IOException { + String modelId = "test_ingest_with_input_fields"; + createPassThroughModel(modelId); + putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE); + putVocabulary(List.of("these", "are", "my", "words"), modelId); + startDeployment(modelId); + + String inputOutput = """ + [ + { + "input_field": "body", + "output_field": "body_tokens" + } + ] + """; + String docs = """ + [ + { + "_source": { + "body": "these are" + } + }, + { + "_source": { + "body": "my words" + } + } + ] + """; + var simulateResponse = simulatePipeline(pipelineDefinition(modelId, inputOutput), docs); + var responseMap = entityAsMap(simulateResponse); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + assertNotNull(MapHelper.dig("doc._source.body_tokens", simulatedDocs.get(0))); + assertNotNull(MapHelper.dig("doc._source.body_tokens", simulatedDocs.get(1))); + } + + @SuppressWarnings("unchecked") + public void testIngestWithMultipleInputFields() throws IOException { + String modelId = "test_ingest_with_multiple_input_fields"; + createPassThroughModel(modelId); + putModelDefinition(modelId, PyTorchModelIT.BASE_64_ENCODED_MODEL, PyTorchModelIT.RAW_MODEL_SIZE); + putVocabulary(List.of("these", "are", "my", "words"), modelId); + startDeployment(modelId); + + String inputOutput = """ + [ + { + "input_field": "title", + "output_field": "ml.body_tokens" + }, + { + "input_field": "body", + "output_field": "ml.title_tokens" + } + ] + """; + + String docs = """ + [ + { + "_source": { + "title": "my", + "body": "these are" + } + }, + { + "_source": { + "title": "are", + "body": "my words" + } + } + ] + """; + var simulateResponse = simulatePipeline(pipelineDefinition(modelId, inputOutput), docs); + var responseMap = entityAsMap(simulateResponse); + var simulatedDocs = (List>) responseMap.get("docs"); + assertThat(simulatedDocs, hasSize(2)); + assertNotNull(MapHelper.dig("doc._source.ml.title_tokens", simulatedDocs.get(0))); + assertNotNull(MapHelper.dig("doc._source.ml.body_tokens", simulatedDocs.get(0))); + assertNotNull(MapHelper.dig("doc._source.ml.title_tokens", simulatedDocs.get(1))); + assertNotNull(MapHelper.dig("doc._source.ml.body_tokens", simulatedDocs.get(1))); + } + + private static String pipelineDefinition(String modelId, String inputOutput) { + return Strings.format(""" + { + "processors": [ + { + "inference": { + "model_id": "%s", + "input_output": %s + } + } + ] + }""", modelId, inputOutput); + } + + private Response simulatePipeline(String pipelineDef, String docs) throws IOException { + String simulate = Strings.format(""" + { + "pipeline": %s, + "docs": %s + }""", pipelineDef, docs); + + Request request = new Request("POST", "_ingest/pipeline/_simulate?error_trace=true"); + request.setJsonEntity(simulate); + return client().performRequest(request); + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 4120cd6669e7f..ef78078d1bbcd 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -17,6 +17,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.AbstractProcessor; import org.elasticsearch.ingest.ConfigurationUtils; @@ -94,7 +95,7 @@ public class InferenceProcessor extends AbstractProcessor { private static final String DEFAULT_TARGET_FIELD = "ml.inference"; // input field config - public static final String INPUT = "input"; + public static final String INPUT_OUTPUT = "input_output"; public static final String INPUT_FIELD = "input_field"; public static final String OUTPUT_FIELD = "output_field"; @@ -154,7 +155,7 @@ private InferenceProcessor( this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); if (configuredWithInputsFields) { - this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT); + this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT_OUTPUT); this.targetField = null; this.fieldMap = null; } else { @@ -213,7 +214,7 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { if (configuredWithInputsFields) { List requestInputs = new ArrayList<>(); for (var inputFields : inputs) { - var lookup = (String)fields.get(inputFields.inputField); + var lookup = (String) fields.get(inputFields.inputField); if (lookup == null) { lookup = ""; // need to send a non-null request to the same number of results back } @@ -249,16 +250,22 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc if (configuredWithInputsFields) { if (response.getInferenceResults().size() != inputs.size()) { - throw new ElasticsearchStatusException("number of results [{}] does not match the number of inputs [{}]", - RestStatus.INTERNAL_SERVER_ERROR, response.getInferenceResults().size(), inputs.size()); + throw new ElasticsearchStatusException( + "number of results [{}] does not match the number of inputs [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + response.getInferenceResults().size(), + inputs.size() + ); } - for (int i=0; i< inputs.size(); i++) { - InferenceResults.writeResult( + for (int i = 0; i < inputs.size(); i++) { + InferenceResults.writeResultToField( response.getInferenceResults().get(i), ingestDocument, + inputs.get(i).outputBasePath(), inputs.get(i).outputField, - response.getId() != null ? response.getId() : modelId + response.getId() != null ? response.getId() : modelId, + i == 0 ); } } else { @@ -341,7 +348,6 @@ public InferenceProcessor create( String description, Map config ) { - if (this.maxIngestProcessors <= currentInferenceProcessors) { throw new ElasticsearchStatusException( "Max number of inference processors reached, total inference processors [{}]. " @@ -367,11 +373,11 @@ public InferenceProcessor create( inferenceConfigUpdate = inferenceConfigUpdateFromMap(inferenceConfigMap); } - Map input = ConfigurationUtils.readOptionalMap(TYPE, tag, config, INPUT); - boolean configuredWithInputFields = input != null; + List> inputs = ConfigurationUtils.readOptionalList(TYPE, tag, config, INPUT_OUTPUT); + boolean configuredWithInputFields = inputs != null; if (configuredWithInputFields) { // new style input/output configuration - var parsedInputs = parseInputFields(tag, List.of(input)); + var parsedInputs = parseInputFields(tag, inputs); // validate incompatible settings are not present String targetField = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, TARGET_FIELD); @@ -381,7 +387,7 @@ public InferenceProcessor create( tag, TARGET_FIELD, "option is incompatible with [" - + INPUT + + INPUT_OUTPUT + "]." + " Use the [" + OUTPUT_FIELD @@ -399,9 +405,9 @@ public InferenceProcessor create( + "." + InferenceConfig.RESULTS_FIELD.getPreferredName() + "] setting is incompatible with using [" - + INPUT + + INPUT_OUTPUT + "]. Prefer to use the [" - + INPUT + + INPUT_OUTPUT + "." + OUTPUT_FIELD + "] option to specify where to write the inference results to." @@ -547,7 +553,7 @@ void checkSupportedVersion(InferenceConfig config) { List parseInputFields(String tag, List> inputs) { if (inputs.isEmpty()) { - throw newConfigurationException(TYPE, tag, INPUT, "cannot be empty at least one is required"); + throw newConfigurationException(TYPE, tag, INPUT_OUTPUT, "cannot be empty at least one is required"); } var inputNames = new HashSet(); var outputNames = new HashSet(); @@ -564,10 +570,12 @@ List parseInputFields(String tag, List> inputs) throw duplicatedFieldNameError(OUTPUT_FIELD, outputField, tag); } + var outputPaths = extractBasePathAndFinalElement(outputField); + if (input.isEmpty()) { - parsedInputs.add(new InputConfig(inputField, outputField, Map.of())); + parsedInputs.add(new InputConfig(inputField, outputPaths.v1(), outputPaths.v2(), Map.of())); } else { - parsedInputs.add(new InputConfig(inputField, outputField, new HashMap<>(input))); + parsedInputs.add(new InputConfig(inputField, outputPaths.v1(), outputPaths.v2(), new HashMap<>(input))); } } @@ -578,6 +586,25 @@ private ElasticsearchException duplicatedFieldNameError(String property, String return newConfigurationException(TYPE, tag, property, "names must be unique but [" + fieldName + "] is repeated"); } - public record InputConfig(String inputField, String outputField, Map extras) {} + /** + * {@code outputField} can be a dot '.' seperated path of elements. + * Extract the base path (everything before the last '.') and the final + * element. + * If {@code outputField} does not contain any dotted elements the base + * path is null. + * + * @param outputField The path to split + * @return Tuple of {@code } + */ + static Tuple extractBasePathAndFinalElement(String outputField) { + int lastIndex = outputField.lastIndexOf('.'); + if (lastIndex < 0) { + return new Tuple<>(null, outputField); + } else { + return new Tuple<>(outputField.substring(0, lastIndex), outputField.substring(lastIndex + 1)); + } + } + + public record InputConfig(String inputField, String outputBasePath, String outputField, Map extras) {} } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index e0c366f52aaf1..07d27a1b1bbe8 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -27,7 +27,6 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.core.Tuple; -import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.ingest.PipelineConfiguration; import org.elasticsearch.test.ESTestCase; @@ -434,6 +433,36 @@ public void testCreateProcessorWithFieldMap() { assertThat(fieldMap, hasEntry("source", "dest")); } + public void testCreateProcessorWithInputOutputs() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, false); + + Map config = new HashMap<>(); + config.put(InferenceProcessor.MODEL_ID, "my_model"); + + Map input1 = new HashMap<>(); + input1.put(InferenceProcessor.INPUT_FIELD, "in1"); + input1.put(InferenceProcessor.OUTPUT_FIELD, "out1"); + Map input2 = new HashMap<>(); + input2.put(InferenceProcessor.INPUT_FIELD, "in2"); + input2.put(InferenceProcessor.OUTPUT_FIELD, "out2"); + + List> inputOutputs = new ArrayList<>(); + inputOutputs.add(input1); + inputOutputs.add(input2); + config.put(InferenceProcessor.INPUT_OUTPUT, inputOutputs); + + var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, config); + assertTrue(processor.isConfiguredWithInputsFields()); + assertEquals("my_model", processor.getModelId()); + var configuredInputs = processor.getInputs(); + assertThat(configuredInputs, hasSize(2)); + assertEquals(configuredInputs.get(0).inputField(), "in1"); + assertEquals(configuredInputs.get(0).outputField(), "out1"); + assertEquals(configuredInputs.get(1).inputField(), "in2"); + assertEquals(configuredInputs.get(1).outputField(), "out2"); + + } + public void testCreateProcessorWithDuplicateFields() { Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); @@ -551,7 +580,7 @@ public void testCreateProcessorWithIncompatibleTargetFieldSetting() { { put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "ml"); - put(InferenceProcessor.INPUT, input); + put(InferenceProcessor.INPUT_OUTPUT, List.of(input)); } }; @@ -562,7 +591,7 @@ public void testCreateProcessorWithIncompatibleTargetFieldSetting() { assertThat( ex.getMessage(), containsString( - "[target_field] option is incompatible with [input]. Use the [output_field] option to specify where to write the " + "[target_field] option is incompatible with [input_output]. Use the [output_field] option to specify where to write the " + "inference results to." ) ); @@ -586,7 +615,7 @@ public void testCreateProcessorWithIncompatibleResultFieldSetting() { Map config = new HashMap<>() { { put(InferenceProcessor.MODEL_ID, "my_model"); - put(InferenceProcessor.INPUT, input); + put(InferenceProcessor.INPUT_OUTPUT, List.of(input)); put( InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap( @@ -604,8 +633,8 @@ public void testCreateProcessorWithIncompatibleResultFieldSetting() { assertThat( ex.getMessage(), containsString( - "The [inference_config.results_field] setting is incompatible with using [input]. " - + "Prefer to use the [input.output_field] option to specify where to write the inference results to." + "The [inference_config.results_field] setting is incompatible with using [input_output]. " + + "Prefer to use the [input_output.output_field] option to specify where to write the inference results to." ) ); } @@ -642,10 +671,22 @@ public void testCreateProcessorWithInputFields() { Map config = new HashMap<>() { { put(InferenceProcessor.MODEL_ID, "my_model"); - put(InferenceProcessor.INPUT, inputMap); + put(InferenceProcessor.INPUT_OUTPUT, List.of(inputMap)); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap())); } }; + // create valid inference configs with required fields + if (inferenceConfigType.equals(TextSimilarityConfigUpdate.NAME)) { + var inferenceConfig = new HashMap(); + inferenceConfig.put(TextSimilarityConfig.TEXT.getPreferredName(), "text to compare"); + config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, inferenceConfig)); + } else if (inferenceConfigType.equals(QuestionAnsweringConfigUpdate.NAME)) { + var inferenceConfig = new HashMap(); + inferenceConfig.put(QuestionAnsweringConfig.QUESTION.getPreferredName(), "why is the sky blue?"); + config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, inferenceConfig)); + } else { + config.put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap())); + } var inferenceProcessor = processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config); assertEquals("my_model", inferenceProcessor.getModelId()); @@ -653,7 +694,7 @@ public void testCreateProcessorWithInputFields() { var inputs = inferenceProcessor.getInputs(); assertThat(inputs, hasSize(1)); - assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", "out", Map.of())); + assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", null, "out", Map.of())); assertNull(inferenceProcessor.getFieldMap()); assertNull(inferenceProcessor.getTargetField()); @@ -672,14 +713,14 @@ public void testParsingInputFields() { for (int i = 0; i < numInputs; i++) { Map inputMap = new HashMap<>(); inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i); - inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out" + i); + inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out." + i); inputs.add(inputMap); } var parsedInputs = processorFactory.parseInputFields("my_processor", inputs); assertThat(parsedInputs, hasSize(numInputs)); for (int i = 0; i < numInputs; i++) { - assertEquals(new InferenceProcessor.Factory.InputConfig("in" + i, "out" + i, Map.of()), parsedInputs.get(i)); + assertEquals(new InferenceProcessor.Factory.InputConfig("in" + i, "out", Integer.toString(i), Map.of()), parsedInputs.get(i)); } } @@ -719,6 +760,22 @@ public void testParsingInputFieldsDuplicateFieldNames() { } } + public void testExtractBasePathAndFinalElement() { + { + String path = "foo.bar.result"; + var extractedPaths = InferenceProcessor.Factory.extractBasePathAndFinalElement(path); + assertEquals("foo.bar", extractedPaths.v1()); + assertEquals("result", extractedPaths.v2()); + } + + { + String path = "result"; + var extractedPaths = InferenceProcessor.Factory.extractBasePathAndFinalElement(path); + assertNull(extractedPaths.v1()); + assertEquals("result", extractedPaths.v2()); + } + } + public void testParsingInputFieldsGivenNoInputs() { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( client, @@ -728,7 +785,7 @@ public void testParsingInputFieldsGivenNoInputs() { ); var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", List.of())); - assertThat(e.getMessage(), containsString("[input] cannot be empty at least one is required")); + assertThat(e.getMessage(), containsString("[input_output] cannot be empty at least one is required")); } private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index bc521922e6c05..f85b5e687ac3d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ClassificationInferenceResults; import org.elasticsearch.xpack.core.ml.inference.results.RegressionFeatureImportance; import org.elasticsearch.xpack.core.ml.inference.results.RegressionInferenceResults; +import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResultsTests; import org.elasticsearch.xpack.core.ml.inference.results.TopClassEntry; import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; @@ -491,4 +492,76 @@ public void testMutateDocumentWithModelIdResult() { assertThat(document.getFieldValue("ml.my_processor.foo", Double.class), equalTo(0.7)); assertThat(document.getFieldValue("ml.my_processor.model_id", String.class), equalTo(modelId)); } + + public void testMutateDocumentWithInputFields() { + String modelId = "regression-123"; + List inputs = new ArrayList<>(); + inputs.add(new InferenceProcessor.Factory.InputConfig("body", null, "body_result", Map.of())); + inputs.add(new InferenceProcessor.Factory.InputConfig("content", null, "content_result", Map.of())); + + InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration( + client, + auditor, + "my_processor_tag", + "description", + modelId, + new RegressionConfigUpdate("foo", null), + inputs + ); + + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + + InferModelAction.Response response = new InferModelAction.Response( + List.of(new RegressionInferenceResults(0.7, "ignore"), new RegressionInferenceResults(1.0, "ignore")), + modelId, + true + ); + inferenceProcessor.mutateDocument(response, document); + + assertThat(document.getFieldValue("body_result", Double.class), equalTo(0.7)); + assertThat(document.getFieldValue("content_result", Double.class), equalTo(1.0)); + } + + public void testMutateDocumentWithInputFieldsNested() { + String modelId = "elser"; + List inputs = new ArrayList<>(); + inputs.add(new InferenceProcessor.Factory.InputConfig("body", "ml.results", "body_tokens", Map.of())); + inputs.add(new InferenceProcessor.Factory.InputConfig("content", "ml.results", "content_tokens", Map.of())); + + InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration( + client, + auditor, + "my_processor_tag", + "description", + modelId, + new RegressionConfigUpdate("foo", null), + inputs + ); + + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + + var teResult1 = TextExpansionResultsTests.createRandomResults(); + var teResult2 = TextExpansionResultsTests.createRandomResults(); + InferModelAction.Response response = new InferModelAction.Response(List.of(teResult1, teResult2), modelId, true); + inferenceProcessor.mutateDocument(response, document); + + var bodyTokens = document.getFieldValue("ml.results.body_tokens", HashMap.class); + assertEquals(teResult1.getWeightedTokens().size(), bodyTokens.entrySet().size()); + if (teResult1.getWeightedTokens().isEmpty() == false) { + assertEquals( + (float) bodyTokens.get(teResult1.getWeightedTokens().get(0).token()), + teResult1.getWeightedTokens().get(0).weight(), + 0.001 + ); + } + var contentTokens = document.getFieldValue("ml.results.content_tokens", HashMap.class); + assertEquals(teResult2.getWeightedTokens().size(), contentTokens.entrySet().size()); + if (teResult2.getWeightedTokens().isEmpty() == false) { + assertEquals( + (float) contentTokens.get(teResult2.getWeightedTokens().get(0).token()), + teResult2.getWeightedTokens().get(0).weight(), + 0.001 + ); + } + } }