diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index ecce8ec50..5a849fd89 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -53,6 +53,8 @@ private CommonValue() {} public static final String WORKFLOW_URI = FLOW_FRAMEWORK_BASE_URI + "/workflow"; /** Field name for workflow Id, the document Id of the indexed use case template */ public static final String WORKFLOW_ID = "workflow_id"; + /** Field name for dry run, the flag to indicate if validation is necessary */ + public static final String DRY_RUN = "dryrun"; /** The field name for provision workflow within a use case template*/ public static final String PROVISION_WORKFLOW = "provision"; diff --git a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java index f3cb55950..7e8aefc15 100644 --- a/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java +++ b/src/main/java/org/opensearch/flowframework/exception/FlowFrameworkException.java @@ -9,11 +9,15 @@ package org.opensearch.flowframework.exception; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; /** * Representation of Flow Framework Exceptions */ -public class FlowFrameworkException extends RuntimeException { +public class FlowFrameworkException extends RuntimeException implements ToXContentObject { private static final long serialVersionUID = 1L; @@ -60,4 +64,9 @@ public FlowFrameworkException(String message, Throwable cause, RestStatus restSt public RestStatus getRestStatus() { return restStatus; } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + return builder.startObject().field("error", this.getMessage()).endObject(); + } } diff --git a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java index 04a3fac5b..1b0f7c9d7 100644 --- a/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java +++ b/src/main/java/org/opensearch/flowframework/indices/FlowFrameworkIndicesHandler.java @@ -12,6 +12,7 @@ import com.google.common.io.Resources; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.admin.indices.create.CreateIndexRequest; import org.opensearch.action.admin.indices.create.CreateIndexResponse; import org.opensearch.action.admin.indices.mapping.put.PutMappingRequest; @@ -29,6 +30,7 @@ import org.opensearch.common.xcontent.XContentType; import org.opensearch.commons.authuser.User; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContent; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; @@ -148,7 +150,7 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } }, e -> { logger.error("Failed to create index " + indexName, e); - internalListener.onFailure(e); + internalListener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); }); CreateIndexRequest request = new CreateIndexRequest(indexName).mapping(mapping).settings(indexSettings); client.admin().indices().create(request, actionListener); @@ -181,8 +183,14 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe ); } }, exception -> { - logger.error("Failed to update index setting for: " + indexName, exception); - internalListener.onFailure(exception); + String errorMessage = "Failed to update index setting for: " + indexName; + logger.error(errorMessage, exception); + internalListener.onFailure( + new FlowFrameworkException( + errorMessage + " : " + exception.getMessage(), + ExceptionsHelper.status(exception) + ) + ); })); } else { internalListener.onFailure( @@ -190,8 +198,14 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe ); } }, exception -> { - logger.error("Failed to update index " + indexName, exception); - internalListener.onFailure(exception); + String errorMessage = "Failed to update index " + indexName; + logger.error(errorMessage, exception); + internalListener.onFailure( + new FlowFrameworkException( + errorMessage + " : " + exception.getMessage(), + ExceptionsHelper.status(exception) + ) + ); }) ); } else { @@ -200,8 +214,11 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe internalListener.onResponse(true); } }, e -> { - logger.error("Failed to update index mapping", e); - internalListener.onFailure(e); + String errorMessage = "Failed to update index mapping"; + logger.error(errorMessage, e); + internalListener.onFailure( + new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e)) + ); })); } else { // No need to update index if it's already updated. @@ -209,8 +226,9 @@ public void initFlowFrameworkIndexIfAbsent(FlowFrameworkIndex index, ActionListe } } } catch (Exception e) { - logger.error("Failed to init index " + indexName, e); - listener.onFailure(e); + String errorMessage = "Failed to init index " + indexName; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } } @@ -272,8 +290,9 @@ public void putTemplateToGlobalContext(Template template, ActionListener context.restore())); } catch (Exception e) { - logger.error("Failed to index global_context index"); - listener.onFailure(e); + String errorMessage = "Failed to index global_context index"; + logger.error(errorMessage); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } }, e -> { logger.error("Failed to create global_context index", e); @@ -310,13 +329,15 @@ public void putInitialStateToWorkflowState(String workflowId, User user, ActionL request.id(workflowId); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - logger.error("Failed to put state index document", e); - listener.onFailure(e); + String errorMessage = "Failed to put state index document"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } }, e -> { - logger.error("Failed to create global_context index", e); - listener.onFailure(e); + String errorMessage = "Failed to create global_context index"; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); })); } @@ -332,7 +353,7 @@ public void updateTemplateInGlobalContext(String documentId, Template template, + documentId + ", global_context index does not exist."; logger.error(exceptionMessage); - listener.onFailure(new Exception(exceptionMessage)); + listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST)); } else { IndexRequest request = new IndexRequest(GLOBAL_CONTEXT_INDEX).id(documentId); try ( @@ -343,8 +364,9 @@ public void updateTemplateInGlobalContext(String documentId, Template template, .setRefreshPolicy(WriteRequest.RefreshPolicy.IMMEDIATE); client.index(request, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - logger.error("Failed to update global_context entry : {}. {}", documentId, e.getMessage()); - listener.onFailure(e); + String errorMessage = "Failed to update global_context entry : " + documentId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } } } @@ -365,7 +387,7 @@ public void updateFlowFrameworkSystemIndexDoc( if (!doesIndexExist(indexName)) { String exceptionMessage = "Failed to update document for given workflow due to missing " + indexName + " index"; logger.error(exceptionMessage); - listener.onFailure(new Exception(exceptionMessage)); + listener.onFailure(new FlowFrameworkException(exceptionMessage, RestStatus.BAD_REQUEST)); } else { try (ThreadContext.StoredContext context = client.threadPool().getThreadContext().stashContext()) { UpdateRequest updateRequest = new UpdateRequest(indexName, documentId); @@ -376,8 +398,9 @@ public void updateFlowFrameworkSystemIndexDoc( // TODO: decide what condition can be considered as an update conflict and add retry strategy client.update(updateRequest, ActionListener.runBefore(listener, () -> context.restore())); } catch (Exception e) { - logger.error("Failed to update {} entry : {}. {}", indexName, documentId, e.getMessage()); - listener.onFailure(e); + String errorMessage = "Failed to update " + indexName + " entry : " + documentId; + logger.error(errorMessage, e); + listener.onFailure(new FlowFrameworkException(errorMessage + " : " + e.getMessage(), ExceptionsHelper.status(e))); } } } diff --git a/src/main/java/org/opensearch/flowframework/model/Template.java b/src/main/java/org/opensearch/flowframework/model/Template.java index a05c374d8..fbafb8fd0 100644 --- a/src/main/java/org/opensearch/flowframework/model/Template.java +++ b/src/main/java/org/opensearch/flowframework/model/Template.java @@ -177,7 +177,7 @@ public static Template parse(XContentParser parser) throws IOException { } } if (name == null) { - throw new IOException("An template object requires a name."); + throw new IOException("A template object requires a name."); } return new Template(name, description, useCase, templateVersion, compatibilityVersion, workflows, user); diff --git a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java index ace440f75..b5400e247 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestCreateWorkflowAction.java @@ -9,18 +9,26 @@ package org.opensearch.flowframework.rest; import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Template; import org.opensearch.flowframework.transport.CreateWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestToXContentListener; import java.io.IOException; import java.util.List; import java.util.Locale; +import static org.opensearch.flowframework.common.CommonValue.DRY_RUN; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_ID; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_URI; @@ -29,6 +37,7 @@ */ public class RestCreateWorkflowAction extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestCreateWorkflowAction.class); private static final String CREATE_WORKFLOW_ACTION = "create_workflow_action"; /** @@ -53,11 +62,32 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + try { - String workflowId = request.param(WORKFLOW_ID); - Template template = Template.parse(request.content().utf8ToString()); - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template); - return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); + String workflowId = request.param(WORKFLOW_ID); + Template template = Template.parse(request.content().utf8ToString()); + boolean dryRun = request.paramAsBoolean(DRY_RUN, false); + + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, template, dryRun); + + return channel -> client.execute(CreateWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.CREATED, builder)); + }, exception -> { + try { + FlowFrameworkException ex = (FlowFrameworkException) exception; + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back create workflow exception", e); + } + })); + } catch (Exception e) { + FlowFrameworkException ex = new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST); + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); + } } } diff --git a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java index 89471ee00..1bd07eaf0 100644 --- a/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java +++ b/src/main/java/org/opensearch/flowframework/rest/RestProvisionWorkflowAction.java @@ -9,14 +9,19 @@ package org.opensearch.flowframework.rest; import com.google.common.collect.ImmutableList; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.client.node.NodeClient; +import org.opensearch.core.action.ActionListener; import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.transport.ProvisionWorkflowAction; import org.opensearch.flowframework.transport.WorkflowRequest; import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.BytesRestResponse; import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestToXContentListener; import java.io.IOException; import java.util.List; @@ -30,6 +35,8 @@ */ public class RestProvisionWorkflowAction extends BaseRestHandler { + private static final Logger logger = LogManager.getLogger(RestProvisionWorkflowAction.class); + private static final String PROVISION_WORKFLOW_ACTION = "provision_workflow_action"; /** @@ -52,21 +59,35 @@ public List routes() { @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { - - // Validate content - if (request.hasContent()) { - throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); - } - - // Validate params String workflowId = request.param(WORKFLOW_ID); - if (workflowId == null) { - throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + try { + // Validate content + if (request.hasContent()) { + throw new FlowFrameworkException("Invalid request format", RestStatus.BAD_REQUEST); + } + // Validate params + if (workflowId == null) { + throw new FlowFrameworkException("workflow_id cannot be null", RestStatus.BAD_REQUEST); + } + // Create request and provision + WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); + return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, ActionListener.wrap(response -> { + XContentBuilder builder = response.toXContent(channel.newBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(RestStatus.OK, builder)); + }, exception -> { + try { + FlowFrameworkException ex = (FlowFrameworkException) exception; + XContentBuilder exceptionBuilder = ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS); + channel.sendResponse(new BytesRestResponse(ex.getRestStatus(), exceptionBuilder)); + } catch (IOException e) { + logger.error("Failed to send back provision workflow exception", e); + } + })); + } catch (FlowFrameworkException ex) { + return channel -> channel.sendResponse( + new BytesRestResponse(ex.getRestStatus(), ex.toXContent(channel.newErrorBuilder(), ToXContent.EMPTY_PARAMS)) + ); } - - // Create request and provision - WorkflowRequest workflowRequest = new WorkflowRequest(workflowId, null); - return channel -> client.execute(ProvisionWorkflowAction.INSTANCE, workflowRequest, new RestToXContentListener<>(channel)); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java index c0baccc21..a6b809fc8 100644 --- a/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/CreateWorkflowTransportAction.java @@ -11,6 +11,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.client.Client; @@ -23,9 +24,14 @@ import org.opensearch.flowframework.model.ProvisioningProgress; import org.opensearch.flowframework.model.State; import org.opensearch.flowframework.model.Template; +import org.opensearch.flowframework.model.Workflow; +import org.opensearch.flowframework.workflow.ProcessNode; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; +import java.util.List; + import static org.opensearch.flowframework.common.CommonValue.PROVISIONING_PROGRESS_FIELD; import static org.opensearch.flowframework.common.CommonValue.STATE_FIELD; import static org.opensearch.flowframework.common.CommonValue.WORKFLOW_STATE_INDEX; @@ -38,6 +44,7 @@ public class CreateWorkflowTransportAction extends HandledTransportAction listener) { + User user = getUserContext(client); Template templateWithUser = new Template( request.getTemplate().name(), @@ -72,6 +83,21 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { @@ -83,12 +109,21 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to save workflow state : {}", exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + } }) ); }, exception -> { logger.error("Failed to save use case template : {}", exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } + })); } else { // Update existing entry, full document replacement @@ -105,15 +140,31 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { logger.error("Failed to update workflow state : {}", exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } }) ); }, exception -> { logger.error("Failed to updated use case template {} : {}", request.getWorkflowId(), exception.getMessage()); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + if (exception instanceof FlowFrameworkException) { + listener.onFailure(exception); + } else { + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); + } + }) ); } } + private void validateWorkflows(Template template) throws Exception { + for (Workflow workflow : template.workflows().values()) { + List sortedNodes = workflowProcessSorter.sortProcessNodes(workflow); + workflowProcessSorter.validateGraph(sortedNodes); + } + } + } diff --git a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java index 22ac414e5..443bbf8a6 100644 --- a/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/ProvisionWorkflowTransportAction.java @@ -11,6 +11,7 @@ import com.google.common.collect.ImmutableMap; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.ExceptionsHelper; import org.opensearch.action.get.GetRequest; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; @@ -133,17 +134,17 @@ protected void doExecute(Task task, WorkflowRequest request, ActionListener { - if (exception instanceof IllegalArgumentException) { + if (exception instanceof FlowFrameworkException) { logger.error("Workflow validation failed for workflow : " + workflowId); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.BAD_REQUEST)); + listener.onFailure(exception); } else { logger.error("Failed to retrieve template from global context.", exception); - listener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } })); } catch (Exception e) { logger.error("Failed to retrieve template from global context.", e); - listener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); } } @@ -166,7 +167,7 @@ private void executeWorkflowAsync(String workflowId, List workflowS try { threadPool.executor(PROVISION_THREAD_POOL).execute(() -> { executeWorkflow(workflowSequence, provisionWorkflowListener); }); } catch (Exception exception) { - provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + provisionWorkflowListener.onFailure(new FlowFrameworkException(exception.getMessage(), ExceptionsHelper.status(exception))); } } @@ -206,7 +207,7 @@ private void executeWorkflow(List workflowSequence, ActionListener< } catch (IllegalArgumentException e) { workflowListener.onFailure(new FlowFrameworkException(e.getMessage(), RestStatus.BAD_REQUEST)); } catch (Exception ex) { - workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), RestStatus.INTERNAL_SERVER_ERROR)); + workflowListener.onFailure(new FlowFrameworkException(ex.getMessage(), ExceptionsHelper.status(ex))); } } diff --git a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java index 0b105552f..2d2046329 100644 --- a/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java +++ b/src/main/java/org/opensearch/flowframework/transport/WorkflowRequest.java @@ -32,15 +32,30 @@ public class WorkflowRequest extends ActionRequest { */ @Nullable private Template template; + /** + * Validation flag + */ + private boolean dryRun; /** - * Instantiates a new WorkflowRequest + * Instantiates a new WorkflowRequest and defaults dry run to false * @param workflowId the documentId of the workflow * @param template the use case template which describes the workflow */ public WorkflowRequest(@Nullable String workflowId, @Nullable Template template) { + this(workflowId, template, false); + } + + /** + * Instantiates a new WorkflowRequest + * @param workflowId the documentId of the workflow + * @param template the use case template which describes the workflow + * @param dryRun flag to indicate if validation is necessary + */ + public WorkflowRequest(@Nullable String workflowId, @Nullable Template template, boolean dryRun) { this.workflowId = workflowId; this.template = template; + this.dryRun = dryRun; } /** @@ -53,6 +68,7 @@ public WorkflowRequest(StreamInput in) throws IOException { this.workflowId = in.readOptionalString(); String templateJson = in.readOptionalString(); this.template = templateJson == null ? null : Template.parse(templateJson); + this.dryRun = in.readBoolean(); } /** @@ -73,11 +89,20 @@ public Template getTemplate() { return this.template; } + /** + * Gets the dry run validation flag + * @return the dry run boolean + */ + public boolean isDryRun() { + return this.dryRun; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeOptionalString(workflowId); out.writeOptionalString(template == null ? null : template.toJson()); + out.writeBoolean(dryRun); } @Override diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 745de5921..10a038cbb 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -11,6 +11,8 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; @@ -131,7 +133,10 @@ public void validateGraph(List processNodes) throws Exception { if (!allInputs.containsAll(expectedInputs)) { expectedInputs.removeAll(allInputs); - throw new IllegalArgumentException("Invalid graph, missing the following required inputs : " + expectedInputs.toString()); + throw new FlowFrameworkException( + "Invalid graph, missing the following required inputs : " + expectedInputs.toString(), + RestStatus.BAD_REQUEST + ); } } @@ -142,8 +147,9 @@ private TimeValue parseTimeout(WorkflowNode node) { String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD); TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName); if (timeValue.millis() < 0) { - throw new IllegalArgumentException( - "Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive" + throw new FlowFrameworkException( + "Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive", + RestStatus.BAD_REQUEST ); } return timeValue; @@ -155,14 +161,14 @@ private static List topologicalSort(List workflowNod for (WorkflowEdge edge : workflowEdges) { String source = edge.source(); if (!nodeIds.contains(source)) { - throw new IllegalArgumentException("Edge source " + source + " does not correspond to a node."); + throw new FlowFrameworkException("Edge source " + source + " does not correspond to a node.", RestStatus.BAD_REQUEST); } String dest = edge.destination(); if (!nodeIds.contains(dest)) { - throw new IllegalArgumentException("Edge destination " + dest + " does not correspond to a node."); + throw new FlowFrameworkException("Edge destination " + dest + " does not correspond to a node.", RestStatus.BAD_REQUEST); } if (source.equals(dest)) { - throw new IllegalArgumentException("Edge connects node " + source + " to itself."); + throw new FlowFrameworkException("Edge connects node " + source + " to itself.", RestStatus.BAD_REQUEST); } } @@ -185,7 +191,7 @@ private static List topologicalSort(List workflowNod Queue sourceNodes = new ArrayDeque<>(); workflowNodes.stream().filter(n -> !predecessorEdges.containsKey(n)).forEach(n -> sourceNodes.add(n)); if (sourceNodes.isEmpty()) { - throw new IllegalArgumentException("No start node detected: all nodes have a predecessor."); + throw new FlowFrameworkException("No start node detected: all nodes have a predecessor.", RestStatus.BAD_REQUEST); } logger.debug("Start node(s): {}", sourceNodes); @@ -208,7 +214,7 @@ private static List topologicalSort(List workflowNod } } if (!graph.isEmpty()) { - throw new IllegalArgumentException("Cycle detected: " + graph); + throw new FlowFrameworkException("Cycle detected: " + graph, RestStatus.BAD_REQUEST); } logger.debug("Execution sequence: {}", sortedNodes); return sortedNodes; diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json index 23eb81c00..241a8ecbc 100644 --- a/src/main/resources/mappings/workflow-steps.json +++ b/src/main/resources/mappings/workflow-steps.json @@ -32,7 +32,7 @@ "version", "protocol", "parameters", - "credentials", + "credential", "actions" ], "outputs":[ diff --git a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java index d897c6756..3daaa4536 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestCreateWorkflowActionTests.java @@ -11,6 +11,7 @@ import org.opensearch.Version; import org.opensearch.client.node.NodeClient; import org.opensearch.core.common.bytes.BytesArray; +import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; import org.opensearch.flowframework.TestHelpers; import org.opensearch.flowframework.model.Template; @@ -20,9 +21,9 @@ import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; -import java.io.IOException; import java.util.List; import java.util.Locale; import java.util.Map; @@ -84,13 +85,14 @@ public void testRestCreateWorkflowActionRoutes() { } - public void testInvalidCreateWorkflowRequest() throws IOException { + public void testInvalidCreateWorkflowRequest() throws Exception { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.createWorkflowPath) .withContent(new BytesArray(invalidTemplate), MediaTypeRegistry.JSON) .build(); - - IOException ex = expectThrows(IOException.class, () -> { createWorkflowRestAction.prepareRequest(request, nodeClient); }); - assertEquals("Unable to parse field [invalid] in a template object.", ex.getMessage()); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + createWorkflowRestAction.handleRequest(request, channel, nodeClient); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("Unable to parse field [invalid] in a template object.")); } } diff --git a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java index a44817cec..4d9ef22e4 100644 --- a/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java +++ b/src/test/java/org/opensearch/flowframework/rest/RestProvisionWorkflowActionTests.java @@ -12,13 +12,12 @@ import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.MediaTypeRegistry; -import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.rest.RestHandler.Route; import org.opensearch.rest.RestRequest; import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.test.rest.FakeRestChannel; import org.opensearch.test.rest.FakeRestRequest; -import java.io.IOException; import java.util.List; import java.util.Locale; @@ -51,31 +50,35 @@ public void testRestProvisiionWorkflowActionRoutes() { assertEquals(this.provisionWorkflowPath, routes.get(0).getPath()); } - public void testNullWorkflowIdAndTemplate() throws IOException { + public void testNullWorkflowId() throws Exception { - // Request with no content or params + // Request with no params RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.provisionWorkflowPath) .build(); - FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { - provisionWorkflowRestAction.prepareRequest(request, nodeClient); - }); - assertEquals("workflow_id cannot be null", ex.getMessage()); - assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + FakeRestChannel channel = new FakeRestChannel(request, true, 1); + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); + + assertEquals(1, channel.errors().get()); + assertEquals(RestStatus.BAD_REQUEST, channel.capturedResponse().status()); + assertTrue(channel.capturedResponse().content().utf8ToString().contains("workflow_id cannot be null")); } - public void testInvalidRequestWithContent() throws IOException { + public void testInvalidRequestWithContent() { RestRequest request = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) .withPath(this.provisionWorkflowPath) .withContent(new BytesArray("request body"), MediaTypeRegistry.JSON) .build(); - FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> { - provisionWorkflowRestAction.prepareRequest(request, nodeClient); + FakeRestChannel channel = new FakeRestChannel(request, false, 1); + IllegalArgumentException ex = expectThrows(IllegalArgumentException.class, () -> { + provisionWorkflowRestAction.handleRequest(request, channel, nodeClient); }); - assertEquals("Invalid request format", ex.getMessage()); - assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + assertEquals( + "request [POST /_plugins/_flow_framework/workflow/{workflow_id}/_provision] does not support having a body", + ex.getMessage() + ); } } diff --git a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java index b6f7bea2d..fbec8a034 100644 --- a/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/CreateWorkflowTransportActionTests.java @@ -22,6 +22,8 @@ import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; import org.opensearch.flowframework.util.ParseUtils; +import org.opensearch.flowframework.workflow.WorkflowProcessSorter; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ThreadPool; @@ -43,6 +45,7 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { private CreateWorkflowTransportAction createWorkflowTransportAction; private FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private WorkflowProcessSorter workflowProcessSorter; private Template template; private Client client = mock(Client.class); private ThreadPool threadPool; @@ -52,14 +55,16 @@ public class CreateWorkflowTransportActionTests extends OpenSearchTestCase { @Override public void setUp() throws Exception { super.setUp(); + threadPool = mock(ThreadPool.class); this.flowFrameworkIndicesHandler = mock(FlowFrameworkIndicesHandler.class); + this.workflowProcessSorter = new WorkflowProcessSorter(mock(WorkflowStepFactory.class), threadPool); this.createWorkflowTransportAction = new CreateWorkflowTransportAction( mock(TransportService.class), mock(ActionFilters.class), + workflowProcessSorter, flowFrameworkIndicesHandler, client ); - threadPool = mock(ThreadPool.class); // client = mock(Client.class); ThreadContext threadContext = new ThreadContext(Settings.EMPTY); // threadContext = mock(ThreadContext.class); @@ -88,6 +93,67 @@ public void setUp() throws Exception { ); } + public void testFailedDryRunValidation() { + + WorkflowNode createConnector = new WorkflowNode( + "workflow_step_1", + "create_connector", + Map.of(), + Map.ofEntries( + Map.entry("name", ""), + Map.entry("description", ""), + Map.entry("version", ""), + Map.entry("protocol", ""), + Map.entry("parameters", ""), + Map.entry("credential", ""), + Map.entry("actions", "") + ) + ); + + WorkflowNode registerModel = new WorkflowNode( + "workflow_step_2", + "register_model", + Map.ofEntries(Map.entry("workflow_step_1", "connector_id")), + Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description")) + ); + + WorkflowNode deployModel = new WorkflowNode( + "workflow_step_3", + "deploy_model", + Map.ofEntries(Map.entry("workflow_step_2", "model_id")), + Map.of() + ); + + WorkflowEdge edge1 = new WorkflowEdge(createConnector.id(), registerModel.id()); + WorkflowEdge edge2 = new WorkflowEdge(registerModel.id(), deployModel.id()); + WorkflowEdge cyclicalEdge = new WorkflowEdge(deployModel.id(), createConnector.id()); + + Workflow workflow = new Workflow( + Map.of(), + List.of(createConnector, registerModel, deployModel), + List.of(edge1, edge2, cyclicalEdge) + ); + + Template cyclicalTemplate = new Template( + "test", + "description", + "use case", + Version.fromString("1.0.0"), + List.of(Version.fromString("2.0.0"), Version.fromString("3.0.0")), + Map.of("workflow", workflow), + TestHelpers.randomUser() + ); + + @SuppressWarnings("unchecked") + ActionListener listener = mock(ActionListener.class); + WorkflowRequest createNewWorkflow = new WorkflowRequest(null, cyclicalTemplate, true); + + createWorkflowTransportAction.doExecute(mock(Task.class), createNewWorkflow, listener); + ArgumentCaptor exceptionCaptor = ArgumentCaptor.forClass(Exception.class); + verify(listener, times(1)).onFailure(exceptionCaptor.capture()); + assertEquals("No start node detected: all nodes have a predecessor.", exceptionCaptor.getValue().getMessage()); + } + public void testFailedToCreateNewWorkflow() { @SuppressWarnings("unchecked") ActionListener listener = mock(ActionListener.class); diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index 9f629ff9e..65fccbb7e 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -141,31 +141,35 @@ public void testOrdering() throws IOException { public void testCycles() { Exception ex; - ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); + ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A")), List.of(edge("A", "A"))))); assertEquals("Edge connects node A to itself.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "B")))) ); assertEquals("Edge connects node B to itself.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "B"), edge("B", "A")))) ); assertEquals(NO_START_NODE_DETECTED, ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B"), node("C")), List.of(edge("A", "B"), edge("B", "C"), edge("C", "B")))) ); assertTrue(ex.getMessage().startsWith(CYCLE_DETECTED)); assertTrue(ex.getMessage().contains("B->C")); assertTrue(ex.getMessage().contains("C->B")); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse( workflow( List.of(node("A"), node("B"), node("C"), node("D")), @@ -177,6 +181,7 @@ public void testCycles() { assertTrue(ex.getMessage().contains("B->C")); assertTrue(ex.getMessage().contains("C->D")); assertTrue(ex.getMessage().contains("D->B")); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); } public void testNoEdges() throws IOException { @@ -196,13 +201,15 @@ public void testNoEdges() throws IOException { public void testExceptions() throws IOException { Exception ex = assertThrows( - IllegalArgumentException.class, + FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("C", "B")))) ); assertEquals("Edge source C does not correspond to a node.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); - ex = assertThrows(IllegalArgumentException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "C"))))); + ex = assertThrows(FlowFrameworkException.class, () -> parse(workflow(List.of(node("A"), node("B")), List.of(edge("A", "C"))))); assertEquals("Edge destination C does not correspond to a node.", ex.getMessage()); + assertEquals(RestStatus.BAD_REQUEST, ((FlowFrameworkException) ex).getRestStatus()); ex = assertThrows( FlowFrameworkException.class, @@ -223,7 +230,7 @@ public void testSuccessfulGraphValidation() throws Exception { Map.entry("version", ""), Map.entry("protocol", ""), Map.entry("parameters", ""), - Map.entry("credentials", ""), + Map.entry("credential", ""), Map.entry("actions", "") ) ); @@ -268,11 +275,11 @@ public void testFailedGraphValidation() { Workflow workflow = new Workflow(Map.of(), List.of(registerModel, deployModel), List.of(edge)); List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow); - IllegalArgumentException ex = expectThrows( - IllegalArgumentException.class, + FlowFrameworkException ex = expectThrows( + FlowFrameworkException.class, () -> workflowProcessSorter.validateGraph(sortedProcessNodes) ); assertEquals("Invalid graph, missing the following required inputs : [connector_id]", ex.getMessage()); - + assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } }