From 92d9108b58f913fc2c073df99afe49a411083790 Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Sun, 21 Jan 2024 18:41:06 -0800 Subject: [PATCH] Replace all CompletableFutures with PlainActionFutures (#419) * Intial commit to remove CompletableFuture Signed-off-by: Owais Kazi * Removed CompletableFuture from ProcessNode and tests Signed-off-by: Owais Kazi * Removed CompletableFuture from create index and pipeline workflow steps Signed-off-by: Owais Kazi * Passing tests Signed-off-by: Owais Kazi * Addressed initial comments Signed-off-by: Owais Kazi * Move log line Signed-off-by: Daniel Widdis * Reenable multi-node tests Signed-off-by: Daniel Widdis * Disable fail-fast Signed-off-by: Daniel Widdis --------- Signed-off-by: Owais Kazi Signed-off-by: Daniel Widdis Co-authored-by: Daniel Widdis --- .github/workflows/CI.yml | 3 +- .../DeprovisionWorkflowTransportAction.java | 6 +- .../ProvisionWorkflowTransportAction.java | 8 +-- .../AbstractRegisterLocalModelStep.java | 20 +++---- .../AbstractRetryableWorkflowStep.java | 3 +- .../workflow/CreateConnectorStep.java | 16 +++--- .../workflow/CreateIndexStep.java | 16 +++--- .../workflow/CreateIngestPipelineStep.java | 22 +++---- .../workflow/DeleteAgentStep.java | 12 ++-- .../workflow/DeleteConnectorStep.java | 12 ++-- .../workflow/DeleteModelStep.java | 12 ++-- .../workflow/DeployModelStep.java | 18 +++--- .../flowframework/workflow/NoOpStep.java | 9 ++- .../flowframework/workflow/ProcessNode.java | 30 ++++------ .../workflow/RegisterAgentStep.java | 16 +++--- .../workflow/RegisterModelGroupStep.java | 16 +++--- .../workflow/RegisterRemoteModelStep.java | 20 +++---- .../flowframework/workflow/ToolStep.java | 10 ++-- .../workflow/UndeployModelStep.java | 14 ++--- .../flowframework/workflow/WorkflowStep.java | 5 +- ...provisionWorkflowTransportActionTests.java | 12 ++-- .../workflow/CreateConnectorStepTests.java | 8 +-- .../workflow/CreateIndexStepTests.java | 10 ++-- .../CreateIngestPipelineStepTests.java | 14 ++--- .../workflow/DeleteAgentStepTests.java | 12 ++-- .../workflow/DeleteConnectorStepTests.java | 12 ++-- .../workflow/DeleteModelStepTests.java | 12 ++-- .../workflow/DeployModelStepTests.java | 10 ++-- .../workflow/ModelGroupStepTests.java | 12 ++-- .../flowframework/workflow/NoOpStepTests.java | 5 +- .../workflow/ProcessNodeTests.java | 57 +++++++++---------- .../workflow/RegisterAgentTests.java | 8 +-- .../RegisterLocalCustomModelStepTests.java | 13 ++--- ...RegisterLocalPretrainedModelStepTests.java | 13 ++--- ...sterLocalSparseEncodingModelStepTests.java | 13 ++--- .../RegisterRemoteModelStepTests.java | 14 ++--- .../flowframework/workflow/ToolStepTests.java | 4 +- .../workflow/UndeployModelStepTests.java | 12 ++-- 38 files changed, 245 insertions(+), 264 deletions(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 14c2ca25d..060e5a2d2 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -61,6 +61,7 @@ jobs: integTest: needs: [spotless, javadoc] strategy: + fail-fast: false matrix: os: [ubuntu-latest, macos-latest, windows-latest] # Don't use 21.0.2 https://github.com/opensearch-project/flow-framework/issues/426 @@ -81,6 +82,6 @@ jobs: run: | ./gradlew integTest yamlRestTest - name: Multi Nodes Integration Testing - if: matrix.java == 21 + if: matrix.java == '21.0.1' run: | ./gradlew integTest -PnumNodes=3 diff --git a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java index e026053e1..7383e0f12 100644 --- a/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportAction.java @@ -14,6 +14,7 @@ import org.opensearch.OpenSearchStatusException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; @@ -39,7 +40,6 @@ import java.util.List; import java.util.Map; import java.util.Objects; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; @@ -160,9 +160,9 @@ private void executeDeprovisionSequence( ProcessNode deprovisionNode = iter.next(); ResourceCreated resource = getResourceFromDeprovisionNode(deprovisionNode, resourcesCreated); String resourceNameAndId = getResourceNameAndId(resource); - CompletableFuture deprovisionFuture = deprovisionNode.execute(); + PlainActionFuture deprovisionFuture = deprovisionNode.execute(); try { - deprovisionFuture.join(); + deprovisionFuture.get(); logger.info("Successful {} for {}", deprovisionNode.id(), resourceNameAndId); // Remove from list so we don't try again iter.remove(); diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index f625aafb6..4e8562c25 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -14,6 +14,7 @@ import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; @@ -40,7 +41,6 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.CancellationException; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.ERROR_FIELD; @@ -183,7 +183,7 @@ private void executeWorkflowAsync(String workflowId, List workflowS private void executeWorkflow(List workflowSequence, String workflowId) { try { - List> workflowFutureList = new ArrayList<>(); + List> workflowFutureList = new ArrayList<>(); for (ProcessNode processNode : workflowSequence) { List predecessors = processNode.predecessors(); @@ -202,8 +202,8 @@ private void executeWorkflow(List workflowSequence, String workflow workflowFutureList.add(processNode.execute()); } - // Attempt to join each workflow step future, may throw a CompletionException if any step completes exceptionally - workflowFutureList.forEach(CompletableFuture::join); + // Attempt to join each workflow step future, may throw a ExecutionException if any step completes exceptionally + workflowFutureList.forEach(PlainActionFuture::actionGet); logger.info("Provisioning completed successfully for workflow {}", workflowId); flowFrameworkIndicesHandler.updateFlowFrameworkSystemIndexDoc( diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java index 1f4a78fd0..5074f3efa 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRegisterLocalModelStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -28,7 +29,6 @@ import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.ALL_CONFIG; @@ -75,14 +75,14 @@ protected AbstractRegisterLocalModelStep( } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture registerLocalModelFuture = new CompletableFuture<>(); + PlainActionFuture registerLocalModelFuture = PlainActionFuture.newFuture(); try { Map inputs = ParseUtils.getInputsFromPreviousSteps( @@ -180,7 +180,7 @@ public CompletableFuture execute( "successfully updated resources created in state index: {}", deployUpdateResponse.getIndex() ); - registerLocalModelFuture.complete( + registerLocalModelFuture.onResponse( new WorkflowData( Map.ofEntries( Map.entry(resourceName, id), @@ -192,7 +192,7 @@ public CompletableFuture execute( ); }, deployUpdateException -> { logger.error("Failed to update simulated deploy step resource", deployUpdateException); - registerLocalModelFuture.completeExceptionally( + registerLocalModelFuture.onFailure( new FlowFrameworkException( deployUpdateException.getMessage(), ExceptionsHelper.status(deployUpdateException) @@ -201,7 +201,7 @@ public CompletableFuture execute( }) ); } else { - registerLocalModelFuture.complete( + registerLocalModelFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())), currentNodeInputs.getWorkflowId(), @@ -209,16 +209,14 @@ public CompletableFuture execute( ) ); } - }, exception -> { registerLocalModelFuture.completeExceptionally(exception); }) + }, exception -> { registerLocalModelFuture.onFailure(exception); }) ); }, exception -> { logger.error("Failed to register local model"); - registerLocalModelFuture.completeExceptionally( - new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) - ); + registerLocalModelFuture.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); })); } catch (FlowFrameworkException e) { - registerLocalModelFuture.completeExceptionally(e); + registerLocalModelFuture.onFailure(e); } return registerLocalModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java index c277ab27d..63d0d7587 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/AbstractRetryableWorkflowStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.FutureUtils; import org.opensearch.core.action.ActionListener; @@ -68,7 +69,7 @@ protected AbstractRetryableWorkflowStep( protected void retryableGetMlTask( String workflowId, String nodeId, - CompletableFuture future, + PlainActionFuture future, String taskId, String workflowStep, ActionListener mlTaskListener diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java index 4f8cfa9a5..228b4161f 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateConnectorStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -33,7 +34,6 @@ import java.util.Map; import java.util.Map.Entry; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; @@ -70,13 +70,13 @@ public CreateConnectorStep(MachineLearningNodeClient mlClient, FlowFrameworkIndi // TODO: need to add retry conflicts here @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture createConnectorFuture = new CompletableFuture<>(); + PlainActionFuture createConnectorFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @@ -93,7 +93,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { mlCreateConnectorResponse.getConnectorId(), ActionListener.wrap(response -> { logger.info("successfully updated resources created in state index: {}", response.getIndex()); - createConnectorFuture.complete( + createConnectorFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(resourceName, mlCreateConnectorResponse.getConnectorId())), currentNodeInputs.getWorkflowId(), @@ -102,7 +102,7 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { ); }, exception -> { logger.error("Failed to update new created resource", exception); - createConnectorFuture.completeExceptionally( + createConnectorFuture.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) @@ -110,14 +110,14 @@ public void onResponse(MLCreateConnectorResponse mlCreateConnectorResponse) { } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + createConnectorFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @Override public void onFailure(Exception e) { logger.error("Failed to create connector"); - createConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + createConnectorFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -171,7 +171,7 @@ public void onFailure(Exception e) { mlClient.createConnector(mlInput, actionListener); } catch (FlowFrameworkException e) { - createConnectorFuture.completeExceptionally(e); + createConnectorFuture.onFailure(e); } return createConnectorFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java index 2ec4edeac..eee1f94ec 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIndexStep.java @@ -13,6 +13,7 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.settings.Settings; @@ -25,7 +26,6 @@ import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.atomic.AtomicBoolean; import static org.opensearch.flowframework.common.CommonValue.DEFAULT_MAPPING_OPTION; @@ -59,13 +59,13 @@ public CreateIndexStep(ClusterService clusterService, Client client, FlowFramewo } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture createIndexFuture = new CompletableFuture<>(); + PlainActionFuture createIndexFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override @@ -80,7 +80,7 @@ public void onResponse(CreateIndexResponse createIndexResponse) { createIndexResponse.index(), ActionListener.wrap(response -> { logger.info("successfully updated resource created in state index: {}", response.getIndex()); - createIndexFuture.complete( + createIndexFuture.onResponse( new WorkflowData( Map.of(resourceName, createIndexResponse.index()), currentNodeInputs.getWorkflowId(), @@ -89,21 +89,21 @@ public void onResponse(CreateIndexResponse createIndexResponse) { ); }, exception -> { logger.error("Failed to update new created resource", exception); - createIndexFuture.completeExceptionally( + createIndexFuture.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) ); } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - createIndexFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + createIndexFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @Override public void onFailure(Exception e) { logger.error("Failed to create an index", e); - createIndexFuture.completeExceptionally(e); + createIndexFuture.onFailure(e); } }; @@ -128,7 +128,7 @@ public void onFailure(Exception e) { } } catch (Exception e) { logger.error("Failed to find the correct resource for the workflow step", e); - createIndexFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + createIndexFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } // TODO: diff --git a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java index 79637779c..d0bbed40b 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.client.ClusterAdminClient; import org.opensearch.common.xcontent.XContentFactory; @@ -27,7 +28,6 @@ import java.util.List; import java.util.Map; import java.util.Map.Entry; -import java.util.concurrent.CompletableFuture; import java.util.stream.Stream; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -66,14 +66,14 @@ public CreateIngestPipelineStep(Client client, FlowFrameworkIndicesHandler flowF } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture createIngestPipelineFuture = new CompletableFuture<>(); + PlainActionFuture createIngestPipelineFuture = PlainActionFuture.newFuture(); String pipelineId = null; String description = null; @@ -127,7 +127,7 @@ public CompletableFuture execute( ); } catch (IOException e) { logger.error("Failed to create ingest pipeline configuration: " + e.getMessage()); - createIngestPipelineFuture.completeExceptionally(e); + createIngestPipelineFuture.onFailure(e); } break; } @@ -135,7 +135,9 @@ public CompletableFuture execute( if (configuration == null) { // Required workflow data not found - createIngestPipelineFuture.completeExceptionally(new Exception("Failed to create ingest pipeline, required inputs not found")); + createIngestPipelineFuture.onFailure( + new IllegalArgumentException("Failed to create ingest pipeline, required inputs not found") + ); } else { // Create PutPipelineRequest and execute PutPipelineRequest putPipelineRequest = new PutPipelineRequest(pipelineId, configuration, XContentType.JSON); @@ -153,7 +155,7 @@ public CompletableFuture execute( logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); // PutPipelineRequest returns only an AcknowledgeResponse, returning pipelineId instead // TODO: revisit this concept of pipeline_id to be consistent with what makes most sense to end user here - createIngestPipelineFuture.complete( + createIngestPipelineFuture.onResponse( new WorkflowData( Map.of(resourceName, putPipelineRequest.getId()), currentNodeInputs.getWorkflowId(), @@ -162,7 +164,7 @@ public CompletableFuture execute( ); }, exception -> { logger.error("Failed to update new created resource", exception); - createIngestPipelineFuture.completeExceptionally( + createIngestPipelineFuture.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) @@ -170,14 +172,12 @@ public CompletableFuture execute( } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - createIngestPipelineFuture.completeExceptionally( - new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)) - ); + createIngestPipelineFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }, exception -> { logger.error("Failed to create ingest pipeline : " + exception.getMessage()); - createIngestPipelineFuture.completeExceptionally(exception); + createIngestPipelineFuture.onFailure(exception); })); } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java index ca4c2066a..04c1cca92 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteAgentStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; @@ -20,7 +21,6 @@ import java.util.Collections; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; @@ -45,19 +45,19 @@ public DeleteAgentStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture deleteAgentFuture = new CompletableFuture<>(); + PlainActionFuture deleteAgentFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { - deleteAgentFuture.complete( + deleteAgentFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(AGENT_ID, deleteResponse.getId())), currentNodeInputs.getWorkflowId(), @@ -69,7 +69,7 @@ public void onResponse(DeleteResponse deleteResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to delete agent"); - deleteAgentFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + deleteAgentFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -88,7 +88,7 @@ public void onFailure(Exception e) { mlClient.deleteAgent(agentId, actionListener); } catch (FlowFrameworkException e) { - deleteAgentFuture.completeExceptionally(e); + deleteAgentFuture.onFailure(e); } return deleteAgentFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java index 3d48bc1a0..6c3376369 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteConnectorStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; @@ -20,7 +21,6 @@ import java.util.Collections; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; @@ -45,19 +45,19 @@ public DeleteConnectorStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture deleteConnectorFuture = new CompletableFuture<>(); + PlainActionFuture deleteConnectorFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { - deleteConnectorFuture.complete( + deleteConnectorFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(CONNECTOR_ID, deleteResponse.getId())), currentNodeInputs.getWorkflowId(), @@ -69,7 +69,7 @@ public void onResponse(DeleteResponse deleteResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to delete connector"); - deleteConnectorFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + deleteConnectorFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -88,7 +88,7 @@ public void onFailure(Exception e) { mlClient.deleteConnector(connectorId, actionListener); } catch (FlowFrameworkException e) { - deleteConnectorFuture.completeExceptionally(e); + deleteConnectorFuture.onFailure(e); } return deleteConnectorFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java index 1152d76e4..be8e66138 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeleteModelStep.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; @@ -20,7 +21,6 @@ import java.util.Collections; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -45,19 +45,19 @@ public DeleteModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture deleteModelFuture = new CompletableFuture<>(); + PlainActionFuture deleteModelFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override public void onResponse(DeleteResponse deleteResponse) { - deleteModelFuture.complete( + deleteModelFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(MODEL_ID, deleteResponse.getId())), currentNodeInputs.getWorkflowId(), @@ -69,7 +69,7 @@ public void onResponse(DeleteResponse deleteResponse) { @Override public void onFailure(Exception e) { logger.error("Failed to delete model"); - deleteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + deleteModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -89,7 +89,7 @@ public void onFailure(Exception e) { mlClient.deleteModel(modelId, actionListener); } catch (FlowFrameworkException e) { - deleteModelFuture.completeExceptionally(e); + deleteModelFuture.onFailure(e); } return deleteModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java index 2a52d41a1..5759f2ba2 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/DeployModelStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -24,7 +25,6 @@ import java.util.Collections; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; @@ -61,14 +61,14 @@ public DeployModelStep( } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture deployModelFuture = new CompletableFuture<>(); + PlainActionFuture deployModelFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override @@ -87,25 +87,21 @@ public void onResponse(MLDeployModelResponse mlDeployModelResponse) { // Deployed Model Resource has been updated String resourceName = getResourceByWorkflowStep(getName()); String id = getResourceId(mlTask); - deployModelFuture.complete( + deployModelFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(resourceName, id), Map.entry(REGISTER_MODEL_STATUS, mlTask.getState().name())), currentNodeInputs.getWorkflowId(), currentNodeId ) ); - }, - e -> { - deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); - } - ) + }, e -> { deployModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); }) ); } @Override public void onFailure(Exception e) { logger.error("Failed to deploy model"); - deployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + deployModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -125,7 +121,7 @@ public void onFailure(Exception e) { mlClient.deploy(modelId, actionListener); } catch (FlowFrameworkException e) { - deployModelFuture.completeExceptionally(e); + deployModelFuture.onFailure(e); } return deployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java index 1738d6f60..e13181cf7 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/NoOpStep.java @@ -8,8 +8,9 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; + import java.util.Map; -import java.util.concurrent.CompletableFuture; /** * A workflow step that does nothing. May be used for synchronizing other actions. @@ -23,13 +24,15 @@ public NoOpStep() {} public static final String NAME = "noop"; @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - return CompletableFuture.completedFuture(WorkflowData.EMPTY); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(WorkflowData.EMPTY); + return future; } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java index c4fd4c6fa..82aa2c2d1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ProcessNode.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.unit.TimeValue; import org.opensearch.threadpool.Scheduler.ScheduledCancellable; import org.opensearch.threadpool.ThreadPool; @@ -19,7 +20,6 @@ import java.util.Map; import java.util.concurrent.CompletableFuture; import java.util.concurrent.TimeoutException; -import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -39,7 +39,7 @@ public class ProcessNode { private final ThreadPool threadPool; private final TimeValue nodeTimeout; - private final CompletableFuture future = new CompletableFuture<>(); + private final PlainActionFuture future = PlainActionFuture.newFuture(); /** * Create this node linked to its executing process, including input data and any predecessor nodes. @@ -109,7 +109,7 @@ public WorkflowData input() { * @return A future indicating the processing state of this node. * Returns {@code null} if it has not begun executing, should not happen if a workflow is sorted and executed topologically. */ - public CompletableFuture future() { + public PlainActionFuture future() { return future; } @@ -139,51 +139,45 @@ public TimeValue nodeTimeout() { * @return this node's future. * This is returned immediately, while process execution continues asynchronously. */ - public CompletableFuture execute() { + public PlainActionFuture execute() { if (this.future.isDone()) { throw new IllegalStateException("Process Node [" + this.id + "] already executed."); } CompletableFuture.runAsync(() -> { - List> predFutures = predecessors.stream().map(p -> p.future()).collect(Collectors.toList()); try { - if (!predecessors.isEmpty()) { - CompletableFuture waitForPredecessors = CompletableFuture.allOf(predFutures.toArray(new CompletableFuture[0])); - waitForPredecessors.join(); - } - - logger.info("Starting {}.", this.id); // get the input data from predecessor(s) Map inputMap = new HashMap<>(); - for (CompletableFuture cf : predFutures) { - WorkflowData wd = cf.get(); + for (ProcessNode node : predecessors) { + WorkflowData wd = node.future().actionGet(); inputMap.put(wd.getNodeId(), wd); } + logger.info("Starting {}.", this.id); ScheduledCancellable delayExec = null; if (this.nodeTimeout.compareTo(TimeValue.ZERO) > 0) { delayExec = threadPool.schedule(() -> { if (!future.isDone()) { - future.completeExceptionally(new TimeoutException("Execute timed out for " + this.id)); + future.onFailure(new TimeoutException("Execute timed out for " + this.id)); } }, this.nodeTimeout, ThreadPool.Names.SAME); } // record start time for this step. - CompletableFuture stepFuture = this.workflowStep.execute( + PlainActionFuture stepFuture = this.workflowStep.execute( this.id, this.input, inputMap, this.previousNodeInputs ); // If completed exceptionally, this is a no-op - future.complete(stepFuture.get()); + future.onResponse(stepFuture.get()); // record end time passing workflow steps if (delayExec != null) { delayExec.cancel(); } logger.info("Finished {}.", this.id); - } catch (Throwable t) { - this.future.completeExceptionally(t.getCause() == null ? t : t.getCause()); + } catch (Exception e) { + this.future.onFailure(e); } }, threadPool.executor(WORKFLOW_THREAD_POOL)); return this.future; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java index 09f874bc7..8c36575a4 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterAgentStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.Nullable; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -32,7 +33,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.APP_TYPE_FIELD; @@ -78,7 +78,7 @@ public RegisterAgentStep(MachineLearningNodeClient mlClient, FlowFrameworkIndice } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, @@ -87,7 +87,7 @@ public CompletableFuture execute( String workflowId = currentNodeInputs.getWorkflowId(); - CompletableFuture registerAgentModelFuture = new CompletableFuture<>(); + PlainActionFuture registerAgentModelFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override @@ -102,7 +102,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { mlRegisterAgentResponse.getAgentId(), ActionListener.wrap(response -> { logger.info("successfully updated resources created in state index: {}", response.getIndex()); - registerAgentModelFuture.complete( + registerAgentModelFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(resourceName, mlRegisterAgentResponse.getAgentId())), workflowId, @@ -111,7 +111,7 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { ); }, exception -> { logger.error("Failed to update new created resource", exception); - registerAgentModelFuture.completeExceptionally( + registerAgentModelFuture.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) @@ -119,14 +119,14 @@ public void onResponse(MLRegisterAgentResponse mlRegisterAgentResponse) { } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - registerAgentModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + registerAgentModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @Override public void onFailure(Exception e) { logger.error("Failed to register the agent"); - registerAgentModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + registerAgentModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -201,7 +201,7 @@ public void onFailure(Exception e) { mlClient.registerAgent(mlAgent, actionListener); } catch (FlowFrameworkException e) { - registerAgentModelFuture.completeExceptionally(e); + registerAgentModelFuture.onFailure(e); } return registerAgentModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java index 518e2879e..9acda1b6c 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterModelGroupStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.util.CollectionUtils; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -26,7 +27,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.ADD_ALL_BACKEND_ROLES; import static org.opensearch.flowframework.common.CommonValue.BACKEND_ROLES_FIELD; @@ -61,13 +61,13 @@ public RegisterModelGroupStep(MachineLearningNodeClient mlClient, FlowFrameworkI } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture registerModelGroupFuture = new CompletableFuture<>(); + PlainActionFuture registerModelGroupFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @Override @@ -82,7 +82,7 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse mlRegisterModelGroupResponse.getModelGroupId(), ActionListener.wrap(updateResponse -> { logger.info("successfully updated resources created in state index: {}", updateResponse.getIndex()); - registerModelGroupFuture.complete( + registerModelGroupFuture.onResponse( new WorkflowData( Map.ofEntries( Map.entry(resourceName, mlRegisterModelGroupResponse.getModelGroupId()), @@ -94,7 +94,7 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse ); }, exception -> { logger.error("Failed to update new created resource", exception); - registerModelGroupFuture.completeExceptionally( + registerModelGroupFuture.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) @@ -102,14 +102,14 @@ public void onResponse(MLRegisterModelGroupResponse mlRegisterModelGroupResponse } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + registerModelGroupFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @Override public void onFailure(Exception e) { logger.error("Failed to register model group"); - registerModelGroupFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + registerModelGroupFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -148,7 +148,7 @@ public void onFailure(Exception e) { mlClient.registerModelGroup(mlInput, actionListener); } catch (FlowFrameworkException e) { - registerModelGroupFuture.completeExceptionally(e); + registerModelGroupFuture.onFailure(e); } return registerModelGroupFuture; diff --git a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java index 4ce4eed78..8cd184a18 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStep.java @@ -11,6 +11,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.ExceptionsHelper; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -24,7 +25,6 @@ import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.DEPLOY_FIELD; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; @@ -59,14 +59,14 @@ public RegisterRemoteModelStep(MachineLearningNodeClient mlClient, FlowFramework } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture registerRemoteModelFuture = new CompletableFuture<>(); + PlainActionFuture registerRemoteModelFuture = PlainActionFuture.newFuture(); Set requiredKeys = Set.of(NAME_FIELD, CONNECTOR_ID); Set optionalKeys = Set.of(MODEL_GROUP_ID, DESCRIPTION_FIELD, DEPLOY_FIELD); @@ -126,7 +126,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { completeRegisterFuture(deployUpdateResponse, resourceName, mlRegisterModelResponse); }, deployUpdateException -> { logger.error("Failed to update simulated deploy step resource", deployUpdateException); - registerRemoteModelFuture.completeExceptionally( + registerRemoteModelFuture.onFailure( new FlowFrameworkException( deployUpdateException.getMessage(), ExceptionsHelper.status(deployUpdateException) @@ -139,7 +139,7 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { } }, exception -> { logger.error("Failed to update new created resource", exception); - registerRemoteModelFuture.completeExceptionally( + registerRemoteModelFuture.onFailure( new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception)) ); }) @@ -147,15 +147,13 @@ public void onResponse(MLRegisterModelResponse mlRegisterModelResponse) { } catch (Exception e) { logger.error("Failed to parse and update new created resource", e); - registerRemoteModelFuture.completeExceptionally( - new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e)) - ); + registerRemoteModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegisterModelResponse mlRegisterModelResponse) { logger.info("successfully updated resources created in state index: {}", response.getIndex()); - registerRemoteModelFuture.complete( + registerRemoteModelFuture.onResponse( new WorkflowData( Map.ofEntries( Map.entry(resourceName, mlRegisterModelResponse.getModelId()), @@ -170,12 +168,12 @@ void completeRegisterFuture(UpdateResponse response, String resourceName, MLRegi @Override public void onFailure(Exception e) { logger.error("Failed to register remote model"); - registerRemoteModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + registerRemoteModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }); } catch (FlowFrameworkException e) { - registerRemoteModelFuture.completeExceptionally(e); + registerRemoteModelFuture.onFailure(e); } return registerRemoteModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index 9b6df63eb..d3d3611f1 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -10,6 +10,7 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; import org.opensearch.ml.common.agent.MLToolSpec; @@ -17,7 +18,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.concurrent.CompletableFuture; import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; import static org.opensearch.flowframework.common.CommonValue.INCLUDE_OUTPUT_IN_AGENT_RESPONSE; @@ -34,11 +34,11 @@ public class ToolStep implements WorkflowStep { private static final Logger logger = LogManager.getLogger(ToolStep.class); - CompletableFuture toolFuture = new CompletableFuture<>(); + PlainActionFuture toolFuture = PlainActionFuture.newFuture(); static final String NAME = "create_tool"; @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, @@ -80,7 +80,7 @@ public CompletableFuture execute( MLToolSpec mlToolSpec = builder.build(); - toolFuture.complete( + toolFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(TOOLS_FIELD, mlToolSpec)), currentNodeInputs.getWorkflowId(), @@ -91,7 +91,7 @@ public CompletableFuture execute( logger.info("Tool registered successfully {}", type); } catch (FlowFrameworkException e) { - toolFuture.completeExceptionally(e); + toolFuture.onFailure(e); } return toolFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java index 78412e482..a90ff1aa8 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/UndeployModelStep.java @@ -13,6 +13,7 @@ import org.opensearch.ExceptionsHelper; import org.opensearch.OpenSearchException; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.util.ParseUtils; @@ -23,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CompletableFuture; import java.util.stream.Collectors; import static org.opensearch.flowframework.common.CommonValue.SUCCESS; @@ -50,13 +50,13 @@ public UndeployModelStep(MachineLearningNodeClient mlClient) { } @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture undeployModelFuture = new CompletableFuture<>(); + PlainActionFuture undeployModelFuture = PlainActionFuture.newFuture(); ActionListener actionListener = new ActionListener<>() { @@ -64,7 +64,7 @@ public CompletableFuture execute( public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { List failures = mlUndeployModelsResponse.getResponse().failures(); if (failures.isEmpty()) { - undeployModelFuture.complete( + undeployModelFuture.onResponse( new WorkflowData( Map.ofEntries(Map.entry(SUCCESS, !mlUndeployModelsResponse.getResponse().hasFailures())), currentNodeInputs.getWorkflowId(), @@ -75,14 +75,14 @@ public void onResponse(MLUndeployModelsResponse mlUndeployModelsResponse) { List failedNodes = failures.stream().map(FailedNodeException::nodeId).collect(Collectors.toList()); String message = "Failed to undeploy model on nodes " + failedNodes; logger.error(message); - undeployModelFuture.completeExceptionally(new OpenSearchException(message)); + undeployModelFuture.onFailure(new OpenSearchException(message)); } } @Override public void onFailure(Exception e) { logger.error("Failed to unldeploy model"); - undeployModelFuture.completeExceptionally(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); + undeployModelFuture.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } }; @@ -102,7 +102,7 @@ public void onFailure(Exception e) { mlClient.undeploy(new String[] { modelId }, null, actionListener); } catch (FlowFrameworkException e) { - undeployModelFuture.completeExceptionally(e); + undeployModelFuture.onFailure(e); } return undeployModelFuture; } diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java index cbd127ea5..16cc2b200 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java @@ -8,8 +8,9 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; + import java.util.Map; -import java.util.concurrent.CompletableFuture; /** * Interface for the workflow setup of different building blocks. @@ -24,7 +25,7 @@ public interface WorkflowStep { * @param previousNodeInputs Input params for this node that come from previous steps * @return A CompletableFuture of the building block. This block should return immediately, but not be completed until the step executes, containing either the step's output data or {@link WorkflowData#EMPTY} which may be passed to follow-on steps. */ - CompletableFuture execute( + PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, diff --git a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java index 4d29fa827..3582b7974 100644 --- a/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/DeprovisionWorkflowTransportActionTests.java @@ -10,6 +10,7 @@ import org.opensearch.action.LatchedActionListener; import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.client.Client; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -32,7 +33,6 @@ import org.junit.AfterClass; import java.util.List; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; @@ -122,9 +122,9 @@ public void testDeprovisionWorkflow() throws Exception { return null; }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); - when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn( - CompletableFuture.completedFuture(WorkflowData.EMPTY) - ); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onResponse(WorkflowData.EMPTY); + when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(WorkflowResponse.class); @@ -152,8 +152,8 @@ public void testFailToDeprovision() throws Exception { return null; }).when(client).execute(any(GetWorkflowStateAction.class), any(GetWorkflowStateRequest.class), any()); - CompletableFuture future = new CompletableFuture<>(); - future.completeExceptionally(new RuntimeException("rte")); + PlainActionFuture future = PlainActionFuture.newFuture(); + future.onFailure(new RuntimeException("rte")); when(this.deleteConnectorStep.execute(anyString(), any(WorkflowData.class), anyMap(), anyMap())).thenReturn(future); deprovisionWorkflowTransportAction.doExecute(mock(Task.class), workflowRequest, listener); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java index 54bf871f6..242dfd02d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateConnectorStepTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -24,7 +25,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.Mock; @@ -99,7 +99,7 @@ public void testCreateConnector() throws IOException, ExecutionException, Interr return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = createConnectorStep.execute( + PlainActionFuture future = createConnectorStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -121,7 +121,7 @@ public void testCreateConnectorFailure() throws IOException { return null; }).when(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); - CompletableFuture future = createConnectorStep.execute( + PlainActionFuture future = createConnectorStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -130,7 +130,7 @@ public void testCreateConnectorFailure() throws IOException { verify(machineLearningNodeClient).createConnector(any(MLCreateConnectorInput.class), any()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to create connector", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java index 32ae31c43..7eb02891c 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIndexStepTests.java @@ -10,6 +10,7 @@ import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; import org.opensearch.client.Client; @@ -31,7 +32,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; @@ -109,7 +109,7 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute( + PlainActionFuture future = createIndexStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -119,7 +119,7 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio verify(indicesAdminClient, times(1)).create(any(CreateIndexRequest.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new CreateIndexResponse(true, true, "demo")); - assertTrue(future.isDone() && !future.isCompletedExceptionally()); + assertTrue(future.isDone()); Map outputData = Map.of(INDEX_NAME, "demo"); assertEquals(outputData, future.get().getContent()); @@ -129,7 +129,7 @@ public void testCreateIndexStep() throws ExecutionException, InterruptedExceptio public void testCreateIndexStepFailure() throws ExecutionException, InterruptedException { @SuppressWarnings({ "unchecked" }) ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIndexStep.execute( + PlainActionFuture future = createIndexStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -140,7 +140,7 @@ public void testCreateIndexStepFailure() throws ExecutionException, InterruptedE actionListenerCaptor.getValue().onFailure(new Exception("Failed to create an index")); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof Exception); assertEquals("Failed to create an index", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java index 9473b4ff9..a36e293b9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/CreateIngestPipelineStepTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.ingest.PutPipelineRequest; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.support.master.AcknowledgedResponse; import org.opensearch.action.update.UpdateResponse; import org.opensearch.client.AdminClient; @@ -22,7 +23,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; @@ -90,7 +90,7 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute( + PlainActionFuture future = createIngestPipelineStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -103,7 +103,7 @@ public void testCreateIngestPipelineStep() throws InterruptedException, Executio verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onResponse(new AcknowledgedResponse(true)); - assertTrue(future.isDone() && !future.isCompletedExceptionally()); + assertTrue(future.isDone()); assertEquals(outpuData.getContent(), future.get().getContent()); } @@ -113,7 +113,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { @SuppressWarnings("unchecked") ArgumentCaptor> actionListenerCaptor = ArgumentCaptor.forClass(ActionListener.class); - CompletableFuture future = createIngestPipelineStep.execute( + PlainActionFuture future = createIngestPipelineStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -126,7 +126,7 @@ public void testCreateIngestPipelineStepFailure() throws InterruptedException { verify(clusterAdminClient, times(1)).putPipeline(any(PutPipelineRequest.class), actionListenerCaptor.capture()); actionListenerCaptor.getValue().onFailure(new Exception("Failed to create ingest pipeline")); - assertTrue(future.isDone() && future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); assertTrue(exception.getCause() instanceof Exception); @@ -148,13 +148,13 @@ public void testMissingData() throws InterruptedException { "test-node-id" ); - CompletableFuture future = createIngestPipelineStep.execute( + PlainActionFuture future = createIngestPipelineStep.execute( incorrectData.getNodeId(), incorrectData, Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isDone() && future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException exception = assertThrows(ExecutionException.class, () -> future.get()); assertTrue(exception.getCause() instanceof Exception); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java index 2bc13faf3..8121f53be 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteAgentStepTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; @@ -20,7 +21,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.Mock; @@ -60,7 +60,7 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte return null; }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); - CompletableFuture future = deleteAgentStep.execute( + PlainActionFuture future = deleteAgentStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(AGENT_ID, agentId), "workflowId", "nodeId")), @@ -75,14 +75,14 @@ public void testDeleteAgent() throws IOException, ExecutionException, Interrupte public void testNoAgentIdInOutput() throws IOException { DeleteAgentStep deleteAgentStep = new DeleteAgentStep(machineLearningNodeClient); - CompletableFuture future = deleteAgentStep.execute( + PlainActionFuture future = deleteAgentStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Missing required inputs [agent_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); @@ -97,7 +97,7 @@ public void testDeleteAgentFailure() throws IOException { return null; }).when(machineLearningNodeClient).deleteAgent(any(String.class), any()); - CompletableFuture future = deleteAgentStep.execute( + PlainActionFuture future = deleteAgentStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(AGENT_ID, "test"), "workflowId", "nodeId")), @@ -106,7 +106,7 @@ public void testDeleteAgentFailure() throws IOException { verify(machineLearningNodeClient).deleteAgent(any(String.class), any()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to delete agent", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java index 827298fb4..45aa4b7ad 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteConnectorStepTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; @@ -20,7 +21,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.Mock; @@ -60,7 +60,7 @@ public void testDeleteConnector() throws IOException, ExecutionException, Interr return null; }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); - CompletableFuture future = deleteConnectorStep.execute( + PlainActionFuture future = deleteConnectorStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(CONNECTOR_ID, connectorId), "workflowId", "nodeId")), @@ -75,14 +75,14 @@ public void testDeleteConnector() throws IOException, ExecutionException, Interr public void testNoConnectorIdInOutput() throws IOException { DeleteConnectorStep deleteConnectorStep = new DeleteConnectorStep(machineLearningNodeClient); - CompletableFuture future = deleteConnectorStep.execute( + PlainActionFuture future = deleteConnectorStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Missing required inputs [connector_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); @@ -97,7 +97,7 @@ public void testDeleteConnectorFailure() throws IOException { return null; }).when(machineLearningNodeClient).deleteConnector(any(String.class), any()); - CompletableFuture future = deleteConnectorStep.execute( + PlainActionFuture future = deleteConnectorStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(CONNECTOR_ID, "test"), "workflowId", "nodeId")), @@ -106,7 +106,7 @@ public void testDeleteConnectorFailure() throws IOException { verify(machineLearningNodeClient).deleteConnector(any(String.class), any()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to delete connector", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java index 38a603f0f..4c69347cb 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeleteModelStepTests.java @@ -9,6 +9,7 @@ package org.opensearch.flowframework.workflow; import org.opensearch.action.delete.DeleteResponse; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.Index; import org.opensearch.core.index.shard.ShardId; @@ -20,7 +21,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.Mock; @@ -60,7 +60,7 @@ public void testDeleteModel() throws IOException, ExecutionException, Interrupte return null; }).when(machineLearningNodeClient).deleteModel(any(String.class), any()); - CompletableFuture future = deleteModelStep.execute( + PlainActionFuture future = deleteModelStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), @@ -75,14 +75,14 @@ public void testDeleteModel() throws IOException, ExecutionException, Interrupte public void testNoModelIdInOutput() throws IOException { DeleteModelStep deleteModelStep = new DeleteModelStep(machineLearningNodeClient); - CompletableFuture future = deleteModelStep.execute( + PlainActionFuture future = deleteModelStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); @@ -97,7 +97,7 @@ public void testDeleteModelFailure() throws IOException { return null; }).when(machineLearningNodeClient).deleteModel(any(String.class), any()); - CompletableFuture future = deleteModelStep.execute( + PlainActionFuture future = deleteModelStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), @@ -106,7 +106,7 @@ public void testDeleteModelFailure() throws IOException { verify(machineLearningNodeClient).deleteModel(any(String.class), any()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to delete model", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java index 92d5be388..8f9192639 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/DeployModelStepTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -34,7 +35,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -150,14 +150,14 @@ public void testDeployModel() throws ExecutionException, InterruptedException, I return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = deployModel.execute( + PlainActionFuture future = deployModel.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), Collections.emptyMap() ); - future.join(); + future.actionGet(); verify(machineLearningNodeClient, times(1)).deploy(any(String.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); @@ -177,7 +177,7 @@ public void testDeployModelFailure() { return null; }).when(machineLearningNodeClient).deploy(eq("modelId"), actionListenerCaptor.capture()); - CompletableFuture future = deployModel.execute( + PlainActionFuture future = deployModel.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -232,7 +232,7 @@ public void testDeployModelTaskFailure() throws IOException, InterruptedExceptio return null; }).when(machineLearningNodeClient).getTask(any(), any()); - CompletableFuture future = this.deployModel.execute( + PlainActionFuture future = this.deployModel.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java index 749a7b2d5..ea46bc9ec 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ModelGroupStepTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -25,7 +26,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; @@ -91,7 +91,7 @@ public void testRegisterModelGroup() throws ExecutionException, InterruptedExcep return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = modelGroupStep.execute( + PlainActionFuture future = modelGroupStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -118,7 +118,7 @@ public void testRegisterModelGroupFailure() throws IOException { return null; }).when(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - CompletableFuture future = modelGroupStep.execute( + PlainActionFuture future = modelGroupStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -127,7 +127,7 @@ public void testRegisterModelGroupFailure() throws IOException { verify(machineLearningNodeClient).registerModelGroup(any(MLRegisterModelGroupInput.class), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to register model group", ex.getCause().getMessage()); @@ -137,14 +137,14 @@ public void testRegisterModelGroupFailure() throws IOException { public void testRegisterModelGroupWithNoName() throws IOException { RegisterModelGroupStep modelGroupStep = new RegisterModelGroupStep(machineLearningNodeClient, flowFrameworkIndicesHandler); - CompletableFuture future = modelGroupStep.execute( + PlainActionFuture future = modelGroupStep.execute( inputDataWithNoName.getNodeId(), inputDataWithNoName, Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Missing required inputs [name] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java index 1782375cc..171c75272 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/NoOpStepTests.java @@ -8,24 +8,23 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.util.Collections; -import java.util.concurrent.CompletableFuture; public class NoOpStepTests extends OpenSearchTestCase { public void testNoOpStep() throws IOException { NoOpStep noopStep = new NoOpStep(); assertEquals(NoOpStep.NAME, noopStep.getName()); - CompletableFuture future = noopStep.execute( + PlainActionFuture future = noopStep.execute( "nodeId", WorkflowData.EMPTY, Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); - assertFalse(future.isCompletedExceptionally()); } } diff --git a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java index 1f67d2b0b..be822d26d 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ProcessNodeTests.java @@ -8,9 +8,11 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.OpenSearchExecutors; +import org.opensearch.common.util.concurrent.UncategorizedExecutionException; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; import org.opensearch.threadpool.TestThreadPool; @@ -21,11 +23,8 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; -import java.util.concurrent.CompletionException; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; -import java.util.concurrent.TimeoutException; import static org.opensearch.flowframework.common.CommonValue.FLOW_FRAMEWORK_THREAD_POOL_PREFIX; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_THREAD_POOL; @@ -51,10 +50,10 @@ public static void setup() { ) ); - CompletableFuture successfulFuture = new CompletableFuture<>(); - successfulFuture.complete(WorkflowData.EMPTY); - CompletableFuture failedFuture = new CompletableFuture<>(); - failedFuture.completeExceptionally(new RuntimeException("Test exception")); + PlainActionFuture successfulFuture = PlainActionFuture.newFuture(); + successfulFuture.onResponse(WorkflowData.EMPTY); + PlainActionFuture failedFuture = PlainActionFuture.newFuture(); + failedFuture.onFailure(new RuntimeException("Test exception")); successfulNode = mock(ProcessNode.class); when(successfulNode.future()).thenReturn(successfulFuture); failedNode = mock(ProcessNode.class); @@ -70,14 +69,14 @@ public void testNode() throws InterruptedException, ExecutionException { // Tests where execute nas no timeout ProcessNode nodeA = new ProcessNode("A", new WorkflowStep() { @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture f = new CompletableFuture<>(); - f.complete(new WorkflowData(Map.of("test", "output"), "test-id", "test-node-id")); + PlainActionFuture f = PlainActionFuture.newFuture(); + f.onResponse(new WorkflowData(Map.of("test", "output"), "test-id", "test-node-id")); return f; } @@ -102,7 +101,7 @@ public String getName() { assertEquals(50, nodeA.nodeTimeout().millis()); assertEquals("A", nodeA.toString()); - CompletableFuture f = nodeA.execute(); + PlainActionFuture f = nodeA.execute(); assertEquals(f, nodeA.future()); assertEquals("output", f.get().getContent().get("test")); } @@ -111,14 +110,14 @@ public void testNodeNoTimeout() throws InterruptedException, ExecutionException // Tests where execute finishes before timeout ProcessNode nodeB = new ProcessNode("B", new WorkflowStep() { @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture future = new CompletableFuture<>(); - testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMillis(100), WORKFLOW_THREAD_POOL); + PlainActionFuture future = PlainActionFuture.newFuture(); + testThreadPool.schedule(() -> future.onResponse(WorkflowData.EMPTY), TimeValue.timeValueMillis(100), WORKFLOW_THREAD_POOL); return future; } @@ -133,7 +132,7 @@ public String getName() { assertEquals(Collections.emptyList(), nodeB.predecessors()); assertEquals("B", nodeB.toString()); - CompletableFuture f = nodeB.execute(); + PlainActionFuture f = nodeB.execute(); assertEquals(f, nodeB.future()); assertEquals(WorkflowData.EMPTY, f.get()); } @@ -142,14 +141,14 @@ public void testNodeTimeout() throws InterruptedException, ExecutionException { // Tests where execute finishes after timeout ProcessNode nodeZ = new ProcessNode("Zzz", new WorkflowStep() { @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture future = new CompletableFuture<>(); - testThreadPool.schedule(() -> future.complete(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), WORKFLOW_THREAD_POOL); + PlainActionFuture future = PlainActionFuture.newFuture(); + testThreadPool.schedule(() -> future.onResponse(WorkflowData.EMPTY), TimeValue.timeValueMinutes(1), WORKFLOW_THREAD_POOL); return future; } @@ -164,24 +163,24 @@ public String getName() { assertEquals(Collections.emptyList(), nodeZ.predecessors()); assertEquals("Zzz", nodeZ.toString()); - CompletableFuture f = nodeZ.execute(); - CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); - assertTrue(f.isCompletedExceptionally()); - assertEquals(TimeoutException.class, exception.getCause().getClass()); + PlainActionFuture f = nodeZ.execute(); + UncategorizedExecutionException exception = assertThrows(UncategorizedExecutionException.class, () -> f.actionGet()); + assertTrue(f.isDone()); + assertEquals(ExecutionException.class, exception.getCause().getClass()); } public void testExceptions() { // Tests where a predecessor future completed exceptionally ProcessNode nodeE = new ProcessNode("E", new WorkflowStep() { @Override - public CompletableFuture execute( + public PlainActionFuture execute( String currentNodeId, WorkflowData currentNodeInputs, Map outputs, Map previousNodeInputs ) { - CompletableFuture f = new CompletableFuture<>(); - f.complete(WorkflowData.EMPTY); + PlainActionFuture f = PlainActionFuture.newFuture(); + f.onResponse(WorkflowData.EMPTY); return f; } @@ -196,11 +195,9 @@ public String getName() { assertEquals(2, nodeE.predecessors().size()); assertEquals("E", nodeE.toString()); - CompletableFuture f = nodeE.execute(); - CompletionException exception = assertThrows(CompletionException.class, () -> f.join()); - assertTrue(f.isCompletedExceptionally()); - assertEquals("Test exception", exception.getCause().getMessage()); - + PlainActionFuture f = nodeE.execute(); + RuntimeException exception = assertThrows(RuntimeException.class, () -> f.actionGet()); + assertTrue(f.isDone()); // Tests where we already called execute assertThrows(IllegalStateException.class, () -> nodeE.execute()); } diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java index ae87b36cd..5360a4f12 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterAgentTests.java @@ -8,6 +8,7 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -25,7 +26,6 @@ import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.ArgumentCaptor; @@ -105,7 +105,7 @@ public void testRegisterAgent() throws IOException, ExecutionException, Interrup return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = registerAgentStep.execute( + PlainActionFuture future = registerAgentStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -137,7 +137,7 @@ public void testRegisterAgentFailure() throws IOException { return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = registerAgentStep.execute( + PlainActionFuture future = registerAgentStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), @@ -146,7 +146,7 @@ public void testRegisterAgentFailure() throws IOException { verify(machineLearningNodeClient).registerAgent(any(MLAgent.class), actionListenerCaptor.capture()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Failed to register the agent", ex.getCause().getMessage()); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java index 010abcf2d..4e2d74865 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalCustomModelStepTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -32,7 +33,6 @@ import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -161,14 +161,14 @@ public void testRegisterLocalCustomModelSuccess() throws Exception { return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = registerLocalModelStep.execute( + PlainActionFuture future = registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), Collections.emptyMap() ); - future.join(); + future.actionGet(); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); @@ -185,7 +185,7 @@ public void testRegisterLocalCustomModelFailure() { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerLocalModelStep.execute( + PlainActionFuture future = this.registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -234,7 +234,7 @@ public void testRegisterLocalCustomModelTaskFailure() { return null; }).when(machineLearningNodeClient).getTask(any(), any()); - CompletableFuture future = this.registerLocalModelStep.execute( + PlainActionFuture future = this.registerLocalModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -247,14 +247,13 @@ public void testRegisterLocalCustomModelTaskFailure() { } public void testMissingInputs() { - CompletableFuture future = registerLocalModelStep.execute( + PlainActionFuture future = registerLocalModelStep.execute( "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java index 031967713..a46c292e9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalPretrainedModelStepTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -32,7 +33,6 @@ import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -155,14 +155,14 @@ public void testRegisterLocalPretrainedModelSuccess() throws Exception { return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = registerLocalPretrainedModelStep.execute( + PlainActionFuture future = registerLocalPretrainedModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), Collections.emptyMap() ); - future.join(); + future.actionGet(); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); @@ -179,7 +179,7 @@ public void testRegisterLocalPretrainedModelFailure() { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerLocalPretrainedModelStep.execute( + PlainActionFuture future = this.registerLocalPretrainedModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -228,7 +228,7 @@ public void testRegisterLocalPretrainedModelTaskFailure() { return null; }).when(machineLearningNodeClient).getTask(any(), any()); - CompletableFuture future = this.registerLocalPretrainedModelStep.execute( + PlainActionFuture future = this.registerLocalPretrainedModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -241,14 +241,13 @@ public void testRegisterLocalPretrainedModelTaskFailure() { } public void testMissingInputs() { - CompletableFuture future = registerLocalPretrainedModelStep.execute( + PlainActionFuture future = registerLocalPretrainedModelStep.execute( "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java index 6cedf632b..b548e09a9 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterLocalSparseEncodingModelStepTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.common.settings.Settings; import org.opensearch.common.unit.TimeValue; @@ -32,7 +33,6 @@ import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; @@ -158,14 +158,14 @@ public void testRegisterLocalSparseEncodingModelSuccess() throws Exception { return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = registerLocalSparseEncodingModelStep.execute( + PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), Collections.emptyMap() ); - future.join(); + future.actionGet(); verify(machineLearningNodeClient, times(1)).register(any(MLRegisterModelInput.class), any()); verify(machineLearningNodeClient, times(1)).getTask(any(), any()); @@ -182,7 +182,7 @@ public void testRegisterLocalSparseEncodingModelFailure() { return null; }).when(machineLearningNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerLocalSparseEncodingModelStep.execute( + PlainActionFuture future = this.registerLocalSparseEncodingModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -231,7 +231,7 @@ public void testRegisterLocalSparseEncodingModelTaskFailure() { return null; }).when(machineLearningNodeClient).getTask(any(), any()); - CompletableFuture future = this.registerLocalSparseEncodingModelStep.execute( + PlainActionFuture future = this.registerLocalSparseEncodingModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -244,14 +244,13 @@ public void testRegisterLocalSparseEncodingModelTaskFailure() { } public void testMissingInputs() { - CompletableFuture future = registerLocalSparseEncodingModelStep.execute( + PlainActionFuture future = registerLocalSparseEncodingModelStep.execute( "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); diff --git a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java index 2bc57f888..50766efe5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/RegisterRemoteModelStepTests.java @@ -10,6 +10,7 @@ import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.action.update.UpdateResponse; import org.opensearch.core.action.ActionListener; import org.opensearch.core.index.shard.ShardId; @@ -23,7 +24,6 @@ import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.Mock; @@ -89,7 +89,7 @@ public void testRegisterRemoteModelSuccess() throws Exception { return null; }).when(flowFrameworkIndicesHandler).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); - CompletableFuture future = this.registerRemoteModelStep.execute( + PlainActionFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), @@ -101,7 +101,6 @@ public void testRegisterRemoteModelSuccess() throws Exception { verify(flowFrameworkIndicesHandler, times(1)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); assertTrue(future.isDone()); - assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); @@ -137,7 +136,7 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { "test-node-id" ); - CompletableFuture future = this.registerRemoteModelStep.execute( + PlainActionFuture future = this.registerRemoteModelStep.execute( deployWorkflowData.getNodeId(), deployWorkflowData, Collections.emptyMap(), @@ -149,7 +148,6 @@ public void testRegisterAndDeployRemoteModelSuccess() throws Exception { verify(flowFrameworkIndicesHandler, times(2)).updateResourceInStateIndex(anyString(), anyString(), anyString(), anyString(), any()); assertTrue(future.isDone()); - assertFalse(future.isCompletedExceptionally()); assertEquals(modelId, future.get().getContent().get(MODEL_ID)); assertEquals(status, future.get().getContent().get(REGISTER_MODEL_STATUS)); } @@ -161,14 +159,13 @@ public void testRegisterRemoteModelFailure() { return null; }).when(mlNodeClient).register(any(MLRegisterModelInput.class), any()); - CompletableFuture future = this.registerRemoteModelStep.execute( + PlainActionFuture future = this.registerRemoteModelStep.execute( workflowData.getNodeId(), workflowData, Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = expectThrows(ExecutionException.class, () -> future.get().getClass()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("test", ex.getCause().getMessage()); @@ -176,14 +173,13 @@ public void testRegisterRemoteModelFailure() { } public void testMissingInputs() { - CompletableFuture future = this.registerRemoteModelStep.execute( + PlainActionFuture future = this.registerRemoteModelStep.execute( "nodeId", new WorkflowData(Collections.emptyMap(), "test-id", "test-node-id"), Collections.emptyMap(), Collections.emptyMap() ); assertTrue(future.isDone()); - assertTrue(future.isCompletedExceptionally()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertTrue(ex.getCause().getMessage().startsWith("Missing required inputs [")); diff --git a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java index c7e8df2d8..0dc4d7960 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/ToolStepTests.java @@ -8,13 +8,13 @@ */ package org.opensearch.flowframework.workflow; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.ml.common.agent.MLToolSpec; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; import java.util.Collections; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; public class ToolStepTests extends OpenSearchTestCase { @@ -40,7 +40,7 @@ public void setUp() throws Exception { public void testTool() throws IOException, ExecutionException, InterruptedException { ToolStep toolStep = new ToolStep(); - CompletableFuture future = toolStep.execute( + PlainActionFuture future = toolStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), diff --git a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java index 97f008e21..6fe1cade2 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/UndeployModelStepTests.java @@ -10,6 +10,7 @@ import org.opensearch.OpenSearchException; import org.opensearch.action.FailedNodeException; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.cluster.ClusterName; import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; @@ -23,7 +24,6 @@ import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutionException; import org.mockito.Mock; @@ -68,7 +68,7 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup return null; }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); - CompletableFuture future = UndeployModelStep.execute( + PlainActionFuture future = UndeployModelStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, modelId), "workflowId", "nodeId")), @@ -83,14 +83,14 @@ public void testUndeployModel() throws IOException, ExecutionException, Interrup public void testNoModelIdInOutput() throws IOException { UndeployModelStep UndeployModelStep = new UndeployModelStep(machineLearningNodeClient); - CompletableFuture future = UndeployModelStep.execute( + PlainActionFuture future = UndeployModelStep.execute( inputData.getNodeId(), inputData, Collections.emptyMap(), Collections.emptyMap() ); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof FlowFrameworkException); assertEquals("Missing required inputs [model_id] in workflow [test-id] node [test-node-id]", ex.getCause().getMessage()); @@ -114,7 +114,7 @@ public void testUndeployModelFailure() throws IOException { return null; }).when(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); - CompletableFuture future = UndeployModelStep.execute( + PlainActionFuture future = UndeployModelStep.execute( inputData.getNodeId(), inputData, Map.of("step_1", new WorkflowData(Map.of(MODEL_ID, "test"), "workflowId", "nodeId")), @@ -123,7 +123,7 @@ public void testUndeployModelFailure() throws IOException { verify(machineLearningNodeClient).undeploy(any(String[].class), any(), any()); - assertTrue(future.isCompletedExceptionally()); + assertTrue(future.isDone()); ExecutionException ex = assertThrows(ExecutionException.class, () -> future.get().getContent()); assertTrue(ex.getCause() instanceof OpenSearchException); assertEquals("Failed to undeploy model on nodes [failed-node]", ex.getCause().getMessage());