diff --git a/src/hydra_zen/__init__.py b/src/hydra_zen/__init__.py index 00d6e166..563c2e8b 100644 --- a/src/hydra_zen/__init__.py +++ b/src/hydra_zen/__init__.py @@ -9,6 +9,7 @@ to_yaml, ) from ._launch import launch +from ._like import like from ._version import get_versions from .structured_configs import ( ZenField, @@ -36,6 +37,7 @@ "ZenField", "make_custom_builds_fn", "launch", + "like", ] __version__ = get_versions()["version"] diff --git a/src/hydra_zen/_like.py b/src/hydra_zen/_like.py new file mode 100644 index 00000000..cce1d417 --- /dev/null +++ b/src/hydra_zen/_like.py @@ -0,0 +1,41 @@ +from typing import TypeVar, cast + +T = TypeVar("T") + + +class _Tracker: + def __init__(self, origin, tracks=None): + self.origin = origin + + self.tracker = [] if tracks is None else tracks.copy() + + def __repr__(self) -> str: + base = f"Like({repr(self.origin)})" + for item in self.tracker: + if isinstance(item, tuple): + _, args, kwargs = item + contents = "" + if args: + contents += ", ".join(repr(x) for x in args) + if kwargs: + if args: + contents += ", " + contents += ", ".join((f"{k}={v}" for k, v in kwargs.items())) + + base += f"({contents})" + else: + base += f".{item}" + return base + + def __call__(self, *args, **kwargs): + return _Tracker(self.origin, self.tracker + [("__call__", args, kwargs)]) + + def __getattr__(self, name): + # IPython will make attribute calls on objects; we don't want to track these. + if name != "_ipython_canary_method_should_not_exist_" and name != "__wrapped__": + return _Tracker(self.origin, self.tracker + [name]) + return self + + +def like(obj: T) -> T: + return cast(T, _Tracker(obj)) diff --git a/src/hydra_zen/structured_configs/_implementations.py b/src/hydra_zen/structured_configs/_implementations.py index 31b1bf6b..1f40cd45 100644 --- a/src/hydra_zen/structured_configs/_implementations.py +++ b/src/hydra_zen/structured_configs/_implementations.py @@ -60,6 +60,7 @@ ) from hydra_zen.typing._implementations import DataClass, HasTarget, _DataClass +from .._like import _Tracker from ._value_conversion import ZEN_VALUE_CONVERSION _T = TypeVar("_T") @@ -396,6 +397,20 @@ def wrapper(decorated_obj: Any) -> Any: return wrapper +def make_like(target, tracks): + from hydra._internal import utils as _hydra_internal_utils + + obj = _hydra_internal_utils._locate(target) + + for item in tracks: + if not isinstance(item, str): + _, args, kwargs = item + obj = obj(*args, **kwargs) + else: + obj = getattr(obj, item) + return obj + + def just(obj: Importable) -> Type[Just[Importable]]: """Returns a config that, when instantiated by Hydra, "just" returns the un-instantiated target-object. @@ -457,6 +472,9 @@ def just(obj: Importable) -> Type[Just[Importable]]: >>> conf.reduction_fn(conf.data) 6 """ + if isinstance(obj, _Tracker): + return builds(make_like, _utils.get_obj_path(obj.origin), obj.tracker) + try: obj_path = _utils.get_obj_path(obj) except AttributeError: