Skip to content

Commit

Permalink
fix: update the DSL 2.0 logic for determining whether a Component is …
Browse files Browse the repository at this point in the history
…replicating or not (#340)

* fix: update the DSL 2.0 logic for determining whether a Component is replicating or not

Signed-off-by: Vassilis Vassiladis <[email protected]>

---------

Signed-off-by: Vassilis Vassiladis <[email protected]>
  • Loading branch information
VassilisVassiliadis authored and GitHub Enterprise committed Jun 19, 2024
1 parent a49f425 commit 5a8ac57
Show file tree
Hide file tree
Showing 3 changed files with 247 additions and 44 deletions.
145 changes: 105 additions & 40 deletions python/experiment/model/frontends/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,9 +203,9 @@ def __str__(self):

def split(
self,
scopes: typing.Iterable[typing.Tuple[str]],
scopes: typing.Iterable[typing.Tuple[str, ...]],
) -> typing.Tuple[
typing.Tuple[str],
typing.Tuple[str, ...],
typing.Optional[str],
]:
"""Utility method to infer the partition of an OutputReference location into stepName and fileRef
Expand Down Expand Up @@ -819,7 +819,7 @@ def _replace_many_parameter_references(

def replace_parameter_references(
value: ParameterValueType,
all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"],
all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"],
location: typing.Iterable[str],
is_replica: bool,
variables: typing.Optional[typing.Dict[str, ParameterValueType]] = None
Expand Down Expand Up @@ -1045,7 +1045,7 @@ def fold_in_defaults_of_parameters(self):

def resolve_parameter_references_of_instance(
self: "ScopeStack.Scope",
all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"]
all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"]
) -> typing.List[Exception]:
errors: typing.List[Exception] = []
for idx, (name, value) in enumerate(self.parameters.items()):
Expand All @@ -1069,7 +1069,7 @@ def resolve_parameter_references_of_instance(

def resolve_output_references_of_instance(
self: "ScopeStack.Scope",
all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"],
all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"],
ensure_references_point_to_sibling_steps: bool,
) -> typing.List[Exception]:
errors: typing.List[Exception] = []
Expand Down Expand Up @@ -1107,7 +1107,7 @@ def resolve_output_references_of_instance(

def resolve_legacy_data_references(
self: "ScopeStack.Scope",
all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"],
all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"],
) -> typing.List[Exception]:
errors: typing.List[Exception] = []
for idx, (name, value) in enumerate(self.parameters.items()):
Expand Down Expand Up @@ -1160,7 +1160,7 @@ def replace_step_references(
self,
value: ParameterValueType,
field: typing.List[str],
all_scopes: typing.Dict[typing.Tuple[str], "ScopeStack.Scope"],
all_scopes: typing.Dict[typing.Tuple[str, ...], "ScopeStack.Scope"],
ensure_references_point_to_sibling_steps: bool = True,
) -> ParameterValueType:
"""Rewrites all references to steps in a value to their absolute form
Expand Down Expand Up @@ -1190,7 +1190,6 @@ def replace_step_references(
sibling_steps = []
if len(self.location) > 1:
uid_parent = tuple(self.location[:-1])
parent_workflow_name = uid_parent[-1] if len(uid_parent) > 0 else "**missing**"

try:
parent_scope = all_scopes[uid_parent]
Expand Down Expand Up @@ -1257,7 +1256,15 @@ def __init__(self):
# The root of the namespace is what the entrypoint invokes, it's name is always `entry-instance`.
# The values are the ScopeEntries which are effectively instances of a Template i.e.
# The instance name, definition, and parameters of a Template
self.scopes: typing.Dict[typing.Tuple[str], ScopeStack.Scope] = {}
self.scopes: typing.Dict[typing.Tuple[str, ...], ScopeStack.Scope] = {}

# VV: Keys are "locations" of Component template instances, values are either True or False indicating
# whether the component is replicating or not
self.replicating_components: typing.Dict[typing.Tuple[str, ...], bool] = {}

# VV: Keys are "locations" of Component template instances, values are either True or False indicating
# whether the component is aggregating or not
self.aggregating_components: typing.Dict[typing.Tuple[str, ...], bool] = {}

def can_template_replicate(self, location: typing.Iterable[str]) -> bool:
"""Returns whether the template can replicate
Expand All @@ -1277,27 +1284,79 @@ def can_template_replicate(self, location: typing.Iterable[str]) -> bool:
If the location does not map to a known scope
"""

location = list(location)
scope: ScopeStack.Scope = self.scopes[tuple(location)]
if isinstance(scope.template, Workflow):
return False

# VV: This component is using a variable for its workflowAttributes.replicate field.
# For now, we can assume that this means the component is replicating
if scope.template.workflowAttributes.replicate not in ["0", 0, "", None]:
return True

if scope.template.workflowAttributes.aggregate is True:
self.aggregating_components[tuple(location)] = True
return False

# VV: This component doesn't replicate, let's start walking from its producers all the way to the
# root of the experiment. Stop when one of the components is either replicating or aggregating.
# If you reach the entrypoint then this component is not replicating.

# VV: Iterate scopes starting from the CURRENT template and moving up till you reach:
# 1. an aggregating component -> replica = False
# 2. a POTENTIALLY replicating component -> replica = True
# 3. the entrypoint -> replica = False
pattern_vanilla = re.compile(OutputReferenceVanilla)
pattern_nested = re.compile(OutputReferenceNested)

while location:
scope = self.scopes[tuple(location)]
location = location[:-1]
component_locations_to_check = [scope.location]

if not isinstance(scope.template, Component):
while component_locations_to_check:
location = component_locations_to_check.pop()

scope: ScopeStack.Scope = self.scopes[tuple(location)]
if isinstance(scope.template, Workflow):
continue

# VV: This component is using a variable for its workflowAttributes.replicate field.
# For now, we can assume that this means the component is replicating
if scope.template.workflowAttributes.replicate not in ["0", 0, "", None]:
return True
elif (
isinstance(scope.template.workflowAttributes.aggregate, bool)
and scope.template.workflowAttributes.aggregate is True
):
return False

for value in scope.parameters.values():
if not isinstance(value, str):
continue

for pattern in [pattern_vanilla, pattern_nested]:
for match in pattern.finditer(value):
ref = OutputReference.from_str(match.group(0))
location = ref.location

producer = None
while location:
try:
producer = self.scopes[tuple(location)]
if isinstance(producer.template, Component) is False:
continue
break
except KeyError:
# VV: This location doesn't map to a component. The OutputReference must be pointing
# to an output of a step. Trim one level and check whether that points to a known step.
location = location[:-1]

if not producer:
continue

replicating = self.replicating_components.get(tuple(location))

if (replicating
or producer.template.workflowAttributes.replicate not in ["0", 0, "", None]):
self.replicating_components[tuple(scope.location)] = True
return True

if (self.aggregating_components.get(tuple(producer.location), False) is True
or producer.template.workflowAttributes.aggregate is True):
# VV: If this component is not aggregating then it **might** be replicating
self.aggregating_components[tuple(producer.location)] = True
break

# VV: Cannot tell whether this component replicates or not, need to visit its producer
component_locations_to_check.append(producer.location)

return False

Expand Down Expand Up @@ -1880,7 +1939,7 @@ def discover_legacy_references(self) -> typing.Dict[str, typing.List[str]]:

def convert_outputreferences_to_datareferences(
self,
uid_to_name: typing.Dict[typing.Tuple[str], typing.Tuple[int, str]],
uid_to_name: typing.Dict[typing.Tuple[str, ...], typing.Tuple[int, str]],
location: typing.List[typing.Union[str, int]],
):
"""Utility method to convert OutputReference instances into Legacy DataReferences
Expand Down Expand Up @@ -1917,17 +1976,18 @@ def convert_outputreferences_to_datareferences(
self.flowir["command"]["arguments"] = args

# VV: TODO Here we'll need to do something about :copy - I'll figure this out in a future update
for match in pattern_output.finditer(args):
ref = OutputReference.from_str(match.group(0))
if not ref.method:
raise experiment.model.errors.DSLInvalidFieldError(
location=["components", self.scope.template.signature.name, "command", "arguments"],
underlying_error=ValueError(f"The arguments of {self.scope.location} contain a reference to "
f"the output {match.group(0)} but the OutputReference is partial, it does not "
f"end with a :$method suffix.")
)
else:
arguments_output.add(match.group(0))
if isinstance(args, str):
for match in pattern_output.finditer(args):
ref = OutputReference.from_str(match.group(0))
if not ref.method:
raise experiment.model.errors.DSLInvalidFieldError(
location=["components", self.scope.template.signature.name, "command", "arguments"],
underlying_error=ValueError(f"The arguments of {self.scope.location} contain a reference to "
f"the output {match.group(0)} but the OutputReference is partial, it does not "
f"end with a :$method suffix.")
)
else:
arguments_output.add(match.group(0))

for idx, (name, value) in enumerate(self.scope.parameters.items()):
if not isinstance(value, str):
Expand Down Expand Up @@ -1960,8 +2020,9 @@ def convert_outputreferences_to_datareferences(
)
self.errors.append(err)

for match in pattern_legacy.finditer(args):
arguments_legacy.add(match.group(0))
if isinstance(args, str):
for match in pattern_legacy.finditer(args):
arguments_legacy.add(match.group(0))

for name, value in self.scope.parameters.items():
if not isinstance(value, str):
Expand Down Expand Up @@ -2081,9 +2142,13 @@ def replace_parameter_references(
except experiment.model.errors.DSLInvalidFieldError as e:
self.errors.append(e)
except Exception as e:
str_location = "/".join(map(str, self.scope.dsl_location()))

self.errors.append(
experiment.model.errors.DSLInvalidFieldError(
self.template_dsl_location + node.location, underlying_error=e
self.template_dsl_location + node.location,
underlying_error=ValueError(f"The component was instantiated at {str_location}. "
f"Error: {e}")
)
)

Expand Down Expand Up @@ -2210,7 +2275,7 @@ def namespace_to_flowir(
override_entrypoint_args=override_entrypoint_args
)

components: typing.Dict[typing.Tuple[str], ComponentFlowIR] = {}
components: typing.Dict[typing.Tuple[str, ...], ComponentFlowIR] = {}
errors = []

for location, scope in scopes.scopes.items():
Expand All @@ -2236,7 +2301,7 @@ def namespace_to_flowir(
raise experiment.model.errors.DSLInvalidError.from_errors(errors)

component_names: typing.Dict[str, int] = {}
uid_to_name: typing.Dict[typing.Tuple[str], typing.Tuple[int, str]] = {}
uid_to_name: typing.Dict[typing.Tuple[str, ...], typing.Tuple[int, str]] = {}

pattern_name = re.compile(SignatureNamePattern)

Expand Down
10 changes: 8 additions & 2 deletions python/experiment/model/frontends/flowir.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import pprint
import re
import traceback
import typing
from string import Template
from threading import RLock
from typing import (Any, Callable, Dict, List, MutableMapping, Optional, Set,
Expand Down Expand Up @@ -4930,8 +4931,13 @@ def refresh_component_dictionary(self):
)
self._component_dictionary[comp_id] = component

def replicate(self, platform=None, ignore_errors=False, top_level_folders=None):
# type: (str, bool, List[str]) -> DictFlowIR
def replicate(
self,
platform: typing.Optional[str]=None,
ignore_errors: bool=False,
top_level_folders: typing.Optional[typing.List[str]]=None
) -> DictFlowIR:

"""Replicates a primitive FlowIRConcrete
Arguments:
Expand Down
Loading

0 comments on commit 5a8ac57

Please sign in to comment.