From f1418420862dfae67051dde1265be1ae09b414b4 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 11 Apr 2024 11:38:27 +0200 Subject: [PATCH 1/5] first draft --- poetry.lock | 16 ++++++++-------- pyproject.toml | 2 +- zntrack/fields/dependency.py | 9 +++++++++ zntrack/fields/field.py | 18 ++++++++++++++++++ zntrack/fields/zn/options.py | 32 ++++++++++++++++++++++++++++++++ 5 files changed, 68 insertions(+), 9 deletions(-) diff --git a/poetry.lock b/poetry.lock index 777a55a6..d5a618d5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5669,21 +5669,21 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [[package]] name = "znflow" -version = "0.1.14" +version = "0.2.0a0" description = "A general purpose framework for building and running computational graphs." optional = false -python-versions = ">=3.8,<4.0" +python-versions = ">=3.9,<4.0" files = [ - {file = "znflow-0.1.14-py3-none-any.whl", hash = "sha256:f94f21cdaece949754e6dd5beaedfb078b9331ca49e32d9a2dfaa4ac1d8f8324"}, - {file = "znflow-0.1.14.tar.gz", hash = "sha256:bf85dbb4c816a3c1ae98ed62f75feb10a22032e8bccf8d016ee6d406873c9c03"}, + {file = "znflow-0.2.0a0-py3-none-any.whl", hash = "sha256:e9d8e9a3efb800fbd06959883e84ed4ef0b6fb7304b2c1262d5eac4617aa9d39"}, + {file = "znflow-0.2.0a0.tar.gz", hash = "sha256:0c4279d30a6999938aff58b867c9ea9b7166b3d24c87794891bca2e96ac7f04c"}, ] [package.dependencies] -matplotlib = ">=3.6.3,<4.0.0" -networkx = ">=3.0,<4.0" +matplotlib = ">=3,<4" +networkx = ">=3,<4" [package.extras] -dask = ["bokeh (>=2.4.2,<3.0.0)", "dask (>=2022.12.1,<2023.0.0)", "dask-jobqueue (>=0.8.1,<0.9.0)", "distributed (>=2022.12.1,<2023.0.0)"] +dask = ["bokeh (>=2,<3)", "dask (>=2022,<2023)", "dask-jobqueue (>=0.8,<0.9)", "distributed (>=2022,<2023)"] [[package]] name = "zninit" @@ -5713,4 +5713,4 @@ files = [ [metadata] lock-version = "2.0" python-versions = ">=3.9,<4.0.0" -content-hash = "cd187c08029a8632406528c1ac370397729406979b98390499726e26ff411c46" +content-hash = "c7b4e0d280aa028c18c82c2567c0417ab055412837d88e9a88c878b916d37b2c" diff --git a/pyproject.toml b/pyproject.toml index 2bedcae1..4b8bfedd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,7 +19,7 @@ typer = "^0.7" dot4dict = "^0.1" zninit = "^0.1" znjson = "^0.2" -znflow = "^0.1" +znflow = "0.2.0a0" varname = "^0.13" # for Python3.12 compatibliity pyzmq = "^25" diff --git a/zntrack/fields/dependency.py b/zntrack/fields/dependency.py index 1d49cbb6..b6bc6f01 100644 --- a/zntrack/fields/dependency.py +++ b/zntrack/fields/dependency.py @@ -302,3 +302,12 @@ def _update_node_name(self, entry, instance, graph, key=None): entry.name = entry_name return entry + + + def get_dvc_data(self, instance: "Node") -> dict: + """Get the DVC data.""" + return {"deps": [x.as_posix() for x in self.get_files(instance)]} + + def get_zntrack_data(self, instance: "Node") -> dict: + """Get the zntrack data.""" + return {self.name: getattr(instance, self.name)} diff --git a/zntrack/fields/field.py b/zntrack/fields/field.py index 3650a895..66c84681 100644 --- a/zntrack/fields/field.py +++ b/zntrack/fields/field.py @@ -36,6 +36,7 @@ class Field(zninit.Descriptor, abc.ABC): ---------- dvc_option : str The dvc command option for this field. + """ dvc_option: str = None @@ -49,6 +50,7 @@ def save(self, instance: "Node"): ---------- instance : Node The Node instance to save the field for. + """ raise NotImplementedError @@ -70,6 +72,7 @@ def get_files(self, instance: "Node") -> list: ------- list The affected files. + """ raise NotImplementedError @@ -83,6 +86,7 @@ def load(self, instance: "Node", lazy: bool = None): lazy : bool, optional Whether to load the field lazily. This only applies to 'LazyField' classes. + """ try: instance.__dict__[self.name] = self.get_data(instance) @@ -103,6 +107,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]: ------- typing.List[tuple] The stage add argument for this field. + """ return [ (f"--{self.dvc_option}", pathlib.Path(x).as_posix()) @@ -127,6 +132,7 @@ def get_optional_dvc_cmd( ------- typing.List[str] The optional dvc commands. + """ return [] @@ -157,6 +163,18 @@ def _write_value_to_config(self, value, instance: "Node", encoder=None): json.dump(zntrack_dict, f, indent=4, cls=encoder) + def get_zntrack_data(self, instance: "Node") -> dict: + """Get the data that will be written to the zntrack config file.""" + return {} + + def get_dvc_data(self, instance: "Node") -> dict: + """Get the data that will be written to the dvc config file.""" + return {} + + def get_params_data(self, instance: "Node") -> dict: + """Get the data that will be written to the params file.""" + return {} + class DataIsLazyError(Exception): """Exception to raise when a field is accessed that contains lazy data.""" diff --git a/zntrack/fields/zn/options.py b/zntrack/fields/zn/options.py index edb9dcc5..37a640d7 100644 --- a/zntrack/fields/zn/options.py +++ b/zntrack/fields/zn/options.py @@ -103,11 +103,20 @@ class Params(Field): ---------- dvc_option: str The DVC option to use. Default is "params". + """ dvc_option: str = "params" group = FieldGroup.PARAMETER + def get_params_data(self, instance: "Node") -> dict: + """Get the parameters data.""" + return {self.name: getattr(instance, self.name)} + + def get_dvc_data(self, instance: "Node") -> dict: + """Get the DVC data.""" + return {"params": [instance.name]} + def get_files(self, instance: "Node") -> list: """Get the list of files affected by this field. @@ -115,6 +124,7 @@ def get_files(self, instance: "Node") -> list: ------- list A list of file paths. + """ return [config.files.params] @@ -125,6 +135,7 @@ def save(self, instance: "Node"): ---------- instance : Node The node instance associated with this field. + """ file = self.get_files(instance)[0] @@ -161,6 +172,7 @@ def get_stage_add_argument(self, instance: "Node") -> typing.List[tuple]: ------- list A list of tuples containing the DVC option and the file path. + """ file = self.get_files(instance)[0] return [(f"--{self.dvc_option}", f"{file}:{instance.name}")] @@ -171,6 +183,14 @@ class Output(LazyField): group = FieldGroup.RESULT + def get_dvc_data(self, instance: "Node") -> dict: + """Get the DVC data.""" + return {"outs": [x.as_posix() for x in self.get_files(instance)]} + + def get_zntrack_data(self, instance: "Node") -> dict: + """Get the zntrack data.""" + return {self.name: pathlib.Path(f"$nwd$/{self.name}.json")} + def __init__(self, dvc_option: str, **kwargs): """Create a new Output field. @@ -180,6 +200,7 @@ def __init__(self, dvc_option: str, **kwargs): The DVC option used to specify the output file. **kwargs Additional arguments to pass to the parent constructor. + """ self.dvc_option = dvc_option super().__init__(**kwargs) @@ -196,6 +217,7 @@ def get_files(self, instance) -> list: ------- list A list containing the path of the file. + """ return [get_nwd(instance) / f"{self.name}.json"] @@ -206,6 +228,7 @@ def save(self, instance: "Node"): ---------- instance : Node The node instance. + """ try: value = self.get_value_except_lazy(instance) @@ -236,6 +259,7 @@ def get_stage_add_argument(self, instance) -> typing.List[tuple]: ------- list A list containing the DVC command for this field. + """ file = self.get_files(instance)[0] return [(f"--{self.dvc_option}", file.as_posix())] @@ -247,6 +271,14 @@ class Plots(PlotsMixin, LazyField): dvc_option: str = "plots" group = FieldGroup.RESULT + def get_dvc_data(self, instance: "Node") -> dict: + """Get the DVC data.""" + return {"plots": [x.as_posix() for x in self.get_files(instance)]} + + def get_zntrack_data(self, instance: "Node") -> dict: + """Get the zntrack data.""" + return {self.name: pathlib.Path(f"$nwd$/{self.name}.csv")} + def get_files(self, instance) -> list: """Get the path of the file in the node directory.""" return [get_nwd(instance) / f"{self.name}.csv"] From 9c0f696c0006b9a44f54d747d285efbdd1fa5a9a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Apr 2024 09:40:27 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- zntrack/fields/dependency.py | 3 +-- zntrack/fields/field.py | 6 +++--- zntrack/fields/zn/options.py | 2 +- 3 files changed, 5 insertions(+), 6 deletions(-) diff --git a/zntrack/fields/dependency.py b/zntrack/fields/dependency.py index b6bc6f01..10f0b23a 100644 --- a/zntrack/fields/dependency.py +++ b/zntrack/fields/dependency.py @@ -302,8 +302,7 @@ def _update_node_name(self, entry, instance, graph, key=None): entry.name = entry_name return entry - - + def get_dvc_data(self, instance: "Node") -> dict: """Get the DVC data.""" return {"deps": [x.as_posix() for x in self.get_files(instance)]} diff --git a/zntrack/fields/field.py b/zntrack/fields/field.py index 66c84681..0485cbe6 100644 --- a/zntrack/fields/field.py +++ b/zntrack/fields/field.py @@ -50,7 +50,7 @@ def save(self, instance: "Node"): ---------- instance : Node The Node instance to save the field for. - + """ raise NotImplementedError @@ -162,7 +162,6 @@ def _write_value_to_config(self, value, instance: "Node", encoder=None): with open(config.files.zntrack, "w") as f: json.dump(zntrack_dict, f, indent=4, cls=encoder) - def get_zntrack_data(self, instance: "Node") -> dict: """Get the data that will be written to the zntrack config file.""" return {} @@ -170,11 +169,12 @@ def get_zntrack_data(self, instance: "Node") -> dict: def get_dvc_data(self, instance: "Node") -> dict: """Get the data that will be written to the dvc config file.""" return {} - + def get_params_data(self, instance: "Node") -> dict: """Get the data that will be written to the params file.""" return {} + class DataIsLazyError(Exception): """Exception to raise when a field is accessed that contains lazy data.""" diff --git a/zntrack/fields/zn/options.py b/zntrack/fields/zn/options.py index 37a640d7..11543220 100644 --- a/zntrack/fields/zn/options.py +++ b/zntrack/fields/zn/options.py @@ -274,7 +274,7 @@ class Plots(PlotsMixin, LazyField): def get_dvc_data(self, instance: "Node") -> dict: """Get the DVC data.""" return {"plots": [x.as_posix() for x in self.get_files(instance)]} - + def get_zntrack_data(self, instance: "Node") -> dict: """Get the zntrack data.""" return {self.name: pathlib.Path(f"$nwd$/{self.name}.csv")} From 03e901f9200bc9394e909fef4d6190eb42fb8966 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 11 Apr 2024 12:32:50 +0200 Subject: [PATCH 3/5] fix for new znflow groups (used instead of str) --- zntrack/utils/__init__.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/zntrack/utils/__init__.py b/zntrack/utils/__init__.py index 4bb823be..d4b1613f 100644 --- a/zntrack/utils/__init__.py +++ b/zntrack/utils/__init__.py @@ -225,7 +225,7 @@ def cwd_temp_dir(required_files=None) -> tempfile.TemporaryDirectory: class NodeName: """The name of a node.""" - groups: list[str] + groups: znflow.Group name: str varname: str = None suffix: int = 0 @@ -235,7 +235,7 @@ def __str__(self) -> str: """Get the node name.""" name = [] if self.groups is not None: - name.extend(self.groups) + name.extend(x for x in self.groups.names[0]) if self.use_varname: name.append(self.varname) else: @@ -255,7 +255,7 @@ def get_name_without_groups(self) -> str: def update_suffix(self, project: "Project", node: "Node") -> None: """Update the suffix.""" - node_names = [x["value"].name for x in project.graph.nodes.values()] + # node_names = [x["value"].name for x in project.graph.nodes.values()] self.use_varname = project.magic_names node_names = [] From af9336dea559ab4b8ee1c3a8607c05823ee1d8a3 Mon Sep 17 00:00:00 2001 From: Fabian Zills Date: Thu, 11 Apr 2024 12:34:37 +0200 Subject: [PATCH 4/5] remove old out code --- zntrack/utils/__init__.py | 1 - 1 file changed, 1 deletion(-) diff --git a/zntrack/utils/__init__.py b/zntrack/utils/__init__.py index d4b1613f..b676f299 100644 --- a/zntrack/utils/__init__.py +++ b/zntrack/utils/__init__.py @@ -255,7 +255,6 @@ def get_name_without_groups(self) -> str: def update_suffix(self, project: "Project", node: "Node") -> None: """Update the suffix.""" - # node_names = [x["value"].name for x in project.graph.nodes.values()] self.use_varname = project.magic_names node_names = [] From 89d432b06af02beb786901ee8830f1e57b56bbe0 Mon Sep 17 00:00:00 2001 From: Niklas Kappel Date: Fri, 19 Jul 2024 15:56:07 +0200 Subject: [PATCH 5/5] Add get_dvc_data and get_zntrack data for DVCOption field --- zntrack/fields/dvc/options.py | 44 +++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/zntrack/fields/dvc/options.py b/zntrack/fields/dvc/options.py index 3cb93f47..46a7e663 100644 --- a/zntrack/fields/dvc/options.py +++ b/zntrack/fields/dvc/options.py @@ -56,6 +56,44 @@ def __init__(self, *args, **kwargs): self.dvc_option = kwargs.pop("dvc_option") super().__init__(*args, **kwargs) + def get_dvc_data(self, instance: "Node") -> dict: + """Get the data to be saved to the dvc.yaml file. + + Parameters + ---------- + instance : Node + The node instance to get the data for. + + Returns + ------- + dict + The data to be saved to the dvc.yaml file. + + """ + return {self.dvc_option: self.get_files(instance)} + + def get_zntrack_data(self, instance: "Node") -> dict: + """Get the data to be saved to the zntrack.json file. + + Parameters + ---------- + instance : Node + The node instance to get the data for. + + Returns + ------- + dict + The data to be saved to the zntrack.json file. + + """ + try: + value = instance.__dict__[self.name] + except KeyError: + # Taken from DVCOption.save. + # TODO: Should we return an empty dict if getattr fails? + value = getattr(instance, self.name) + return {self.name: {"_type": _get_import_path(value), "value": value}} + def get_files(self, instance: "Node") -> list: """Get the files affected by this field. @@ -168,3 +206,9 @@ def __get__(self, instance: "Node", owner=None): class PlotsOption(PlotsMixin, DVCOption): """Field with DVC plots kwargs.""" + + +# TODO: How was this done previously? +def _get_import_path(obj): + obj_type = type(obj) + return f"{obj_type.__module__}.{obj_type.__qualname__}"