diff --git a/server/src/main/java/org/elasticsearch/ingest/IngestService.java b/server/src/main/java/org/elasticsearch/ingest/IngestService.java index 3adaab078ad4a..e376e6f8a8a24 100644 --- a/server/src/main/java/org/elasticsearch/ingest/IngestService.java +++ b/server/src/main/java/org/elasticsearch/ingest/IngestService.java @@ -20,6 +20,7 @@ import org.elasticsearch.action.admin.cluster.node.info.NodesInfoResponse; import org.elasticsearch.action.bulk.TransportBulkAction; import org.elasticsearch.action.index.IndexRequest; +import org.elasticsearch.action.inference.InferenceAction; import org.elasticsearch.action.ingest.DeletePipelineRequest; import org.elasticsearch.action.ingest.PutPipelineRequest; import org.elasticsearch.action.support.RefCountingRunnable; @@ -57,6 +58,7 @@ import org.elasticsearch.index.IndexSettings; import org.elasticsearch.index.VersionType; import org.elasticsearch.index.analysis.AnalysisRegistry; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.node.ReportingService; import org.elasticsearch.plugins.IngestPlugin; import org.elasticsearch.plugins.internal.DocumentParsingObserver; @@ -107,6 +109,7 @@ public class IngestService implements ClusterStateApplier, ReportingService documentParsingObserverSupplier; private final Map processorFactories; + private final Client client; // Ideally this should be in IngestMetadata class, but we don't have the processor factories around there. // We know of all the processor factories when a node with all its plugin have been initialized. Also some // processor factories rely on other node services. Custom metadata is statically registered when classes @@ -153,6 +156,11 @@ public static MatcherWatchdog createGrokThreadWatchdog(Environment env, ThreadPo : batchExecutionContext.initialState().copyAndUpdateMetadata(b -> b.putCustom(IngestMetadata.TYPE, finalIngestMetadata)); }; + public boolean hasInferenceFields(IndexRequest indexRequest) { + return indexRequest.sourceAsMap().keySet().stream() + .anyMatch(fieldName -> fieldNeedsInference(indexRequest.index(), fieldName)); + } + /** * Specialized cluster state update task specifically for ingest pipeline operations. * These operations all receive an AcknowledgedResponse. @@ -187,6 +195,7 @@ public IngestService( this.clusterService = clusterService; this.scriptService = scriptService; this.documentParsingObserverSupplier = documentParsingObserverSupplier; + this.client = client; this.processorFactories = processorFactories( ingestPlugins, new Processor.Parameters( @@ -676,61 +685,112 @@ protected void doRun() { int i = 0; for (DocWriteRequest actionRequest : actionRequests) { IndexRequest indexRequest = TransportBulkAction.getIndexWriteRequest(actionRequest); - if (indexRequest == null) { - i++; - continue; - } + if (indexRequest != null) { + PipelineIterator pipelines = getAndResetPipelines(indexRequest); + if (pipelines.hasNext()) { + executePipelinesOnActionRequest( + actionRequest, + i, + refs.acquire(), + indexRequest, + pipelines, + onDropped, + onFailure + ); + } - PipelineIterator pipelines = getAndResetPipelines(indexRequest); - if (pipelines.hasNext() == false) { - i++; - continue; + String index = indexRequest.index(); + Map sourceMap = indexRequest.sourceAsMap(); + final int position = i; + sourceMap.entrySet().stream() + .filter(entry -> fieldNeedsInference(index, entry.getKey())) + .forEach(entry -> { + runInferenceForField(indexRequest, entry.getKey(), entry.getValue(), refs, position, onFailure); + }); } + i++; + } + } + } + }); + } - // start the stopwatch and acquire a ref to indicate that we're working on this document - final long startTimeInNanos = System.nanoTime(); - totalMetrics.preIngest(); - final int slot = i; - final Releasable ref = refs.acquire(); - // the document listener gives us three-way logic: a document can fail processing (1), or it can - // be successfully processed. a successfully processed document can be kept (2) or dropped (3). - final ActionListener documentListener = ActionListener.runAfter(new ActionListener<>() { - @Override - public void onResponse(Boolean kept) { - assert kept != null; - if (kept == false) { - onDropped.accept(slot); - } - } + // TODO actual mapping check here + private boolean fieldNeedsInference(String index, String fieldName) { + return fieldName.startsWith("infer_"); + } - @Override - public void onFailure(Exception e) { - totalMetrics.ingestFailed(); - onFailure.accept(slot, e); - } - }, () -> { - // regardless of success or failure, we always stop the ingest "stopwatch" and release the ref to indicate - // that we're finished with this document - final long ingestTimeInNanos = System.nanoTime() - startTimeInNanos; - totalMetrics.postIngest(ingestTimeInNanos); - ref.close(); - }); - DocumentParsingObserver documentParsingObserver = documentParsingObserverSupplier.get(); + private void runInferenceForField(IndexRequest indexRequest, String fieldName, Object fieldValue, RefCountingRunnable ref, int position, BiConsumer onFailure) { + var ingestDocument = newIngestDocument(indexRequest, documentParsingObserverSupplier.get()); + if (ingestDocument.hasField(fieldName) == false) { + return; + } - IngestDocument ingestDocument = newIngestDocument(indexRequest, documentParsingObserver); + ref.acquire(); - executePipelines(pipelines, indexRequest, ingestDocument, documentListener); - indexRequest.setPipelinesHaveRun(); + // TODO Hardcoding model ID and task type + InferenceAction.Request inferenceRequest = new InferenceAction.Request(TaskType.SPARSE_EMBEDDING, "my-elser-model", ingestDocument.getFieldValue(fieldName, String.class), Map.of()); - assert actionRequest.index() != null; - documentParsingObserver.setIndexName(actionRequest.index()); - documentParsingObserver.close(); + client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener() { + @Override + public void onResponse(InferenceAction.Response response) { + ingestDocument.setFieldValue(fieldName + ".inference", response.getResult().asMap(fieldName).get(fieldName)); + ref.close(); + } - i++; - } + @Override + public void onFailure(Exception e) { + onFailure.accept(position, e); + ref.close(); + } + }); + } + + private void executePipelinesOnActionRequest( + DocWriteRequest actionRequest, + final int slot, + final Releasable ref, + IndexRequest indexRequest, + PipelineIterator pipelines, + IntConsumer onDropped, + BiConsumer onFailure + ) { + // start the stopwatch and acquire a ref to indicate that we're working on this document + final long startTimeInNanos = System.nanoTime(); + totalMetrics.preIngest(); + // the document listener gives us three-way logic: a document can fail processing (1), or it can + // be successfully processed. a successfully processed document can be kept (2) or dropped (3). + final ActionListener documentListener = ActionListener.runAfter(new ActionListener<>() { + @Override + public void onResponse(Boolean kept) { + assert kept != null; + if (kept == false) { + onDropped.accept(slot); } } + + @Override + public void onFailure(Exception e) { + totalMetrics.ingestFailed(); + onFailure.accept(slot, e); + } + }, () -> { + // regardless of success or failure, we always stop the ingest "stopwatch" and release the ref to indicate + // that we're finished with this document + final long ingestTimeInNanos = System.nanoTime() - startTimeInNanos; + totalMetrics.postIngest(ingestTimeInNanos); + ref.close(); }); + DocumentParsingObserver documentParsingObserver = documentParsingObserverSupplier.get(); + + IngestDocument ingestDocument = newIngestDocument(indexRequest, documentParsingObserver); + + executePipelinesOnActionRequest(pipelines, indexRequest, ingestDocument, documentListener); + indexRequest.setPipelinesHaveRun(); + + assert actionRequest.index() != null; + documentParsingObserver.setIndexName(actionRequest.index()); + documentParsingObserver.close(); } /** @@ -805,7 +865,7 @@ public PipelineSlot next() { } } - private void executePipelines( + private void executePipelinesOnActionRequest( final PipelineIterator pipelines, final IndexRequest indexRequest, final IngestDocument ingestDocument, @@ -925,7 +985,7 @@ private void executePipelines( } if (newPipelines.hasNext()) { - executePipelines(newPipelines, indexRequest, ingestDocument, listener); + executePipelinesOnActionRequest(newPipelines, indexRequest, ingestDocument, listener); } else { // update the index request's source and (potentially) cache the timestamp for TSDB updateIndexRequestSource(indexRequest, ingestDocument);