Skip to content

Commit

Permalink
Made inference sequential so it works better with multiple fields
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Oct 31, 2023
1 parent ac89ac5 commit 338ecd7
Showing 1 changed file with 59 additions and 45 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,16 @@
import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver;
import org.elasticsearch.cluster.metadata.Metadata;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.index.IndexService;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.index.Index;
import org.elasticsearch.index.mapper.SemanticTextFieldMapper;
import org.elasticsearch.indices.IndicesService;
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.plugins.internal.DocumentParsingObserver;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiConsumer;
import java.util.function.IntConsumer;
import java.util.function.Supplier;
Expand Down Expand Up @@ -65,22 +66,17 @@ protected void processIndexRequest(
) {
assert indexRequest.isFieldInferenceDone() == false;

refs.acquire();
// Inference responses can update the fields concurrently
final Map<String, Object> sourceMap = new ConcurrentHashMap<>(indexRequest.sourceAsMap());
try (var inferenceRefs = new RefCountingRunnable(() -> onInferenceComplete(refs, indexRequest, sourceMap))) {
sourceMap.entrySet()
.stream()
.filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue()))
.forEach(entry -> {
runInferenceForField(indexRequest, entry.getKey(), inferenceRefs, slot, sourceMap, onFailure);
});
}
}
IngestDocument ingestDocument = newIngestDocument(indexRequest);
List<String> fieldNames = ingestDocument.getSource().entrySet()
.stream()
.filter(entry -> fieldNeedsInference(indexRequest, entry.getKey(), entry.getValue()))
.map(Map.Entry::getKey)
.toList();

// Runs inference sequentially. This makes easier sync and removes the problem of having multiple
// BulkItemResponses for a single bulk request in TransportBulkAction.unwrappingSingleItemBulkResponse
runInferenceForFields(indexRequest, fieldNames, refs.acquire(), slot, ingestDocument, onFailure);

private void onInferenceComplete(RefCountingRunnable refs, IndexRequest indexRequest, Map<String, Object> sourceMap) {
updateIndexRequestSource(indexRequest, newIngestDocument(indexRequest, sourceMap));
refs.close();
}

@Override
Expand Down Expand Up @@ -112,26 +108,50 @@ private boolean fieldNeedsInference(IndexRequest indexRequest, String fieldName,
}

private String getModelForField(IndexRequest indexRequest, String fieldName) {
IndexService indexService = indicesService.indexService(
indexNameExpressionResolver.concreteSingleIndex(clusterService.state(), indexRequest)
);
return indexService.mapperService().mappingLookup().modelForField(fieldName);
// Check all indices related to request have the same model id
String model = null;
try {
Index[] indices = indexNameExpressionResolver.concreteIndices(clusterService.state(), indexRequest);
for (Index index : indices) {
String modelForIndex = indicesService.indexService(index).mapperService().mappingLookup().modelForField(fieldName);
if ((modelForIndex == null)) {
return null;
}
if ((model != null) && modelForIndex.equals(model) == false) {
return null;
}
model = modelForIndex;
}
} catch (Exception e) {
// There's a problem retrieving the index
return null;
}
return model;
}

private void runInferenceForField(
private void runInferenceForFields(
IndexRequest indexRequest,
String fieldName,
RefCountingRunnable refs,
List<String> fieldNames,
Releasable ref,
int position,
final Map<String, Object> sourceAsMap,
final IngestDocument ingestDocument,
BiConsumer<Integer, Exception> onFailure
) {
final String fieldValue = (String) sourceAsMap.get(fieldName);
if (fieldValue == null) {
// We finished processing
if (fieldNames.isEmpty()) {
updateIndexRequestSource(indexRequest, ingestDocument);
indexRequest.isFieldInferenceDone(true);
ref.close();
return;
}
String fieldName = fieldNames.get(0);
List<String> nextFieldNames = fieldNames.subList(1, fieldNames.size());
final String fieldValue = ingestDocument.getFieldValue(fieldName, String.class);
if (fieldValue == null) {
// Run inference for next field
runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure);
}

refs.acquire();
String modelForField = getModelForField(indexRequest, fieldName);
assert modelForField != null : "Field " + fieldName + " has no model associated in mappings";

Expand All @@ -143,37 +163,31 @@ private void runInferenceForField(
Map.of()
);

final long startTimeInNanos = System.nanoTime();
ingestMetric.preIngest();
client.execute(InferenceAction.INSTANCE, inferenceRequest, ActionListener.runAfter(new ActionListener<InferenceAction.Response>() {
client.execute(InferenceAction.INSTANCE, inferenceRequest, new ActionListener<>() {
@Override
public void onResponse(InferenceAction.Response response) {
// Transform into two subfields, one with the actual text and other with the inference
Map<String, Object> newFieldValue = new HashMap<>();
newFieldValue.put(SemanticTextFieldMapper.TEXT_SUBFIELD_NAME, fieldValue);
newFieldValue.put(
SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME,
response.getResult().asMap(fieldName).get(fieldName)
SemanticTextFieldMapper.SPARSE_VECTOR_SUBFIELD_NAME,
response.getResult().asMap(fieldName).get(fieldName)
);
sourceAsMap.put(fieldName, newFieldValue);
ingestDocument.setFieldValue(fieldName, newFieldValue);

// Run inference for next fields
runInferenceForFields(indexRequest, nextFieldNames, ref, position, ingestDocument, onFailure);
}

@Override
public void onFailure(Exception e) {
// Wrap exception in an illegal argument exception, as there is a problem with the model or model config
onFailure.accept(
position,
new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e)
position,
new IllegalArgumentException("Error performing inference for field [" + fieldName + "]: " + e.getMessage(), e)
);
ingestMetric.ingestFailed();
ref.close();
}
}, () -> {
// regardless of success or failure, we always stop the ingest "stopwatch" and release the ref to indicate
// that we're finished with this document
indexRequest.isFieldInferenceDone(true);
final long ingestTimeInNanos = System.nanoTime() - startTimeInNanos;
ingestMetric.postIngest(ingestTimeInNanos);
refs.close();
}));
});
}
}

0 comments on commit 338ecd7

Please sign in to comment.