Skip to content

Commit

Permalink
FIX: deprecated import + expose affine_grid
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Sep 12, 2024
1 parent a4d5f53 commit bd2c1f6
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 17 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/test-and-publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
pytorch-version: "1.11"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.12"
pytorch-version: "2.4"
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/test
Expand Down
8 changes: 4 additions & 4 deletions interpol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .api import *
from .resize import *
from .restrict import *
from . import backend
from .api import * # noqa: F401, F403
from .resize import * # noqa: F401, F403
from .restrict import * # noqa: F401, F403
from . import backend # noqa: F401

from . import _version
__version__ = _version.get_versions()['version']
34 changes: 23 additions & 11 deletions interpol/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
"""High level interpolation API"""

__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad',
'spline_coeff', 'spline_coeff_nd',
'identity_grid', 'add_identity_grid', 'add_identity_grid_']
__all__ = [
'pull',
'push',
'count',
'grid_pull',
'grid_push',
'grid_count',
'grid_grad',
'spline_coeff',
'spline_coeff_nd',
'identity_grid',
'add_identity_grid',
'add_identity_grid_',
'affine_grid',
]

import torch
from .utils import expanded_shape, matvec
Expand Down Expand Up @@ -44,7 +56,7 @@
https://en.wikipedia.org/wiki/Discrete_sine_transform"""

_doc_bound_coeff = \
"""`bound` can be an int, a string or a BoundType.
"""`bound` can be an int, a string or a BoundType.
Possible values are:
- 'replicate' or 'nearest' : a a a | a b c d | d d d
- 'dct1' or 'mirror' : d c b | a b c d | c b a
Expand All @@ -61,7 +73,7 @@
- `dct2` corresponds to mirroring about the edge of the first/last voxel
See https://en.wikipedia.org/wiki/Discrete_cosine_transform
https://en.wikipedia.org/wiki/Discrete_sine_transform
/!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
orders >= 6."""

Expand Down Expand Up @@ -143,11 +155,11 @@ def grid_pull(input, grid, interpolation='linear', bound='zero',
{interpolation}
{bound}
If the input dtype is not a floating point type, the input image is
assumed to contain labels. Then, unique labels are extracted
and resampled individually, making them soft labels. Finally,
the label map is reconstructed from the individual soft labels by
If the input dtype is not a floating point type, the input image is
assumed to contain labels. Then, unique labels are extracted
and resampled individually, making them soft labels. Finally,
the label map is reconstructed from the individual soft labels by
assigning the label with maximum soft value.
Parameters
Expand Down Expand Up @@ -290,7 +302,7 @@ def grid_count(grid, shape=None, interpolation='linear', bound='zero',
def grid_grad(input, grid, interpolation='linear', bound='zero',
extrapolate=False, prefilter=False):
"""Sample spatial gradients of an image with respect to a deformation field.
Notes
-----
{interpolation}
Expand Down
7 changes: 5 additions & 2 deletions interpol/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,12 @@
grid_grad, grid_grad_backward)
from .utils import fake_decorator
try:
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.amp import custom_fwd, custom_bwd
except (ModuleNotFoundError, ImportError):
custom_fwd = custom_bwd = fake_decorator
try:
from torch.cuda.amp import custom_fwd, custom_bwd
except (ModuleNotFoundError, ImportError):
custom_fwd = custom_bwd = fake_decorator


def make_list(x):
Expand Down

0 comments on commit bd2c1f6

Please sign in to comment.