diff --git a/src/mrpro/utils/__init__.py b/src/mrpro/utils/__init__.py index ee66fc3c0..894cc4e59 100644 --- a/src/mrpro/utils/__init__.py +++ b/src/mrpro/utils/__init__.py @@ -10,11 +10,14 @@ from mrpro.utils.split_idx import split_idx from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop +from mrpro.utils.getnested import getnestedattr, getnesteditem __all__ = [ "Indexer", "broadcast_right", "fill_range_", + "getnestedattr", + "getnesteditem", "reduce_view", "remove_repeat", "reshape_broadcasted", @@ -26,4 +29,4 @@ "unsqueeze_left", "unsqueeze_right", "zero_pad_or_crop" -] +] \ No newline at end of file diff --git a/src/mrpro/utils/getnested.py b/src/mrpro/utils/getnested.py new file mode 100644 index 000000000..9d1d61046 --- /dev/null +++ b/src/mrpro/utils/getnested.py @@ -0,0 +1,76 @@ +"""Get a nested attribute.""" + +from collections.abc import Mapping +from typing import TypeVar, cast, overload + +T = TypeVar('T') + + +@overload +def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: None = ...) -> object | None: ... +@overload +def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: None = ...) -> T: ... +@overload +def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: type[T] = ...) -> T | None: ... +@overload +def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: type[T] = ...) -> T: ... + + +def getnestedattr(obj: object, *attrs: str, default: T | None = None, return_type: type[T] | None = None) -> T | None: + """ + Get a nested attribute, or return a default if any step fails. + + Parameters + ---------- + obj + object to get attribute from + attrs + attribute names to get + default + value to return if any step fails + return_type + type to cast the result to (only for type hinting) + """ + if return_type is not None and default is not None and not isinstance(default, return_type): + raise TypeError('default must be of the same type as return_type') + for attr in attrs: + try: + obj = getattr(obj, attr) + except AttributeError: + return default + return cast(T, obj) + + +@overload +def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: None = ...) -> object | None: ... +@overload +def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: None = ...) -> T: ... +@overload +def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: type[T] = ...) -> T | None: ... +@overload +def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: type[T] = ...) -> T: ... + + +def getnesteditem(obj: Mapping, *items: str, default: T | None = None, return_type: type[T] | None = None) -> T | None: + """ + Get a nested item, or return a default if any step fails. + + Parameters + ---------- + obj + object to get attribute from + items + item names to get + default + value to return if any step fails + return_type + type to cast the result to (only for type hinting) + """ + if return_type is not None and default is not None and not isinstance(default, return_type): + raise TypeError('default must be of the same type as return_type') + for item in items: + try: + obj = obj[item] + except (KeyError, TypeError): + return default + return cast(T, obj) diff --git a/tests/utils/test_getnested.py b/tests/utils/test_getnested.py new file mode 100644 index 000000000..17773774b --- /dev/null +++ b/tests/utils/test_getnested.py @@ -0,0 +1,86 @@ +from dataclasses import dataclass, field +from typing import assert_type + +import pytest +from mrpro.utils import getnestedattr, getnesteditem + + +@dataclass +class C: + """Test class for getnestedattr.""" + + c: int = 1 + + +@dataclass +class B: + """Test class for getnestedattr.""" + + b: C = field(default_factory=C) + + +@dataclass +class A: + """Test class for getnestedattr.""" + + a: B = field(default_factory=B) + + +def test_getnestedattr_value() -> None: + """Test getnestedattr with a valid path.""" + obj = A() + actual = getnestedattr(obj, 'a', 'b', 'c') + assert actual == 1 + + +def test_getnestedattr_default() -> None: + """Test getnestedattr with a missing path and a default value.""" + obj = A() + actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', default=2) + assert_type(actual, int) + assert actual == 2 + + +def test_getnestedattr_type() -> None: + """Test getnestedattr with a missing path no default value, but a return type.""" + obj = A() + actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', return_type=int) + assert_type(actual, int | None) + assert actual is None + + +def test_getnestedattr_default_type_error() -> None: + """Test getnestedattr with a default value and a return type that do not match.""" + obj = A() + with pytest.raises(TypeError): + getnestedattr(obj, 'a', default=2, return_type=str) + + +def test_getnesteditem_value() -> None: + """Test getnesteditem with a valid path.""" + obj = {'a': {'b': {'c': 1}}} + actual = getnesteditem(obj, 'a', 'b', 'c') + assert actual == 1 + + +def test_getnesteditem_default() -> None: + """Test getnesteditem with a missing path and a default value.""" + obj = {'a': {'b': {'c': 1}}} + actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', default=2) + assert_type(actual, int) + assert actual == 2 + + +def test_getnesteditem_type() -> None: + """Test getnesteditem with a missing path no default value, but a return type.""" + obj = {'a': {'b': {'c': 1}}} + actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', return_type=int) + assert_type(actual, int | None) + assert actual is None + + +def test_getnesteditem_default_type_error() -> None: + """Test getnesteditem with a default value and a return type that do not match.""" + obj = {'a': {'b': {'c': 1}}} + with pytest.raises(TypeError): + getnesteditem(obj, 'a', default=2, return_type=str)