Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

getnested #594

Draft
wants to merge 5 commits into
base: gh/fzimmermann89/37/head
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion src/mrpro/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,14 @@
from mrpro.utils.split_idx import split_idx
from mrpro.utils.reshape import broadcast_right, unsqueeze_left, unsqueeze_right, reduce_view, reshape_broadcasted
from mrpro.utils.zero_pad_or_crop import zero_pad_or_crop
from mrpro.utils.getnested import getnestedattr, getnesteditem

__all__ = [
"Indexer",
"broadcast_right",
"fill_range_",
"getnestedattr",
"getnesteditem",
"reduce_view",
"remove_repeat",
"reshape_broadcasted",
Expand All @@ -26,4 +29,4 @@
"unsqueeze_left",
"unsqueeze_right",
"zero_pad_or_crop"
]
]
76 changes: 76 additions & 0 deletions src/mrpro/utils/getnested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Get a nested attribute."""

from collections.abc import Mapping
from typing import TypeVar, cast, overload

T = TypeVar('T')


@overload
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: None = ...) -> object | None: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: None = ...) -> T: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: type[T] = ...) -> T | None: ...
@overload
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: type[T] = ...) -> T: ...


def getnestedattr(obj: object, *attrs: str, default: T | None = None, return_type: type[T] | None = None) -> T | None:
"""
Get a nested attribute, or return a default if any step fails.

Parameters
----------
obj
object to get attribute from
attrs
attribute names to get
default
value to return if any step fails
return_type
type to cast the result to (only for type hinting)
"""
if return_type is not None and default is not None and not isinstance(default, return_type):
raise TypeError('default must be of the same type as return_type')
for attr in attrs:
try:
obj = getattr(obj, attr)
except AttributeError:
return default
return cast(T, obj)


@overload
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: None = ...) -> object | None: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: None = ...) -> T: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: type[T] = ...) -> T | None: ...
@overload
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: type[T] = ...) -> T: ...


def getnesteditem(obj: Mapping, *items: str, default: T | None = None, return_type: type[T] | None = None) -> T | None:
"""
Get a nested item, or return a default if any step fails.

Parameters
----------
obj
object to get attribute from
items
item names to get
default
value to return if any step fails
return_type
type to cast the result to (only for type hinting)
"""
if return_type is not None and default is not None and not isinstance(default, return_type):
raise TypeError('default must be of the same type as return_type')
for item in items:
try:
obj = obj[item]
except (KeyError, TypeError):
return default
return cast(T, obj)
86 changes: 86 additions & 0 deletions tests/utils/test_getnested.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from dataclasses import dataclass, field
from typing import assert_type

import pytest
from mrpro.utils import getnestedattr, getnesteditem


@dataclass
class C:
"""Test class for getnestedattr."""

c: int = 1


@dataclass
class B:
"""Test class for getnestedattr."""

b: C = field(default_factory=C)


@dataclass
class A:
"""Test class for getnestedattr."""

a: B = field(default_factory=B)


def test_getnestedattr_value() -> None:
"""Test getnestedattr with a valid path."""
obj = A()
actual = getnestedattr(obj, 'a', 'b', 'c')
assert actual == 1


def test_getnestedattr_default() -> None:
"""Test getnestedattr with a missing path and a default value."""
obj = A()
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', default=2)
assert_type(actual, int)
assert actual == 2


def test_getnestedattr_type() -> None:
"""Test getnestedattr with a missing path no default value, but a return type."""
obj = A()
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', return_type=int)
assert_type(actual, int | None)
assert actual is None


def test_getnestedattr_default_type_error() -> None:
"""Test getnestedattr with a default value and a return type that do not match."""
obj = A()
with pytest.raises(TypeError):
getnestedattr(obj, 'a', default=2, return_type=str)


def test_getnesteditem_value() -> None:
"""Test getnesteditem with a valid path."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'b', 'c')
assert actual == 1


def test_getnesteditem_default() -> None:
"""Test getnesteditem with a missing path and a default value."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', default=2)
assert_type(actual, int)
assert actual == 2


def test_getnesteditem_type() -> None:
"""Test getnesteditem with a missing path no default value, but a return type."""
obj = {'a': {'b': {'c': 1}}}
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', return_type=int)
assert_type(actual, int | None)
assert actual is None


def test_getnesteditem_default_type_error() -> None:
"""Test getnesteditem with a default value and a return type that do not match."""
obj = {'a': {'b': {'c': 1}}}
with pytest.raises(TypeError):
getnesteditem(obj, 'a', default=2, return_type=str)
Loading