Skip to content

Commit

Permalink
Check for existing download
Browse files Browse the repository at this point in the history
  • Loading branch information
davidkyle committed Nov 15, 2024
1 parent f400839 commit 7d98561
Show file tree
Hide file tree
Showing 4 changed files with 207 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0; you may not use this file except in compliance with the Elastic License
* 2.0.
*/

package org.elasticsearch.xpack.ml.packageloader.action;

import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.support.master.AcknowledgedResponse;
import org.elasticsearch.tasks.RemovedTaskListener;
import org.elasticsearch.tasks.Task;

public record DownloadTaskRemovedListener(ModelDownloadTask trackedTask, ActionListener<AcknowledgedResponse> listener)
implements
RemovedTaskListener {

@Override
public void onRemoved(Task task) {
if (task.getId() == trackedTask.getId() && task.getAction().equals(trackedTask.getAction())) {
if (trackedTask.getTaskException() == null) {
listener.onResponse(AcknowledgedResponse.TRUE);
} else {
listener.onFailure(trackedTask.getTaskException());
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.elasticsearch.tasks.Task;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.xcontent.XContentBuilder;
import org.elasticsearch.xpack.core.ml.MlTasks;

import java.io.IOException;
import java.util.Map;
Expand Down Expand Up @@ -51,9 +52,12 @@ public void writeTo(StreamOutput out) throws IOException {
}

private final AtomicReference<DownLoadProgress> downloadProgress = new AtomicReference<>(new DownLoadProgress(0, 0));
private final String modelId;
private volatile Exception taskException;

public ModelDownloadTask(long id, String type, String action, String description, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, description, parentTaskId, headers);
public ModelDownloadTask(long id, String type, String action, String modelId, TaskId parentTaskId, Map<String, String> headers) {
super(id, type, action, taskDescription(modelId), parentTaskId, headers);
this.modelId = modelId;
}

void setProgress(int totalParts, int downloadedParts) {
Expand All @@ -65,4 +69,19 @@ public DownloadStatus getStatus() {
return new DownloadStatus(downloadProgress.get());
}

public String getModelId() {
return modelId;
}

public void setTaskException(Exception exception) {
this.taskException = exception;
}

public Exception getTaskException() {
return taskException;
}

public static String taskDescription(String modelId) {
return MlTasks.downloadModelTaskDescription(modelId);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,6 @@
import org.elasticsearch.tasks.TaskAwareRequest;
import org.elasticsearch.tasks.TaskCancelledException;
import org.elasticsearch.tasks.TaskId;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.TransportService;
import org.elasticsearch.xpack.core.common.notifications.Level;
Expand All @@ -42,21 +41,24 @@
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN;
import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION;
import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE;
import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription;

public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction<Request, AcknowledgedResponse> {

private static final Logger logger = LogManager.getLogger(TransportLoadTrainedModelPackage.class);

private final Client client;
private final CircuitBreakerService circuitBreakerService;
final Map<String, List<DownloadTaskRemovedListener>> downloadTrackersByModelId;

@Inject
public TransportLoadTrainedModelPackage(
Expand All @@ -81,6 +83,7 @@ public TransportLoadTrainedModelPackage(
);
this.client = new OriginSettingClient(client, ML_ORIGIN);
this.circuitBreakerService = circuitBreakerService;
downloadTrackersByModelId = new HashMap<>();
}

@Override
Expand All @@ -91,6 +94,17 @@ protected ClusterBlockException checkBlock(Request request, ClusterState state)
@Override
protected void masterOperation(Task task, Request request, ClusterState state, ActionListener<AcknowledgedResponse> listener)
throws Exception {
if (existingDownloadInProgress(request.getModelId(), request.isWaitForCompletion(), listener)) {
logger.debug("Existing download of model [{}] in progress", request.getModelId());

if (request.isWaitForCompletion() == false) {
listener.onResponse(AcknowledgedResponse.TRUE);
}

// download in progress, nothing to do
return;
}

ModelDownloadTask downloadTask = createDownloadTask(request);

try {
Expand All @@ -107,7 +121,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A

var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.<AcknowledgedResponse>noop();

importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask);
importModel(client, () -> unregisterTask(downloadTask), request, modelImporter, downloadTask, downloadCompleteListener);
} catch (Exception e) {
taskManager.unregister(downloadTask);
listener.onFailure(e);
Expand All @@ -124,22 +138,90 @@ private ParentTaskAssigningClient getParentTaskAssigningClient(Task originTask)
return new ParentTaskAssigningClient(client, parentTaskId);
}

/**
* Look for a current download task of the model and optionally wait
* for that task to complete if there is one.
* synchronized with {@code unregisterTask} to prevent the task being
* removed before the remove listener is added.
* @param modelId Model being downloaded
* @param isWaitForCompletion Wait until the download completes before
* calling the listener
* @param listener Model download listener
* @return True if a download task is in progress
*/
synchronized boolean existingDownloadInProgress(
String modelId,
boolean isWaitForCompletion,
ActionListener<AcknowledgedResponse> listener
) {
var description = ModelDownloadTask.taskDescription(modelId);
var tasks = taskManager.getCancellableTasks().values();

ModelDownloadTask inProgress = null;
for (var task : tasks) {
if (description.equals(task.getDescription()) && task instanceof ModelDownloadTask downloadTask) {
inProgress = downloadTask;
break;
}
}

if (inProgress != null) {
if (isWaitForCompletion == false) {
// Not waiting for the download to complete, it is enough that
// the download is in progress
return true;
}
// Otherwise register a task removed listener which is called
// once the tasks is complete and unregistered
var tracker = new DownloadTaskRemovedListener(inProgress, listener);
downloadTrackersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker);
taskManager.registerRemovedTaskListener(tracker);
return true;
}

return false;
}

/**
* Unregister the completed task triggering any remove task listeners.
* This method is synchronized to prevent the task being removed while
* {@code waitForExistingDownload} is in progress.
* @param task The completed task
*/
synchronized void unregisterTask(ModelDownloadTask task) {
taskManager.unregister(task); // unregister will call the on remove function

var trackers = downloadTrackersByModelId.remove(task.getModelId());
if (trackers != null) {
for (var tracker : trackers) {
taskManager.unregisterRemovedTaskListener(tracker);
}
}
}

/**
* This is package scope so that we can test the logic directly.
* This should only be called from the masterOperation method and the tests
* This should only be called from the masterOperation method and the tests.
* This method is static for testing.
*
* @param auditClient a client which should only be used to send audit notifications. This client cannot be associated with the passed
* in task, that way when the task is cancelled the notification requests can
* still be performed. If it is associated with the task (i.e. via ParentTaskAssigningClient),
* then the requests will throw a TaskCancelledException.
* @param unregisterTaskFn Runnable to unregister the task. Because this is a static function
* a lambda is used rather than the instance method.
* @param request The download request
* @param modelImporter The importer
* @param task Download task
* @param listener Listener
*/
static void importModel(
Client auditClient,
TaskManager taskManager,
Runnable unregisterTaskFn,
Request request,
ModelImporter modelImporter,
ActionListener<AcknowledgedResponse> listener,
Task task
ModelDownloadTask task,
ActionListener<AcknowledgedResponse> listener
) {
final String modelId = request.getModelId();
final long relativeStartNanos = System.nanoTime();
Expand All @@ -155,9 +237,12 @@ static void importModel(
Level.INFO
);
listener.onResponse(AcknowledgedResponse.TRUE);
}, exception -> listener.onFailure(processException(auditClient, modelId, exception)));
}, exception -> {
task.setTaskException(exception);
listener.onFailure(processException(auditClient, modelId, exception));
});

modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task)));
modelImporter.doImport(ActionListener.runAfter(finishListener, unregisterTaskFn));
}

static Exception processException(Client auditClient, String modelId, Exception e) {
Expand Down Expand Up @@ -197,14 +282,7 @@ public TaskId getParentTask() {

@Override
public ModelDownloadTask createTask(long id, String type, String action, TaskId parentTaskId, Map<String, String> headers) {
return new ModelDownloadTask(
id,
type,
action,
downloadModelTaskDescription(request.getModelId()),
parentTaskId,
headers
);
return new ModelDownloadTask(id, type, action, request.getModelId(), parentTaskId, headers);
}
}, false);
}
Expand Down
Loading

0 comments on commit 7d98561

Please sign in to comment.