Skip to content

Commit

Permalink
[Feature/agent_framework] Add Undeploy Model Step (#236)
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 18, 2023
1 parent b0a2d01 commit 830cff2
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ private CommonValue() {}
/** The provision workflow thread pool name */
public static final String PROVISION_THREAD_POOL = "opensearch_workflow_provision";

/** Success name field */
public static final String SUCCESS = "success";
/** Index name field */
public static final String INDEX_NAME = "index_name";
/** Type field */
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.ExceptionsHelper;
import org.opensearch.OpenSearchException;
import org.opensearch.action.FailedNodeException;
import org.opensearch.core.action.ActionListener;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.flowframework.util.ParseUtils;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.stream.Collectors;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.SUCCESS;

/**
* Step to undeploy model
*/
public class UndeployModelStep implements WorkflowStep {

private static final Logger logger = LogManager.getLogger(UndeployModelStep.class);

private MachineLearningNodeClient mlClient;

static final String NAME = "undeploy_model";

/**
* Instantiate this class
* @param mlClient Machine Learning client to perform the undeploy
*/
public UndeployModelStep(MachineLearningNodeClient mlClient) {
this.mlClient = mlClient;
}

@Override
public CompletableFuture<WorkflowData> execute(
String currentNodeId,
WorkflowData currentNodeInputs,
Map<String, WorkflowData> outputs,
Map<String, String> previousNodeInputs
) throws IOException {
CompletableFuture<WorkflowData> undeployModelFuture = new CompletableFuture<>();

ActionListener<MLUndeployModelsResponse> actionListener = new ActionListener<>() {

@Override
public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) {
List<FailedNodeException> failures = mlUndeployModelsResponse.getResponse().failures();
if (failures.isEmpty()) {
undeployModelFuture.complete(
new WorkflowData(
Map.ofEntries(Map.entry(SUCCESS, !mlUndeployModelsResponse.getResponse().hasFailures())),
currentNodeInputs.getWorkflowId(),
currentNodeInputs.getNodeId()
)
);
} else {
List<String> failedNodes = failures.stream().map(FailedNodeException::nodeId).collect(Collectors.toList());
String message = "Failed to undeploy model on nodes " + failedNodes;
logger.error(message);
undeployModelFuture.completeExceptionally(new OpenSearchException(message));
}
}

@Override
public void onFailure(Exception e) {
logger.error("Failed to unldeploy model");
undeployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
}
};

Set<String> requiredKeys = Set.of(MODEL_ID);
Set<String> optionalKeys = Collections.emptySet();

try {
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps(
requiredKeys,
optionalKeys,
currentNodeInputs,
outputs,
previousNodeInputs
);

String modelId = inputs.get(MODEL_ID).toString();

mlClient.undeploy(new String[] { modelId }, null, actionListener);
} catch (FlowFrameworkException e) {
undeployModelFuture.completeExceptionally(e);
}
return undeployModelFuture;
}

@Override
public String getName() {
return NAME;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ public WorkflowStepFactory(
);
stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient));
stepMap.put(UndeployModelStep.NAME, new UndeployModelStep(mlClient));
stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient));
stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient, flowFrameworkIndicesHandler));
Expand Down
8 changes: 8 additions & 0 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,14 @@
"deploy_model_status"
]
},
"undeploy_model": {
"inputs":[
"model_id"
],
"outputs":[
"success"
]
},
"register_model_group": {
"inputs":[
"name"
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/
package org.opensearch.flowframework.workflow;

import org.opensearch.OpenSearchException;
import org.opensearch.action.FailedNodeException;
import org.opensearch.cluster.ClusterName;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.flowframework.exception.FlowFrameworkException;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse;
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;

import org.mockito.Mock;
import org.mockito.MockitoAnnotations;

import static org.opensearch.flowframework.common.CommonValue.MODEL_ID;
import static org.opensearch.flowframework.common.CommonValue.SUCCESS;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.verify;

public class UndeployModelStepTests extends OpenSearchTestCase {
private WorkflowData inputData;

@Mock
MachineLearningNodeClient machineLearningNodeClient;

@Override
public void setUp() throws Exception {
super.setUp();

MockitoAnnotations.openMocks(this);

inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id");
}

public void testUndeployModel() throws IOException, ExecutionException, InterruptedException {

String modelId = randomAlphaOfLength(5);
UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient);

doAnswer(invocation -> {
ClusterName clusterName = new ClusterName("clusterName");
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse(
clusterName,
Collections.emptyList(),
Collections.emptyList()
);
MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse);
actionListener.onResponse(output);
return null;
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any());

CompletableFuture<WorkflowData> future = UndeployModelStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")),
Map.of("step_1", MODEL_ID)
);
verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any());

assertTrue(future.isDone());
assertTrue((boolean) future.get().getContent().get(SUCCESS));
}

public void testNoModelIdInOutput() throws IOException {
UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient);

CompletableFuture<WorkflowData> future = UndeployModelStep.execute(
inputData.getNodeId(),
inputData,
Collections.emptyMap(),
Collections.emptyMap()
);

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof FlowFrameworkException);
assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage());
}

public void testUndeployModelFailure() throws IOException {
UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient);

doAnswer(invocation -> {
ClusterName clusterName = new ClusterName("clusterName");
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2);
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse(
clusterName,
Collections.emptyList(),
List.of(new FailedNodeException("failed-node", "Test message", null))
);
MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse);
actionListener.onResponse(output);

actionListener.onFailure(new FlowFrameworkException("Failed to undeploy model", RestStatus.INTERNAL_SERVER_ERROR));
return null;
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any());

CompletableFuture<WorkflowData> future = UndeployModelStep.execute(
inputData.getNodeId(),
inputData,
Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")),
Map.of("step_1", MODEL_ID)
);

verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any());

assertTrue(future.isCompletedExceptionally());
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent());
assertTrue(ex.getCause() instanceof OpenSearchException);
assertEquals("Failed to undeploy model on nodes [failed-node]", ex.getCause().getMessage());
}
}

0 comments on commit 830cff2

Please sign in to comment.