Skip to content

Commit

Permalink
Completed test implementations
Browse files Browse the repository at this point in the history
Signed-off-by: Daniel Widdis <[email protected]>
  • Loading branch information
dbwiddis committed Dec 13, 2023
1 parent 60d699d commit a18a9c3
Show file tree
Hide file tree
Showing 2 changed files with 115 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -145,8 +145,9 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
}
}));
} catch (Exception e) {
logger.error("Failed to retrieve template from global context.", e);
listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)));
String message = "Failed to retrieve template from global context.";
logger.error(message, e);
listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(e)));
}
}

Expand All @@ -155,8 +156,8 @@ private void getResourcesAndExecute(
List<ProcessNode> provisionProcessSequence,
ActionListener<WorkflowResponse> listener
) {
GetWorkflowStateRequest getRequest = new GetWorkflowStateRequest(workflowId, true);
client.execute(GetWorkflowStateAction.INSTANCE, getRequest, ActionListener.wrap(response -> {
GetWorkflowStateRequest getStateRequest = new GetWorkflowStateRequest(workflowId, true);
client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> {
// Get a map of step id to created resources
final Map<String, ResourceCreated> resourceMap = response.getWorkflowState()
.resourcesCreated()
Expand All @@ -166,8 +167,9 @@ private void getResourcesAndExecute(
// Now finally do the deprovision
executeDeprovisionSequence(workflowId, resourceMap, provisionProcessSequence, listener);
}, exception -> {
logger.error("Failed to get workflow state for workflow " + workflowId);
listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)));
String message = "Failed to get workflow state for workflow " + workflowId;
logger.error(message, exception);
listener.onFailure(new FlowFrameworkException(message, ExceptionsHelper.status(exception)));
}));
}

Expand All @@ -177,7 +179,7 @@ private void executeDeprovisionSequence(
List<ProcessNode> provisionProcessSequence,
ActionListener<WorkflowResponse> listener
) {
// Create a list of ProcessNodes with ta corresponding deprovision workflow steps
// Create a list of ProcessNodes with the corresponding deprovision workflow steps
List<ProcessNode> deprovisionProcessSequence = provisionProcessSequence.stream()
// Only include nodes that created a resource
.filter(pn -> resourceMap.containsKey(pn.id()))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,33 +14,46 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.ThreadContext;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.bytes.BytesReference;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.flowframework.TestHelpers;
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.Template;
import org.opensearch.flowframework.model.Workflow;
import org.opensearch.flowframework.model.WorkflowEdge;
import org.opensearch.flowframework.model.WorkflowNode;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.util.EncryptorUtils;
import org.opensearch.flowframework.workflow.CreateConnectorStep;
import org.opensearch.flowframework.workflow.DeleteConnectorStep;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowProcessSorter;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.index.get.GetResult;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.TestThreadPool;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.TransportService;
import org.junit.AfterClass;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.TimeUnit;

import org.mockito.ArgumentCaptor;

import static org.opensearch.flowframework.common.CommonValue.GLOBAL_CONTEXT_INDEX;
import static org.opensearch.flowframework.common.CommonValue.PROVISION_WORKFLOW;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyMap;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.times;
Expand All @@ -49,19 +62,20 @@

public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase {

private ThreadPool threadPool;
private static ThreadPool threadPool = new TestThreadPool(DeprovisionWorkflowTransportActionTests.class.getName());
private Client client;
private WorkflowProcessSorter workflowProcessSorter;
private WorkflowStepFactory workflowStepFactory;
private DeleteConnectorStep deleteConnectorStep;
private DeprovisionWorkflowTransportAction deprovisionWorkflowTransportAction;
private Template template;
private GetResult getResult;
private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler;
private EncryptorUtils encryptorUtils;

@Override
public void setUp() throws Exception {
super.setUp();
this.threadPool = mock(ThreadPool.class);
this.client = mock(Client.class);
this.workflowProcessSorter = mock(WorkflowProcessSorter.class);
this.workflowStepFactory = mock(WorkflowStepFactory.class);
Expand All @@ -81,23 +95,32 @@ public void setUp() throws Exception {

Version templateVersion = Version.fromString("1.0.0");
List<Version> compatibilityVersions = List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0"));
WorkflowNode nodeA = new WorkflowNode("A", "a-type", Map.of(), Map.of("foo", "bar"));
WorkflowNode nodeB = new WorkflowNode("B", "b-type", Map.of(), Map.of("baz", "qux"));
WorkflowEdge edgeAB = new WorkflowEdge("A", "B");
List<WorkflowNode> nodes = List.of(nodeA, nodeB);
List<WorkflowEdge> edges = List.of(edgeAB);
WorkflowNode node = new WorkflowNode("step_1", "create_connector", Collections.emptyMap(), Collections.emptyMap());
List<WorkflowNode> nodes = List.of(node);
List<WorkflowEdge> edges = Collections.emptyList();
Workflow workflow = new Workflow(Map.of("key", "value"), nodes, edges);

this.template = new Template(
"test",
"description",
"use case",
templateVersion,
compatibilityVersions,
Map.of("deprovision", workflow),
Map.of(PROVISION_WORKFLOW, workflow),
Map.of(),
TestHelpers.randomUser()
);
this.getResult = mock(GetResult.class);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
ProcessNode processNode = mock(ProcessNode.class);
when(processNode.id()).thenReturn("step_1");
when(processNode.workflowStep()).thenReturn(new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap());
when(processNode.input()).thenReturn(WorkflowData.EMPTY);
when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5));
when(this.workflowProcessSorter.sortProcessNodes(any(Workflow.class), any(String.class))).thenReturn(List.of(processNode));
this.deleteConnectorStep = mock(DeleteConnectorStep.class);
when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep);

ThreadPool clientThreadPool = mock(ThreadPool.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
Expand All @@ -106,49 +129,106 @@ public void setUp() throws Exception {
when(clientThreadPool.getThreadContext()).thenReturn(threadContext);
}

public void testDedeprovisionWorkflow() {
@AfterClass
public static void cleanup() {
ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS);
}

public void testDeprovisionWorkflow() throws IOException {
String workflowId = "1";
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
when(getResult.sourceAsString()).thenReturn(this.template.toJson());

doAnswer(invocation -> {
ActionListener<GetResponse> responseListener = invocation.getArgument(1);

XContentBuilder builder = XContentFactory.jsonBuilder();
this.template.toXContent(builder, null);
BytesReference templateBytesRef = BytesReference.bytes(builder);
GetResult getResult = new GetResult(GLOBAL_CONTEXT_INDEX, workflowId, 1, 1, 1, true, templateBytesRef, null, null);
when(getResult.isExists()).thenReturn(true);
responseListener.onResponse(new GetResponse(getResult));
return null;
}).when(client).get(any(GetRequest.class), any());

when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template);

doAnswer(invocation -> {
ActionListener<GetWorkflowStateResponse> responseListener = invocation.getArgument(2);

WorkflowState state = WorkflowState.builder()
.resourcesCreated(List.of(new ResourceCreated("create_connector", "step_1", "connectorId")))
.build();
responseListener.onResponse(new GetWorkflowStateResponse(state, true));
return null;
}).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any());

when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(
CompletableFuture.completedFuture(WorkflowData.EMPTY)
);

deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
// TODO: need a lot more mocking for happy path
// ArgumentCaptor<WorkflowResponse> responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class);
ArgumentCaptor<WorkflowResponse> responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class);

// verify(listener, times(1)).onResponse(responseCaptor.capture());
// assertEquals(workflowId, responseCaptor.getValue().getWorkflowId());
verify(listener, times(1)).onResponse(responseCaptor.capture());
assertEquals(workflowId, responseCaptor.getValue().getWorkflowId());
}

public void testFailedToRetrieveTemplateFromGlobalContext() {
String workflowId = "1";
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest request = new WorkflowRequest("1", null);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
when(getResult.sourceAsString()).thenReturn(this.template.toJson());

doAnswer(invocation -> {
ActionListener<GetResponse> responseListener = invocation.getArgument(1);
responseListener.onFailure(new Exception("Failed to retrieve template from global context."));

when(getResult.isExists()).thenReturn(false);
responseListener.onResponse(new GetResponse(getResult));
return null;
}).when(client).get(any(GetRequest.class), any());

deprovisionWorkflowTransportAction.doExecute(mock(Task.class), request, listener);
deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("Failed to retrieve template from global context.", exceptionCaptor.getValue().getMessage());
assertEquals("Failed to retrieve template (1) from global context.", exceptionCaptor.getValue().getMessage());
}

public void testFailToDeprovision() throws IOException {
String workflowId = "1";
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);
when(getResult.sourceAsString()).thenReturn(this.template.toJson());

doAnswer(invocation -> {
ActionListener<GetResponse> responseListener = invocation.getArgument(1);

when(getResult.isExists()).thenReturn(true);
responseListener.onResponse(new GetResponse(getResult));
return null;
}).when(client).get(any(GetRequest.class), any());

when(encryptorUtils.decryptTemplateCredentials(any())).thenReturn(template);

doAnswer(invocation -> {
ActionListener<GetWorkflowStateResponse> responseListener = invocation.getArgument(2);

WorkflowState state = WorkflowState.builder()
.resourcesCreated(List.of(new ResourceCreated("deploy_model", "step_1", "modelId")))
.build();
responseListener.onResponse(new GetWorkflowStateResponse(state, true));
return null;
}).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any());

CompletableFuture<WorkflowData> future = new CompletableFuture<>();
future.completeExceptionally(new RuntimeException("rte"));
when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future);

deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage());
}
}

0 comments on commit a18a9c3

Please sign in to comment.