diff --git a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java index fc33621ae..473b9a873 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportAction.java @@ -118,8 +118,10 @@ protected void doExecute(Task task, ReprovisionWorkflowRequest request, ActionLi client.execute(GetWorkflowStateAction.INSTANCE, getStateRequest, ActionListener.wrap(response -> { context.restore(); - if (!ProvisioningProgress.DONE.equals(ProvisioningProgress.valueOf(response.getWorkflowState().getState()))) { - String errorMessage = "The template can not be reprovisioned unless its provisioning state is DONE: " + workflowId; + State currentState = State.valueOf(response.getWorkflowState().getState()); + if (State.PROVISIONING.equals(currentState) || State.NOT_STARTED.equals(currentState)) { + String errorMessage = "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: " + + workflowId; throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 1e4294ad7..83e9b9310 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -272,12 +272,10 @@ public List<ProcessNode> createReprovisionSequence( // Case 4 : No modification to existing node, create proxy step to pass down required input to dependent nodes // Node ID should give us resources created - ResourceCreated nodeResource = null; - for (ResourceCreated resourceCreated : resourcesCreated) { - if (resourceCreated.workflowStepId().equals(node.id())) { - nodeResource = resourceCreated; - } - } + ResourceCreated nodeResource = resourcesCreated.stream() + .filter(rc -> rc.workflowStepId().equals(node.id())) + .findFirst() + .orElse(null); if (nodeResource != null) { // create process node diff --git a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java index 938574eed..899b7c83c 100644 --- a/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/ReprovisionWorkflowTransportActionTests.java @@ -16,8 +16,8 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; -import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.ResourceCreated; +import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowState; @@ -117,7 +117,7 @@ public void testReprovisionWorkflow() throws Exception { WorkflowState state = mock(WorkflowState.class); ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); - when(state.getState()).thenReturn(ProvisioningProgress.DONE.toString()); + when(state.getState()).thenReturn(State.COMPLETED.toString()); when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); listener.onResponse(new GetWorkflowStateResponse(state, true)); return null; @@ -143,7 +143,7 @@ public void testReprovisionWorkflow() throws Exception { assertEquals(workflowId, responseCaptor.getValue().getWorkflowId()); } - public void testReprovisionInProgressWorkflow() throws Exception { + public void testReprovisionProvisioningWorkflow() throws Exception { String workflowId = "1"; Template mockTemplate = mock(Template.class); @@ -164,7 +164,7 @@ public void testReprovisionInProgressWorkflow() throws Exception { WorkflowState state = mock(WorkflowState.class); ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); - when(state.getState()).thenReturn(ProvisioningProgress.IN_PROGRESS.toString()); + when(state.getState()).thenReturn(State.PROVISIONING.toString()); when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); listener.onResponse(new GetWorkflowStateResponse(state, true)); return null; @@ -178,7 +178,7 @@ public void testReprovisionInProgressWorkflow() throws Exception { ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals( - "The template can not be reprovisioned unless its provisioning state is DONE: 1", + "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: 1", exceptionCaptor.getValue().getMessage() ); } @@ -204,7 +204,7 @@ public void testReprovisionNotStartedWorkflow() throws Exception { WorkflowState state = mock(WorkflowState.class); ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); - when(state.getState()).thenReturn(ProvisioningProgress.NOT_STARTED.toString()); + when(state.getState()).thenReturn(State.NOT_STARTED.toString()); when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); listener.onResponse(new GetWorkflowStateResponse(state, true)); return null; @@ -218,7 +218,7 @@ public void testReprovisionNotStartedWorkflow() throws Exception { ArgumentCaptor<Exception> exceptionCaptor = ArgumentCaptor.forClass(Exception.class); verify(listener, times(1)).onFailure(exceptionCaptor.capture()); assertEquals( - "The template can not be reprovisioned unless its provisioning state is DONE: 1", + "The template can not be reprovisioned unless its provisioning state is DONE or FAILED: 1", exceptionCaptor.getValue().getMessage() ); } @@ -244,7 +244,7 @@ public void testFailedStateUpdate() throws Exception { WorkflowState state = mock(WorkflowState.class); ResourceCreated resourceCreated = new ResourceCreated("stepName", workflowId, "resourceType", "resourceId"); - when(state.getState()).thenReturn(ProvisioningProgress.DONE.toString()); + when(state.getState()).thenReturn(State.COMPLETED.toString()); when(state.resourcesCreated()).thenReturn(List.of(resourceCreated)); listener.onResponse(new GetWorkflowStateResponse(state, true)); return null;