Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FIX: deprecated import + expose affine_grid #21

Merged
merged 3 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .github/workflows/test-and-publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ jobs:
pytorch-version: "1.11"
- python-version: "3.11"
pytorch-version: "2.0"
- python-version: "3.12"
pytorch-version: "2.4"
steps:
- uses: actions/checkout@v3
- uses: ./.github/actions/test
Expand Down
8 changes: 4 additions & 4 deletions interpol/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .api import *
from .resize import *
from .restrict import *
from . import backend
from .api import * # noqa: F401, F403
from .resize import * # noqa: F401, F403
from .restrict import * # noqa: F401, F403
from . import backend # noqa: F401

from . import _version
__version__ = _version.get_versions()['version']
34 changes: 23 additions & 11 deletions interpol/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,20 @@
"""High level interpolation API"""

__all__ = ['grid_pull', 'grid_push', 'grid_count', 'grid_grad',
'spline_coeff', 'spline_coeff_nd',
'identity_grid', 'add_identity_grid', 'add_identity_grid_']
__all__ = [
'pull',
'push',
'count',
'grid_pull',
'grid_push',
'grid_count',
'grid_grad',
'spline_coeff',
'spline_coeff_nd',
'identity_grid',
'add_identity_grid',
'add_identity_grid_',
'affine_grid',
]

import torch
from .utils import expanded_shape, matvec
Expand Down Expand Up @@ -44,7 +56,7 @@
https://en.wikipedia.org/wiki/Discrete_sine_transform"""

_doc_bound_coeff = \
"""`bound` can be an int, a string or a BoundType.
"""`bound` can be an int, a string or a BoundType.
Possible values are:
- 'replicate' or 'nearest' : a a a | a b c d | d d d
- 'dct1' or 'mirror' : d c b | a b c d | c b a
Expand All @@ -61,7 +73,7 @@
- `dct2` corresponds to mirroring about the edge of the first/last voxel
See https://en.wikipedia.org/wiki/Discrete_cosine_transform
https://en.wikipedia.org/wiki/Discrete_sine_transform

/!\ Only 'dct1', 'dct2' and 'dft' are implemented for interpolation
orders >= 6."""

Expand Down Expand Up @@ -143,11 +155,11 @@ def grid_pull(input, grid, interpolation='linear', bound='zero',
{interpolation}

{bound}
If the input dtype is not a floating point type, the input image is
assumed to contain labels. Then, unique labels are extracted
and resampled individually, making them soft labels. Finally,
the label map is reconstructed from the individual soft labels by

If the input dtype is not a floating point type, the input image is
assumed to contain labels. Then, unique labels are extracted
and resampled individually, making them soft labels. Finally,
the label map is reconstructed from the individual soft labels by
assigning the label with maximum soft value.

Parameters
Expand Down Expand Up @@ -290,7 +302,7 @@ def grid_count(grid, shape=None, interpolation='linear', bound='zero',
def grid_grad(input, grid, interpolation='linear', bound='zero',
extrapolate=False, prefilter=False):
"""Sample spatial gradients of an image with respect to a deformation field.

Notes
-----
{interpolation}
Expand Down
60 changes: 46 additions & 14 deletions interpol/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,41 @@
grid_grad, grid_grad_backward)
from .utils import fake_decorator
try:
from torch.cuda.amp import custom_fwd, custom_bwd
from torch.amp import custom_fwd, custom_bwd
except (ModuleNotFoundError, ImportError):
custom_fwd = custom_bwd = fake_decorator
try:
from torch.cuda.amp import (
custom_fwd as _custom_fwd_cuda,
custom_bwd as _custom_bwd_cuda
)
except (ModuleNotFoundError, ImportError):
_custom_fwd_cuda = _custom_bwd_cuda = fake_decorator

try:
from torch.cpu.amp import (
custom_fwd as _custom_fwd_cpu,
custom_bwd as _custom_bwd_cpu
)
except (ModuleNotFoundError, ImportError):
_custom_fwd_cpu = _custom_bwd_cpu = fake_decorator

def custom_fwd(fwd=None, *, device_type, cast_inputs=None):
if device_type == 'cuda':
decorator = _custom_fwd_cuda(cast_inputs=cast_inputs)
return decorator(fwd) if fwd else decorator
if device_type == 'cpu':
decorator = _custom_fwd_cpu(cast_inputs=cast_inputs)
return decorator(fwd) if fwd else decorator
return fake_decorator(fwd) if fwd else decorator

def custom_bwd(bwd=None, *, device_type):
if device_type == 'cuda':
decorator = _custom_bwd_cuda
return decorator(bwd) if bwd else decorator
if device_type == 'cpu':
decorator = _custom_bwd_cpu
return decorator(bwd) if bwd else decorator
return fake_decorator(bwd) if bwd else decorator


def make_list(x):
Expand Down Expand Up @@ -125,7 +157,7 @@ def inter_to_nitorch(inter, as_type='str'):
class GridPull(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -143,7 +175,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -155,7 +187,7 @@ def backward(ctx, grad):
class GridPush(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -173,7 +205,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -185,7 +217,7 @@ def backward(ctx, grad):
class GridCount(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, grid, shape, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -203,7 +235,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -216,7 +248,7 @@ def backward(ctx, grad):
class GridGrad(torch.autograd.Function):

@staticmethod
@custom_fwd(cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -234,7 +266,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -248,7 +280,7 @@ def backward(ctx, grad):
class SplineCoeff(torch.autograd.Function):

@staticmethod
@custom_fwd
@custom_fwd(device_type='cuda')
def forward(ctx, input, bound, interpolation, dim, inplace):

bound = bound_to_nitorch(make_list(bound)[0], as_type='int')
Expand All @@ -265,7 +297,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
# symmetric filter -> backward == forward
# (I don't know if I can write into grad, so inplace=False to be safe)
Expand All @@ -276,7 +308,7 @@ def backward(ctx, grad):
class SplineCoeffND(torch.autograd.Function):

@staticmethod
@custom_fwd
@custom_fwd(device_type='cuda')
def forward(ctx, input, bound, interpolation, dim, inplace):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -293,7 +325,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
return output

@staticmethod
@custom_bwd
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
# symmetric filter -> backward == forward
# (I don't know if I can write into grad, so inplace=False to be safe)
Expand Down
Loading