Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fzimmermann89 committed Jan 5, 2025
2 parents e790db9 + 15adfb4 commit 57414b9
Show file tree
Hide file tree
Showing 3 changed files with 133 additions and 45 deletions.
105 changes: 76 additions & 29 deletions src/mrpro/data/CheckDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class FieldTypeError(RuntimeCheckError):


@lru_cache
def _parse_string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType, ...], int | None]:
def string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType, ...], int | None]:
"""
Parse a string representation of a shape specification into dimension types.
Expand Down Expand Up @@ -134,10 +134,10 @@ def _parse_string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType,
return tuple(dims), index_variadic


def _shape_specification_to_string(dims: tuple[_DimType, ...]) -> str:
def shape_specification_to_string(dims: tuple[_DimType, ...]) -> str:
"""Convert a shape specification to a string.
The inverse of `_parse_string_to_shape_specification`.
The inverse of `string_to_shape_specification`.
Parameters
----------
Expand Down Expand Up @@ -323,7 +323,7 @@ def _check_named_variadic_dim(


@lru_cache
def _parse_string_to_size(shape_string: str, memo: ShapeMemo) -> tuple[int, ...]:
def string_to_size(shape_string: str, memo: ShapeMemo) -> tuple[int, ...]:
"""Convert a string representation of a shape to a shape.
Parameters
Expand All @@ -340,7 +340,7 @@ def _parse_string_to_size(shape_string: str, memo: ShapeMemo) -> tuple[int, ...]
"""
shape: list[int] = []
for dim in _parse_string_to_shape_specification(shape_string)[0]:
for dim in string_to_shape_specification(shape_string)[0]:
if isinstance(dim, _FixedDim):
shape.append(int(dim.name))
elif isinstance(dim, _NamedDim):
Expand Down Expand Up @@ -422,7 +422,7 @@ def __init__(self, dtype: torch.dtype | Sequence[torch.dtype] | None = None, sha
For more information, see `jaxtyping <https://docs.kidger.site/jaxtyping/api/array/>`__
"""
if shape is not None:
self.shape, self.index_variadic = _parse_string_to_shape_specification(shape)
self.shape, self.index_variadic = string_to_shape_specification(shape)
else:
self.shape = None
if dtype is not None:
Expand Down Expand Up @@ -452,7 +452,7 @@ def __repr__(self) -> str:
"""Get a string representation of the annotation."""
arguments = []
if self.shape is not None:
arguments.append(f"shape='{_shape_specification_to_string(self.shape) }'")
arguments.append(f"shape='{shape_specification_to_string(self.shape) }'")
if self.dtype is not None:
arguments.append(f'dtype={self.dtype}')
representation = f"Annotation({', '.join(arguments)})"
Expand Down Expand Up @@ -621,28 +621,7 @@ def check_invariants(self, recurse: bool = False) -> None:
name = elem.name
expected_type = elem.type
value = getattr(self, name)
if get_origin(expected_type) is Annotated:
expected_type, *annotations = get_args(expected_type)
else:
annotations = []
if get_origin(expected_type) is Union:
expected_type = reduce(lambda a, b: a | b, get_args(expected_type))
if not isinstance(expected_type, type | UnionType):
raise TypeError(
f'Expected a type, got {type(expected_type)}. This could be caused by __future__.annotations'
)
if not isinstance(value, expected_type):
raise FieldTypeError(f'Expected {expected_type} for {name}, got {type(value)}')
for annotation in annotations:
# there could be other annotations not related to the shape and dtype
if isinstance(annotation, Annotation):
try:
memo = annotation.check(value, memo=memo, strict=False)
except RuntimeCheckError as e:
raise type(e)(
f'Dataclass invariant violated for {self.__class__.__name__}.{name}: {e}\n {annotation}.'
) from None

memo = runtime_check(value, expected_type, memo)
if recurse and isinstance(value, CheckDataMixin):
value.check_invariants(recurse=True)

Expand Down Expand Up @@ -707,3 +686,71 @@ def __exit__(
self.check()
except RuntimeCheckError as e:
raise type(e)(e) from None


def runtime_check(value: object, type_hint: type | str | UnionType, memo: ShapeMemo | None = None) -> ShapeMemo | None:
"""Perform runtime type, dtype and shape checks on a value.
Parameters
----------
value
the object to check
type_hint
the type hint to check against
memo
a memoization object storing the shape of named dimensions
from previous checks. Will not be modified.
Returns
-------
memo
a new memoization object storing the shape of named dimensions from
previous and the current check.
Raises
------
TypeError
If the type hint is not a type or string that can be evaluated to a type.
FieldTypeError
If the value does not match the type hint.
DtypeError
If the value does not match the dtype hint.
ShapeError
If the value does not match the shape hint.
RuntimeCheckError
On union types, if all options failed. Contains a list of the exceptions
that were raised for the options.
"""
if isinstance(type_hint, str): # stringified type hint
try:
type_hint = eval(type_hint) # noqa: S307
except: # noqa: E722
raise TypeError(
f'Cannot evaluate type hint string {type_hint}. Consider removing stringified type annotations.'
) from None

if get_origin(type_hint) is Union or isinstance(type_hint, UnionType):
exceptions = []
for option in get_args(type_hint):
try:
# recursive call to check the value against the option
# if an option matches, we exit
return runtime_check(value, option, memo)
except RuntimeCheckError as e:
exceptions.append((option, e))
raise RuntimeCheckError(
'None of the options matched:\n'
+ '\n'.join(f' - {t} failed due to an {e.__class__.__name__}: {e}' for t, e in exceptions)
)
if get_origin(type_hint) is Annotated:
type_hint, *annotations = get_args(type_hint)
memo = runtime_check(value, type_hint, memo)
for annotation in annotations:
if isinstance(annotation, Annotation): # there might be other annotations
memo = annotation.check(value, memo=memo)
return memo
if not isinstance(type_hint, type):
raise TypeError(f'Expected a type, got {type(type_hint)}')
if not isinstance(value, type_hint):
raise FieldTypeError(f'Expected {type_hint.__name__}, got {type(value).__name__}')
return memo
2 changes: 1 addition & 1 deletion src/mrpro/data/MoveDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ def device(self) -> torch.device | None:
Looks at each field of a dataclass implementing a device attribute,
such as torch.Tensors or MoveDataMixin instances. If the devices
of the fields differ, an InconsistentDeviceError is raised, otherwise
of the fields iffer, an InconsistentDeviceError is raised, otherwise
the device is returned. If no field implements a device attribute,
None is returned.
Expand Down
71 changes: 56 additions & 15 deletions tests/data/test_checkdatamixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,16 @@
Annotation,
CheckDataMixin,
DtypeError,
FieldTypeError,
RuntimeCheckError,
ShapeError,
ShapeMemo,
SpecificationError,
SuspendDataChecks,
_FixedDim,
_NamedDim,
_parse_string_to_shape_specification,
_parse_string_to_size,
_shape_specification_to_string,
string_to_shape_specification,
string_to_size,
shape_specification_to_string,
)


Expand Down Expand Up @@ -70,6 +70,17 @@ class WithOptional(CheckDataMixin):
Union[None, torch.Tensor], Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0') # noqa: UP007
] = None
integer: int | None = None
outer_or_tensor: (
Annotated[torch.Tensor, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')]
| None
) = None
outer_optional_tensor: Optional[ # noqa: UP007
Annotated[torch.Tensor, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')]
] = None
outer_union_tensor: Union[ # noqa: UP007
None,
Annotated[torch.Tensor, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')],
] = None


def test_slots() -> None:
Expand All @@ -87,19 +98,49 @@ def test_frozen() -> None:
def test_optional() -> None:
"""Test dataclass with None-able attributes"""
WithOptional()
WithOptional(torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), 1)
WithOptional(
torch.ones(1, 1, 1, 1),
torch.ones(1, 1, 1, 1),
torch.ones(1, 1, 1, 1),
torch.ones(1, 1, 1, 1),
1,
torch.ones(1, 1, 1, 1),
torch.ones(1, 1, 1, 1),
torch.ones(1, 1, 1, 1),
)


def test_optional_fail() -> None:
"""Test exceptions with dataclass with None-able attributes"""
with pytest.raises(ShapeError):
with pytest.raises(RuntimeCheckError):
WithOptional(None, torch.ones(1))
with pytest.raises(ShapeError):
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, torch.ones(1))
with pytest.raises(ShapeError):
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, torch.ones(1))
with pytest.raises(FieldTypeError):
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, None, torch.ones(1))
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, None, None, torch.ones(1))
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, None, None, None, torch.ones(1))

with pytest.raises(RuntimeCheckError):
WithOptional('not a tensor') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, 'not a tensor') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, 'not a tensor') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, 'not a tensor') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, 'not an integer') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, None, 'not a tensor') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, None, None, 'not a tensor') # type:ignore[arg-type]
with pytest.raises(RuntimeCheckError):
WithOptional(None, None, None, None, None, None, None, 'not a tensor') # type:ignore[arg-type]


def test_checked_dataclass_success() -> None:
Expand Down Expand Up @@ -221,7 +262,7 @@ def test_dype_fail() -> None:
)
def test_parse_shape(string: str, expected: tuple) -> None:
"""Test parsing of shape string"""
parsed = _parse_string_to_shape_specification(string)
parsed = string_to_shape_specification(string)
assert parsed == expected


Expand All @@ -237,7 +278,7 @@ def test_parse_shape(string: str, expected: tuple) -> None:
)
def test_specification_to_string(expected: str, shape: tuple) -> None:
"""Test conversion of parsed specification back to a string"""
string = _shape_specification_to_string(shape)
string = shape_specification_to_string(shape)
assert string == expected


Expand All @@ -248,12 +289,12 @@ def test_string_to_shape() -> None:
memo = instance._memo # type:ignore[attr-defined]
memo = memo | {'fromdict': 8}
memo = memo | ShapeMemo(frommemo=9)
shape = _parse_string_to_size('fromdict frommemo fixed=3 dim 1', memo)
shape = string_to_size('fromdict frommemo fixed=3 dim 1', memo)
assert shape == (8, 9, 3, 2, 1)

with pytest.raises(KeyError):
_parse_string_to_size('doesnotexist', memo)
string_to_size('doesnotexist', memo)
with pytest.raises(KeyError):
_parse_string_to_size('*doesnotexist', memo)
string_to_size('*doesnotexist', memo)
with pytest.raises(SpecificationError):
_parse_string_to_size('...', memo)
string_to_size('...', memo)

0 comments on commit 57414b9

Please sign in to comment.