-
Notifications
You must be signed in to change notification settings - Fork 37
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature/agent_framework] Add Undeploy Model Step (#236)
Signed-off-by: Daniel Widdis <[email protected]>
- Loading branch information
Showing
5 changed files
with
256 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
114 changes: 114 additions & 0 deletions
114
src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.apache.logging.log4j.LogManager; | ||
import org.apache.logging.log4j.Logger; | ||
import org.opensearch.ExceptionsHelper; | ||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.FailedNodeException; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.flowframework.util.ParseUtils; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.Set; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.stream.Collectors; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.SUCCESS; | ||
|
||
/** | ||
* Step to undeploy model | ||
*/ | ||
public class UndeployModelStep implements WorkflowStep { | ||
|
||
private static final Logger logger = LogManager.getLogger(UndeployModelStep.class); | ||
|
||
private MachineLearningNodeClient mlClient; | ||
|
||
static final String NAME = "undeploy_model"; | ||
|
||
/** | ||
* Instantiate this class | ||
* @param mlClient Machine Learning client to perform the undeploy | ||
*/ | ||
public UndeployModelStep(MachineLearningNodeClient mlClient) { | ||
this.mlClient = mlClient; | ||
} | ||
|
||
@Override | ||
public CompletableFuture<WorkflowData> execute( | ||
String currentNodeId, | ||
WorkflowData currentNodeInputs, | ||
Map<String, WorkflowData> outputs, | ||
Map<String, String> previousNodeInputs | ||
) throws IOException { | ||
CompletableFuture<WorkflowData> undeployModelFuture = new CompletableFuture<>(); | ||
|
||
ActionListener<MLUndeployModelsResponse> actionListener = new ActionListener<>() { | ||
|
||
@Override | ||
public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { | ||
List<FailedNodeException> failures = mlUndeployModelsResponse.getResponse().failures(); | ||
if (failures.isEmpty()) { | ||
undeployModelFuture.complete( | ||
new WorkflowData( | ||
Map.ofEntries(Map.entry(SUCCESS, !mlUndeployModelsResponse.getResponse().hasFailures())), | ||
currentNodeInputs.getWorkflowId(), | ||
currentNodeInputs.getNodeId() | ||
) | ||
); | ||
} else { | ||
List<String> failedNodes = failures.stream().map(FailedNodeException::nodeId).collect(Collectors.toList()); | ||
String message = "Failed to undeploy model on nodes " + failedNodes; | ||
logger.error(message); | ||
undeployModelFuture.completeExceptionally(new OpenSearchException(message)); | ||
} | ||
} | ||
|
||
@Override | ||
public void onFailure(Exception e) { | ||
logger.error("Failed to unldeploy model"); | ||
undeployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); | ||
} | ||
}; | ||
|
||
Set<String> requiredKeys = Set.of(MODEL_ID); | ||
Set<String> optionalKeys = Collections.emptySet(); | ||
|
||
try { | ||
Map<String, Object> inputs = ParseUtils.getInputsFromPreviousSteps( | ||
requiredKeys, | ||
optionalKeys, | ||
currentNodeInputs, | ||
outputs, | ||
previousNodeInputs | ||
); | ||
|
||
String modelId = inputs.get(MODEL_ID).toString(); | ||
|
||
mlClient.undeploy(new String[] { modelId }, null, actionListener); | ||
} catch (FlowFrameworkException e) { | ||
undeployModelFuture.completeExceptionally(e); | ||
} | ||
return undeployModelFuture; | ||
} | ||
|
||
@Override | ||
public String getName() { | ||
return NAME; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
131 changes: 131 additions & 0 deletions
131
src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,131 @@ | ||
/* | ||
* Copyright OpenSearch Contributors | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
* The OpenSearch Contributors require contributions made to | ||
* this file be licensed under the Apache-2.0 license or a | ||
* compatible open source license. | ||
*/ | ||
package org.opensearch.flowframework.workflow; | ||
|
||
import org.opensearch.OpenSearchException; | ||
import org.opensearch.action.FailedNodeException; | ||
import org.opensearch.cluster.ClusterName; | ||
import org.opensearch.core.action.ActionListener; | ||
import org.opensearch.core.rest.RestStatus; | ||
import org.opensearch.flowframework.exception.FlowFrameworkException; | ||
import org.opensearch.ml.client.MachineLearningNodeClient; | ||
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelNodesResponse; | ||
import org.opensearch.ml.common.transport.undeploy.MLUndeployModelsResponse; | ||
import org.opensearch.test.OpenSearchTestCase; | ||
|
||
import java.io.IOException; | ||
import java.util.Collections; | ||
import java.util.List; | ||
import java.util.Map; | ||
import java.util.concurrent.CompletableFuture; | ||
import java.util.concurrent.ExecutionException; | ||
|
||
import org.mockito.Mock; | ||
import org.mockito.MockitoAnnotations; | ||
|
||
import static org.opensearch.flowframework.common.CommonValue.MODEL_ID; | ||
import static org.opensearch.flowframework.common.CommonValue.SUCCESS; | ||
import static org.mockito.ArgumentMatchers.any; | ||
import static org.mockito.Mockito.doAnswer; | ||
import static org.mockito.Mockito.verify; | ||
|
||
public class UndeployModelStepTests extends OpenSearchTestCase { | ||
private WorkflowData inputData; | ||
|
||
@Mock | ||
MachineLearningNodeClient machineLearningNodeClient; | ||
|
||
@Override | ||
public void setUp() throws Exception { | ||
super.setUp(); | ||
|
||
MockitoAnnotations.openMocks(this); | ||
|
||
inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); | ||
} | ||
|
||
public void testUndeployModel() throws IOException, ExecutionException, InterruptedException { | ||
|
||
String modelId = randomAlphaOfLength(5); | ||
UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); | ||
|
||
doAnswer(invocation -> { | ||
ClusterName clusterName = new ClusterName("clusterName"); | ||
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2); | ||
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( | ||
clusterName, | ||
Collections.emptyList(), | ||
Collections.emptyList() | ||
); | ||
MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); | ||
actionListener.onResponse(output); | ||
return null; | ||
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); | ||
|
||
CompletableFuture<WorkflowData> future = UndeployModelStep.execute( | ||
inputData.getNodeId(), | ||
inputData, | ||
Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), | ||
Map.of("step_1", MODEL_ID) | ||
); | ||
verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); | ||
|
||
assertTrue(future.isDone()); | ||
assertTrue((boolean) future.get().getContent().get(SUCCESS)); | ||
} | ||
|
||
public void testNoModelIdInOutput() throws IOException { | ||
UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); | ||
|
||
CompletableFuture<WorkflowData> future = UndeployModelStep.execute( | ||
inputData.getNodeId(), | ||
inputData, | ||
Collections.emptyMap(), | ||
Collections.emptyMap() | ||
); | ||
|
||
assertTrue(future.isCompletedExceptionally()); | ||
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); | ||
assertTrue(ex.getCause() instanceof FlowFrameworkException); | ||
assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); | ||
} | ||
|
||
public void testUndeployModelFailure() throws IOException { | ||
UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); | ||
|
||
doAnswer(invocation -> { | ||
ClusterName clusterName = new ClusterName("clusterName"); | ||
ActionListener<MLUndeployModelsResponse> actionListener = invocation.getArgument(2); | ||
MLUndeployModelNodesResponse mlUndeployModelNodesResponse = new MLUndeployModelNodesResponse( | ||
clusterName, | ||
Collections.emptyList(), | ||
List.of(new FailedNodeException("failed-node", "Test message", null)) | ||
); | ||
MLUndeployModelsResponse output = new MLUndeployModelsResponse(mlUndeployModelNodesResponse); | ||
actionListener.onResponse(output); | ||
|
||
actionListener.onFailure(new FlowFrameworkException("Failed to undeploy model", RestStatus.INTERNAL_SERVER_ERROR)); | ||
return null; | ||
}).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); | ||
|
||
CompletableFuture<WorkflowData> future = UndeployModelStep.execute( | ||
inputData.getNodeId(), | ||
inputData, | ||
Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), | ||
Map.of("step_1", MODEL_ID) | ||
); | ||
|
||
verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); | ||
|
||
assertTrue(future.isCompletedExceptionally()); | ||
ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); | ||
assertTrue(ex.getCause() instanceof OpenSearchException); | ||
assertEquals("Failed to undeploy model on nodes [failed-node]", ex.getCause().getMessage()); | ||
} | ||
} |