Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Share implementation between a PyTorch and TensorFlow #9

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file added torch_ops/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions torch_ops/flexpool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright 2019 ComputerGraphics Tuebingen. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Tensorflow op performing flex convolution operation."""


from torch.autograd import Function
import torch

import flexpool_cuda

torch.manual_seed(42)


class FlexPoolFunction(Function):
@staticmethod
def forward(ctx, features, neighborhood):
outputs = flexpool_cuda.forward(features, neighborhood)
output, argmax = outputs[:2]
ctx.save_for_backward(output, argmax)

return output
69 changes: 69 additions & 0 deletions torch_ops/flexpool_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/* Copyright 2019 ComputerGraphics Tuebingen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Authors: Fabian Groh, Patrick Wieschollek, Hendrik P.A. Lensch

#include <torch/torch.h>

#include <vector>

// CUDA forward declarations

std::vector<at::Tensor> flexpool_cuda_forward(at::Tensor features,
at::Tensor neighborhood);

std::vector<at::Tensor> flexpool_cuda_backward(at::Tensor features,
at::Tensor neighborhood,
at::Tensor topdiff,
at::Tensor argmax);

// C++ interface

// NOTE: AT_ASSERT has become AT_CHECK on master after 0.4.
#define CHECK_CUDA(x) \
AT_ASSERTM(x.type().is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)

std::vector<at::Tensor> flexpool_forward(at::Tensor features,
at::Tensor neighborhood) {
CHECK_INPUT(features);
CHECK_INPUT(neighborhood);

return flexpool_cuda_forward(features, neighborhood);
}

std::vector<at::Tensor> flexpool_backward(at::Tensor features,
at::Tensor neighborhood,
at::Tensor topdiff,
at::Tensor argmax) {
CHECK_INPUT(features);
CHECK_INPUT(neighborhood);
CHECK_INPUT(topdiff);
CHECK_INPUT(argmax);

return flexpool_cuda_backward(features, neighborhood, topdiff, argmax);
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &flexpool_forward, "FlexPool forward (CUDA)");
m.def("backward", &flexpool_forward, "FlexPool backward (CUDA)");
}

#undef CHECK_CUDA
#undef CHECK_CONTIGUOUS
#undef CHECK_INPUT
81 changes: 81 additions & 0 deletions torch_ops/flexpool_cuda_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
/* Copyright 2019 ComputerGraphics Tuebingen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Authors: Fabian Groh, Patrick Wieschollek, Hendrik P.A. Lensch

#include <ATen/ATen.h>

#include <cuda.h>
#include <cuda_runtime.h>

#include <vector>

namespace {

#include "../user_ops/kernels/flex_pool_kernel_gpu_impl.cuh"

} // namespace

std::vector<at::Tensor> flexpool_cuda_forward(at::Tensor features,
at::Tensor neighborhood) {
const int B = features.size(0);
const int D = features.size(1);
const int N = features.size(2);

const int K = neighborhood.size(1);

auto output = at::zeros({B, D, N}, features.type());
auto argmax = at::zeros({B, D, N}, neighborhood.type());

const int threads = 32;
dim3 block(threads, threads, 1);
dim3 grid(up2(N, threads), up2(D, threads), B);

AT_DISPATCH_FLOATING_TYPES(features.type(), "flexpool_forward_cuda", ([&] {
forward<scalar_t><<<grid, block>>>(
B, N, K, D, features.data<scalar_t>(),
neighborhood.data<int>(),
output.data<scalar_t>(), argmax.data<int>(),
std::numeric_limits<scalar_t>::lowest());
}));

return {output, argmax};
}

std::vector<at::Tensor> flexpool_cuda_backward(at::Tensor features,
at::Tensor neighborhood,
at::Tensor topdiff,
at::Tensor argmax) {
const int B = features.size(0);
const int D = features.size(1);
const int N = features.size(2);

const int K = neighborhood.size(1);

auto bottom_diff = at::zeros({B, D, N}, features.type());

const int threads = 32;
dim3 block(threads, threads, 1);
dim3 grid(up2(N, threads), up2(D, threads), B);

AT_DISPATCH_FLOATING_TYPES(features.type(), "flexpool_backward_cuda", ([&] {
backward<scalar_t><<<grid, block>>>(
B, N, K, D, features.data<scalar_t>(),
neighborhood.data<int>(),
topdiff.data<scalar_t>(), argmax.data<int>(),
bottom_diff.data<scalar_t>());
}));

return {bottom_diff};
}
29 changes: 29 additions & 0 deletions torch_ops/setup.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from setuptools import setup
from torch.utils.cpp_extension import BuildExtension, CUDAExtension
import sysconfig

extra_compile_args = sysconfig.get_config_var('CFLAGS').split()
extra_compile_args += ["-std=c++11", "-Wall", "-Wextra"]
extra_compile_args += ['--expt-relaxed-constexpr']

flags = []

setup(
name='patchmatch_cuda',
ext_modules=[
CUDAExtension(
'flexpool_cuda',
sources=[
'flexpool_cuda.cpp',
'flexpool_cuda_kernel.cu',
],
extra_compile_args={
"cxx": flags,
"nvcc": flags + ["--expt-relaxed-constexpr", "-O2",
"--gpu-architecture=sm_61"],
},),
],
extra_compile_args=extra_compile_args,
cmdclass={
'build_ext': BuildExtension
})
1 change: 1 addition & 0 deletions user_ops/kernels/cuda_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,5 @@ __device__ T* DynamicSharedMemory() {
return reinterpret_cast<T*>(s_shm);
}


#endif // LIB_CUDA_UTILS_H_
66 changes: 2 additions & 64 deletions user_ops/kernels/flex_pool_kernel_gpu.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,72 +25,10 @@ limitations under the License.
#include "tensorflow/core/util/cuda_kernel_helper.h"

namespace {
inline int up2(int len, int th) { return (len - 1) / th + 1; }

template <typename Dtype>
__global__ void forward(const int B, const int N, const int K, const int D,
const Dtype* features, const int* neighborhood,
Dtype* output, int* argmax, Dtype float_min_value) {
// features: each feature description for each point [B, D, N].
// neighborhood: all K nearest neighbors [B, K, N].
// output: each feature description for each point [B, D, N].
// argmax: global id in neighborhood who was winning the pooling [B, D, N].
const int b = blockIdx.z;

for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D;
d += blockDim.y * gridDim.y) {
for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N;
n += blockDim.x * gridDim.x) {
Dtype best_value = float_min_value;
int best_id = 0;

const int current_flat = b * D * N + d * N + n;

for (int k_ = 0; k_ < K; ++k_) {
const int other_global_id = neighborhood[b * K * N + k_ * N + n];
const Dtype v = features[b * D * N + d * N + other_global_id];

if (best_value < v) {
best_id = other_global_id;
best_value = v;
}
}

output[current_flat] = best_value;
argmax[current_flat] = best_id;
}
}
}
#include "flex_pool_kernel_gpu_impl.cuh"


template <typename Dtype>
__global__ void backward(const int B, const int N, const int K, const int D,

const Dtype* features, const int* neighborhood,
const Dtype* topdiff, const int* argmax,

Dtype* grad_features) {
// features: each feature description for each point [B, D, N].
// neighborhood: all K nearest neighbors [B, K, N].
// gradients: topdiff[B, D, N].
// argmax: argmax[B, D, N].
// grad_features: gradient to each feature description for each point [B, D,
// N].
const int b = blockIdx.z;

for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D;
d += blockDim.y * gridDim.y) {
for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N;
n += blockDim.x * gridDim.x) {
const int top_id_flat = b * D * N + d * N + n;
const int argmax_id = argmax[top_id_flat];
const int bottom_id_flat = b * D * N + d * N + argmax_id;

// TODO(patwie): scattered write, yeah :-(
tensorflow::CudaAtomicAdd(&grad_features[bottom_id_flat],
topdiff[top_id_flat]);
}
}
}

} // namespace

Expand Down
86 changes: 86 additions & 0 deletions user_ops/kernels/flex_pool_kernel_gpu_impl.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright 2017 ComputerGraphics Tuebingen. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Authors: Fabian Groh, Patrick Wieschollek, Hendrik P.A. Lensch

#ifndef LIB_FLEX_POOL_KERNEL_GPU_IMPL_H_
#define LIB_FLEX_POOL_KERNEL_GPU_IMPL_H_

inline int up2(int len, int th) { return (len - 1) / th + 1; }

template <typename Dtype>
__global__ void forward(const int B, const int N, const int K, const int D,
const Dtype* features, const int* neighborhood,
Dtype* output, int* argmax, Dtype float_min_value) {
// features: each feature description for each point [B, D, N].
// neighborhood: all K nearest neighbors [B, K, N].
// output: each feature description for each point [B, D, N].
// argmax: global id in neighborhood who was winning the pooling [B, D, N].
const int b = blockIdx.z;

for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D;
d += blockDim.y * gridDim.y) {
for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N;
n += blockDim.x * gridDim.x) {
Dtype best_value = float_min_value;
int best_id = 0;

const int current_flat = b * D * N + d * N + n;

for (int k_ = 0; k_ < K; ++k_) {
const int other_global_id = neighborhood[b * K * N + k_ * N + n];
const Dtype v = features[b * D * N + d * N + other_global_id];

if (best_value < v) {
best_id = other_global_id;
best_value = v;
}
}

output[current_flat] = best_value;
argmax[current_flat] = best_id;
}
}
}

template <typename Dtype>
__global__ void backward(const int B, const int N, const int K, const int D,

const Dtype* features, const int* neighborhood,
const Dtype* topdiff, const int* argmax,

Dtype* grad_features) {
// features: each feature description for each point [B, D, N].
// neighborhood: all K nearest neighbors [B, K, N].
// gradients: topdiff[B, D, N].
// argmax: argmax[B, D, N].
// grad_features: gradient to each feature description for each point [B, D,
// N].
const int b = blockIdx.z;

for (int d = blockIdx.y * blockDim.y + threadIdx.y; d < D;
d += blockDim.y * gridDim.y) {
for (int n = blockIdx.x * blockDim.x + threadIdx.x; n < N;
n += blockDim.x * gridDim.x) {
const int top_id_flat = b * D * N + d * N + n;
const int argmax_id = argmax[top_id_flat];
const int bottom_id_flat = b * D * N + d * N + argmax_id;

// TODO(patwie): scattered write, yeah :-(
atomicAdd(&grad_features[bottom_id_flat], topdiff[top_id_flat]);
}
}
}

#endif // LIB_FLEX_POOL_KERNEL_GPU_IMPL_H_