From 15adfb456fc04d5328e0c98f15c4c802d1dab9ba Mon Sep 17 00:00:00 2001 From: Felix Zimmermann Date: Mon, 6 Jan 2025 00:47:29 +0100 Subject: [PATCH] Update [ghstack-poisoned] --- src/mrpro/data/CheckDataMixin.py | 105 +++++++++++++++++++++--------- src/mrpro/data/MoveDataMixin.py | 2 +- tests/data/test_checkdatamixin.py | 71 +++++++++++++++----- 3 files changed, 133 insertions(+), 45 deletions(-) diff --git a/src/mrpro/data/CheckDataMixin.py b/src/mrpro/data/CheckDataMixin.py index ba42c50e..a87f6889 100644 --- a/src/mrpro/data/CheckDataMixin.py +++ b/src/mrpro/data/CheckDataMixin.py @@ -60,7 +60,7 @@ class FieldTypeError(RuntimeCheckError): @lru_cache -def _parse_string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType, ...], int | None]: +def string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType, ...], int | None]: """ Parse a string representation of a shape specification into dimension types. @@ -134,10 +134,10 @@ def _parse_string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType, return tuple(dims), index_variadic -def _shape_specification_to_string(dims: tuple[_DimType, ...]) -> str: +def shape_specification_to_string(dims: tuple[_DimType, ...]) -> str: """Convert a shape specification to a string. - The inverse of `_parse_string_to_shape_specification`. + The inverse of `string_to_shape_specification`. Parameters ---------- @@ -323,7 +323,7 @@ def _check_named_variadic_dim( @lru_cache -def _parse_string_to_size(shape_string: str, memo: ShapeMemo) -> tuple[int, ...]: +def string_to_size(shape_string: str, memo: ShapeMemo) -> tuple[int, ...]: """Convert a string representation of a shape to a shape. Parameters @@ -340,7 +340,7 @@ def _parse_string_to_size(shape_string: str, memo: ShapeMemo) -> tuple[int, ...] """ shape: list[int] = [] - for dim in _parse_string_to_shape_specification(shape_string)[0]: + for dim in string_to_shape_specification(shape_string)[0]: if isinstance(dim, _FixedDim): shape.append(int(dim.name)) elif isinstance(dim, _NamedDim): @@ -422,7 +422,7 @@ def __init__(self, dtype: torch.dtype | Sequence[torch.dtype] | None = None, sha For more information, see `jaxtyping `__ """ if shape is not None: - self.shape, self.index_variadic = _parse_string_to_shape_specification(shape) + self.shape, self.index_variadic = string_to_shape_specification(shape) else: self.shape = None if dtype is not None: @@ -452,7 +452,7 @@ def __repr__(self) -> str: """Get a string representation of the annotation.""" arguments = [] if self.shape is not None: - arguments.append(f"shape='{_shape_specification_to_string(self.shape) }'") + arguments.append(f"shape='{shape_specification_to_string(self.shape) }'") if self.dtype is not None: arguments.append(f'dtype={self.dtype}') representation = f"Annotation({', '.join(arguments)})" @@ -621,28 +621,7 @@ def check_invariants(self, recurse: bool = False) -> None: name = elem.name expected_type = elem.type value = getattr(self, name) - if get_origin(expected_type) is Annotated: - expected_type, *annotations = get_args(expected_type) - else: - annotations = [] - if get_origin(expected_type) is Union: - expected_type = reduce(lambda a, b: a | b, get_args(expected_type)) - if not isinstance(expected_type, type | UnionType): - raise TypeError( - f'Expected a type, got {type(expected_type)}. This could be caused by __future__.annotations' - ) - if not isinstance(value, expected_type): - raise FieldTypeError(f'Expected {expected_type} for {name}, got {type(value)}') - for annotation in annotations: - # there could be other annotations not related to the shape and dtype - if isinstance(annotation, Annotation): - try: - memo = annotation.check(value, memo=memo, strict=False) - except RuntimeCheckError as e: - raise type(e)( - f'Dataclass invariant violated for {self.__class__.__name__}.{name}: {e}\n {annotation}.' - ) from None - + memo = runtime_check(value, expected_type, memo) if recurse and isinstance(value, CheckDataMixin): value.check_invariants(recurse=True) @@ -707,3 +686,71 @@ def __exit__( self.check() except RuntimeCheckError as e: raise type(e)(e) from None + + +def runtime_check(value: object, type_hint: type | str | UnionType, memo: ShapeMemo | None = None) -> ShapeMemo | None: + """Perform runtime type, dtype and shape checks on a value. + + Parameters + ---------- + value + the object to check + type_hint + the type hint to check against + memo + a memoization object storing the shape of named dimensions + from previous checks. Will not be modified. + + Returns + ------- + memo + a new memoization object storing the shape of named dimensions from + previous and the current check. + + Raises + ------ + TypeError + If the type hint is not a type or string that can be evaluated to a type. + FieldTypeError + If the value does not match the type hint. + DtypeError + If the value does not match the dtype hint. + ShapeError + If the value does not match the shape hint. + RuntimeCheckError + On union types, if all options failed. Contains a list of the exceptions + that were raised for the options. + """ + if isinstance(type_hint, str): # stringified type hint + try: + type_hint = eval(type_hint) # noqa: S307 + except: # noqa: E722 + raise TypeError( + f'Cannot evaluate type hint string {type_hint}. Consider removing stringified type annotations.' + ) from None + + if get_origin(type_hint) is Union or isinstance(type_hint, UnionType): + exceptions = [] + for option in get_args(type_hint): + try: + # recursive call to check the value against the option + # if an option matches, we exit + return runtime_check(value, option, memo) + except RuntimeCheckError as e: + exceptions.append((option, e)) + raise RuntimeCheckError( + 'None of the options matched:\n' + + '\n'.join(f' - {t} failed due to an {e.__class__.__name__}: {e}' for t, e in exceptions) + ) + if get_origin(type_hint) is Annotated: + type_hint, *annotations = get_args(type_hint) + memo = runtime_check(value, type_hint, memo) + for annotation in annotations: + if isinstance(annotation, Annotation): # there might be other annotations + memo = annotation.check(value, memo=memo) + return memo + if not isinstance(type_hint, type): + raise TypeError(f'Expected a type, got {type(type_hint)}') + if not isinstance(value, type_hint): + raise FieldTypeError(f'Expected {type_hint.__name__}, got {type(value).__name__}') + return memo diff --git a/src/mrpro/data/MoveDataMixin.py b/src/mrpro/data/MoveDataMixin.py index 99bcb3df..6f4a59d4 100644 --- a/src/mrpro/data/MoveDataMixin.py +++ b/src/mrpro/data/MoveDataMixin.py @@ -391,7 +391,7 @@ def device(self) -> torch.device | None: Looks at each field of a dataclass implementing a device attribute, such as torch.Tensors or MoveDataMixin instances. If the devices - of the fields differ, an InconsistentDeviceError is raised, otherwise + of the fields iffer, an InconsistentDeviceError is raised, otherwise the device is returned. If no field implements a device attribute, None is returned. diff --git a/tests/data/test_checkdatamixin.py b/tests/data/test_checkdatamixin.py index d01aacb6..06025457 100644 --- a/tests/data/test_checkdatamixin.py +++ b/tests/data/test_checkdatamixin.py @@ -7,16 +7,16 @@ Annotation, CheckDataMixin, DtypeError, - FieldTypeError, + RuntimeCheckError, ShapeError, ShapeMemo, SpecificationError, SuspendDataChecks, _FixedDim, _NamedDim, - _parse_string_to_shape_specification, - _parse_string_to_size, - _shape_specification_to_string, + string_to_shape_specification, + string_to_size, + shape_specification_to_string, ) @@ -70,6 +70,17 @@ class WithOptional(CheckDataMixin): Union[None, torch.Tensor], Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0') # noqa: UP007 ] = None integer: int | None = None + outer_or_tensor: ( + Annotated[torch.Tensor, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')] + | None + ) = None + outer_optional_tensor: Optional[ # noqa: UP007 + Annotated[torch.Tensor, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')] + ] = None + outer_union_tensor: Union[ # noqa: UP007 + None, + Annotated[torch.Tensor, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')], + ] = None def test_slots() -> None: @@ -87,19 +98,49 @@ def test_frozen() -> None: def test_optional() -> None: """Test dataclass with None-able attributes""" WithOptional() - WithOptional(torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), 1) + WithOptional( + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + 1, + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + torch.ones(1, 1, 1, 1), + ) def test_optional_fail() -> None: """Test exceptions with dataclass with None-able attributes""" - with pytest.raises(ShapeError): + with pytest.raises(RuntimeCheckError): WithOptional(None, torch.ones(1)) - with pytest.raises(ShapeError): + with pytest.raises(RuntimeCheckError): WithOptional(None, None, torch.ones(1)) - with pytest.raises(ShapeError): + with pytest.raises(RuntimeCheckError): WithOptional(None, None, None, torch.ones(1)) - with pytest.raises(FieldTypeError): + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, None, None, torch.ones(1)) + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, None, None, None, torch.ones(1)) + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, None, None, None, None, torch.ones(1)) + + with pytest.raises(RuntimeCheckError): + WithOptional('not a tensor') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): + WithOptional(None, 'not a tensor') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, 'not a tensor') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, 'not a tensor') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): WithOptional(None, None, None, None, 'not an integer') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, None, None, 'not a tensor') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, None, None, None, 'not a tensor') # type:ignore[arg-type] + with pytest.raises(RuntimeCheckError): + WithOptional(None, None, None, None, None, None, None, 'not a tensor') # type:ignore[arg-type] def test_checked_dataclass_success() -> None: @@ -221,7 +262,7 @@ def test_dype_fail() -> None: ) def test_parse_shape(string: str, expected: tuple) -> None: """Test parsing of shape string""" - parsed = _parse_string_to_shape_specification(string) + parsed = string_to_shape_specification(string) assert parsed == expected @@ -237,7 +278,7 @@ def test_parse_shape(string: str, expected: tuple) -> None: ) def test_specification_to_string(expected: str, shape: tuple) -> None: """Test conversion of parsed specification back to a string""" - string = _shape_specification_to_string(shape) + string = shape_specification_to_string(shape) assert string == expected @@ -248,12 +289,12 @@ def test_string_to_shape() -> None: memo = instance._memo # type:ignore[attr-defined] memo = memo | {'fromdict': 8} memo = memo | ShapeMemo(frommemo=9) - shape = _parse_string_to_size('fromdict frommemo fixed=3 dim 1', memo) + shape = string_to_size('fromdict frommemo fixed=3 dim 1', memo) assert shape == (8, 9, 3, 2, 1) with pytest.raises(KeyError): - _parse_string_to_size('doesnotexist', memo) + string_to_size('doesnotexist', memo) with pytest.raises(KeyError): - _parse_string_to_size('*doesnotexist', memo) + string_to_size('*doesnotexist', memo) with pytest.raises(SpecificationError): - _parse_string_to_size('...', memo) + string_to_size('...', memo)