diff --git a/.gitattributes b/.gitattributes old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 index 3ae6c3c..b7ddf34 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ *.pyc -src/cpp/high_dim_filter.so +*.so +crfrnn_keras_model.h5 \ No newline at end of file diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 index dadef5e..66a4bc7 --- a/README.md +++ b/README.md @@ -1,4 +1,6 @@ # CRF-RNN for Semantic Image Segmentation - Keras/Tensorflow version +## Forked from [sadeepj/crfasrnn_keras](https://github.com/sadeepj/crfasrnn_keras). +#### Credit for all content below due to sadeepj/crfasrnn_keras contributors ![sample](sample.png) Live demo:      [http://crfasrnn.torr.vision](http://crfasrnn.torr.vision)
diff --git a/download_model_weights.sh b/download_model_weights.sh new file mode 100755 index 0000000..51f3c69 --- /dev/null +++ b/download_model_weights.sh @@ -0,0 +1 @@ +wget https://github.com/sadeepj/crfasrnn_keras/releases/download/v1.0/crfrnn_keras_model.h5 diff --git a/image.jpg b/image.jpg old mode 100644 new mode 100755 diff --git a/labels.png b/labels.png new file mode 100755 index 0000000..5b89df2 Binary files /dev/null and b/labels.png differ diff --git a/requirements.txt b/requirements.txt old mode 100644 new mode 100755 diff --git a/requirements_gpu.txt b/requirements_gpu.txt old mode 100644 new mode 100755 diff --git a/run_demo.py b/run_demo.py old mode 100644 new mode 100755 index e0a287a..218107e --- a/run_demo.py +++ b/run_demo.py @@ -23,7 +23,7 @@ """ import sys -sys.path.insert(1, './src') +sys.path.insert(1, './src/python') from crfrnn_model import get_crfrnn_model_def import util diff --git a/sample.png b/sample.png old mode 100644 new mode 100755 diff --git a/src/cpp/Makefile b/src/cpp/Makefile old mode 100644 new mode 100755 index ff4a152..ee4cb52 --- a/src/cpp/Makefile +++ b/src/cpp/Makefile @@ -12,13 +12,15 @@ # Define the compiler CC := g++ +# Define the target python version +PYTHON := python3.6 # Read Tensorflow paths -TF_INC := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') -TF_LIB := $(shell python -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') +TF_INC := $(shell $(PYTHON) -c 'import tensorflow as tf; print(tf.sysconfig.get_include())') +TF_LIB := $(shell $(PYTHON) -c 'import tensorflow as tf; print(tf.sysconfig.get_lib())') # Is the Tensorflow version >= 1.4? -TF_VERSION_GTE_1_4 := $(shell expr `python -c 'import tensorflow as tf; print(tf.__version__)' | cut -f1,2 -d.` \>= 1.4) +TF_VERSION_GTE_1_4 := $(shell expr `$(PYTHON) -c 'import tensorflow as tf; print(tf.__version__)' | cut -f1,2 -d.` \>= 1.4) # Flags required for all cases CFLAGS := -std=c++11 -D_GLIBCXX_USE_CXX11_ABI=0 -shared -fPIC -I$(TF_INC) -O2 @@ -40,9 +42,9 @@ endif .PHONY: all clean high_dim_filter.so: high_dim_filter.cc modified_permutohedral.cc - $(CC) $(CFLAGS) -o high_dim_filter.so high_dim_filter.cc modified_permutohedral.cc $(LDFLAGS) + $(CC) $(CFLAGS) -o build/high_dim_filter.so high_dim_filter.cc modified_permutohedral.cc $(LDFLAGS) clean: - $(RM) high_dim_filter.so + $(RM) $(OUTDIR)/high_dim_filter.so all: high_dim_filter.so diff --git a/src/cpp/high_dim_filter.cc b/src/cpp/high_dim_filter.cc old mode 100644 new mode 100755 index 25485ce..b01d247 --- a/src/cpp/high_dim_filter.cc +++ b/src/cpp/high_dim_filter.cc @@ -30,45 +30,99 @@ using namespace tensorflow; -void compute_spatial_kernel(float * const output_kernel, const int width, - const int height, const float theta_gamma) { +void compute_spatial_kernel(float * const output_kernel, + const int width, + const int height, + const float theta_gamma) { - const int num_pixels = width * height; - for (int p = 0; p < num_pixels; ++p) { - output_kernel[2 * p] = static_cast(p % width) / theta_gamma; - output_kernel[2 * p + 1] = static_cast(p / width) / theta_gamma; - } + const int num_pixels = width * height; + for (int p = 0; p < num_pixels; ++p) { + output_kernel[2 * p] = static_cast(p % width) / theta_gamma; + output_kernel[2 * p + 1] = static_cast(p / width) / theta_gamma; + } } -void compute_bilateral_kernel(float * const output_kernel, const Tensor& rgb_tensor, - const float theta_alpha, const float theta_beta) { - - const int height = rgb_tensor.dim_size(1); - const int width = rgb_tensor.dim_size(2); - const int num_pixels = height * width; - auto rgb = rgb_tensor.flat(); +void compute_spatial_kernel_3d(float * const output_kernel, + const int width, + const int height, + const int depth, + const float theta_gamma, + const float theta_gamma_z) { + const int hw = height * width; + const int num_voxels = depth * height * width; + for (int p = 0; p < num_voxels; ++p) { + output_kernel[3 * p] = static_cast(p % width) / theta_gamma; + output_kernel[3 * p + 1] = static_cast(p / width) / theta_gamma; + output_kernel[3 * p + 2] = static_cast(p / hw) / theta_gamma_z; + } +} - for (int p = 0; p < num_pixels; ++p) { - // Spatial terms - output_kernel[5 * p] = static_cast(p % width) / theta_alpha; - output_kernel[5 * p + 1] = static_cast(p / width) / theta_alpha; +void compute_bilateral_kernel(float * const output_kernel, + const Tensor& image_tensor, + const float theta_alpha, + const float theta_beta) { + + const int unary_channels = image_tensor.dim_size(0); + const int height = image_tensor.dim_size(1); + const int width = image_tensor.dim_size(2); + const int num_pixels = height * width; + auto rgb = image_tensor.flat(); + + // Number of output unary_channels: rgb unary_channels plus two spatial (x, y) unary_channels + const int oc = unary_channels + 2; + for (int p = 0; p < num_pixels; ++p) { + // Spatial terms + output_kernel[oc * p] = static_cast(p % width) / theta_alpha; + output_kernel[oc * p + 1] = static_cast(p / width) / theta_alpha; + + // Color channel terms + for (int i = 0; i < unary_channels; ++i) { + output_kernel[oc * p + i + 2] = + static_cast(rgb(p + i * num_pixels) / theta_beta); + } + } +} - // Color terms - output_kernel[5 * p + 2] = static_cast(rgb(p) / theta_beta); - output_kernel[5 * p + 3] = static_cast(rgb(num_pixels + p) / theta_beta); - output_kernel[5 * p + 4] = static_cast(rgb(2 * num_pixels + p) / theta_beta); - } +void compute_bilateral_kernel_3d(float * const output_kernel, + const Tensor& image_tensor, + const float theta_alpha, + const float theta_alpha_z, + const float theta_beta) { + const int unary_channels = image_tensor.dim_size(0); + const int depth = image_tensor.dim_size(1); + const int height = image_tensor.dim_size(2); + const int width = image_tensor.dim_size(3); + const int hw = height * width; + const int num_pixels = depth * height * width; + + auto rgb = image_tensor.flat(); + + const int oc = unary_channels + 3; + for (int p = 0; p < num_pixels; ++p) { + output_kernel[oc * p] = static_cast(p % width) / theta_alpha; + output_kernel[oc * p + 1] = static_cast(p / width) / theta_alpha; + output_kernel[oc * p + 2] = static_cast(p / hw) / theta_alpha_z; + + // Color channel terms + for (int i = 0; i < unary_channels; ++i) { + output_kernel[oc * p + i + 3] = + static_cast(rgb(p + i * num_pixels)) / theta_beta; + } + } } REGISTER_OP("HighDimFilter") + .Attr("T: {float}") .Attr("bilateral: bool") .Attr("theta_alpha: float = 1.0") + .Attr("theta_alpha_z: float = 1.0") .Attr("theta_beta: float = 1.0") .Attr("theta_gamma: float = 1.0") + .Attr("theta_gamma_z: float = 1.0") .Attr("backwards: bool = false") - .Input("raw: float32") - .Input("rgb: float32") - .Output("filtered: float32") + .Input("raw: T") + .Input("rgb: T") + .Output("filtered: T") .SetShapeFn([](::tensorflow::shape_inference::InferenceContext* c) { c->set_output(0, c->input(0)); return Status::OK(); @@ -82,10 +136,14 @@ class HighDimFilterOp : public OpKernel { context->GetAttr("bilateral", &bilateral_)); OP_REQUIRES_OK(context, context->GetAttr("theta_alpha", &theta_alpha_)); + OP_REQUIRES_OK(context, + context->GetAttr("theta_alpha_z", &theta_alpha_z_)); OP_REQUIRES_OK(context, context->GetAttr("theta_beta", &theta_beta_)); OP_REQUIRES_OK(context, context->GetAttr("theta_gamma", &theta_gamma_)); + OP_REQUIRES_OK(context, + context->GetAttr("theta_gamma_z", &theta_gamma_z_)); OP_REQUIRES_OK(context, context->GetAttr("backwards", &backwards_)); } @@ -93,35 +151,61 @@ class HighDimFilterOp : public OpKernel { void Compute(OpKernelContext* context) override { // Grab the unary tensor - const Tensor& input_tensor = context->input(0); + const Tensor& unary_tensor = context->input(0); // Grab the RGB image tensor const Tensor& image_tensor = context->input(1); - - const int channels = input_tensor.dim_size(0); - const int height = input_tensor.dim_size(1); - const int width = input_tensor.dim_size(2); - const int num_pixels = width * height; + + const int spatial_dims = image_tensor.dims() - 1; + const bool is_3d = spatial_dims == 3; + + const int image_channels = image_tensor.dim_size(0); + const int bilateral_channels = image_channels + spatial_dims; + const int unary_channels = unary_tensor.dim_size(0); + const int depth = is_3d ? image_tensor.dim_size(1) : 1; + const int height = image_tensor.dim_size(spatial_dims - 1); + const int width = image_tensor.dim_size(spatial_dims); + const int num_pixels = width * height * depth; // Create the output tensor Tensor* output_tensor = NULL; - OP_REQUIRES_OK(context, context->allocate_output(0, input_tensor.shape(), + OP_REQUIRES_OK(context, context->allocate_output(0, unary_tensor.shape(), &output_tensor)); ModifiedPermutohedral mp; if (bilateral_) { - float * const kernel_vals = new float[5 * num_pixels]; - compute_bilateral_kernel(kernel_vals, image_tensor, - theta_alpha_, theta_beta_); - mp.init(kernel_vals, 5, num_pixels); - mp.compute(*output_tensor, input_tensor, channels, backwards_); - + float * const kernel_vals = new float[bilateral_channels * num_pixels]; + if (is_3d) { + compute_bilateral_kernel_3d(kernel_vals, + image_tensor, + theta_alpha_, + theta_alpha_z_, + theta_beta_); + } else { + compute_bilateral_kernel(kernel_vals, + image_tensor, + theta_alpha_, + theta_beta_); + } + mp.init(kernel_vals, bilateral_channels, num_pixels); + mp.compute(*output_tensor, unary_tensor, unary_channels, backwards_); delete[] kernel_vals; } else { - float * const kernel_vals = new float[2 * num_pixels]; - compute_spatial_kernel(kernel_vals, width, height, theta_gamma_); - mp.init(kernel_vals, 2, num_pixels); - mp.compute(*output_tensor, input_tensor, channels, backwards_); - + float * const kernel_vals = new float[spatial_dims * num_pixels]; + if (is_3d) { + compute_spatial_kernel_3d(kernel_vals, + width, + height, + depth, + theta_gamma_, + theta_gamma_z_); + } else { + compute_spatial_kernel(kernel_vals, + width, + height, + theta_gamma_); + } + mp.init(kernel_vals, spatial_dims, num_pixels); + mp.compute(*output_tensor, unary_tensor, unary_channels, backwards_); delete[] kernel_vals; } @@ -130,8 +214,10 @@ class HighDimFilterOp : public OpKernel { private: bool bilateral_; float theta_alpha_; + float theta_alpha_z_; float theta_beta_; float theta_gamma_; + float theta_gamma_z_; bool backwards_; }; diff --git a/src/cpp/modified_permutohedral.cc b/src/cpp/modified_permutohedral.cc old mode 100644 new mode 100755 diff --git a/src/cpp/modified_permutohedral.h b/src/cpp/modified_permutohedral.h old mode 100644 new mode 100755 diff --git a/src/python/crfrnn_layer.py b/src/python/crfrnn_layer.py new file mode 100755 index 0000000..de6e980 --- /dev/null +++ b/src/python/crfrnn_layer.py @@ -0,0 +1,281 @@ +""" +MIT License + +Copyright (c) 2017 Sadeep Jayasumana + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import numpy as np +import tensorflow as tf +from keras.engine.topology import Layer +from high_dim_filter_loader import custom_module + + +def _diagonal_initializer(shape): + return np.eye(shape[0], shape[1], dtype=np.float32) + + +def _potts_model_initializer(shape): + return -1 * _diagonal_initializer(shape) + + +class CrfRnnLayer3D(Layer): + """ Implements a 3D variant of the 2D CRF-RNN layer described in: + + Conditional Random Fields as Recurrent Neural Networks, + S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, + C. Huang and P. Torr, ICCV 2015 + """ + + def __init__(self, + image_dims, + num_classes, + theta_alpha=160., + theta_alpha_z=40., + theta_beta=3., + theta_gamma=3., + theta_gamma_z=1., + num_iterations=10, + **kwargs): + self.image_dims = image_dims + self.num_classes = num_classes + self.theta_alpha = theta_alpha + self.theta_alpha_z = theta_alpha_z + self.theta_beta = theta_beta + self.theta_gamma = theta_gamma + self.theta_gamma_z = theta_gamma_z + self.num_iterations = num_iterations + + self.spatial_ker_weights = None + self.bilateral_ker_weights = None + self.compatibility_matrix = None + + super(CrfRnnLayer3D, self).__init__(**kwargs) + + pass + + def build(self, input_shape): + # Weights of the spatial kernel + self.spatial_ker_weights = self.add_weight( + name='spatial_ker_weights', + shape=(self.num_classes, self.num_classes), + initializer=_diagonal_initializer, + trainable=True) + + # Weights of the bilateral kernel + self.bilateral_ker_weights = self.add_weight( + name='bilateral_ker_weights', + shape=(self.num_classes, self.num_classes), + initializer=_diagonal_initializer, + trainable=True) + + # Compatibility matrix + self.compatibility_matrix = self.add_weight( + name='compatibility_matrix', + shape=(self.num_classes, self.num_classes), + initializer=_potts_model_initializer, + trainable=True) + + super(CrfRnnLayer3D, self).build(input_shape) + + pass + + def call(self, inputs): + image = inputs[0, 0:-self.num_classes, ...] + unaries = inputs[0, -self.num_classes:, ...] + + c = self.num_classes + d, h, w = self.image_dims + + # Prepare filter normalization coefficients + all_ones = np.ones((c, d, h, w), dtype=np.float32) + spatial_norm_vals = custom_module.high_dim_filter( + all_ones, + image, + bilateral=False, + theta_gamma=self.theta_gamma, + theta_gamma_z=self.theta_gamma_z) + bilateral_norm_vals = custom_module.high_dim_filter( + all_ones, + image, + bilateral=True, + theta_alpha=self.theta_alpha, + theta_alpha_z=self.theta_alpha_z, + theta_beta=self.theta_beta) + + q_values = unaries + + for i in range(self.num_iterations): + softmax_out = tf.nn.softmax(q_values, 0) + + # Spatial filtering + spatial_out = custom_module.high_dim_filter( + softmax_out, + image, + bilateral=False, + theta_gamma=self.theta_gamma, + theta_gamma_z=self.theta_gamma_z) + spatial_out = spatial_out / spatial_norm_vals + + # Bilateral filtering + bilateral_out = custom_module.high_dim_filter( + softmax_out, + image, + bilateral=True, + theta_alpha=self.theta_alpha, + theta_alpha_z=self.theta_alpha_z, + theta_beta=self.theta_beta) + bilateral_out = bilateral_out / bilateral_norm_vals + + # Weighting filter outputs + message_passing = (tf.matmul(self.spatial_ker_weights, + tf.reshape(spatial_out, (c, -1))) + + tf.matmul(self.bilateral_ker_weights, + tf.reshape(bilateral_out, (c, -1)))) + + # Compatibility transform + pairwise = tf.matmul(self.compatibility_matrix, message_passing) + + # Adding unary potentials + pairwise = tf.reshape(pairwise, (c, d, h, w)) + q_values = unaries - pairwise + + return tf.reshape(q_values, (1, c, d, h, w)) + + def compute_output_shape(self, input_shape): + return input_shape + + +class CrfRnnLayer2D(Layer): + """ Implements the 2D CRF-RNN layer described in: + + Conditional Random Fields as Recurrent Neural Networks, + S. Zheng, S. Jayasumana, B. Romera-Paredes, V. Vineet, Z. Su, D. Du, + C. Huang and P. Torr, ICCV 2015 + """ + + def __init__(self, + image_dims, + num_classes, + theta_alpha=160., + theta_beta=3., + theta_gamma=3., + num_iterations=10, + **kwargs): + self.image_dims = image_dims + self.num_classes = num_classes + self.theta_alpha = theta_alpha + self.theta_beta = theta_beta + self.theta_gamma = theta_gamma + self.num_iterations = num_iterations + + self.spatial_ker_weights = None + self.bilateral_ker_weights = None + self.compatibility_matrix = None + + super(CrfRnnLayer2D, self).__init__(**kwargs) + + pass + + def build(self, input_shape): + # Weights of the spatial kernel + self.spatial_ker_weights = self.add_weight( + name='spatial_ker_weights', + shape=(self.num_classes, self.num_classes), + initializer=_diagonal_initializer, + trainable=True) + + # Weights of the bilateral kernel + self.bilateral_ker_weights = self.add_weight( + name='bilateral_ker_weights', + shape=(self.num_classes, self.num_classes), + initializer=_diagonal_initializer, + trainable=True) + + # Compatibility matrix + self.compatibility_matrix = self.add_weight( + name='compatibility_matrix', + shape=(self.num_classes, self.num_classes), + initializer=_potts_model_initializer, + trainable=True) + + super(CrfRnnLayer2D, self).build(input_shape) + + pass + + def call(self, inputs): + unaries = inputs[0, 0:self.num_classes, ...] + image = inputs[0, self.num_classes:, ...] + + c, h, w = self.num_classes, self.image_dims[0], self.image_dims[1] + all_ones = np.ones((c, h, w), dtype=np.float32) + + # Prepare filter normalization coefficients + spatial_norm_vals = custom_module.high_dim_filter( + all_ones, + image, + bilateral=False, + theta_gamma=self.theta_gamma) + bilateral_norm_vals = custom_module.high_dim_filter( + all_ones, + image, + bilateral=True, + theta_alpha=self.theta_alpha, + theta_beta=self.theta_beta) + + q_values = unaries + + for i in range(self.num_iterations): + softmax_out = tf.nn.softmax(q_values, 0) + + # Spatial filtering + spatial_out = custom_module.high_dim_filter( + softmax_out, + image, + bilateral=False, + theta_gamma=self.theta_gamma) + spatial_out = spatial_out / spatial_norm_vals + + # Bilateral filtering + bilateral_out = custom_module.high_dim_filter( + softmax_out, + image, + bilateral=True, + theta_alpha=self.theta_alpha, + theta_beta=self.theta_beta) + bilateral_out = bilateral_out / bilateral_norm_vals + + # Weighting filter outputs + message_passing = (tf.matmul(self.spatial_ker_weights, + tf.reshape(spatial_out, (c, -1))) + + tf.matmul(self.bilateral_ker_weights, + tf.reshape(bilateral_out, (c, -1)))) + + # Compatibility transform + pairwise = tf.matmul(self.compatibility_matrix, message_passing) + + # Adding unary potentials + pairwise = tf.reshape(pairwise, (c, h, w)) + q_values = unaries - pairwise + + return tf.reshape(q_values, (1, c, h, w)) + + def compute_output_shape(self, input_shape): + return input_shape diff --git a/src/crfrnn_layer.py b/src/python/crfrnn_layer_rgb.py old mode 100644 new mode 100755 similarity index 100% rename from src/crfrnn_layer.py rename to src/python/crfrnn_layer_rgb.py diff --git a/src/crfrnn_model.py b/src/python/crfrnn_model.py old mode 100644 new mode 100755 similarity index 93% rename from src/crfrnn_model.py rename to src/python/crfrnn_model.py index bdf5944..0644c8a --- a/src/crfrnn_model.py +++ b/src/python/crfrnn_model.py @@ -25,7 +25,7 @@ from keras.models import Model from keras.layers import Conv2D, MaxPooling2D, Input, ZeroPadding2D, \ Dropout, Conv2DTranspose, Cropping2D, Add -from crfrnn_layer import CrfRnnLayer +from crfrnn_layer_rgb import CrfRnnLayer def get_crfrnn_model_def(): @@ -103,12 +103,12 @@ def get_crfrnn_model_def(): upscore = Cropping2D(((31, 37), (31, 37)))(upsample) output = CrfRnnLayer(image_dims=(height, weight), - num_classes=21, - theta_alpha=160., - theta_beta=3., - theta_gamma=3., - num_iterations=10, - name='crfrnn')([upscore, img_input]) + num_classes=21, + theta_alpha=160., + theta_beta=3., + theta_gamma=3., + num_iterations=10, + name='crfrnn')([upscore, img_input]) # Build the model model = Model(img_input, output, name='crfrnn_net') diff --git a/src/high_dim_filter_loader.py b/src/python/high_dim_filter_loader.py old mode 100644 new mode 100755 similarity index 84% rename from src/high_dim_filter_loader.py rename to src/python/high_dim_filter_loader.py index 2ce32a1..4daa497 --- a/src/high_dim_filter_loader.py +++ b/src/python/high_dim_filter_loader.py @@ -25,7 +25,11 @@ import os import tensorflow as tf from tensorflow.python.framework import ops -custom_module = tf.load_op_library(os.path.join(os.path.dirname(__file__), 'cpp', 'high_dim_filter.so')) +custom_module = tf.load_op_library(os.path.join(os.path.dirname(__file__), + '..', + 'cpp', + 'build', + 'high_dim_filter.so')) @ops.RegisterGradient('HighDimFilter') @@ -46,8 +50,10 @@ def _high_dim_filter_grad(op, grad): grad_vals = custom_module.high_dim_filter(grad, rgb, bilateral=op.get_attr('bilateral'), theta_alpha=op.get_attr('theta_alpha'), + theta_alpha_z=op.get_attr('theta_alpha_z'), theta_beta=op.get_attr('theta_beta'), theta_gamma=op.get_attr('theta_gamma'), + theta_gamma_z=op.get_attr('theta_gamma_z'), backwards=True) return [grad_vals, tf.zeros_like(rgb)] diff --git a/src/test_gradients.py b/src/python/test_gradients.py old mode 100644 new mode 100755 similarity index 100% rename from src/test_gradients.py rename to src/python/test_gradients.py diff --git a/src/util.py b/src/python/util.py old mode 100644 new mode 100755 similarity index 100% rename from src/util.py rename to src/python/util.py