From aea1b0b8bfc73a8991c244996503bfec8be4a393 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Thu, 23 Nov 2023 13:28:00 +0100 Subject: [PATCH] Add hidden support for chunking in InferenceProcessor --- .../ingest/common/ScriptProcessor.java | 2 +- multi-cluster-run.gradle | 15 +++ .../inference/InferenceResults.java | 25 +++++ .../xpack/ml/MachineLearning.java | 16 ++- .../inference/ingest/InferenceProcessor.java | 101 +++++++++++++----- .../SemanticTextInferenceProcessor.java | 35 ++++-- .../ingest/InferenceProcessorTests.java | 16 ++- 7 files changed, 160 insertions(+), 50 deletions(-) create mode 100644 multi-cluster-run.gradle diff --git a/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/ScriptProcessor.java b/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/ScriptProcessor.java index 84e66a3134b69..17e24025d8319 100644 --- a/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/ScriptProcessor.java +++ b/modules/ingest-common/src/main/java/org/elasticsearch/ingest/common/ScriptProcessor.java @@ -50,7 +50,7 @@ public final class ScriptProcessor extends AbstractProcessor { * @param precompiledIngestScriptFactory The {@link Script} precompiled script * @param scriptService The {@link ScriptService} used to execute the script. */ - ScriptProcessor( + public ScriptProcessor( String tag, String description, Script script, diff --git a/multi-cluster-run.gradle b/multi-cluster-run.gradle new file mode 100644 index 0000000000000..616ddaf2c57f0 --- /dev/null +++ b/multi-cluster-run.gradle @@ -0,0 +1,15 @@ +rootProject { + if (project.name == 'elasticsearch') { + afterEvaluate { + testClusters.configureEach { + numberOfNodes = 2 + } + def cluster = testClusters.named("runTask").get() + cluster.getNodes().each { node -> + node.setting('cluster.initial_master_nodes', cluster.getLastNode().getName()) + node.setting('node.roles', '[master,data_hot,data_content]') + } + cluster.getFirstNode().setting('node.roles', '[]') + } + } +} diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceResults.java b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java index 43d5c37563cc0..a4f164c2d7cbe 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceResults.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceResults.java @@ -13,6 +13,9 @@ import org.elasticsearch.ingest.IngestDocument; import org.elasticsearch.xcontent.ToXContentFragment; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -55,6 +58,28 @@ static void writeResultToField( } } + static void writeChunkResultsToField( + List results, + IngestDocument ingestDocument, + @Nullable String basePath, + String outputField) { + Objects.requireNonNull(results, "results"); + Objects.requireNonNull(ingestDocument, "ingestDocument"); + Objects.requireNonNull(outputField, "outputField"); + @SuppressWarnings("unchecked") + List> inputValues = ingestDocument.getFieldValue(basePath + "." + outputField, List.class); + List> outputValues = new ArrayList<>(); + int currentResult = 0; + for (InferenceResults result : results) { + Map outputMap = new HashMap<>(); + outputMap.put("inference", result.asMap(outputField).get(outputField)); + outputMap.putAll(inputValues.get(currentResult)); + outputValues.add(outputMap); + } + + ingestDocument.setFieldValue(basePath + "." + outputField, outputValues); + } + private static void setOrAppendValue(String path, Object value, IngestDocument ingestDocument) { if (ingestDocument.hasField(path)) { ingestDocument.appendFieldValue(path, value); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 19b3976ecc6e7..a9bc74579c165 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -49,7 +49,6 @@ import org.elasticsearch.index.analysis.CharFilterFactory; import org.elasticsearch.index.analysis.TokenizerFactory; import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.index.mapper.MetadataFieldMapper; import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.indices.AssociatedIndexDescriptor; import org.elasticsearch.indices.SystemIndexDescriptor; @@ -365,7 +364,6 @@ import org.elasticsearch.xpack.ml.job.snapshot.upgrader.SnapshotUpgradeTaskExecutor; import org.elasticsearch.xpack.ml.job.task.OpenJobPersistentTasksExecutor; import org.elasticsearch.xpack.ml.mapper.SemanticTextFieldMapper; -import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.AnomalyDetectionAuditor; import org.elasticsearch.xpack.ml.notifications.DataFrameAnalyticsAuditor; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; @@ -2286,13 +2284,13 @@ public Map getMappers() { ); } - @Override - public Map getMetadataMappers() { - return Map.of( - SemanticTextInferenceResultFieldMapper.CONTENT_TYPE, - SemanticTextInferenceResultFieldMapper.PARSER - ); - } +// @Override +// public Map getMetadataMappers() { +// return Map.of( +// SemanticTextInferenceResultFieldMapper.CONTENT_TYPE, +// SemanticTextInferenceResultFieldMapper.PARSER +// ); +// } @Override public Optional getIngestPipeline(IndexMetadata indexMetadata, Processor.Parameters parameters) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java index e600ddd42107f..6ce7aadb52e7f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessor.java @@ -109,9 +109,23 @@ public static InferenceProcessor fromInputFieldConfiguration( String modelId, InferenceConfigUpdate inferenceConfig, List inputs, - boolean ignoreMissing + boolean ignoreMissing, + boolean supportChunking ) { - return new InferenceProcessor(client, auditor, tag, description, null, modelId, inferenceConfig, null, inputs, true, ignoreMissing); + return new InferenceProcessor( + client, + auditor, + tag, + description, + null, + modelId, + inferenceConfig, + null, + inputs, + true, + ignoreMissing, + supportChunking + ); } public static InferenceProcessor fromTargetFieldConfiguration( @@ -136,6 +150,7 @@ public static InferenceProcessor fromTargetFieldConfiguration( fieldMap, null, false, + false, false ); } @@ -151,6 +166,7 @@ public static InferenceProcessor fromTargetFieldConfiguration( private final List inputs; private final boolean configuredWithInputsFields; private final boolean ignoreMissing; + private final boolean supportChunking; private InferenceProcessor( Client client, @@ -163,7 +179,8 @@ private InferenceProcessor( Map fieldMap, List inputs, boolean configuredWithInputsFields, - boolean ignoreMissing + boolean ignoreMissing, + boolean supportChunking ) { super(tag, description); this.configuredWithInputsFields = configuredWithInputsFields; @@ -172,6 +189,7 @@ private InferenceProcessor( this.modelId = ExceptionsHelper.requireNonNull(modelId, MODEL_ID); this.inferenceConfig = ExceptionsHelper.requireNonNull(inferenceConfig, INFERENCE_CONFIG); this.ignoreMissing = ignoreMissing; + this.supportChunking = supportChunking; if (configuredWithInputsFields) { this.inputs = ExceptionsHelper.requireNonNull(inputs, INPUT_OUTPUT); @@ -229,12 +247,23 @@ InferModelAction.Request buildRequest(IngestDocument ingestDocument) { List requestInputs = new ArrayList<>(); for (var inputFields : inputs) { try { - var inputText = ingestDocument.getFieldValue(inputFields.inputField, String.class, ignoreMissing); - // field is missing and ignoreMissing == true then a null value is returned. - if (inputText == null) { - inputText = ""; // need to send a non-null request to the same number of results back + if (supportChunking) { + @SuppressWarnings("unchecked") + List> inputChunks = ingestDocument.getFieldValue(inputFields.inputField, List.class, ignoreMissing); + // field is missing and ignoreMissing == true then a null value is returned. + if (inputChunks == null) { + requestInputs.add(""); // need to send a non-null request to the same number of results back + } else { + requestInputs.addAll(inputChunks.stream().map(m -> m.get("text").toString()).toList()); + } + } else { + var inputText = ingestDocument.getFieldValue(inputFields.inputField, String.class, ignoreMissing); + // field is missing and ignoreMissing == true then a null value is returned. + if (inputText == null) { + inputText = ""; // need to send a non-null request to the same number of results back + } + requestInputs.add(inputText); } - requestInputs.add(inputText); } catch (IllegalArgumentException e) { if (ingestDocument.hasField(inputFields.inputField())) { // field is present but of the wrong type, translate to a more meaningful message @@ -297,24 +326,43 @@ void mutateDocument(InferModelAction.Response response, IngestDocument ingestDoc // String modelIdField = tag == null ? MODEL_ID_RESULTS_FIELD : MODEL_ID_RESULTS_FIELD + "." + tag; if (configuredWithInputsFields) { - if (response.getInferenceResults().size() != inputs.size()) { - throw new ElasticsearchStatusException( - "number of results [{}] does not match the number of inputs [{}]", - RestStatus.INTERNAL_SERVER_ERROR, - response.getInferenceResults().size(), - inputs.size() - ); - } + if (supportChunking) { + int currentResult = 0; + for (Factory.InputConfig input : inputs) { + int inputSize = ingestDocument.getFieldValue(input.inputField, List.class).size(); + List inputResults = response.getInferenceResults() + .subList(currentResult, currentResult + inputSize); + InferenceResults.writeChunkResultsToField(inputResults, ingestDocument, input.outputBasePath, input.outputField); + currentResult += inputSize; + } + if (currentResult != response.getInferenceResults().size()) { + throw new ElasticsearchStatusException( + "number of results [{}] does not match the number of inputs [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + response.getInferenceResults().size(), + currentResult + ); + } + } else { + if (response.getInferenceResults().size() != inputs.size()) { + throw new ElasticsearchStatusException( + "number of results [{}] does not match the number of inputs [{}]", + RestStatus.INTERNAL_SERVER_ERROR, + response.getInferenceResults().size(), + inputs.size() + ); + } - for (int i = 0; i < inputs.size(); i++) { - InferenceResults.writeResultToField( - response.getInferenceResults().get(i), - ingestDocument, - inputs.get(i).outputBasePath(), - inputs.get(i).outputField, - response.getId() != null ? response.getId() : modelId, - i == 0 - ); + for (int i = 0; i < inputs.size(); i++) { + InferenceResults.writeResultToField( + response.getInferenceResults().get(i), + ingestDocument, + inputs.get(i).outputBasePath(), + inputs.get(i).outputField, + response.getId() != null ? response.getId() : modelId, + i == 0 + ); + } } } else { assert response.getInferenceResults().size() == 1; @@ -472,7 +520,8 @@ public InferenceProcessor create( modelId, inferenceConfigUpdate, parsedInputs, - ignoreMissing + ignoreMissing, + false ); } else { // old style configuration with target field diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java index 7d4eab2ec52a6..a2d5376b54aef 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/ingest/SemanticTextInferenceProcessor.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.ml.mapper.SemanticTextInferenceResultFieldMapper; import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import java.util.ArrayList; import java.util.List; import java.util.Map; import java.util.Set; @@ -26,8 +27,10 @@ public class SemanticTextInferenceProcessor extends AbstractProcessor implements public static final String TYPE = "semanticTextInference"; public static final String TAG = "semantic_text"; + public static final String TEXT_SUFFIX = ".text"; + public static final String INFERENCE_SUFFIX = ".inference"; - private final Map> fieldsForModels; + private final Map> modelForFields; private final Processor wrappedProcessor; @@ -38,18 +41,18 @@ public SemanticTextInferenceProcessor( Client client, InferenceAuditor inferenceAuditor, String description, - Map> fieldsForModels + Map> modelForFields ) { super(TAG, description); this.client = client; this.inferenceAuditor = inferenceAuditor; - this.fieldsForModels = fieldsForModels; + this.modelForFields = modelForFields; this.wrappedProcessor = createWrappedProcessor(); } private Processor createWrappedProcessor() { - InferenceProcessor[] inferenceProcessors = fieldsForModels.entrySet() + InferenceProcessor[] inferenceProcessors = modelForFields.entrySet() .stream() .map(e -> createInferenceProcessor(e.getKey(), e.getValue())) .toArray(InferenceProcessor[]::new); @@ -58,7 +61,11 @@ private Processor createWrappedProcessor() { private InferenceProcessor createInferenceProcessor(String modelId, Set fields) { List inputConfigs = fields.stream() - .map(f -> new InferenceProcessor.Factory.InputConfig(f, SemanticTextInferenceResultFieldMapper.NAME, f, Map.of())) + .map(field -> new InferenceProcessor.Factory.InputConfig( + SemanticTextInferenceResultFieldMapper.NAME + "." + field, + SemanticTextInferenceResultFieldMapper.NAME, + field, + Map.of())) .toList(); return InferenceProcessor.fromInputFieldConfiguration( @@ -69,18 +76,28 @@ private InferenceProcessor createInferenceProcessor(String modelId, Set modelId, TextExpansionConfigUpdate.EMPTY_UPDATE, inputConfigs, - false + false, + true ); } @Override public void execute(IngestDocument ingestDocument, BiConsumer handler) { + modelForFields.forEach((modelId, fields) -> chunkText(ingestDocument, modelId, fields)); getInnerProcessor().execute(ingestDocument, handler); } - @Override - public IngestDocument execute(IngestDocument ingestDocument) throws Exception { - return getInnerProcessor().execute(ingestDocument); + private static void chunkText(IngestDocument ingestDocument, String modelId, Set fields) { + for (String field : fields) { + String value = ingestDocument.getFieldValue(field, String.class); + if (value != null) { + String[] chunks = value.split("\\."); + ingestDocument.setFieldValue(SemanticTextInferenceResultFieldMapper.NAME + "." + field, new ArrayList<>()); + for (String chunk : chunks) { + ingestDocument.appendFieldValue(SemanticTextInferenceResultFieldMapper.NAME + "." + field, Map.of("text", chunk)); + } + } + } } @Override diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java index 4821efa29631f..5b8ac82ee3a2d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/ingest/InferenceProcessorTests.java @@ -523,7 +523,8 @@ public void testMutateDocumentWithInputFields() { modelId, new RegressionConfigUpdate("foo", null), inputs, - randomBoolean() + randomBoolean(), + false ); IngestDocument document = TestIngestDocument.emptyIngestDocument(); @@ -553,7 +554,8 @@ public void testMutateDocumentWithInputFieldsNested() { modelId, new RegressionConfigUpdate("foo", null), inputs, - randomBoolean() + randomBoolean(), + false ); IngestDocument document = TestIngestDocument.emptyIngestDocument(); @@ -599,7 +601,8 @@ public void testBuildRequestWithInputFields() { modelId, new EmptyConfigUpdate(), inputs, - randomBoolean() + randomBoolean(), + false ); IngestDocument document = TestIngestDocument.emptyIngestDocument(); @@ -628,7 +631,8 @@ public void testBuildRequestWithInputFields_WrongType() { modelId, new EmptyConfigUpdate(), inputs, - randomBoolean() + randomBoolean(), + false ); IngestDocument document = TestIngestDocument.emptyIngestDocument(); @@ -654,6 +658,7 @@ public void testBuildRequestWithInputFields_MissingField() { modelId, new EmptyConfigUpdate(), inputs, + false, false ); @@ -675,7 +680,8 @@ public void testBuildRequestWithInputFields_MissingField() { modelId, new EmptyConfigUpdate(), inputs, - true + true, + false ); IngestDocument document = TestIngestDocument.emptyIngestDocument();