From d3c1b535d71408b7f81c08b946dc19ae3ea9b4cb Mon Sep 17 00:00:00 2001 From: Owais Kazi Date: Mon, 19 Feb 2024 10:55:02 -0800 Subject: [PATCH] Moved workflow-steps.json to Enum (#523) * Created enum for workflow steps json Signed-off-by: Owais Kazi * Added getters for enum and rest of the enums Signed-off-by: Owais Kazi * Removed workflow-steps.json entries and the file completely Signed-off-by: Owais Kazi * Fixed tests Signed-off-by: Owais Kazi * Added entry to CHANGELOG.md Signed-off-by: Owais Kazi * Addressed PR comments and removed the map Signed-off-by: Owais Kazi * Updated CHANGELOG.md Signed-off-by: Owais Kazi --------- Signed-off-by: Owais Kazi (cherry picked from commit be0df19e3e3579a018e0d6f0ba423228f0d0491d) --- CHANGELOG.md | 39 +- DEVELOPER_GUIDE.md | 2 +- .../flowframework/common/CommonValue.java | 2 + .../model/WorkflowStepValidator.java | 73 +--- .../model/WorkflowValidator.java | 36 -- .../transport/GetWorkflowStepResponse.java | 4 +- .../GetWorkflowStepTransportAction.java | 13 +- .../flowframework/workflow/ToolStep.java | 1 + .../workflow/WorkflowProcessSorter.java | 41 +- .../workflow/WorkflowStepFactory.java | 369 +++++++++++++++++- .../resources/mappings/workflow-steps.json | 183 --------- .../model/WorkflowStepValidatorTests.java | 32 +- .../model/WorkflowValidatorTests.java | 144 +++++-- .../GetWorkflowStepTransportActionTests.java | 7 +- .../workflow/WorkflowProcessSorterTests.java | 16 +- 15 files changed, 552 insertions(+), 410 deletions(-) delete mode 100644 src/main/resources/mappings/workflow-steps.json diff --git a/CHANGELOG.md b/CHANGELOG.md index 7bfc2c10b..3486c6e90 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,22 +1,23 @@ -CHANGELOG - +# CHANGELOG +All notable changes to this project are documented in this file. Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.1.0/) -2.12.0 Initial Release -## [Unreleased] -### Added -- Github workflow for changelog verification ([#440](https://github.com/opensearch-project/flow-framework/pull/440)) - -### Changed - -### Deprecated - -### Removed - -### Fixed - -### Security - - -[Unreleased]: https://github.com/opensearch-project/flow-framework/compare/2.x...HEAD +## [Unreleased 3.0](https://github.com/opensearch-project/flow-framework/compare/2.x...HEAD) +### Features +### Enhancements +### Bug Fixes +### Infrastructure +### Documentation +### Maintenance +### Refactoring + +## [Unreleased 2.x](https://github.com/opensearch-project/flow-framework/compare/2.12...2.x) +### Features +### Enhancements +### Bug Fixes +### Infrastructure +### Documentation +### Maintenance +### Refactoring +- Moved workflow-steps.json to Enum ([#523](https://github.com/opensearch-project/flow-framework/pull/523)) diff --git a/DEVELOPER_GUIDE.md b/DEVELOPER_GUIDE.md index 11ae0b89c..7b76a7dfe 100644 --- a/DEVELOPER_GUIDE.md +++ b/DEVELOPER_GUIDE.md @@ -108,6 +108,6 @@ To add functionality to workflows, add new Workflow Steps to the [`org.opensearc 1. Implement the [Workflow](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStep.java) interface. See existing steps for examples for input, output, and API execution. 2. Choose a unique name for the step which is not used by other steps. This will align with the `step_type` field in the templates and should be descriptive of what the step does. 3. Add a constructor and call it from the [WorkflowStepFactory](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java). -4. Add a configuration to the [`workflow-steps.json`](https://github.com/opensearch-project/flow-framework/blob/main/src/main/resources/mappings/workflow-steps.json) file specifying required inputs, outputs, required plugins, and optionally a different timeout than the default. +4. Add an entry to the [WorkflowStepFactory](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java) enum specifying required inputs, outputs, required plugins, and optionally a different timeout than the default. 5. If your step provisions a resource that should be deprovisioned, create the corresponding step and add both steps to the [`WorkflowResources`](https://github.com/opensearch-project/flow-framework/blob/main/src/main/java/org/opensearch/flowframework/common/WorkflowResources.java) enum. 6. Write unit and integration tests. diff --git a/src/main/java/org/opensearch/flowframework/common/CommonValue.java b/src/main/java/org/opensearch/flowframework/common/CommonValue.java index 5be7482e0..ef54addff 100644 --- a/src/main/java/org/opensearch/flowframework/common/CommonValue.java +++ b/src/main/java/org/opensearch/flowframework/common/CommonValue.java @@ -188,4 +188,6 @@ private CommonValue() {} public static final String RESOURCE_TYPE = "resource_type"; /** The field name for the resource id */ public static final String RESOURCE_ID = "resource_id"; + /** The field name for the opensearch-ml plugin */ + public static final String OPENSEARCH_ML = "opensearch-ml"; } diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java index 6f1b4242a..e18b0e20e 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowStepValidator.java @@ -11,18 +11,12 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.common.unit.TimeValue; -import org.opensearch.core.rest.RestStatus; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.flowframework.exception.FlowFrameworkException; import java.io.IOException; -import java.util.ArrayList; import java.util.List; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - /** * This represents an object of workflow steps json which maps each step to expected inputs and outputs */ @@ -39,6 +33,7 @@ public class WorkflowStepValidator implements ToXContentObject { /** Timeout field name */ private static final String TIMEOUT = "timeout"; + private String workflowStep; private List inputs; private List outputs; private List requiredPlugins; @@ -46,74 +41,26 @@ public class WorkflowStepValidator implements ToXContentObject { /** * Instantiate the object representing a Workflow Step validator + * @param workflowStep name of the workflow step * @param inputs the workflow step inputs * @param outputs the workflow step outputs * @param requiredPlugins the required plugins for this workflow step * @param timeout the timeout for this workflow step */ - public WorkflowStepValidator(List inputs, List outputs, List requiredPlugins, TimeValue timeout) { + public WorkflowStepValidator( + String workflowStep, + List inputs, + List outputs, + List requiredPlugins, + TimeValue timeout + ) { + this.workflowStep = workflowStep; this.inputs = inputs; this.outputs = outputs; this.requiredPlugins = requiredPlugins; this.timeout = timeout; } - /** - * Parse raw json content into a WorkflowStepValidator instance - * @param parser json based content parser - * @return an instance of the WorkflowStepValidator - * @throws IOException if the content cannot be parsed correctly - */ - public static WorkflowStepValidator parse(XContentParser parser) throws IOException { - List parsedInputs = new ArrayList<>(); - List parsedOutputs = new ArrayList<>(); - List requiredPlugins = new ArrayList<>(); - TimeValue timeout = null; - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = parser.currentName(); - parser.nextToken(); - switch (fieldName) { - case INPUTS_FIELD: - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - parsedInputs.add(parser.text()); - } - break; - case OUTPUTS_FIELD: - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - parsedOutputs.add(parser.text()); - } - break; - case REQUIRED_PLUGINS: - ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_ARRAY) { - requiredPlugins.add(parser.text()); - } - break; - case TIMEOUT: - try { - timeout = TimeValue.parseTimeValue(parser.text(), TIMEOUT); - } catch (IllegalArgumentException e) { - logger.error("Failed to parse TIMEOUT value for field [{}]", fieldName, e); - throw new FlowFrameworkException( - "Failed to parse workflow-step.json file for field [" + fieldName + "]", - RestStatus.INTERNAL_SERVER_ERROR - ); - } - break; - default: - throw new FlowFrameworkException( - "Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object.", - RestStatus.BAD_REQUEST - ); - } - } - return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins, timeout); - } - /** * Get the required inputs * @return the inputs diff --git a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java index 262eed800..cbc8d29af 100644 --- a/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java +++ b/src/main/java/org/opensearch/flowframework/model/WorkflowValidator.java @@ -11,15 +11,10 @@ import org.opensearch.common.xcontent.json.JsonXContent; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.flowframework.util.ParseUtils; import java.io.IOException; -import java.util.HashMap; import java.util.Map; -import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; - /** * This represents the workflow steps json which maps each step to expected inputs and outputs */ @@ -35,37 +30,6 @@ public WorkflowValidator(Map workflowStepValidato this.workflowStepValidators = workflowStepValidators; } - /** - * Parse raw json content into a WorkflowValidator instance - * @param parser json based content parser - * @return an instance of the WorkflowValidator - * @throws IOException if the content cannot be parsed correctly - */ - public static WorkflowValidator parse(XContentParser parser) throws IOException { - - Map workflowStepValidators = new HashMap<>(); - - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); - while (parser.nextToken() != XContentParser.Token.END_OBJECT) { - String type = parser.currentName(); - ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); - workflowStepValidators.put(type, WorkflowStepValidator.parse(parser)); - } - return new WorkflowValidator(workflowStepValidators); - } - - /** - * Parse a workflow step JSON file into a WorkflowValidator object - * - * @param file the file name of the workflow step json - * @return A {@link WorkflowValidator} represented by the JSON - * @throws IOException on failure to read and parse the json file - */ - public static WorkflowValidator parse(String file) throws IOException { - String json = ParseUtils.resourceToString("/" + file); - return parse(ParseUtils.jsonToParser(json)); - } - /** * Output this object in a compact JSON string. * diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepResponse.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepResponse.java index 0d9ce9371..3e736a124 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepResponse.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepResponse.java @@ -14,6 +14,7 @@ import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.flowframework.model.WorkflowValidator; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import java.io.IOException; @@ -23,6 +24,7 @@ public class GetWorkflowStepResponse extends ActionResponse implements ToXContentObject { private WorkflowValidator workflowValidator; + private WorkflowStepFactory workflowStepFactory; /** * Instantiates a new GetWorkflowStepResponse from an input stream @@ -31,7 +33,7 @@ public class GetWorkflowStepResponse extends ActionResponse implements ToXConten */ public GetWorkflowStepResponse(StreamInput in) throws IOException { super(in); - this.workflowValidator = WorkflowValidator.parse(in.readString()); + this.workflowValidator = this.workflowStepFactory.getWorkflowValidator(); } /** diff --git a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java index 0232c28f6..8b4d8a001 100644 --- a/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java +++ b/src/main/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportAction.java @@ -18,6 +18,7 @@ import org.opensearch.core.action.ActionListener; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.model.WorkflowValidator; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -27,21 +28,29 @@ public class GetWorkflowStepTransportAction extends HandledTransportAction { private final Logger logger = LogManager.getLogger(GetWorkflowStepTransportAction.class); + private final WorkflowStepFactory workflowStepFactory; /** * Instantiates a new GetWorkflowStepTransportAction instance * @param transportService the transport service * @param actionFilters action filters + * @param workflowStepFactory The factory instantiating workflow steps */ @Inject - public GetWorkflowStepTransportAction(TransportService transportService, ActionFilters actionFilters) { + public GetWorkflowStepTransportAction( + TransportService transportService, + ActionFilters actionFilters, + WorkflowStepFactory workflowStepFactory + ) { super(GetWorkflowStepAction.NAME, transportService, actionFilters, WorkflowRequest::new); + this.workflowStepFactory = workflowStepFactory; } @Override protected void doExecute(Task task, ActionRequest request, ActionListener listener) { try { - listener.onResponse(new GetWorkflowStepResponse(WorkflowValidator.parse("mappings/workflow-steps.json"))); + WorkflowValidator workflowValidator = this.workflowStepFactory.getWorkflowValidator(); + listener.onResponse(new GetWorkflowStepResponse(workflowValidator)); } catch (Exception e) { logger.error("Failed to retrieve workflow step json.", e); listener.onFailure(new FlowFrameworkException(e.getMessage(), ExceptionsHelper.status(e))); diff --git a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java index d3d3611f1..2a5536638 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java +++ b/src/main/java/org/opensearch/flowframework/workflow/ToolStep.java @@ -106,6 +106,7 @@ private Map getToolsParametersMap( Map previousNodeInputs, Map outputs ) { + @SuppressWarnings("unchecked") Map parametersMap = (Map) parameters; Optional previousNodeModel = previousNodeInputs.entrySet() .stream() diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java index 4b4078098..ac6d75d58 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowProcessSorter.java @@ -19,7 +19,6 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; -import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.plugins.PluginInfo; import org.opensearch.plugins.PluginsService; import org.opensearch.threadpool.ThreadPool; @@ -43,6 +42,10 @@ import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_DEFAULT_VALUE; import static org.opensearch.flowframework.model.WorkflowNode.NODE_TIMEOUT_FIELD; import static org.opensearch.flowframework.model.WorkflowNode.USER_INPUTS_FIELD; +import static org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps.getInputByWorkflowType; +import static org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps.getOutputByWorkflowType; +import static org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps.getRequiredPluginsByWorkflowType; +import static org.opensearch.flowframework.workflow.WorkflowStepFactory.WorkflowSteps.getTimeoutByWorkflowType; /** * Converts a workflow of nodes and edges into a topologically sorted list of Process Nodes. @@ -139,31 +142,28 @@ public List sortProcessNodes(Workflow workflow, String workflowId) * @throws Exception if validation fails */ public void validate(List processNodes, PluginsService pluginsService) throws Exception { - WorkflowValidator validator = readWorkflowValidator(); List installedPlugins = pluginsService.info() .getPluginInfos() .stream() .map(PluginInfo::getName) .collect(Collectors.toList()); - validatePluginsInstalled(processNodes, validator, installedPlugins); - validateGraph(processNodes, validator); + validatePluginsInstalled(processNodes, installedPlugins); + validateGraph(processNodes); } /** * Validates a sorted workflow, determines if each process node's required plugins are currently installed * @param processNodes A list of process nodes - * @param validator The validation definitions for the workflow steps * @param installedPlugins The list of installed plugins * @throws Exception on validation failure */ - public void validatePluginsInstalled(List processNodes, WorkflowValidator validator, List installedPlugins) - throws Exception { + public void validatePluginsInstalled(List processNodes, List installedPlugins) throws Exception { // Iterate through process nodes in graph for (ProcessNode processNode : processNodes) { // Retrieve required plugins of this node based on type String nodeType = processNode.workflowStep().getName(); - List requiredPlugins = new ArrayList<>(validator.getWorkflowStepValidators().get(nodeType).getRequiredPlugins()); + List requiredPlugins = new ArrayList<>(getRequiredPluginsByWorkflowType(nodeType)); if (!installedPlugins.containsAll(requiredPlugins)) { requiredPlugins.removeAll(installedPlugins); throw new FlowFrameworkException( @@ -180,10 +180,9 @@ public void validatePluginsInstalled(List processNodes, WorkflowVal /** * Validates a sorted workflow, determines if each process node's user inputs and predecessor outputs match the expected workflow step inputs * @param processNodes A list of process nodes - * @param validator The validation definitions for the workflow steps * @throws Exception on validation failure */ - public void validateGraph(List processNodes, WorkflowValidator validator) throws Exception { + public void validateGraph(List processNodes) throws Exception { // Iterate through process nodes in graph for (ProcessNode processNode : processNodes) { @@ -196,7 +195,7 @@ public void validateGraph(List processNodes, WorkflowValidator vali // Compile a list of outputs from the predecessor nodes based on type List predecessorOutputs = predecessorNodeTypes.stream() - .map(nodeType -> validator.getWorkflowStepValidators().get(nodeType).getOutputs()) + .map(nodeType -> getOutputByWorkflowType(nodeType)) .flatMap(Collection::stream) .collect(Collectors.toList()); @@ -208,9 +207,7 @@ public void validateGraph(List processNodes, WorkflowValidator vali .collect(Collectors.toList()); // Retrieve list of required inputs from the current process node and compare - List expectedInputs = new ArrayList<>( - validator.getWorkflowStepValidators().get(processNode.workflowStep().getName()).getInputs() - ); + List expectedInputs = new ArrayList<>(getInputByWorkflowType(processNode.workflowStep().getName())); if (!allInputs.containsAll(expectedInputs)) { expectedInputs.removeAll(allInputs); @@ -225,18 +222,6 @@ public void validateGraph(List processNodes, WorkflowValidator vali } } - private WorkflowValidator readWorkflowValidator() { - try { - return WorkflowValidator.parse("mappings/workflow-steps.json"); - } catch (Exception e) { - logger.error("Failed at reading workflow-steps mapping file", e); - throw new FlowFrameworkException( - "Failed at reading workflow-steps.json mapping file for a new workflow.", - RestStatus.INTERNAL_SERVER_ERROR - ); - } - } - /** * A method for parsing workflow timeout value. * The value could be parsed from node NODE_TIMEOUT_FIELD, the timeout field in workflow-step.json, @@ -245,9 +230,7 @@ private WorkflowValidator readWorkflowValidator() { * @return the timeout value */ protected TimeValue parseTimeout(WorkflowNode node) { - WorkflowValidator validator = readWorkflowValidator(); - TimeValue nodeTimeoutValue = Optional.ofNullable(validator.getWorkflowStepValidators().get(node.type()).getTimeout()) - .orElse(NODE_TIMEOUT_DEFAULT_VALUE); + TimeValue nodeTimeoutValue = Optional.ofNullable(getTimeoutByWorkflowType(node.type())).orElse(NODE_TIMEOUT_DEFAULT_VALUE); String nodeTimeoutAsString = nodeTimeoutValue.getSeconds() + "s"; String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, nodeTimeoutAsString); String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD); diff --git a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java index ab63bd7a8..7f54b17cc 100644 --- a/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java +++ b/src/main/java/org/opensearch/flowframework/workflow/WorkflowStepFactory.java @@ -8,25 +8,63 @@ */ package org.opensearch.flowframework.workflow; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.client.Client; import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.unit.TimeValue; +import org.opensearch.core.common.Strings; import org.opensearch.core.rest.RestStatus; import org.opensearch.flowframework.common.FlowFrameworkSettings; import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; +import org.opensearch.flowframework.model.WorkflowStepValidator; +import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.threadpool.ThreadPool; +import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.function.Supplier; +import static org.opensearch.flowframework.common.CommonValue.ACTIONS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.CREDENTIAL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.DESCRIPTION_FIELD; +import static org.opensearch.flowframework.common.CommonValue.EMBEDDING_DIMENSION; +import static org.opensearch.flowframework.common.CommonValue.FRAMEWORK_TYPE; +import static org.opensearch.flowframework.common.CommonValue.FUNCTION_NAME; +import static org.opensearch.flowframework.common.CommonValue.MODEL_CONTENT_HASH_VALUE; +import static org.opensearch.flowframework.common.CommonValue.MODEL_FORMAT; +import static org.opensearch.flowframework.common.CommonValue.MODEL_GROUP_STATUS; +import static org.opensearch.flowframework.common.CommonValue.MODEL_TYPE; +import static org.opensearch.flowframework.common.CommonValue.NAME_FIELD; +import static org.opensearch.flowframework.common.CommonValue.OPENSEARCH_ML; +import static org.opensearch.flowframework.common.CommonValue.PARAMETERS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.PROTOCOL_FIELD; +import static org.opensearch.flowframework.common.CommonValue.REGISTER_MODEL_STATUS; +import static org.opensearch.flowframework.common.CommonValue.SUCCESS; +import static org.opensearch.flowframework.common.CommonValue.TOOLS_FIELD; +import static org.opensearch.flowframework.common.CommonValue.TYPE; +import static org.opensearch.flowframework.common.CommonValue.URL; +import static org.opensearch.flowframework.common.CommonValue.VERSION_FIELD; +import static org.opensearch.flowframework.common.WorkflowResources.AGENT_ID; +import static org.opensearch.flowframework.common.WorkflowResources.CONNECTOR_ID; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_GROUP_ID; +import static org.opensearch.flowframework.common.WorkflowResources.MODEL_ID; + /** * Generates instances implementing {@link WorkflowStep}. */ public class WorkflowStepFactory { private final Map> stepMap = new HashMap<>(); + private static final Logger logger = LogManager.getLogger(WorkflowStepFactory.class); + private static ThreadPool threadPool; + private static MachineLearningNodeClient mlClient; + private static FlowFrameworkIndicesHandler flowFrameworkIndicesHandler; + private static FlowFrameworkSettings flowFrameworkSettings; /** * Instantiate this class. @@ -46,32 +84,327 @@ public WorkflowStepFactory( FlowFrameworkIndicesHandler flowFrameworkIndicesHandler, FlowFrameworkSettings flowFrameworkSettings ) { - stepMap.put(NoOpStep.NAME, NoOpStep::new); - stepMap.put( + this.threadPool = threadPool; + this.mlClient = mlClient; + this.flowFrameworkIndicesHandler = flowFrameworkIndicesHandler; + this.flowFrameworkSettings = flowFrameworkSettings; + // Initialize the WorkflowSteps enum inside the constructor + for (WorkflowSteps workflowStep : WorkflowSteps.values()) { + stepMap.put(workflowStep.getWorkflowStepName(), workflowStep.step()); + } + } + + /** + * Enum encapsulating the different step names, their inputs, outputs, required plugin and timeout of the step + */ + + public enum WorkflowSteps { + + /** Noop Step */ + NOOP("noop", Collections.emptyList(), Collections.emptyList(), Collections.emptyList(), null, NoOpStep::new), + + /** Create Connector Step */ + CREATE_CONNECTOR( + CreateConnectorStep.NAME, + List.of(NAME_FIELD, DESCRIPTION_FIELD, VERSION_FIELD, PROTOCOL_FIELD, PARAMETERS_FIELD, CREDENTIAL_FIELD, ACTIONS_FIELD), + List.of(CONNECTOR_ID), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), + () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler) + ), + + /** Register Local Custom Model Step */ + REGISTER_LOCAL_CUSTOM_MODEL( RegisterLocalCustomModelStep.NAME, + List.of( + NAME_FIELD, + VERSION_FIELD, + MODEL_FORMAT, + FUNCTION_NAME, + MODEL_CONTENT_HASH_VALUE, + URL, + MODEL_TYPE, + EMBEDDING_DIMENSION, + FRAMEWORK_TYPE + ), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), () -> new RegisterLocalCustomModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put( + ), + + /** Register Local Sparse Encoding Model Step */ + REGISTER_LOCAL_SPARSE_ENCODING_MODEL( RegisterLocalSparseEncodingModelStep.NAME, + List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT, FUNCTION_NAME, MODEL_CONTENT_HASH_VALUE, URL), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), () -> new RegisterLocalSparseEncodingModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put( + ), + + /** Register Local Pretrained Model Step */ + REGISTER_LOCAL_PRETRAINED_MODEL( RegisterLocalPretrainedModelStep.NAME, + List.of(NAME_FIELD, VERSION_FIELD, MODEL_FORMAT), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(60), () -> new RegisterLocalPretrainedModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put(RegisterRemoteModelStep.NAME, () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteModelStep.NAME, () -> new DeleteModelStep(mlClient)); - stepMap.put( + ), + + /** Register Remote Model Step */ + REGISTER_REMOTE_MODEL( + RegisterRemoteModelStep.NAME, + List.of(NAME_FIELD, CONNECTOR_ID), + List.of(MODEL_ID, REGISTER_MODEL_STATUS), + List.of(OPENSEARCH_ML), + null, + () -> new RegisterRemoteModelStep(mlClient, flowFrameworkIndicesHandler) + ), + + /** Register Model Group Step */ + REGISTER_MODEL_GROUP( + RegisterModelGroupStep.NAME, + List.of(NAME_FIELD), + List.of(MODEL_GROUP_ID, MODEL_GROUP_STATUS), + List.of(OPENSEARCH_ML), + null, + () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler) + ), + + /** Deploy Model Step */ + DEPLOY_MODEL( DeployModelStep.NAME, + List.of(MODEL_ID), + List.of(MODEL_ID), + List.of(OPENSEARCH_ML), + TimeValue.timeValueSeconds(15), () -> new DeployModelStep(threadPool, mlClient, flowFrameworkIndicesHandler, flowFrameworkSettings) - ); - stepMap.put(UndeployModelStep.NAME, () -> new UndeployModelStep(mlClient)); - stepMap.put(CreateConnectorStep.NAME, () -> new CreateConnectorStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteConnectorStep.NAME, () -> new DeleteConnectorStep(mlClient)); - stepMap.put(RegisterModelGroupStep.NAME, () -> new RegisterModelGroupStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(ToolStep.NAME, ToolStep::new); - stepMap.put(RegisterAgentStep.NAME, () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler)); - stepMap.put(DeleteAgentStep.NAME, () -> new DeleteAgentStep(mlClient)); + ), + + /** Undeploy Model Step */ + UNDEPLOY_MODEL( + UndeployModelStep.NAME, + List.of(MODEL_ID), + List.of(SUCCESS), + List.of(OPENSEARCH_ML), + null, + () -> new UndeployModelStep(mlClient) + ), + + /** Delete Model Step */ + DELETE_MODEL( + DeleteModelStep.NAME, + List.of(MODEL_ID), + List.of(MODEL_ID), + List.of(OPENSEARCH_ML), + null, + () -> new DeleteModelStep(mlClient) + ), + + /** Delete Connector Step */ + DELETE_CONNECTOR( + DeleteConnectorStep.NAME, + List.of(CONNECTOR_ID), + List.of(CONNECTOR_ID), + List.of(OPENSEARCH_ML), + null, + () -> new DeleteConnectorStep(mlClient) + ), + + /** Register Agent Step */ + REGISTER_AGENT( + RegisterAgentStep.NAME, + List.of(NAME_FIELD, TYPE), + List.of(AGENT_ID), + List.of(OPENSEARCH_ML), + null, + () -> new RegisterAgentStep(mlClient, flowFrameworkIndicesHandler) + ), + + /** Delete Agent Step */ + DELETE_AGENT( + DeleteAgentStep.NAME, + List.of(AGENT_ID), + List.of(AGENT_ID), + List.of(OPENSEARCH_ML), + null, + () -> new DeleteAgentStep(mlClient) + ), + + /** Create Tool Step */ + CREATE_TOOL(ToolStep.NAME, List.of(TYPE), List.of(TOOLS_FIELD), List.of(OPENSEARCH_ML), null, ToolStep::new); + + private final String workflowStepName; + private final List inputs; + private final List outputs; + private final List requiredPlugins; + private final TimeValue timeout; + private final Supplier workflowStep; + + WorkflowSteps( + String workflowStepName, + List inputs, + List outputs, + List requiredPlugins, + TimeValue timeout, + Supplier workflowStep + ) { + this.workflowStepName = workflowStepName; + this.inputs = List.copyOf(inputs); + this.outputs = List.copyOf(outputs); + this.requiredPlugins = requiredPlugins; + this.timeout = timeout; + this.workflowStep = workflowStep; + } + + /** + * Returns the workflowStep for the given enum Constant + * @return the workflowStep of this data. + */ + public String getWorkflowStepName() { + return workflowStepName; + } + + /** + * Get the required inputs + * @return the inputs + */ + public List inputs() { + return inputs; + } + + /** + * Get the required outputs + * @return the outputs + */ + public List outputs() { + return outputs; + } + + /** + * Get the required plugins + * @return the required plugins + */ + public List requiredPlugins() { + return requiredPlugins; + } + + /** + * Get the timeout + * @return the timeout + */ + public TimeValue timeout() { + return timeout; + } + + /** + * Get the step + * @return the step + */ + public Supplier step() { + return workflowStep; + } + + /** + * Get the workflow step validator object + * @return the WorkflowStepValidator + */ + public WorkflowStepValidator getWorkflowStepValidator() { + return new WorkflowStepValidator(workflowStepName, inputs, outputs, requiredPlugins, timeout); + }; + + /** + * Gets the timeout based on the workflowStep. + * @param workflowStep workflow step type + * @return the resource that will be created + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static TimeValue getTimeoutByWorkflowType(String workflowStep) throws FlowFrameworkException { + if (!Strings.isNullOrEmpty(workflowStep)) { + for (WorkflowSteps mapping : values()) { + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.timeout(); + } + } + } + String errorMessage = "Unable to find workflow timeout for step: " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + + /** + * Gets the required plugins based on the workflowStep. + * @param workflowStep workflow step type + * @return the resource that will be created + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static List getRequiredPluginsByWorkflowType(String workflowStep) throws FlowFrameworkException { + if (!Strings.isNullOrEmpty(workflowStep)) { + for (WorkflowSteps mapping : values()) { + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.requiredPlugins(); + } + } + } + String errorMessage = "Unable to find workflow required plugins for step: " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + + /** + * Gets the output based on the workflowStep. + * @param workflowStep workflow step type + * @return the resource that will be created + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static List getOutputByWorkflowType(String workflowStep) throws FlowFrameworkException { + if (!Strings.isNullOrEmpty(workflowStep)) { + for (WorkflowSteps mapping : values()) { + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.outputs(); + } + } + } + String errorMessage = "Unable to find workflow output for step " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + + /** + * Gets the input based on the workflowStep. + * @param workflowStep workflow step type + * @return the resource that will be created + * @throws FlowFrameworkException if workflow step doesn't exist in enum + */ + public static List getInputByWorkflowType(String workflowStep) throws FlowFrameworkException { + if (!Strings.isNullOrEmpty(workflowStep)) { + for (WorkflowSteps mapping : values()) { + if (workflowStep.equals(mapping.getWorkflowStepName())) { + return mapping.inputs(); + } + } + } + String errorMessage = "Unable to find workflow input for step: " + workflowStep; + logger.error(errorMessage); + throw new FlowFrameworkException(errorMessage, RestStatus.BAD_REQUEST); + } + + } + + /** + * Get the object of WorkflowValidator consisting of workflow steps + * @return WorkflowValidator + */ + public WorkflowValidator getWorkflowValidator() { + Map workflowStepValidators = new HashMap<>(); + + for (WorkflowSteps mapping : WorkflowSteps.values()) { + workflowStepValidators.put(mapping.getWorkflowStepName(), mapping.getWorkflowStepValidator()); + } + + return new WorkflowValidator(workflowStepValidators); } /** diff --git a/src/main/resources/mappings/workflow-steps.json b/src/main/resources/mappings/workflow-steps.json deleted file mode 100644 index 6b7793696..000000000 --- a/src/main/resources/mappings/workflow-steps.json +++ /dev/null @@ -1,183 +0,0 @@ -{ - "noop": { - "inputs":[], - "outputs":[], - "required_plugins":[] - }, - "create_connector": { - "inputs":[ - "name", - "description", - "version", - "protocol", - "parameters", - "credential", - "actions" - ], - "outputs":[ - "connector_id" - ], - "required_plugins":[ - "opensearch-ml" - ], - "timeout": "60s" - }, - "delete_connector": { - "inputs": [ - "connector_id" - ], - "outputs":[ - "connector_id" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "register_local_custom_model": { - "inputs":[ - "name", - "version", - "model_format", - "function_name", - "model_content_hash_value", - "url", - "model_type", - "embedding_dimension", - "framework_type" - ], - "outputs":[ - "model_id", - "register_model_status" - ], - "required_plugins":[ - "opensearch-ml" - ], - "timeout": "60s" - }, - "register_local_sparse_encoding_model": { - "inputs":[ - "name", - "version", - "model_format", - "function_name", - "model_content_hash_value", - "url" - ], - "outputs":[ - "model_id", - "register_model_status" - ], - "required_plugins":[ - "opensearch-ml" - ], - "timeout": "60s" - }, - "register_local_pretrained_model": { - "inputs":[ - "name", - "version", - "model_format" - ], - "outputs":[ - "model_id", - "register_model_status" - ], - "required_plugins":[ - "opensearch-ml" - ], - "timeout": "60s" - }, - "register_remote_model": { - "inputs": [ - "name", - "connector_id" - ], - "outputs": [ - "model_id", - "register_model_status" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "delete_model": { - "inputs": [ - "model_id" - ], - "outputs":[ - "model_id" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "deploy_model": { - "inputs":[ - "model_id" - ], - "outputs":[ - "model_id" - ], - "required_plugins":[ - "opensearch-ml" - ], - "timeout": "15s" - }, - "undeploy_model": { - "inputs":[ - "model_id" - ], - "outputs":[ - "success" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "register_model_group": { - "inputs":[ - "name" - ], - "outputs":[ - "model_group_id", - "model_group_status" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "register_agent": { - "inputs":[ - "name", - "type" - ], - "outputs":[ - "agent_id" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "delete_agent": { - "inputs": [ - "agent_id" - ], - "outputs":[ - "agent_id" - ], - "required_plugins":[ - "opensearch-ml" - ] - }, - "create_tool": { - "inputs": [ - "type" - ], - "outputs": [ - "tools" - ], - "required_plugins":[ - "opensearch-ml" - ] - } -} diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java index 1426ebfae..f155e1f90 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowStepValidatorTests.java @@ -8,41 +8,33 @@ */ package org.opensearch.flowframework.model; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.flowframework.exception.FlowFrameworkException; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.test.OpenSearchTestCase; import java.io.IOException; +import java.util.HashMap; +import java.util.Map; public class WorkflowStepValidatorTests extends OpenSearchTestCase { - private String validValidator; - private String invalidValidator; - @Override public void setUp() throws Exception { super.setUp(); - validValidator = "{\"inputs\":[\"input_value\"],\"outputs\":[\"output_value\"]}"; - invalidValidator = "{\"inputs\":[\"input_value\"],\"invalid_field\":[\"output_value\"]}"; } public void testParseWorkflowStepValidator() throws IOException { - XContentParser parser = TemplateTestJsonUtil.jsonToParser(validValidator); - WorkflowStepValidator workflowStepValidator = WorkflowStepValidator.parse(parser); - assertEquals(1, workflowStepValidator.getInputs().size()); - assertEquals(1, workflowStepValidator.getOutputs().size()); + Map workflowStepValidators = new HashMap<>(); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepValidator() + ); - assertEquals("input_value", workflowStepValidator.getInputs().get(0)); - assertEquals("output_value", workflowStepValidator.getOutputs().get(0)); - } + assertEquals(7, WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.inputs().size()); + assertEquals(1, WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.outputs().size()); - public void testFailedParseWorkflowStepValidator() throws IOException { - XContentParser parser = TemplateTestJsonUtil.jsonToParser(invalidValidator); - FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> WorkflowStepValidator.parse(parser)); - assertEquals("Unable to parse field [invalid_field] in a WorkflowStepValidator object.", ex.getMessage()); - assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); + assertEquals("name", WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.inputs().get(0)); + assertEquals("connector_id", WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.outputs().get(0)); } } diff --git a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java index 46f1e7950..73fb3186c 100644 --- a/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java +++ b/src/test/java/org/opensearch/flowframework/model/WorkflowValidatorTests.java @@ -15,10 +15,7 @@ import org.opensearch.common.settings.ClusterSettings; import org.opensearch.common.settings.Setting; import org.opensearch.common.settings.Settings; -import org.opensearch.core.rest.RestStatus; -import org.opensearch.core.xcontent.XContentParser; import org.opensearch.flowframework.common.FlowFrameworkSettings; -import org.opensearch.flowframework.exception.FlowFrameworkException; import org.opensearch.flowframework.indices.FlowFrameworkIndicesHandler; import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.ml.client.MachineLearningNodeClient; @@ -27,7 +24,9 @@ import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -41,41 +40,135 @@ public class WorkflowValidatorTests extends OpenSearchTestCase { - private String validWorkflowStepJson; - private String invalidWorkflowStepJson; private FlowFrameworkSettings flowFrameworkSettings; @Override public void setUp() throws Exception { super.setUp(); - validWorkflowStepJson = - "{\"workflow_step_1\":{\"inputs\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; - invalidWorkflowStepJson = - "{\"workflow_step_1\":{\"bad_field\":[\"input_1\",\"input_2\"],\"outputs\":[\"output_1\"]},\"workflow_step_2\":{\"inputs\":[\"input_1\",\"input_2\",\"input_3\"],\"outputs\":[\"output_1\",\"output_2\",\"output_3\"]}}"; flowFrameworkSettings = mock(FlowFrameworkSettings.class); when(flowFrameworkSettings.isFlowFrameworkEnabled()).thenReturn(true); } public void testParseWorkflowValidator() throws IOException { + Map workflowStepValidators = new HashMap<>(); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.CREATE_CONNECTOR.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DELETE_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DEPLOY_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DEPLOY_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_CUSTOM_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_PRETRAINED_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_LOCAL_SPARSE_ENCODING_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_REMOTE_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_MODEL_GROUP.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.REGISTER_AGENT.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.CREATE_TOOL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.CREATE_TOOL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.UNDEPLOY_MODEL.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.UNDEPLOY_MODEL.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DELETE_CONNECTOR.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DELETE_CONNECTOR.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.DELETE_AGENT.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.DELETE_AGENT.getWorkflowStepValidator() + ); + workflowStepValidators.put( + WorkflowStepFactory.WorkflowSteps.NOOP.getWorkflowStepName(), + WorkflowStepFactory.WorkflowSteps.NOOP.getWorkflowStepValidator() + ); - XContentParser parser = TemplateTestJsonUtil.jsonToParser(validWorkflowStepJson); - WorkflowValidator validator = WorkflowValidator.parse(parser); + WorkflowValidator validator = new WorkflowValidator(workflowStepValidators); - assertEquals(2, validator.getWorkflowStepValidators().size()); - assertTrue(validator.getWorkflowStepValidators().keySet().contains("workflow_step_1")); - assertEquals(2, validator.getWorkflowStepValidators().get("workflow_step_1").getInputs().size()); - assertEquals(1, validator.getWorkflowStepValidators().get("workflow_step_1").getOutputs().size()); - assertTrue(validator.getWorkflowStepValidators().keySet().contains("workflow_step_2")); - assertEquals(3, validator.getWorkflowStepValidators().get("workflow_step_2").getInputs().size()); - assertEquals(3, validator.getWorkflowStepValidators().get("workflow_step_2").getOutputs().size()); - } + assertEquals(14, validator.getWorkflowStepValidators().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_connector")); + assertEquals(7, validator.getWorkflowStepValidators().get("create_connector").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("create_connector").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("delete_model")); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_model").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("deploy_model")); + assertEquals(1, validator.getWorkflowStepValidators().get("deploy_model").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("deploy_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_remote_model")); + assertEquals(2, validator.getWorkflowStepValidators().get("register_remote_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_remote_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_model_group")); + assertEquals(1, validator.getWorkflowStepValidators().get("register_model_group").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_model_group").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_local_custom_model")); + assertEquals(9, validator.getWorkflowStepValidators().get("register_local_custom_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_local_custom_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_local_sparse_encoding_model")); + assertEquals(6, validator.getWorkflowStepValidators().get("register_local_sparse_encoding_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_local_sparse_encoding_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_local_pretrained_model")); + assertEquals(3, validator.getWorkflowStepValidators().get("register_local_pretrained_model").getInputs().size()); + assertEquals(2, validator.getWorkflowStepValidators().get("register_local_pretrained_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("undeploy_model")); + assertEquals(1, validator.getWorkflowStepValidators().get("undeploy_model").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("undeploy_model").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("delete_connector")); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_connector").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_connector").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("register_agent")); + assertEquals(2, validator.getWorkflowStepValidators().get("register_agent").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("register_agent").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("delete_agent")); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_agent").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("delete_agent").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("create_tool")); + assertEquals(1, validator.getWorkflowStepValidators().get("create_tool").getInputs().size()); + assertEquals(1, validator.getWorkflowStepValidators().get("create_tool").getOutputs().size()); + + assertTrue(validator.getWorkflowStepValidators().keySet().contains("noop")); + assertEquals(0, validator.getWorkflowStepValidators().get("noop").getInputs().size()); + assertEquals(0, validator.getWorkflowStepValidators().get("noop").getOutputs().size()); - public void testFailedParseWorkflowValidator() throws IOException { - XContentParser parser = TemplateTestJsonUtil.jsonToParser(invalidWorkflowStepJson); - FlowFrameworkException ex = expectThrows(FlowFrameworkException.class, () -> WorkflowValidator.parse(parser)); - assertEquals("Unable to parse field [bad_field] in a WorkflowStepValidator object.", ex.getMessage()); - assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); } public void testWorkflowStepFactoryHasValidators() throws IOException { @@ -106,8 +199,7 @@ public void testWorkflowStepFactoryHasValidators() throws IOException { flowFrameworkSettings ); - // Read in workflow-steps.json - WorkflowValidator workflowValidator = WorkflowValidator.parse("mappings/workflow-steps.json"); + WorkflowValidator workflowValidator = workflowStepFactory.getWorkflowValidator(); // Get all workflow step validator types List registeredWorkflowValidatorTypes = new ArrayList(workflowValidator.getWorkflowStepValidators().keySet()); diff --git a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java index 6e6e39f80..685198e3d 100644 --- a/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java +++ b/src/test/java/org/opensearch/flowframework/transport/GetWorkflowStepTransportActionTests.java @@ -10,6 +10,7 @@ import org.opensearch.action.support.ActionFilters; import org.opensearch.core.action.ActionListener; +import org.opensearch.flowframework.workflow.WorkflowStepFactory; import org.opensearch.tasks.Task; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.transport.TransportService; @@ -30,7 +31,11 @@ public class GetWorkflowStepTransportActionTests extends OpenSearchTestCase { public void setUp() throws Exception { super.setUp(); - this.getWorkflowStepTransportAction = new GetWorkflowStepTransportAction(mock(TransportService.class), mock(ActionFilters.class)); + this.getWorkflowStepTransportAction = new GetWorkflowStepTransportAction( + mock(TransportService.class), + mock(ActionFilters.class), + mock(WorkflowStepFactory.class) + ); } public void testGetWorkflowStepAction() throws IOException { diff --git a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java index d66e4f391..7a2250ca5 100644 --- a/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java +++ b/src/test/java/org/opensearch/flowframework/workflow/WorkflowProcessSorterTests.java @@ -24,7 +24,6 @@ import org.opensearch.flowframework.model.Workflow; import org.opensearch.flowframework.model.WorkflowEdge; import org.opensearch.flowframework.model.WorkflowNode; -import org.opensearch.flowframework.model.WorkflowValidator; import org.opensearch.ml.client.MachineLearningNodeClient; import org.opensearch.test.OpenSearchTestCase; import org.opensearch.threadpool.ScalingExecutorBuilder; @@ -86,8 +85,8 @@ private static List parse(String json) throws IOException { private static WorkflowProcessSorter workflowProcessSorter; private static Client client = mock(Client.class); private static ClusterService clusterService = mock(ClusterService.class); - private static WorkflowValidator validator; private static FlowFrameworkSettings flowFrameworkSettings; + private static WorkflowStepFactory workflowStepFactory; @BeforeClass public static void setup() throws IOException { @@ -128,7 +127,6 @@ public static void setup() throws IOException { flowFrameworkSettings ); workflowProcessSorter = new WorkflowProcessSorter(factory, testThreadPool, clusterService, client, flowFrameworkSettings); - validator = WorkflowValidator.parse("mappings/workflow-steps.json"); } @AfterClass @@ -386,7 +384,7 @@ public void testSuccessfulGraphValidation() throws Exception { ); List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); - workflowProcessSorter.validateGraph(sortedProcessNodes, validator); + workflowProcessSorter.validateGraph(sortedProcessNodes); } public void testFailedGraphValidation() throws IOException { @@ -410,7 +408,7 @@ public void testFailedGraphValidation() throws IOException { List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); FlowFrameworkException ex = expectThrows( FlowFrameworkException.class, - () -> workflowProcessSorter.validateGraph(sortedProcessNodes, validator) + () -> workflowProcessSorter.validateGraph(sortedProcessNodes) ); assertEquals("Invalid workflow, node [workflow_step_1] missing the following required inputs : [connector_id]", ex.getMessage()); assertEquals(RestStatus.BAD_REQUEST, ex.getRestStatus()); @@ -455,11 +453,7 @@ public void testSuccessfulInstalledPluginValidation() throws Exception { ); List sortedProcessNodes = workflowProcessSorter.sortProcessNodes(workflow, "123"); - workflowProcessSorter.validatePluginsInstalled( - sortedProcessNodes, - validator, - List.of("opensearch-flow-framework", "opensearch-ml") - ); + workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, List.of("opensearch-flow-framework", "opensearch-ml")); } public void testFailedInstalledPluginValidation() throws Exception { @@ -503,7 +497,7 @@ public void testFailedInstalledPluginValidation() throws Exception { FlowFrameworkException exception = expectThrows( FlowFrameworkException.class, - () -> workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, validator, List.of("opensearch-flow-framework")) + () -> workflowProcessSorter.validatePluginsInstalled(sortedProcessNodes, List.of("opensearch-flow-framework")) ); assertEquals(