From 46d9594c0fab9ca8cc72dd189e0e84feb83b8d86 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Fri, 6 Oct 2023 15:10:45 +0100 Subject: [PATCH] [ML] More checks and tests for parsing Inference processor config (#100335) Following on from #100205 this PR adds more tests and checks for corner cases when parsing the configuration. --- .../ingest/processors/inference.asciidoc | 3 +- .../ClassificationInferenceResultsTests.java | 9 +- .../results/ErrorInferenceResultsTests.java | 4 +- .../results/InferenceResultsTestCase.java | 52 +++++++- .../ml/inference/results/NerResultsTests.java | 9 +- ...lpClassificationInferenceResultsTests.java | 9 +- .../PyTorchPassThroughResultsTests.java | 7 +- ...uestionAnsweringInferenceResultsTests.java | 9 +- .../RegressionInferenceResultsTests.java | 7 +- .../results/TextEmbeddingResultsTests.java | 8 +- .../results/TextExpansionResultsTests.java | 7 +- .../TextSimilarityInferenceResultsTests.java | 9 +- .../results/WarningInferenceResultsTests.java | 4 +- .../inference/ingest/InferenceProcessor.java | 99 +++++++++++--- .../InferenceProcessorFactoryTests.java | 123 ++++++++++++++++-- .../ingest/InferenceProcessorTests.java | 111 +++++++++++++++- 16 files changed, 401 insertions(+), 69 deletions(-) diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index f0c029d99e14a..75b667e634cdb 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -17,10 +17,11 @@ ingested in the pipeline. |====== | Name | Required | Default | Description | `model_id` . | yes | - | (String) The ID or alias for the trained model, or the ID of the deployment. -| `input_output` | no | (List) Input fields for inference and output (destination) fields for the inference results. This options is incompatible with the `target_field` and `field_map` options. +| `input_output` | no | - | (List) Input fields for inference and output (destination) fields for the inference results. This options is incompatible with the `target_field` and `field_map` options. | `target_field` | no | `ml.inference.` | (String) Field added to incoming documents to contain results objects. | `field_map` | no | If defined the model's default field map | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration. | `inference_config` | no | The default settings defined in the model | (Object) Contains the inference type and its options. +| `ignore_missing` | no | `false` | (Boolean) If `true` and any of the input fields defined in `input_ouput` are missing then those missing fields are quietly ignored, otherwise a missing field causes a failure. Only applies when using `input_output` configurations to explicitly list the input fields. include::common-options.asciidoc[] |====== diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java index 28e318c0dab48..a937fef23e4bc 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ClassificationInferenceResultsTests.java @@ -209,8 +209,13 @@ public void testToXContent() throws IOException { } @Override - void assertFieldValues(ClassificationInferenceResults createdInstance, IngestDocument document, String resultsField) { - String path = resultsField + "." + createdInstance.getResultsField(); + void assertFieldValues( + ClassificationInferenceResults createdInstance, + IngestDocument document, + String parentField, + String resultsField + ) { + String path = parentField + resultsField; switch (createdInstance.getPredictionFieldType()) { case NUMBER -> assertThat(document.getFieldValue(path, Double.class), equalTo(createdInstance.predictedValue())); case STRING -> assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue())); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java index e25b2da55b15b..20b2b4737c8b5 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/ErrorInferenceResultsTests.java @@ -34,7 +34,7 @@ protected ErrorInferenceResults mutateInstance(ErrorInferenceResults instance) t } @Override - void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String resultsField) { - assertThat(document.getFieldValue(resultsField + ".error", String.class), equalTo(createdInstance.getException().getMessage())); + void assertFieldValues(ErrorInferenceResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertThat(document.getFieldValue(parentField + "error", String.class), equalTo(createdInstance.getException().getMessage())); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java index 27503547e5705..bda9eed40659c 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/InferenceResultsTestCase.java @@ -18,6 +18,8 @@ import java.io.IOException; import java.util.Map; +import static org.hamcrest.Matchers.equalTo; + abstract class InferenceResultsTestCase extends AbstractWireSerializingTestCase { public void testWriteToIngestDoc() throws IOException { @@ -34,11 +36,57 @@ public void testWriteToIngestDoc() throws IOException { document.setFieldValue(parentField, Map.of()); } InferenceResults.writeResult(inferenceResult, document, parentField, modelId); - assertFieldValues(inferenceResult, document, alreadyHasResult ? parentField + ".1" : parentField); + + String expectedOutputPath = alreadyHasResult ? parentField + ".1." : parentField + "."; + + assertThat( + document.getFieldValue(expectedOutputPath + InferenceResults.MODEL_ID_RESULTS_FIELD, String.class), + equalTo(modelId) + ); + if (inferenceResult instanceof NlpInferenceResults nlpInferenceResults && nlpInferenceResults.isTruncated()) { + assertTrue(document.getFieldValue(expectedOutputPath + "is_truncated", Boolean.class)); + } + + assertFieldValues(inferenceResult, document, expectedOutputPath, inferenceResult.getResultsField()); + } + } + + private void testWriteToIngestDocField() throws IOException { + for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { + T inferenceResult = createTestInstance(); + if (randomBoolean()) { + inferenceResult = copyInstance(inferenceResult, TransportVersion.current()); + } + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + String outputField = randomAlphaOfLength(10); + String modelId = randomAlphaOfLength(10); + String parentField = randomBoolean() ? null : randomAlphaOfLength(10); + boolean writeModelId = randomBoolean(); + + boolean alreadyHasResult = randomBoolean(); + if (alreadyHasResult && parentField != null) { + document.setFieldValue(parentField, Map.of()); + } + InferenceResults.writeResultToField(inferenceResult, document, parentField, outputField, modelId, writeModelId); + + String expectedOutputPath = parentField == null ? "" : parentField + "."; + if (alreadyHasResult && parentField != null) { + expectedOutputPath = expectedOutputPath + "1."; + } + + if (writeModelId) { + String modelIdPath = expectedOutputPath + InferenceResults.MODEL_ID_RESULTS_FIELD; + assertThat(document.getFieldValue(modelIdPath, String.class), equalTo(modelId)); + } + if (inferenceResult instanceof NlpInferenceResults nlpInferenceResults && nlpInferenceResults.isTruncated()) { + assertTrue(document.getFieldValue(expectedOutputPath + "is_truncated", Boolean.class)); + } + + assertFieldValues(inferenceResult, document, expectedOutputPath, outputField); } } - abstract void assertFieldValues(T createdInstance, IngestDocument document, String resultsField); + abstract void assertFieldValues(T createdInstance, IngestDocument document, String parentField, String resultsField); public void testWriteToDocAndSerialize() throws IOException { for (int i = 0; i < NUMBER_OF_TEST_RUNS; ++i) { diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java index 68c69ff67fa48..4be49807d27b0 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NerResultsTests.java @@ -92,15 +92,12 @@ public void testAsMap() { @Override @SuppressWarnings("unchecked") - void assertFieldValues(NerResults createdInstance, IngestDocument document, String resultsField) { - assertThat( - document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), String.class), - equalTo(createdInstance.getAnnotatedResult()) - ); + void assertFieldValues(NerResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertThat(document.getFieldValue(parentField + resultsField, String.class), equalTo(createdInstance.getAnnotatedResult())); if (createdInstance.getEntityGroups().size() > 0) { List> resultList = (List>) document.getFieldValue( - resultsField + "." + ENTITY_FIELD, + parentField + ENTITY_FIELD, List.class ); assertThat(resultList.size(), equalTo(createdInstance.getEntityGroups().size())); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java index ac3cb638d88d1..f05b8ac3d8eab 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/NlpClassificationInferenceResultsTests.java @@ -79,8 +79,13 @@ protected Writeable.Reader instanceReader() { } @Override - void assertFieldValues(NlpClassificationInferenceResults createdInstance, IngestDocument document, String resultsField) { - String path = resultsField + "." + createdInstance.getResultsField(); + void assertFieldValues( + NlpClassificationInferenceResults createdInstance, + IngestDocument document, + String parentField, + String resultsField + ) { + String path = parentField + resultsField; assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue())); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java index c2386010a8f67..e6b38a08a75ba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/PyTorchPassThroughResultsTests.java @@ -58,10 +58,7 @@ public void testAsMap() { } @Override - void assertFieldValues(PyTorchPassThroughResults createdInstance, IngestDocument document, String resultsField) { - assertArrayEquals( - createdInstance.getInference(), - document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), double[][].class) - ); + void assertFieldValues(PyTorchPassThroughResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertArrayEquals(createdInstance.getInference(), document.getFieldValue(parentField + resultsField, double[][].class)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java index c9c65ea3f3538..29e7a5627cdd3 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/QuestionAnsweringInferenceResultsTests.java @@ -83,8 +83,13 @@ protected Writeable.Reader instanceReader() { } @Override - void assertFieldValues(QuestionAnsweringInferenceResults createdInstance, IngestDocument document, String resultsField) { - String path = resultsField + "." + createdInstance.getResultsField(); + void assertFieldValues( + QuestionAnsweringInferenceResults createdInstance, + IngestDocument document, + String parentField, + String resultsField + ) { + String path = parentField + resultsField; assertThat(document.getFieldValue(path, String.class), equalTo(createdInstance.predictedValue())); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java index 27a07e8f996f7..9eef7a42da9a8 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/RegressionInferenceResultsTests.java @@ -95,10 +95,7 @@ public void testToXContent() { } @Override - void assertFieldValues(RegressionInferenceResults createdInstance, IngestDocument document, String resultsField) { - assertThat( - document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), Double.class), - closeTo(createdInstance.value(), 1e-10) - ); + void assertFieldValues(RegressionInferenceResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertThat(document.getFieldValue(parentField + resultsField, Double.class), closeTo(createdInstance.value(), 1e-10)); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java index d29e79698e2c9..fd3ac7f8c0d12 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextEmbeddingResultsTests.java @@ -55,11 +55,7 @@ public void testAsMap() { } @Override - void assertFieldValues(TextEmbeddingResults createdInstance, IngestDocument document, String resultsField) { - assertArrayEquals( - document.getFieldValue(resultsField + "." + createdInstance.getResultsField(), double[].class), - createdInstance.getInference(), - 1e-10 - ); + void assertFieldValues(TextEmbeddingResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertArrayEquals(document.getFieldValue(parentField + resultsField, double[].class), createdInstance.getInference(), 1e-10); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java index c3b2fbf6fb556..82487960dfe8f 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextExpansionResultsTests.java @@ -47,11 +47,8 @@ protected TextExpansionResults mutateInstance(TextExpansionResults instance) { @Override @SuppressWarnings("unchecked") - void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String resultsField) { - var ingestedTokens = (Map) document.getFieldValue( - resultsField + '.' + createdInstance.getResultsField(), - Map.class - ); + void assertFieldValues(TextExpansionResults createdInstance, IngestDocument document, String parentField, String resultsField) { + var ingestedTokens = (Map) document.getFieldValue(parentField + resultsField, Map.class); var tokenMap = createdInstance.getWeightedTokens() .stream() .collect(Collectors.toMap(TextExpansionResults.WeightedToken::token, TextExpansionResults.WeightedToken::weight)); diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResultsTests.java index e543e04a01085..b72f89bf0ae97 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/TextSimilarityInferenceResultsTests.java @@ -33,8 +33,13 @@ protected Writeable.Reader instanceReader() { } @Override - void assertFieldValues(TextSimilarityInferenceResults createdInstance, IngestDocument document, String resultsField) { - String path = resultsField + "." + createdInstance.getResultsField(); + void assertFieldValues( + TextSimilarityInferenceResults createdInstance, + IngestDocument document, + String parentField, + String resultsField + ) { + String path = parentField + resultsField; assertThat(document.getFieldValue(path, Double.class), equalTo(createdInstance.predictedValue())); } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java index 68379f888d11b..594fffc0c91f4 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/results/WarningInferenceResultsTests.java @@ -33,7 +33,7 @@ protected Writeable.Reader instanceReader() { } @Override - void assertFieldValues(WarningInferenceResults createdInstance, IngestDocument document, String resultsField) { - assertThat(document.getFieldValue(resultsField + ".warning", String.class), equalTo(createdInstance.getWarning())); + void assertFieldValues(WarningInferenceResults createdInstance, IngestDocument document, String parentField, String resultsField) { + assertThat(document.getFieldValue(parentField + "warning", String.class), equalTo(createdInstance.getWarning())); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index ef78078d1bbcd..905317713263e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -87,6 +87,7 @@ public class InferenceProcessor extends AbstractProcessor { public static final String TYPE = "inference"; public static final String MODEL_ID = "model_id"; public static final String INFERENCE_CONFIG = "inference_config"; + public static final String IGNORE_MISSING = "ignore_missing"; // target field style mappings public static final String TARGET_FIELD = "target_field"; @@ -106,9 +107,10 @@ public static InferenceProcessor fromInputFieldConfiguration( String description, String modelId, InferenceConfigUpdate inferenceConfig, - List inputs + List inputs, + boolean ignoreMissing ) { - return new InferenceProcessor(client, auditor, tag, description, null, modelId, inferenceConfig, null, inputs, true); + return new InferenceProcessor(client, auditor, tag, description, null, modelId, inferenceConfig, null, inputs, true, ignoreMissing); } public static InferenceProcessor fromTargetFieldConfiguration( @@ -121,7 +123,20 @@ public static InferenceProcessor fromTargetFieldConfiguration( InferenceConfigUpdate inferenceConfig, Map fieldMap ) { - return new InferenceProcessor(client, auditor, tag, description, targetField, modelId, inferenceConfig, fieldMap, null, false); + // ignore_missing only applies to when using the input_field config + return new InferenceProcessor( + client, + auditor, + tag, + description, + targetField, + modelId, + inferenceConfig, + fieldMap, + null, + false, + false + ); } private final Client client; @@ -134,6 +149,7 @@ public static InferenceProcessor fromTargetFieldConfiguration( private final AtomicBoolean shouldAudit = new AtomicBoolean(true); private final List inputs; private final boolean configuredWithInputsFields; + private final boolean ignoreMissing; private InferenceProcessor( Client client, @@ -145,7 +161,8 @@ private InferenceProcessor( InferenceConfigUpdate inferenceConfig, Map fieldMap, List inputs, - boolean configuredWithInputsFields + boolean configuredWithInputsFields, + boolean ignoreMissing ) { super(tag, description); this.configuredWithInputsFields = configuredWithInputsFields; @@ -153,6 +170,7 @@ private InferenceProcessor( this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor"); this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); + this.ignoreMissing = ignoreMissing; if (configuredWithInputsFields) { this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT_OUTPUT); @@ -205,23 +223,36 @@ void handleResponse(InferModelAction.Response response, IngestDocument ingestDoc } InferModelAction.Request buildRequest(IngestDocument ingestDocument) { - Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); - // Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor) - if (ingestDocument.getIngestMetadata().isEmpty() == false) { - fields.put(INGEST_KEY, ingestDocument.getIngestMetadata()); - } - if (configuredWithInputsFields) { + // ignore missing only applies when using an input field list List requestInputs = new ArrayList<>(); for (var inputFields : inputs) { - var lookup = (String) fields.get(inputFields.inputField); - if (lookup == null) { - lookup = ""; // need to send a non-null request to the same number of results back + try { + var inputText = ingestDocument.getFieldValue(inputFields.inputField, String.class, ignoreMissing); + // field is missing and ignoreMissing == true then a null value is returned. + if (inputText == null) { + inputText = ""; // need to send a non-null request to the same number of results back + } + requestInputs.add(inputText); + } catch (IllegalArgumentException e) { + if (ingestDocument.hasField(inputFields.inputField())) { + // field is present but of the wrong type, translate to a more meaningful message + throw new IllegalArgumentException( + "input field [" + inputFields.inputField + "] cannot be processed because it is not a text field" + ); + } else { + throw e; + } } - requestInputs.add(lookup); } return InferModelAction.Request.forTextInput(modelId, inferenceConfig, requestInputs); } else { + Map fields = new HashMap<>(ingestDocument.getSourceAndMetadata()); + // Add ingestMetadata as previous processors might have added metadata from which we are predicting (see: foreach processor) + if (ingestDocument.getIngestMetadata().isEmpty() == false) { + fields.put(INGEST_KEY, ingestDocument.getIngestMetadata()); + } + LocalModel.mapFieldsIfNecessary(fields, fieldMap); return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed); } @@ -373,11 +404,13 @@ public InferenceProcessor create( inferenceConfigUpdate = inferenceConfigUpdateFromMap(inferenceConfigMap); } - List> inputs = ConfigurationUtils.readOptionalList(TYPE, tag, config, INPUT_OUTPUT); + List> inputs = readOptionalInputOutPutConfig(config, tag); boolean configuredWithInputFields = inputs != null; if (configuredWithInputFields) { // new style input/output configuration var parsedInputs = parseInputFields(tag, inputs); + // ignore missing only applies to input field config + boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, tag, config, IGNORE_MISSING, false); // validate incompatible settings are not present String targetField = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, TARGET_FIELD); @@ -414,7 +447,16 @@ public InferenceProcessor create( ); } - return fromInputFieldConfiguration(client, auditor, tag, description, modelId, inferenceConfigUpdate, parsedInputs); + return fromInputFieldConfiguration( + client, + auditor, + tag, + description, + modelId, + inferenceConfigUpdate, + parsedInputs, + ignoreMissing + ); } else { // old style configuration with target field String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag; @@ -553,7 +595,7 @@ void checkSupportedVersion(InferenceConfig config) { List parseInputFields(String tag, List> inputs) { if (inputs.isEmpty()) { - throw newConfigurationException(TYPE, tag, INPUT_OUTPUT, "cannot be empty at least one is required"); + throw newConfigurationException(TYPE, tag, INPUT_OUTPUT, "property cannot be empty at least one is required"); } var inputNames = new HashSet(); var outputNames = new HashSet(); @@ -582,6 +624,29 @@ List parseInputFields(String tag, List> inputs) return parsedInputs; } + @SuppressWarnings("unchecked") + List> readOptionalInputOutPutConfig(Map config, String tag) { + Object inputOutputs = config.remove(INPUT_OUTPUT); + if (inputOutputs == null) { + return null; + } + + // input_output may be a single map or a list of maps + if (inputOutputs instanceof List inputOutputList) { + if (inputOutputList.isEmpty() == false) { + // check it is a list of maps + if (inputOutputList.get(0) instanceof Map == false) { + throw ConfigurationUtils.newConfigurationException(TYPE, tag, INPUT_OUTPUT, "property isn't a list of maps"); + } + } + return (List>) inputOutputList; + } else if (inputOutputs instanceof Map) { + return List.of((Map) inputOutputs); + } else { + throw ConfigurationUtils.newConfigurationException(TYPE, tag, INPUT_OUTPUT, "property isn't a map or list of maps"); + } + } + private ElasticsearchException duplicatedFieldNameError(String property, String fieldName, String tag) { return newConfigurationException(TYPE, tag, property, "names must be unique but [" + fieldName + "] is repeated"); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 07d27a1b1bbe8..a8d3af2efe7cd 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -647,10 +647,16 @@ public void testCreateProcessorWithInputFields() { randomBoolean() ); - Map inputMap = new HashMap<>() { + Map inputMap1 = new HashMap<>() { { - put(InferenceProcessor.INPUT_FIELD, "in"); - put(InferenceProcessor.OUTPUT_FIELD, "out"); + put(InferenceProcessor.INPUT_FIELD, "in1"); + put(InferenceProcessor.OUTPUT_FIELD, "out1"); + } + }; + Map inputMap2 = new HashMap<>() { + { + put(InferenceProcessor.INPUT_FIELD, "in2"); + put(InferenceProcessor.OUTPUT_FIELD, "out2"); } }; @@ -671,8 +677,7 @@ public void testCreateProcessorWithInputFields() { Map config = new HashMap<>() { { put(InferenceProcessor.MODEL_ID, "my_model"); - put(InferenceProcessor.INPUT_OUTPUT, List.of(inputMap)); - put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap())); + put(InferenceProcessor.INPUT_OUTPUT, List.of(inputMap1, inputMap2)); } }; // create valid inference configs with required fields @@ -693,13 +698,115 @@ public void testCreateProcessorWithInputFields() { assertTrue(inferenceProcessor.isConfiguredWithInputsFields()); var inputs = inferenceProcessor.getInputs(); - assertThat(inputs, hasSize(1)); - assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", null, "out", Map.of())); + assertThat(inputs, hasSize(2)); + assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in1", null, "out1", Map.of())); + assertEquals(inputs.get(1), new InferenceProcessor.Factory.InputConfig("in2", null, "out2", Map.of())); assertNull(inferenceProcessor.getFieldMap()); assertNull(inferenceProcessor.getTargetField()); } + public void testCreateProcessorWithInputFieldSingleOrList() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + for (var isList : new boolean[] { true, false }) { + Map inputMap = new HashMap<>() { + { + put(InferenceProcessor.INPUT_FIELD, "in"); + put(InferenceProcessor.OUTPUT_FIELD, "out"); + } + }; + + Map config = new HashMap<>(); + config.put(InferenceProcessor.MODEL_ID, "my_model"); + if (isList) { + config.put(InferenceProcessor.INPUT_OUTPUT, List.of(inputMap)); + } else { + config.put(InferenceProcessor.INPUT_OUTPUT, inputMap); + } + + if (randomBoolean()) { + config.put( + InferenceProcessor.INFERENCE_CONFIG, + Collections.singletonMap(TextExpansionConfigUpdate.NAME, Collections.emptyMap()) + ); + } + + var inferenceProcessor = processorFactory.create(Collections.emptyMap(), "processor_with_single_input", null, config); + assertEquals("my_model", inferenceProcessor.getModelId()); + assertTrue(inferenceProcessor.isConfiguredWithInputsFields()); + + var inputs = inferenceProcessor.getInputs(); + assertThat(inputs, hasSize(1)); + assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", null, "out", Map.of())); + + assertNull(inferenceProcessor.getFieldMap()); + assertNull(inferenceProcessor.getTargetField()); + } + } + + public void testCreateProcessorWithInputFieldWrongType() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + { + Map config = new HashMap<>(); + config.put(InferenceProcessor.MODEL_ID, "my_model"); + config.put(InferenceProcessor.INPUT_OUTPUT, List.of(1, 2, 3)); + + var e = expectThrows( + ElasticsearchParseException.class, + () -> processorFactory.create(Collections.emptyMap(), "processor_with_bad_config", null, config) + ); + assertThat(e.getMessage(), containsString("[input_output] property isn't a list of maps")); + } + { + Map config = new HashMap<>(); + config.put(InferenceProcessor.MODEL_ID, "my_model"); + config.put(InferenceProcessor.INPUT_OUTPUT, Boolean.TRUE); + + var e = expectThrows( + ElasticsearchParseException.class, + () -> processorFactory.create(Collections.emptyMap(), "processor_with_bad_config", null, config) + ); + assertThat(e.getMessage(), containsString("[input_output] property isn't a map or list of maps")); + } + { + Map badMap = new HashMap<>(); + badMap.put(Boolean.TRUE, "foo"); + Map config = new HashMap<>(); + config.put(InferenceProcessor.MODEL_ID, "my_model"); + config.put(InferenceProcessor.INPUT_OUTPUT, badMap); + + var e = expectThrows( + ElasticsearchParseException.class, + () -> processorFactory.create(Collections.emptyMap(), "processor_with_bad_config", null, config) + ); + assertThat(e.getMessage(), containsString("[input_field] required property is missing")); + } + { + // empty list + Map config = new HashMap<>(); + config.put(InferenceProcessor.MODEL_ID, "my_model"); + config.put(InferenceProcessor.INPUT_OUTPUT, List.of()); + + var e = expectThrows( + ElasticsearchParseException.class, + () -> processorFactory.create(Collections.emptyMap(), "processor_with_bad_config", null, config) + ); + assertThat(e.getMessage(), containsString("[input_output] property cannot be empty at least one is required")); + } + } + public void testParsingInputFields() { InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( client, @@ -785,7 +892,7 @@ public void testParsingInputFieldsGivenNoInputs() { ); var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", List.of())); - assertThat(e.getMessage(), containsString("[input_output] cannot be empty at least one is required")); + assertThat(e.getMessage(), containsString("[input_output] property cannot be empty at least one is required")); } private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index f85b5e687ac3d..a68084aa6eb28 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.WarningInferenceResults; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.EmptyConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PredictionFieldType; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; @@ -34,6 +35,7 @@ import java.util.Map; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.nullValue; @@ -506,7 +508,8 @@ public void testMutateDocumentWithInputFields() { "description", modelId, new RegressionConfigUpdate("foo", null), - inputs + inputs, + randomBoolean() ); IngestDocument document = TestIngestDocument.emptyIngestDocument(); @@ -535,7 +538,8 @@ public void testMutateDocumentWithInputFieldsNested() { "description", modelId, new RegressionConfigUpdate("foo", null), - inputs + inputs, + randomBoolean() ); IngestDocument document = TestIngestDocument.emptyIngestDocument(); @@ -545,6 +549,8 @@ public void testMutateDocumentWithInputFieldsNested() { InferModelAction.Response response = new InferModelAction.Response(List.of(teResult1, teResult2), modelId, true); inferenceProcessor.mutateDocument(response, document); + assertEquals(modelId, document.getFieldValue("ml.results.model_id", String.class)); + var bodyTokens = document.getFieldValue("ml.results.body_tokens", HashMap.class); assertEquals(teResult1.getWeightedTokens().size(), bodyTokens.entrySet().size()); if (teResult1.getWeightedTokens().isEmpty() == false) { @@ -564,4 +570,105 @@ public void testMutateDocumentWithInputFieldsNested() { ); } } + + public void testBuildRequestWithInputFields() { + String modelId = "elser"; + List inputs = new ArrayList<>(); + inputs.add(new InferenceProcessor.Factory.InputConfig("body.text", "ml.results", "body_tokens", Map.of())); + inputs.add(new InferenceProcessor.Factory.InputConfig("title.text", "ml.results", "title_tokens", Map.of())); + + InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration( + client, + auditor, + "my_processor_tag", + "description", + modelId, + new EmptyConfigUpdate(), + inputs, + randomBoolean() + ); + + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + document.setFieldValue("body.text", "body_text"); + document.setFieldValue("title.text", "title_text"); + document.setFieldValue("unrelated", "text"); + + var request = inferenceProcessor.buildRequest(document); + assertTrue(request.getObjectsToInfer().isEmpty()); + var requestInputs = request.getTextInput(); + assertThat(requestInputs, contains("body_text", "title_text")); + } + + public void testBuildRequestWithInputFields_WrongType() { + String modelId = "elser"; + List inputs = new ArrayList<>(); + inputs.add(new InferenceProcessor.Factory.InputConfig("not_a_string", "ml.results", "tokens", Map.of())); + + InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration( + client, + auditor, + "my_processor_tag", + "description", + modelId, + new EmptyConfigUpdate(), + inputs, + randomBoolean() + ); + + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + document.setFieldValue("not_a_string", Boolean.TRUE); + document.setFieldValue("unrelated", "text"); + + var e = expectThrows(IllegalArgumentException.class, () -> inferenceProcessor.buildRequest(document)); + assertThat(e.getMessage(), containsString("input field [not_a_string] cannot be processed because it is not a text field")); + } + + public void testBuildRequestWithInputFields_MissingField() { + String modelId = "elser"; + List inputs = new ArrayList<>(); + inputs.add(new InferenceProcessor.Factory.InputConfig("body.text", "ml.results", "body_tokens", Map.of())); + inputs.add(new InferenceProcessor.Factory.InputConfig("title.text", "ml.results", "title_tokens", Map.of())); + + { + InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration( + client, + auditor, + "my_processor_tag", + "description", + modelId, + new EmptyConfigUpdate(), + inputs, + false + ); + + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + document.setFieldValue("body.text", "body_text"); + document.setFieldValue("unrelated", "text"); + + var e = expectThrows(IllegalArgumentException.class, () -> inferenceProcessor.buildRequest(document)); + assertThat(e.getMessage(), containsString("field [title] not present as part of path [title.text]")); + } + + // same test with ignore_missing == true + { + InferenceProcessor inferenceProcessor = InferenceProcessor.fromInputFieldConfiguration( + client, + auditor, + "my_processor_tag", + "description", + modelId, + new EmptyConfigUpdate(), + inputs, + true + ); + + IngestDocument document = TestIngestDocument.emptyIngestDocument(); + document.setFieldValue("body.text", "body_text"); + document.setFieldValue("unrelated", 1.0); + + var request = inferenceProcessor.buildRequest(document); + var requestInputs = request.getTextInput(); + assertThat(requestInputs, contains("body_text", "")); + } + } }