Skip to content

Commit

Permalink
[onert] Revisit DepthwiseConv train cker (#13546)
Browse files Browse the repository at this point in the history
This commit revisits DepthwiseConv train cker to change class to
functions.

ONE-DCO-1.0-Signed-off-by: Jiyoung Yun <[email protected]>
  • Loading branch information
jyoungyun authored Aug 1, 2024
1 parent e7249f5 commit b7f78f3
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 108 deletions.
12 changes: 12 additions & 0 deletions compute/cker/include/cker/eigen/EigenSupport.h
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,18 @@ inline const Eigen::ThreadPoolDevice *GetThreadPoolDevice()
return ctx.device.get();
}

template <typename T> int64_t kPacketSize()
{
typedef typename Eigen::internal::packet_traits<T>::type Packet;
return sizeof(Packet) / sizeof(T);
}

inline int getThreadCount()
{
const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice();
return d.numThreads();
}

} // namespace eigen_support
} // namespace cker
} // namespace nnfw
Expand Down
146 changes: 63 additions & 83 deletions compute/cker/include/cker/train/operation/DepthwiseConv.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,90 +28,70 @@ namespace cker
namespace train
{

class DepthwiseConv
template <typename T>
void backpropInput(const DepthwiseConvParams &params, const Shape &incoming_shape,
const T *incoming_data, const Shape &filter_shape, const T *filter_data,
T *padded_filter_data, const Shape &grad_shape, T *grad_data, bool pad_filter,
T *filter_buffers_data, T *filter_dim_buffers_data)
{
public:
DepthwiseConv() = default;

template <typename T> int64_t kPacketSize() const
{
typedef typename Eigen::internal::packet_traits<T>::type Packet;
return sizeof(Packet) / sizeof(T);
}

int getThreadCount() const
{
// NOTE The Eigen library uses both main thread as well as a thread pool.
// Therefore, it needs to add an additional memory buffer for main thread.
const Eigen::ThreadPoolDevice &d = *eigen_support::GetThreadPoolDevice();
return d.numThreads() + 1;
}

template <typename T>
void backpropInput(const DepthwiseConvParams &params, const Shape &incoming_shape,
const T *incoming_data, const Shape &filter_shape, const T *filter_data,
T *padded_filter_data, const Shape &grad_shape, T *grad_data, bool pad_filter,
T *filter_buffers_data, T *filter_dim_buffers_data)
{
if (params.stride_height != params.stride_width)
throw std::runtime_error("Not support different length strides");

if (params.dilation_height_factor != 1 || params.dilation_width_factor != 1)
throw std::runtime_error{"Not support dilation other than 1."};

const int batch = MatchingDim(incoming_shape, 0, grad_shape, 0);
const int input_depth = grad_shape.Dims(3);
const int output_depth = incoming_shape.Dims(3);
const int incoming_height = incoming_shape.Dims(1);
const int incoming_width = incoming_shape.Dims(2);
const int grad_height = grad_shape.Dims(1);
const int grad_width = grad_shape.Dims(2);
const int stride = params.stride_height;
const int depth_multiplier = params.depth_multiplier;
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int pad_height = params.padding_values.height;
const int pad_width = params.padding_values.width;

depthwise_conv_op::LaunchDepthwiseConvBackpropInputOp<Eigen::ThreadPoolDevice, T>()(
batch, grad_height, grad_width, input_depth, filter_height, filter_width, depth_multiplier,
stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data,
filter_data, padded_filter_data, grad_data, pad_filter, filter_buffers_data,
filter_dim_buffers_data);
}

template <typename T>
void backpropFilter(const DepthwiseConvParams &params, const Shape &incoming_shape,
const T *incoming_data, const Shape &input_shape, const T *input_data,
const Shape &filter_grad_shape, T *filter_grad_data, T *padded_filter_data,
T *filter_buffers_data)
{
if (params.stride_height != params.stride_width)
throw std::runtime_error("Not support different length strides");

if (params.dilation_height_factor != 1 || params.dilation_width_factor != 1)
throw std::runtime_error{"Not support dilation other than 1."};

const int batch = MatchingDim(incoming_shape, 0, input_shape, 0);
const int input_depth = input_shape.Dims(3);
const int output_depth = incoming_shape.Dims(3);
const int incoming_height = incoming_shape.Dims(1);
const int incoming_width = incoming_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int stride = params.stride_height;
const int depth_multiplier = params.depth_multiplier;
const int filter_height = filter_grad_shape.Dims(1);
const int filter_width = filter_grad_shape.Dims(2);
const int pad_height = params.padding_values.height;
const int pad_width = params.padding_values.width;

depthwise_conv_op::LaunchDepthwiseConvBackpropFilterOp<Eigen::ThreadPoolDevice, T>()(
batch, input_height, input_width, input_depth, filter_height, filter_width, depth_multiplier,
stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data,
input_data, filter_grad_data, padded_filter_data, filter_buffers_data);
}
};
if (params.stride_height != params.stride_width)
throw std::runtime_error("Not support different length strides");

if (params.dilation_height_factor != 1 || params.dilation_width_factor != 1)
throw std::runtime_error{"Not support dilation other than 1."};

const int batch = MatchingDim(incoming_shape, 0, grad_shape, 0);
const int input_depth = grad_shape.Dims(3);
const int output_depth = incoming_shape.Dims(3);
const int incoming_height = incoming_shape.Dims(1);
const int incoming_width = incoming_shape.Dims(2);
const int grad_height = grad_shape.Dims(1);
const int grad_width = grad_shape.Dims(2);
const int stride = params.stride_height;
const int depth_multiplier = params.depth_multiplier;
const int filter_height = filter_shape.Dims(1);
const int filter_width = filter_shape.Dims(2);
const int pad_height = params.padding_values.height;
const int pad_width = params.padding_values.width;

depthwise_conv_op::LaunchDepthwiseConvBackpropInputOp<Eigen::ThreadPoolDevice, T>()(
batch, grad_height, grad_width, input_depth, filter_height, filter_width, depth_multiplier,
stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data,
filter_data, padded_filter_data, grad_data, pad_filter, filter_buffers_data,
filter_dim_buffers_data);
}

template <typename T>
void backpropFilter(const DepthwiseConvParams &params, const Shape &incoming_shape,
const T *incoming_data, const Shape &input_shape, const T *input_data,
const Shape &filter_grad_shape, T *filter_grad_data, T *padded_filter_data,
T *filter_buffers_data)
{
if (params.stride_height != params.stride_width)
throw std::runtime_error("Not support different length strides");

if (params.dilation_height_factor != 1 || params.dilation_width_factor != 1)
throw std::runtime_error{"Not support dilation other than 1."};

const int batch = MatchingDim(incoming_shape, 0, input_shape, 0);
const int input_depth = input_shape.Dims(3);
const int output_depth = incoming_shape.Dims(3);
const int incoming_height = incoming_shape.Dims(1);
const int incoming_width = incoming_shape.Dims(2);
const int input_height = input_shape.Dims(1);
const int input_width = input_shape.Dims(2);
const int stride = params.stride_height;
const int depth_multiplier = params.depth_multiplier;
const int filter_height = filter_grad_shape.Dims(1);
const int filter_width = filter_grad_shape.Dims(2);
const int pad_height = params.padding_values.height;
const int pad_width = params.padding_values.width;

depthwise_conv_op::LaunchDepthwiseConvBackpropFilterOp<Eigen::ThreadPoolDevice, T>()(
batch, input_height, input_width, input_depth, filter_height, filter_width, depth_multiplier,
stride, pad_height, pad_width, incoming_height, incoming_width, output_depth, incoming_data,
input_data, filter_grad_data, padded_filter_data, filter_buffers_data);
}

} // namespace train
} // namespace cker
Expand Down
31 changes: 15 additions & 16 deletions compute/cker/src/train/DepthwiseConv.test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* limitations under the License.
*/

#include <cker/eigen/EigenSupport.h>
#include <cker/train/operation/DepthwiseConv.h>

#include <gtest/gtest.h>
Expand All @@ -25,14 +26,11 @@ namespace
template <typename T> class DepthwiseConvVerifier
{
public:
DepthwiseConvVerifier() : _dconv_kernel{new nnfw::cker::train::DepthwiseConv()}
{
_dconv_kernel = std::make_unique<nnfw::cker::train::DepthwiseConv>();
}
DepthwiseConvVerifier() = default;

void prepare(const nnfw::cker::Shape &incoming_shape, const nnfw::cker::Shape &filter_shape)
{
const int k_packet_size = _dconv_kernel->kPacketSize<T>();
const int k_packet_size = nnfw::cker::eigen_support::kPacketSize<T>();
const int batch = incoming_shape.Dims(0);
const int out_depth = incoming_shape.Dims(3);
const int filter_rows = filter_shape.Dims(1);
Expand All @@ -49,7 +47,9 @@ template <typename T> class DepthwiseConvVerifier
}

{
const int thread_count = _dconv_kernel->getThreadCount();
// NOTE The Eigen library uses both main thread as well as a thread pool.
// Therefore, it needs to add an additional memory buffer for main thread.
const int thread_count = nnfw::cker::eigen_support::getThreadCount() + 1;

nnfw::cker::Shape filter_buffer_shape(
{thread_count, filter_spatial_size, padded_filter_inner_dim_size});
Expand All @@ -71,10 +71,10 @@ template <typename T> class DepthwiseConvVerifier
calculateInputGradExpected(params, incoming_shape, incoming_data, filter_shape, filter_data,
grad_shape, expected.data());

_dconv_kernel->backpropInput(params, incoming_shape, incoming_data, filter_shape, filter_data,
_padded_filter.data(), grad_shape, gradient.data(),
_use_padded_filter, _filter_buffers.data(),
_filter_dim_buffers.data());
nnfw::cker::train::backpropInput(params, incoming_shape, incoming_data, filter_shape,
filter_data, _padded_filter.data(), grad_shape,
gradient.data(), _use_padded_filter, _filter_buffers.data(),
_filter_dim_buffers.data());

for (size_t i = 0; i < gradient.size(); ++i)
EXPECT_NEAR(gradient[i], expected[i], 1e-3f);
Expand All @@ -87,7 +87,7 @@ template <typename T> class DepthwiseConvVerifier
{
std::vector<T> gradient(grad_shape.FlatSize(), static_cast<T>(0));

EXPECT_ANY_THROW(_dconv_kernel->backpropInput(
EXPECT_ANY_THROW(nnfw::cker::train::backpropInput(
params, incoming_shape, incoming_data, filter_shape, filter_data, _padded_filter.data(),
grad_shape, gradient.data(), _use_padded_filter, _filter_buffers.data(),
_filter_dim_buffers.data()));
Expand All @@ -104,9 +104,9 @@ template <typename T> class DepthwiseConvVerifier
calculateFilterGradExpected(params, incoming_shape, incoming_data, input_shape, input_data,
filter_grad_shape, expected.data());

_dconv_kernel->backpropFilter(params, incoming_shape, incoming_data, input_shape, input_data,
filter_grad_shape, gradient.data(), _padded_filter.data(),
_filter_buffers.data());
nnfw::cker::train::backpropFilter(params, incoming_shape, incoming_data, input_shape,
input_data, filter_grad_shape, gradient.data(),
_padded_filter.data(), _filter_buffers.data());

for (size_t i = 0; i < gradient.size(); ++i)
EXPECT_NEAR(gradient[i], expected[i], 1e-3f);
Expand All @@ -119,7 +119,7 @@ template <typename T> class DepthwiseConvVerifier
{
std::vector<T> gradient(filter_grad_shape.FlatSize(), static_cast<T>(0));

EXPECT_ANY_THROW(_dconv_kernel->backpropFilter(
EXPECT_ANY_THROW(nnfw::cker::train::backpropFilter(
params, incoming_shape, incoming_data, input_shape, input_data, filter_grad_shape,
gradient.data(), _padded_filter.data(), _filter_buffers.data()));
}
Expand Down Expand Up @@ -186,7 +186,6 @@ template <typename T> class DepthwiseConvVerifier
}

private:
std::unique_ptr<nnfw::cker::train::DepthwiseConv> _dconv_kernel;
bool _use_padded_filter;
std::vector<T> _padded_filter;
std::vector<T> _filter_buffers;
Expand Down
15 changes: 9 additions & 6 deletions runtime/onert/backend/train/ops/DepthwiseConvolutionLayer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@

#include "OperationUtils.h"

#include <cker/eigen/EigenSupport.h>
#include <cker/train/operation/DepthwiseConv.h>
#include <cker/train/operation/ReLU.h>

namespace onert
Expand All @@ -33,8 +35,7 @@ DepthwiseConvolutionLayer::DepthwiseConvolutionLayer()
: cpu::ops::DepthwiseConvolutionLayer(), _grad_weights{nullptr}, _grad_bias{nullptr},
_back_prop_input{nullptr}, _back_prop_output{nullptr}, _act_back_prop_output{nullptr},
_use_padded_filter{false}, _padded_filter{nullptr}, _filter_buffers{nullptr},
_filter_dim_buffers{nullptr},
_dconv_kernel{std::make_unique<nnfw::cker::train::DepthwiseConv>()}
_filter_dim_buffers{nullptr}
{
// DO NOTHING
}
Expand Down Expand Up @@ -66,7 +67,7 @@ void DepthwiseConvolutionLayer::configureBackward(IPortableTensor *back_prop_inp
{
case OperandType::FLOAT32:
{
return _dconv_kernel->kPacketSize<float>();
return nnfw::cker::eigen_support::kPacketSize<float>();
}
default:
throw std::runtime_error("train DepthwiseConvolutionLayer: unsupported data type");
Expand All @@ -93,7 +94,9 @@ void DepthwiseConvolutionLayer::configureBackward(IPortableTensor *back_prop_inp
_padded_filter->setBuffer(std::make_shared<basic::Allocator>(_padded_filter->total_size()));

// prepare out_bprop and in_bprop buffer for cker
const int thread_count = _dconv_kernel->getThreadCount();
// NOTE The Eigen library uses both main thread as well as a thread pool.
// Therefore, it needs to add an additional memory buffer for main thread.
const int thread_count = nnfw::cker::eigen_support::getThreadCount() + 1;

auto filter_buffers_info = ir::OperandInfo(_kernel->get_info());
filter_buffers_info.shape({thread_count, filter_spatial_size, padded_filter_inner_dim_size});
Expand Down Expand Up @@ -151,14 +154,14 @@ void DepthwiseConvolutionLayer::backwardFloat32()
dconv_params.dilation_width_factor = _dilationWidth;

// Calculate gradient for input
_dconv_kernel->backpropInput(
nnfw::cker::train::backpropInput(
dconv_params, getShape(backprop_act), getBuffer<float>(backprop_act), getShape(_kernel),
getBuffer<float>(_kernel), getBuffer<float>(_padded_filter.get()), getShape(_back_prop_input),
getBuffer<float>(_back_prop_input), _use_padded_filter, getBuffer<float>(_filter_buffers.get()),
getBuffer<float>(_filter_dim_buffers.get()));

// Calculate gradient for weights
_dconv_kernel->backpropFilter(
nnfw::cker::train::backpropFilter(
dconv_params, getShape(backprop_act), getBuffer<float>(backprop_act), getShape(_input),
getBuffer<float>(_input), getShape(_grad_weights), getBuffer<float>(_grad_weights),
getBuffer<float>(_padded_filter.get()), getBuffer<float>(_filter_buffers.get()));
Expand Down
3 changes: 0 additions & 3 deletions runtime/onert/backend/train/ops/DepthwiseConvolutionLayer.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#ifndef __ONERT_BACKEND_TRAIN_OPS_DEPTHWISECONVOLUTIONLAYER_H__
#define __ONERT_BACKEND_TRAIN_OPS_DEPTHWISECONVOLUTIONLAYER_H__

#include <cker/train/operation/DepthwiseConv.h>
#include <ops/DepthwiseConvolutionLayer.h>
#include <backend/basic/Allocator.h>

Expand Down Expand Up @@ -61,8 +60,6 @@ class DepthwiseConvolutionLayer : public ::onert::exec::train::ITrainableFunctio
std::unique_ptr<Tensor> _padded_filter;
std::unique_ptr<Tensor> _filter_buffers;
std::unique_ptr<Tensor> _filter_dim_buffers;

std::unique_ptr<nnfw::cker::train::DepthwiseConv> _dconv_kernel;
};

} // namespace ops
Expand Down

0 comments on commit b7f78f3

Please sign in to comment.