Skip to content

Commit

Permalink
feat: implement checks for parameters and variables of DSL 2.0 templa…
Browse files Browse the repository at this point in the history
…tes during lightweight validation (#321)

* feat: implement checks for parameters and variables of DSL 2.0 templates during lightweight validation

Contributes to https://github.ibm.com/st4sd/st4sd-runtime-core/issues/319
Signed-off-by: Vassilis Vassiladis <[email protected]>
  • Loading branch information
VassilisVassiliadis authored and GitHub Enterprise committed Nov 30, 2023
1 parent 3249d96 commit fa992f7
Show file tree
Hide file tree
Showing 2 changed files with 225 additions and 5 deletions.
170 changes: 170 additions & 0 deletions python/experiment/model/frontends/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
# - a list of namespace identifiers
"""
import copy
import dataclasses
import typing

import experiment.model.codes
Expand Down Expand Up @@ -1468,6 +1469,12 @@ class Action(str, enum.Enum):
else:
known_template_names[temp.signature.name] = kind

if not namespace.entrypoint:
if dsl_error.underlying_errors:
raise dsl_error
else:
return

try:
initial_template = namespace.get_template(namespace.entrypoint.entryInstance)
except Exception as e:
Expand Down Expand Up @@ -1829,6 +1836,16 @@ def __init__(
# VV: this could be a component with a bunch of errors so we cannot trust its contents
self.flowir["command"] = {}

if is_replica:
for idx in range(len(scope.template.signature.parameters)):
param = scope.template.signature.parameters[idx]
if param.name == "replica":
errors.append(experiment.model.errors.DSLInvalidFieldError(
location=template_dsl_location + ["signature", "parameters", idx, "name"],
underlying_error = ValueError('Replicating components cannot define a '
'parameter called "replica"')
))

@property
def step_name(self) -> str:
# VV: The step name is the last entry in the location of the Template instance
Expand Down Expand Up @@ -2164,6 +2181,14 @@ def namespace_to_flowir(
experiment.model.errors.DSLInvalidError:
When there are errors
"""
if not namespace.entrypoint:
raise experiment.model.errors.DSLInvalidError.from_errors(
[
experiment.model.errors.DSLInvalidFieldError(
location=["entrypoint"], underlying_error=ValueError("Missing entrypoint")
)
]
)

# VV: Make a copy in the namespace and work on that, because in `auto_generate_entrypoint()` we'll patch the
# entrypoint so that special parameters of the entry-instance Template (like input.<name> and data.<name>)
Expand Down Expand Up @@ -2434,3 +2459,148 @@ def discover_parameter_conflicts(

if errors:
raise experiment.model.errors.DSLInvalidError.from_errors(errors)


def lightweight_validate(
namespace: Namespace,
override_entrypoint_args: typing.Optional[typing.Dict[str, ParameterValueType]] = None,
):
"""Performs lightweight validation of the contents of a Namespace.
This method can deal with Namespaces which do not have an entrypoint or have an entrypoint with
incomplete information.
If the method does not raise an exception this does not mean that the Namespace is completely correct.
The only way to fully validate a namespace is to produce FlowIR from its definition.
Args:
namespace:
The namespace definition
override_entrypoint_args:
Overrides the arguments of the entrypoint
Raises:
experiment.model.errors.DSLInvalidError: If the namespace contains errors
"""
re_var = re.compile(ParameterPattern)

def no_param_refs_in_parameter_defaults(
location: typing.List[typing.Union[str, int]],
template: typing.Union[Component, Workflow],
) -> typing.List[experiment.model.errors.DSLInvalidFieldError]:
errors = []

for idx in range(len(template.signature.parameters)):
param = comp.signature.parameters[idx]

if not param.default:
continue

refs = set([match.group()[2:-2] for match in re_var.finditer(param.default)])

for r in refs:
errors.append(
experiment.model.errors.DSLInvalidFieldError(
location=location + ["signature", "parameters", idx, "default"],
underlying_error=ValueError(
f'Reference to parameter/variable "{r}"'
)
)
)

return errors

def detect_unknown_variables(
comp: Component,
index: int
) -> typing.List[experiment.model.errors.DSLInvalidFieldError]:
@dataclasses.dataclass
class Field:
location: typing.List[typing.Union[str, int]]
value: typing.Any

comp_location = ["components", index]

pending: typing.List[Field] = [
Field(
location=[],
value=comp.model_dump(by_alias=True, exclude_unset=True, exclude_defaults=True)
)
]

errors = []

parameters = [x.name for x in comp.signature.parameters]
variables = [x for x in comp.variables]

for common in set(variables).intersection(parameters):
errors.append(
experiment.model.errors.DSLInvalidFieldError(
location=comp_location + ["variables", common],
underlying_error=ValueError(f'Component defines a variable that shadows its parameter "{common}"')
)
)

may_reference = set(parameters + variables)

# VV: `replica` is a special variable which the runtime auto-injects to components that replicate
# for the time being let's assume that it's valid for components to reference this variable
may_reference.add("replica")

errors.extend(no_param_refs_in_parameter_defaults(location=comp_location, template=comp))

while pending:
current = pending.pop(0)

try:
if isinstance(current.value, dict):
for key, value in current.value.items():
pending.insert(0, Field(location=current.location+[key], value=value))
elif isinstance(current.value, str):
refs = set([match.group()[2:-2] for match in re_var.finditer(current.value)])
for unknown in refs.difference(may_reference):
errors.append(
experiment.model.errors.DSLInvalidFieldError(
location=comp_location + current.location,
underlying_error=ValueError(
f'Reference to unknown parameter or variable "{unknown}"'
)
)
)
elif isinstance(current.value, list):
if current.location == ["signature", "parameters"]:
continue

for idx, value in enumerate(current.value):
pending.insert(0, Field(location=current.location+[idx], value=value))
except Exception as exc:
errors.append(experiment.model.errors.DSLInvalidFieldError(
location=current.location,
underlying_error=ValueError(f"Internal error while validating field, underlying issue: {exc}")
))

return errors

errors_acc = []

for idx in range(len(namespace.components)):
# VV: enumerate messes up typehints
comp = namespace.components[idx]
errors_acc.extend(detect_unknown_variables(comp, idx))

for idx in range(len(namespace.workflows)):
# VV: enumerate messes up typehints
wf = namespace.workflows[idx]
errors_acc.extend(no_param_refs_in_parameter_defaults(location=["workflows", idx], template=wf))

scopes = ScopeStack()
try:
scopes.discover_all_instances_of_templates(
namespace=namespace,
override_entrypoint_args=override_entrypoint_args
)
except experiment.model.errors.DSLInvalidError as e:
errors_acc.extend(e.underlying_errors)

if errors_acc:
raise experiment.model.errors.DSLInvalidError.from_errors(errors_acc)
60 changes: 55 additions & 5 deletions tests/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,9 +770,8 @@ def dsl_conflicting_templates() -> typing.Dict[str, typing.Any]:

def test_dsl_conflicting_templates(dsl_conflicting_templates: typing.Dict[str, typing.Any]):
namespace = experiment.model.frontends.dsl.Namespace(**dsl_conflicting_templates)
scopes = experiment.model.frontends.dsl.ScopeStack()
with pytest.raises(experiment.model.errors.DSLInvalidError) as e:
scopes.discover_all_instances_of_templates(namespace)
experiment.model.frontends.dsl.lightweight_validate(namespace)

assert e.value.errors() == [
{'loc': ['workflows', 1], 'msg': 'There already is a Workflow template called main'},
Expand All @@ -788,11 +787,9 @@ def test_dsl_nested_workflows(dsl_nested_workflows: typing.Dict[str, typing.Any]

def test_detect_cycle(dsl_with_cycle: typing.Dict[str, typing.Any]):
namespace = experiment.model.frontends.dsl.Namespace(**dsl_with_cycle)
scopes = experiment.model.frontends.dsl.ScopeStack()


with pytest.raises(experiment.model.errors.DSLInvalidError) as e:
scopes.discover_all_instances_of_templates(namespace)
experiment.model.frontends.dsl.lightweight_validate(namespace)

exc = e.value

Expand Down Expand Up @@ -1313,3 +1310,56 @@ def test_nanopore_parse():

assert comp['command']['arguments'] == ('-n "$((%(replica)s+1)),+0p" input/cif_files.dat:ref '
'| awk -F "/" \'{print $1}\'')


def test_lightweight_validation_missing_variables():
dsl = yaml.safe_load("""
components:
- signature:
name: foo
command:
executable: bar
workflowAttributes:
replicate: "%(replicas)s"
""")

namespace = experiment.model.frontends.dsl.Namespace(**dsl)

with pytest.raises(experiment.model.errors.DSLInvalidError) as e:
experiment.model.frontends.dsl.lightweight_validate(namespace)

errors = e.value.errors()

assert errors == [
{
"loc": ["components", 0, "workflowAttributes", "replicate"],
"msg": 'Reference to unknown parameter or variable "replicas"'
}
]


def test_lightweight_validation_parameter_referencing_variable():
dsl = yaml.safe_load("""
components:
- signature:
name: foo
parameters:
- name: hello
default: "%(something)s"
command:
executable: bar
""")

namespace = experiment.model.frontends.dsl.Namespace(**dsl)

with pytest.raises(experiment.model.errors.DSLInvalidError) as e:
experiment.model.frontends.dsl.lightweight_validate(namespace)

errors = e.value.errors()

assert errors == [
{
"loc": ["components", 0, "signature", "parameters", 0, "default"],
"msg": 'Reference to parameter/variable "something"'
}
]

0 comments on commit fa992f7

Please sign in to comment.