Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
fzimmermann89 committed Jan 2, 2025
1 parent c889b6a commit fbaf94a
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 28 deletions.
20 changes: 11 additions & 9 deletions src/mrpro/data/CheckDataMixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
from collections.abc import Iterator, Mapping, Sequence
from dataclasses import Field, dataclass, fields
from functools import lru_cache, reduce
from types import TracebackType
from typing import Annotated, Any, ClassVar, Literal, TypeAlias, get_args, get_origin
from types import TracebackType, UnionType
from typing import Annotated, Any, ClassVar, Literal, TypeAlias, Union, get_args, get_origin

import torch
from typing_extensions import Protocol, Self, runtime_checkable
Expand Down Expand Up @@ -99,7 +99,7 @@ def _parse_string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType,
if elem != '...':
raise SpecificationError("Anonymous multiple axes '...' must be used on its own; " f'got {elem}')
if index_variadic is not None:
raise SpecificationError('Cannot use variadic specifiers (`*name` or `...`) ' 'more than once.')
raise SpecificationError('Cannot use variadic specifiers (`*name` or `...`) more than once.')
index_variadic = index
dims.append(_anonymous_variadic_dim)
continue
Expand All @@ -118,7 +118,7 @@ def _parse_string_to_shape_specification(dim_str: str) -> tuple[tuple[_DimType,
elif variadic:
dims.append(_NamedVariadicDim(elem, broadcastable))
if index_variadic is not None:
raise SpecificationError('Cannot use variadic specifiers (`*name` or `...`) ' 'more than once.')
raise SpecificationError('Cannot use variadic specifiers (`*name` or `...`) more than once.')
index_variadic = index
elif anonymous:
dims.append(_anonymous_dim)
Expand Down Expand Up @@ -179,10 +179,10 @@ def __init__(self, *arg, **kwargs):
self._d = dict(*arg, **kwargs)

def __getitem__(self, key: str) -> tuple[tuple[int, ...], bool]:
"""Get the shape of a named dimension."""
"""Get the shape and if it can be broadcasted of a named dimension."""
value = self._d[key]
if isinstance(value, int):
return (value,), False
return (value,), False # Default to non-broadcastable
return value

def __len__(self) -> int:
Expand Down Expand Up @@ -414,7 +414,7 @@ def __init__(self, dtype: torch.dtype | Sequence[torch.dtype] | None = None, sha
were the string will be ignored and only serves as documentation.
Example:
`*#batch channel=2 depth #height #width` indicates that the object has at least dimensions.
`*#batch channel=2 depth #height #width` indicates that the object has at least dimensions.
The last two dimensions are named `height` and `width`.
These must be broadcastable for all objects using the same memoization object.
The depth dimensions must match exactly for all objects using the same memoization object.
Expand Down Expand Up @@ -627,7 +627,9 @@ def check_invariants(self, recurse: bool = False) -> None:
expected_type, *annotations = get_args(expected_type)
else:
annotations = []
if not isinstance(expected_type, type):
if get_origin(expected_type) is Union:
expected_type = reduce(lambda a, b: a | b, get_args(expected_type))
if not isinstance(expected_type, type | UnionType):
raise TypeError(
f'Expected a type, got {type(expected_type)}. This could be caused by __future__.annotations'
)
Expand All @@ -637,7 +639,7 @@ def check_invariants(self, recurse: bool = False) -> None:
# there could be other annotations not related to the shape and dtype
if isinstance(annotation, Annotation):
try:
memo = annotation.check(value, memo=memo, strict=True)
memo = annotation.check(value, memo=memo, strict=False)
except RuntimeCheckError as e:
raise type(e)(
f'Dataclass invariant violated for {self.__class__.__name__}.{name}: {e}\n {annotation}.'
Expand Down
121 changes: 102 additions & 19 deletions tests/data/test_checkdatamixin.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from dataclasses import dataclass
from typing import Annotated
from typing import Annotated, Optional, Union

import pytest
import torch
from mrpro.data.CheckDataMixin import (
Annotation,
CheckDataMixin,
DtypeError,
FieldTypeError,
ShapeError,
ShapeMemo,
SpecificationError,
SuspendDataChecks,
_FixedDim,
_NamedDim,
Expand Down Expand Up @@ -40,30 +44,80 @@ class CheckedDataClass(CheckDataMixin):
class Slots(CheckDataMixin):
"""A test dataclass with slots"""

tensor: Annotated[torch.Tensor, Annotation(shape='... _ 5 dim')]
tensor1: Annotated[torch.Tensor, Annotation(shape='dim')]
tensor2: Annotated[torch.Tensor, Annotation(shape='... _ 5 dim')]


@dataclass(frozen=True)
class Frozen(CheckDataMixin):
"""A frozen test dataclass"""

tensor: Annotated[torch.Tensor, Annotation(dtype=(torch.float32,))]
tensor1: Annotated[torch.Tensor, Annotation(dtype=(torch.float32,))]


def test_checked_dataclass_success():
@dataclass
class WithOptional(CheckDataMixin):
"""A dataclass with None-able fields"""

tensor: torch.Tensor | None = None
or_tensor: Annotated[
torch.Tensor | None, Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0')
] = None
optional_tensor: Annotated[
Optional[torch.Tensor], Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0') # noqa: UP007
] = None
union_tensor: Annotated[
Union[None, torch.Tensor], Annotation(dtype=(torch.float32, torch.float64), shape='*#other coil #k2 #k1 #k0') # noqa: UP007
] = None
integer: int | None = None


def test_slots() -> None:
"""Test dataclass with slots."""
# Declaration of the dataclass is already the main test
Slots(torch.zeros(2), torch.zeros(1, 5, 2))


def test_frozen() -> None:
"""Test frozen dataclass."""
# Declaration of the dataclass is already the main test
Frozen(torch.ones(1))


def test_optional() -> None:
"""Test dataclass with None-able attributes"""
WithOptional()
WithOptional(torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), torch.ones(1, 1, 1, 1), 1)


def test_optional_fail() -> None:
"""Test exceptions with dataclass with None-able attributes"""
with pytest.raises(ShapeError):
WithOptional(None, torch.ones(1))
with pytest.raises(ShapeError):
WithOptional(None, None, torch.ones(1))
with pytest.raises(ShapeError):
WithOptional(None, None, None, torch.ones(1))
with pytest.raises(FieldTypeError):
WithOptional(None, None, None, None, 'not an integer') # type:ignore[arg-type]


def test_checked_dataclass_success() -> None:
"""Test successful checked dataclass"""
n_coil = 2
n_k2 = 3
n_k1 = 3
n_k0 = 4
n_other = (1, 2, 3)
_ = CheckedDataClass(
CheckedDataClass(
float_tensor=torch.ones(*n_other, n_coil, n_k2, n_k1, n_k0),
int_tensor=torch.zeros(*(1,), 1, n_k2, 1, 1, dtype=torch.int),
string='test',
)


def test_checked_dataclass_variadic_fail():
def test_checked_dataclass_variadic_fail() -> None:
"""Test exception raised on wrong variadic size"""
n_coil = 2
n_k2 = 3
n_k1 = 3
Expand All @@ -75,43 +129,45 @@ def test_checked_dataclass_variadic_fail():
match=f"'*other' is {tuple_to_regex(n_other_fail)}, "
f'which cannot be broadcast with the existing value of {tuple_to_regex(n_other)}',
):
_ = CheckedDataClass(
CheckedDataClass(
float_tensor=torch.ones(*n_other, n_coil, n_k2, n_k1, n_k0),
int_tensor=torch.zeros(*n_other_fail, 1, n_k2, 1, 1, dtype=torch.int),
string='test',
)


def test_checked_dataclass_fixed_fail():
def test_checked_dataclass_fixed_fail() -> None:
"""Test exception raised on wrong fixed size"""
n_coil = 2
n_k2 = 3
n_k1 = 3
n_k0 = 4
n_other = (1, 2, 3)
not_one = 17
with pytest.raises(ShapeError, match=f' the dimension size {not_one} does not equal 1'):
_ = CheckedDataClass(
CheckedDataClass(
float_tensor=torch.ones(*n_other, n_coil, n_k2, n_k1, n_k0),
int_tensor=torch.zeros(*n_other, 1, n_k2, 1, not_one, dtype=torch.int),
string='test',
)


def test_suspend_check_success():
def test_suspend_check_success() -> None:
"""Test the SuspendDataChecks context with a valid shape on exit"""
with SuspendDataChecks():
instance = Slots(torch.zeros(1))
instance = Slots(torch.zeros(6), torch.zeros(1))
# fix the shape
instance.tensor = torch.zeros(2, 3, 4, 5, 6)
instance.tensor2 = torch.zeros(2, 3, 4, 5, 6)


def test_suspend_check_fail():
"""Test the SuspendDataChecks context with an invalid shape on exit"""
with pytest.raises(ShapeError, match='dimensions'), SuspendDataChecks():
_ = Slots(torch.zeros(1))
# needs to be assigned to exist after leaving the Suspend context
_ = Slots(torch.zeros(1), torch.zeros(1))


def test_shape():
def test_shape() -> None:
"""Test the shape property"""
n_coil = 2
n_k2 = 3
Expand All @@ -126,7 +182,7 @@ def test_shape():
assert instance.shape == (*n_other, n_coil, n_k2, n_k1, n_k0)


def test_dype():
def test_dype() -> None:
"""Test the dtype property"""
instance = CheckedDataClass(
float_tensor=torch.ones(3, 4, 5, 6, dtype=torch.float64),
Expand All @@ -137,6 +193,22 @@ def test_dype():
assert instance.dtype == torch.float64


def test_dype_fail() -> None:
"""Test the dtype exception"""
with pytest.raises(DtypeError):
CheckedDataClass(
float_tensor=torch.ones(3, 4, 5, 6, dtype=torch.int),
int_tensor=torch.zeros(2, 1, 4, 1, 1, dtype=torch.int),
string='wrong float_tensor',
)
with pytest.raises(DtypeError):
CheckedDataClass(
float_tensor=torch.ones(3, 4, 5, 6, dtype=torch.int),
int_tensor=torch.zeros(2, 1, 4, 1, 1, dtype=torch.float32),
string='wrong int_tensor',
)


@pytest.mark.parametrize(
('string', 'expected'),
[
Expand All @@ -148,6 +220,7 @@ def test_dype():
ids=['fixed', 'named broadcastable', 'anonymous', 'anonymous variadic'],
)
def test_parse_shape(string: str, expected: tuple) -> None:
"""Test parsing of shape string"""
parsed = _parse_string_to_shape_specification(string)
assert parsed == expected

Expand All @@ -168,9 +241,19 @@ def test_specification_to_string(expected: str, shape: tuple) -> None:
assert string == expected


def test_string_to_shape():
def test_string_to_shape() -> None:
"""Test conversion of string to shape"""
instance = Slots(torch.zeros(1, 2, 5, 2)) # has shape hint '... _ 5 dim'
instance = Slots(torch.zeros(2), torch.zeros(1, 2, 5, 2)) # has shape hint '... _ 5 dim'
instance.check_invariants()
shape = _parse_string_to_size('fixed=3 dim 1', instance._memo) # type:ignore[attr-defined]
assert shape == (3, 2, 1)
memo = instance._memo # type:ignore[attr-defined]
memo = memo | {'fromdict': 8}
memo = memo | ShapeMemo(frommemo=9)
shape = _parse_string_to_size('fromdict frommemo fixed=3 dim 1', memo)
assert shape == (8, 9, 3, 2, 1)

with pytest.raises(KeyError):
_parse_string_to_size('doesnotexist', memo)
with pytest.raises(KeyError):
_parse_string_to_size('*doesnotexist', memo)
with pytest.raises(SpecificationError):
_parse_string_to_size('...', memo)

0 comments on commit fbaf94a

Please sign in to comment.