Skip to content

Commit

Permalink
Fix thread hang in flow framework issue (#413)
Browse files Browse the repository at this point in the history
* Fix hang in flow framework

Signed-off-by: zane-neo <[email protected]>

* Fix spotless

Signed-off-by: Daniel Widdis <[email protected]>

* Fix tests to account for async delay

Signed-off-by: Daniel Widdis <[email protected]>

* Update GHA version for security tests

Signed-off-by: Daniel Widdis <[email protected]>

---------

Signed-off-by: zane-neo <[email protected]>
Signed-off-by: Daniel Widdis <[email protected]>
Co-authored-by: Daniel Widdis <[email protected]>
(cherry picked from commit bf2cad0)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and dbwiddis committed Jan 17, 2024
1 parent 9c7c2e1 commit b46651e
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 33 deletions.
12 changes: 6 additions & 6 deletions .github/workflows/test_security.yml
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
name: Security test workflow for Flow Framework
on:
push:
branches:
- "*"
branches-ignore:
- 'whitesource-remediate/**'
- 'backport/**'
pull_request:
branches:
- "*"
types: [opened, synchronize, reopened]

jobs:
Get-CI-Image-Tag:
Expand All @@ -30,9 +30,9 @@ jobs:

steps:
- name: Checkout Flow Framework
uses: actions/checkout@v1
uses: actions/checkout@v3
- name: Setup Java ${{ matrix.java }}
uses: actions/setup-java@v1
uses: actions/setup-java@v3
with:
distribution: 'temurin'
java-version: ${{ matrix.java }}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
import static org.opensearch.flowframework.common.CommonValue.PROVISION_START_TIME_FIELD;
import static org.opensearch.flowframework.common.CommonValue.RESOURCES_CREATED_FIELD;
import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD;
import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL;
import static org.opensearch.flowframework.common.WorkflowResources.getDeprovisionStepByWorkflowStep;
import static org.opensearch.flowframework.common.WorkflowResources.getResourceByWorkflowStep;

Expand Down Expand Up @@ -101,7 +102,8 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener<Work
context.restore();

// Retrieve resources from workflow state and deprovision
executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener);
threadPool.executor(WORKFLOW_THREAD_POOL)
.execute(() -> executeDeprovisionSequence(workflowId, response.getWorkflowState().resourcesCreated(), listener));
}, exception -> {
String message = "Failed to get workflow state for workflow " + workflowId;
logger.error(message, exception);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.transport;

import org.opensearch.action.LatchedActionListener;
import org.opensearch.action.support.ActionFilters;
import org.opensearch.client.Client;
import org.opensearch.common.settings.Settings;
Expand All @@ -19,12 +20,9 @@
import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler;
import org.opensearch.flowframework.model.ResourceCreated;
import org.opensearch.flowframework.model.WorkflowState;
import org.opensearch.flowframework.workflow.CreateConnectorStep;
import org.opensearch.flowframework.workflow.DeleteConnectorStep;
import org.opensearch.flowframework.workflow.ProcessNode;
import org.opensearch.flowframework.workflow.WorkflowData;
import org.opensearch.flowframework.workflow.WorkflowStepFactory;
import org.opensearch.ml.client.MachineLearningNodeClient;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.FixedExecutorBuilder;
Expand All @@ -33,10 +31,9 @@
import org.opensearch.transport.TransportService;
import org.junit.AfterClass;

import java.io.IOException;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;

import org.mockito.ArgumentCaptor;
Expand All @@ -50,6 +47,7 @@
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
Expand Down Expand Up @@ -77,47 +75,41 @@ public class DeprovisionWorkflowTransportActionTests extends OpenSearchTestCase
public void setUp() throws Exception {
super.setUp();
this.client = mock(Client.class);
ThreadPool clientThreadPool = spy(threadPool);
when(client.threadPool()).thenReturn(clientThreadPool);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);
when(clientThreadPool.getThreadContext()).thenReturn(threadContext);

this.workflowStepFactory = mock(WorkflowStepFactory.class);
this.deleteConnectorStep = mock(DeleteConnectorStep.class);
when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep);

this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class);
flowFrameworkSettings = mock(FlowFrameworkSettings.class);
when(flowFrameworkSettings.getRequestTimeout()).thenReturn(TimeValue.timeValueSeconds(10));

this.deprovisionWorkflowTransportAction = new DeprovisionWorkflowTransportAction(
mock(TransportService.class),
mock(ActionFilters.class),
threadPool,
clientThreadPool,
client,
workflowStepFactory,
flowFrameworkIndicesHandler,
flowFrameworkSettings
);

MachineLearningNodeClient mlClient = new MachineLearningNodeClient(client);
ProcessNode processNode = mock(ProcessNode.class);
when(processNode.id()).thenReturn("step_1");
when(processNode.workflowStep()).thenReturn(new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler));
when(processNode.previousNodeInputs()).thenReturn(Collections.emptyMap());
when(processNode.input()).thenReturn(WorkflowData.EMPTY);
when(processNode.nodeTimeout()).thenReturn(TimeValue.timeValueSeconds(5));
this.deleteConnectorStep = mock(DeleteConnectorStep.class);
when(this.workflowStepFactory.createStep("delete_connector")).thenReturn(deleteConnectorStep);

ThreadPool clientThreadPool = mock(ThreadPool.class);
ThreadContext threadContext = new ThreadContext(Settings.EMPTY);

when(client.threadPool()).thenReturn(clientThreadPool);
when(clientThreadPool.getThreadContext()).thenReturn(threadContext);
}

@AfterClass
public static void cleanup() {
ThreadPool.terminate(threadPool, 500, TimeUnit.MILLISECONDS);
}

public void testDeprovisionWorkflow() throws IOException {
public void testDeprovisionWorkflow() throws Exception {
String workflowId = "1";

CountDownLatch latch = new CountDownLatch(1);
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
ActionListener<WorkflowResponse> listener = spy(new LatchedActionListener<WorkflowResponse>(mock(ActionListener.class), latch));
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);

doAnswer(invocation -> {
Expand All @@ -137,14 +129,17 @@ public void testDeprovisionWorkflow() throws IOException {
deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
ArgumentCaptor<WorkflowResponse> responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class);

latch.await(5, TimeUnit.SECONDS);
verify(listener, times(1)).onResponse(responseCaptor.capture());
assertEquals(workflowId, responseCaptor.getValue().getWorkflowId());
}

public void testFailToDeprovision() throws IOException {
public void testFailToDeprovision() throws Exception {
String workflowId = "1";

CountDownLatch latch = new CountDownLatch(1);
@SuppressWarnings("unchecked")
ActionListener<WorkflowResponse> listener = mock(ActionListener.class);
ActionListener<WorkflowResponse> listener = spy(new LatchedActionListener<WorkflowResponse>(mock(ActionListener.class), latch));
WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null);

doAnswer(invocation -> {
Expand All @@ -164,6 +159,7 @@ public void testFailToDeprovision() throws IOException {
deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener);
ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class);

latch.await(5, TimeUnit.SECONDS);
verify(listener, times(1)).onFailure(exceptionCaptor.capture());
assertEquals("Failed to deprovision some resources: [model_id modelId].", exceptionCaptor.getValue().getMessage());
}
Expand Down

0 comments on commit b46651e

Please sign in to comment.