diff --git a/.github/workflows/test-and-publish.yaml b/.github/workflows/test-and-publish.yaml index 4c46a9c..2b9576e 100644 --- a/.github/workflows/test-and-publish.yaml +++ b/.github/workflows/test-and-publish.yaml @@ -50,6 +50,8 @@ jobs: pytorch-version: "1.11" - python-version: "3.11" pytorch-version: "2.0" + - python-version: "3.12" + pytorch-version: "2.4" steps: - uses: actions/checkout@v3 - uses: ./.github/actions/test diff --git a/interpol/__init__.py b/interpol/__init__.py index ecb4add..91dee73 100644 --- a/interpol/__init__.py +++ b/interpol/__init__.py @@ -1,7 +1,7 @@ -from .api import * -from .resize import * -from .restrict import * -from . import backend +from .api import * # noqa: F401, F403 +from .resize import * # noqa: F401, F403 +from .restrict import * # noqa: F401, F403 +from . import backend # noqa: F401 from . import _version __version__ = _version.get_versions()['version'] diff --git a/interpol/api.py b/interpol/api.py index b7c0066..b128368 100755 --- a/interpol/api.py +++ b/interpol/api.py @@ -1,8 +1,20 @@ """High level interpolation API""" -__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad', - 'spline_coeff', 'spline_coeff_nd', - 'identity_grid', 'add_identity_grid', 'add_identity_grid_'] +__all__ = [ + 'pull', + 'push', + 'count', + 'grid_pull', + 'grid_push', + 'grid_count', + 'grid_grad', + 'spline_coeff', + 'spline_coeff_nd', + 'identity_grid', + 'add_identity_grid', + 'add_identity_grid_', + 'affine_grid', +] import torch from .utils import expanded_shape, matvec @@ -44,7 +56,7 @@ https://en.wikipedia.org/wiki/Discrete_sine_transform""" _doc_bound_coeff = \ -"""`bound` can be an int, a string or a BoundType. +"""`bound` can be an int, a string or a BoundType. Possible values are: - 'replicate' or 'nearest' : a a a | a b c d | d d d - 'dct1' or 'mirror' : d c b | a b c d | c b a @@ -61,7 +73,7 @@ - `dct2` corresponds to mirroring about the edge of the first/last voxel See https://en.wikipedia.org/wiki/Discrete_cosine_transform https://en.wikipedia.org/wiki/Discrete_sine_transform - + /!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation orders >= 6.""" @@ -143,11 +155,11 @@ def grid_pull(input, grid, interpolation='linear', bound='zero', {interpolation} {bound} - - If the input dtype is not a floating point type, the input image is - assumed to contain labels. Then, unique labels are extracted - and resampled individually, making them soft labels. Finally, - the label map is reconstructed from the individual soft labels by + + If the input dtype is not a floating point type, the input image is + assumed to contain labels. Then, unique labels are extracted + and resampled individually, making them soft labels. Finally, + the label map is reconstructed from the individual soft labels by assigning the label with maximum soft value. Parameters @@ -290,7 +302,7 @@ def grid_count(grid, shape=None, interpolation='linear', bound='zero', def grid_grad(input, grid, interpolation='linear', bound='zero', extrapolate=False, prefilter=False): """Sample spatial gradients of an image with respect to a deformation field. - + Notes ----- {interpolation} diff --git a/interpol/autograd.py b/interpol/autograd.py index 40cace9..2025be9 100644 --- a/interpol/autograd.py +++ b/interpol/autograd.py @@ -10,9 +10,12 @@ grid_grad, grid_grad_backward) from .utils import fake_decorator try: - from torch.cuda.amp import custom_fwd, custom_bwd + from torch.amp import custom_fwd, custom_bwd except (ModuleNotFoundError, ImportError): - custom_fwd = custom_bwd = fake_decorator + try: + from torch.cuda.amp import custom_fwd, custom_bwd + except (ModuleNotFoundError, ImportError): + custom_fwd = custom_bwd = fake_decorator def make_list(x):