Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Filtering can now be performed on the GPU #32

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,5 @@
*.pyc
src/cpp/high_dim_filter.so
*.h5
*.png
*.o
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,8 @@ You should not see any errors while importing `tensorflow` and `keras` above.

### Step 3: Build CRF-RNN custom op C++ code

**Note**: Edit the makefile to select whether to run on GPU/CPU.

Run `make` inside the `crfasrnn_keras/src/cpp` directory:
```
$ cd crfasrnn_keras/src/cpp
Expand Down
5 changes: 4 additions & 1 deletion run_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
sys.path.insert(1, './src')
from crfrnn_model import get_crfrnn_model_def
import util

import time

def main():
input_file = 'image.jpg'
Expand All @@ -39,9 +39,12 @@ def main():
model.load_weights(saved_model_path)

img_data, img_h, img_w = util.get_preprocessed_image(input_file)
tic = time.clock()
probs = model.predict(img_data, verbose=False)[0, :, :, :]
toc = time.clock()
segmentation = util.get_label_image(probs, img_h, img_w)
segmentation.save(output_file)
print "Time taken: " + str(toc - tic)


if __name__ == '__main__':
Expand Down
22 changes: 17 additions & 5 deletions src/cpp/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# Define the compiler
CC := g++

#Use GPU Implementation?
USE_GPU := 1

# Read Tensorflow paths
TF_INC := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())')
TF_LIB := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())')
TF_CFLAGS=$(shell python -c 'import tensorflow as tf; print(" ".join(tf.sysconfig.get_compile_flags()))')

# Is the Tensorflow version >= 1.4?
TF_VERSION_GTE_1_4 := $(shell expr `python -c 'import tensorflow as tf; print(tf.__version__)' | cut -f1,2 -d.` \>= 1.4)
Expand All @@ -31,18 +35,26 @@ endif
# Set some more flags if the Tensorflow version is >= 1.4
ifeq ($(TF_VERSION_GTE_1_4), 1)
CFLAGS += -I$(TF_INC)/external/nsync/public
LDFLAGS := -L$(TF_LIB) -ltensorflow_framework
LDFLAGS := -L$(TF_LIB) -ltensorflow_framework -L/usr/local/cuda/lib64 -lcuda -lcudart
else
LDFLAGS :=
endif

# Define build targets
.PHONY: all clean

high_dim_filter.so: high_dim_filter.cc modified_permutohedral.cc
$(CC) $(CFLAGS) -o high_dim_filter.so high_dim_filter.cc modified_permutohedral.cc $(LDFLAGS)

all: high_dim_filter.so

high_dim_filter.so: cudacode.o modified_permutohedral.o high_dim_filter.cc modified_permutohedral.cc
g++ $(CFLAGS) -o high_dim_filter.so high_dim_filter.cc modified_permutohedral.cc cudacode.o modified_permutohedral.o $(LDFLAGS) -D FILTER_GPU=$(USE_GPU)

cudacode.o: high_dim_filter.cu
nvcc -std=c++11 -c -o cudacode.o high_dim_filter.cu $(TF_CFLAGS) -x cu -Xcompiler -fPIC --expt-relaxed-constexpr -D FILTER_GPU=$(USE_GPU)

modified_permutohedral.o: modified_permutohedral.cu
nvcc -std=c++11 -c -o modified_permutohedral.o modified_permutohedral.cu $(TF_CFLAGS) -x cu -Xcompiler -fPIC --expt-relaxed-constexpr -D FILTER_GPU=$(USE_GPU)

clean:
$(RM) high_dim_filter.so
$(RM) high_dim_filter.so cudacode.o modified_permutohedral.o

all: high_dim_filter.so
95 changes: 95 additions & 0 deletions src/cpp/hash_helper.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
//This file is take from caffe/crfasrnn


#define modHash(n) ((n)%(2*table_capacity));

namespace caffe {
template<int kd>
__device__ __host__ static unsigned int hash(signed short *key) {
unsigned int k = 0;
for (int i = 0; i < kd; i++) {
k += key[i];
k = k * 1664525;
}
return k;
}
template<int kd>
__device__ __host__ static unsigned int has(int *key) {
unsigned int k = 0;
for (int i = 0; i < kd; i++) {
k += key[i];
k = k * 1664525;
}
return k;
}
template<int kd>
__device__ static int hashTableInsert(unsigned int fh, signed short *key,
signed short * table_keys,
int* table_entries,
int table_capacity,
unsigned int slot)
{
int h = modHash(fh);
while (1) {
int *e = &table_entries[h];
// if the cell is empty (-1), lock it (-2)
int contents = atomicCAS(e, -1, -2);

if (contents == -2) {
// If it was locked already, move on the next cell

} else if (contents == -1) {
// If it was empty, we successfully locked it, write our key
for (int i = 0; i < kd; i++) {
table_keys[slot*kd+i] = key[i];
}
// Unlock
atomicExch(e, slot);

return h;
} else {
// The cell is unlocked and has a key in it, check if it matches
bool match = true;
for (int i = 0; i < kd && match; i++) {
match = (table_keys[contents*kd+i] == key[i]);
}
if (match) return h;
}
// increment the bucket with wraparound
h++;
if (h == table_capacity*2) h = 0;
}
}

template<int kd>
__device__ static int hashTableInsert(signed short *key,
signed short* table_keys,
int* table_entries,
int table_capacity,
unsigned int slot) {
unsigned int myHash = hash<kd>(key);
return hashTableInsert<kd>(myHash, key, table_keys, table_entries, table_capacity, slot);
}

template<int kd>
__device__ static int hashTableRetrieve(signed short*key,
const int * table_entries,
const signed short* table_keys,
const int table_capacity) {
int h = modHash(hash<kd>(key));
while (1) {
const int *e = table_entries + h;
if (*e == -1) return -1;
bool match = true;
for (int i = 0; i < kd && match; i++) {
match = (table_keys[(*e)*kd+i] == key[i]);
}
if (match) return *e;

h++;
if (h == table_capacity*2) h = 0;
}
}

} //namespace caffe

107 changes: 78 additions & 29 deletions src/cpp/high_dim_filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
*
* Copyright (c) 2017 Sadeep Jayasumana
*
* Modified by Tom Joy 2018, [email protected]
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
Expand All @@ -26,7 +28,7 @@
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "modified_permutohedral.h"
#include "include/high_dim_filter.h"

using namespace tensorflow;

Expand Down Expand Up @@ -61,19 +63,62 @@ void compute_bilateral_kernel(float * const output_kernel, const Tensor& rgb_ten
}

REGISTER_OP("HighDimFilter")
.Attr("T: {float}")
.Attr("bilateral: bool")
.Attr("theta_alpha: float = 1.0")
.Attr("theta_beta: float = 1.0")
.Attr("theta_gamma: float = 1.0")
.Attr("backwards: bool = false")
.Input("raw: float32")
.Input("rgb: float32")
.Output("filtered: float32")
.Input("raw: T")
.Input("rgb: T")
.Output("filtered: T")
.SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) {
c->set_output(0, c->input(0));
return Status::OK();
});


using CPUDevice = Eigen::ThreadPoolDevice;
using GPUDevice = Eigen::GpuDevice;

template <>
struct HighDimFilterFunctor<CPUDevice> {
void operator()(const CPUDevice& d,
const tensorflow::Tensor & input_q,
const tensorflow::Tensor & input_img,
tensorflow::Tensor * out,
const FilterParams & params) {

ModifiedPermutohedral mp;

const int channels = input_q.dim_size(0);
const int height = input_img.dim_size(1);
const int width = input_img.dim_size(2);
const int num_pixels = width * height;

if (params.bilateral_) {
float * const kernel_vals = new float[5 * num_pixels];
compute_bilateral_kernel(kernel_vals, input_img,
params.theta_alpha_, params.theta_beta_);
mp.init_cpu(kernel_vals, 5, num_pixels);
mp.compute_cpu(*out, input_q, channels, params.backwards_);

delete[] kernel_vals;
} else {
float * const kernel_vals = new float[2 * num_pixels];
compute_spatial_kernel(kernel_vals, width, height, params.theta_gamma_);
mp.init_cpu(kernel_vals, 2, num_pixels);
mp.compute_cpu(*out, input_q, channels, params.backwards_);

delete[] kernel_vals;
}

}

};


template <typename Device, typename T>
class HighDimFilterOp : public OpKernel {
public:
explicit HighDimFilterOp(OpKernelConstruction* context) : OpKernel(context) {
Expand All @@ -97,36 +142,25 @@ class HighDimFilterOp : public OpKernel {
// Grab the RGB image tensor
const Tensor& image_tensor = context->input(1);

const int channels = input_tensor.dim_size(0);
const int height = input_tensor.dim_size(1);
const int width = input_tensor.dim_size(2);
const int num_pixels = width * height;

// Create the output tensor
Tensor* output_tensor = NULL;
OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(),
&output_tensor));
ModifiedPermutohedral mp;

if (bilateral_) {
float * const kernel_vals = new float[5 * num_pixels];
compute_bilateral_kernel(kernel_vals, image_tensor,
theta_alpha_, theta_beta_);
mp.init(kernel_vals, 5, num_pixels);
mp.compute(*output_tensor, input_tensor, channels, backwards_);

delete[] kernel_vals;
} else {
float * const kernel_vals = new float[2 * num_pixels];
compute_spatial_kernel(kernel_vals, width, height, theta_gamma_);
mp.init(kernel_vals, 2, num_pixels);
mp.compute(*output_tensor, input_tensor, channels, backwards_);

delete[] kernel_vals;
}

const FilterParams params(bilateral_,
theta_alpha_,
theta_beta_,
theta_gamma_,
backwards_);
//filter
HighDimFilterFunctor<Device>()(
context->eigen_device<Device>(),
input_tensor,
image_tensor,
output_tensor,
params);
}

private:
bool bilateral_;
float theta_alpha_;
Expand All @@ -135,4 +169,19 @@ class HighDimFilterOp : public OpKernel {
bool backwards_;
};

REGISTER_KERNEL_BUILDER(Name("HighDimFilter").Device(DEVICE_CPU), HighDimFilterOp);

// Register kernels
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("HighDimFilter").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
HighDimFilterOp<CPUDevice, T>);
REGISTER_CPU(float);

#if FILTER_GPU
#define REGISTER_GPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("HighDimFilter").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
HighDimFilterOp<GPUDevice, T>);
REGISTER_GPU(float);
#endif

Loading