From d8803b5ca2e44c7eb20f7264d5fd9088043a4c33 Mon Sep 17 00:00:00 2001 From: David Kyle Date: Tue, 12 Sep 2023 11:57:12 +0100 Subject: [PATCH] Add input configuration --- .../ingest/processors/inference.asciidoc | 13 + .../trainedmodel/ClassificationConfig.java | 1 - .../trainedmodel/InferenceConfig.java | 2 + .../ml/inference/trainedmodel/NlpConfig.java | 1 - .../trainedmodel/RegressionConfig.java | 1 - .../inference/ingest/InferenceProcessor.java | 277 +++++++++++++-- .../InferenceProcessorFactoryTests.java | 321 ++++++++++++++++-- .../ingest/InferenceProcessorTests.java | 24 +- 8 files changed, 570 insertions(+), 70 deletions(-) diff --git a/docs/reference/ingest/processors/inference.asciidoc b/docs/reference/ingest/processors/inference.asciidoc index 9b358643df734..c424e37b70b91 100644 --- a/docs/reference/ingest/processors/inference.asciidoc +++ b/docs/reference/ingest/processors/inference.asciidoc @@ -19,10 +19,23 @@ ingested in the pipeline. | `model_id` . | yes | - | (String) The ID or alias for the trained model, or the ID of the deployment. | `target_field` | no | `ml.inference.` | (String) Field added to incoming documents to contain results objects. | `field_map` | no | If defined the model's default field map | (Object) Maps the document field names to the known field names of the model. This mapping takes precedence over any default mappings provided in the model configuration. +| `input_field` | no | | `inference_config` | no | The default settings defined in the model | (Object) Contains the inference type and its options. include::common-options.asciidoc[] |====== +[source,js] +-------------------------------------------------- +{ + "inference": { + "model_id": "model_deployment_for_inference", + "target_field": "FlightDelayMin_prediction_infer", + "input_field": "" + "inference_config": { "regression": {} } + } +} +-------------------------------------------------- +// NOTCONSOLE [source,js] -------------------------------------------------- diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java index 5f5b42605ed7e..156fd76a9419c 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/ClassificationConfig.java @@ -23,7 +23,6 @@ public class ClassificationConfig implements LenientlyParsedInferenceConfig, Str public static final ParseField NAME = new ParseField("classification"); - public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); public static final ParseField TOP_CLASSES_RESULTS_FIELD = new ParseField("top_classes_results_field"); public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java index 23bc2307b52f1..2b043cf022a3d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/InferenceConfig.java @@ -8,6 +8,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.io.stream.VersionedNamedWriteable; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.utils.NamedXContentObject; @@ -15,6 +16,7 @@ public interface InferenceConfig extends NamedXContentObject, VersionedNamedWrit String DEFAULT_TOP_CLASSES_RESULTS_FIELD = "top_classes"; String DEFAULT_RESULTS_FIELD = "predicted_value"; + ParseField RESULTS_FIELD = new ParseField("results_field"); boolean isTargetTypeSupported(TargetType targetType); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java index 689575bf3d0e0..5f4840dceb7bd 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/NlpConfig.java @@ -15,7 +15,6 @@ public interface NlpConfig extends LenientlyParsedInferenceConfig, StrictlyParse ParseField VOCABULARY = new ParseField("vocabulary"); ParseField TOKENIZATION = new ParseField("tokenization"); ParseField CLASSIFICATION_LABELS = new ParseField("classification_labels"); - ParseField RESULTS_FIELD = new ParseField("results_field"); ParseField NUM_TOP_CLASSES = new ParseField("num_top_classes"); MlConfigVersion MINIMUM_NLP_SUPPORTED_VERSION = MlConfigVersion.V_8_0_0; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java index 04365d5c7ec1c..8ea53b2725523 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/trainedmodel/RegressionConfig.java @@ -24,7 +24,6 @@ public class RegressionConfig implements LenientlyParsedInferenceConfig, Strictl public static final ParseField NAME = new ParseField("regression"); private static final MlConfigVersion MIN_SUPPORTED_VERSION = MlConfigVersion.V_7_6_0; private static final TransportVersion MIN_SUPPORTED_TRANSPORT_VERSION = TransportVersions.V_7_6_0; - public static final ParseField RESULTS_FIELD = new ParseField("results_field"); public static final ParseField NUM_TOP_FEATURE_IMPORTANCE_VALUES = new ParseField("num_top_feature_importance_values"); public static RegressionConfig EMPTY_PARAMS = new RegressionConfig(DEFAULT_RESULTS_FIELD, null); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index 5d1c60964b8b9..4120cd6669e7f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -57,15 +57,17 @@ import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.elasticsearch.xpack.ml.utils.InferenceProcessorInfoExtractor; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiConsumer; import java.util.function.Consumer; -import static org.elasticsearch.inference.InferenceResults.MODEL_ID_RESULTS_FIELD; +import static org.elasticsearch.ingest.ConfigurationUtils.newConfigurationException; import static org.elasticsearch.ingest.IngestDocument.INGEST_KEY; import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ClientHelper.executeAsyncWithOrigin; @@ -82,23 +84,57 @@ public class InferenceProcessor extends AbstractProcessor { ); public static final String TYPE = "inference"; + public static final String MODEL_ID = "model_id"; public static final String INFERENCE_CONFIG = "inference_config"; + + // target field style mappings public static final String TARGET_FIELD = "target_field"; public static final String FIELD_MAPPINGS = "field_mappings"; public static final String FIELD_MAP = "field_map"; private static final String DEFAULT_TARGET_FIELD = "ml.inference"; + // input field config + public static final String INPUT = "input"; + public static final String INPUT_FIELD = "input_field"; + public static final String OUTPUT_FIELD = "output_field"; + + public static InferenceProcessor fromInputFieldConfiguration( + Client client, + InferenceAuditor auditor, + String tag, + String description, + String modelId, + InferenceConfigUpdate inferenceConfig, + List inputs + ) { + return new InferenceProcessor(client, auditor, tag, description, null, modelId, inferenceConfig, null, inputs, true); + } + + public static InferenceProcessor fromTargetFieldConfiguration( + Client client, + InferenceAuditor auditor, + String tag, + String description, + String targetField, + String modelId, + InferenceConfigUpdate inferenceConfig, + Map fieldMap + ) { + return new InferenceProcessor(client, auditor, tag, description, targetField, modelId, inferenceConfig, fieldMap, null, false); + } + private final Client client; private final String modelId; - private final String targetField; private final InferenceConfigUpdate inferenceConfig; private final Map fieldMap; private final InferenceAuditor auditor; private volatile boolean previouslyLicensed; private final AtomicBoolean shouldAudit = new AtomicBoolean(true); + private final List inputs; + private final boolean configuredWithInputsFields; - public InferenceProcessor( + private InferenceProcessor( Client client, InferenceAuditor auditor, String tag, @@ -106,15 +142,26 @@ public InferenceProcessor( String targetField, String modelId, InferenceConfigUpdate inferenceConfig, - Map fieldMap + Map fieldMap, + List inputs, + boolean configuredWithInputsFields ) { super(tag, description); + this.configuredWithInputsFields = configuredWithInputsFields; this.client = ExceptionsHelper.requireNonNull(client, "client"); - this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD); this.auditor = ExceptionsHelper.requireNonNull(auditor, "auditor"); - this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID_RESULTS_FIELD); + this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); - this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP); + + if (configuredWithInputsFields) { + this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT); + this.targetField = null; + this.fieldMap = null; + } else { + this.inputs = null; + this.targetField = ExceptionsHelper.requireNonNull(targetField, TARGET_FIELD); + this.fieldMap = ExceptionsHelper.requireNonNull(fieldMap, FIELD_MAP); + } } public String getModelId() { @@ -123,11 +170,20 @@ public String getModelId() { @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { + + InferModelAction.Request request; + try { + request = buildRequest(ingestDocument); + } catch (ElasticsearchStatusException e) { + handler.accept(ingestDocument, e); + return; + } + executeAsyncWithOrigin( client, ML_ORIGIN, InferModelAction.INSTANCE, - this.buildRequest(ingestDocument), + request, ActionListener.wrap(r -> handleResponse(r, ingestDocument, handler), e -> handler.accept(ingestDocument, e)) ); } @@ -153,8 +209,21 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { if (ingestDocument.getIngestMetadata().isEmpty() == false) { fields.put(INGEST_KEY, ingestDocument.getIngestMetadata()); } - LocalModel.mapFieldsIfNecessary(fields, fieldMap); - return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed); + + if (configuredWithInputsFields) { + List requestInputs = new ArrayList<>(); + for (var inputFields : inputs) { + var lookup = (String)fields.get(inputFields.inputField); + if (lookup == null) { + lookup = ""; // need to send a non-null request to the same number of results back + } + requestInputs.add(lookup); + } + return InferModelAction.Request.forTextInput(modelId, inferenceConfig, requestInputs); + } else { + LocalModel.mapFieldsIfNecessary(fields, fieldMap); + return InferModelAction.Request.forIngestDocs(modelId, List.of(fields), inferenceConfig, previouslyLicensed); + } } void auditWarningAboutLicenseIfNecessary() { @@ -171,13 +240,36 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc if (response.getInferenceResults().isEmpty()) { throw new ElasticsearchStatusException("Unexpected empty inference response", RestStatus.INTERNAL_SERVER_ERROR); } - assert response.getInferenceResults().size() == 1; - InferenceResults.writeResult( - response.getInferenceResults().get(0), - ingestDocument, - targetField, - response.getId() != null ? response.getId() : modelId - ); + + // TODO + // The field where the model Id is written to. + // If multiple inference processors are in the same pipeline, it is wise to tag them + // The tag will keep default value entries from stepping on each other + // String modelIdField = tag == null ? MODEL_ID_RESULTS_FIELD : MODEL_ID_RESULTS_FIELD + "." + tag; + + if (configuredWithInputsFields) { + if (response.getInferenceResults().size() != inputs.size()) { + throw new ElasticsearchStatusException("number of results [{}] does not match the number of inputs [{}]", + RestStatus.INTERNAL_SERVER_ERROR, response.getInferenceResults().size(), inputs.size()); + } + + for (int i=0; i< inputs.size(); i++) { + InferenceResults.writeResult( + response.getInferenceResults().get(i), + ingestDocument, + inputs.get(i).outputField, + response.getId() != null ? response.getId() : modelId + ); + } + } else { + assert response.getInferenceResults().size() == 1; + InferenceResults.writeResult( + response.getInferenceResults().get(0), + ingestDocument, + targetField, + response.getId() != null ? response.getId() : modelId + ); + } } @Override @@ -190,6 +282,30 @@ public String getType() { return TYPE; } + boolean isConfiguredWithInputsFields() { + return configuredWithInputsFields; + } + + public List getInputs() { + return inputs; + } + + Map getFieldMap() { + return fieldMap; + } + + String getTargetField() { + return targetField; + } + + InferenceConfigUpdate getInferenceConfig() { + return inferenceConfig; + } + + InferenceAuditor getAuditor() { + return auditor; + } + public static final class Factory implements Processor.Factory, Consumer { private static final Logger logger = LogManager.getLogger(Factory.class); @@ -237,37 +353,91 @@ public InferenceProcessor create( ); } - String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID_RESULTS_FIELD); - String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag; - // If multiple inference processors are in the same pipeline, it is wise to tag them - // The tag will keep default value entries from stepping on each other - String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField); - Map fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAP); - if (fieldMap == null) { - fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); - // TODO Remove in 9?.x - if (fieldMap != null) { - LoggingDeprecationHandler.INSTANCE.logRenamedField(null, () -> null, FIELD_MAPPINGS, FIELD_MAP); - } - } - if (fieldMap == null) { - fieldMap = Collections.emptyMap(); - } + String modelId = ConfigurationUtils.readStringProperty(TYPE, tag, config, MODEL_ID); InferenceConfigUpdate inferenceConfigUpdate; Map inferenceConfigMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, INFERENCE_CONFIG); if (inferenceConfigMap == null) { if (minNodeVersion.before(EmptyConfigUpdate.minimumSupportedVersion())) { // an inference config is required when the empty update is not supported - throw ConfigurationUtils.newConfigurationException(TYPE, tag, INFERENCE_CONFIG, "required property is missing"); + throw newConfigurationException(TYPE, tag, INFERENCE_CONFIG, "required property is missing"); } - inferenceConfigUpdate = new EmptyConfigUpdate(); } else { inferenceConfigUpdate = inferenceConfigUpdateFromMap(inferenceConfigMap); } - return new InferenceProcessor(client, auditor, tag, description, targetField, modelId, inferenceConfigUpdate, fieldMap); + Map input = ConfigurationUtils.readOptionalMap(TYPE, tag, config, INPUT); + boolean configuredWithInputFields = input != null; + if (configuredWithInputFields) { + // new style input/output configuration + var parsedInputs = parseInputFields(tag, List.of(input)); + + // validate incompatible settings are not present + String targetField = ConfigurationUtils.readOptionalStringProperty(TYPE, tag, config, TARGET_FIELD); + if (targetField != null) { + throw newConfigurationException( + TYPE, + tag, + TARGET_FIELD, + "option is incompatible with [" + + INPUT + + "]." + + " Use the [" + + OUTPUT_FIELD + + "] option to specify where to write the inference results to." + ); + } + + if (inferenceConfigUpdate.getResultsField() != null) { + throw newConfigurationException( + TYPE, + tag, + null, + "The [" + + INFERENCE_CONFIG + + "." + + InferenceConfig.RESULTS_FIELD.getPreferredName() + + "] setting is incompatible with using [" + + INPUT + + "]. Prefer to use the [" + + INPUT + + "." + + OUTPUT_FIELD + + "] option to specify where to write the inference results to." + ); + } + + return fromInputFieldConfiguration(client, auditor, tag, description, modelId, inferenceConfigUpdate, parsedInputs); + } else { + // old style configuration with target field + String defaultTargetField = tag == null ? DEFAULT_TARGET_FIELD : DEFAULT_TARGET_FIELD + "." + tag; + // If multiple inference processors are in the same pipeline, it is wise to tag them + // The tag will keep default value entries from stepping on each other + String targetField = ConfigurationUtils.readStringProperty(TYPE, tag, config, TARGET_FIELD, defaultTargetField); + Map fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAP); + if (fieldMap == null) { + fieldMap = ConfigurationUtils.readOptionalMap(TYPE, tag, config, FIELD_MAPPINGS); + // TODO Remove in 9?.x + if (fieldMap != null) { + LoggingDeprecationHandler.INSTANCE.logRenamedField(null, () -> null, FIELD_MAPPINGS, FIELD_MAP); + } + } + + if (fieldMap == null) { + fieldMap = Collections.emptyMap(); + } + return fromTargetFieldConfiguration( + client, + auditor, + tag, + description, + targetField, + modelId, + inferenceConfigUpdate, + fieldMap + ); + } } // Package private for testing @@ -374,5 +544,40 @@ void checkSupportedVersion(InferenceConfig config) { ); } } + + List parseInputFields(String tag, List> inputs) { + if (inputs.isEmpty()) { + throw newConfigurationException(TYPE, tag, INPUT, "cannot be empty at least one is required"); + } + var inputNames = new HashSet(); + var outputNames = new HashSet(); + var parsedInputs = new ArrayList(); + + for (var input : inputs) { + String inputField = ConfigurationUtils.readStringProperty(TYPE, tag, input, INPUT_FIELD); + String outputField = ConfigurationUtils.readStringProperty(TYPE, tag, input, OUTPUT_FIELD); + + if (inputNames.add(inputField) == false) { + throw duplicatedFieldNameError(INPUT_FIELD, inputField, tag); + } + if (outputNames.add(outputField) == false) { + throw duplicatedFieldNameError(OUTPUT_FIELD, outputField, tag); + } + + if (input.isEmpty()) { + parsedInputs.add(new InputConfig(inputField, outputField, Map.of())); + } else { + parsedInputs.add(new InputConfig(inputField, outputField, new HashMap<>(input))); + } + } + + return parsedInputs; + } + + private ElasticsearchException duplicatedFieldNameError(String property, String fieldName, String tag) { + return newConfigurationException(TYPE, tag, property, "names must be unique but [" + fieldName + "] is repeated"); + } + + public record InputConfig(String inputField, String outputField, Map extras) {} } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java index 3854ad984467d..e0c366f52aaf1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorFactoryTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.ml.inference.ingest; import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchParseException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterName; @@ -36,21 +37,33 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.FillMaskConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.NerConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.PassThroughConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.QuestionAnsweringConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.RegressionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextClassificationConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextEmbeddingConfigUpdate; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfig; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextExpansionConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.TextSimilarityConfigUpdate; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfig; +import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ZeroShotClassificationConfigUpdate; import org.elasticsearch.xpack.ml.MachineLearning; import org.junit.Before; import java.io.IOException; import java.net.InetAddress; +import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.HashMap; @@ -59,7 +72,11 @@ import java.util.Map; import java.util.Set; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.empty; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasEntry; +import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -89,7 +106,7 @@ public void setUpVariables() { clusterService = new ClusterService(settings, clusterSettings, tp, null); } - public void testCreateProcessorWithTooManyExisting() throws Exception { + public void testCreateProcessorWithTooManyExisting() { Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); includeNodeInfoValues.forEach(includeNodeInfo -> { @@ -135,7 +152,7 @@ public void testCreateProcessorWithInvalidInferenceConfig() { Map config = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("unknown_type", Collections.emptyMap())); } @@ -158,7 +175,7 @@ public void testCreateProcessorWithInvalidInferenceConfig() { Map config2 = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap("regression", "boom")); } @@ -172,7 +189,7 @@ public void testCreateProcessorWithInvalidInferenceConfig() { Map config3 = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put(InferenceProcessor.INFERENCE_CONFIG, Collections.emptyMap()); } @@ -185,7 +202,7 @@ public void testCreateProcessorWithInvalidInferenceConfig() { }); } - public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { + public void testCreateProcessorWithTooOldMinNodeVersion() { Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); includeNodeInfoValues.forEach(includeNodeInfo -> { @@ -203,7 +220,7 @@ public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { Map regression = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put( InferenceProcessor.INFERENCE_CONFIG, @@ -224,7 +241,7 @@ public void testCreateProcessorWithTooOldMinNodeVersion() throws IOException { Map classification = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put( InferenceProcessor.INFERENCE_CONFIG, @@ -315,7 +332,7 @@ public void testCreateProcessorWithEmptyConfigNotSupportedOnOldNode() throws IOE Map minimalConfig = new HashMap<>() { { - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); } }; @@ -342,7 +359,7 @@ public void testCreateProcessor() { Map regression = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put( InferenceProcessor.INFERENCE_CONFIG, @@ -351,12 +368,18 @@ public void testCreateProcessor() { } }; - processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression); + var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression); + assertEquals(includeNodeInfo, processor.getAuditor().includeNodeInfo()); + assertFalse(processor.isConfiguredWithInputsFields()); + assertEquals("my_model", processor.getModelId()); + assertEquals("result", processor.getTargetField()); + assertThat(processor.getFieldMap().entrySet(), empty()); + assertNull(processor.getInputs()); Map classification = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); put( InferenceProcessor.INFERENCE_CONFIG, @@ -368,19 +391,49 @@ public void testCreateProcessor() { } }; - processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification); + processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, classification); + assertFalse(processor.isConfiguredWithInputsFields()); Map mininmal = new HashMap<>() { { - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "result"); } }; - processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, mininmal); + processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, mininmal); + assertFalse(processor.isConfiguredWithInputsFields()); + assertEquals("my_model", processor.getModelId()); + assertEquals("result", processor.getTargetField()); + assertNull(processor.getInputs()); }); } + public void testCreateProcessorWithFieldMap() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory(client, clusterService, Settings.EMPTY, false); + + Map config = new HashMap<>() { + { + put(InferenceProcessor.FIELD_MAP, Collections.singletonMap("source", "dest")); + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "result"); + put( + InferenceProcessor.INFERENCE_CONFIG, + Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()) + ); + } + }; + + var processor = processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, config); + assertFalse(processor.isConfiguredWithInputsFields()); + assertEquals("my_model", processor.getModelId()); + assertEquals("result", processor.getTargetField()); + assertNull(processor.getInputs()); + var fieldMap = processor.getFieldMap(); + assertThat(fieldMap.entrySet(), hasSize(1)); + assertThat(fieldMap, hasEntry("source", "dest")); + } + public void testCreateProcessorWithDuplicateFields() { Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); @@ -395,7 +448,7 @@ public void testCreateProcessorWithDuplicateFields() { Map regression = new HashMap<>() { { put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); - put(InferenceResults.MODEL_ID_RESULTS_FIELD, "my_model"); + put(InferenceProcessor.MODEL_ID, "my_model"); put(InferenceProcessor.TARGET_FIELD, "ml"); put( InferenceProcessor.INFERENCE_CONFIG, @@ -415,7 +468,41 @@ public void testCreateProcessorWithDuplicateFields() { }); } - public void testParseFromMap() { + public void testCreateProcessorWithIgnoreMissing() { + Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); + + includeNodeInfoValues.forEach(includeNodeInfo -> { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + includeNodeInfo + ); + + Map regression = new HashMap<>() { + { + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.FIELD_MAP, Collections.emptyMap()); + put("ignore_missing", Boolean.TRUE); + put( + InferenceProcessor.INFERENCE_CONFIG, + Collections.singletonMap( + RegressionConfig.NAME.getPreferredName(), + Collections.singletonMap(RegressionConfig.RESULTS_FIELD.getPreferredName(), "warning") + ) + ); + } + }; + + Exception ex = expectThrows( + Exception.class, + () -> processorFactory.create(Collections.emptyMap(), "my_inference_processor", null, regression) + ); + assertThat(ex.getMessage(), equalTo("Invalid inference config. " + "More than one field is configured as [warning]")); + }); + } + + public void testParseInferenceConfigFromMap() { Set includeNodeInfoValues = new HashSet<>(Arrays.asList(true, false)); includeNodeInfoValues.forEach(includeNodeInfo -> { @@ -433,6 +520,7 @@ public void testParseFromMap() { Tuple.tuple(PassThroughConfig.NAME, Map.of()), Tuple.tuple(TextClassificationConfig.NAME, Map.of()), Tuple.tuple(TextEmbeddingConfig.NAME, Map.of()), + Tuple.tuple(TextExpansionConfig.NAME, Map.of()), Tuple.tuple(ZeroShotClassificationConfig.NAME, Map.of()), Tuple.tuple(QuestionAnsweringConfig.NAME, Map.of("question", "What is the answer to life, the universe and everything?")) )) { @@ -444,8 +532,203 @@ public void testParseFromMap() { }); } - private static ClusterState buildClusterState(Metadata metadata) { - return ClusterState.builder(new ClusterName("_name")).metadata(metadata).build(); + public void testCreateProcessorWithIncompatibleTargetFieldSetting() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + Map input = new HashMap<>() { + { + put(InferenceProcessor.INPUT_FIELD, "in"); + put(InferenceProcessor.OUTPUT_FIELD, "out"); + } + }; + + Map config = new HashMap<>() { + { + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.TARGET_FIELD, "ml"); + put(InferenceProcessor.INPUT, input); + } + }; + + ElasticsearchParseException ex = expectThrows( + ElasticsearchParseException.class, + () -> processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config) + ); + assertThat( + ex.getMessage(), + containsString( + "[target_field] option is incompatible with [input]. Use the [output_field] option to specify where to write the " + + "inference results to." + ) + ); + } + + public void testCreateProcessorWithIncompatibleResultFieldSetting() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + Map input = new HashMap<>() { + { + put(InferenceProcessor.INPUT_FIELD, "in"); + put(InferenceProcessor.OUTPUT_FIELD, "out"); + } + }; + + Map config = new HashMap<>() { + { + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.INPUT, input); + put( + InferenceProcessor.INFERENCE_CONFIG, + Collections.singletonMap( + TextExpansionConfig.NAME, + Collections.singletonMap(TextExpansionConfig.RESULTS_FIELD.getPreferredName(), "foo") + ) + ); + } + }; + + ElasticsearchParseException ex = expectThrows( + ElasticsearchParseException.class, + () -> processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config) + ); + assertThat( + ex.getMessage(), + containsString( + "The [inference_config.results_field] setting is incompatible with using [input]. " + + "Prefer to use the [input.output_field] option to specify where to write the inference results to." + ) + ); + } + + public void testCreateProcessorWithInputFields() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + Map inputMap = new HashMap<>() { + { + put(InferenceProcessor.INPUT_FIELD, "in"); + put(InferenceProcessor.OUTPUT_FIELD, "out"); + } + }; + + String inferenceConfigType = randomFrom( + ClassificationConfigUpdate.NAME.getPreferredName(), + RegressionConfigUpdate.NAME.getPreferredName(), + FillMaskConfigUpdate.NAME, + NerConfigUpdate.NAME, + PassThroughConfigUpdate.NAME, + QuestionAnsweringConfigUpdate.NAME, + TextClassificationConfigUpdate.NAME, + TextEmbeddingConfigUpdate.NAME, + TextExpansionConfigUpdate.NAME, + TextSimilarityConfigUpdate.NAME, + ZeroShotClassificationConfigUpdate.NAME + ); + + Map config = new HashMap<>() { + { + put(InferenceProcessor.MODEL_ID, "my_model"); + put(InferenceProcessor.INPUT, inputMap); + put(InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(inferenceConfigType, Collections.emptyMap())); + } + }; + + var inferenceProcessor = processorFactory.create(Collections.emptyMap(), "processor_with_inputs", null, config); + assertEquals("my_model", inferenceProcessor.getModelId()); + assertTrue(inferenceProcessor.isConfiguredWithInputsFields()); + + var inputs = inferenceProcessor.getInputs(); + assertThat(inputs, hasSize(1)); + assertEquals(inputs.get(0), new InferenceProcessor.Factory.InputConfig("in", "out", Map.of())); + + assertNull(inferenceProcessor.getFieldMap()); + assertNull(inferenceProcessor.getTargetField()); + } + + public void testParsingInputFields() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + int numInputs = randomIntBetween(1, 3); + List> inputs = new ArrayList<>(); + for (int i = 0; i < numInputs; i++) { + Map inputMap = new HashMap<>(); + inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i); + inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out" + i); + inputs.add(inputMap); + } + + var parsedInputs = processorFactory.parseInputFields("my_processor", inputs); + assertThat(parsedInputs, hasSize(numInputs)); + for (int i = 0; i < numInputs; i++) { + assertEquals(new InferenceProcessor.Factory.InputConfig("in" + i, "out" + i, Map.of()), parsedInputs.get(i)); + } + } + + public void testParsingInputFieldsDuplicateFieldNames() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + int numInputs = 2; + { + List> inputs = new ArrayList<>(); + for (int i = 0; i < numInputs; i++) { + Map inputMap = new HashMap<>(); + inputMap.put(InferenceProcessor.INPUT_FIELD, "in"); + inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out" + i); + inputs.add(inputMap); + } + + var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", inputs)); + assertThat(e.getMessage(), containsString("[input_field] names must be unique but [in] is repeated")); + } + + { + List> inputs = new ArrayList<>(); + for (int i = 0; i < numInputs; i++) { + Map inputMap = new HashMap<>(); + inputMap.put(InferenceProcessor.INPUT_FIELD, "in" + i); + inputMap.put(InferenceProcessor.OUTPUT_FIELD, "out"); + inputs.add(inputMap); + } + + var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", inputs)); + assertThat(e.getMessage(), containsString("[output_field] names must be unique but [out] is repeated")); + } + } + + public void testParsingInputFieldsGivenNoInputs() { + InferenceProcessor.Factory processorFactory = new InferenceProcessor.Factory( + client, + clusterService, + Settings.EMPTY, + randomBoolean() + ); + + var e = expectThrows(ElasticsearchParseException.class, () -> processorFactory.parseInputFields("my_processor", List.of())); + assertThat(e.getMessage(), containsString("[input] cannot be empty at least one is required")); } private static ClusterState buildClusterStateWithModelReferences(String... modelId) throws IOException { @@ -513,7 +796,7 @@ private static Map forEachProcessorWithInference(String modelId) private static Map inferenceProcessorForModel(String modelId) { return Collections.singletonMap(InferenceProcessor.TYPE, new HashMap<>() { { - put(InferenceResults.MODEL_ID_RESULTS_FIELD, modelId); + put(InferenceProcessor.MODEL_ID, modelId); put( InferenceProcessor.INFERENCE_CONFIG, Collections.singletonMap(RegressionConfig.NAME.getPreferredName(), Collections.emptyMap()) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 32d66401785ea..bc521922e6c05 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -56,7 +56,7 @@ public void setUpVariables() { public void testMutateDocumentWithClassification() { String targetField = "ml.my_processor"; - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -89,7 +89,7 @@ public void testMutateDocumentWithClassification() { public void testMutateDocumentClassificationTopNClasses() { ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, null, null); ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, null, PredictionFieldType.STRING); - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -126,7 +126,7 @@ public void testMutateDocumentClassificationTopNClasses() { public void testMutateDocumentClassificationFeatureInfluence() { ClassificationConfig classificationConfig = new ClassificationConfig(2, null, null, 2, PredictionFieldType.STRING); ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, null, null, 2, null); - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -180,7 +180,7 @@ public void testMutateDocumentClassificationFeatureInfluence() { public void testMutateDocumentClassificationTopNClassesWithSpecificField() { ClassificationConfig classificationConfig = new ClassificationConfig(2, "result", "tops", null, PredictionFieldType.STRING); ClassificationConfigUpdate classificationConfigUpdate = new ClassificationConfigUpdate(2, "result", "tops", null, null); - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -217,7 +217,7 @@ public void testMutateDocumentClassificationTopNClassesWithSpecificField() { public void testMutateDocumentRegression() { RegressionConfig regressionConfig = new RegressionConfig("foo"); RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", null); - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -244,7 +244,7 @@ public void testMutateDocumentRegression() { public void testMutateDocumentRegressionWithTopFeatures() { RegressionConfig regressionConfig = new RegressionConfig("foo", 2); RegressionConfigUpdate regressionConfigUpdate = new RegressionConfigUpdate("foo", 2); - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -280,7 +280,7 @@ public void testGenerateRequestWithEmptyMapping() { String modelId = "model"; Integer topNClasses = randomBoolean() ? null : randomIntBetween(1, 10); - InferenceProcessor processor = new InferenceProcessor( + InferenceProcessor processor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -320,7 +320,7 @@ public void testGenerateWithMapping() { fieldMapping.put("categorical", "new_categorical"); fieldMapping.put("_ingest._value", "metafield"); - InferenceProcessor processor = new InferenceProcessor( + InferenceProcessor processor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -362,7 +362,7 @@ public void testGenerateWithMappingNestedFields() { fieldMapping.put("value2", "new_value2"); fieldMapping.put("categorical.bar", "new_categorical"); - InferenceProcessor processor = new InferenceProcessor( + InferenceProcessor processor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -390,7 +390,7 @@ public void testGenerateWithMappingNestedFields() { public void testHandleResponseLicenseChanged() { String targetField = "regression_value"; - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -440,7 +440,7 @@ public void testHandleResponseLicenseChanged() { public void testMutateDocumentWithWarningResult() { String targetField = "regression_value"; - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor", @@ -468,7 +468,7 @@ public void testMutateDocumentWithWarningResult() { public void testMutateDocumentWithModelIdResult() { String modelAlias = "special_model"; String modelId = "regression-123"; - InferenceProcessor inferenceProcessor = new InferenceProcessor( + InferenceProcessor inferenceProcessor = InferenceProcessor.fromTargetFieldConfiguration( client, auditor, "my_processor",