Skip to content

Commit

Permalink
getnested
Browse files Browse the repository at this point in the history
ghstack-source-id: f75571db8f1c4b1588e754a91dcf7f341c8a31ee
ghstack-comment-id: 2564579214
Pull Request resolved: #594
  • Loading branch information
fzimmermann89 committed Dec 29, 2024
1 parent 51b4558 commit d767dce
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 1 deletion.
5 changes: 4 additions & 1 deletion src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.getnested import getnestedattr, getnesteditem

__all__ = [
"Indexer",
"broadcast_right",
"fill_range_",
"getnestedattr",
"getnesteditem",
"reduce_view",
"remove_repeat",
"reshape_broadcasted",
Expand All @@ -26,4 +29,4 @@
"unsqueeze_left",
"unsqueeze_right",
"zero_pad_or_crop"
]
]
76 changes: 76 additions & 0 deletions src/mrpro/utils/getnested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Get a nested attribute."""

from collections.abc import Mapping
from typing import TypeVar, cast, overload

T = TypeVar('T')


@overload
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: None = ...) -> object | None: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: None = ...) -> T: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: type[T] = ...) -> T | None: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: type[T] = ...) -> T: ...


def getnestedattr(obj: object, *attrs: str, default: T | None = None, return_type: type[T] | None = None) -> T | None:
"""
Get a nested attribute, or return a default if any step fails.
Parameters
----------
obj
object to get attribute from
attrs
attribute names to get
default
value to return if any step fails
return_type
type to cast the result to (only for type hinting)
"""
if return_type is not None and default is not None and not isinstance(default, return_type):
raise TypeError('default must be of the same type as return_type')
for attr in attrs:
try:
obj = getattr(obj, attr)
except AttributeError:
return default
return cast(T, obj)


@overload
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: None = ...) -> object | None: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: None = ...) -> T: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: type[T] = ...) -> T | None: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: type[T] = ...) -> T: ...


def getnesteditem(obj: Mapping, *items: str, default: T | None = None, return_type: type[T] | None = None) -> T | None:
"""
Get a nested item, or return a default if any step fails.
Parameters
----------
obj
object to get attribute from
items
item names to get
default
value to return if any step fails
return_type
type to cast the result to (only for type hinting)
"""
if return_type is not None and default is not None and not isinstance(default, return_type):
raise TypeError('default must be of the same type as return_type')
for item in items:
try:
obj = obj[item]
except (KeyError, TypeError):
return default
return cast(T, obj)
86 changes: 86 additions & 0 deletions tests/utils/test_getnested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import assert_type

import pytest
from mrpro.utils import getnestedattr, getnesteditem


@dataclass
class C:
"""Test class for getnestedattr."""

c: int = 1


@dataclass
class B:
"""Test class for getnestedattr."""

b: C = field(default_factory=C)


@dataclass
class A:
"""Test class for getnestedattr."""

a: B = field(default_factory=B)


def test_getnestedattr_value() -> None:
"""Test getnestedattr with a valid path."""
obj = A()
actual = getnestedattr(obj, 'a', 'b', 'c')
assert actual == 1


def test_getnestedattr_default() -> None:
"""Test getnestedattr with a missing path and a default value."""
obj = A()
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', default=2)
assert_type(actual, int)
assert actual == 2


def test_getnestedattr_type() -> None:
"""Test getnestedattr with a missing path no default value, but a return type."""
obj = A()
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', return_type=int)
assert_type(actual, int | None)
assert actual is None


def test_getnestedattr_default_type_error() -> None:
"""Test getnestedattr with a default value and a return type that do not match."""
obj = A()
with pytest.raises(TypeError):
getnestedattr(obj, 'a', default=2, return_type=str)


def test_getnesteditem_value() -> None:
"""Test getnesteditem with a valid path."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'b', 'c')
assert actual == 1


def test_getnesteditem_default() -> None:
"""Test getnesteditem with a missing path and a default value."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', default=2)
assert_type(actual, int)
assert actual == 2


def test_getnesteditem_type() -> None:
"""Test getnesteditem with a missing path no default value, but a return type."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', return_type=int)
assert_type(actual, int | None)
assert actual is None


def test_getnesteditem_default_type_error() -> None:
"""Test getnesteditem with a default value and a return type that do not match."""
obj = {'a': {'b': {'c': 1}}}
with pytest.raises(TypeError):
getnesteditem(obj, 'a', default=2, return_type=str)

0 comments on commit d767dce

Please sign in to comment.