From 5682c22085fae0ad7f986ce7e9d5e5886404a16d Mon Sep 17 00:00:00 2001 From: Daniel Widdis Date: Tue, 5 Dec 2023 21:08:44 -0800 Subject: [PATCH] [Feature/agent_framework] Add Delete Agent Step (#246) Delete Agent Step Signed-off-by: Daniel Widdis --- .../workflow/DeleteAgentStep.java | 100 ++++++++++++++++ .../workflow/WorkflowStepFactory.java | 1 + .../resources/mappings/workflow-steps.json | 8 ++ .../workflow/DeleteAgentStepTests.java | 113 ++++++++++++++++++ 4 files changed, 222 insertions(+) create mode 100644 src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java create mode 100644 src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java new file mode 100644 index 000000000..d97b4ed28 --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java @@ -0,0 +1,100 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.AGENT_ID; + +/** + * Step to delete a agent for a remote model + */ +public class DeleteAgentStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteAgentStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_agent"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteAgentStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture deleteAgentFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteAgentFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry("agent_id", deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete agent"); + deleteAgentFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + Set requiredKeys = Set.of(AGENT_ID); + Set optionalKeys = Collections.emptySet(); + + try { + Map inputs = ParseUtils.getInputsFromPreviousSteps( + requiredKeys, + optionalKeys, + currentNodeInputs, + outputs, + previousNodeInputs + ); + String agentId = (String) inputs.get(AGENT_ID); + + mlClient.deleteAgent(agentId, actionListener); + } catch (FlowFrameworkException e) { + deleteAgentFuture.completeExceptionally(e); + } + return deleteAgentFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index 3a2ccb4e1..c2e55b100 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -59,6 +59,7 @@ public WorkflowStepFactory( stepMap.put(ModelGroupStep.NAME, () -> new ModelGroupStep(mlClient, flowFrameworkIndicesHandler)); stepMap.put(ToolStep.NAME, ToolStep::new); stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index e3263d9a2..149b1cfce 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -123,6 +123,14 @@ "agent_id" ] }, + "delete_agent": { + "inputs": [ + "agent_id" + ], + "outputs":[ + "agent_id" + ] + }, "create_tool": { "inputs": [ "type" diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java new file mode 100644 index 000000000..a893b8928 --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java @@ -0,0 +1,113 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteAgentStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"); + } + + public void testDeleteAgent() throws IOException, ExecutionException, InterruptedException { + + String agentId = randomAlphaOfLength(5); + DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); + + doAnswer(invocation -> { + String agentIdArg = invocation.getArgument(0); + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, agentIdArg, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + CompletableFuture future = deleteAgentStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("agent_id", agentId), "workflowId", "nodeId")), + Map.of("step_1", "agent_id") + ); + verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + assertTrue(future.isDone()); + assertEquals(agentId, future.get().getContent().get("agent_id")); + } + + public void testNoAgentIdInOutput() throws IOException { + DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); + + CompletableFuture future = deleteAgentStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Missing required inputs [agent_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); + } + + public void testDeleteAgentFailure() throws IOException { + DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to delete agent", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + CompletableFuture future = deleteAgentStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("agent_id", "test"), "workflowId", "nodeId")), + Map.of("step_1", "agent_id") + ); + + verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to delete agent", ex.getCause().getMessage()); + } +}