Skip to content

Commit

Permalink
must be kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
balbasty committed Sep 12, 2024
1 parent 44d4816 commit 97e63ee
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions interpol/autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def inter_to_nitorch(inter, as_type='str'):
class GridPull(torch.autograd.Function):

@staticmethod
@custom_fwd('cuda', cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -175,7 +175,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd('cuda')
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -187,7 +187,7 @@ def backward(ctx, grad):
class GridPush(torch.autograd.Function):

@staticmethod
@custom_fwd('cuda', cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -205,7 +205,7 @@ def forward(ctx, input, grid, shape, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd('cuda')
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -217,7 +217,7 @@ def backward(ctx, grad):
class GridCount(torch.autograd.Function):

@staticmethod
@custom_fwd('cuda', cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, grid, shape, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -235,7 +235,7 @@ def forward(ctx, grid, shape, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd('cuda')
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -248,7 +248,7 @@ def backward(ctx, grad):
class GridGrad(torch.autograd.Function):

@staticmethod
@custom_fwd('cuda', cast_inputs=torch.float32)
@custom_fwd(device_type='cuda', cast_inputs=torch.float32)
def forward(ctx, input, grid, interpolation, bound, extrapolate):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -266,7 +266,7 @@ def forward(ctx, input, grid, interpolation, bound, extrapolate):
return output

@staticmethod
@custom_bwd('cuda')
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
var = ctx.saved_tensors
opt = ctx.opt
Expand All @@ -280,7 +280,7 @@ def backward(ctx, grad):
class SplineCoeff(torch.autograd.Function):

@staticmethod
@custom_fwd('cuda')
@custom_fwd(device_type='cuda')
def forward(ctx, input, bound, interpolation, dim, inplace):

bound = bound_to_nitorch(make_list(bound)[0], as_type='int')
Expand All @@ -297,7 +297,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
return output

@staticmethod
@custom_bwd('cuda')
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
# symmetric filter -> backward == forward
# (I don't know if I can write into grad, so inplace=False to be safe)
Expand All @@ -308,7 +308,7 @@ def backward(ctx, grad):
class SplineCoeffND(torch.autograd.Function):

@staticmethod
@custom_fwd('cuda')
@custom_fwd(device_type='cuda')
def forward(ctx, input, bound, interpolation, dim, inplace):

bound = bound_to_nitorch(make_list(bound), as_type='int')
Expand All @@ -325,7 +325,7 @@ def forward(ctx, input, bound, interpolation, dim, inplace):
return output

@staticmethod
@custom_bwd('cuda')
@custom_bwd(device_type='cuda')
def backward(ctx, grad):
# symmetric filter -> backward == forward
# (I don't know if I can write into grad, so inplace=False to be safe)
Expand Down

0 comments on commit 97e63ee

Please sign in to comment.