diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java new file mode 100644 index 000000000..bf0fae33e --- /dev/null +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -0,0 +1,105 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; + +import java.io.IOException; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.CompletableFuture; + +import static org.opensearch.flowframework.common.CommonValue.CONNECTOR_ID; + +/** + * Step to delete a connector for a remote model + */ +public class DeleteConnectorStep implements WorkflowStep { + + private static final Logger logger = LogManager.getLogger(DeleteConnectorStep.class); + + private MachineLearningNodeClient mlClient; + + static final String NAME = "delete_connector"; + + /** + * Instantiate this class + * @param mlClient Machine Learning client to perform the deletion + */ + public DeleteConnectorStep(MachineLearningNodeClient mlClient) { + this.mlClient = mlClient; + } + + @Override + public CompletableFuture execute( + String currentNodeId, + WorkflowData currentNodeInputs, + Map outputs, + Map previousNodeInputs + ) throws IOException { + CompletableFuture deleteConnectorFuture = new CompletableFuture<>(); + + ActionListener actionListener = new ActionListener<>() { + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteConnectorFuture.complete( + new WorkflowData( + Map.ofEntries(Map.entry("connector_id", deleteResponse.getId())), + currentNodeInputs.getWorkflowId(), + currentNodeInputs.getNodeId() + ) + ); + } + + @Override + public void onFailure(Exception e) { + logger.error("Failed to delete connector"); + deleteConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + } + }; + + String connectorId = null; + + // Previous Node inputs defines which step the connector ID came from + Optional previousNode = previousNodeInputs.entrySet() + .stream() + .filter(e -> CONNECTOR_ID.equals(e.getValue())) + .map(Map.Entry::getKey) + .findFirst(); + if (previousNode.isPresent()) { + WorkflowData previousNodeOutput = outputs.get(previousNode.get()); + if (previousNodeOutput != null && previousNodeOutput.getContent().containsKey(CONNECTOR_ID)) { + connectorId = previousNodeOutput.getContent().get(CONNECTOR_ID).toString(); + } + } + + if (connectorId != null) { + mlClient.deleteConnector(connectorId, actionListener); + } else { + deleteConnectorFuture.completeExceptionally( + new FlowFrameworkException("Required field " + CONNECTOR_ID + " is not provided", RestStatus.BAD_REQUEST) + ); + } + + return deleteConnectorFuture; + } + + @Override + public String getName() { + return NAME; + } +} diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index c9e565bba..bac65c23a 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -25,7 +25,6 @@ public class WorkflowStepFactory { private final Map stepMap = new HashMap<>(); - private final FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; /** * Instantiate this class. @@ -42,17 +41,6 @@ public WorkflowStepFactory( Client client, MachineLearningNodeClient mlClient, FlowFrameworkIndicesHandler flowFrameworkIndicesHandler - ) { - this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; - populateMap(settings, clusterService, client, mlClient, flowFrameworkIndicesHandler); - } - - private void populateMap( - Settings settings, - ClusterService clusterService, - Client client, - MachineLearningNodeClient mlClient, - FlowFrameworkIndicesHandler flowFrameworkIndicesHandler ) { stepMap.put(NoOpStep.NAME, new NoOpStep()); stepMap.put(CreateIndexStep.NAME, new CreateIndexStep(clusterService, client)); @@ -61,6 +49,7 @@ private void populateMap( stepMap.put(RegisterRemoteModelStep.NAME, new RegisterRemoteModelStep(mlClient)); stepMap.put(DeployModelStep.NAME, new DeployModelStep(mlClient)); stepMap.put(CreateConnectorStep.NAME, new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + stepMap.put(DeleteConnectorStep.NAME, new DeleteConnectorStep(mlClient)); stepMap.put(ModelGroupStep.NAME, new ModelGroupStep(mlClient)); stepMap.put(ToolStep.NAME, new ToolStep()); stepMap.put(RegisterAgentStep.NAME, new RegisterAgentStep(mlClient)); @@ -80,7 +69,7 @@ public WorkflowStep createStep(String type) { /** * Gets the step map - * @return the step map + * @return a read-only copy of the step map */ public Map getStepMap() { return Map.copyOf(this.stepMap); diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index c9794d4ea..5840b0906 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -39,6 +39,14 @@ "connector_id" ] }, + "delete_connector": { + "inputs": [ + "connector_id" + ], + "outputs":[ + "connector_id" + ] + }, "register_local_model": { "inputs":[ "name", diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index de3add996..1135a0ca6 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -25,7 +25,6 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; -import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.MockitoAnnotations; @@ -81,15 +80,12 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr String connectorId = "connect"; CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); MLCreateConnectorResponse output = new MLCreateConnectorResponse(connectorId); actionListener.onResponse(output); return null; - }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), @@ -98,8 +94,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr Collections.emptyMap() ); - verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); - + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); assertTrue(future.isDone()); assertEquals(connectorId, future.get().getContent().get("connector_id")); @@ -108,14 +103,11 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr public void testCreateConnectorFailure() throws IOException { CreateConnectorStep createConnectorStep = new CreateConnectorStep(machineLearningNodeClient, flowFrameworkIndicesHandler); - @SuppressWarnings("unchecked") - ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(1); actionListener.onFailure(new FlowFrameworkException("Failed to create connector", RestStatus.INTERNAL_SERVER_ERROR)); return null; - }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); CompletableFuture future = createConnectorStep.execute( inputData.getNodeId(), @@ -124,7 +116,7 @@ public void testCreateConnectorFailure() throws IOException { Collections.emptyMap() ); - verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), actionListenerCaptor.capture()); + verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java new file mode 100644 index 000000000..3c997a02e --- /dev/null +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -0,0 +1,115 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ +package org.opensearch.flowframework.workflow; + +import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.index.Index; +import org.opensearch.core.index.shard.ShardId; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.common.CommonValue; +import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.ml.client.MachineLearningNodeClient; +import org.opensearch.ml.common.transport.connector.MLCreateConnectorResponse; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.util.Collections; +import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; + +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +public class DeleteConnectorStepTests extends OpenSearchTestCase { + private WorkflowData inputData; + + @Mock + MachineLearningNodeClient machineLearningNodeClient; + + @Override + public void setUp() throws Exception { + super.setUp(); + + MockitoAnnotations.openMocks(this); + + inputData = new WorkflowData(Map.of(CommonValue.CONNECTOR_ID, "test"), "test-id", "test-node-id"); + } + + public void testDeleteConnector() throws IOException, ExecutionException, InterruptedException { + + String connectorId = randomAlphaOfLength(5); + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + doAnswer(invocation -> { + String connectorIdArg = invocation.getArgument(0); + ActionListener actionListener = invocation.getArgument(1); + ShardId shardId = new ShardId(new Index("indexName", "uuid"), 1); + DeleteResponse output = new DeleteResponse(shardId, connectorIdArg, 1, 1, 1, true); + actionListener.onResponse(output); + return null; + }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("connector_id", connectorId), "workflowId", "nodeId")), + Map.of("step_1", "connector_id") + ); + verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + assertTrue(future.isDone()); + assertEquals(connectorId, future.get().getContent().get("connector_id")); + } + + public void testNoConnectorIdInOutput() throws IOException { + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Collections.emptyMap(), + Collections.emptyMap() + ); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Required field connector_id is not provided", ex.getCause().getMessage()); + } + + public void testDeleteConnectorFailure() throws IOException { + DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(1); + actionListener.onFailure(new FlowFrameworkException("Failed to delete connector", RestStatus.INTERNAL_SERVER_ERROR)); + return null; + }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + CompletableFuture future = deleteConnectorStep.execute( + inputData.getNodeId(), + inputData, + Map.of("step_1", new WorkflowData(Map.of("connector_id", "test"), "workflowId", "nodeId")), + Map.of("step_1", "connector_id") + ); + + verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); + + assertTrue(future.isCompletedExceptionally()); + ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); + assertTrue(ex.getCause() instanceof FlowFrameworkException); + assertEquals("Failed to delete connector", ex.getCause().getMessage()); + } +}