From 338ecd781f5d832f28277175b449b1ae996ce340 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 31 Oct 2023 12:42:23 +0100 Subject: [PATCH] Made inference sequential so it works better with multiple fields --- ...FieldInferenceBulkRequestPreprocessor.java | 104 ++++++++++-------- 1 file changed, 59 insertions(+), 45 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java index 4e390bf92e569..21a99365255d9 100644 --- a/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java +++ b/server/src/main/java/org/elasticsearch/ingest/FieldInferenceBulkRequestPreprocessor.java @@ -18,15 +18,16 @@ import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.index.IndexService; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.index.Index; import org.elasticsearch.index.mapper.SemanticTextFieldMapper; import org.elasticsearch.indices.IndicesService; import org.elasticsearch.inference.TaskType; import org.elasticsearch.plugins.internal.DocumentParsingObserver; import java.util.HashMap; +import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiConsumer; import java.util.function.IntConsumer; import java.util.function.Supplier; @@ -65,22 +66,17 @@ protected void processIndexRequest( ) { assert indexRequest.isFieldInferenceDone() == false; - refs.acquire(); - // Inference responses can update the fields concurrently - final Map sourceMap = new ConcurrentHashMap<>(indexRequest.sourceAsMap()); - try (var inferenceRefs = new RefCountingRunnable(() -> onInferenceComplete(refs, indexRequest, sourceMap))) { - sourceMap.entrySet() - .stream() - .filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue())) - .forEach(entry -> { - runInferenceForField(indexRequest, entry.getKey(), inferenceRefs, slot, sourceMap, onFailure); - }); - } - } + IngestDocument ingestDocument = newIngestDocument(indexRequest); + List fieldNames = ingestDocument.getSource().entrySet() + .stream() + .filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue())) + .map(Map.Entry::getKey) + .toList(); + + // Runs inference sequentially. This makes easier sync and removes the problem of having multiple + // BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse + runInferenceForFields(indexRequest, fieldNames, refs.acquire(), slot, ingestDocument, onFailure); - private void onInferenceComplete(RefCountingRunnable refs, IndexRequest indexRequest, Map sourceMap) { - updateIndexRequestSource(indexRequest, newIngestDocument(indexRequest, sourceMap)); - refs.close(); } @Override @@ -112,26 +108,50 @@ private boolean fieldNeedsInference(IndexRequest indexRequest, String fieldName, } private String getModelForField(IndexRequest indexRequest, String fieldName) { - IndexService indexService = indicesService.indexService( - indexNameExpressionResolver.concreteSingleIndex(clusterService.state(), indexRequest) - ); - return indexService.mapperService().mappingLookup().modelForField(fieldName); + // Check all indices related to request have the same model id + String model = null; + try { + Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), indexRequest); + for (Index index : indices) { + String modelForIndex = indicesService.indexService(index).mapperService().mappingLookup().modelForField(fieldName); + if ((modelForIndex == null)) { + return null; + } + if ((model != null) && modelForIndex.equals(model) == false) { + return null; + } + model = modelForIndex; + } + } catch (Exception e) { + // There's a problem retrieving the index + return null; + } + return model; } - private void runInferenceForField( + private void runInferenceForFields( IndexRequest indexRequest, - String fieldName, - RefCountingRunnable refs, + List fieldNames, + Releasable ref, int position, - final Map sourceAsMap, + final IngestDocument ingestDocument, BiConsumer onFailure ) { - final String fieldValue = (String) sourceAsMap.get(fieldName); - if (fieldValue == null) { + // We finished processing + if (fieldNames.isEmpty()) { + updateIndexRequestSource(indexRequest, ingestDocument); + indexRequest.isFieldInferenceDone(true); + ref.close(); return; } + String fieldName = fieldNames.get(0); + List nextFieldNames = fieldNames.subList(1, fieldNames.size()); + final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class); + if (fieldValue == null) { + // Run inference for next field + runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure); + } - refs.acquire(); String modelForField = getModelForField(indexRequest, fieldName); assert modelForField != null : "Field " + fieldName + " has no model associated in mappings"; @@ -143,37 +163,31 @@ private void runInferenceForField( Map.of() ); - final long startTimeInNanos = System.nanoTime(); - ingestMetric.preIngest(); - client.execute(InferenceAction.INSTANCE, inferenceRequest, ActionListener.runAfter(new ActionListener() { + client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() { @Override public void onResponse(InferenceAction.Response response) { // Transform into two subfields, one with the actual text and other with the inference Map newFieldValue = new HashMap<>(); newFieldValue.put(SemanticTextFieldMapper.TEXT_SUBFIELD_NAME, fieldValue); newFieldValue.put( - SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, - response.getResult().asMap(fieldName).get(fieldName) + SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME, + response.getResult().asMap(fieldName).get(fieldName) ); - sourceAsMap.put(fieldName, newFieldValue); + ingestDocument.setFieldValue(fieldName, newFieldValue); + + // Run inference for next fields + runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure); } @Override public void onFailure(Exception e) { // Wrap exception in an illegal argument exception, as there is a problem with the model or model config onFailure.accept( - position, - new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e) + position, + new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e) ); - ingestMetric.ingestFailed(); + ref.close(); } - }, () -> { - // regardless of success or failure, we always stop the ingest "stopwatch" and release the ref to indicate - // that we're finished with this document - indexRequest.isFieldInferenceDone(true); - final long ingestTimeInNanos = System.nanoTime() - startTimeInNanos; - ingestMetric.postIngest(ingestTimeInNanos); - refs.close(); - })); + }); } }