Skip to content

Commit

Permalink
First working version of inference resolving
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Oct 23, 2023
1 parent 260b5a9 commit 688fc7e
Show file tree
Hide file tree
Showing 7 changed files with 248 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,7 @@ static TransportVersion def(int id) {
public static final TransportVersion BUILD_QUALIFIER_SEPARATED = def(8_518_00_0);
public static final TransportVersion PIPELINES_IN_BULK_RESPONSE_ADDED = def(8_519_00_0);
public static final TransportVersion PLUGIN_DESCRIPTOR_STRING_VERSION = def(8_520_00_0);
public static final TransportVersion SEMANTIC_TEXT_FIELD_ADDED = def(8_521_00_0);
/*
* STOP! READ THIS FIRST! No, really,
* ____ _____ ___ ____ _ ____ _____ _ ____ _____ _ _ ___ ____ _____ ___ ____ ____ _____ _
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,15 +271,15 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec
final AtomicArray<BulkItemResponse> responses = new AtomicArray<>(bulkRequest.requests.size());

boolean hasIndexRequestsWithPipelines = false;
boolean hasInferenceFields = false;
boolean needsFieldInference = false;
final Metadata metadata = clusterService.state().getMetadata();
final Version minNodeVersion = clusterService.state().getNodes().getMinNodeVersion();
for (DocWriteRequest<?> actionRequest : bulkRequest.requests) {
IndexRequest indexRequest = getIndexWriteRequest(actionRequest);
if (indexRequest != null) {
IngestService.resolvePipelinesAndUpdateIndexRequest(actionRequest, indexRequest, metadata);
hasIndexRequestsWithPipelines |= IngestService.hasPipeline(indexRequest);
hasInferenceFields |= ingestService.hasInferenceFields(indexRequest);
needsFieldInference |= ingestService.needsFieldInference(indexRequest);
}

if (actionRequest instanceof IndexRequest ir) {
Expand All @@ -290,7 +290,7 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec
}
}

if (hasIndexRequestsWithPipelines || hasInferenceFields) {
if (hasIndexRequestsWithPipelines) {
// this method (doExecute) will be called again, but with the bulk requests updated from the ingest node processing but
// also with IngestService.NOOP_PIPELINE_NAME on each request. This ensures that this on the second time through this method,
// this path is never taken.
Expand All @@ -304,14 +304,32 @@ protected void doInternalExecute(Task task, BulkRequest bulkRequest, String exec
assert arePipelinesResolved : bulkRequest;
}
if (clusterService.localNode().isIngestNode()) {
processBulkIndexIngestRequest(task, bulkRequest, executorName, l);
processPipelinesBulkIndexIngestRequest(task, bulkRequest, executorName, l);
} else {
ingestForwarder.forwardIngestRequest(BulkAction.INSTANCE, bulkRequest, l);
}
});
return;
}

if (needsFieldInference) {
// this method (doExecute) will be called again, but with the bulk requests updated with the field inference and also with
// isFieldInferenceResolved set to true on each request. This ensures that this on the second time through this method,
// this path is never taken.
ActionListener.run(listener, l -> {
if (Assertions.ENABLED) {
final boolean isFieldsInferenceResolved = bulkRequest.requests()
.stream()
.map(TransportBulkAction::getIndexWriteRequest)
.filter(Objects::nonNull)
.allMatch(IndexRequest::isFieldInferenceResolved);
assert isFieldsInferenceResolved == false : bulkRequest;
}
processFieldsInferenceBulkIndexIngestRequest(task, bulkRequest, executorName, l);
});
return;
}

// Attempt to create all the indices that we're going to need during the bulk before we start.
// Step 1: collect all the indices in the request
final Map<String, Boolean> indices = bulkRequest.requests.stream()
Expand Down Expand Up @@ -803,15 +821,15 @@ private long relativeTime() {
return relativeTimeProvider.getAsLong();
}

private void processBulkIndexIngestRequest(
private void processPipelinesBulkIndexIngestRequest(
Task task,
BulkRequest original,
String executorName,
ActionListener<BulkResponse> listener
) {
final long ingestStartTimeInNanos = System.nanoTime();
final BulkRequestModifier bulkRequestModifier = new BulkRequestModifier(original);
ingestService.executeBulkRequest(
ingestService.executePipelinesBulkRequest(
original.numberOfActions(),
() -> bulkRequestModifier,
bulkRequestModifier::markItemAsDropped,
Expand Down Expand Up @@ -862,6 +880,65 @@ public boolean isForceExecution() {
);
}

private void processFieldsInferenceBulkIndexIngestRequest(
Task task,
BulkRequest original,
String executorName,
ActionListener<BulkResponse> listener
) {
final long ingestStartTimeInNanos = System.nanoTime();
final BulkRequestModifier bulkRequestModifier = new BulkRequestModifier(original);
ingestService.executeFieldInferenceBulkRequest(
original.numberOfActions(),
() -> bulkRequestModifier,
bulkRequestModifier::markItemAsDropped,
bulkRequestModifier::markItemAsFailed,
(originalThread, exception) -> {
if (exception != null) {
logger.debug("failed to execute inference for a bulk request", exception);
listener.onFailure(exception);
} else {
BulkRequest bulkRequest = bulkRequestModifier.getBulkRequest();
long ingestTookInMillis = TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - ingestStartTimeInNanos);
ActionListener<BulkResponse> actionListener = bulkRequestModifier.wrapActionListenerIfNeeded(
ingestTookInMillis,
listener
);
if (bulkRequest.requests().isEmpty()) {
// at this stage, the transport bulk action can't deal with a bulk request with no requests,
// so we stop and send an empty response back to the client.
// (this will happen if pre-processing all items in the bulk failed)
actionListener.onResponse(new BulkResponse(new BulkItemResponse[0], 0));
} else {
ActionRunnable<BulkResponse> runnable = new ActionRunnable<>(actionListener) {
@Override
protected void doRun() {
doInternalExecute(task, bulkRequest, executorName, actionListener);
}

@Override
public boolean isForceExecution() {
// If we fork back to a write thread we **not** should fail, because tp queue is full.
// (Otherwise the work done during ingest will be lost)
// It is okay to force execution here. Throttling of write requests happens prior to
// ingest when a node receives a bulk request.
return true;
}
};
// If a processor went async and returned a response on a different thread then
// before we continue the bulk request we should fork back on a write thread:
if (originalThread == Thread.currentThread()) {
runnable.run();
} else {
threadPool.executor(executorName).execute(runnable);
}
}
}
},
executorName
);
}

static final class BulkRequestModifier implements Iterator<DocWriteRequest<?>> {

final BulkRequest bulkRequest;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Map;
import java.util.Objects;

import static org.elasticsearch.TransportVersions.SEMANTIC_TEXT_FIELD_ADDED;
import static org.elasticsearch.action.ValidateActions.addValidationError;
import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_PRIMARY_TERM;
import static org.elasticsearch.index.seqno.SequenceNumbers.UNASSIGNED_SEQ_NO;
Expand Down Expand Up @@ -105,6 +106,8 @@ public class IndexRequest extends ReplicatedWriteRequest<IndexRequest> implement

private boolean isPipelineResolved;

private boolean isFieldInferenceResolved;

private boolean requireAlias;
/**
* This indicates whether the response to this request ought to list the ingest pipelines that were executed on the document
Expand Down Expand Up @@ -189,6 +192,7 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio
: new ArrayList<>(possiblyImmutableExecutedPipelines);
}
}
isFieldInferenceResolved = in.getTransportVersion().before(SEMANTIC_TEXT_FIELD_ADDED) || in.readBoolean();
}

public IndexRequest() {
Expand Down Expand Up @@ -375,6 +379,26 @@ public boolean isPipelineResolved() {
return this.isPipelineResolved;
}

/**
* Sets if field inference for this request has been resolved by the coordinating node.
*
* @param isFieldInferenceResolved true if the field inference has been resolved
* @return the request
*/
public IndexRequest isFieldInferenceResolved(final boolean isFieldInferenceResolved) {
this.isFieldInferenceResolved = isFieldInferenceResolved;
return this;
}

/**
* Returns whether the field inference for this request has been resolved by the coordinating node.
*
* @return true if the pipeline has been resolved
*/
public boolean isFieldInferenceResolved() {
return this.isFieldInferenceResolved;
}

/**
* The source of the document to index, recopied to a new array if it is unsafe.
*/
Expand Down Expand Up @@ -755,6 +779,9 @@ private void writeBody(StreamOutput out) throws IOException {
out.writeOptionalCollection(executedPipelines, StreamOutput::writeString);
}
}
if (out.getTransportVersion().onOrAfter(SEMANTIC_TEXT_FIELD_ADDED)) {
out.writeBoolean(isFieldInferenceResolved);
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.elasticsearch.index.mapper.MappedFieldType;
import org.elasticsearch.index.mapper.MapperBuilderContext;
import org.elasticsearch.index.mapper.SourceValueFetcher;
import org.elasticsearch.index.mapper.TextFieldMapper;
import org.elasticsearch.index.mapper.TextSearchInfo;
import org.elasticsearch.index.mapper.ValueFetcher;
import org.elasticsearch.index.query.SearchExecutionContext;
Expand Down Expand Up @@ -82,7 +81,7 @@ public SparseVectorFieldMapper build(MapperBuilderContext context) {
}

return new Builder(n);
}, notInMultiFieldsUlessParentOfType(CONTENT_TYPE, TextFieldMapper.CONTENT_TYPE)); // TODO Change for semantic_text field type
}, notInMultiFields(CONTENT_TYPE));

public static final class SparseVectorFieldType extends MappedFieldType {

Expand Down
Loading

0 comments on commit 688fc7e

Please sign in to comment.