diff --git a/script/functions/stn.py b/script/functions/stn.py index b8610f3..9a9d53f 100644 --- a/script/functions/stn.py +++ b/script/functions/stn.py @@ -2,27 +2,33 @@ import torch from torch.autograd import Function from _ext import my_lib - +from cffi import FFI +ffi = FFI() class STNFunction(Function): def forward(self, input1, input2): self.input1 = input1 self.input2 = input2 + self.device_c = ffi.new("int *") output = torch.zeros(input1.size()[0], input2.size()[1], input2.size()[2], input1.size()[3]) + #print('decice %d' % torch.cuda.current_device()) + self.device = torch.cuda.current_device() + self.device_c[0] = self.device if not input1.is_cuda: my_lib.BilinearSamplerBHWD_updateOutput(input1, input2, output) else: - output = output.cuda() - my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output) + output = output.cuda(self.device) + my_lib.BilinearSamplerBHWD_updateOutput_cuda(input1, input2, output, self.device_c) return output def backward(self, grad_output): grad_input1 = torch.zeros(self.input1.size()) grad_input2 = torch.zeros(self.input2.size()) + #print('backward decice %d' % self.device) if not grad_output.is_cuda: my_lib.BilinearSamplerBHWD_updateGradInput(self.input1, self.input2, grad_input1, grad_input2, grad_output) else: - grad_input1 = grad_input1.cuda() - grad_input2 = grad_input2.cuda() - my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output) + grad_input1 = grad_input1.cuda(self.device) + grad_input2 = grad_input2.cuda(self.device) + my_lib.BilinearSamplerBHWD_updateGradInput_cuda(self.input1, self.input2, grad_input1, grad_input2, grad_output, self.device_c) return grad_input1, grad_input2 diff --git a/script/src/my_lib_cuda.c b/script/src/my_lib_cuda.c index 68b523d..12a0f56 100644 --- a/script/src/my_lib_cuda.c +++ b/script/src/my_lib_cuda.c @@ -12,13 +12,14 @@ extern THCState *state; // we assume BHWD format in inputImages // we assume BHW(YX) format on grids -int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output) +int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output, int * device) { // THCState *state = getCutorchState(L); // THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor"); // THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor"); // THCudaTensor *output = (THCudaTensor *)luaT_checkudata(L, 4, "torch.CudaTensor"); + cudaSetDevice(device[0]); int success = 0; success = BilinearSamplerBHWD_updateOutput_cuda_kernel(output->size[2], output->size[1], @@ -27,17 +28,17 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso THCudaTensor_size(state, inputImages, 1), THCudaTensor_size(state, inputImages, 2), THCudaTensor_size(state, output, 2), - THCudaTensor_data(state, inputImages), + THCudaTensor_data(state, inputImages), THCudaTensor_stride(state, inputImages, 0), THCudaTensor_stride(state, inputImages, 3), THCudaTensor_stride(state, inputImages, 1), THCudaTensor_stride(state, inputImages, 2), - THCudaTensor_data(state, grids), + THCudaTensor_data(state, grids), THCudaTensor_stride(state, grids, 0), THCudaTensor_stride(state, grids, 3), THCudaTensor_stride(state, grids, 1), THCudaTensor_stride(state, grids, 2), - THCudaTensor_data(state, output), + THCudaTensor_data(state, output), THCudaTensor_stride(state, output, 0), THCudaTensor_stride(state, output, 3), THCudaTensor_stride(state, output, 1), @@ -52,7 +53,7 @@ int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTenso } int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages, - THCudaTensor *gradGrids, THCudaTensor *gradOutput) + THCudaTensor *gradGrids, THCudaTensor *gradOutput, int * device) { // THCState *state = getCutorchState(L); // THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor"); @@ -61,6 +62,7 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe // THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor"); // THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor"); + cudaSetDevice(device[0]); int success = 0; success = BilinearSamplerBHWD_updateGradInput_cuda_kernel(gradOutput->size[2], gradOutput->size[1], @@ -69,27 +71,27 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe THCudaTensor_size(state, inputImages, 1), THCudaTensor_size(state, inputImages, 2), THCudaTensor_size(state, gradOutput, 2), - THCudaTensor_data(state, inputImages), + THCudaTensor_data(state, inputImages), THCudaTensor_stride(state, inputImages, 0), THCudaTensor_stride(state, inputImages, 3), THCudaTensor_stride(state, inputImages, 1), THCudaTensor_stride(state, inputImages, 2), - THCudaTensor_data(state, grids), + THCudaTensor_data(state, grids), THCudaTensor_stride(state, grids, 0), THCudaTensor_stride(state, grids, 3), THCudaTensor_stride(state, grids, 1), THCudaTensor_stride(state, grids, 2), - THCudaTensor_data(state, gradInputImages), + THCudaTensor_data(state, gradInputImages), THCudaTensor_stride(state, gradInputImages, 0), THCudaTensor_stride(state, gradInputImages, 3), THCudaTensor_stride(state, gradInputImages, 1), THCudaTensor_stride(state, gradInputImages, 2), - THCudaTensor_data(state, gradGrids), + THCudaTensor_data(state, gradGrids), THCudaTensor_stride(state, gradGrids, 0), THCudaTensor_stride(state, gradGrids, 3), THCudaTensor_stride(state, gradGrids, 1), THCudaTensor_stride(state, gradGrids, 2), - THCudaTensor_data(state, gradOutput), + THCudaTensor_data(state, gradOutput), THCudaTensor_stride(state, gradOutput, 0), THCudaTensor_stride(state, gradOutput, 3), THCudaTensor_stride(state, gradOutput, 1), @@ -104,7 +106,7 @@ int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTe } int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids, - THCudaTensor *gradGrids, THCudaTensor *gradOutput) + THCudaTensor *gradGrids, THCudaTensor *gradOutput, int * device) { // THCState *state = getCutorchState(L); // THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor"); @@ -112,6 +114,7 @@ int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, // THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor"); // THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor"); + cudaSetDevice(device[0]); int success = 0; success = BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda_kernel( gradOutput->size[2], @@ -121,22 +124,22 @@ int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor_size(state, inputImages, 1), THCudaTensor_size(state, inputImages, 2), THCudaTensor_size(state, gradOutput, 2), - THCudaTensor_data(state, inputImages), + THCudaTensor_data(state, inputImages), THCudaTensor_stride(state, inputImages, 0), THCudaTensor_stride(state, inputImages, 3), THCudaTensor_stride(state, inputImages, 1), THCudaTensor_stride(state, inputImages, 2), - THCudaTensor_data(state, grids), + THCudaTensor_data(state, grids), THCudaTensor_stride(state, grids, 0), THCudaTensor_stride(state, grids, 3), THCudaTensor_stride(state, grids, 1), THCudaTensor_stride(state, grids, 2), - THCudaTensor_data(state, gradGrids), + THCudaTensor_data(state, gradGrids), THCudaTensor_stride(state, gradGrids, 0), THCudaTensor_stride(state, gradGrids, 3), THCudaTensor_stride(state, gradGrids, 1), THCudaTensor_stride(state, gradGrids, 2), - THCudaTensor_data(state, gradOutput), + THCudaTensor_data(state, gradOutput), THCudaTensor_stride(state, gradOutput, 0), THCudaTensor_stride(state, gradOutput, 3), THCudaTensor_stride(state, gradOutput, 1), diff --git a/script/src/my_lib_cuda.h b/script/src/my_lib_cuda.h index 11f7458..09dcb4e 100644 --- a/script/src/my_lib_cuda.h +++ b/script/src/my_lib_cuda.h @@ -2,10 +2,10 @@ // we assume BHWD format in inputImages // we assume BHW(YX) format on grids -int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output); +int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output, int *); int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages, - THCudaTensor *gradGrids, THCudaTensor *gradOutput); + THCudaTensor *gradGrids, THCudaTensor *gradOutput, int *); int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids, - THCudaTensor *gradGrids, THCudaTensor *gradOutput); + THCudaTensor *gradGrids, THCudaTensor *gradOutput, int *); diff --git a/script/test.py b/script/test.py index e2e692c..214958b 100644 --- a/script/test.py +++ b/script/test.py @@ -44,15 +44,16 @@ out.backward(input1.data) print(input1.grad.size(), 'time:', time.time() - start) -input1 = input1.cuda() -input2 = input2.cuda() - -start = time.time() -out = s(input1, input2) -print(out.size(), 'time:', time.time() - start) -start = time.time() -out.backward(input1.data) -print('time:', time.time() - start) +with torch.cuda.device(3): + input1 = input1.cuda() + input2 = input2.cuda() + start = time.time() + out = s(input1, input2) + print(out.size(), 'time:', time.time() - start) + start = time.time() + #out.backward(input1.data.cuda()) + torch.sum(out).backward() + print('time:', time.time() - start) input = Variable(torch.from_numpy(np.array([[3.6]], dtype=np.float32)), requires_grad = True)