-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
ghstack-source-id: 234f87615364bbe56634289e84438a8ca15aab60 ghstack-comment-id: 2564579214 Pull Request resolved: #594
- Loading branch information
1 parent
c250f03
commit 2ebcc2b
Showing
3 changed files
with
166 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
"""Get a nested attribute.""" | ||
|
||
from collections.abc import Mapping | ||
from typing import TypeVar, cast, overload | ||
|
||
T = TypeVar('T') | ||
|
||
|
||
@overload | ||
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: None = ...) -> object | None: ... | ||
@overload | ||
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: None = ...) -> T: ... | ||
@overload | ||
def getnestedattr(obj: object, *attrs: str, default: None = ..., return_type: type[T] = ...) -> T | None: ... | ||
@overload | ||
def getnestedattr(obj: object, *attrs: str, default: T = ..., return_type: type[T] = ...) -> T: ... | ||
|
||
|
||
def getnestedattr(obj: object, *attrs: str, default: T | None = None, return_type: type[T] | None = None) -> T | None: | ||
""" | ||
Get a nested attribute, or return a default if any step fails. | ||
Parameters | ||
---------- | ||
obj | ||
object to get attribute from | ||
attrs | ||
attribute names to get | ||
default | ||
value to return if any step fails | ||
return_type | ||
type to cast the result to (only for type hinting) | ||
""" | ||
if return_type is not None and default is not None and not isinstance(default, return_type): | ||
raise TypeError('default must be of the same type as return_type') | ||
for attr in attrs: | ||
try: | ||
obj = getattr(obj, attr) | ||
except AttributeError: | ||
return default | ||
return cast(T, obj) | ||
|
||
|
||
@overload | ||
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: None = ...) -> object | None: ... | ||
@overload | ||
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: None = ...) -> T: ... | ||
@overload | ||
def getnesteditem(obj: Mapping, *items: str, default: None = ..., return_type: type[T] = ...) -> T | None: ... | ||
@overload | ||
def getnesteditem(obj: Mapping, *items: str, default: T = ..., return_type: type[T] = ...) -> T: ... | ||
|
||
|
||
def getnesteditem(obj: Mapping, *items: str, default: T | None = None, return_type: type[T] | None = None) -> T | None: | ||
""" | ||
Get a nested item, or return a default if any step fails. | ||
Parameters | ||
---------- | ||
obj | ||
object to get attribute from | ||
items | ||
item names to get | ||
default | ||
value to return if any step fails | ||
return_type | ||
type to cast the result to (only for type hinting) | ||
""" | ||
if return_type is not None and default is not None and not isinstance(default, return_type): | ||
raise TypeError('default must be of the same type as return_type') | ||
for item in items: | ||
try: | ||
obj = obj[item] | ||
except (KeyError, TypeError): | ||
return default | ||
return cast(T, obj) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
from dataclasses import dataclass, field | ||
from typing import assert_type | ||
|
||
import pytest | ||
from mrpro.utils import getnestedattr, getnesteditem | ||
|
||
|
||
@dataclass | ||
class C: | ||
"""Test class for getnestedattr.""" | ||
|
||
c: int = 1 | ||
|
||
|
||
@dataclass | ||
class B: | ||
"""Test class for getnestedattr.""" | ||
|
||
b: C = field(default_factory=C) | ||
|
||
|
||
@dataclass | ||
class A: | ||
"""Test class for getnestedattr.""" | ||
|
||
a: B = field(default_factory=B) | ||
|
||
|
||
def test_getnestedattr_value() -> None: | ||
"""Test getnestedattr with a valid path.""" | ||
obj = A() | ||
actual = getnestedattr(obj, 'a', 'b', 'c') | ||
assert actual == 1 | ||
|
||
|
||
def test_getnestedattr_default() -> None: | ||
"""Test getnestedattr with a missing path and a default value.""" | ||
obj = A() | ||
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', default=2) | ||
assert_type(actual, int) | ||
assert actual == 2 | ||
|
||
|
||
def test_getnestedattr_type() -> None: | ||
"""Test getnestedattr with a missing path no default value, but a return type.""" | ||
obj = A() | ||
actual = getnestedattr(obj, 'a', 'doesnotexist', 'c', return_type=int) | ||
assert_type(actual, int | None) | ||
assert actual is None | ||
|
||
|
||
def test_getnestedattr_default_type_error() -> None: | ||
"""Test getnestedattr with a default value and a return type that do not match.""" | ||
obj = A() | ||
with pytest.raises(TypeError): | ||
getnestedattr(obj, 'a', default=2, return_type=str) | ||
|
||
|
||
def test_getnesteditem_value() -> None: | ||
"""Test getnesteditem with a valid path.""" | ||
obj = {'a': {'b': {'c': 1}}} | ||
actual = getnesteditem(obj, 'a', 'b', 'c') | ||
assert actual == 1 | ||
|
||
|
||
def test_getnesteditem_default() -> None: | ||
"""Test getnesteditem with a missing path and a default value.""" | ||
obj = {'a': {'b': {'c': 1}}} | ||
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', default=2) | ||
assert_type(actual, int) | ||
assert actual == 2 | ||
|
||
|
||
def test_getnesteditem_type() -> None: | ||
"""Test getnesteditem with a missing path no default value, but a return type.""" | ||
obj = {'a': {'b': {'c': 1}}} | ||
actual = getnesteditem(obj, 'a', 'doesnotexist', 'c', return_type=int) | ||
assert_type(actual, int | None) | ||
assert actual is None | ||
|
||
|
||
def test_getnesteditem_default_type_error() -> None: | ||
"""Test getnesteditem with a default value and a return type that do not match.""" | ||
obj = {'a': {'b': {'c': 1}}} | ||
with pytest.raises(TypeError): | ||
getnesteditem(obj, 'a', default=2, return_type=str) |