diff --git a/server/src/main/java/module-info.java b/server/src/main/java/module-info.java index 1a082e7558577..62849bf20acfc 100644 --- a/server/src/main/java/module-info.java +++ b/server/src/main/java/module-info.java @@ -134,6 +134,7 @@ exports org.elasticsearch.action.fieldcaps; exports org.elasticsearch.action.get; exports org.elasticsearch.action.index; + exports org.elasticsearch.action.inference; exports org.elasticsearch.action.ingest; exports org.elasticsearch.action.resync; exports org.elasticsearch.action.search; diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 60c14740658bb..33872a12aa950 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -146,7 +146,7 @@ static TransportVersion def(int id) { public static final TransportVersion TOO_MANY_SCROLL_CONTEXTS_EXCEPTION_ADDED = def(8_521_00_0); public static final TransportVersion UNCONTENDED_REGISTER_ANALYSIS_ADDED = def(8_522_00_0); public static final TransportVersion TRANSFORM_GET_CHECKPOINT_TIMEOUT_ADDED = def(8_523_00_0); - + public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_524_00_0); /* * STOP! READ THIS FIRST! No, really, * ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _ diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 13d10be86bd68..88854aba1746b 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -48,7 +48,6 @@ import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.core.Assertions; import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.Index; @@ -59,6 +58,8 @@ import org.elasticsearch.index.shard.ShardId; import org.elasticsearch.indices.IndexClosedException; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.ingest.BulkRequestPreprocessor; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.node.NodeClosedException; import org.elasticsearch.tasks.Task; @@ -95,7 +96,7 @@ public class TransportBulkAction extends HandledTransportAction bulkRequestPreprocessors; private final LongSupplier relativeTimeProvider; private final IngestActionForwarder ingestForwarder; private final NodeClient client; @@ -110,6 +111,7 @@ public TransportBulkAction( TransportService transportService, ClusterService clusterService, IngestService ingestService, + FieldInferenceBulkRequestPreprocessor fieldInferenceBulkRequestPreprocessor, NodeClient client, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, @@ -121,6 +123,7 @@ public TransportBulkAction( transportService, clusterService, ingestService, + fieldInferenceBulkRequestPreprocessor, client, actionFilters, indexNameExpressionResolver, @@ -135,6 +138,7 @@ public TransportBulkAction( TransportService transportService, ClusterService clusterService, IngestService ingestService, + FieldInferenceBulkRequestPreprocessor fieldInferenceBulkRequestPreprocessor, NodeClient client, ActionFilters actionFilters, IndexNameExpressionResolver indexNameExpressionResolver, @@ -146,7 +150,7 @@ public TransportBulkAction( Objects.requireNonNull(relativeTimeProvider); this.threadPool = threadPool; this.clusterService = clusterService; - this.ingestService = ingestService; + this.bulkRequestPreprocessors = List.of(ingestService, fieldInferenceBulkRequestPreprocessor); this.relativeTimeProvider = relativeTimeProvider; this.ingestForwarder = new IngestActionForwarder(transportService); this.client = client; @@ -270,45 +274,9 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec final long startTime = relativeTime(); final AtomicArray responses = new AtomicArray<>(bulkRequest.requests.size()); - boolean hasIndexRequestsWithPipelines = false; - final Metadata metadata = clusterService.state().getMetadata(); - final Version minNodeVersion = clusterService.state().getNodes().getMinNodeVersion(); - for (DocWriteRequest actionRequest : bulkRequest.requests) { - IndexRequest indexRequest = getIndexWriteRequest(actionRequest); - if (indexRequest != null) { - IngestService.resolvePipelinesAndUpdateIndexRequest(actionRequest, indexRequest, metadata); - hasIndexRequestsWithPipelines |= IngestService.hasPipeline(indexRequest); - } - - if (actionRequest instanceof IndexRequest ir) { - ir.checkAutoIdWithOpTypeCreateSupportedByVersion(minNodeVersion); - if (ir.getAutoGeneratedTimestamp() != IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP) { - throw new IllegalArgumentException("autoGeneratedTimestamp should not be set externally"); - } - } - } - - if (hasIndexRequestsWithPipelines) { - // this method (doExecute) will be called again, but with the bulk requests updated from the ingest node processing but - // also with IngestService.NOOP_PIPELINE_NAME on each request. This ensures that this on the second time through this method, - // this path is never taken. - ActionListener.run(listener, l -> { - if (Assertions.ENABLED) { - final boolean arePipelinesResolved = bulkRequest.requests() - .stream() - .map(TransportBulkAction::getIndexWriteRequest) - .filter(Objects::nonNull) - .allMatch(IndexRequest::isPipelineResolved); - assert arePipelinesResolved : bulkRequest; - } - if (clusterService.localNode().isIngestNode()) { - processBulkIndexIngestRequest(task, bulkRequest, executorName, l); - } else { - ingestForwarder.forwardIngestRequest(BulkAction.INSTANCE, bulkRequest, l); - } - }); - return; - } + // Preprocess bulk requests with ingestion services. If needs preprocessing, then return early as the preprocessing + // action will invoke this method again + if (preprocessBulkRequest(task, bulkRequest, executorName, listener)) return; // Attempt to create all the indices that we're going to need during the bulk before we start. // Step 1: collect all the indices in the request @@ -801,7 +769,41 @@ private long relativeTime() { return relativeTimeProvider.getAsLong(); } - private void processBulkIndexIngestRequest( + private boolean preprocessBulkRequest(Task task, BulkRequest bulkRequest, String executorName, ActionListener listener) { + final Metadata metadata = clusterService.state().getMetadata(); + final Version minNodeVersion = clusterService.state().getNodes().getMinNodeVersion(); + boolean needsProcessing = false; + for (BulkRequestPreprocessor preprocessor : bulkRequestPreprocessors) { + for (DocWriteRequest docWriteRequest : bulkRequest.requests) { + IndexRequest indexRequest = getIndexWriteRequest(docWriteRequest); + if (indexRequest != null) { + needsProcessing = needsProcessing || preprocessor.needsProcessing(docWriteRequest, indexRequest, metadata); + } + + if (docWriteRequest instanceof IndexRequest ir) { + ir.checkAutoIdWithOpTypeCreateSupportedByVersion(minNodeVersion); + if (ir.getAutoGeneratedTimestamp() != IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP) { + throw new IllegalArgumentException("autoGeneratedTimestamp should not be set externally"); + } + } + } + + if (needsProcessing) { + ActionListener.run(listener, l -> { + if ((preprocessor.shouldExecuteOnIngestNode() == false) || clusterService.localNode().isIngestNode()) { + preprocessBulkRequestWithPreprocessor(preprocessor, task, bulkRequest, executorName, l); + } else { + ingestForwarder.forwardIngestRequest(BulkAction.INSTANCE, bulkRequest, l); + } + }); + return true; + } + } + return false; + } + + private void preprocessBulkRequestWithPreprocessor( + BulkRequestPreprocessor preprocessor, Task task, BulkRequest original, String executorName, @@ -809,7 +811,8 @@ private void processBulkIndexIngestRequest( ) { final long ingestStartTimeInNanos = System.nanoTime(); final BulkRequestModifier bulkRequestModifier = new BulkRequestModifier(original); - ingestService.executeBulkRequest( + preprocessor.processBulkRequest( + threadPool.executor(executorName), original.numberOfActions(), () -> bulkRequestModifier, bulkRequestModifier::markItemAsDropped, diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java index 2f202dd21ad7c..56699af7f8dcf 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java @@ -49,6 +49,7 @@ import java.util.Map; import java.util.Objects; +import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_FIELD_ADDED; import static org.elasticsearch.action.ValidateActions.addValidationError; import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM; import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO; @@ -105,6 +106,8 @@ public class IndexRequest extends ReplicatedWriteRequest implement private boolean isPipelineResolved; + private boolean isFieldInferenceDone; + private boolean requireAlias; /** * This indicates whether the response to this request ought to list the ingest pipelines that were executed on the document @@ -189,6 +192,7 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio : new ArrayList<>(possiblyImmutableExecutedPipelines); } } + isFieldInferenceDone = in.getTransportVersion().before(SEMANTIC_TEXT_FIELD_ADDED) || in.readBoolean(); } public IndexRequest() { @@ -375,6 +379,26 @@ public boolean isPipelineResolved() { return this.isPipelineResolved; } + /** + * Sets if field inference for this request has been done by the coordinating node. + * + * @param isFieldInferenceDone true if the field inference has been resolved + * @return the request + */ + public IndexRequest isFieldInferenceDone(final boolean isFieldInferenceDone) { + this.isFieldInferenceDone = isFieldInferenceDone; + return this; + } + + /** + * Returns whether the field inference for this request has been resolved by the coordinating node. + * + * @return true if the pipeline has been resolved + */ + public boolean isFieldInferenceDone() { + return this.isFieldInferenceDone; + } + /** * The source of the document to index, recopied to a new array if it is unsafe. */ @@ -755,6 +779,9 @@ private void writeBody(StreamOutput out) throws IOException { out.writeOptionalCollection(executedPipelines, StreamOutput::writeString); } } + if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD_ADDED)) { + out.writeBoolean(isFieldInferenceDone); + } } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java b/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java rename to server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java index 7938c2abd8d99..5e2e67786d977 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/InferenceAction.java +++ b/server/src/main/java/org/elasticsearch/action/inference/InferenceAction.java @@ -1,11 +1,12 @@ /* * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. */ -package org.elasticsearch.xpack.inference.action; +package org.elasticsearch.action.inference; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionRequest; diff --git a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java index b9ba0762e5117..796d10d5c893b 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/FieldTypeLookup.java @@ -36,6 +36,8 @@ final class FieldTypeLookup { */ private final Map> fieldToCopiedFields; + private final Map fieldToInferenceModels; + private final int maxParentPathDots; FieldTypeLookup( @@ -48,6 +50,7 @@ final class FieldTypeLookup { final Map fullSubfieldNameToParentPath = new HashMap<>(); final Map dynamicFieldTypes = new HashMap<>(); final Map> fieldToCopiedFields = new HashMap<>(); + final Map fieldToInferenceModels = new HashMap<>(); for (FieldMapper fieldMapper : fieldMappers) { String fieldName = fieldMapper.name(); MappedFieldType fieldType = fieldMapper.fieldType(); @@ -65,6 +68,9 @@ final class FieldTypeLookup { } fieldToCopiedFields.get(targetField).add(fieldName); } + if (fieldType.hasInferenceModel()) { + fieldToInferenceModels.put(fieldName, fieldType.getInferenceModel()); + } } int maxParentPathDots = 0; @@ -97,6 +103,7 @@ final class FieldTypeLookup { // make values into more compact immutable sets to save memory fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue()))); this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields); + this.fieldToInferenceModels = Map.copyOf(fieldToInferenceModels); } public static int dotCount(String path) { @@ -109,6 +116,10 @@ public static int dotCount(String path) { return dotCount; } + String modelForField(String fieldName) { + return this.fieldToInferenceModels.get(fieldName); + } + /** * Returns the mapped field type for the given field name. */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java b/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java index b68bb1a2b1987..0d86ee1607dce 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappedFieldType.java @@ -203,6 +203,14 @@ public List dimensions() { return Collections.emptyList(); } + public String getInferenceModel() { + return null; + } + + public final boolean hasInferenceModel() { + return getInferenceModel() != null; + } + /** * @return metric type or null if the field is not a metric field */ diff --git a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java index 7c44f33fbafa5..14c3c4371c030 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/MappingLookup.java @@ -491,4 +491,8 @@ public void validateDoesNotShadow(String name) { throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric"); } } + + public String modelForField(String fieldName) { + return fieldTypeLookup.modelForField(fieldName); + } } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java new file mode 100644 index 0000000000000..4bc79628268ff --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/mapper/SemanticTextFieldMapper.java @@ -0,0 +1,201 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.mapper; + +import org.apache.lucene.document.Field; +import org.apache.lucene.document.StringField; +import org.apache.lucene.search.Query; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Map; + +/** A {@link FieldMapper} for full-text fields. */ +public class SemanticTextFieldMapper extends FieldMapper { + + public static final String CONTENT_TYPE = "semantic_text"; + + public static final String TEXT_SUBFIELD_NAME = "text"; + public static final String SPARSE_VECTOR_SUBFIELD_NAME = "inference"; + + private static SemanticTextFieldMapper toType(FieldMapper in) { + return (SemanticTextFieldMapper) in; + } + + public static class Builder extends FieldMapper.Builder { + + final Parameter modelId = Parameter.stringParam("model_id", false, m -> toType(m).modelId, null).addValidator(value -> { + if (value == null) { + // TODO check the model exists + throw new IllegalArgumentException("field [model_id] must be specified"); + } + }); + + private final Parameter> meta = Parameter.metaParam(); + + public Builder(String name) { + super(name); + } + + public Builder modelId(String modelId) { + this.modelId.setValue(modelId); + return this; + } + + @Override + protected Parameter[] getParameters() { + return new Parameter[] { modelId, meta }; + } + + private SemanticTextFieldType buildFieldType(MapperBuilderContext context) { + return new SemanticTextFieldType(context.buildFullName(name), modelId.getValue(), meta.getValue()); + } + + @Override + public SemanticTextFieldMapper build(MapperBuilderContext context) { + String fullName = context.buildFullName(name); + String subfieldName = fullName + "." + SPARSE_VECTOR_SUBFIELD_NAME; + SparseVectorFieldMapper sparseVectorFieldMapper = new SparseVectorFieldMapper.Builder(subfieldName).build(context); + return new SemanticTextFieldMapper( + name(), + new SemanticTextFieldType(name(), modelId.getValue(), meta.getValue()), + modelId.getValue(), + sparseVectorFieldMapper, + copyTo, + this + ); + } + } + + public static final TypeParser PARSER = new TypeParser((n, c) -> new Builder(n), notInMultiFields(CONTENT_TYPE)); + + public static class SemanticTextFieldType extends SimpleMappedFieldType { + + private SparseVectorFieldType sparseVectorFieldType; + + private final String modelId; + + public SemanticTextFieldType(String name, String modelId, Map meta) { + super(name, true, false, false, TextSearchInfo.NONE, meta); + this.modelId = modelId; + } + + public String modelId() { + return modelId; + } + + public SparseVectorFieldType getSparseVectorFieldType() { + return this.sparseVectorFieldType; + } + + @Override + public String typeName() { + return CONTENT_TYPE; + } + + public String getInferenceModel() { + return modelId; + } + + @Override + public ValueFetcher valueFetcher(SearchExecutionContext context, String format) { + return SourceValueFetcher.identity(name(), context, format); + } + + @Override + public Query termQuery(Object value, SearchExecutionContext context) { + return sparseVectorFieldType.termQuery(value, context); + } + + @Override + public Query existsQuery(SearchExecutionContext context) { + return sparseVectorFieldType.existsQuery(context); + } + } + + private final String modelId; + private final SparseVectorFieldMapper sparseVectorFieldMapper; + + private SemanticTextFieldMapper( + String simpleName, + MappedFieldType mappedFieldType, + String modelId, + SparseVectorFieldMapper sparseVectorFieldMapper, + CopyTo copyTo, + Builder builder + ) { + super(simpleName, mappedFieldType, MultiFields.empty(), copyTo); + this.modelId = modelId; + this.sparseVectorFieldMapper = sparseVectorFieldMapper; + } + + @Override + public FieldMapper.Builder getMergeBuilder() { + return new Builder(simpleName()).init(this); + } + + @Override + public void parse(DocumentParserContext context) throws IOException { + + if (context.parser().currentToken() != XContentParser.Token.START_OBJECT) { + throw new IllegalArgumentException( + "[semantic_text] fields must be a json object, expected a START_OBJECT but got: " + context.parser().currentToken() + ); + } + + boolean textFound = false; + boolean inferenceFound = false; + for (XContentParser.Token token = context.parser().nextToken(); token != XContentParser.Token.END_OBJECT; token = context.parser() + .nextToken()) { + if (token != XContentParser.Token.FIELD_NAME) { + throw new IllegalArgumentException("[semantic_text] fields expect an object with field names, found " + token); + } + + String fieldName = context.parser().currentName(); + XContentParser.Token valueToken = context.parser().nextToken(); + switch (fieldName) { + case TEXT_SUBFIELD_NAME: + context.doc().add(new StringField(name() + TEXT_SUBFIELD_NAME, context.parser().textOrNull(), Field.Store.NO)); + textFound = true; + break; + case SPARSE_VECTOR_SUBFIELD_NAME: + sparseVectorFieldMapper.parse(context); + inferenceFound = true; + break; + default: + throw new IllegalArgumentException("Unexpected subfield value: " + fieldName); + } + } + + if (textFound == false) { + throw new IllegalArgumentException("[semantic_text] value does not contain [" + TEXT_SUBFIELD_NAME + "] subfield"); + } + if (inferenceFound == false) { + throw new IllegalArgumentException("[semantic_text] value does not contain [" + SPARSE_VECTOR_SUBFIELD_NAME + "] subfield"); + } + } + + @Override + protected void parseCreateField(DocumentParserContext context) { + throw new AssertionError("parse is implemented directly"); + } + + @Override + protected String contentType() { + return CONTENT_TYPE; + } + + @Override + public SemanticTextFieldType fieldType() { + return (SemanticTextFieldType) super.fieldType(); + } +} diff --git a/server/src/main/java/org/elasticsearch/indices/IndicesModule.java b/server/src/main/java/org/elasticsearch/indices/IndicesModule.java index 181852c2c3bc9..b94e3ca7785e6 100644 --- a/server/src/main/java/org/elasticsearch/indices/IndicesModule.java +++ b/server/src/main/java/org/elasticsearch/indices/IndicesModule.java @@ -55,6 +55,7 @@ import org.elasticsearch.index.mapper.RangeType; import org.elasticsearch.index.mapper.RoutingFieldMapper; import org.elasticsearch.index.mapper.RuntimeField; +import org.elasticsearch.index.mapper.SemanticTextFieldMapper; import org.elasticsearch.index.mapper.SeqNoFieldMapper; import org.elasticsearch.index.mapper.SourceFieldMapper; import org.elasticsearch.index.mapper.TextFieldMapper; @@ -197,6 +198,7 @@ public static Map getMappers(List mappe mappers.put(DenseVectorFieldMapper.CONTENT_TYPE, DenseVectorFieldMapper.PARSER); mappers.put(SparseVectorFieldMapper.CONTENT_TYPE, SparseVectorFieldMapper.PARSER); + mappers.put(SemanticTextFieldMapper.CONTENT_TYPE, SemanticTextFieldMapper.PARSER); for (MapperPlugin mapperPlugin : mapperPlugins) { for (Map.Entry entry : mapperPlugin.getMappers().entrySet()) { diff --git a/server/src/main/java/org/elasticsearch/ingest/AbstractBulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/AbstractBulkRequestPreprocessor.java new file mode 100644 index 0000000000000..2e76e306e362b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/ingest/AbstractBulkRequestPreprocessor.java @@ -0,0 +1,105 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.ingest; + +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.bulk.TransportBulkAction; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.plugins.internal.DocumentParsingObserver; + +import java.util.Map; +import java.util.concurrent.ExecutorService; +import java.util.function.BiConsumer; +import java.util.function.IntConsumer; +import java.util.function.Supplier; + +public abstract class AbstractBulkRequestPreprocessor implements BulkRequestPreprocessor { + + protected final Supplier documentParsingObserverSupplier; + + protected final IngestMetric ingestMetric = new IngestMetric(); + + public AbstractBulkRequestPreprocessor(Supplier documentParsingObserver) { + this.documentParsingObserverSupplier = documentParsingObserver; + } + + @Override + public void processBulkRequest( + ExecutorService executorService, + int numberOfActionRequests, + Iterable> actionRequests, + final IntConsumer onDropped, + final BiConsumer onFailure, + final BiConsumer onCompletion, + final String executorName + ) { + assert numberOfActionRequests > 0 : "numberOfActionRequests must be greater than 0 but was [" + numberOfActionRequests + "]"; + + executorService.execute(new AbstractRunnable() { + @Override + public void onFailure(Exception e) { + onCompletion.accept(null, e); + } + + @Override + protected void doRun() { + final Thread originalThread = Thread.currentThread(); + try (var refs = new RefCountingRunnable(() -> onCompletion.accept(originalThread, null))) { + int slot = 0; + for (DocWriteRequest actionRequest : actionRequests) { + IndexRequest indexRequest = TransportBulkAction.getIndexWriteRequest(actionRequest); + if (indexRequest != null) { + processIndexRequest(indexRequest, slot, refs, onDropped, onFailure); + } + slot++; + } + } + } + }); + } + + protected abstract void processIndexRequest( + IndexRequest indexRequest, + int slot, + RefCountingRunnable refs, + IntConsumer onDropped, + BiConsumer onFailure + ); + + /** + * Updates an index request based on the source of an ingest document, guarding against self-references if necessary. + */ + protected static void updateIndexRequestSource(final IndexRequest request, final IngestDocument document) { + boolean ensureNoSelfReferences = document.doNoSelfReferencesCheck(); + // we already check for self references elsewhere (and clear the bit), so this should always be false, + // keeping the check and assert as a guard against extraordinarily surprising circumstances + assert ensureNoSelfReferences == false; + request.source(document.getSource(), request.getContentType(), ensureNoSelfReferences); + } + + /** + * Builds a new ingest document from the passed-in index request. + */ + protected IngestDocument newIngestDocument(final IndexRequest request) { + return new IngestDocument( + request.index(), + request.id(), + request.version(), + request.routing(), + request.versionType(), + request.sourceAsMap(documentParsingObserverSupplier.get()) + ); + } + + protected IngestDocument newIngestDocument(final IndexRequest request, Map sourceMap) { + return new IngestDocument(request.index(), request.id(), request.version(), request.routing(), request.versionType(), sourceMap); + } +} diff --git a/server/src/main/java/org/elasticsearch/ingest/BulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/BulkRequestPreprocessor.java new file mode 100644 index 0000000000000..73adba4726dd5 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/ingest/BulkRequestPreprocessor.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.ingest; + +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.cluster.metadata.Metadata; + +import java.util.concurrent.ExecutorService; +import java.util.function.BiConsumer; +import java.util.function.IntConsumer; + +public interface BulkRequestPreprocessor { + void processBulkRequest( + ExecutorService executorService, + int numberOfActionRequests, + Iterable> actionRequests, + IntConsumer onDropped, + BiConsumer onFailure, + BiConsumer onCompletion, + String executorName + ); + + boolean needsProcessing(DocWriteRequest docWriteRequest, IndexRequest indexRequest, Metadata metadata); + + boolean hasBeenProcessed(IndexRequest indexRequest); + + boolean shouldExecuteOnIngestNode(); +} diff --git a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java new file mode 100644 index 0000000000000..21a99365255d9 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java @@ -0,0 +1,193 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.ingest; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.DocWriteRequest; +import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.inference.InferenceAction; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.client.internal.OriginSettingClient; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.mapper.SemanticTextFieldMapper; +import org.elasticsearch.indices.IndicesService; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.plugins.internal.DocumentParsingObserver; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.function.BiConsumer; +import java.util.function.IntConsumer; +import java.util.function.Supplier; + +public class FieldInferenceBulkRequestPreprocessor extends AbstractBulkRequestPreprocessor { + + public static final String SEMANTIC_TEXT_ORIGIN = "semantic_text"; + + private final IndicesService indicesService; + + private final ClusterService clusterService; + + private final OriginSettingClient client; + private final IndexNameExpressionResolver indexNameExpressionResolver; + + public FieldInferenceBulkRequestPreprocessor( + Supplier documentParsingObserver, + ClusterService clusterService, + IndicesService indicesService, + IndexNameExpressionResolver indexNameExpressionResolver, + Client client + ) { + super(documentParsingObserver); + this.indicesService = indicesService; + this.clusterService = clusterService; + this.client = new OriginSettingClient(client, SEMANTIC_TEXT_ORIGIN); + this.indexNameExpressionResolver = indexNameExpressionResolver; + } + + protected void processIndexRequest( + IndexRequest indexRequest, + int slot, + RefCountingRunnable refs, + IntConsumer onDropped, + final BiConsumer onFailure + ) { + assert indexRequest.isFieldInferenceDone() == false; + + IngestDocument ingestDocument = newIngestDocument(indexRequest); + List fieldNames = ingestDocument.getSource().entrySet() + .stream() + .filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue())) + .map(Map.Entry::getKey) + .toList(); + + // Runs inference sequentially. This makes easier sync and removes the problem of having multiple + // BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse + runInferenceForFields(indexRequest, fieldNames, refs.acquire(), slot, ingestDocument, onFailure); + + } + + @Override + public boolean needsProcessing(DocWriteRequest docWriteRequest, IndexRequest indexRequest, Metadata metadata) { + return (indexRequest.isFieldInferenceDone() == false) + && indexRequest.sourceAsMap() + .entrySet() + .stream() + .anyMatch(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue())); + } + + @Override + public boolean hasBeenProcessed(IndexRequest indexRequest) { + return indexRequest.isFieldInferenceDone(); + } + + @Override + public boolean shouldExecuteOnIngestNode() { + return false; + } + + private boolean fieldNeedsInference(IndexRequest indexRequest, String fieldName, Object fieldValue) { + + if (fieldValue instanceof String == false) { + return false; + } + + return getModelForField(indexRequest, fieldName) != null; + } + + private String getModelForField(IndexRequest indexRequest, String fieldName) { + // Check all indices related to request have the same model id + String model = null; + try { + Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), indexRequest); + for (Index index : indices) { + String modelForIndex = indicesService.indexService(index).mapperService().mappingLookup().modelForField(fieldName); + if ((modelForIndex == null)) { + return null; + } + if ((model != null) && modelForIndex.equals(model) == false) { + return null; + } + model = modelForIndex; + } + } catch (Exception e) { + // There's a problem retrieving the index + return null; + } + return model; + } + + private void runInferenceForFields( + IndexRequest indexRequest, + List fieldNames, + Releasable ref, + int position, + final IngestDocument ingestDocument, + BiConsumer onFailure + ) { + // We finished processing + if (fieldNames.isEmpty()) { + updateIndexRequestSource(indexRequest, ingestDocument); + indexRequest.isFieldInferenceDone(true); + ref.close(); + return; + } + String fieldName = fieldNames.get(0); + List nextFieldNames = fieldNames.subList(1, fieldNames.size()); + final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class); + if (fieldValue == null) { + // Run inference for next field + runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure); + } + + String modelForField = getModelForField(indexRequest, fieldName); + assert modelForField != null : "Field " + fieldName + " has no model associated in mappings"; + + // TODO Hardcoding task type, how to get that from model ID? + InferenceAction.Request inferenceRequest = new InferenceAction.Request( + TaskType.SPARSE_EMBEDDING, + modelForField, + fieldValue, + Map.of() + ); + + client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { + @Override + public void onResponse(InferenceAction.Response response) { + // Transform into two subfields, one with the actual text and other with the inference + Map newFieldValue = new HashMap<>(); + newFieldValue.put(SemanticTextFieldMapper.TEXT_SUBFIELD_NAME, fieldValue); + newFieldValue.put( + SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, + response.getResult().asMap(fieldName).get(fieldName) + ); + ingestDocument.setFieldValue(fieldName, newFieldValue); + + // Run inference for next fields + runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure); + } + + @Override + public void onFailure(Exception e) { + // Wrap exception in an illegal argument exception, as there is a problem with the model or model config + onFailure.accept( + position, + new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e) + ); + ref.close(); + } + }); + } +} diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestService.java b/server/src/main/java/org/elasticsearch/ingest/IngestService.java index 3adaab078ad4a..55265b7cb532c 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestService.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestService.java @@ -18,7 +18,6 @@ import org.elasticsearch.action.DocWriteRequest; import org.elasticsearch.action.admin.cluster.node.info.NodeInfo; import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; -import org.elasticsearch.action.bulk.TransportBulkAction; import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.ingest.DeletePipelineRequest; import org.elasticsearch.action.ingest.PutPipelineRequest; @@ -45,7 +44,6 @@ import org.elasticsearch.common.regex.Regex; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.CollectionUtils; -import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; @@ -94,7 +92,7 @@ /** * Holder class for several ingest related services. */ -public class IngestService implements ClusterStateApplier, ReportingService { +public class IngestService extends AbstractBulkRequestPreprocessor implements ClusterStateApplier, ReportingService { public static final String NOOP_PIPELINE_NAME = "_none"; @@ -105,7 +103,6 @@ public class IngestService implements ClusterStateApplier, ReportingService taskQueue; private final ClusterService clusterService; private final ScriptService scriptService; - private final Supplier documentParsingObserverSupplier; private final Map processorFactories; // Ideally this should be in IngestMetadata class, but we don't have the processor factories around there. // We know of all the processor factories when a node with all its plugin have been initialized. Also some @@ -184,9 +181,9 @@ public IngestService( MatcherWatchdog matcherWatchdog, Supplier documentParsingObserverSupplier ) { + super(documentParsingObserverSupplier); this.clusterService = clusterService; this.scriptService = scriptService; - this.documentParsingObserverSupplier = documentParsingObserverSupplier; this.processorFactories = processorFactories( ingestPlugins, new Processor.Parameters( @@ -540,6 +537,76 @@ private static void collectProcessorMetrics( } } + @Override + public boolean needsProcessing(DocWriteRequest docWriteRequest, IndexRequest indexRequest, Metadata metadata) { + resolvePipelinesAndUpdateIndexRequest(docWriteRequest, indexRequest, metadata); + return hasPipeline(indexRequest); + } + + @Override + public boolean hasBeenProcessed(IndexRequest indexRequest) { + return hasPipeline(indexRequest) && indexRequest.isPipelineResolved(); + } + + @Override + public boolean shouldExecuteOnIngestNode() { + return true; + } + + @Override + protected void processIndexRequest( + IndexRequest indexRequest, + int slot, + RefCountingRunnable refs, + IntConsumer onDropped, + final BiConsumer onFailure + ) { + assert indexRequest.isPipelineResolved(); + + IngestService.PipelineIterator pipelines = getAndResetPipelines(indexRequest); + if (pipelines.hasNext() == false) { + return; + } + + // start the stopwatch and acquire a ref to indicate that we're working on this document + final long startTimeInNanos = System.nanoTime(); + ingestMetric.preIngest(); + final Releasable ref = refs.acquire(); + // the document listener gives us three-way logic: a document can fail processing (1), or it can + // be successfully processed. a successfully processed document can be kept (2) or dropped (3). + final ActionListener documentListener = ActionListener.runAfter(new ActionListener<>() { + @Override + public void onResponse(Boolean kept) { + assert kept != null; + if (kept == false) { + onDropped.accept(slot); + } + } + + @Override + public void onFailure(Exception e) { + ingestMetric.ingestFailed(); + onFailure.accept(slot, e); + } + }, () -> { + // regardless of success or failure, we always stop the ingest "stopwatch" and release the ref to indicate + // that we're finished with this document + final long ingestTimeInNanos = System.nanoTime() - startTimeInNanos; + ingestMetric.postIngest(ingestTimeInNanos); + ref.close(); + }); + DocumentParsingObserver documentParsingObserver = documentParsingObserverSupplier.get(); + + IngestDocument ingestDocument = newIngestDocument(indexRequest); + + executePipelines(pipelines, indexRequest, ingestDocument, documentListener); + indexRequest.setPipelinesHaveRun(); + + assert indexRequest.index() != null; + documentParsingObserver.setIndexName(indexRequest.index()); + documentParsingObserver.close(); + } + /** * Used in this class and externally by the {@link org.elasticsearch.action.ingest.ReservedPipelineAction} */ @@ -652,87 +719,6 @@ void validatePipeline(Map ingestInfos, String pipelin ExceptionsHelper.rethrowAndSuppress(exceptions); } - public void executeBulkRequest( - final int numberOfActionRequests, - final Iterable> actionRequests, - final IntConsumer onDropped, - final BiConsumer onFailure, - final BiConsumer onCompletion, - final String executorName - ) { - assert numberOfActionRequests > 0 : "numberOfActionRequests must be greater than 0 but was [" + numberOfActionRequests + "]"; - - threadPool.executor(executorName).execute(new AbstractRunnable() { - - @Override - public void onFailure(Exception e) { - onCompletion.accept(null, e); - } - - @Override - protected void doRun() { - final Thread originalThread = Thread.currentThread(); - try (var refs = new RefCountingRunnable(() -> onCompletion.accept(originalThread, null))) { - int i = 0; - for (DocWriteRequest actionRequest : actionRequests) { - IndexRequest indexRequest = TransportBulkAction.getIndexWriteRequest(actionRequest); - if (indexRequest == null) { - i++; - continue; - } - - PipelineIterator pipelines = getAndResetPipelines(indexRequest); - if (pipelines.hasNext() == false) { - i++; - continue; - } - - // start the stopwatch and acquire a ref to indicate that we're working on this document - final long startTimeInNanos = System.nanoTime(); - totalMetrics.preIngest(); - final int slot = i; - final Releasable ref = refs.acquire(); - // the document listener gives us three-way logic: a document can fail processing (1), or it can - // be successfully processed. a successfully processed document can be kept (2) or dropped (3). - final ActionListener documentListener = ActionListener.runAfter(new ActionListener<>() { - @Override - public void onResponse(Boolean kept) { - assert kept != null; - if (kept == false) { - onDropped.accept(slot); - } - } - - @Override - public void onFailure(Exception e) { - totalMetrics.ingestFailed(); - onFailure.accept(slot, e); - } - }, () -> { - // regardless of success or failure, we always stop the ingest "stopwatch" and release the ref to indicate - // that we're finished with this document - final long ingestTimeInNanos = System.nanoTime() - startTimeInNanos; - totalMetrics.postIngest(ingestTimeInNanos); - ref.close(); - }); - DocumentParsingObserver documentParsingObserver = documentParsingObserverSupplier.get(); - - IngestDocument ingestDocument = newIngestDocument(indexRequest, documentParsingObserver); - - executePipelines(pipelines, indexRequest, ingestDocument, documentListener); - indexRequest.setPipelinesHaveRun(); - - assert actionRequest.index() != null; - documentParsingObserver.setIndexName(actionRequest.index()); - documentParsingObserver.close(); - - i++; - } - } - } - }); - } - /** * Returns the pipelines of the request, and updates the request so that it no longer references * any pipelines (both the default and final pipeline are set to the noop pipeline). @@ -1052,17 +1038,6 @@ private static void updateIndexRequestMetadata(final IndexRequest request, final } } - /** - * Updates an index request based on the source of an ingest document, guarding against self-references if necessary. - */ - private static void updateIndexRequestSource(final IndexRequest request, final IngestDocument document) { - boolean ensureNoSelfReferences = document.doNoSelfReferencesCheck(); - // we already check for self references elsewhere (and clear the bit), so this should always be false, - // keeping the check and assert as a guard against extraordinarily surprising circumstances - assert ensureNoSelfReferences == false; - request.source(document.getSource(), request.getContentType(), ensureNoSelfReferences); - } - /** * Grab the @timestamp and store it on the index request so that TSDB can use it without needing to parse * the source for this document. diff --git a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java index adcb9d29861c0..9fe0dd3939c36 100644 --- a/server/src/main/java/org/elasticsearch/node/NodeConstruction.java +++ b/server/src/main/java/org/elasticsearch/node/NodeConstruction.java @@ -122,6 +122,7 @@ import org.elasticsearch.indices.recovery.plan.RecoveryPlannerService; import org.elasticsearch.indices.recovery.plan.ShardSnapshotsService; import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.MonitorService; import org.elasticsearch.monitor.fs.FsHealthService; @@ -703,6 +704,13 @@ private void construct(Environment initialEnvironment, NodeServiceProvider servi searchModule.getRequestCacheKeyDifferentiator(), documentParsingObserverSupplier ); + final FieldInferenceBulkRequestPreprocessor fieldInferenceBulkRequestPreprocessor = new FieldInferenceBulkRequestPreprocessor( + documentParsingObserverSupplier, + clusterService, + indicesService, + clusterModule.getIndexNameExpressionResolver(), + client + ); final var parameters = new IndexSettingProvider.Parameters(indicesService::createIndexMapperServiceForValidation); IndexSettingProviders indexSettingProviders = new IndexSettingProviders( @@ -1075,6 +1083,7 @@ record PluginServiceInstances( b.bind(ScriptService.class).toInstance(scriptService); b.bind(AnalysisRegistry.class).toInstance(analysisModule.getAnalysisRegistry()); b.bind(IngestService.class).toInstance(ingestService); + b.bind(FieldInferenceBulkRequestPreprocessor.class).toInstance(fieldInferenceBulkRequestPreprocessor); b.bind(IndexingPressure.class).toInstance(indexingLimits); b.bind(UsageService.class).toInstance(usageService); b.bind(AggregationUsageService.class).toInstance(searchModule.getValuesSourceRegistry().getUsageService()); diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java index e097b83fb9d35..5649655480660 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIndicesThatCannotBeCreatedTests.java @@ -31,6 +31,8 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.index.VersionType; import org.elasticsearch.indices.EmptySystemIndices; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; +import org.elasticsearch.ingest.IngestService; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.MockUtils; @@ -125,7 +127,8 @@ public boolean hasIndexAbstraction(String indexAbstraction, ClusterState state) threadPool, transportService, clusterService, - null, + mock(IngestService.class), + mock(FieldInferenceBulkRequestPreprocessor.class), null, mock(ActionFilters.class), indexNameExpressionResolver, diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java index 0168eb0488a5b..7c8c5478ee652 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionIngestTests.java @@ -42,6 +42,7 @@ import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.TestIndexNameExpressionResolver; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; @@ -93,6 +94,7 @@ public class TransportBulkActionIngestTests extends ESTestCase { TransportService transportService; ClusterService clusterService; IngestService ingestService; + FieldInferenceBulkRequestPreprocessor fieldInferenceBulkRequestPreprocessor; ThreadPool threadPool; /** Arguments to callbacks we want to capture, but which require generics, so we must use @Captor */ @@ -114,6 +116,9 @@ public class TransportBulkActionIngestTests extends ESTestCase { /** True if the next call to the index action should act as an ingest node */ boolean localIngest; + /** True if IngestService should process pipelines in the next request */ + boolean shouldProcessPipelines; + /** The nodes that forwarded index requests should be cycled through. */ DiscoveryNodes nodes; DiscoveryNode remoteNode1; @@ -133,6 +138,7 @@ class TestTransportBulkAction extends TransportBulkAction { transportService, clusterService, ingestService, + fieldInferenceBulkRequestPreprocessor, null, new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), @@ -229,6 +235,9 @@ public void setupAction() { }).when(clusterService).addStateApplier(any(ClusterStateApplier.class)); // setup the mocked ingest service for capturing calls ingestService = mock(IngestService.class); + when(ingestService.shouldExecuteOnIngestNode()).thenReturn(true); + when(ingestService.needsProcessing(any(), any(), any())).thenAnswer(stub -> shouldProcessPipelines); + fieldInferenceBulkRequestPreprocessor = mock(FieldInferenceBulkRequestPreprocessor.class); action = new TestTransportBulkAction(); singleItemBulkWriteAction = new TestSingleItemBulkWriteAction(action); reset(transportService); // call on construction of action @@ -279,7 +288,8 @@ public void testIngestLocal() throws Exception { assertFalse(action.isExecuted); // haven't executed yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(bulkRequest.numberOfActions()), bulkDocsItr.capture(), any(), @@ -321,7 +331,8 @@ public void testSingleItemBulkActionIngestLocal() throws Exception { assertFalse(action.isExecuted); // haven't executed yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(1), bulkDocsItr.capture(), any(), @@ -367,7 +378,8 @@ public void testIngestSystemLocal() throws Exception { assertFalse(action.isExecuted); // haven't executed yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(bulkRequest.numberOfActions()), bulkDocsItr.capture(), any(), @@ -404,7 +416,7 @@ public void testIngestForward() throws Exception { ActionTestUtils.execute(action, null, bulkRequest, listener); // should not have executed ingest locally - verify(ingestService, never()).executeBulkRequest(anyInt(), any(), any(), any(), any(), any()); + verify(ingestService, never()).processBulkRequest(any(), anyInt(), any(), any(), any(), any(), any()); // but instead should have sent to a remote node with the transport service ArgumentCaptor node = ArgumentCaptor.forClass(DiscoveryNode.class); verify(transportService).sendRequest(node.capture(), eq(BulkAction.NAME), any(), remoteResponseHandler.capture()); @@ -444,7 +456,7 @@ public void testSingleItemBulkActionIngestForward() throws Exception { ActionTestUtils.execute(singleItemBulkWriteAction, null, indexRequest, listener); // should not have executed ingest locally - verify(ingestService, never()).executeBulkRequest(anyInt(), any(), any(), any(), any(), any()); + verify(ingestService, never()).processBulkRequest(any(), anyInt(), any(), any(), any(), any(), any()); // but instead should have sent to a remote node with the transport service ArgumentCaptor node = ArgumentCaptor.forClass(DiscoveryNode.class); verify(transportService).sendRequest(node.capture(), eq(BulkAction.NAME), any(), remoteResponseHandler.capture()); @@ -524,7 +536,8 @@ private void validatePipelineWithBulkUpsert(@Nullable String indexRequestIndexNa assertFalse(action.isExecuted); // haven't executed yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(bulkRequest.numberOfActions()), bulkDocsItr.capture(), any(), @@ -572,7 +585,8 @@ public void testDoExecuteCalledTwiceCorrectly() throws Exception { assertFalse(action.indexCreated); // no index yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(1), bulkDocsItr.capture(), any(), @@ -666,7 +680,8 @@ public void testFindDefaultPipelineFromTemplateMatch() { ); assertEquals("pipeline2", indexRequest.getPipeline()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(1), bulkDocsItr.capture(), any(), @@ -710,7 +725,8 @@ public void testFindDefaultPipelineFromV2TemplateMatch() { ); assertEquals("pipeline2", indexRequest.getPipeline()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(1), bulkDocsItr.capture(), any(), @@ -737,7 +753,8 @@ public void testIngestCallbackExceptionHandled() throws Exception { assertFalse(action.isExecuted); // haven't executed yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(bulkRequest.numberOfActions()), bulkDocsItr.capture(), any(), @@ -760,6 +777,10 @@ private void validateDefaultPipeline(IndexRequest indexRequest) { AtomicBoolean responseCalled = new AtomicBoolean(false); AtomicBoolean failureCalled = new AtomicBoolean(false); assertNull(indexRequest.getPipeline()); + when(ingestService.needsProcessing(any(), eq(indexRequest), any())).thenAnswer(i -> { + IngestService.resolvePipelinesAndUpdateIndexRequest(i.getArgument(0), i.getArgument(1), i.getArgument(2)); + return true; + }); ActionTestUtils.execute( singleItemBulkWriteAction, null, @@ -774,7 +795,8 @@ private void validateDefaultPipeline(IndexRequest indexRequest) { assertFalse(action.isExecuted); // haven't executed yet assertFalse(responseCalled.get()); assertFalse(failureCalled.get()); - verify(ingestService).executeBulkRequest( + verify(ingestService).processBulkRequest( + any(), eq(1), bulkDocsItr.capture(), any(), @@ -788,6 +810,7 @@ private void validateDefaultPipeline(IndexRequest indexRequest) { // now check success indexRequest.setPipeline(IngestService.NOOP_PIPELINE_NAME); // this is done by the real pipeline execution service when processing + when(ingestService.needsProcessing(any(), eq(indexRequest), any())).thenReturn(false); completionHandler.getValue().accept(DUMMY_WRITE_THREAD, null); assertTrue(action.isExecuted); assertFalse(responseCalled.get()); // listener would only be called by real index action, not our mocked one diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java index e2c71f3b20084..06619efd62785 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTests.java @@ -40,6 +40,8 @@ import org.elasticsearch.indices.EmptySystemIndices; import org.elasticsearch.indices.SystemIndexDescriptorUtils; import org.elasticsearch.indices.SystemIndices; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; +import org.elasticsearch.ingest.IngestService; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; import org.elasticsearch.test.index.IndexVersionUtils; @@ -61,6 +63,7 @@ import static org.elasticsearch.test.ClusterServiceUtils.createClusterService; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; public class TransportBulkActionTests extends ESTestCase { @@ -82,7 +85,8 @@ class TestTransportBulkAction extends TransportBulkAction { TransportBulkActionTests.this.threadPool, transportService, clusterService, - null, + mock(IngestService.class), + mock(FieldInferenceBulkRequestPreprocessor.class), null, new ActionFilters(Collections.emptySet()), new Resolver(), diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java index d4c5fc09e821f..9a01eb6348753 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionTookTests.java @@ -32,6 +32,8 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.IndexingPressure; import org.elasticsearch.indices.EmptySystemIndices; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; +import org.elasticsearch.ingest.IngestService; import org.elasticsearch.tasks.Task; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.VersionUtils; @@ -58,6 +60,7 @@ import static org.elasticsearch.test.StreamsUtils.copyToStringFromClasspath; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.mockito.Mockito.mock; public class TransportBulkActionTookTests extends ESTestCase { @@ -246,7 +249,8 @@ static class TestTransportBulkAction extends TransportBulkAction { threadPool, transportService, clusterService, - null, + mock(IngestService.class), + mock(FieldInferenceBulkRequestPreprocessor.class), client, actionFilters, indexNameExpressionResolver, diff --git a/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java b/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java index 3b114cf0a618e..3514176f92fd9 100644 --- a/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java +++ b/server/src/test/java/org/elasticsearch/ingest/IngestServiceTests.java @@ -207,7 +207,15 @@ public void testExecuteIndexPipelineDoesNotExist() { @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); assertTrue(failure.get()); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); @@ -1106,7 +1114,8 @@ public String getType() { @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest( + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, bulkRequest.numberOfActions(), bulkRequest.requests(), indexReq -> {}, @@ -1149,7 +1158,8 @@ public void testExecuteBulkPipelineDoesNotExist() { BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest( + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, bulkRequest.numberOfActions(), bulkRequest.requests(), indexReq -> {}, @@ -1213,7 +1223,8 @@ public void close() { BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest( + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, bulkRequest.numberOfActions(), bulkRequest.requests(), indexReq -> {}, @@ -1246,7 +1257,15 @@ public void testExecuteSuccess() { final BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); verify(failureHandler, never()).accept(any(), any()); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); } @@ -1279,7 +1298,15 @@ public void testDynamicTemplates() throws Exception { CountDownLatch latch = new CountDownLatch(1); final BiConsumer failureHandler = (v, e) -> { throw new AssertionError("must never fail", e); }; final BiConsumer completionHandler = (t, e) -> latch.countDown(); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); latch.await(); assertThat(indexRequest.getDynamicTemplates(), equalTo(Map.of("foo", "bar", "foo.bar", "baz"))); } @@ -1300,7 +1327,15 @@ public void testExecuteEmptyPipeline() throws Exception { final BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); verify(failureHandler, never()).accept(any(), any()); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); } @@ -1354,7 +1389,15 @@ public void testExecutePropagateAllMetadataUpdates() throws Exception { final BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); verify(processor).execute(any(), any()); verify(failureHandler, never()).accept(any(), any()); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); @@ -1403,7 +1446,15 @@ public void testExecuteFailure() throws Exception { final BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); verify(processor).execute(eqIndexTypeId(indexRequest.version(), indexRequest.versionType(), Map.of()), any()); verify(failureHandler, times(1)).accept(eq(0), any(RuntimeException.class)); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); @@ -1452,7 +1503,15 @@ public void testExecuteSuccessWithOnFailure() throws Exception { final BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); verify(failureHandler, never()).accept(eq(0), any(IngestProcessorException.class)); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); } @@ -1495,7 +1554,15 @@ public void testExecuteFailureWithNestedOnFailure() throws Exception { final BiConsumer failureHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); verify(processor).execute(eqIndexTypeId(indexRequest.version(), indexRequest.versionType(), Map.of()), any()); verify(failureHandler, times(1)).accept(eq(0), any(RuntimeException.class)); verify(completionHandler, times(1)).accept(Thread.currentThread(), null); @@ -1549,7 +1616,8 @@ public void testBulkRequestExecutionWithFailures() throws Exception { BiConsumer requestItemErrorHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest( + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, numRequest, bulkRequest.requests(), indexReq -> {}, @@ -1607,7 +1675,8 @@ public void testBulkRequestExecution() throws Exception { BiConsumer requestItemErrorHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final BiConsumer completionHandler = mock(BiConsumer.class); - ingestService.executeBulkRequest( + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, numRequest, bulkRequest.requests(), indexReq -> {}, @@ -1720,7 +1789,15 @@ public String execute() { final IndexRequest indexRequest = new IndexRequest("_index"); indexRequest.setPipeline("_id1").setFinalPipeline("_id2"); indexRequest.source(randomAlphaOfLength(10), randomAlphaOfLength(10)); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, (integer, e) -> {}, (thread, e) -> {}, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + (integer, e) -> {}, + (thread, e) -> {}, + Names.WRITE + ); { final IngestStats ingestStats = ingestService.stats(); @@ -1791,7 +1868,15 @@ public void testStats() throws Exception { final IndexRequest indexRequest = new IndexRequest("_index"); indexRequest.setPipeline("_id1").setFinalPipeline("_none"); indexRequest.source(randomAlphaOfLength(10), randomAlphaOfLength(10)); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); final IngestStats afterFirstRequestStats = ingestService.stats(); assertThat(afterFirstRequestStats.pipelineStats().size(), equalTo(2)); @@ -1808,7 +1893,15 @@ public void testStats() throws Exception { assertProcessorStats(0, afterFirstRequestStats, "_id2", 0, 0, 0); indexRequest.setPipeline("_id2"); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); final IngestStats afterSecondRequestStats = ingestService.stats(); assertThat(afterSecondRequestStats.pipelineStats().size(), equalTo(2)); // total @@ -1830,7 +1923,15 @@ public void testStats() throws Exception { clusterState = executePut(putRequest, clusterState); ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); indexRequest.setPipeline("_id1"); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); final IngestStats afterThirdRequestStats = ingestService.stats(); assertThat(afterThirdRequestStats.pipelineStats().size(), equalTo(2)); // total @@ -1853,7 +1954,15 @@ public void testStats() throws Exception { clusterState = executePut(putRequest, clusterState); ingestService.applyClusterState(new ClusterChangedEvent("", clusterState, previousClusterState)); indexRequest.setPipeline("_id1"); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, failureHandler, completionHandler, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + failureHandler, + completionHandler, + Names.WRITE + ); final IngestStats afterForthRequestStats = ingestService.stats(); assertThat(afterForthRequestStats.pipelineStats().size(), equalTo(2)); // total @@ -1941,7 +2050,8 @@ public String getDescription() { final BiConsumer completionHandler = mock(BiConsumer.class); @SuppressWarnings("unchecked") final IntConsumer dropHandler = mock(IntConsumer.class); - ingestService.executeBulkRequest( + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, bulkRequest.numberOfActions(), bulkRequest.requests(), dropHandler, @@ -2029,7 +2139,15 @@ public void testCBORParsing() throws Exception { .setPipeline("_id") .setFinalPipeline("_none"); - ingestService.executeBulkRequest(1, List.of(indexRequest), indexReq -> {}, (integer, e) -> {}, (thread, e) -> {}, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 1, + List.of(indexRequest), + indexReq -> {}, + (integer, e) -> {}, + (thread, e) -> {}, + Names.WRITE + ); } assertThat(reference.get(), is(instanceOf(byte[].class))); @@ -2100,7 +2218,15 @@ public void testSetsRawTimestamp() { bulkRequest.add(indexRequest6); bulkRequest.add(indexRequest7); bulkRequest.add(indexRequest8); - ingestService.executeBulkRequest(8, bulkRequest.requests(), indexReq -> {}, (integer, e) -> {}, (thread, e) -> {}, Names.WRITE); + ingestService.processBulkRequest( + EsExecutors.DIRECT_EXECUTOR_SERVICE, + 8, + bulkRequest.requests(), + indexReq -> {}, + (integer, e) -> {}, + (thread, e) -> {}, + Names.WRITE + ); assertThat(indexRequest1.getRawTimestamp(), nullValue()); assertThat(indexRequest2.getRawTimestamp(), nullValue()); diff --git a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java index 1b5ff3f39be22..710ecebf061c2 100644 --- a/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java +++ b/server/src/test/java/org/elasticsearch/snapshots/SnapshotResiliencyTests.java @@ -160,6 +160,7 @@ import org.elasticsearch.indices.recovery.RecoverySettings; import org.elasticsearch.indices.recovery.SnapshotFilesProvider; import org.elasticsearch.indices.recovery.plan.PeerOnlyRecoveryPlannerService; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; import org.elasticsearch.ingest.IngestService; import org.elasticsearch.monitor.StatusInfo; import org.elasticsearch.node.ResponseCollectorService; @@ -1936,6 +1937,13 @@ protected void assertSnapshotOrGenericThread() { null, () -> DocumentParsingObserver.EMPTY_INSTANCE ), + new FieldInferenceBulkRequestPreprocessor( + () -> DocumentParsingObserver.EMPTY_INSTANCE, + clusterService, + indicesService, + indexNameExpressionResolver, + client + ), client, actionFilters, indexNameExpressionResolver, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/InternalUsers.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/InternalUsers.java index 652d6815eea46..1f5912f09d077 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/InternalUsers.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/InternalUsers.java @@ -15,8 +15,10 @@ import org.elasticsearch.action.admin.indices.settings.put.UpdateSettingsAction; import org.elasticsearch.action.admin.indices.stats.IndicesStatsAction; import org.elasticsearch.action.downsample.DownsampleAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.core.XPackPlugin; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; +import org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege; import org.elasticsearch.xpack.core.security.support.MetadataUtils; import java.util.Collection; @@ -189,7 +191,26 @@ public class InternalUsers { null, new RoleDescriptor.IndicesPrivileges[] { RoleDescriptor.IndicesPrivileges.builder().indices(".synonyms*").privileges("all").allowRestrictedIndices(true).build(), - RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges(ReloadAnalyzerAction.NAME).build(), }, + RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges(ReloadAnalyzerAction.NAME).build()}, + null, + null, + null, + MetadataUtils.DEFAULT_RESERVED_METADATA, + Map.of() + ) + ); + + public static final InternalUser SEMANTIC_TEXT_USER = new InternalUser( + UsernamesField.SEMANTIC_TEXT_USER_NAME, + new RoleDescriptor( + UsernamesField.SEMANTIC_TEXT_ROLE_NAME, + new String[] { "monitor" }, + new RoleDescriptor.IndicesPrivileges[] { + RoleDescriptor.IndicesPrivileges.builder() + .indices(".inference*", ".secrets-inference*") + .privileges("read") + .allowRestrictedIndices(true) + .build()}, null, null, null, @@ -211,7 +232,8 @@ public class InternalUsers { ASYNC_SEARCH_USER, STORAGE_USER, DATA_STREAM_LIFECYCLE_USER, - SYNONYMS_USER + SYNONYMS_USER, + SEMANTIC_TEXT_USER ).collect(Collectors.toUnmodifiableMap(InternalUser::principal, Function.identity())); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/UsernamesField.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/UsernamesField.java index 821d222bb930c..57f6d50327496 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/UsernamesField.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/user/UsernamesField.java @@ -34,6 +34,8 @@ public final class UsernamesField { public static final String STORAGE_ROLE_NAME = "_storage"; public static final String SYNONYMS_USER_NAME = "_synonyms"; public static final String SYNONYMS_ROLE_NAME = "_synonyms"; + public static final String SEMANTIC_TEXT_USER_NAME = "_semantic_text"; + public static final String SEMANTIC_TEXT_ROLE_NAME = "_semantic_text"; public static final String REMOTE_MONITORING_NAME = "remote_monitoring_user"; public static final String REMOTE_MONITORING_COLLECTION_ROLE = "remote_monitoring_collector"; diff --git a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java index 0da0340084cba..3b539bf653fe6 100644 --- a/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java +++ b/x-pack/plugin/inference/src/internalClusterTest/java/org/elasticsearch/xpack/inference/integration/MockInferenceServiceIT.java @@ -24,7 +24,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.junit.Before; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 2f0f95cf8a911..8da72b4793bf3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -32,7 +32,7 @@ import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.inference.action.DeleteInferenceModelAction; import org.elasticsearch.xpack.inference.action.GetInferenceModelAction; -import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.xpack.inference.action.PutInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceModelAction; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 29909163d7b3b..81daedc8d4f18 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.HandledTransportAction; import org.elasticsearch.common.inject.Inject; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java index 9d7a0d331b2b3..272544e6c1c06 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestInferenceAction.java @@ -11,7 +11,7 @@ import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestRequest; import org.elasticsearch.rest.action.RestToXContentListener; -import org.elasticsearch.xpack.inference.action.InferenceAction; +import org.elasticsearch.action.inference.InferenceAction; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java index 3e1bea0051656..33103eaf277d6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionRequestTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.core.Tuple; import org.elasticsearch.inference.TaskType; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java index 795923e56c6bb..122bac29f43cd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/InferenceActionResponseTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.inference.action; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.test.AbstractWireSerializingTestCase; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationUtils.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationUtils.java index d93ee6ad36c67..3b9b7b68237a0 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationUtils.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/AuthorizationUtils.java @@ -8,6 +8,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor; import org.elasticsearch.xpack.core.ClientHelper; import org.elasticsearch.xpack.core.security.SecurityContext; import org.elasticsearch.xpack.core.security.authc.Authentication; @@ -159,6 +160,9 @@ public static void switchUserBasedOnActionOriginAndExecute( case SYNONYMS_ORIGIN: securityContext.executeAsInternalUser(InternalUsers.SYNONYMS_USER, version, consumer); break; + case FieldInferenceBulkRequestPreprocessor.SEMANTIC_TEXT_ORIGIN: + securityContext.executeAsInternalUser(InternalUsers.SEMANTIC_TEXT_USER, version, consumer); + break; default: assert false : "action.origin [" + actionOrigin + "] is unknown!"; throw new IllegalStateException("action.origin [" + actionOrigin + "] should always be a known value");