Skip to content

Commit

Permalink
Merge from feature branch
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Mar 19, 2024
1 parent dd73d01 commit c5de0da
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 94 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -209,8 +209,8 @@ private void executeBulkRequestsByShard(Map<ShardId, List<BulkItemRequest>> requ
requests.toArray(new BulkItemRequest[0])
);
var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName());
if (indexMetadata != null && indexMetadata.getFieldsForModels().isEmpty() == false) {
bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldsForModels());
if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) {
bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata());
}
bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards());
bulkShardRequest.timeout(bulkRequest.timeout());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import org.elasticsearch.action.support.replication.ReplicatedWriteRequest;
import org.elasticsearch.action.support.replication.ReplicationRequest;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.FieldInferenceMetadata;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.util.set.Sets;
import org.elasticsearch.index.shard.ShardId;
import org.elasticsearch.transport.RawIndexingDataTransportRequest;

import java.io.IOException;
import java.util.Map;
import java.util.Set;

public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequest>
Expand All @@ -34,7 +34,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest<BulkShardRequ

private final BulkItemRequest[] items;

private transient Map<String, Set<String>> fieldsInferenceMetadata = null;
private transient FieldInferenceMetadata fieldsInferenceMetadataMap = null;

public BulkShardRequest(StreamInput in) throws IOException {
super(in);
Expand All @@ -51,24 +51,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe
* Public for test
* Set the transient metadata indicating that this request requires running inference before proceeding.
*/
public void setFieldInferenceMetadata(Map<String, Set<String>> fieldsInferenceMetadata) {
this.fieldsInferenceMetadata = fieldsInferenceMetadata;
public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) {
this.fieldsInferenceMetadataMap = fieldsInferenceMetadata;
}

/**
* Consumes the inference metadata to execute inference on the bulk items just once.
*/
public Map<String, Set<String>> consumeFieldInferenceMetadata() {
var ret = fieldsInferenceMetadata;
fieldsInferenceMetadata = null;
public FieldInferenceMetadata consumeFieldInferenceMetadata() {
FieldInferenceMetadata ret = fieldsInferenceMetadataMap;
fieldsInferenceMetadataMap = null;
return ret;
}

/**
* Public for test
*/
public Map<String, Set<String>> getFieldsInferenceMetadata() {
return fieldsInferenceMetadata;
public FieldInferenceMetadata getFieldsInferenceMetadataMap() {
return fieldsInferenceMetadataMap;
}

public long totalSizeInBytes() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.update.UpdateRequest;
import org.elasticsearch.cluster.metadata.FieldInferenceMetadata;
import org.elasticsearch.common.util.concurrent.AtomicArray;
import org.elasticsearch.common.xcontent.support.XContentMapValues;
import org.elasticsearch.core.Nullable;
Expand All @@ -44,7 +45,6 @@
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

/**
Expand Down Expand Up @@ -81,7 +81,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
case TransportShardBulkAction.ACTION_NAME:
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata();
if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) {
if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) {
Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener);
processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion);
} else {
Expand All @@ -96,7 +96,7 @@ public <Request extends ActionRequest, Response extends ActionResponse> void app
}

private void processBulkShardRequest(
Map<String, Set<String>> fieldInferenceMetadata,
FieldInferenceMetadata fieldInferenceMetadata,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
) {
Expand All @@ -112,13 +112,13 @@ private record FieldInferenceResponse(String field, Model model, ChunkedInferenc
private record FieldInferenceResponseAccumulator(int id, List<FieldInferenceResponse> responses, List<Exception> failures) {}

private class AsyncBulkShardInferenceAction implements Runnable {
private final Map<String, Set<String>> fieldInferenceMetadata;
private final FieldInferenceMetadata fieldInferenceMetadata;
private final BulkShardRequest bulkShardRequest;
private final Runnable onCompletion;
private final AtomicArray<FieldInferenceResponseAccumulator> inferenceResults;

private AsyncBulkShardInferenceAction(
Map<String, Set<String>> fieldInferenceMetadata,
FieldInferenceMetadata fieldInferenceMetadata,
BulkShardRequest bulkShardRequest,
Runnable onCompletion
) {
Expand Down Expand Up @@ -289,39 +289,35 @@ private Map<String, List<FieldInferenceRequest>> createFieldInferenceRequests(Bu
continue;
}
final Map<String, Object> docMap = indexRequest.sourceAsMap();
for (var entry : fieldInferenceMetadata.entrySet()) {
String inferenceId = entry.getKey();
for (var field : entry.getValue()) {
var value = XContentMapValues.extractValue(field, docMap);
if (value == null) {
continue;
}
if (inferenceResults.get(item.id()) == null) {
inferenceResults.set(
for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) {
String field = entry.getKey();
String inferenceId = entry.getValue().inferenceId();
var value = XContentMapValues.extractValue(field, docMap);
if (value == null) {
continue;
}
if (inferenceResults.get(item.id()) == null) {
inferenceResults.set(
item.id(),
new FieldInferenceResponseAccumulator(
item.id(),
new FieldInferenceResponseAccumulator(
item.id(),
Collections.synchronizedList(new ArrayList<>()),
Collections.synchronizedList(new ArrayList<>())
)
);
}
if (value instanceof String valueStr) {
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(
inferenceId,
k -> new ArrayList<>()
);
fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr));
} else {
inferenceResults.get(item.id()).failures.add(
new ElasticsearchStatusException(
"Invalid format for field [{}], expected [String] got [{}]",
RestStatus.BAD_REQUEST,
field,
value.getClass().getSimpleName()
)
);
}
Collections.synchronizedList(new ArrayList<>()),
Collections.synchronizedList(new ArrayList<>())
)
);
}
if (value instanceof String valueStr) {
List<FieldInferenceRequest> fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>());
fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr));
} else {
inferenceResults.get(item.id()).failures.add(
new ElasticsearchStatusException(
"Invalid format for field [{}], expected [String] got [{}]",
RestStatus.BAD_REQUEST,
field,
value.getClass().getSimpleName()
)
);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.action.index.IndexRequest;
import org.elasticsearch.action.support.ActionFilterChain;
import org.elasticsearch.action.support.WriteRequest;
import org.elasticsearch.cluster.metadata.FieldInferenceMetadata;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.index.shard.ShardId;
Expand All @@ -40,7 +41,6 @@

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
Expand Down Expand Up @@ -79,7 +79,7 @@ public void testFilterNoop() throws Exception {
CountDownLatch chainExecuted = new CountDownLatch(1);
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata());
assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap());
} finally {
chainExecuted.countDown();
}
Expand All @@ -91,7 +91,9 @@ public void testFilterNoop() throws Exception {
WriteRequest.RefreshPolicy.NONE,
new BulkItemRequest[0]
);
request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar")));
request.setFieldInferenceMetadata(
new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of())))
);
filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}
Expand All @@ -104,7 +106,7 @@ public void testInferenceNotFound() throws Exception {
ActionFilterChain actionFilterChain = (task, action, request, listener) -> {
try {
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertNull(bulkShardRequest.getFieldsInferenceMetadata());
assertNull(bulkShardRequest.getFieldsInferenceMetadataMap());
for (BulkItemRequest item : bulkShardRequest.items()) {
assertNotNull(item.getPrimaryResponse());
assertTrue(item.getPrimaryResponse().isFailed());
Expand All @@ -118,11 +120,15 @@ public void testInferenceNotFound() throws Exception {
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);

Map<String, Set<String>> inferenceFields = Map.of(
model.getInferenceEntityId(),
Set.of("field1"),
"inference_0",
Set.of("field2", "field3")
FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata(
Map.of(
"field1",
new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()),
"field2",
new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()),
"field3",
new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of())
)
);
BulkItemRequest[] items = new BulkItemRequest[10];
for (int i = 0; i < items.length; i++) {
Expand All @@ -144,19 +150,19 @@ public void testManyRandomDocs() throws Exception {
}

int numInferenceFields = randomIntBetween(1, 5);
Map<String, Set<String>> inferenceFields = new HashMap<>();
Map<String, FieldInferenceMetadata.FieldInferenceOptions> inferenceFieldsMap = new HashMap<>();
for (int i = 0; i < numInferenceFields; i++) {
String inferenceId = randomFrom(inferenceModelMap.keySet());
String field = randomAlphaOfLengthBetween(5, 10);
var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>());
res.add(field);
String inferenceId = randomFrom(inferenceModelMap.keySet());
inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of()));
}
FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap);

int numRequests = randomIntBetween(100, 1000);
BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests];
BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests];
for (int id = 0; id < numRequests; id++) {
BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields);
BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata);
originalRequests[id] = res[0];
modifiedRequests[id] = res[1];
}
Expand All @@ -167,7 +173,7 @@ public void testManyRandomDocs() throws Exception {
try {
assertThat(request, instanceOf(BulkShardRequest.class));
BulkShardRequest bulkShardRequest = (BulkShardRequest) request;
assertNull(bulkShardRequest.getFieldsInferenceMetadata());
assertNull(bulkShardRequest.getFieldsInferenceMetadataMap());
BulkItemRequest[] items = bulkShardRequest.items();
assertThat(items.length, equalTo(originalRequests.length));
for (int id = 0; id < items.length; id++) {
Expand All @@ -186,7 +192,7 @@ public void testManyRandomDocs() throws Exception {
ActionListener actionListener = mock(ActionListener.class);
Task task = mock(Task.class);
BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests);
original.setFieldInferenceMetadata(inferenceFields);
original.setFieldInferenceMetadata(fieldInferenceMetadata);
filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain);
awaitLatch(chainExecuted, 10, TimeUnit.SECONDS);
}
Expand Down Expand Up @@ -257,42 +263,40 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool
private static BulkItemRequest[] randomBulkItemRequest(
int id,
Map<String, StaticModel> modelMap,
Map<String, Set<String>> inferenceFieldMap
FieldInferenceMetadata fieldInferenceMetadata
) {
Map<String, Object> docMap = new LinkedHashMap<>();
Map<String, Object> inferenceResultsMap = new LinkedHashMap<>();
for (var entry : inferenceFieldMap.entrySet()) {
String inferenceId = entry.getKey();
var model = modelMap.get(inferenceId);
for (var field : entry.getValue()) {
String text = randomAlphaOfLengthBetween(10, 100);
docMap.put(field, text);
if (model == null) {
// ignore results, the doc should fail with a resource not found exception
continue;
}
int numChunks = randomIntBetween(1, 5);
List<String> chunks = new ArrayList<>();
for (int i = 0; i < numChunks; i++) {
chunks.add(randomAlphaOfLengthBetween(5, 10));
}
TaskType taskType = model.getTaskType();
final ChunkedInferenceServiceResults results;
switch (taskType) {
case TEXT_EMBEDDING:
results = randomTextEmbeddings(chunks);
break;
for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) {
String field = entry.getKey();
var model = modelMap.get(entry.getValue().inferenceId());
String text = randomAlphaOfLengthBetween(10, 100);
docMap.put(field, text);
if (model == null) {
// ignore results, the doc should fail with a resource not found exception
continue;
}
int numChunks = randomIntBetween(1, 5);
List<String> chunks = new ArrayList<>();
for (int i = 0; i < numChunks; i++) {
chunks.add(randomAlphaOfLengthBetween(5, 10));
}
TaskType taskType = model.getTaskType();
final ChunkedInferenceServiceResults results;
switch (taskType) {
case TEXT_EMBEDDING:
results = randomTextEmbeddings(chunks);
break;

case SPARSE_EMBEDDING:
results = randomSparseEmbeddings(chunks);
break;
case SPARSE_EMBEDDING:
results = randomSparseEmbeddings(chunks);
break;

default:
throw new AssertionError("Unknown task type " + taskType.name());
}
model.putResult(text, results);
InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results);
default:
throw new AssertionError("Unknown task type " + taskType.name());
}
model.putResult(text, results);
InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results);
}
Map<String, Object> expectedDocMap = new LinkedHashMap<>(docMap);
expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap);
Expand Down

0 comments on commit c5de0da

Please sign in to comment.