Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve Graph Writing #785

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ typer = "^0.7"
dot4dict = "^0.1"
zninit = "^0.1"
znjson = "^0.2"
znflow = "^0.1"
znflow = "0.2.0a0"
varname = "^0.13"
# for Python3.12 compatibliity
pyzmq = "^25"
Expand Down
8 changes: 8 additions & 0 deletions zntrack/fields/dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,3 +302,11 @@ def _update_node_name(self, entry, instance, graph, key=None):
entry.name = entry_name

return entry

def get_dvc_data(self, instance: "Node") -> dict:
"""Get the DVC data."""
return {"deps": [x.as_posix() for x in self.get_files(instance)]}

def get_zntrack_data(self, instance: "Node") -> dict:
"""Get the zntrack data."""
return {self.name: getattr(instance, self.name)}
44 changes: 44 additions & 0 deletions zntrack/fields/dvc/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,44 @@ def __init__(self, *args, **kwargs):
self.dvc_option = kwargs.pop("dvc_option")
super().__init__(*args, **kwargs)

def get_dvc_data(self, instance: "Node") -> dict:
"""Get the data to be saved to the dvc.yaml file.

Parameters
----------
instance : Node
The node instance to get the data for.

Returns
-------
dict
The data to be saved to the dvc.yaml file.

"""
return {self.dvc_option: self.get_files(instance)}

def get_zntrack_data(self, instance: "Node") -> dict:
"""Get the data to be saved to the zntrack.json file.

Parameters
----------
instance : Node
The node instance to get the data for.

Returns
-------
dict
The data to be saved to the zntrack.json file.

"""
try:
value = instance.__dict__[self.name]
except KeyError:
# Taken from DVCOption.save.
# TODO: Should we return an empty dict if getattr fails?
value = getattr(instance, self.name)
return {self.name: {"_type": _get_import_path(value), "value": value}}

def get_files(self, instance: "Node") -> list:
"""Get the files affected by this field.

Expand Down Expand Up @@ -168,3 +206,9 @@ def __get__(self, instance: "Node", owner=None):

class PlotsOption(PlotsMixin, DVCOption):
"""Field with DVC plots kwargs."""


# TODO: How was this done previously?
def _get_import_path(obj):
obj_type = type(obj)
return f"{obj_type.__module__}.{obj_type.__qualname__}"
18 changes: 18 additions & 0 deletions zntrack/fields/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class Field(zninit.Descriptor, abc.ABC):
----------
dvc_option : str
The dvc command option for this field.

"""

dvc_option: str = None
Expand All @@ -49,6 +50,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The Node instance to save the field for.

"""
raise NotImplementedError

Expand All @@ -70,6 +72,7 @@ def get_files(self, instance: "Node") -> list:
-------
list
The affected files.

"""
raise NotImplementedError

Expand All @@ -83,6 +86,7 @@ def load(self, instance: "Node", lazy: bool = None):
lazy : bool, optional
Whether to load the field lazily.
This only applies to 'LazyField' classes.

"""
try:
instance.__dict__[self.name] = self.get_data(instance)
Expand All @@ -103,6 +107,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
typing.List[tuple]
The stage add argument for this field.

"""
return [
(f"--{self.dvc_option}", pathlib.Path(x).as_posix())
Expand All @@ -127,6 +132,7 @@ def get_optional_dvc_cmd(
-------
typing.List[str]
The optional dvc commands.

"""
return []

Expand Down Expand Up @@ -156,6 +162,18 @@ def _write_value_to_config(self, value, instance: "Node", encoder=None):
with open(config.files.zntrack, "w") as f:
json.dump(zntrack_dict, f, indent=4, cls=encoder)

def get_zntrack_data(self, instance: "Node") -> dict:
"""Get the data that will be written to the zntrack config file."""
return {}

def get_dvc_data(self, instance: "Node") -> dict:
"""Get the data that will be written to the dvc config file."""
return {}

def get_params_data(self, instance: "Node") -> dict:
"""Get the data that will be written to the params file."""
return {}


class DataIsLazyError(Exception):
"""Exception to raise when a field is accessed that contains lazy data."""
Expand Down
32 changes: 32 additions & 0 deletions zntrack/fields/zn/options.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,18 +103,28 @@ class Params(Field):
----------
dvc_option: str
The DVC option to use. Default is "params".

"""

dvc_option: str = "params"
group = FieldGroup.PARAMETER

def get_params_data(self, instance: "Node") -> dict:
"""Get the parameters data."""
return {self.name: getattr(instance, self.name)}

def get_dvc_data(self, instance: "Node") -> dict:
"""Get the DVC data."""
return {"params": [instance.name]}

def get_files(self, instance: "Node") -> list:
"""Get the list of files affected by this field.

Returns
-------
list
A list of file paths.

"""
return [config.files.params]

Expand All @@ -125,6 +135,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance associated with this field.

"""
file = self.get_files(instance)[0]

Expand Down Expand Up @@ -161,6 +172,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]:
-------
list
A list of tuples containing the DVC option and the file path.

"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", f"{file}:{instance.name}")]
Expand All @@ -171,6 +183,14 @@ class Output(LazyField):

group = FieldGroup.RESULT

def get_dvc_data(self, instance: "Node") -> dict:
"""Get the DVC data."""
return {"outs": [x.as_posix() for x in self.get_files(instance)]}

def get_zntrack_data(self, instance: "Node") -> dict:
"""Get the zntrack data."""
return {self.name: pathlib.Path(f"$nwd$/{self.name}.json")}

def __init__(self, dvc_option: str, **kwargs):
"""Create a new Output field.

Expand All @@ -180,6 +200,7 @@ def __init__(self, dvc_option: str, **kwargs):
The DVC option used to specify the output file.
**kwargs
Additional arguments to pass to the parent constructor.

"""
self.dvc_option = dvc_option
super().__init__(**kwargs)
Expand All @@ -196,6 +217,7 @@ def get_files(self, instance) -> list:
-------
list
A list containing the path of the file.

"""
return [get_nwd(instance) / f"{self.name}.json"]

Expand All @@ -206,6 +228,7 @@ def save(self, instance: "Node"):
----------
instance : Node
The node instance.

"""
try:
value = self.get_value_except_lazy(instance)
Expand Down Expand Up @@ -236,6 +259,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]:
-------
list
A list containing the DVC command for this field.

"""
file = self.get_files(instance)[0]
return [(f"--{self.dvc_option}", file.as_posix())]
Expand All @@ -247,6 +271,14 @@ class Plots(PlotsMixin, LazyField):
dvc_option: str = "plots"
group = FieldGroup.RESULT

def get_dvc_data(self, instance: "Node") -> dict:
"""Get the DVC data."""
return {"plots": [x.as_posix() for x in self.get_files(instance)]}

def get_zntrack_data(self, instance: "Node") -> dict:
"""Get the zntrack data."""
return {self.name: pathlib.Path(f"$nwd$/{self.name}.csv")}

def get_files(self, instance) -> list:
"""Get the path of the file in the node directory."""
return [get_nwd(instance) / f"{self.name}.csv"]
Expand Down
5 changes: 2 additions & 3 deletions zntrack/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory:
class NodeName:
"""The name of a node."""

groups: list[str]
groups: znflow.Group
name: str
varname: str = None
suffix: int = 0
Expand All @@ -235,7 +235,7 @@ def __str__(self) -> str:
"""Get the node name."""
name = []
if self.groups is not None:
name.extend(self.groups)
name.extend(x for x in self.groups.names[0])
if self.use_varname:
name.append(self.varname)
else:
Expand All @@ -255,7 +255,6 @@ def get_name_without_groups(self) -> str:

def update_suffix(self, project: "Project", node: "Node") -> None:
"""Update the suffix."""
node_names = [x["value"].name for x in project.graph.nodes.values()]
self.use_varname = project.magic_names

node_names = []
Expand Down