diff --git a/torch_ops/__init__.py b/torch_ops/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/torch_ops/flexpool.py b/torch_ops/flexpool.py new file mode 100644 index 0000000..571cf02 --- /dev/null +++ b/torch_ops/flexpool.py @@ -0,0 +1,33 @@ +# Copyright 2019 ComputerGraphics Tuebingen. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Tensorflow op performing flex convolution operation.""" + + +from torch.autograd import Function +import torch + +import flexpool_cuda + +torch.manual_seed(42) + + +class FlexPoolFunction(Function): + @staticmethod + def forward(ctx, features, neighborhood): + outputs = flexpool_cuda.forward(features, neighborhood) + output, argmax = outputs[:2] + ctx.save_for_backward(output, argmax) + + return output diff --git a/torch_ops/flexpool_cuda.cpp b/torch_ops/flexpool_cuda.cpp new file mode 100644 index 0000000..5bb94d6 --- /dev/null +++ b/torch_ops/flexpool_cuda.cpp @@ -0,0 +1,69 @@ +/* Copyright 2019 ComputerGraphics Tuebingen. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Authors: Fabian Groh, Patrick Wieschollek, Hendrik P.A. Lensch + +#include + +#include + +// CUDA forward declarations + +std::vector flexpool_cuda_forward(at::Tensor features, + at::Tensor neighborhood); + +std::vector flexpool_cuda_backward(at::Tensor features, + at::Tensor neighborhood, + at::Tensor topdiff, + at::Tensor argmax); + +// C++ interface + +// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4. +#define CHECK_CUDA(x) \ + AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor") +#define CHECK_CONTIGUOUS(x) \ + AT_ASSERTM(x.is_contiguous(), #x " must be contiguous") +#define CHECK_INPUT(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) + +std::vector flexpool_forward(at::Tensor features, + at::Tensor neighborhood) { + CHECK_INPUT(features); + CHECK_INPUT(neighborhood); + + return flexpool_cuda_forward(features, neighborhood); +} + +std::vector flexpool_backward(at::Tensor features, + at::Tensor neighborhood, + at::Tensor topdiff, + at::Tensor argmax) { + CHECK_INPUT(features); + CHECK_INPUT(neighborhood); + CHECK_INPUT(topdiff); + CHECK_INPUT(argmax); + + return flexpool_cuda_backward(features, neighborhood, topdiff, argmax); +} + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { + m.def("forward", &flexpool_forward, "FlexPool forward (CUDA)"); + m.def("backward", &flexpool_forward, "FlexPool backward (CUDA)"); +} + +#undef CHECK_CUDA +#undef CHECK_CONTIGUOUS +#undef CHECK_INPUT diff --git a/torch_ops/flexpool_cuda_kernel.cu b/torch_ops/flexpool_cuda_kernel.cu new file mode 100644 index 0000000..faa4782 --- /dev/null +++ b/torch_ops/flexpool_cuda_kernel.cu @@ -0,0 +1,81 @@ +/* Copyright 2019 ComputerGraphics Tuebingen. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Authors: Fabian Groh, Patrick Wieschollek, Hendrik P.A. Lensch + +#include + +#include +#include + +#include + +namespace { + +#include "../user_ops/kernels/flex_pool_kernel_gpu_impl.cuh" + +} // namespace + +std::vector flexpool_cuda_forward(at::Tensor features, + at::Tensor neighborhood) { + const int B = features.size(0); + const int D = features.size(1); + const int N = features.size(2); + + const int K = neighborhood.size(1); + + auto output = at::zeros({B, D, N}, features.type()); + auto argmax = at::zeros({B, D, N}, neighborhood.type()); + + const int threads = 32; + dim3 block(threads, threads, 1); + dim3 grid(up2(N, threads), up2(D, threads), B); + + AT_DISPATCH_FLOATING_TYPES(features.type(), "flexpool_forward_cuda", ([&] { + forward<<>>( + B, N, K, D, features.data(), + neighborhood.data(), + output.data(), argmax.data(), + std::numeric_limits::lowest()); + })); + + return {output, argmax}; +} + +std::vector flexpool_cuda_backward(at::Tensor features, + at::Tensor neighborhood, + at::Tensor topdiff, + at::Tensor argmax) { + const int B = features.size(0); + const int D = features.size(1); + const int N = features.size(2); + + const int K = neighborhood.size(1); + + auto bottom_diff = at::zeros({B, D, N}, features.type()); + + const int threads = 32; + dim3 block(threads, threads, 1); + dim3 grid(up2(N, threads), up2(D, threads), B); + + AT_DISPATCH_FLOATING_TYPES(features.type(), "flexpool_backward_cuda", ([&] { + backward<<>>( + B, N, K, D, features.data(), + neighborhood.data(), + topdiff.data(), argmax.data(), + bottom_diff.data()); + })); + + return {bottom_diff}; +} diff --git a/torch_ops/setup.py b/torch_ops/setup.py new file mode 100644 index 0000000..50b030c --- /dev/null +++ b/torch_ops/setup.py @@ -0,0 +1,29 @@ +from setuptools import setup +from torch.utils.cpp_extension import BuildExtension, CUDAExtension +import sysconfig + +extra_compile_args = sysconfig.get_config_var('CFLAGS').split() +extra_compile_args += ["-std=c++11", "-Wall", "-Wextra"] +extra_compile_args += ['--expt-relaxed-constexpr'] + +flags = [] + +setup( + name='patchmatch_cuda', + ext_modules=[ + CUDAExtension( + 'flexpool_cuda', + sources=[ + 'flexpool_cuda.cpp', + 'flexpool_cuda_kernel.cu', + ], + extra_compile_args={ + "cxx": flags, + "nvcc": flags + ["--expt-relaxed-constexpr", "-O2", + "--gpu-architecture=sm_61"], + },), + ], + extra_compile_args=extra_compile_args, + cmdclass={ + 'build_ext': BuildExtension + }) diff --git a/user_ops/kernels/cuda_utils.h b/user_ops/kernels/cuda_utils.h index 2b7fc4a..40e1fd4 100644 --- a/user_ops/kernels/cuda_utils.h +++ b/user_ops/kernels/cuda_utils.h @@ -7,4 +7,5 @@ __device__ T* DynamicSharedMemory() { return reinterpret_cast(s_shm); } + #endif // LIB_CUDA_UTILS_H_ diff --git a/user_ops/kernels/flex_pool_kernel_gpu.cu.cc b/user_ops/kernels/flex_pool_kernel_gpu.cu.cc index 8ad9bbe..8bbf5f0 100644 --- a/user_ops/kernels/flex_pool_kernel_gpu.cu.cc +++ b/user_ops/kernels/flex_pool_kernel_gpu.cu.cc @@ -25,72 +25,10 @@ limitations under the License. #include "tensorflow/core/util/cuda_kernel_helper.h" namespace { -inline int up2(int len, int th) { return (len - 1) / th + 1; } -template -__global__ void forward(const int B, const int N, const int K, const int D, - const Dtype* features, const int* neighborhood, - Dtype* output, int* argmax, Dtype float_min_value) { - // features: each feature description for each point [B, D, N]. - // neighborhood: all K nearest neighbors [B, K, N]. - // output: each feature description for each point [B, D, N]. - // argmax: global id in neighborhood who was winning the pooling [B, D, N]. - const int b = blockIdx.z; - - for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D; - d += blockDim.y * gridDim.y) { - for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N; - n += blockDim.x * gridDim.x) { - Dtype best_value = float_min_value; - int best_id = 0; - - const int current_flat = b * D * N + d * N + n; - - for (int k_ = 0; k_ < K; ++k_) { - const int other_global_id = neighborhood[b * K * N + k_ * N + n]; - const Dtype v = features[b * D * N + d * N + other_global_id]; - - if (best_value < v) { - best_id = other_global_id; - best_value = v; - } - } - - output[current_flat] = best_value; - argmax[current_flat] = best_id; - } - } -} +#include "flex_pool_kernel_gpu_impl.cuh" + -template -__global__ void backward(const int B, const int N, const int K, const int D, - - const Dtype* features, const int* neighborhood, - const Dtype* topdiff, const int* argmax, - - Dtype* grad_features) { - // features: each feature description for each point [B, D, N]. - // neighborhood: all K nearest neighbors [B, K, N]. - // gradients: topdiff[B, D, N]. - // argmax: argmax[B, D, N]. - // grad_features: gradient to each feature description for each point [B, D, - // N]. - const int b = blockIdx.z; - - for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D; - d += blockDim.y * gridDim.y) { - for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N; - n += blockDim.x * gridDim.x) { - const int top_id_flat = b * D * N + d * N + n; - const int argmax_id = argmax[top_id_flat]; - const int bottom_id_flat = b * D * N + d * N + argmax_id; - - // TODO(patwie): scattered write, yeah :-( - tensorflow::CudaAtomicAdd(&grad_features[bottom_id_flat], - topdiff[top_id_flat]); - } - } -} } // namespace diff --git a/user_ops/kernels/flex_pool_kernel_gpu_impl.cuh b/user_ops/kernels/flex_pool_kernel_gpu_impl.cuh new file mode 100644 index 0000000..81c9646 --- /dev/null +++ b/user_ops/kernels/flex_pool_kernel_gpu_impl.cuh @@ -0,0 +1,86 @@ +/* Copyright 2017 ComputerGraphics Tuebingen. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Authors: Fabian Groh, Patrick Wieschollek, Hendrik P.A. Lensch + +#ifndef LIB_FLEX_POOL_KERNEL_GPU_IMPL_H_ +#define LIB_FLEX_POOL_KERNEL_GPU_IMPL_H_ + +inline int up2(int len, int th) { return (len - 1) / th + 1; } + +template +__global__ void forward(const int B, const int N, const int K, const int D, + const Dtype* features, const int* neighborhood, + Dtype* output, int* argmax, Dtype float_min_value) { + // features: each feature description for each point [B, D, N]. + // neighborhood: all K nearest neighbors [B, K, N]. + // output: each feature description for each point [B, D, N]. + // argmax: global id in neighborhood who was winning the pooling [B, D, N]. + const int b = blockIdx.z; + + for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D; + d += blockDim.y * gridDim.y) { + for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N; + n += blockDim.x * gridDim.x) { + Dtype best_value = float_min_value; + int best_id = 0; + + const int current_flat = b * D * N + d * N + n; + + for (int k_ = 0; k_ < K; ++k_) { + const int other_global_id = neighborhood[b * K * N + k_ * N + n]; + const Dtype v = features[b * D * N + d * N + other_global_id]; + + if (best_value < v) { + best_id = other_global_id; + best_value = v; + } + } + + output[current_flat] = best_value; + argmax[current_flat] = best_id; + } + } +} + +template +__global__ void backward(const int B, const int N, const int K, const int D, + + const Dtype* features, const int* neighborhood, + const Dtype* topdiff, const int* argmax, + + Dtype* grad_features) { + // features: each feature description for each point [B, D, N]. + // neighborhood: all K nearest neighbors [B, K, N]. + // gradients: topdiff[B, D, N]. + // argmax: argmax[B, D, N]. + // grad_features: gradient to each feature description for each point [B, D, + // N]. + const int b = blockIdx.z; + + for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D; + d += blockDim.y * gridDim.y) { + for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N; + n += blockDim.x * gridDim.x) { + const int top_id_flat = b * D * N + d * N + n; + const int argmax_id = argmax[top_id_flat]; + const int bottom_id_flat = b * D * N + d * N + argmax_id; + + // TODO(patwie): scattered write, yeah :-( + atomicAdd(&grad_features[bottom_id_flat], topdiff[top_id_flat]); + } + } +} + +#endif // LIB_FLEX_POOL_KERNEL_GPU_IMPL_H_