diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 180248df8..89471ee00 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -10,6 +10,8 @@ import com.google.common.collect.ImmutableList; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; @@ -53,13 +55,13 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli // Validate content if (request.hasContent()) { - throw new IOException("Invalid request format"); + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); } // Validate params String workflowId = request.param(WORKFLOW_ID); if (workflowId == null) { - throw new IOException("workflow_id cannot be null"); + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); } // Create request and provision diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index eace47037..abc704064 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -14,6 +14,8 @@ import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.GlobalContextHandler; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -53,7 +55,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to save use case template : {}", exception.getMessage()); - listener.onFailure(exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); })); } else { // Update existing entry, full document replacement @@ -62,7 +64,7 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage()); - listener.onFailure(exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); })); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index be4f2ed2d..0dbec5bf2 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -17,6 +17,8 @@ import org.opensearch.common.inject.Inject; import org.opensearch.common.util.concurrent.ThreadContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.workflow.ProcessNode; @@ -92,11 +94,11 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to retrieve template from global context.", exception); - listener.onFailure(exception); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); })); } catch (Exception e) { logger.error("Failed to retrieve template from global context.", e); - listener.onFailure(e); + listener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } @@ -119,7 +121,7 @@ private void executeWorkflowAsync(String workflowId, Workflow workflow) { try { threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflow, provisionWorkflowListener); }); } catch (Exception exception) { - provisionWorkflowListener.onFailure(exception); + provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } @@ -157,7 +159,7 @@ private void executeWorkflow(Workflow workflow, ActionListener workflowL // TODO : Create State Index request with provisioning state, start time, end time, etc, pending implementation. String for now workflowListener.onResponse("READY"); } catch (CancellationException | CompletionException ex) { - workflowListener.onFailure(ex); + workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index 267b4f6ce..a44817cec 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -10,7 +10,9 @@ import org.opensearch.client.node.NodeClient; import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; @@ -56,8 +58,11 @@ public void testNullWorkflowIdAndTemplate() throws IOException { .withPath(this.provisionWorkflowPath) .build(); - IOException ex = expectThrows(IOException.class, () -> { provisionWorkflowRestAction.prepareRequest(request, nodeClient); }); + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { + provisionWorkflowRestAction.prepareRequest(request, nodeClient); + }); assertEquals("workflow_id cannot be null", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } public void testInvalidRequestWithContent() throws IOException { @@ -66,8 +71,11 @@ public void testInvalidRequestWithContent() throws IOException { .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) .build(); - IOException ex = expectThrows(IOException.class, () -> { provisionWorkflowRestAction.prepareRequest(request, nodeClient); }); + FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { + provisionWorkflowRestAction.prepareRequest(request, nodeClient); + }); assertEquals("Invalid request format", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } }