Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PoC DO NOT MERGE - Store semantic_text mapping info #9

1 change: 1 addition & 0 deletions server/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@
exports org.elasticsearch.action.fieldcaps;
exports org.elasticsearch.action.get;
exports org.elasticsearch.action.index;
exports org.elasticsearch.action.inference;
exports org.elasticsearch.action.ingest;
exports org.elasticsearch.action.resync;
exports org.elasticsearch.action.search;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ static TransportVersion def(int id) {
public static final TransportVersion TOO_MANY_SCROLL_CONTEXTS_EXCEPTION_ADDED = def(8_521_00_0);
public static final TransportVersion UNCONTENDED_REGISTER_ANALYSIS_ADDED = def(8_522_00_0);
public static final TransportVersion TRANSFORM_GET_CHECKPOINT_TIMEOUT_ADDED = def(8_523_00_0);

public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_524_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
import org.elasticsearch.common.inject.Inject;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.core.Assertions;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.Index;
Expand All @@ -59,6 +58,8 @@
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.indices.IndexClosedException;
import org.elasticsearch.indices.SystemIndices;
import org.elasticsearch.ingest.BulkRequestPreprocessor;
import org.elasticsearch.ingest.FieldInferenceBulkRequestPreprocessor;
import org.elasticsearch.ingest.IngestService;
import org.elasticsearch.node.NodeClosedException;
import org.elasticsearch.tasks.Task;
Expand Down Expand Up @@ -95,7 +96,7 @@ public class TransportBulkAction extends HandledTransportAction<BulkRequest, Bul

private final ThreadPool threadPool;
private final ClusterService clusterService;
private final IngestService ingestService;
private final List<BulkRequestPreprocessor> bulkRequestPreprocessors;
private final LongSupplier relativeTimeProvider;
private final IngestActionForwarder ingestForwarder;
private final NodeClient client;
Expand All @@ -110,6 +111,7 @@ public TransportBulkAction(
TransportService transportService,
ClusterService clusterService,
IngestService ingestService,
FieldInferenceBulkRequestPreprocessor fieldInferenceBulkRequestPreprocessor,
NodeClient client,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
Expand All @@ -121,6 +123,7 @@ public TransportBulkAction(
transportService,
clusterService,
ingestService,
fieldInferenceBulkRequestPreprocessor,
client,
actionFilters,
indexNameExpressionResolver,
Expand All @@ -135,6 +138,7 @@ public TransportBulkAction(
TransportService transportService,
ClusterService clusterService,
IngestService ingestService,
FieldInferenceBulkRequestPreprocessor fieldInferenceBulkRequestPreprocessor,
NodeClient client,
ActionFilters actionFilters,
IndexNameExpressionResolver indexNameExpressionResolver,
Expand All @@ -146,7 +150,7 @@ public TransportBulkAction(
Objects.requireNonNull(relativeTimeProvider);
this.threadPool = threadPool;
this.clusterService = clusterService;
this.ingestService = ingestService;
this.bulkRequestPreprocessors = List.of(ingestService, fieldInferenceBulkRequestPreprocessor);
this.relativeTimeProvider = relativeTimeProvider;
this.ingestForwarder = new IngestActionForwarder(transportService);
this.client = client;
Expand Down Expand Up @@ -270,45 +274,9 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec
final long startTime = relativeTime();
final AtomicArray<BulkItemResponse> responses = new AtomicArray<>(bulkRequest.requests.size());

boolean hasIndexRequestsWithPipelines = false;
final Metadata metadata = clusterService.state().getMetadata();
final Version minNodeVersion = clusterService.state().getNodes().getMinNodeVersion();
for (DocWriteRequest<?> actionRequest : bulkRequest.requests) {
IndexRequest indexRequest = getIndexWriteRequest(actionRequest);
if (indexRequest != null) {
IngestService.resolvePipelinesAndUpdateIndexRequest(actionRequest, indexRequest, metadata);
hasIndexRequestsWithPipelines |= IngestService.hasPipeline(indexRequest);
}

if (actionRequest instanceof IndexRequest ir) {
ir.checkAutoIdWithOpTypeCreateSupportedByVersion(minNodeVersion);
if (ir.getAutoGeneratedTimestamp() != IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP) {
throw new IllegalArgumentException("autoGeneratedTimestamp should not be set externally");
}
}
}

if (hasIndexRequestsWithPipelines) {
// this method (doExecute) will be called again, but with the bulk requests updated from the ingest node processing but
// also with IngestService.NOOP_PIPELINE_NAME on each request. This ensures that this on the second time through this method,
// this path is never taken.
ActionListener.run(listener, l -> {
if (Assertions.ENABLED) {
final boolean arePipelinesResolved = bulkRequest.requests()
.stream()
.map(TransportBulkAction::getIndexWriteRequest)
.filter(Objects::nonNull)
.allMatch(IndexRequest::isPipelineResolved);
assert arePipelinesResolved : bulkRequest;
}
if (clusterService.localNode().isIngestNode()) {
processBulkIndexIngestRequest(task, bulkRequest, executorName, l);
} else {
ingestForwarder.forwardIngestRequest(BulkAction.INSTANCE, bulkRequest, l);
}
});
return;
}
// Preprocess bulk requests with ingestion services. If needs preprocessing, then return early as the preprocessing
// action will invoke this method again
if (preprocessBulkRequest(task, bulkRequest, executorName, listener)) return;

// Attempt to create all the indices that we're going to need during the bulk before we start.
// Step 1: collect all the indices in the request
Expand Down Expand Up @@ -801,15 +769,50 @@ private long relativeTime() {
return relativeTimeProvider.getAsLong();
}

private void processBulkIndexIngestRequest(
private boolean preprocessBulkRequest(Task task, BulkRequest bulkRequest, String executorName, ActionListener<BulkResponse> listener) {
final Metadata metadata = clusterService.state().getMetadata();
final Version minNodeVersion = clusterService.state().getNodes().getMinNodeVersion();
boolean needsProcessing = false;
for (BulkRequestPreprocessor preprocessor : bulkRequestPreprocessors) {
for (DocWriteRequest<?> docWriteRequest : bulkRequest.requests) {
IndexRequest indexRequest = getIndexWriteRequest(docWriteRequest);
if (indexRequest != null) {
needsProcessing = needsProcessing || preprocessor.needsProcessing(docWriteRequest, indexRequest, metadata);
}

if (docWriteRequest instanceof IndexRequest ir) {
ir.checkAutoIdWithOpTypeCreateSupportedByVersion(minNodeVersion);
if (ir.getAutoGeneratedTimestamp() != IndexRequest.UNSET_AUTO_GENERATED_TIMESTAMP) {
throw new IllegalArgumentException("autoGeneratedTimestamp should not be set externally");
}
}
}

if (needsProcessing) {
ActionListener.run(listener, l -> {
if ((preprocessor.shouldExecuteOnIngestNode() == false) || clusterService.localNode().isIngestNode()) {
preprocessBulkRequestWithPreprocessor(preprocessor, task, bulkRequest, executorName, l);
} else {
ingestForwarder.forwardIngestRequest(BulkAction.INSTANCE, bulkRequest, l);
}
});
return true;
}
}
return false;
}

private void preprocessBulkRequestWithPreprocessor(
BulkRequestPreprocessor preprocessor,
Task task,
BulkRequest original,
String executorName,
ActionListener<BulkResponse> listener
) {
final long ingestStartTimeInNanos = System.nanoTime();
final BulkRequestModifier bulkRequestModifier = new BulkRequestModifier(original);
ingestService.executeBulkRequest(
preprocessor.processBulkRequest(
threadPool.executor(executorName),
original.numberOfActions(),
() -> bulkRequestModifier,
bulkRequestModifier::markItemAsDropped,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_FIELD_ADDED;
import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
Expand Down Expand Up @@ -105,6 +106,8 @@ public class IndexRequest extends ReplicatedWriteRequest<IndexRequest> implement

private boolean isPipelineResolved;

private boolean isFieldInferenceDone;

private boolean requireAlias;
/**
* This indicates whether the response to this request ought to list the ingest pipelines that were executed on the document
Expand Down Expand Up @@ -189,6 +192,7 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio
: new ArrayList<>(possiblyImmutableExecutedPipelines);
}
}
isFieldInferenceDone = in.getTransportVersion().before(SEMANTIC_TEXT_FIELD_ADDED) || in.readBoolean();
}

public IndexRequest() {
Expand Down Expand Up @@ -375,6 +379,26 @@ public boolean isPipelineResolved() {
return this.isPipelineResolved;
}

/**
* Sets if field inference for this request has been done by the coordinating node.
*
* @param isFieldInferenceDone true if the field inference has been resolved
* @return the request
*/
public IndexRequest isFieldInferenceDone(final boolean isFieldInferenceDone) {
this.isFieldInferenceDone = isFieldInferenceDone;
return this;
}

/**
* Returns whether the field inference for this request has been resolved by the coordinating node.
*
* @return true if the pipeline has been resolved
*/
public boolean isFieldInferenceDone() {
return this.isFieldInferenceDone;
}

/**
* The source of the document to index, recopied to a new array if it is unsafe.
*/
Expand Down Expand Up @@ -755,6 +779,9 @@ private void writeBody(StreamOutput out) throws IOException {
out.writeOptionalCollection(executedPipelines, StreamOutput::writeString);
}
}
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD_ADDED)) {
out.writeBoolean(isFieldInferenceDone);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.xpack.inference.action;
package org.elasticsearch.action.inference;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.action.ActionRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ final class FieldTypeLookup {
*/
private final Map<String, Set<String>> fieldToCopiedFields;

private final Map<String, String> fieldToInferenceModels;

private final int maxParentPathDots;

FieldTypeLookup(
Expand All @@ -48,6 +50,7 @@ final class FieldTypeLookup {
final Map<String, String> fullSubfieldNameToParentPath = new HashMap<>();
final Map<String, DynamicFieldType> dynamicFieldTypes = new HashMap<>();
final Map<String, Set<String>> fieldToCopiedFields = new HashMap<>();
final Map<String, String> fieldToInferenceModels = new HashMap<>();
for (FieldMapper fieldMapper : fieldMappers) {
String fieldName = fieldMapper.name();
MappedFieldType fieldType = fieldMapper.fieldType();
Expand All @@ -65,6 +68,9 @@ final class FieldTypeLookup {
}
fieldToCopiedFields.get(targetField).add(fieldName);
}
if (fieldType.hasInferenceModel()) {
fieldToInferenceModels.put(fieldName, fieldType.getInferenceModel());
}
}

int maxParentPathDots = 0;
Expand Down Expand Up @@ -97,6 +103,7 @@ final class FieldTypeLookup {
// make values into more compact immutable sets to save memory
fieldToCopiedFields.entrySet().forEach(e -> e.setValue(Set.copyOf(e.getValue())));
this.fieldToCopiedFields = Map.copyOf(fieldToCopiedFields);
this.fieldToInferenceModels = Map.copyOf(fieldToInferenceModels);
}

public static int dotCount(String path) {
Expand All @@ -109,6 +116,10 @@ public static int dotCount(String path) {
return dotCount;
}

String modelForField(String fieldName) {
return this.fieldToInferenceModels.get(fieldName);
}

/**
* Returns the mapped field type for the given field name.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,14 @@ public List<String> dimensions() {
return Collections.emptyList();
}

public String getInferenceModel() {
return null;
}

public final boolean hasInferenceModel() {
return getInferenceModel() != null;
}

/**
* @return metric type or null if the field is not a metric field
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,4 +491,8 @@ public void validateDoesNotShadow(String name) {
throw new MapperParsingException("Field [" + name + "] attempted to shadow a time_series_metric");
}
}

public String modelForField(String fieldName) {
return fieldTypeLookup.modelForField(fieldName);
}
}
Loading
Loading