Skip to content

Commit

Permalink
attempt to fix issue #9
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed Jun 12, 2017
1 parent a7de67a commit d8ef434
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 33 deletions.
18 changes: 12 additions & 6 deletions script/functions/stn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,27 +2,33 @@
import torch
from torch.autograd import Function
from _ext import my_lib

from cffi import FFI
ffi = FFI()

class STNFunction(Function):
def forward(self, input1, input2):
self.input1 = input1
self.input2 = input2
self.device_c = ffi.new("int *")
output = torch.zeros(input1.size()[0], input2.size()[1], input2.size()[2], input1.size()[3])
#print('decice %d' % torch.cuda.current_device())
self.device = torch.cuda.current_device()
self.device_c[0] = self.device
if not input1.is_cuda:
my_lib.BilinearSamplerBHWD_updateOutput(input1, input2, output)
else:
output = output.cuda()
my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output)
output = output.cuda(self.device)
my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c)
return output

def backward(self, grad_output):
grad_input1 = torch.zeros(self.input1.size())
grad_input2 = torch.zeros(self.input2.size())
#print('backward decice %d' % self.device)
if not grad_output.is_cuda:
my_lib.BilinearSamplerBHWD_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output)
else:
grad_input1 = grad_input1.cuda()
grad_input2 = grad_input2.cuda()
my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output)
grad_input1 = grad_input1.cuda(self.device)
grad_input2 = grad_input2.cuda(self.device)
my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output, self.device_c)
return grad_input1, grad_input2
33 changes: 18 additions & 15 deletions script/src/my_lib_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,14 @@ extern THCState *state;
// we assume BHWD format in inputImages
// we assume BHW(YX) format on grids

int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output)
int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output, int * device)
{
// THCState *state = getCutorchState(L);
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// THCudaTensor *output = (THCudaTensor *)luaT_checkudata(L, 4, "torch.CudaTensor");

cudaSetDevice(device[0]);
int success = 0;
success = BilinearSamplerBHWD_updateOutput_cuda_kernel(output->size[2],
output->size[1],
Expand All @@ -27,17 +28,17 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso
THCudaTensor_size(state, inputImages, 1),
THCudaTensor_size(state, inputImages, 2),
THCudaTensor_size(state, output, 2),
THCudaTensor_data(state, inputImages),
THCudaTensor_data(state, inputImages),
THCudaTensor_stride(state, inputImages, 0),
THCudaTensor_stride(state, inputImages, 3),
THCudaTensor_stride(state, inputImages, 1),
THCudaTensor_stride(state, inputImages, 2),
THCudaTensor_data(state, grids),
THCudaTensor_data(state, grids),
THCudaTensor_stride(state, grids, 0),
THCudaTensor_stride(state, grids, 3),
THCudaTensor_stride(state, grids, 1),
THCudaTensor_stride(state, grids, 2),
THCudaTensor_data(state, output),
THCudaTensor_data(state, output),
THCudaTensor_stride(state, output, 0),
THCudaTensor_stride(state, output, 3),
THCudaTensor_stride(state, output, 1),
Expand All @@ -52,7 +53,7 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso
}

int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
THCudaTensor *gradGrids, THCudaTensor *gradOutput)
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int * device)
{
// THCState *state = getCutorchState(L);
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
Expand All @@ -61,6 +62,7 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");

cudaSetDevice(device[0]);
int success = 0;
success = BilinearSamplerBHWD_updateGradInput_cuda_kernel(gradOutput->size[2],
gradOutput->size[1],
Expand All @@ -69,27 +71,27 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
THCudaTensor_size(state, inputImages, 1),
THCudaTensor_size(state, inputImages, 2),
THCudaTensor_size(state, gradOutput, 2),
THCudaTensor_data(state, inputImages),
THCudaTensor_data(state, inputImages),
THCudaTensor_stride(state, inputImages, 0),
THCudaTensor_stride(state, inputImages, 3),
THCudaTensor_stride(state, inputImages, 1),
THCudaTensor_stride(state, inputImages, 2),
THCudaTensor_data(state, grids),
THCudaTensor_data(state, grids),
THCudaTensor_stride(state, grids, 0),
THCudaTensor_stride(state, grids, 3),
THCudaTensor_stride(state, grids, 1),
THCudaTensor_stride(state, grids, 2),
THCudaTensor_data(state, gradInputImages),
THCudaTensor_data(state, gradInputImages),
THCudaTensor_stride(state, gradInputImages, 0),
THCudaTensor_stride(state, gradInputImages, 3),
THCudaTensor_stride(state, gradInputImages, 1),
THCudaTensor_stride(state, gradInputImages, 2),
THCudaTensor_data(state, gradGrids),
THCudaTensor_data(state, gradGrids),
THCudaTensor_stride(state, gradGrids, 0),
THCudaTensor_stride(state, gradGrids, 3),
THCudaTensor_stride(state, gradGrids, 1),
THCudaTensor_stride(state, gradGrids, 2),
THCudaTensor_data(state, gradOutput),
THCudaTensor_data(state, gradOutput),
THCudaTensor_stride(state, gradOutput, 0),
THCudaTensor_stride(state, gradOutput, 3),
THCudaTensor_stride(state, gradOutput, 1),
Expand All @@ -104,14 +106,15 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe
}

int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids,
THCudaTensor *gradGrids, THCudaTensor *gradOutput)
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int * device)
{
// THCState *state = getCutorchState(L);
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");

cudaSetDevice(device[0]);
int success = 0;
success = BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda_kernel(
gradOutput->size[2],
Expand All @@ -121,22 +124,22 @@ int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages,
THCudaTensor_size(state, inputImages, 1),
THCudaTensor_size(state, inputImages, 2),
THCudaTensor_size(state, gradOutput, 2),
THCudaTensor_data(state, inputImages),
THCudaTensor_data(state, inputImages),
THCudaTensor_stride(state, inputImages, 0),
THCudaTensor_stride(state, inputImages, 3),
THCudaTensor_stride(state, inputImages, 1),
THCudaTensor_stride(state, inputImages, 2),
THCudaTensor_data(state, grids),
THCudaTensor_data(state, grids),
THCudaTensor_stride(state, grids, 0),
THCudaTensor_stride(state, grids, 3),
THCudaTensor_stride(state, grids, 1),
THCudaTensor_stride(state, grids, 2),
THCudaTensor_data(state, gradGrids),
THCudaTensor_data(state, gradGrids),
THCudaTensor_stride(state, gradGrids, 0),
THCudaTensor_stride(state, gradGrids, 3),
THCudaTensor_stride(state, gradGrids, 1),
THCudaTensor_stride(state, gradGrids, 2),
THCudaTensor_data(state, gradOutput),
THCudaTensor_data(state, gradOutput),
THCudaTensor_stride(state, gradOutput, 0),
THCudaTensor_stride(state, gradOutput, 3),
THCudaTensor_stride(state, gradOutput, 1),
Expand Down
6 changes: 3 additions & 3 deletions script/src/my_lib_cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
// we assume BHWD format in inputImages
// we assume BHW(YX) format on grids

int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output);
int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output, int *);

int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
THCudaTensor *gradGrids, THCudaTensor *gradOutput);
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int *);

int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids,
THCudaTensor *gradGrids, THCudaTensor *gradOutput);
THCudaTensor *gradGrids, THCudaTensor *gradOutput, int *);
19 changes: 10 additions & 9 deletions script/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,16 @@
out.backward(input1.data)
print(input1.grad.size(), 'time:', time.time() - start)

input1 = input1.cuda()
input2 = input2.cuda()

start = time.time()
out = s(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
out.backward(input1.data)
print('time:', time.time() - start)
with torch.cuda.device(3):
input1 = input1.cuda()
input2 = input2.cuda()
start = time.time()
out = s(input1, input2)
print(out.size(), 'time:', time.time() - start)
start = time.time()
#out.backward(input1.data.cuda())
torch.sum(out).backward()
print('time:', time.time() - start)

input = Variable(torch.from_numpy(np.array([[3.6]], dtype=np.float32)), requires_grad = True)

Expand Down

0 comments on commit d8ef434

Please sign in to comment.