Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeout for each workflow steps #317

Merged
merged 5 commits into from
Dec 27, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,13 @@ private CommonValue() {}
/** Schema Version field name */
public static final String SCHEMA_VERSION_FIELD = "schema_version";
/** Global Context Index Name */
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-ai-global-context";
public static final String GLOBAL_CONTEXT_INDEX = ".plugins-flow-framework-templates";
/** Global Context index mapping file path */
public static final String GLOBAL_CONTEXT_INDEX_MAPPING = "mappings/global-context.json";
/** Global Context index mapping version */
public static final Integer GLOBAL_CONTEXT_INDEX_VERSION = 1;
/** Workflow State Index Name */
public static final String WORKFLOW_STATE_INDEX = ".plugins-workflow-state";
public static final String WORKFLOW_STATE_INDEX = ".plugins-flow-framework-state";
/** Workflow State index mapping file path */
public static final String WORKFLOW_STATE_INDEX_MAPPING = "mappings/workflow-state.json";
/** Workflow State index mapping version */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
*/
package org.opensearch.flowframework.model;

import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.xcontent.ToXContentObject;
import org.opensearch.core.xcontent.XContentBuilder;
import org.opensearch.core.xcontent.XContentParser;
Expand All @@ -23,6 +24,7 @@
import java.util.Map.Entry;
import java.util.Objects;

import static java.util.concurrent.TimeUnit.SECONDS;
import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken;
import static org.opensearch.flowframework.common.CommonValue.TOOLS_ORDER_FIELD;
import static org.opensearch.flowframework.util.ParseUtils.buildStringToObjectMap;
Expand Down Expand Up @@ -50,7 +52,7 @@ public class WorkflowNode implements ToXContentObject {
/** The field defining the timeout value for this node */
public static final String NODE_TIMEOUT_FIELD = "node_timeout";
/** The default timeout value if the template doesn't override it */
public static final String NODE_TIMEOUT_DEFAULT_VALUE = "15s";
public static final TimeValue NODE_TIMEOUT_DEFAULT_VALUE = new TimeValue(10, SECONDS);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved

private final String id; // unique id
private final String type; // maps to a WorkflowStep
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,12 @@
*/
package org.opensearch.flowframework.model;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
import org.opensearch.flowframework.exception.FlowFrameworkException;

import java.io.IOException;
import java.util.ArrayList;
Expand All @@ -21,27 +26,34 @@
*/
public class WorkflowStepValidator {

private static final Logger logger = LogManager.getLogger(WorkflowStepValidator.class);

/** Inputs field name */
private static final String INPUTS_FIELD = "inputs";
/** Outputs field name */
private static final String OUTPUTS_FIELD = "outputs";
/** Required Plugins field name */
private static final String REQUIRED_PLUGINS = "required_plugins";
/** Timeout field name */
private static final String TIMEOUT = "timeout";

private List<String> inputs;
private List<String> outputs;
private List<String> requiredPlugins;
private TimeValue timeout;

/**
* Intantiate the object representing a Workflow Step validator
* Instantiate the object representing a Workflow Step validator
* @param inputs the workflow step inputs
* @param outputs the workflow step outputs
* @param requiredPlugins the required plugins for this workflow step
* @param timeout the timeout for this workflow step
*/
public WorkflowStepValidator(List<String> inputs, List<String> outputs, List<String> requiredPlugins) {
public WorkflowStepValidator(List<String> inputs, List<String> outputs, List<String> requiredPlugins, TimeValue timeout) {
this.inputs = inputs;
this.outputs = outputs;
this.requiredPlugins = requiredPlugins;
this.timeout = timeout;
}

/**
Expand All @@ -54,6 +66,7 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept
List<String> parsedInputs = new ArrayList<>();
List<String> parsedOutputs = new ArrayList<>();
List<String> requiredPlugins = new ArrayList<>();
TimeValue timeout = null;

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
Expand All @@ -78,11 +91,22 @@ public static WorkflowStepValidator parse(XContentParser parser) throws IOExcept
requiredPlugins.add(parser.text());
}
break;
case TIMEOUT:
try {
timeout = TimeValue.parseTimeValue(parser.text(), TIMEOUT);
} catch (IllegalArgumentException e) {
logger.error("Failed to parse TIMEOUT value for field [" + fieldName + "]", e);
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
throw new FlowFrameworkException(
"Failed to parse workflow-step.json file for field [" + fieldName + "]",
RestStatus.INTERNAL_SERVER_ERROR
);
}
break;
default:
throw new IOException("Unable to parse field [" + fieldName + "] in a WorkflowStepValidator object.");
}
}
return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins);
return new WorkflowStepValidator(parsedInputs, parsedOutputs, requiredPlugins, timeout);
}

/**
Expand All @@ -103,9 +127,17 @@ public List<String> getOutputs() {

/**
* Get the required plugins
* @return the outputs
* @return the required plugins
*/
public List<String> getRequiredPlugins() {
return List.copyOf(requiredPlugins);
}

/**
* Get the timeout
* @return the timeout
*/
public TimeValue getTimeout() {
return timeout;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.CompletableFuture;
Expand Down Expand Up @@ -131,7 +132,6 @@ public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId)
idToNodeMap.put(processNode.id(), processNode);
nodes.add(processNode);
}

return nodes;
}

Expand All @@ -141,7 +141,7 @@ public List<ProcessNode> sortProcessNodes(Workflow workflow, String workflowId)
* @throws Exception if validation fails
*/
public void validate(List<ProcessNode> processNodes) throws Exception {
WorkflowValidator validator = WorkflowValidator.parse("mappings/workflow-steps.json");
WorkflowValidator validator = readWorkflowValidator(processNodes.get(0).id());
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
validatePluginsInstalled(processNodes, validator);
validateGraph(processNodes, validator);
}
Expand Down Expand Up @@ -244,20 +244,43 @@ public void validateGraph(List<ProcessNode> processNodes, WorkflowValidator vali
);
}
}
}

private WorkflowValidator readWorkflowValidator(String workflowId) {
try {
return WorkflowValidator.parse("mappings/workflow-steps.json");
} catch (Exception e) {
logger.error("Failed to read workflow-steps mapping file", e);
throw new FlowFrameworkException(
"Workflow " + workflowId + " failed at reading workflow-steps mapping file",
RestStatus.INTERNAL_SERVER_ERROR
);
}
}

private TimeValue parseTimeout(WorkflowNode node) {
String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, NODE_TIMEOUT_DEFAULT_VALUE);
/**
* A method for parsing workflow timeout value.
* The value could be parsed from node NODE_TIMEOUT_FIELD, the timeout field in workflow-step.json,
* or the default NODE_TIMEOUT_DEFAULT_VALUE
* @param node the workflow nde
dbwiddis marked this conversation as resolved.
Show resolved Hide resolved
* @return the timeout value
*/
protected TimeValue parseTimeout(WorkflowNode node) {
WorkflowValidator validator = readWorkflowValidator(node.id());
TimeValue nodeTimeoutValue = Optional.ofNullable(validator.getWorkflowStepValidators().get(node.type()).getTimeout())
.orElse(NODE_TIMEOUT_DEFAULT_VALUE);
String nodeTimeoutAsString = nodeTimeoutValue.getSeconds() + "s";
String timeoutValue = (String) node.userInputs().getOrDefault(NODE_TIMEOUT_FIELD, nodeTimeoutAsString);
String fieldName = String.join(".", node.id(), USER_INPUTS_FIELD, NODE_TIMEOUT_FIELD);
TimeValue timeValue = TimeValue.parseTimeValue(timeoutValue, fieldName);
if (timeValue.millis() < 0) {
TimeValue userInputTimeValue = TimeValue.parseTimeValue(timeoutValue, fieldName);

if (userInputTimeValue.millis() < 0) {
throw new FlowFrameworkException(
"Failed to parse timeout value [" + timeoutValue + "] for field [" + fieldName + "]. Must be positive",
RestStatus.BAD_REQUEST
);
}
return timeValue;
return userInputTimeValue;
}

private static List<WorkflowNode> topologicalSort(List<WorkflowNode> workflowNodes, List<WorkflowEdge> workflowEdges) {
Expand Down
6 changes: 4 additions & 2 deletions src/main/resources/mappings/workflow-steps.json
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@
],
"required_plugins":[
"opensearch-ml"
]
],
"timeout": "60s"
},
"register_remote_model": {
"inputs": [
Expand Down Expand Up @@ -109,7 +110,8 @@
],
"required_plugins":[
"opensearch-ml"
]
],
"timeout": "15s"
},
"undeploy_model": {
"inputs":[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.opensearch.common.settings.ClusterSettings;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.core.xcontent.XContentParser;
Expand Down Expand Up @@ -136,7 +137,7 @@ public void testNodeDetails() throws IOException {
ProcessNode node = workflow.get(0);
assertEquals("default_timeout", node.id());
assertEquals(CreateIngestPipelineStep.class, node.workflowStep().getClass());
assertEquals(15, node.nodeTimeout().seconds());
assertEquals(10, node.nodeTimeout().seconds());
node = workflow.get(1);
assertEquals("custom_timeout", node.id());
assertEquals(CreateIndexStep.class, node.workflowStep().getClass());
Expand Down Expand Up @@ -317,7 +318,7 @@ public void testSuccessfulGraphValidation() throws Exception {
workflowProcessSorter.validateGraph(sortedProcessNodes, validator);
}

public void testFailedGraphValidation() {
public void testFailedGraphValidation() throws IOException {

// Create Register Model workflow node with missing connector_id field
WorkflowNode registerModel = new WorkflowNode(
Expand Down Expand Up @@ -509,4 +510,46 @@ public void testFailedInstalledPluginValidation() throws Exception {
exception.getMessage()
);
}

public void testReadWorkflowStepFile_withDefaultTimeout() throws IOException {
// read timeout from node NODE_TIMEOUT_FIELD
WorkflowNode createConnector = new WorkflowNode(
"workflow_step_1",
CreateConnectorStep.NAME,
Map.of(),
Map.ofEntries(
Map.entry("name", ""),
Map.entry("description", ""),
Map.entry("version", ""),
Map.entry("protocol", ""),
Map.entry("parameters", ""),
Map.entry("credential", ""),
Map.entry("actions", ""),
Map.entry("node_timeout", "50s")
)
);
TimeValue createConnectorTimeout = workflowProcessSorter.parseTimeout(createConnector);
assertEquals(createConnectorTimeout.getSeconds(), 50);

// read timeout from workflow-step.json overwrite value
WorkflowNode deployModel = new WorkflowNode(
"workflow_step_3",
DeployModelStep.NAME,
Map.ofEntries(Map.entry("workflow_step_2", "model_id")),
Map.of()
);
TimeValue deployModelTimeout = workflowProcessSorter.parseTimeout(deployModel);
assertEquals(deployModelTimeout.getSeconds(), 15);

// read timeout from NODE_TIMEOUT_DEFAULT_VALUE when there's no node NODE_TIMEOUT_FIELD
// and no overwrite timeout value in workflow-step.json
WorkflowNode registerModel = new WorkflowNode(
"workflow_step_2",
RegisterRemoteModelStep.NAME,
Map.ofEntries(Map.entry("workflow_step_1", "connector_id")),
Map.ofEntries(Map.entry("name", "name"), Map.entry("function_name", "remote"), Map.entry("description", "description"))
);
TimeValue registerRemoteModelTimeout = workflowProcessSorter.parseTimeout(registerModel);
assertEquals(registerRemoteModelTimeout.getSeconds(), 10);
}
}