From d7a4524813b10439c8e0985faa993c17ada91f77 Mon Sep 17 00:00:00 2001 From: Tom Joy Date: Tue, 1 May 2018 13:04:33 +0100 Subject: [PATCH 1/2] Adding filtering via GPU, this is an interactive reabsed comit. See below commits Calling filtering from functor Moving away from tensorflow::Tensors, this causes problems when using GPU Adding functionality to makefile and removing .cc.cul files Implemented templated specialization of functors, also moved away from the tensorflow CUDA macros and defined my own, we have also reverted back to Tensors Fixing issue with registering GPU op when no gpu is present Compiling modified permutohedral.cu, getting an error with invalid pointer Now filtering on GPU Fixing channel error in filter operation and removing std::couts Updates to readme Tidying up before PR --- .gitignore | 3 + README.md | 2 + run_demo.py | 5 +- src/cpp/Makefile | 22 +- src/cpp/hash_helper.cu | 93 ++++ src/cpp/high_dim_filter.cc | 105 +++- src/cpp/high_dim_filter.cu | 85 +++ src/cpp/include/cuda_macros.h | 35 ++ src/cpp/include/hash_table.h | 151 +++++ src/cpp/include/high_dim_filter.h | 42 ++ .../{ => include}/modified_permutohedral.h | 42 +- src/cpp/modified_permutohedral.cc | 19 +- src/cpp/modified_permutohedral.cu | 525 ++++++++++++++++++ 13 files changed, 1074 insertions(+), 55 deletions(-) create mode 100644 src/cpp/hash_helper.cu create mode 100644 src/cpp/high_dim_filter.cu create mode 100644 src/cpp/include/cuda_macros.h create mode 100644 src/cpp/include/hash_table.h create mode 100644 src/cpp/include/high_dim_filter.h rename src/cpp/{ => include}/modified_permutohedral.h (77%) create mode 100644 src/cpp/modified_permutohedral.cu diff --git a/.gitignore b/.gitignore index 3ae6c3c..46ee421 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,5 @@ *.pyc src/cpp/high_dim_filter.so +*.h5 +*.png +*.o diff --git a/README.md b/README.md index 7dfdec2..c04ac47 100644 --- a/README.md +++ b/README.md @@ -45,6 +45,8 @@ You should not see any errors while importing `tensorflow` and `keras` above. ### Step 3: Build CRF-RNN custom op C++ code +**Note**: Edit the makefile to select whether to run on GPU/CPU. + Run `make` inside the `crfasrnn_keras/src/cpp` directory: ``` $ cd crfasrnn_keras/src/cpp diff --git a/run_demo.py b/run_demo.py index e0a287a..98a3873 100644 --- a/run_demo.py +++ b/run_demo.py @@ -26,7 +26,7 @@ sys.path.insert(1, './src') from crfrnn_model import get_crfrnn_model_def import util - +import time def main(): input_file = 'image.jpg' @@ -39,9 +39,12 @@ def main(): model.load_weights(saved_model_path) img_data, img_h, img_w = util.get_preprocessed_image(input_file) + tic = time.clock() probs = model.predict(img_data, verbose=False)[0, :, :, :] + toc = time.clock() segmentation = util.get_label_image(probs, img_h, img_w) segmentation.save(output_file) + print "Time taken: " + str(toc - tic) if __name__ == '__main__': diff --git a/src/cpp/Makefile b/src/cpp/Makefile index ff4a152..569f561 100644 --- a/src/cpp/Makefile +++ b/src/cpp/Makefile @@ -13,9 +13,13 @@ # Define the compiler CC := g++ +#Use GPU Implementation? +USE_GPU := 1 + # Read Tensorflow paths TF_INC := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') TF_LIB := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') +TF_CFLAGS=$(shell python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))') # Is the Tensorflow version >= 1.4? TF_VERSION_GTE_1_4 := $(shell expr `python -c 'import tensorflow as tf; print(tf.__version__)' | cut -f1,2 -d.` \>= 1.4) @@ -31,7 +35,7 @@ endif # Set some more flags if the Tensorflow version is >= 1.4 ifeq ($(TF_VERSION_GTE_1_4), 1) CFLAGS += -I$(TF_INC)/external/nsync/public - LDFLAGS := -L$(TF_LIB) -ltensorflow_framework + LDFLAGS := -L$(TF_LIB) -ltensorflow_framework -L/usr/local/cuda/lib64 -lcuda -lcudart else LDFLAGS := endif @@ -39,10 +43,18 @@ endif # Define build targets .PHONY: all clean -high_dim_filter.so: high_dim_filter.cc modified_permutohedral.cc - $(CC) $(CFLAGS) -o high_dim_filter.so high_dim_filter.cc modified_permutohedral.cc $(LDFLAGS) + +all: high_dim_filter.so + +high_dim_filter.so: cudacode.o modified_permutohedral.o high_dim_filter.cc modified_permutohedral.cc + g++ $(CFLAGS) -o high_dim_filter.so high_dim_filter.cc modified_permutohedral.cc cudacode.o modified_permutohedral.o $(LDFLAGS) -D FILTER_GPU=$(USE_GPU) + +cudacode.o: high_dim_filter.cu + nvcc -std=c++11 -c -o cudacode.o high_dim_filter.cu $(TF_CFLAGS) -x cu -Xcompiler -fPIC --expt-relaxed-constexpr -D FILTER_GPU=$(USE_GPU) + +modified_permutohedral.o: modified_permutohedral.cu + nvcc -std=c++11 -c -o modified_permutohedral.o modified_permutohedral.cu $(TF_CFLAGS) -x cu -Xcompiler -fPIC --expt-relaxed-constexpr -D FILTER_GPU=$(USE_GPU) clean: - $(RM) high_dim_filter.so + $(RM) high_dim_filter.so cudacode.o modified_permutohedral.o -all: high_dim_filter.so diff --git a/src/cpp/hash_helper.cu b/src/cpp/hash_helper.cu new file mode 100644 index 0000000..1ed7bff --- /dev/null +++ b/src/cpp/hash_helper.cu @@ -0,0 +1,93 @@ + +#define modHash(n) ((n)%(2*table_capacity)); + +namespace caffe { + template + __device__ __host__ static unsigned int hash(signed short *key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 1664525; + } + return k; + } + template + __device__ __host__ static unsigned int has(int *key) { + unsigned int k = 0; + for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 1664525; + } + return k; + } + template + __device__ static int hashTableInsert(unsigned int fh, signed short *key, + signed short * table_keys, + int* table_entries, + int table_capacity, + unsigned int slot) + { + int h = modHash(fh); + while (1) { + int *e = &table_entries[h]; + // if the cell is empty (-1), lock it (-2) + int contents = atomicCAS(e, -1, -2); + + if (contents == -2) { + // If it was locked already, move on the next cell + + } else if (contents == -1) { + // If it was empty, we successfully locked it, write our key + for (int i = 0; i < kd; i++) { + table_keys[slot*kd+i] = key[i]; + } + // Unlock + atomicExch(e, slot); + + return h; + } else { + // The cell is unlocked and has a key in it, check if it matches + bool match = true; + for (int i = 0; i < kd && match; i++) { + match = (table_keys[contents*kd+i] == key[i]); + } + if (match) return h; + } + // increment the bucket with wraparound + h++; + if (h == table_capacity*2) h = 0; + } + } + + template + __device__ static int hashTableInsert(signed short *key, + signed short* table_keys, + int* table_entries, + int table_capacity, + unsigned int slot) { + unsigned int myHash = hash(key); + return hashTableInsert(myHash, key, table_keys, table_entries, table_capacity, slot); + } + + template + __device__ static int hashTableRetrieve(signed short*key, + const int * table_entries, + const signed short* table_keys, + const int table_capacity) { + int h = modHash(hash(key)); + while (1) { + const int *e = table_entries + h; + if (*e == -1) return -1; + bool match = true; + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e)*kd+i] == key[i]); + } + if (match) return *e; + + h++; + if (h == table_capacity*2) h = 0; + } + } + +} //namespace caffe + diff --git a/src/cpp/high_dim_filter.cc b/src/cpp/high_dim_filter.cc index 25485ce..4c1d800 100644 --- a/src/cpp/high_dim_filter.cc +++ b/src/cpp/high_dim_filter.cc @@ -26,7 +26,7 @@ #include "tensorflow/core/framework/shape_inference.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/tensor_shape.h" -#include "modified_permutohedral.h" +#include "include/high_dim_filter.h" using namespace tensorflow; @@ -61,19 +61,62 @@ void compute_bilateral_kernel(float * const output_kernel, const Tensor& rgb_ten } REGISTER_OP("HighDimFilter") + .Attr("T: {float}") .Attr("bilateral: bool") .Attr("theta_alpha: float = 1.0") .Attr("theta_beta: float = 1.0") .Attr("theta_gamma: float = 1.0") .Attr("backwards: bool = false") - .Input("raw: float32") - .Input("rgb: float32") - .Output("filtered: float32") + .Input("raw: T") + .Input("rgb: T") + .Output("filtered: T") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); }); + +using CPUDevice = Eigen::ThreadPoolDevice; +using GPUDevice = Eigen::GpuDevice; + +template <> +struct HighDimFilterFunctor { + void operator()(const CPUDevice& d, + const tensorflow::Tensor & input_q, + const tensorflow::Tensor & input_img, + tensorflow::Tensor * out, + const FilterParams & params) { + + ModifiedPermutohedral mp; + + const int channels = input_q.dim_size(0); + const int height = input_img.dim_size(1); + const int width = input_img.dim_size(2); + const int num_pixels = width * height; + + if (params.bilateral_) { + float * const kernel_vals = new float[5 * num_pixels]; + compute_bilateral_kernel(kernel_vals, input_img, + params.theta_alpha_, params.theta_beta_); + mp.init_cpu(kernel_vals, 5, num_pixels); + mp.compute_cpu(*out, input_q, channels, params.backwards_); + + delete[] kernel_vals; + } else { + float * const kernel_vals = new float[2 * num_pixels]; + compute_spatial_kernel(kernel_vals, width, height, params.theta_gamma_); + mp.init_cpu(kernel_vals, 2, num_pixels); + mp.compute_cpu(*out, input_q, channels, params.backwards_); + + delete[] kernel_vals; + } + + } + +}; + + +template class HighDimFilterOp : public OpKernel { public: explicit HighDimFilterOp(OpKernelConstruction* context) : OpKernel(context) { @@ -97,36 +140,25 @@ class HighDimFilterOp : public OpKernel { // Grab the RGB image tensor const Tensor& image_tensor = context->input(1); - const int channels = input_tensor.dim_size(0); - const int height = input_tensor.dim_size(1); - const int width = input_tensor.dim_size(2); - const int num_pixels = width * height; - // Create the output tensor Tensor* output_tensor = NULL; OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), &output_tensor)); - ModifiedPermutohedral mp; - - if (bilateral_) { - float * const kernel_vals = new float[5 * num_pixels]; - compute_bilateral_kernel(kernel_vals, image_tensor, - theta_alpha_, theta_beta_); - mp.init(kernel_vals, 5, num_pixels); - mp.compute(*output_tensor, input_tensor, channels, backwards_); - - delete[] kernel_vals; - } else { - float * const kernel_vals = new float[2 * num_pixels]; - compute_spatial_kernel(kernel_vals, width, height, theta_gamma_); - mp.init(kernel_vals, 2, num_pixels); - mp.compute(*output_tensor, input_tensor, channels, backwards_); - - delete[] kernel_vals; - } + const FilterParams params(bilateral_, + theta_alpha_, + theta_beta_, + theta_gamma_, + backwards_); + //filter + HighDimFilterFunctor()( + context->eigen_device(), + input_tensor, + image_tensor, + output_tensor, + params); } - + private: bool bilateral_; float theta_alpha_; @@ -135,4 +167,19 @@ class HighDimFilterOp : public OpKernel { bool backwards_; }; -REGISTER_KERNEL_BUILDER(Name("HighDimFilter").Device(DEVICE_CPU), HighDimFilterOp); + +// Register kernels +#define REGISTER_CPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("HighDimFilter").Device(DEVICE_CPU).TypeConstraint("T"), \ + HighDimFilterOp); +REGISTER_CPU(float); + +#if FILTER_GPU +#define REGISTER_GPU(T) \ + REGISTER_KERNEL_BUILDER( \ + Name("HighDimFilter").Device(DEVICE_GPU).TypeConstraint("T"), \ + HighDimFilterOp); +REGISTER_GPU(float); +#endif + diff --git a/src/cpp/high_dim_filter.cu b/src/cpp/high_dim_filter.cu new file mode 100644 index 0000000..599bc4e --- /dev/null +++ b/src/cpp/high_dim_filter.cu @@ -0,0 +1,85 @@ +#if FILTER_GPU +#define EIGEN_USE_GPU + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor_shape.h" +#include "include/high_dim_filter.h" +#include "include/cuda_macros.h" + + +using GPUDevice = Eigen::GpuDevice; +using namespace tensorflow; + +__global__ void compute_bilateral_kernel_gpu(const int num_pixels_, + const float * const rgb_blob, + const int width_, + float theta_alpha_, float theta_beta_, + float* const output_kernel) { + CUDA_KERNEL_LOOP(p, num_pixels_) { + output_kernel[5 * p] = (float)(p % width_) / theta_alpha_; + output_kernel[5 * p + 1] = (float)(p / width_) / theta_alpha_; + output_kernel[5 * p + 2] = (float)(rgb_blob[p] / theta_beta_); + output_kernel[5 * p + 3] = (float)((rgb_blob + num_pixels_)[p] / theta_beta_); + output_kernel[5 * p + 4] = (float)((rgb_blob + num_pixels_ * 2)[p] / theta_beta_); + } +} + +__global__ void compute_spatial_kernel_gpu(const int num_pixels_, + const int width_, float theta_gamma_, + float* const output_kernel) { + + CUDA_KERNEL_LOOP(p, num_pixels_) { + output_kernel[2*p] = static_cast(p % width_) / theta_gamma_; + output_kernel[2*p + 1] = static_cast(p / width_) / theta_gamma_; + } +} + +template<> +void HighDimFilterFunctor::operator()( + const GPUDevice& d, + const tensorflow::Tensor & input_q, + const tensorflow::Tensor & input_img, + tensorflow::Tensor * out, + const FilterParams & params) { + + const int channels = input_q.dim_size(0); + const int height = input_img.dim_size(1); + const int width = input_img.dim_size(2); + const int num_pixels = width * height; + + ModifiedPermutohedral mp; + + if (params.bilateral_) { + float * kernel_vals; + + CUDA_CHECK(cudaMalloc((void**)&kernel_vals, 5 * num_pixels * sizeof(float))); + compute_bilateral_kernel_gpu<<>>( + num_pixels, input_img.flat().data(), width, + params.theta_alpha_, params.theta_beta_, kernel_vals); + CUDA_POST_KERNEL_CHECK; + + mp.init_gpu(kernel_vals, 5, width, height); + mp.compute_gpu(out->flat().data(), input_q.flat().data(), channels, params.backwards_); + CUDA_CHECK(cudaFree(kernel_vals)); + } else { + float * kernel_vals; + CUDA_CHECK(cudaMalloc((void**)&kernel_vals, 2 * num_pixels * sizeof(float))); + compute_spatial_kernel_gpu<<>>( + num_pixels, width, params.theta_gamma_, kernel_vals); + CUDA_POST_KERNEL_CHECK; + + mp.init_gpu(kernel_vals, 2, width, height); + mp.compute_gpu(out->flat().data(), input_q.flat().data(), channels, params.backwards_); + CUDA_CHECK(cudaFree(kernel_vals)); + } + + mp.freeMatrix(); + +} + + +template struct HighDimFilterFunctor; + +#endif \ No newline at end of file diff --git a/src/cpp/include/cuda_macros.h b/src/cpp/include/cuda_macros.h new file mode 100644 index 0000000..55cc514 --- /dev/null +++ b/src/cpp/include/cuda_macros.h @@ -0,0 +1,35 @@ +#ifndef CUDA_MACROS_H_ +#define CUDA_MACROS_H_ + +#include + +#define EQ(a, b) \ + ((a) == (b)) + +#define CUDA_KERNEL_LOOP(i, n) \ + for (int i = blockIdx.x * blockDim.x + threadIdx.x; \ + i < (n); \ + i += blockDim.x * gridDim.x) + +#define CUDA_POST_KERNEL_CHECK \ + if (cudaSuccess != cudaPeekAtLastError()) \ + std::cout << "Cuda kernel failed. Error: " \ + << cudaGetErrorString(cudaPeekAtLastError()) + +// CUDA: various checks for different function calls. +#define CUDA_CHECK(condition) \ + /* Code block avoids redefinition of cudaError_t error */ \ + do { \ + cudaError_t error = condition; \ + if(error != cudaSuccess) std::cout << cudaGetErrorString(error); \ +} while (0) + +// CUDA: use 512 threads per block +const int CUDA_NUM_THREADS = 512; + +// CUDA: number of blocks for threads. +inline int CUDA_GET_BLOCKS(const int N) { + return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS; +} + +#endif //CUDA_MACROS_H_ diff --git a/src/cpp/include/hash_table.h b/src/cpp/include/hash_table.h new file mode 100644 index 0000000..49e43d3 --- /dev/null +++ b/src/cpp/include/hash_table.h @@ -0,0 +1,151 @@ +#ifndef HASH_TABLE_HPP +#define HASH_TABLE_HPP + +#include "cuda_macros.h" + +#define modHash(n) ((n)%(2*table_capacity)); + + +class HashTable +{ + public: + int *table_entries; + unsigned int table_capacity; + signed short *table_keys; + bool create; + + HashTable() : create(false) {} + + void createHashTable(const int capacity, const int kd){ + #ifdef FILTER_GPU + // TODO? use symbol to go in constant memory instead + // Initialize table_capacity + table_capacity = (unsigned int)capacity ; + + // Initialize table_entries + CUDA_CHECK(cudaMalloc((void **) &table_entries, 2*capacity*sizeof(int))); + CUDA_CHECK(cudaMemset(table_entries, -1, 2*capacity*sizeof(int))); + + // Initialize table_keys + CUDA_CHECK(cudaMalloc((void **) &table_keys, capacity*kd*sizeof(signed short))); + CUDA_CHECK(cudaMemset(table_keys, 0, capacity*kd*sizeof(signed short))); + + // Set create to true + create = true; + #endif // FILTER_GPU + } + + void resetHashTable(const int capacity, const int kd){ + #ifdef FILTER_GPU + // Initialize table_capacity + table_capacity = (unsigned int)capacity ; + + // Reset table_entries + CUDA_CHECK(cudaMemset(table_entries, -1, 2*capacity*sizeof(int))); + + // Resettable_keys + CUDA_CHECK(cudaMemset(table_keys, 0, capacity*kd*sizeof(signed short))); + #endif // FILTER_GPU + } + + ~HashTable(){ + #ifdef FILTER_GPU + if(create){ + // Free pointers allocated during + CUDA_CHECK(cudaFree(table_entries)); + CUDA_CHECK(cudaFree(table_keys)); + } + #endif //FILTER_GPU + } + +}; + +template +__device__ __host__ static unsigned int hash(signed short *key) { +unsigned int k = 0; +for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 1664525; +} +return k; +} +template +__device__ __host__ static unsigned int has(int *key) { +unsigned int k = 0; +for (int i = 0; i < kd; i++) { + k += key[i]; + k = k * 1664525; +} +return k; +} +template +__device__ static int hashTableInsert(unsigned int fh, signed short *key, +signed short * table_keys, +int* table_entries, +int table_capacity, +unsigned int slot) +{ + int h = modHash(fh); + while (1) { + int *e = &table_entries[h]; + // if the cell is empty (-1), lock it (-2) + int contents = atomicCAS(e, -1, -2); + + if (contents == -2) { + // If it was locked already, move on the next cell + + } else if (contents == -1) { + // If it was empty, we successfully locked it, write our key + for (int i = 0; i < kd; i++) { + table_keys[slot*kd+i] = key[i]; + } + // Unlock + atomicExch(e, slot); + + return h; + } else { + // The cell is unlocked and has a key in it, check if it matches + bool match = true; + for (int i = 0; i < kd && match; i++) { + match = (table_keys[contents*kd+i] == key[i]); + } + if (match) return h; + } + // increment the bucket with wraparound + h++; + if (h == table_capacity*2) h = 0; + } +} + +template +__device__ static int hashTableInsert(signed short *key, +signed short* table_keys, +int* table_entries, +int table_capacity, +unsigned int slot) { + unsigned int myHash = hash(key); + return hashTableInsert(myHash, key, table_keys, table_entries, table_capacity, slot); +} + +template +__device__ static int hashTableRetrieve(signed short*key, + const int * table_entries, + const signed short* table_keys, + const int table_capacity) { + int h = modHash(hash(key)); + while (1) { + const int *e = table_entries + h; + if (*e == -1) return -1; + bool match = true; + for (int i = 0; i < kd && match; i++) { + match = (table_keys[(*e)*kd+i] == key[i]); + } + if (match) return *e; + + h++; + if (h == table_capacity*2) h = 0; + } + } + + +#endif //HASH_TABLE_HPP diff --git a/src/cpp/include/high_dim_filter.h b/src/cpp/include/high_dim_filter.h new file mode 100644 index 0000000..2698e24 --- /dev/null +++ b/src/cpp/include/high_dim_filter.h @@ -0,0 +1,42 @@ +#ifndef KERNEL_HIGH_DIM_FILTER_H_ +#define KERNEL_HIGH_DIM_FILTER_H_ + +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_kernel.h" +#include "modified_permutohedral.h" + + +using GPUDevice = Eigen::GpuDevice; +using CPUDevice = Eigen::ThreadPoolDevice; + +struct FilterParams +{ + FilterParams(bool bilateral, + float theta_alpha, + float theta_beta, + float theta_gamma, + bool backwards) : + bilateral_(bilateral), + theta_alpha_(theta_alpha), + theta_beta_(theta_beta), + theta_gamma_(theta_gamma), + backwards_(backwards) {} + + bool bilateral_; + float theta_alpha_; + float theta_beta_; + float theta_gamma_; + bool backwards_; +}; + +template +struct HighDimFilterFunctor { + void operator()(const Device& d, + const tensorflow::Tensor & input_q, + const tensorflow::Tensor & input_img, + tensorflow::Tensor * out, + const FilterParams & params); +}; + + +#endif // KERNEL_HIGH_DIM_FILTER_H_ \ No newline at end of file diff --git a/src/cpp/modified_permutohedral.h b/src/cpp/include/modified_permutohedral.h similarity index 77% rename from src/cpp/modified_permutohedral.h rename to src/cpp/include/modified_permutohedral.h index 47396f3..2fbf578 100644 --- a/src/cpp/modified_permutohedral.h +++ b/src/cpp/include/modified_permutohedral.h @@ -2,10 +2,8 @@ This file contains a modified version of the "permutohedral.h" code available at http://graphics.stanford.edu/projects/drf/. Copyright notice of the original file is included below: - Copyright (c) 2013, Philipp Krähenbühl All rights reserved. - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright @@ -16,7 +14,6 @@ * Neither the name of the Stanford University nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -40,12 +37,21 @@ #include #include "tensorflow/core/framework/tensor.h" - +#if __CUDACC__ +#include "hash_table.h" //this should not be compiled by nvcc +#endif using namespace tensorflow; /************************************************/ /*** ModifiedPermutohedral Lattice ***/ /************************************************/ + +typedef struct MatrixEntry { + int index; + float weight; +} MatrixEntry; + + class ModifiedPermutohedral { protected: struct Neighbors { @@ -58,8 +64,13 @@ class ModifiedPermutohedral { std::vector offset_, rank_; std::vector barycentric_; std::vector blur_neighbors_; + #if __CUDACC__ + bool is_init; + MatrixEntry * matrix; + HashTable table; + #endif // Number of elements, size of sparse discretized space, dimension of features - int N_, M_, d_; + int N_, M_, d_, w_, h_; void sseCompute(Tensor &out, const Tensor &in, int value_size, bool reverse = false, bool add = false) const; @@ -67,13 +78,26 @@ class ModifiedPermutohedral { void seqCompute(Tensor &out, const Tensor &in, int value_size, bool reverse = false, bool add = false) const; + public: - ModifiedPermutohedral(); + ModifiedPermutohedral() : N_( 0 ), M_( 0 ), d_( 0 ) {} + ~ModifiedPermutohedral() {} - void init(const float *features, int num_dimensions, int num_points); + #if __CUDACC__ + void freeMatrix(){ + if (is_init) { + CUDA_CHECK(cudaFree(matrix)); + } + } + #endif - void compute(Tensor &out, const Tensor &in, int value_size, + void init_cpu(const float *features, int num_dimensions, int num_points); + void init_gpu(const float* features, int num_dimensions, int w, int h); + + void compute_cpu(Tensor &out, const Tensor &in, int value_size, + bool reverse = false, bool add = false) const; + void compute_gpu(float* out, const float* in, int value_size, bool reverse = false, bool add = false) const; }; -#endif //_MODIFIED_PERMUTOHEDRAL_HPP_ +#endif //_MODIFIED_PERMUTOHEDRAL_HPP_ \ No newline at end of file diff --git a/src/cpp/modified_permutohedral.cc b/src/cpp/modified_permutohedral.cc index 5c26256..495c90b 100644 --- a/src/cpp/modified_permutohedral.cc +++ b/src/cpp/modified_permutohedral.cc @@ -2,10 +2,8 @@ This file contains a modified version of the "permutohedral.cpp" code available at http://graphics.stanford.edu/projects/drf/. Copyright notice of the original file is included below: - Copyright (c) 2013, Philipp Krähenbühl All rights reserved. - Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright @@ -16,7 +14,6 @@ * Neither the name of the Stanford University nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - THIS SOFTWARE IS PROVIDED BY Philipp Krähenbühl ''AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE @@ -30,7 +27,7 @@ */ //#include "stdafx.h" -#include "modified_permutohedral.h" +#include "include/modified_permutohedral.h" #ifdef __SSE__ // SSE Permutoheral lattice @@ -135,10 +132,10 @@ class HashTableCopy{ /*** ModifiedPermutohedral Lattice ***/ /************************************************/ -ModifiedPermutohedral::ModifiedPermutohedral():N_( 0 ), M_( 0 ), d_( 0 ) { -} +//ModifiedPermutohedral::ModifiedPermutohedral():N_( 0 ), M_( 0 ), d_( 0 ) {} + #ifdef SSE_PERMUTOHEDRAL -void ModifiedPermutohedral::init(const float* features, int num_dimensions, int num_points) +void ModifiedPermutohedral::init_cpu(const float* features, int num_dimensions, int num_points) { // Compute the lattice coordinates for each feature [there is going to be a lot of magic here N_ = num_points; @@ -321,7 +318,7 @@ void ModifiedPermutohedral::init(const float* features, int num_dimensions, int delete[] n2; } #else -void ModifiedPermutohedral::init (const float* features, int num_dimensions, int num_points) +void ModifiedPermutohedral::init_cpu(const float* features, int num_dimensions, int num_points) { // Compute the lattice coordinates for each feature [there is going to be a lot of magic here N_ = num_points; @@ -622,15 +619,15 @@ void ModifiedPermutohedral::sseCompute(Tensor& out_tensor, const Tensor& in_tens #else void ModifiedPermutohedral::sseCompute(Tensor& out, const Tensor& in, int value_size, bool reverse, bool add) const { - seqCompute( out, in, value_size, reverse, add); + seqCompute_cpu( out, in, value_size, reverse, add); } #endif -void ModifiedPermutohedral::compute(Tensor& out, const Tensor& in, int value_size, bool reverse, bool add) const +void ModifiedPermutohedral::compute_cpu(Tensor& out, const Tensor& in, int value_size, bool reverse, bool add) const { if (value_size <= 2) seqCompute(out, in, value_size, reverse, add); else sseCompute(out, in, value_size, reverse, add); -} +} \ No newline at end of file diff --git a/src/cpp/modified_permutohedral.cu b/src/cpp/modified_permutohedral.cu new file mode 100644 index 0000000..7940317 --- /dev/null +++ b/src/cpp/modified_permutohedral.cu @@ -0,0 +1,525 @@ +#define BLOCK_SIZE 64 + +#include +#include "include/modified_permutohedral.h" +#include "include/cuda_macros.h" +#include "hash_helper.cu" + + +template +__global__ void set_kernel(const int n, const Dtype alpha, Dtype* y) { + CUDA_KERNEL_LOOP(index, n) { + y[index] = alpha; + } +} + +template +void gpu_set(const int N, const Dtype alpha, Dtype* Y) { + if (alpha == 0) { + CUDA_CHECK(cudaMemset(Y, 0, sizeof(Dtype) * N)); // NOLINT(caffe/alt_fn) + return; + } + // NOLINT_NEXT_LINE(whitespace/operators) + set_kernel<<>>( + N, alpha, Y); +} + +template void gpu_set(const int N, const int alpha, int* Y); +template void gpu_set(const int N, const float alpha, float* Y); +template void gpu_set(const int N, const double alpha, double* Y); + +static void swapHashTableValues(float* oldValues, float *newValues, float* table_values,size_t size) { + CUDA_CHECK(cudaMemcpy(oldValues,table_values,size,cudaMemcpyDeviceToDevice)); + CUDA_CHECK(cudaMemcpy(table_values,newValues,size,cudaMemcpyDeviceToDevice)); + CUDA_CHECK(cudaMemcpy(newValues,oldValues,size,cudaMemcpyDeviceToDevice)); + // Works but give poorer results + //oldValues = table_values; + //table_values = newValues; + //newValues = oldValues; +} + +template +__global__ static void createMatrix(const int w, const int h, + const float *positions, + int *table_entries, + int table_capacity, + signed short* table_keys, + const float *scaleFactor, + MatrixEntry *matrix) +{ + // scanline order + //const int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; + //const bool outOfBounds = (idx>=num_points) ; + //const int threadId = idx; + + // 8x8 blocks + const int x = threadIdx.x + blockIdx.x * blockDim.x; + const int y = threadIdx.y + blockIdx.y * blockDim.y; + const int threadId = threadIdx.y*blockDim.x + threadIdx.x; + const int idx = y*w + x; + const bool outOfBounds = (x >= w) || (y >= h); + + float myElevated[pd+1]; + const float *myPosition = positions + idx*pd; + + int myGreedy[pd+1]; + int myRank[pd+1]; + + float myBarycentric[pd+2]; + __shared__ short keys[pd*BLOCK_SIZE]; + short *myKey = keys + threadId * pd; + + if (!outOfBounds) { + + myElevated[pd] = -pd*(myPosition[pd-1])*scaleFactor[pd-1]; + for (int i = pd-1; i > 0; i--) { + myElevated[i] = (myElevated[i+1] - + i*(myPosition[i-1])*scaleFactor[i-1] + + (i+2)*(myPosition[i])*scaleFactor[i]); + } + myElevated[0] = myElevated[1] + 2*(myPosition[0])*scaleFactor[0]; + + + // find the closest zero-colored lattice point + + // greedily search for the closest zero-colored lattice point + signed short sum = 0; + for (int i = 0; i <= pd; i++) { + float v = myElevated[i]*(1.0f/(pd+1)); + float up = ceilf(v) * (pd+1); + float down = floorf(v) * (pd+1); + if (up - myElevated[i] < myElevated[i] - down) { + myGreedy[i] = (signed short)up; + } else { + myGreedy[i] = (signed short)down; + } + sum += myGreedy[i]; + } + sum /= pd+1; + + // sort differential to find the permutation between this simplex and the canonical one + for (int i = 0; i <= pd; i++) { + myRank[i] = 0; + for (int j = 0; j <= pd; j++) { + if (myElevated[i] - myGreedy[i] < myElevated[j] - myGreedy[j] || + (myElevated[i] - myGreedy[i] == myElevated[j] - myGreedy[j] + && i > j)) { + myRank[i]++; + } + } + } + + if (sum > 0) { // sum too large, need to bring down the ones with the smallest differential + for (int i = 0; i <= pd; i++) { + if (myRank[i] >= pd + 1 - sum) { + myGreedy[i] -= pd+1; + myRank[i] += sum - (pd+1); + } else { + myRank[i] += sum; + } + } + } else if (sum < 0) { // sum too small, need to bring up the ones with largest differential + for (int i = 0; i <= pd; i++) { + if (myRank[i] < -sum) { + myGreedy[i] += pd+1; + myRank[i] += (pd+1) + sum; + } else { + myRank[i] += sum; + } + } + } + + // turn delta into barycentric coords + for (int i = 0; i <= pd+1; i++) { + myBarycentric[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + float delta = (myElevated[i] - myGreedy[i]) * (1.0f/(pd+1)); + myBarycentric[pd-myRank[i]] += delta; + myBarycentric[pd+1-myRank[i]] -= delta; + } + myBarycentric[0] += 1.0f + myBarycentric[pd+1]; + } + + for (int color = 0; color <= pd; color++) { + // Compute the location of the lattice point explicitly (all but + // the last coordinate - it's redundant because they sum to zero) + if (!outOfBounds) { + for (int i = 0; i < pd; i++) { + myKey[i] = myGreedy[i] + color; + if (myRank[i] > pd-color) myKey[i] -= (pd+1); + } + } + + if (!outOfBounds) { + MatrixEntry r; + r.index = hashTableInsert(myKey, table_keys, table_entries, + table_capacity, idx*(pd+1)+color); + r.weight = myBarycentric[color]; + matrix[idx*(pd+1) + color] = r; + } + } +} + +template +__global__ static void cleanHashTable(const int n, + int *table_entries, + int table_capacity, + signed short* table_keys) +{ + const int idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x * blockDim.y + threadIdx.x; + + if (idx >= n) return; + + // find my hash table entry + int *e = table_entries + idx; + + // Check if I created my own key in the previous phase + if (*e >= 0) { + // Rehash my key and reset the pointer in order to merge with + // any other pixel that created a different entry under the + // same key. If the computation was serial this would never + // happen, but sometimes race conditions can make the same key + // be inserted twice. hashTableRetrieve always returns the + // earlier, so it's no problem as long as we rehash now. + *e = hashTableRetrieve(table_keys + *e*kd, + table_entries, table_keys, table_capacity); + } +} + +template +__global__ static void resetIndex(const int w, const int h, + MatrixEntry *matrix, + int *table_entries) +{ + const int x = threadIdx.x + blockIdx.x * blockDim.x; + const int y = threadIdx.y + (blockIdx.y/(pd+1)) * blockDim.y; + const int color = blockIdx.y % (pd+1); + const int idx = y*w + x; + const bool outOfBounds = (x >= w) || (y >= h); + if (!outOfBounds){ + MatrixEntry r = matrix[idx*(pd+1)+color]; + matrix[idx*(pd+1)+color].index = table_entries[r.index]; + } +} + +template +__global__ static void splatCache(const int w, const int h, const int vd, + const Dtype *values, + const MatrixEntry *matrix, + float *table_values) +{ + const int x = threadIdx.x + blockIdx.x * blockDim.x; + const int y = threadIdx.y + (blockIdx.y/(pd+1)) * blockDim.y; + const int threadId = threadIdx.y*blockDim.x + threadIdx.x; + const int color = blockIdx.y % (pd+1); + const int idx = y*w + x; + const bool outOfBounds = (x >= w) || (y >= h); + + __shared__ int sharedOffsets[BLOCK_SIZE]; + extern __shared__ float sharedValues[]; + int myOffset = -1; + float *myValue = sharedValues + threadId*(vd+1); + + if (!outOfBounds) { + + const Dtype *value = values + idx; + + MatrixEntry r = matrix[idx*(pd+1)+color]; + + // convert the matrix entry from a pointer into the entries array to a pointer into the keys/values array + //matrix[idx*(pd+1)+color].index = r.index = table_entries[r.index]; + // record the offset into the keys/values array in shared space + myOffset = sharedOffsets[threadId] = r.index*(vd+1); + + for (int j = 0; j < vd; j++) { + myValue[j] = (float)value[j*w*h]*r.weight; + } + myValue[vd] = r.weight; + + } else { + sharedOffsets[threadId] = -1; + } + + __syncthreads(); + + // am I the first thread in this block to care about this key? + + if (outOfBounds) return; + + for (int i = 0; i < BLOCK_SIZE; i++) { + if (i < threadId) { + if (myOffset == sharedOffsets[i]) { + // somebody else with higher priority cares about this key + return; + } + } else if (i > threadId) { + if (myOffset == sharedOffsets[i]) { + // someone else with lower priority cares about this key, accumulate it into mine + for (int j = 0; j <= vd; j++) { + sharedValues[threadId*(vd+1) + j] += sharedValues[i*(vd+1) + j]; + } + } + } + } + + // only the threads with something to write to main memory are still going + float *val = table_values + myOffset; + for (int j = 0; j <= vd; j++) { + atomicAdd(val+j, myValue[j]); + } +} + +template +__global__ static void blur(int n, float *newValues, + const MatrixEntry *matrix, + const int *table_entries, + const signed short *table_keys, + const int table_capacity, + float *table_values, + int color, + const int vd) +{ + const int idx = (blockIdx.y * gridDim.x + blockIdx.x) * blockDim.x * blockDim.y + threadIdx.x; + + if (idx >= n) return; + + // Check if I'm valid + if (matrix[idx].index != idx) return; + + // find my key and the keys of my neighbours + short myKey[pd+1]; + short np[pd+1]; + short nm[pd+1]; + + for (int i = 0; i < pd; i++) { + myKey[i] = table_keys[idx*pd+i]; + np[i] = myKey[i]+1; + nm[i] = myKey[i]-1; + } + + + np[color] -= pd+1; + nm[color] += pd+1; + + int offNp = hashTableRetrieve(np, table_entries, table_keys, table_capacity); + int offNm = hashTableRetrieve(nm, table_entries, table_keys, table_capacity); + + float *valMe = table_values + (vd+1)*idx; + float *valNp = table_values + (vd+1)*offNp; + float *valNm = table_values + (vd+1)*offNm; + float *valOut = newValues + (vd+1)*idx; + + if (offNp >= 0 && offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i]*2) + valNm[i])/2; + } + } else if (offNp >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNp[i] + (valMe[i]*2))/2; + } + } else if (offNm >= 0) { + for (int i = 0; i <= vd; i++) { + valOut[i] = (valNm[i] + (valMe[i]*2))/2; + } + } else { + for (int i = 0; i <= vd; i++) { + valOut[i] = valMe[i]; + } + } +} + +template +__global__ static void slice(const int w, const int h, const int vd, + Dtype *values, + const MatrixEntry *matrix, + float *table_values, + bool add) { + //const int idx = blockIdx.x * BLOCK_SIZE + threadIdx.x; + + const int x = threadIdx.x + blockIdx.x * blockDim.x; + const int y = threadIdx.y + blockIdx.y * blockDim.y; + const int threadId = threadIdx.y*blockDim.x + threadIdx.x; + const int idx = y*w + x; + const bool outOfBounds = (x >= w) || (y >= h); + + if (outOfBounds) return; + + extern __shared__ float localValue[]; + + float *myValue = localValue + threadId*vd; + float myWeight = 0; + + for (int i = 0; i < vd; i++) { + myValue[i] = 0; + } + + for (int i = 0; i <= pd; i++) { + MatrixEntry r = matrix[idx*(pd+1) + i]; + const float *val = table_values + r.index*(vd+1); + for (int j = 0; j < vd; j++) { + myValue[j] += r.weight*val[j]; + } + myWeight += r.weight*val[vd]; + } + + //myWeight = 1.0f/myWeight; + float alpha = 1.0f / (1+powf(2, -pd)); + for (int j = 0; j < vd; j++){ + if(!add){ + values[j*w*h + idx] = 0; + } + values[j*w*h + idx] += myValue[j]*alpha; + } +} + + +template +void gpu_init(const float* features, + HashTable* table, + MatrixEntry* matrix, + const int w, const int h) +{ + int num_points = w*h ; + // Scan line order + //unsigned int blocks = (num_points-1)/64 + 1; + //unsigned int blockSize = 64; + dim3 blocks((w-1)/8+1, (h-1)/8+1, 1); + dim3 blockSize(8, 8, 1); + + float blurVariance = 0.5 ; + float * scaleFactor; + float* scaleFactorHost = new float[pd]; + + // Create Scale factor vector and give it to GPU + // num_dimensions is likely to be low so do that + // on the CPU + for (int i = 0; i < pd; i++) { + scaleFactorHost[i] = (pd+1)*sqrtf((1.0/6 + blurVariance)/((i+1)*(i+2))); + } + CUDA_CHECK(cudaMalloc((void**)&scaleFactor, sizeof(float)*pd)); + CUDA_CHECK(cudaMemcpy(scaleFactor, scaleFactorHost, sizeof(float)*pd, cudaMemcpyHostToDevice)); + + createMatrix<<>>(w, h, + features, + table->table_entries, + table->table_capacity, + table->table_keys, + scaleFactor, + matrix); + CUDA_POST_KERNEL_CHECK; + // fix duplicate hash table entries + int cleanBlockSize = 32; + dim3 cleanBlocks((num_points-1)/cleanBlockSize+1, 2*(pd+1), 1); + cleanHashTable<<>>(2*num_points*(pd+1), + table->table_entries, table->table_capacity, table->table_keys); + CUDA_POST_KERNEL_CHECK; + + blocks.y *= pd+1; + resetIndex<<>>(w, h, matrix, table->table_entries) ; + CUDA_POST_KERNEL_CHECK; + + // Clean intermediate variables + delete[] scaleFactorHost; + CUDA_CHECK(cudaFree(scaleFactor)); + +} + +template +void gpu_compute(Dtype* out, const Dtype* in, const HashTable &table, + const MatrixEntry* matrix, + int w, int h, int vd, + bool reverse, bool add){ + + // Create table_values + int num_points = w*h ; + float *table_values ; + CUDA_CHECK(cudaMalloc((void**)&table_values, sizeof(float)*(vd+1)*num_points*(pd+1))) ; + gpu_set(num_points*(vd+1)*(pd+1), 0, table_values) ; + + dim3 blocks((w-1)/8+1, (h-1)/8+1, 1); + dim3 blockSize(8, 8, 1); + + // splat splits by color, so extend the y coordinate to our blocks to represent that + blocks.y *= pd+1; + splatCache<<>>(w, h, vd, + in, + matrix, + table_values); + CUDA_POST_KERNEL_CHECK; + + // blur + int cleanBlockSize = 32; + dim3 cleanBlocks((num_points-1)/cleanBlockSize+1, 2*(pd+1), 1); + float *newValues; + float *oldValues; + size_t size = num_points*(pd+1)*(vd+1)*sizeof(float); + CUDA_CHECK(cudaMalloc((void**)&(newValues), size)); + CUDA_CHECK(cudaMalloc((void**)&(oldValues), size)); + gpu_set(num_points*(vd+1)*(pd+1), 0, newValues) ; + for (int color = reverse?pd:0; color <= pd && color>=0; reverse?color--:color++) { + blur<<>>(num_points*(pd+1), newValues, + matrix, + table.table_entries, + table.table_keys, + table.table_capacity, + table_values, + color, + vd); + CUDA_POST_KERNEL_CHECK; + // swap pointers does not seem to work... + swapHashTableValues(oldValues, newValues, table_values, size); + } + + // slice + blocks.y /= (pd+1); + slice<<>>(w, h, vd, out, matrix, table_values, add); + CUDA_POST_KERNEL_CHECK; + + // Free memory + CUDA_CHECK(cudaFree(table_values)) ; + CUDA_CHECK(cudaFree(newValues)) ; + CUDA_CHECK(cudaFree(oldValues)) ; +} + +void ModifiedPermutohedral::init_gpu(const float* features, int num_dimensions, int w, int h) { + //Initialize Hash table + if(!is_init){ + table.createHashTable(w*h*(num_dimensions+1), num_dimensions); + CUDA_CHECK(cudaMalloc((void **)&matrix, sizeof(MatrixEntry)*(w*h*(num_dimensions+1)))); + } else { + table.resetHashTable(w_*h_*(d_+1), d_); + } + w_ = w ; + h_ = h ; + d_ = num_dimensions ; + N_ = w*h ; + switch(num_dimensions){ + case 2: + gpu_init<2>(features, &table, matrix, w_, h_); + break; + case 5: + gpu_init<5>(features, &table, matrix, w_, h_); + break; + default: + std::cout << "num_dimensions should be 2 or 5"; + } + is_init = true; +} + +void ModifiedPermutohedral::compute_gpu(float* out, const float* in, int value_size, bool reverse, bool add) const { + // Losing time by dynamically allocating memory but more general function + if(!is_init) + std::cout << "Initialize lattice before doing any computing"; + switch(d_){ + case 2: + gpu_compute<2, float>(out, in, table, matrix, w_, h_, value_size, reverse, add); + break; + case 5: + gpu_compute<5, float>(out, in, table, matrix, w_, h_, value_size, reverse, add); + break; + default: + std::cout << "num_dimensions should be 2 or 5"; + } +} From 43c2db51110cb410fbac5de6d44a9c7dd4136350 Mon Sep 17 00:00:00 2001 From: Tom Joy Date: Wed, 9 May 2018 12:51:38 +0100 Subject: [PATCH 2/2] Adding Licences and corresponding info Updating licence --- src/cpp/hash_helper.cu | 2 ++ src/cpp/high_dim_filter.cc | 2 ++ src/cpp/high_dim_filter.cu | 26 ++++++++++++++++++++++++++ src/cpp/include/hash_table.h | 3 +++ src/cpp/include/high_dim_filter.h | 26 ++++++++++++++++++++++++++ src/cpp/modified_permutohedral.cu | 2 ++ 6 files changed, 61 insertions(+) diff --git a/src/cpp/hash_helper.cu b/src/cpp/hash_helper.cu index 1ed7bff..a54212a 100644 --- a/src/cpp/hash_helper.cu +++ b/src/cpp/hash_helper.cu @@ -1,3 +1,5 @@ +//This file is take from caffe/crfasrnn + #define modHash(n) ((n)%(2*table_capacity)); diff --git a/src/cpp/high_dim_filter.cc b/src/cpp/high_dim_filter.cc index 4c1d800..e9cc20a 100644 --- a/src/cpp/high_dim_filter.cc +++ b/src/cpp/high_dim_filter.cc @@ -3,6 +3,8 @@ * * Copyright (c) 2017 Sadeep Jayasumana * + * Modified by Tom Joy 2018, tomjoy@robots.ox.ac.uk + * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal * in the Software without restriction, including without limitation the rights diff --git a/src/cpp/high_dim_filter.cu b/src/cpp/high_dim_filter.cu index 599bc4e..2ab7b3d 100644 --- a/src/cpp/high_dim_filter.cu +++ b/src/cpp/high_dim_filter.cu @@ -1,3 +1,29 @@ +/* + * MIT License + * + * Copyright (c) 2017 Sadeep Jayasumana + * + * Author Tom Joy 2018, tomjoy@robots.ox.ac.uk + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + #if FILTER_GPU #define EIGEN_USE_GPU diff --git a/src/cpp/include/hash_table.h b/src/cpp/include/hash_table.h index 49e43d3..b738a1d 100644 --- a/src/cpp/include/hash_table.h +++ b/src/cpp/include/hash_table.h @@ -1,3 +1,6 @@ +//This file is take from torrvision/crfasrnn + + #ifndef HASH_TABLE_HPP #define HASH_TABLE_HPP diff --git a/src/cpp/include/high_dim_filter.h b/src/cpp/include/high_dim_filter.h index 2698e24..e722ce6 100644 --- a/src/cpp/include/high_dim_filter.h +++ b/src/cpp/include/high_dim_filter.h @@ -1,3 +1,29 @@ +/* + * MIT License + * + * Copyright (c) 2017 Sadeep Jayasumana + * + * Author Tom Joy 2018, tomjoy@robots.ox.ac.uk + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ + #ifndef KERNEL_HIGH_DIM_FILTER_H_ #define KERNEL_HIGH_DIM_FILTER_H_ diff --git a/src/cpp/modified_permutohedral.cu b/src/cpp/modified_permutohedral.cu index 7940317..e48838a 100644 --- a/src/cpp/modified_permutohedral.cu +++ b/src/cpp/modified_permutohedral.cu @@ -1,3 +1,5 @@ +//This file is take from torrvision/crfasrnn + #define BLOCK_SIZE 64 #include