diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java new file mode 100644 index 0000000000000..9ef04dad888ea --- /dev/null +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/DownloadTaskRemovedListener.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.packageloader.action; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.tasks.RemovedTaskListener; +import org.elasticsearch.tasks.Task; + +public record DownloadTaskRemovedListener(ModelDownloadTask trackedTask, ActionListener listener) + implements + RemovedTaskListener { + + @Override + public void onRemoved(Task task) { + if (task.getId() == trackedTask.getId() && task.getAction().equals(trackedTask.getAction())) { + if (trackedTask.getTaskException() == null) { + listener.onResponse(AcknowledgedResponse.TRUE); + } else { + listener.onFailure(trackedTask.getTaskException()); + } + } + } +} diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java index 59977bd418e11..dd09c3cf65fec 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/ModelDownloadTask.java @@ -13,6 +13,7 @@ import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskId; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.MlTasks; import java.io.IOException; import java.util.Map; @@ -51,9 +52,12 @@ public void writeTo(StreamOutput out) throws IOException { } private final AtomicReference downloadProgress = new AtomicReference<>(new DownLoadProgress(0, 0)); + private final String modelId; + private volatile Exception taskException; - public ModelDownloadTask(long id, String type, String action, String description, TaskId parentTaskId, Map headers) { - super(id, type, action, description, parentTaskId, headers); + public ModelDownloadTask(long id, String type, String action, String modelId, TaskId parentTaskId, Map headers) { + super(id, type, action, taskDescription(modelId), parentTaskId, headers); + this.modelId = modelId; } void setProgress(int totalParts, int downloadedParts) { @@ -65,4 +69,19 @@ public DownloadStatus getStatus() { return new DownloadStatus(downloadProgress.get()); } + public String getModelId() { + return modelId; + } + + public void setTaskException(Exception exception) { + this.taskException = exception; + } + + public Exception getTaskException() { + return taskException; + } + + public static String taskDescription(String modelId) { + return MlTasks.downloadModelTaskDescription(modelId); + } } diff --git a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java index 76b7781b1cffe..b286ca95c74c0 100644 --- a/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java +++ b/x-pack/plugin/ml-package-loader/src/main/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackage.java @@ -30,7 +30,6 @@ import org.elasticsearch.tasks.TaskAwareRequest; import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.tasks.TaskId; -import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.common.notifications.Level; @@ -42,6 +41,9 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.TimeUnit; @@ -49,7 +51,6 @@ import static org.elasticsearch.xpack.core.ClientHelper.ML_ORIGIN; import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION; import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE; -import static org.elasticsearch.xpack.core.ml.MlTasks.downloadModelTaskDescription; public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction { @@ -57,6 +58,7 @@ public class TransportLoadTrainedModelPackage extends TransportMasterNodeAction< private final Client client; private final CircuitBreakerService circuitBreakerService; + final Map> downloadTrackersByModelId; @Inject public TransportLoadTrainedModelPackage( @@ -81,6 +83,7 @@ public TransportLoadTrainedModelPackage( ); this.client = new OriginSettingClient(client, ML_ORIGIN); this.circuitBreakerService = circuitBreakerService; + downloadTrackersByModelId = new HashMap<>(); } @Override @@ -91,6 +94,17 @@ protected ClusterBlockException checkBlock(Request request, ClusterState state) @Override protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { + if (existingDownloadInProgress(request.getModelId(), request.isWaitForCompletion(), listener)) { + logger.debug("Existing download of model [{}] in progress", request.getModelId()); + + if (request.isWaitForCompletion() == false) { + listener.onResponse(AcknowledgedResponse.TRUE); + } + + // download in progress, nothing to do + return; + } + ModelDownloadTask downloadTask = createDownloadTask(request); try { @@ -107,7 +121,7 @@ protected void masterOperation(Task task, Request request, ClusterState state, A var downloadCompleteListener = request.isWaitForCompletion() ? listener : ActionListener.noop(); - importModel(client, taskManager, request, modelImporter, downloadCompleteListener, downloadTask); + importModel(client, () -> unregisterTask(downloadTask), request, modelImporter, downloadTask, downloadCompleteListener); } catch (Exception e) { taskManager.unregister(downloadTask); listener.onFailure(e); @@ -124,22 +138,90 @@ private ParentTaskAssigningClient getParentTaskAssigningClient(Task originTask) return new ParentTaskAssigningClient(client, parentTaskId); } + /** + * Look for a current download task of the model and optionally wait + * for that task to complete if there is one. + * synchronized with {@code unregisterTask} to prevent the task being + * removed before the remove listener is added. + * @param modelId Model being downloaded + * @param isWaitForCompletion Wait until the download completes before + * calling the listener + * @param listener Model download listener + * @return True if a download task is in progress + */ + synchronized boolean existingDownloadInProgress( + String modelId, + boolean isWaitForCompletion, + ActionListener listener + ) { + var description = ModelDownloadTask.taskDescription(modelId); + var tasks = taskManager.getCancellableTasks().values(); + + ModelDownloadTask inProgress = null; + for (var task : tasks) { + if (description.equals(task.getDescription()) && task instanceof ModelDownloadTask downloadTask) { + inProgress = downloadTask; + break; + } + } + + if (inProgress != null) { + if (isWaitForCompletion == false) { + // Not waiting for the download to complete, it is enough that + // the download is in progress + return true; + } + // Otherwise register a task removed listener which is called + // once the tasks is complete and unregistered + var tracker = new DownloadTaskRemovedListener(inProgress, listener); + downloadTrackersByModelId.computeIfAbsent(modelId, s -> new ArrayList<>()).add(tracker); + taskManager.registerRemovedTaskListener(tracker); + return true; + } + + return false; + } + + /** + * Unregister the completed task triggering any remove task listeners. + * This method is synchronized to prevent the task being removed while + * {@code waitForExistingDownload} is in progress. + * @param task The completed task + */ + synchronized void unregisterTask(ModelDownloadTask task) { + taskManager.unregister(task); // unregister will call the on remove function + + var trackers = downloadTrackersByModelId.remove(task.getModelId()); + if (trackers != null) { + for (var tracker : trackers) { + taskManager.unregisterRemovedTaskListener(tracker); + } + } + } + /** * This is package scope so that we can test the logic directly. - * This should only be called from the masterOperation method and the tests + * This should only be called from the masterOperation method and the tests. + * This method is static for testing. * * @param auditClient a client which should only be used to send audit notifications. This client cannot be associated with the passed * in task, that way when the task is cancelled the notification requests can * still be performed. If it is associated with the task (i.e. via ParentTaskAssigningClient), * then the requests will throw a TaskCancelledException. + * @param unregisterTaskFn Runnable to unregister the task. Because this is a static function + * a lambda is used rather than the instance method. + * @param request The download request + * @param modelImporter The importer + * @param task Download task + * @param listener Listener */ static void importModel( Client auditClient, - TaskManager taskManager, + Runnable unregisterTaskFn, Request request, ModelImporter modelImporter, - ActionListener listener, - Task task + ModelDownloadTask task, + ActionListener listener ) { final String modelId = request.getModelId(); final long relativeStartNanos = System.nanoTime(); @@ -155,9 +237,12 @@ static void importModel( Level.INFO ); listener.onResponse(AcknowledgedResponse.TRUE); - }, exception -> listener.onFailure(processException(auditClient, modelId, exception))); + }, exception -> { + task.setTaskException(exception); + listener.onFailure(processException(auditClient, modelId, exception)); + }); - modelImporter.doImport(ActionListener.runAfter(finishListener, () -> taskManager.unregister(task))); + modelImporter.doImport(ActionListener.runAfter(finishListener, unregisterTaskFn)); } static Exception processException(Client auditClient, String modelId, Exception e) { @@ -197,14 +282,7 @@ public TaskId getParentTask() { @Override public ModelDownloadTask createTask(long id, String type, String action, TaskId parentTaskId, Map headers) { - return new ModelDownloadTask( - id, - type, - action, - downloadModelTaskDescription(request.getModelId()), - parentTaskId, - headers - ); + return new ModelDownloadTask(id, type, action, request.getModelId(), parentTaskId, headers); } }, false); } diff --git a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java index cbcfd5b760779..ae7217d34b694 100644 --- a/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java +++ b/x-pack/plugin/ml-package-loader/src/test/java/org/elasticsearch/xpack/ml/packageloader/action/TransportLoadTrainedModelPackageTests.java @@ -10,13 +10,19 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.indices.breaker.CircuitBreakerService; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; import org.elasticsearch.tasks.TaskManager; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.common.notifications.Level; import org.elasticsearch.xpack.core.ml.action.AuditMlNotificationAction; import org.elasticsearch.xpack.core.ml.inference.trainedmodel.ModelPackageConfig; @@ -27,9 +33,13 @@ import java.io.IOException; import java.net.MalformedURLException; import java.net.URISyntaxException; +import java.util.Map; import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_ACTION; +import static org.elasticsearch.xpack.core.ml.MlTasks.MODEL_IMPORT_TASK_TYPE; +import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.core.Is.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; @@ -37,6 +47,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; public class TransportLoadTrainedModelPackageTests extends ESTestCase { private static final String MODEL_IMPORT_FAILURE_MSG_FORMAT = "Model importing failed due to %s [%s]"; @@ -44,17 +55,10 @@ public class TransportLoadTrainedModelPackageTests extends ESTestCase { public void testSendsFinishedUploadNotification() { var uploader = createUploader(null); var taskManager = mock(TaskManager.class); - var task = mock(Task.class); + var task = mock(ModelDownloadTask.class); var client = mock(Client.class); - TransportLoadTrainedModelPackage.importModel( - client, - taskManager, - createRequestWithWaiting(), - uploader, - ActionListener.noop(), - task - ); + TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, ActionListener.noop()); var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class); // 2 notifications- the start and finish messages @@ -108,32 +112,63 @@ public void testSendsWarningNotificationForTaskCancelledException() throws Excep public void testCallsOnResponseWithAcknowledgedResponse() throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); - var task = mock(Task.class); + var task = mock(ModelDownloadTask.class); ModelImporter uploader = createUploader(null); var responseRef = new AtomicReference(); var listener = ActionListener.wrap(responseRef::set, e -> fail("received an exception: " + e.getMessage())); - TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task); + TransportLoadTrainedModelPackage.importModel(client, () -> {}, createRequestWithWaiting(), uploader, task, listener); assertThat(responseRef.get(), is(AcknowledgedResponse.TRUE)); } public void testDoesNotCallListenerWhenNotWaitingForCompletion() { var uploader = mock(ModelImporter.class); var client = mock(Client.class); - var taskManager = mock(TaskManager.class); - var task = mock(Task.class); - + var task = mock(ModelDownloadTask.class); TransportLoadTrainedModelPackage.importModel( client, - taskManager, + () -> {}, createRequestWithoutWaiting(), uploader, - ActionListener.running(ESTestCase::fail), - task + task, + ActionListener.running(ESTestCase::fail) ); } + public void testWaitForExistingDownload() { + var taskManager = mock(TaskManager.class); + var modelId = "foo"; + var task = new ModelDownloadTask(1L, MODEL_IMPORT_TASK_TYPE, MODEL_IMPORT_TASK_ACTION, modelId, new TaskId("node", 1L), Map.of()); + when(taskManager.getCancellableTasks()).thenReturn(Map.of(1L, task)); + + var transportService = mock(TransportService.class); + when(transportService.getTaskManager()).thenReturn(taskManager); + + var action = new TransportLoadTrainedModelPackage( + transportService, + mock(ClusterService.class), + mock(ThreadPool.class), + mock(ActionFilters.class), + mock(IndexNameExpressionResolver.class), + mock(Client.class), + mock(CircuitBreakerService.class) + ); + + assertTrue(action.existingDownloadInProgress(modelId, true, ActionListener.noop())); + verify(taskManager).registerRemovedTaskListener(any()); + assertThat(action.downloadTrackersByModelId.entrySet(), hasSize(1)); + assertThat(action.downloadTrackersByModelId.get(modelId), hasSize(1)); + + // With wait for completion == false no new removed listener will be added + assertTrue(action.existingDownloadInProgress(modelId, false, ActionListener.noop())); + verify(taskManager, times(1)).registerRemovedTaskListener(any()); + assertThat(action.downloadTrackersByModelId.entrySet(), hasSize(1)); + assertThat(action.downloadTrackersByModelId.get(modelId), hasSize(1)); + + assertFalse(action.existingDownloadInProgress("no-task-for-this-one", randomBoolean(), ActionListener.noop())); + } + private void assertUploadCallsOnFailure(Exception exception, String message, RestStatus status, Level level) throws Exception { var esStatusException = new ElasticsearchStatusException(message, status, exception); @@ -152,7 +187,7 @@ private void assertNotificationAndOnFailure( ) throws Exception { var client = mock(Client.class); var taskManager = mock(TaskManager.class); - var task = mock(Task.class); + var task = mock(ModelDownloadTask.class); ModelImporter uploader = createUploader(thrownException); var failureRef = new AtomicReference(); @@ -160,7 +195,14 @@ private void assertNotificationAndOnFailure( (AcknowledgedResponse response) -> { fail("received a acknowledged response: " + response.toString()); }, failureRef::set ); - TransportLoadTrainedModelPackage.importModel(client, taskManager, createRequestWithWaiting(), uploader, listener, task); + TransportLoadTrainedModelPackage.importModel( + client, + () -> taskManager.unregister(task), + createRequestWithWaiting(), + uploader, + task, + listener + ); var notificationArg = ArgumentCaptor.forClass(AuditMlNotificationAction.Request.class); // 2 notifications- the starting message and the failure