Skip to content

Commit

Permalink
add more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jimczi committed Mar 18, 2024
1 parent ebc26d2 commit 86ddc9d
Show file tree
Hide file tree
Showing 8 changed files with 544 additions and 153 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,10 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe
}

/**
* Set the transient metadata indicating that this request requires running inference
* before proceeding.
* Public for test
* Set the transient metadata indicating that this request requires running inference before proceeding.
*/
void setFieldInferenceMetadata(Map<String, Set<String>> fieldsInferenceMetadata) {
public void setFieldInferenceMetadata(Map<String, Set<String>> fieldsInferenceMetadata) {
this.fieldsInferenceMetadata = fieldsInferenceMetadata;
}

Expand All @@ -64,6 +64,13 @@ public Map<String, Set<String>> consumeFieldInferenceMetadata() {
return ret;
}

/**
* Public for test
*/
public Map<String, Set<String>> getFieldsInferenceMetadata() {
return fieldsInferenceMetadata;
}

public long totalSizeInBytes() {
long totalSizeInBytes = 0;
for (int i = 0; i < items.length; i++) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1086,6 +1086,10 @@ public String typeName() {
return CONTENT_TYPE;
}

public Integer getDims() {
return dims;
}

@Override
public ValueFetcher valueFetcher(SearchExecutionContext context, String format) {
if (format != null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,8 @@
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderFactory;
import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings;
import org.elasticsearch.xpack.inference.logging.ThrottlerManager;
import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;
import org.elasticsearch.xpack.inference.rest.RestDeleteInferenceModelAction;
import org.elasticsearch.xpack.inference.rest.RestGetInferenceModelAction;
Expand Down Expand Up @@ -285,7 +285,7 @@ public Map<String, Mapper.TypeParser> getMappers() {

@Override
public Map<String, MetadataFieldMapper.TypeParser> getMetadataMappers() {
return Map.of(SemanticTextInferenceResultFieldMapper.NAME, SemanticTextInferenceResultFieldMapper.PARSER);
return Map.of(InferenceResultFieldMapper.NAME, InferenceResultFieldMapper.PARSER);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.ResourceNotFoundException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRequest;
Expand All @@ -33,11 +34,9 @@
import org.elasticsearch.inference.InferenceServiceRegistry;
import org.elasticsearch.inference.InputType;
import org.elasticsearch.inference.Model;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.tasks.Task;
import org.elasticsearch.xpack.core.inference.results.ChunkedSparseEmbeddingResults;
import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults;
import org.elasticsearch.xpack.inference.mapper.SemanticTextInferenceResultFieldMapper;
import org.elasticsearch.xpack.inference.mapper.SemanticTextModelSettings;
import org.elasticsearch.xpack.inference.mapper.InferenceResultFieldMapper;
import org.elasticsearch.xpack.inference.registry.ModelRegistry;

import java.util.ArrayList;
Expand All @@ -50,7 +49,7 @@

/**
* An {@link ActionFilter} that performs inference on {@link BulkShardRequest} asynchronously and stores the results in
* the individual {@link BulkItemRequest}. The results are then consumed by the {@link SemanticTextInferenceResultFieldMapper}
* the individual {@link BulkItemRequest}. The results are then consumed by the {@link InferenceResultFieldMapper}
* in the subsequent {@link TransportShardBulkAction} downstream.
*/
public class ShardBulkInferenceActionFilter implements ActionFilter {
Expand Down Expand Up @@ -82,7 +81,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
case TransportShardBulkAction.ACTION_NAME:
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata();
if (fieldInferenceMetadata != null) {
if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) {
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
} else {
Expand Down Expand Up @@ -110,18 +109,7 @@ private record FieldInferenceRequest(int id, String field, String input) {}

private record FieldInferenceResponse(String field, Model model, ChunkedInferenceServiceResults chunkedResults) {}

private record FieldInferenceResponseAccumulator(int id, List<FieldInferenceResponse> responses, List<Exception> failures) {
Exception createFailureOrNull() {
if (failures.isEmpty()) {
return null;
}
Exception main = failures.get(0);
for (int i = 1; i < failures.size(); i++) {
main.addSuppressed(failures.get(i));
}
return main;
}
}
private record FieldInferenceResponseAccumulator(int id, List<FieldInferenceResponse> responses, List<Exception> failures) {}

private class AsyncBulkShardInferenceAction implements Runnable {
private final Map<String, Set<String>> fieldInferenceMetadata;
Expand All @@ -147,7 +135,11 @@ public void run() {
try {
for (var inferenceResponse : inferenceResults.asList()) {
var request = bulkShardRequest.items()[inferenceResponse.id];
applyInference(request, inferenceResponse);
try {
applyInferenceResponses(request, inferenceResponse);
} catch (Exception exc) {
request.abort(bulkShardRequest.index(), exc);
}
}
} finally {
onCompletion.run();
Expand Down Expand Up @@ -189,8 +181,8 @@ public void onResponse(ModelRegistry.UnparsedModel unparsedModel) {
var request = requests.get(i);
inferenceResults.get(request.id).failures.add(
new ResourceNotFoundException(
"Inference service [{}] not found for field [{}]",
unparsedModel.service(),
"Inference id [{}] not found for field [{}]",
inferenceId,
request.field
)
);
Expand Down Expand Up @@ -221,9 +213,8 @@ public void onResponse(List<ChunkedInferenceServiceResults> results) {
for (int i = 0; i < results.size(); i++) {
var request = requests.get(i);
var result = results.get(i);
inferenceResults.get(request.id).responses.add(
new FieldInferenceResponse(request.field, inferenceProvider.model, result)
);
var acc = inferenceResults.get(request.id);
acc.responses.add(new FieldInferenceResponse(request.field, inferenceProvider.model, result));
}
}

Expand Down Expand Up @@ -254,38 +245,34 @@ public void onFailure(Exception exc) {
}

/**
* Apply the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}.
* Applies the {@link FieldInferenceResponseAccumulator} to the provider {@link BulkItemRequest}.
* If the response contains failures, the bulk item request is mark as failed for the downstream action.
* Otherwise, the source of the request is augmented with the field inference results.
*/
private void applyInference(BulkItemRequest request, FieldInferenceResponseAccumulator inferenceResponse) {
Exception failure = inferenceResponse.createFailureOrNull();
if (failure != null) {
request.abort(bulkShardRequest.index(), failure);
private void applyInferenceResponses(BulkItemRequest item, FieldInferenceResponseAccumulator response) {
if (response.failures().isEmpty() == false) {
for (var failure : response.failures()) {
item.abort(item.index(), failure);
}
return;
}
final IndexRequest indexRequest = getIndexRequestOrNull(request.request());
final Map<String, Object> newDocMap = indexRequest.sourceAsMap();
final Map<String, Object> inferenceMetadataMap = new LinkedHashMap<>();
newDocMap.put(SemanticTextInferenceResultFieldMapper.NAME, inferenceMetadataMap);
for (FieldInferenceResponse fieldResponse : inferenceResponse.responses) {
List<Map<String, Object>> chunks = new ArrayList<>();
if (fieldResponse.chunkedResults instanceof ChunkedSparseEmbeddingResults textExpansionResults) {
for (var chunk : textExpansionResults.getChunkedResults()) {
chunks.add(chunk.asMap());
}
} else if (fieldResponse.chunkedResults instanceof ChunkedTextEmbeddingResults textEmbeddingResults) {
for (var chunk : textEmbeddingResults.getChunks()) {
chunks.add(chunk.asMap());
}
} else {
request.abort(bulkShardRequest.index(), new IllegalArgumentException("TODO"));
return;

final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
Map<String, Object> newDocMap = indexRequest.sourceAsMap();
Map<String, Object> inferenceMap = new LinkedHashMap<>();
// ignore the existing inference map if any
newDocMap.put(InferenceResultFieldMapper.NAME, inferenceMap);
for (FieldInferenceResponse fieldResponse : response.responses()) {
try {
InferenceResultFieldMapper.applyFieldInference(
inferenceMap,
fieldResponse.field(),
fieldResponse.model(),
fieldResponse.chunkedResults()
);
} catch (Exception exc) {
item.abort(item.index(), exc);
}
Map<String, Object> fieldMap = new LinkedHashMap<>();
fieldMap.putAll(new SemanticTextModelSettings(fieldResponse.model).asMap());
fieldMap.put(SemanticTextInferenceResultFieldMapper.INFERENCE_RESULTS, chunks);
inferenceMetadataMap.put(fieldResponse.field, fieldMap);
}
indexRequest.source(newDocMap);
}
Expand All @@ -294,38 +281,46 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
Map<String, List<FieldInferenceRequest>> fieldRequestsMap = new LinkedHashMap<>();
for (var item : bulkShardRequest.items()) {
if (item.getPrimaryResponse() != null) {
// item was already aborted/processed by a filter in the chain upstream (e.g. security).
// item was already aborted/processed by a filter in the chain upstream (e.g. security)
continue;
}
final IndexRequest indexRequest = getIndexRequestOrNull(item.request());
if (indexRequest == null) {
continue;
}
final Map<String, Object> docMap = indexRequest.sourceAsMap();
List<FieldInferenceRequest> fieldRequests = null;
for (var pair : fieldInferenceMetadata.entrySet()) {
String inferenceId = pair.getKey();
for (var field : pair.getValue()) {
for (var entry : fieldInferenceMetadata.entrySet()) {
String inferenceId = entry.getKey();
for (var field : entry.getValue()) {
var value = XContentMapValues.extractValue(field, docMap);
if (value == null) {
continue;
}
if (value instanceof String valueStr) {
if (inferenceResults.get(item.id()) == null) {
inferenceResults.set(
if (inferenceResults.get(item.id()) == null) {
inferenceResults.set(
item.id(),
new FieldInferenceResponseAccumulator(
item.id(),
new FieldInferenceResponseAccumulator(
item.id(),
Collections.synchronizedList(new ArrayList<>()),
Collections.synchronizedList(new ArrayList<>())
)
);
}
if (fieldRequests == null) {
fieldRequests = new ArrayList<>();
fieldRequestsMap.put(inferenceId, fieldRequests);
}
Collections.synchronizedList(new ArrayList<>()),
Collections.synchronizedList(new ArrayList<>())
)
);
}
if (value instanceof String valueStr) {
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(
inferenceId,
k -> new ArrayList<>()
);
fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr));
} else {
inferenceResults.get(item.id()).failures.add(
new ElasticsearchStatusException(
"Invalid format for field [{}], expected [String] got [{}]",
RestStatus.BAD_REQUEST,
field,
value.getClass().getSimpleName()
)
);
}
}
}
Expand Down
Loading

0 comments on commit 86ddc9d

Please sign in to comment.