Skip to content

Commit

Permalink
More refactoring around runnables
Browse files Browse the repository at this point in the history
  • Loading branch information
carlosdelest committed Dec 22, 2023
1 parent 4802b23 commit 9bb66e1
Showing 1 changed file with 21 additions and 29 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -644,8 +644,7 @@ private Map<ShardId, List<BulkItemRequest>> groupRequestsByShards(ClusterState c
if (ia.getParentDataStream() != null &&
// avoid valid cases when directly indexing into a backing index
// (for example when directly indexing into .ds-logs-foobar-000001)
ia.getName().equals(docWriteRequest.index()) == false
&& docWriteRequest.opType() != OpType.CREATE) {
ia.getName().equals(docWriteRequest.index()) == false && docWriteRequest.opType() != OpType.CREATE) {
throw new IllegalArgumentException("only write ops with an op_type of create are allowed in data streams");
}

Expand Down Expand Up @@ -686,16 +685,15 @@ private void executeBulkRequestsByShard(Map<ShardId, List<BulkItemRequest>> requ
}

String nodeId = clusterService.localNode().getId();
try(RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(() -> {
Runnable onBulkItemsComplete = () -> {
listener.onResponse(
new BulkResponse(
responses.toArray(new BulkItemResponse[responses.length()]),
buildTookInMillis(startTimeNanos)
)
new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos))
);
// Allow memory for bulk shard request items to be reclaimed before all items have been completed
bulkRequest = null;
})) {
};

try (RefCountingRunnable bulkItemRequestCompleteRefCount = new RefCountingRunnable(onBulkItemsComplete)) {
for (Map.Entry<ShardId, List<BulkItemRequest>> entry : requestsByShard.entrySet()) {
final ShardId shardId = entry.getKey();
final List<BulkItemRequest> requests = entry.getValue();
Expand All @@ -717,11 +715,7 @@ private void executeBulkRequestsByShard(Map<ShardId, List<BulkItemRequest>> requ
}
}

private void performInferenceAndExecute(
BulkShardRequest bulkShardRequest,
ClusterState clusterState,
Releasable releaseOnFinish
) {
private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, ClusterState clusterState, Releasable releaseOnFinish) {

Map<String, Set<String>> fieldsForModels = clusterState.metadata()
.index(bulkShardRequest.shardId().getIndex())
Expand All @@ -730,18 +724,18 @@ private void performInferenceAndExecute(
executeBulkShardRequest(bulkShardRequest, releaseOnFinish);
}

// TODO Should we create a specific ThreadPool?
try (var bulkItemReqRef = new RefCountingRunnable(() -> {
Runnable onInferenceComplete = () -> {
BulkShardRequest errorsFilteredShardRequest = new BulkShardRequest(
bulkShardRequest.shardId(),
bulkRequest.getRefreshPolicy(),
Arrays.stream(bulkShardRequest.items()).filter(Objects::nonNull).toArray(BulkItemRequest[]::new)
);
executeBulkShardRequest(errorsFilteredShardRequest, releaseOnFinish);
})) {
};

try (var bulkItemReqRef = new RefCountingRunnable(onInferenceComplete)) {
for (BulkItemRequest request : bulkShardRequest.items()) {
performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef);
performInferenceOnBulkItemRequest(bulkShardRequest, request, fieldsForModels, bulkItemReqRef.acquire());
}
}
}
Expand All @@ -750,21 +744,19 @@ private void performInferenceOnBulkItemRequest(
BulkShardRequest bulkShardRequest,
BulkItemRequest request,
Map<String, Set<String>> fieldsForModels,
RefCountingRunnable bulkItemReqRef
Releasable releaseOnFinish
) {
DocWriteRequest<?> docWriteRequest = request.request();
Map<String, Object> sourceMap = null;
if (docWriteRequest instanceof IndexRequest indexRequest) {
sourceMap = indexRequest.sourceAsMap();
} else if (docWriteRequest instanceof UpdateRequest updateRequest) {
sourceMap = updateRequest.docAsUpsert()
? updateRequest.upsertRequest().sourceAsMap()
: updateRequest.doc().sourceAsMap();
sourceMap = updateRequest.docAsUpsert() ? updateRequest.upsertRequest().sourceAsMap() : updateRequest.doc().sourceAsMap();
}
if (sourceMap == null || sourceMap.isEmpty()) {
releaseOnFinish.close();
return;
}
bulkItemReqRef.acquire();
final Map<String, Object> docMap = new ConcurrentHashMap<>(sourceMap);

// When a document completes processing, update the source with the inference
Expand All @@ -778,7 +770,7 @@ private void performInferenceOnBulkItemRequest(
updateRequest.doc().source(docMap);
}
}
bulkItemReqRef.close();
releaseOnFinish.close();
})) {

for (Map.Entry<String, Set<String>> fieldModelsEntrySet : fieldsForModels.entrySet()) {
Expand Down Expand Up @@ -837,8 +829,11 @@ public void onFailure(Exception e) {

final String indexName = request.index();
DocWriteRequest<?> docWriteRequest = request.request();
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(indexName, docWriteRequest.id(),
new IllegalArgumentException("Error performing inference: " + e.getMessage(), e));
BulkItemResponse.Failure failure = new BulkItemResponse.Failure(
indexName,
docWriteRequest.id(),
new IllegalArgumentException("Error performing inference: " + e.getMessage(), e)
);
responses.set(request.id(), BulkItemResponse.failure(request.id(), docWriteRequest.opType(), failure));
// make sure the request gets never processed again
bulkShardRequest.items()[request.id()] = null;
Expand Down Expand Up @@ -895,10 +890,7 @@ private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasab
if (bulkShardRequest.items().length == 0) {
// No requests to execute due to previous errors, terminate early
listener.onResponse(
new BulkResponse(
responses.toArray(new BulkItemResponse[responses.length()]),
buildTookInMillis(startTimeNanos)
)
new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos))
);
releaseOnFinish.close();
return;
Expand Down

0 comments on commit 9bb66e1

Please sign in to comment.