Skip to content

Commit

Permalink
Add timeout for each workflow steps (#317)
Browse files Browse the repository at this point in the history
* read timeout from workflow-steps

Signed-off-by: Jackie Han <[email protected]>

* change indices name

Signed-off-by: Jackie Han <[email protected]>

* address comments

Signed-off-by: Jackie Han <[email protected]>

* address comment - parse timeout into TimeValue

Signed-off-by: Jackie Han <[email protected]>

* address minor comments

Signed-off-by: Jackie Han <[email protected]>

---------

Signed-off-by: Jackie Han <[email protected]>
(cherry picked from commit 9fdf03c)
Signed-off-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] committed Dec 27, 2023
1 parent ccb5371 commit d1ab850
Show file tree
Hide file tree
Showing 6 changed files with 120 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ private CommonValue() {}
/** Schema Version field name */
public static final String SCHEMA_VERSION_FIELD = "schema_version";
/** Global Context Index Name */
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context";
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-flow-framework-templates";
/** Global Context index mapping file path */
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
/** Global Context index mapping version */
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
/** Workflow State Index Name */
public static final String WORKFLOW_STATE_INDEX = ".plugins-workflow-state";
public static final String WORKFLOW_STATE_INDEX = ".plugins-flow-framework-state";
/** Workflow State index mapping file path */
public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json";
/** Workflow State index mapping version */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.model;

import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand All @@ -23,6 +24,7 @@
import java.util.Map.Entry;
import java.util.Objects;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
Expand Down Expand Up @@ -50,7 +52,7 @@ public class WorkflowNode implements ToXContentObject {
/** The field defining the timeout value for this node */
public static final String NODE_TIMEOUT_FIELD = "node_timeout";
/** The default timeout value if the template doesn't override it */
public static final String NODE_TIMEOUT_DEFAULT_VALUE = "15s";
public static final TimeValue NODE_TIMEOUT_DEFAULT_VALUE = new TimeValue(10, SECONDS);

private final String id; // unique id
private final String type; // maps to a WorkflowStep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
*/
package org.opensearch.flowframework.model;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -21,27 +26,34 @@
*/
public class WorkflowStepValidator {

private static final Logger logger = LogManager.getLogger(WorkflowStepValidator.class);

/** Inputs field name */
private static final String INPUTS_FIELD = "inputs";
/** Outputs field name */
private static final String OUTPUTS_FIELD = "outputs";
/** Required Plugins field name */
private static final String REQUIRED_PLUGINS = "required_plugins";
/** Timeout field name */
private static final String TIMEOUT = "timeout";

private List<String> inputs;
private List<String> outputs;
private List<String> requiredPlugins;
private TimeValue timeout;

/**
* Intantiate the object representing a Workflow Step validator
* Instantiate the object representing a Workflow Step validator
* @param inputs the workflow step inputs
* @param outputs the workflow step outputs
* @param requiredPlugins the required plugins for this workflow step
* @param timeout the timeout for this workflow step
*/
public WorkflowStepValidator(List<String> inputs, List<String> outputs, List<String> requiredPlugins) {
public WorkflowStepValidator(List<String> inputs, List<String> outputs, List<String> requiredPlugins, TimeValue timeout) {
this.inputs = inputs;
this.outputs = outputs;
this.requiredPlugins = requiredPlugins;
this.timeout = timeout;
}

/**
Expand All @@ -54,6 +66,7 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept
List<String> parsedInputs = new ArrayList<>();
List<String> parsedOutputs = new ArrayList<>();
List<String> requiredPlugins = new ArrayList<>();
TimeValue timeout = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -78,11 +91,22 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept
requiredPlugins.add(parser.text());
}
break;
case TIMEOUT:
try {
timeout = TimeValue.parseTimeValue(parser.text(), TIMEOUT);
} catch (IllegalArgumentException e) {
logger.error("Failed to parse TIMEOUT value for field [{}]", fieldName, e);
throw new FlowFrameworkException(
"Failed to parse workflow-step.json file for field [" + fieldName + "]",
RestStatus.INTERNAL_SERVER_ERROR
);
}
break;
default:
throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object.");
}
}
return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins);
return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins, timeout);
}

/**
Expand All @@ -103,9 +127,17 @@ public List<String> getOutputs() {

/**
* Get the required plugins
* @return the outputs
* @return the required plugins
*/
public List<String> getRequiredPlugins() {
return List.copyOf(requiredPlugins);
}

/**
* Get the timeout
* @return the timeout
*/
public TimeValue getTimeout() {
return timeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -131,7 +132,6 @@ public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId)
idToNodeMap.put(processNode.id(), processNode);
nodes.add(processNode);
}

return nodes;
}

Expand All @@ -141,7 +141,7 @@ public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId)
* @throws Exception if validation fails
*/
public void validate(List<ProcessNode> processNodes) throws Exception {
WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json");
WorkflowValidator validator = readWorkflowValidator();
validatePluginsInstalled(processNodes, validator);
validateGraph(processNodes, validator);
}
Expand Down Expand Up @@ -244,20 +244,43 @@ public void validateGraph(List<ProcessNode> processNodes, WorkflowValidator vali
);
}
}
}

private WorkflowValidator readWorkflowValidator() {
try {
return WorkflowValidator.parse("mappings/workflow-steps.json");
} catch (Exception e) {
logger.error("Failed at reading workflow-steps mapping file", e);
throw new FlowFrameworkException(
"Failed at reading workflow-steps.json mapping file for a new workflow.",
RestStatus.INTERNAL_SERVER_ERROR
);
}
}

private TimeValue parseTimeout(WorkflowNode node) {
String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE);
/**
* A method for parsing workflow timeout value.
* The value could be parsed from node NODE_TIMEOUT_FIELD, the timeout field in workflow-step.json,
* or the default NODE_TIMEOUT_DEFAULT_VALUE
* @param node the workflow node
* @return the timeout value
*/
protected TimeValue parseTimeout(WorkflowNode node) {
WorkflowValidator validator = readWorkflowValidator();
TimeValue nodeTimeoutValue = Optional.ofNullable(validator.getWorkflowStepValidators().get(node.type()).getTimeout())
.orElse(NODE_TIMEOUT_DEFAULT_VALUE);
String nodeTimeoutAsString = nodeTimeoutValue.getSeconds() + "s";
String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, nodeTimeoutAsString);
String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD);
TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName);
if (timeValue.millis() < 0) {
TimeValue userInputTimeValue = TimeValue.parseTimeValue(timeoutValue, fieldName);

if (userInputTimeValue.millis() < 0) {
throw new FlowFrameworkException(
"Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive",
RestStatus.BAD_REQUEST
);
}
return timeValue;
return userInputTimeValue;
}

private static List<WorkflowNode> topologicalSort(List<WorkflowNode> workflowNodes, List<WorkflowEdge> workflowEdges) {
Expand Down
6 changes: 4 additions & 2 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
],
"required_plugins":[
"opensearch-ml"
]
],
"timeout": "60s"
},
"register_remote_model": {
"inputs": [
Expand Down Expand Up @@ -109,7 +110,8 @@
],
"required_plugins":[
"opensearch-ml"
]
],
"timeout": "15s"
},
"undeploy_model": {
"inputs":[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -136,7 +137,7 @@ public void testNodeDetails() throws IOException {
ProcessNode node = workflow.get(0);
assertEquals("default_timeout", node.id());
assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass());
assertEquals(15, node.nodeTimeout().seconds());
assertEquals(10, node.nodeTimeout().seconds());
node = workflow.get(1);
assertEquals("custom_timeout", node.id());
assertEquals(CreateIndexStep.class, node.workflowStep().getClass());
Expand Down Expand Up @@ -317,7 +318,7 @@ public void testSuccessfulGraphValidation() throws Exception {
workflowProcessSorter.validateGraph(sortedProcessNodes, validator);
}

public void testFailedGraphValidation() {
public void testFailedGraphValidation() throws IOException {

// Create Register Model workflow node with missing connector_id field
WorkflowNode registerModel = new WorkflowNode(
Expand Down Expand Up @@ -509,4 +510,46 @@ public void testFailedInstalledPluginValidation() throws Exception {
exception.getMessage()
);
}

public void testReadWorkflowStepFile_withDefaultTimeout() throws IOException {
// read timeout from node NODE_TIMEOUT_FIELD
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
CreateConnectorStep.NAME,
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", ""),
Map.entry("node_timeout", "50s")
)
);
TimeValue createConnectorTimeout = workflowProcessSorter.parseTimeout(createConnector);
assertEquals(createConnectorTimeout.getSeconds(), 50);

// read timeout from workflow-step.json overwrite value
WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
DeployModelStep.NAME,
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);
TimeValue deployModelTimeout = workflowProcessSorter.parseTimeout(deployModel);
assertEquals(deployModelTimeout.getSeconds(), 15);

// read timeout from NODE_TIMEOUT_DEFAULT_VALUE when there's no node NODE_TIMEOUT_FIELD
// and no overwrite timeout value in workflow-step.json
WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
RegisterRemoteModelStep.NAME,
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);
TimeValue registerRemoteModelTimeout = workflowProcessSorter.parseTimeout(registerModel);
assertEquals(registerRemoteModelTimeout.getSeconds(), 10);
}
}

0 comments on commit d1ab850

Please sign in to comment.