diff --git a/go.mod b/go.mod index 2779149..f5577f2 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( go.arcalot.io/log/v2 v2.2.0 go.flow.arcalot.io/deployer v0.6.1 go.flow.arcalot.io/dockerdeployer v0.7.3 - go.flow.arcalot.io/expressions v0.4.3 + go.flow.arcalot.io/expressions v0.4.4 go.flow.arcalot.io/kubernetesdeployer v0.9.3 go.flow.arcalot.io/pluginsdk v0.13.0 go.flow.arcalot.io/podmandeployer v0.11.3 diff --git a/go.sum b/go.sum index 7d5f8c3..6d17e09 100644 --- a/go.sum +++ b/go.sum @@ -137,6 +137,8 @@ go.flow.arcalot.io/dockerdeployer v0.7.3 h1:CLvSdqfoE8oZADI0wfry46SXR4CQjB6Qh+6Y go.flow.arcalot.io/dockerdeployer v0.7.3/go.mod h1:YWw9+GbYJxEnlahlYCx4UOJe+QNkecf8+EBtSIQD0aE= go.flow.arcalot.io/expressions v0.4.3 h1:0BRRghutHp0sctsITHe/A1le0yYiJtKNTxm27T+P6Og= go.flow.arcalot.io/expressions v0.4.3/go.mod h1:UORX78N4ep71wOzNXdIo/UY+6SdDD0id0mvuRNEQMeM= +go.flow.arcalot.io/expressions v0.4.4 h1:bYTC7YDmgDWcsdyY41+IvTJbvsM1rdE3ZBJhB+jNPHQ= +go.flow.arcalot.io/expressions v0.4.4/go.mod h1:0Y2LgynO1SWA4bqsnKlCxqLME9zOR8tWKg3g+RG+FFQ= go.flow.arcalot.io/kubernetesdeployer v0.9.3 h1:XKiqmCqXb6ZLwP5IQTAKS/gJHpq0Ub/yEjCfgAwQF2A= go.flow.arcalot.io/kubernetesdeployer v0.9.3/go.mod h1:DtB6HR7HBt/HA1vME0faIpOQ/lhfBJjL6OAGgT3Bu/Q= go.flow.arcalot.io/pluginsdk v0.13.0 h1:bZqohrDkyAHsWmFJbyvPkjqUALPNJqObefVQrmYqUTw= diff --git a/workflow/workflow.go b/workflow/workflow.go index 06a4a79..e1eaf1e 100644 --- a/workflow/workflow.go +++ b/workflow/workflow.go @@ -687,45 +687,7 @@ func (l *loopState) resolveExpressions(inputData any, dataModel any) (any, error l.logger.Debugf("Evaluating expression %s...", expr.String()) return expr.Evaluate(dataModel, l.callableFunctions, l.workflowContext) case *infer.OneOfExpression: - l.logger.Debugf("Evaluating oneof expression %s...", expr.String()) - - // Get the node the OneOf uses to check which Or dependency resolved first (the others will either not be - // in the resolved list, or they will be obviated) - oneOfNode, err := l.dag.GetNodeByID(expr.NodePath) - if err != nil { - return nil, fmt.Errorf("failed to get node to resolve oneof expression (%w)", err) - } - dependencies := oneOfNode.ResolvedDependencies() - firstResolvedDependency := "" - for dependency, dependencyType := range dependencies { - if dependencyType == dgraph.OrDependency { - firstResolvedDependency = dependency - break - } else if dependencyType == dgraph.ObviatedDependency { - l.logger.Infof("Multiple OR cases triggered; skipping %q", dependency) - } - } - if firstResolvedDependency == "" { - return nil, fmt.Errorf("could not find resolved dependency for oneof expression %q", expr.String()) - } - optionID := strings.Replace(firstResolvedDependency, expr.NodePath+".", "", 1) - optionExpr, found := expr.Options[optionID] - if !found { - return nil, fmt.Errorf("could not find oneof option %q for oneof %q", optionID, expr) - } - // Still pass the current node in due to the possibility of a foreach within a foreach. - subTypeResolution, err := l.resolveExpressions(optionExpr, dataModel) - if err != nil { - return nil, err - } - // Validate that it returned a map type (this is required because oneof subtypes need to be objects) - subTypeObjectMap, ok := subTypeResolution.(map[any]any) - if !ok { - return nil, fmt.Errorf("sub-type for oneof is not an object; got %T", subTypeResolution) - } - // Now add the discriminator - subTypeObjectMap[expr.Discriminator] = optionID - return subTypeObjectMap, nil + return l.resolveOneOfExpression(expr, dataModel) } v := reflect.ValueOf(inputData) @@ -758,6 +720,63 @@ func (l *loopState) resolveExpressions(inputData any, dataModel any) (any, error } } +func (l *loopState) resolveOneOfExpression(expr *infer.OneOfExpression, dataModel any) (any, error) { + l.logger.Debugf("Evaluating oneof expression %s...", expr.String()) + + // Get the node the OneOf uses to check which Or dependency resolved first (the others will either not be + // in the resolved list, or they will be obviated) + oneOfNode, err := l.dag.GetNodeByID(expr.NodePath) + if err != nil { + return nil, fmt.Errorf("failed to get node to resolve oneof expression (%w)", err) + } + dependencies := oneOfNode.ResolvedDependencies() + firstResolvedDependency := "" + for dependency, dependencyType := range dependencies { + if dependencyType == dgraph.OrDependency { + firstResolvedDependency = dependency + break + } else if dependencyType == dgraph.ObviatedDependency { + l.logger.Infof("Multiple OR cases triggered; skipping %q", dependency) + } + } + if firstResolvedDependency == "" { + return nil, fmt.Errorf("could not find resolved dependency for oneof expression %q", expr.String()) + } + optionID := strings.Replace(firstResolvedDependency, expr.NodePath+".", "", 1) + optionExpr, found := expr.Options[optionID] + if !found { + return nil, fmt.Errorf("could not find oneof option %q for oneof %q", optionID, expr) + } + // Still pass the current node in due to the possibility of a foreach within a foreach. + subTypeResolution, err := l.resolveExpressions(optionExpr, dataModel) + if err != nil { + return nil, err + } + + // Validate that it returned a map type (this is required because oneof subtypes need to be objects) + // With a special case for values from the providers, which are map[string]any instead of map[any]any + // The output must be copied since it could be referenced several times. + var outputData map[any]any + switch subTypeObjectMap := subTypeResolution.(type) { + case map[string]any: + outputData = make(map[any]any, len(subTypeObjectMap)) + for k, v := range subTypeObjectMap { + outputData[k] = v + } + case map[any]any: + outputData = make(map[any]any, len(subTypeObjectMap)) + for k, v := range subTypeObjectMap { + outputData[k] = v + } + default: + return nil, fmt.Errorf("sub-type for oneof is not the serialized version of an object (a map); got %T", subTypeResolution) + } + // Now add the discriminator + outputData[expr.Discriminator] = optionID + + return outputData, nil +} + // stageChangeHandler is implementing step.StageChangeHandler. type stageChangeHandler struct { onStageChange func( diff --git a/workflow/workflow_test.go b/workflow/workflow_test.go index dcb9af9..780a10d 100644 --- a/workflow/workflow_test.go +++ b/workflow/workflow_test.go @@ -1327,6 +1327,75 @@ func TestGracefullyDisabledStepWorkflow(t *testing.T) { assert.Equals(t, outputDataMap["result"], "disabled_wait_output") } +var shorthandGracefullyDisabledStepWorkflow = ` +version: v0.2.0 +input: + root: WorkflowInput + objects: + WorkflowInput: + id: WorkflowInput + properties: + step_enabled: + type: + type_id: bool +steps: + simple_wait: + plugin: + src: "n/a" + deployment_type: "builtin" + step: wait + input: + wait_time_ms: 0 + enabled: !expr $.input.step_enabled +outputs: + both: + all_output_output: !ordisabled $.steps.simple_wait.outputs + success_output: !ordisabled $.steps.simple_wait.outputs.success +` + +func TestShorthandGracefullyDisabledStepWorkflow(t *testing.T) { + // Run a workflow where the output uses the !ordisabledexpr tag to create a `oneof` expression + // to allow the step to be disabled while still resolving the output. + // Since it's referencing the simple_wait output twice with oneof, but in different ways, + // this is also testing that the oneof doesn't incorrectly mutate the original data source. + preparedWorkflow := assert.NoErrorR[workflow.ExecutableWorkflow](t)( + getTestImplPreparedWorkflow(t, shorthandGracefullyDisabledStepWorkflow), + ) + outputID, outputData, err := preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": true, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "both") + assert.Equals(t, outputData.(map[any]any), map[any]any{ + "all_output_output": map[any]any{ + "result": "enabled", + "success": map[any]any{ + "message": "Plugin slept for 0 ms.", + }, + }, + "success_output": map[any]any{ + "result": "enabled", + "message": "Plugin slept for 0 ms.", + }, + }) + // Test step disabled case + outputID, outputData, err = preparedWorkflow.Execute(context.Background(), map[string]any{ + "step_enabled": false, + }) + assert.NoError(t, err) + assert.Equals(t, outputID, "both") + assert.Equals(t, outputData.(map[any]any), map[any]any{ + "all_output_output": map[any]any{ + "result": "disabled", + "message": "Step simple_wait/wait disabled", + }, + "success_output": map[any]any{ + "result": "disabled", + "message": "Step simple_wait/wait disabled", + }, + }) +} + var oneofWithOneOptionWorkflow = ` version: v0.2.0 input: diff --git a/workflow/yaml.go b/workflow/yaml.go index 735e08b..e79a99e 100644 --- a/workflow/yaml.go +++ b/workflow/yaml.go @@ -3,6 +3,7 @@ package workflow import ( "fmt" "go.flow.arcalot.io/engine/internal/infer" + "regexp" "strings" "go.flow.arcalot.io/engine/internal/step" @@ -50,6 +51,9 @@ func (y yamlConverter) FromYAML(data []byte) (*Workflow, error) { return workflow, nil } +// YamlExprTag is the key to specify that the following code should be interpreted as an expression. +const YamlExprTag = "!expr" + // YamlOneOfKey is the key to specify the oneof options within a !oneof section. const YamlOneOfKey = "one_of" @@ -59,6 +63,21 @@ const YamlDiscriminatorKey = "discriminator" // YamlOneOfTag is the yaml tag that allows the section to be interpreted as a OneOf. const YamlOneOfTag = "!oneof" +// OrDisabledTag is the key to specify that the following code should be interpreted as a `oneof` type with +// two possible outputs: the expr specified or the disabled output. +const OrDisabledTag = "!ordisabled" + +func buildExpression(data yaml.Node, path []string, tag string) (expressions.Expression, error) { + if data.Type() != yaml.TypeIDString { + return nil, fmt.Errorf("%s found on non-string node at %s", tag, strings.Join(path, " -> ")) + } + expr, err := expressions.New(data.Value()) + if err != nil { + return nil, fmt.Errorf("failed to compile expression at %s (%w)", strings.Join(path, " -> "), err) + } + return expr, nil +} + func buildOneOfExpressions(data yaml.Node, path []string) (any, error) { if data.Type() != yaml.TypeIDMap { return nil, fmt.Errorf( @@ -103,19 +122,47 @@ func buildOneOfExpressions(data yaml.Node, path []string) (any, error) { }, nil } +var stepPathRegex = regexp.MustCompile(`((?:\$.)?steps\.[^.]+)(\..+)`) + +// Builds a oneof for the given path, or the step disabled output. +// Requires this to be a valid step output. But it is flexible to support all outputs, +// a specific output, or a field within a specific output. +func buildResultOrDisabledExpression(data yaml.Node, path []string) (any, error) { + successExpr, err := buildExpression(data, path, OrDisabledTag) + if err != nil { + return nil, err + } + // Parse the step + capturedParts := stepPathRegex.FindStringSubmatch(data.Value()) + if len(capturedParts) != 3 { + return nil, fmt.Errorf("unable to parse expression in %s at %s; got %s; must be in format $.steps.step_name.outputs.output", + OrDisabledTag, strings.Join(path, " -> "), data.Value()) + } + // Index 0 is the entire capture, index 1 is the step path, and index 2 is the present case + stepPath := capturedParts[1] + disabledPath := stepPath + ".disabled.output" + disabledExpr, err := expressions.New(disabledPath) + if err != nil { + return nil, fmt.Errorf("failed to compile auto-generated disable case for %s expression at %s; is %q a valid path? (%w)", OrDisabledTag, strings.Join(path, " -> "), disabledPath, err) + } + // Now create a `oneof` expression that handles this situation. + return &infer.OneOfExpression{ + Discriminator: "result", + Options: map[string]any{ + "enabled": successExpr, + "disabled": disabledExpr, + }, + }, nil +} + func yamlBuildExpressions(data yaml.Node, path []string) (any, error) { switch data.Tag() { - case "!expr": - if data.Type() != yaml.TypeIDString { - return nil, fmt.Errorf("!expr found on non-string node at %s", strings.Join(path, " -> ")) - } - expr, err := expressions.New(data.Value()) - if err != nil { - return nil, fmt.Errorf("failed to compile expression at %s (%w)", strings.Join(path, " -> "), err) - } - return expr, nil + case YamlExprTag: + return buildExpression(data, path, YamlExprTag) case YamlOneOfTag: return buildOneOfExpressions(data, path) + case OrDisabledTag: + return buildResultOrDisabledExpression(data, path) } switch data.Type() { case yaml.TypeIDString: diff --git a/workflow/yaml_test.go b/workflow/yaml_test.go index 8389f2a..812bc1b 100644 --- a/workflow/yaml_test.go +++ b/workflow/yaml_test.go @@ -101,5 +101,75 @@ func TestBuildOneOfExpression_InputValidation(t *testing.T) { _, err = oneofResult.(*infer.OneOfExpression).Type(nil, nil, nil) assert.Error(t, err) assert.Contains(t, err.Error(), "not an object") +} + +func TestBuildResultOrDisabledExpression_Simple(t *testing.T) { + // Test without root $ + yamlInput := []byte(`!ordisabled steps.test.outputs`) + input := assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + result, err := buildResultOrDisabledExpression(input, make([]string, 0)) + assert.NoError(t, err) + assert.InstanceOf[*infer.OneOfExpression](t, result) + oneOfResult := result.(*infer.OneOfExpression) + assert.Equals(t, oneOfResult.Discriminator, "result") + assert.Equals(t, oneOfResult.Options, map[string]any{ + "enabled": lang.Must2(expressions.New("steps.test.outputs")), + "disabled": lang.Must2(expressions.New("steps.test.disabled.output")), + }) + + // Test with all outputs + yamlInput = []byte(`!ordisabled $.steps.test.outputs`) + input = assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + result, err = buildResultOrDisabledExpression(input, make([]string, 0)) + assert.NoError(t, err) + assert.InstanceOf[*infer.OneOfExpression](t, result) + oneOfResult = result.(*infer.OneOfExpression) + assert.Equals(t, oneOfResult.Discriminator, "result") + assert.Equals(t, oneOfResult.Options, map[string]any{ + "enabled": lang.Must2(expressions.New("$.steps.test.outputs")), + "disabled": lang.Must2(expressions.New("$.steps.test.disabled.output")), + }) + // Test with a specific output + yamlInput = []byte(`!ordisabled $.steps.test.outputs.success`) + input = assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + result, err = buildResultOrDisabledExpression(input, make([]string, 0)) + assert.NoError(t, err) + assert.InstanceOf[*infer.OneOfExpression](t, result) + oneOfResult = result.(*infer.OneOfExpression) + assert.Equals(t, oneOfResult.Discriminator, "result") + assert.Equals(t, oneOfResult.Options, map[string]any{ + "enabled": lang.Must2(expressions.New("$.steps.test.outputs.success")), + "disabled": lang.Must2(expressions.New("$.steps.test.disabled.output")), + }) +} + +func TestBuildResultOrDisabledExpression_InvalidPattern(t *testing.T) { + // Missing the output + yamlInput := []byte(`!ordisabled $.steps.test`) + input := assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + _, err := buildResultOrDisabledExpression(input, make([]string, 0)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unable to parse expression") + // Trailing period. This could either trigger an unable to parse expression error + // or a token not found error depending on the order of the function under test. + yamlInput = []byte(`!ordisabled $.steps.test.`) + input = assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + _, err = buildResultOrDisabledExpression(input, make([]string, 0)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "token not found") + // Misspelled steps + yamlInput = []byte(`!ordisabled $.stepswrong.test`) + input = assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + _, err = buildResultOrDisabledExpression(input, make([]string, 0)) + assert.Error(t, err) + assert.Contains(t, err.Error(), "unable to parse expression") +} + +func TestBuildExpression_WrongType(t *testing.T) { + yamlInput := []byte(`!expr {}`) // A map + input := assert.NoErrorR[yaml.Node](t)(yaml.New().Parse(yamlInput)) + _, err := buildExpression(input, make([]string, 0), YamlExprTag) + assert.Error(t, err) + assert.Contains(t, err.Error(), "found on non-string node") }