diff --git a/pyproject.toml b/pyproject.toml index 608261844..84eb412a7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,8 @@ urls = {Homepage = "https://github.com/zettaai/zetta_utils"} requires-python = ">3.8,<3.11" dependencies = [ "attrs >= 21.3", - "typeguard @ git+https://github.com/agronholm/typeguard.git@63b28186a9d00bb489f282f80b76834ac3d057fc", + "typeguard", + "typeguard @ git+https://github.com/agronholm/typeguard.git@f377be389765ed0db104b41d78fce3c45e72e149", "cachetools >= 5.2.0", "fsspec >= 2022.8.2", "rich >= 12.6.0", diff --git a/zetta_utils/builder/build.py b/zetta_utils/builder/build.py index c1efbac97..a73710491 100644 --- a/zetta_utils/builder/build.py +++ b/zetta_utils/builder/build.py @@ -2,7 +2,7 @@ from __future__ import annotations import json -from typing import Any, Callable, Final, Optional, Union +from typing import Any, Callable, Final import attrs from typeguard import typechecked @@ -22,8 +22,8 @@ @typechecked def build( - spec: Optional[Union[dict, list]] = None, - path: Optional[str] = None, + spec: dict | list | None = None, + path: str | None = None, ) -> Any: """Build an object from the given spec. @@ -171,9 +171,13 @@ class BuilderPartial: spec: dict[str, Any] _built_spec_kwargs: dict[str, Any] | None = attrs.field(init=False, default=None) + name: str | None = None + # name: str | None = None def get_display_name(self): # pragma: no cover # pretty print - if SPECIAL_KEYS["type"] in self.spec: + if self.name is not None: + return self.name + elif SPECIAL_KEYS["type"] in self.spec: return self.spec[SPECIAL_KEYS["type"]] else: return "BuilderPartial" diff --git a/zetta_utils/builder/built_in_registrations.py b/zetta_utils/builder/built_in_registrations.py index 83715fafd..7e9070674 100644 --- a/zetta_utils/builder/built_in_registrations.py +++ b/zetta_utils/builder/built_in_registrations.py @@ -11,7 +11,7 @@ @register("lambda", False) -def efficient_parse_lambda_str(lambda_str: str) -> Callable: +def efficient_parse_lambda_str(lambda_str: str, name: str | None = None) -> Callable: """Parses strings that are lambda functions""" if not isinstance(lambda_str, str): raise TypeError("`lambda_str` must be a string.") @@ -20,7 +20,7 @@ def efficient_parse_lambda_str(lambda_str: str) -> Callable: if len(lambda_str) > LAMBDA_STR_MAX_LENGTH: raise ValueError(f"`lambda_str` must be at most {LAMBDA_STR_MAX_LENGTH} characters.") - return BuilderPartial(spec={"@type": "invoke_lambda_str", "lambda_str": lambda_str}) + return BuilderPartial(spec={"@type": "invoke_lambda_str", "lambda_str": lambda_str}, name=name) @register("invoke_lambda_str", False) diff --git a/zetta_utils/layer/volumetric/layer_set/layer.py b/zetta_utils/layer/volumetric/layer_set/layer.py index 258b3b55e..279b6d42b 100644 --- a/zetta_utils/layer/volumetric/layer_set/layer.py +++ b/zetta_utils/layer/volumetric/layer_set/layer.py @@ -33,7 +33,7 @@ def __getitem__(self, idx: UserVolumetricIndex) -> dict[str, torch.Tensor]: return self.read_with_procs(idx=idx_backend) def __setitem__( - self, idx: UserVolumetricIndex, data: Mapping[str, torch.Tensor | int | float | bool] + self, idx: UserVolumetricIndex, data: Mapping[str, Union[torch.Tensor, int, float, bool]] ): idx_backend: VolumetricIndex | None = None idx_last: VolumetricIndex | None = None