diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index fbf169162..784b67374 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -145,8 +145,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener provisionProcessSequence, ActionListener listener ) { - GetWorkflowStateRequest getRequest = new GetWorkflowStateRequest(workflowId, true); - client.execute(GetWorkflowStateAction.INSTANCE, getRequest, ActionListener.wrap(response -> { + GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true); + client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { // Get a map of step id to created resources final Map resourceMap = response.getWorkflowState() .resourcesCreated() @@ -166,8 +167,9 @@ private void getResourcesAndExecute( // Now finally do the deprovision executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener); }, exception -> { - logger.error("Failed to get workflow state for workflow " + workflowId); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + String message = "Failed to get workflow state for workflow " + workflowId; + logger.error(message, exception); + listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception))); })); } @@ -177,7 +179,7 @@ private void executeDeprovisionSequence( List provisionProcessSequence, ActionListener listener ) { - // Create a list of ProcessNodes with ta corresponding deprovision workflow steps + // Create a list of ProcessNodes with the corresponding deprovision workflow steps List deprovisionProcessSequence = provisionProcessSequence.stream() // Only include nodes that created a resource .filter(pn -> resourceMap.containsKey(pn.id())) diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index a1f1fb517..fd52e67f1 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -14,33 +14,46 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; +import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.ThreadContext; -import org.opensearch.common.xcontent.XContentFactory; import org.opensearch.core.action.ActionListener; -import org.opensearch.core.common.bytes.BytesReference; -import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.ResourceCreated; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; +import org.opensearch.flowframework.model.WorkflowState; import org.opensearch.flowframework.util.EncryptorUtils; +import org.opensearch.flowframework.workflow.CreateConnectorStep; +import org.opensearch.flowframework.workflow.DeleteConnectorStep; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowData; import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.index.get.GetResult; +import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; +import org.junit.AfterClass; +import java.io.IOException; +import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.TimeUnit; import org.mockito.ArgumentCaptor; -import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX; +import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyMap; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; @@ -49,19 +62,20 @@ public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase { - private ThreadPool threadPool; + private static ThreadPool threadPool = new TestThreadPool(DeprovisionWorkflowTransportActionTests.class.getName()); private Client client; private WorkflowProcessSorter workflowProcessSorter; private WorkflowStepFactory workflowStepFactory; + private DeleteConnectorStep deleteConnectorStep; private DeprovisionWorkflowTransportAction deprovisionWorkflowTransportAction; private Template template; + private GetResult getResult; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; private EncryptorUtils encryptorUtils; @Override public void setUp() throws Exception { super.setUp(); - this.threadPool = mock(ThreadPool.class); this.client = mock(Client.class); this.workflowProcessSorter = mock(WorkflowProcessSorter.class); this.workflowStepFactory = mock(WorkflowStepFactory.class); @@ -81,23 +95,32 @@ public void setUp() throws Exception { Version templateVersion = Version.fromString("1.0.0"); List compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")); - WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar")); - WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux")); - WorkflowEdge edgeAB = new WorkflowEdge("A", "B"); - List nodes = List.of(nodeA, nodeB); - List edges = List.of(edgeAB); + WorkflowNode node = new WorkflowNode("step_1", "create_connector", Collections.emptyMap(), Collections.emptyMap()); + List nodes = List.of(node); + List edges = Collections.emptyList(); Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges); - this.template = new Template( "test", "description", "use case", templateVersion, compatibilityVersions, - Map.of("deprovision", workflow), + Map.of(PROVISION_WORKFLOW, workflow), Map.of(), TestHelpers.randomUser() ); + this.getResult = mock(GetResult.class); + + MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client); + ProcessNode processNode = mock(ProcessNode.class); + when(processNode.id()).thenReturn("step_1"); + when(processNode.workflowStep()).thenReturn(new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); + when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap()); + when(processNode.input()).thenReturn(WorkflowData.EMPTY); + when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5)); + when(this.workflowProcessSorter.sortProcessNodes(any(Workflow.class), any(String.class))).thenReturn(List.of(processNode)); + this.deleteConnectorStep = mock(DeleteConnectorStep.class); + when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep); ThreadPool clientThreadPool = mock(ThreadPool.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); @@ -106,49 +129,106 @@ public void setUp() throws Exception { when(clientThreadPool.getThreadContext()).thenReturn(threadContext); } - public void testDedeprovisionWorkflow() { + @AfterClass + public static void cleanup() { + ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS); + } + public void testDeprovisionWorkflow() throws IOException { String workflowId = "1"; @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); - XContentBuilder builder = XContentFactory.jsonBuilder(); - this.template.toXContent(builder, null); - BytesReference templateBytesRef = BytesReference.bytes(builder); - GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null); + when(getResult.isExists()).thenReturn(true); responseListener.onResponse(new GetResponse(getResult)); return null; }).when(client).get(any(GetRequest.class), any()); when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + + WorkflowState state = WorkflowState.builder() + .resourcesCreated(List.of(new ResourceCreated("create_connector", "step_1", "connectorId"))) + .build(); + responseListener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); + + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn( + CompletableFuture.completedFuture(WorkflowData.EMPTY) + ); + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); - // TODO: need a lot more mocking for happy path - // ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); - // verify(listener, times(1)).onResponse(responseCaptor.capture()); - // assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); + verify(listener, times(1)).onResponse(responseCaptor.capture()); + assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); } public void testFailedToRetrieveTemplateFromGlobalContext() { + String workflowId = "1"; @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); - WorkflowRequest request = new WorkflowRequest("1", null); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + doAnswer(invocation -> { ActionListener responseListener = invocation.getArgument(1); - responseListener.onFailure(new Exception("Failed to retrieve template from global context.")); + + when(getResult.isExists()).thenReturn(false); + responseListener.onResponse(new GetResponse(getResult)); return null; }).when(client).get(any(GetRequest.class), any()); - deprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener); + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); - assertEquals("Failed to retrieve template from global context.", exceptionCaptor.getValue().getMessage()); + assertEquals("Failed to retrieve template (1) from global context.", exceptionCaptor.getValue().getMessage()); } + public void testFailToDeprovision() throws IOException { + String workflowId = "1"; + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + when(getResult.sourceAsString()).thenReturn(this.template.toJson()); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(1); + + when(getResult.isExists()).thenReturn(true); + responseListener.onResponse(new GetResponse(getResult)); + return null; + }).when(client).get(any(GetRequest.class), any()); + + when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template); + + doAnswer(invocation -> { + ActionListener responseListener = invocation.getArgument(2); + + WorkflowState state = WorkflowState.builder() + .resourcesCreated(List.of(new ResourceCreated("deploy_model", "step_1", "modelId"))) + .build(); + responseListener.onResponse(new GetWorkflowStateResponse(state, true)); + return null; + }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); + + CompletableFuture future = new CompletableFuture<>(); + future.completeExceptionally(new RuntimeException("rte")); + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); + + deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage()); + } }