From 1defea382ff8df4dcaff3d33261d12df4a4ea1a0 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 13:57:45 +0800 Subject: [PATCH 01/15] linear: add OP context --- aten/src/ATen/native/mkldnn/Common.h | 16 ++ aten/src/ATen/native/mkldnn/LinearPrepack.cpp | 140 ++++++++++++++++++ aten/src/ATen/native/mkldnn/LinearPrepack.h | 41 +++++ aten/src/ATen/native/mkldnn/OpContext.cpp | 22 +++ aten/src/ATen/native/mkldnn/OpContext.h | 53 +++++++ .../mkldnn/RegisterMkldnnOpContextClass.cpp | 35 +++++ 6 files changed, 307 insertions(+) create mode 100644 aten/src/ATen/native/mkldnn/LinearPrepack.cpp create mode 100644 aten/src/ATen/native/mkldnn/LinearPrepack.h diff --git a/aten/src/ATen/native/mkldnn/Common.h b/aten/src/ATen/native/mkldnn/Common.h index 4e048ebce75978..da6a2c3f604cb2 100644 --- a/aten/src/ATen/native/mkldnn/Common.h +++ b/aten/src/ATen/native/mkldnn/Common.h @@ -39,6 +39,22 @@ struct ContextConv final { attr_(attr) {} }; +struct ContextLinear final { + ideep::tensor weight_packed_; + c10::optional at_bias_; + ideep::attr_t attr_; + + ContextLinear() = delete; + + ContextLinear( + ideep::tensor&& weight_packed, + c10::optional at_bias, + ideep::attr_t attr) + : weight_packed_(std::move(weight_packed)), + at_bias_(std::move(at_bias)), + attr_(attr) {} +}; + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp new file mode 100644 index 00000000000000..d26eab5ae83593 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp @@ -0,0 +1,140 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at { +namespace native { +namespace mkldnn { +namespace internal { +namespace linear { + +using namespace torch::jit::mkldnn; + +c10::intrusive_ptr createLinearPrePackOpContext( + Tensor weight, + c10::optional bias, + std::vector input_size, + std::string attr, + std::vector> scalars, + c10::optional algorithm) { + auto it = fusion_attr_map().find(attr); + TORCH_CHECK(it != fusion_attr_map().end(), "Fusion behavior undefined."); + ideep::attr_t op_attr = it->second.attr_function(scalars, algorithm); + return mkldnn::MkldnnLinearOpContext::create_context( + std::move(weight), std::move(bias), std::move(input_size), op_attr); +} + +ContextLinear create( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef input_size, + const ideep::attr_t& attr) { + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + ideep::tensor w = itensor_view_from_dense(weight); + auto dtype = w.get_data_type(); + + int64_t b_size = std::accumulate( + input_size.begin(), + input_size.end(), + (int64_t)1, + std::multiplies()) / + input_size[input_size.size() - 1]; + + auto out_features = weight.size(0); + auto in_features = weight.size(1); + ideep::dims reshaped_input_size = {b_size, in_features}; + + ideep::tensor::desc expected_weight_desc = + ideep::inner_product_forward::expected_weights_desc( + {out_features, in_features}, + reshaped_input_size, + /* w_dtype */ dtype, + /* x_dtype */ dtype); + + ideep::tensor packed_weight; + packed_weight.init(expected_weight_desc); + packed_weight.feed_from(w); + + return ContextLinear{ + std::move(packed_weight), + bias.has_value() ? c10::make_optional(*bias) : c10::nullopt, + std::move(attr)}; +} + +Tensor run(ContextLinear& context, const Tensor& input) { + const ideep::tensor& mkldnn_weight = context.weight_packed_; + + auto input_size = input.sizes(); + + const int64_t dim = input.dim(); + auto input_reshaped = + dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); + + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(mkldnn_weight.get_dim(0)); + auto output = at::empty(output_size, input.options()); + + if (dim != 2) { + std::vector output_size_reshaped = { + input_reshaped.size(0), mkldnn_weight.get_dim(0)}; + output = output.reshape(output_size_reshaped); + } + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + const ideep::tensor mkldnn_input = itensor_view_from_dense(input_reshaped); + ideep::tensor mkldnn_output = itensor_view_from_dense(output); + + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(context.at_bias_); + const Tensor& bias = *bias_maybe_owned; + + if (bias.defined()) { + const ideep::tensor mkldnn_bias = itensor_view_from_dense(bias); + ideep::inner_product_forward::compute( + mkldnn_input, + mkldnn_weight, + mkldnn_bias, + mkldnn_output, + ideep::scale_t(), + ideep::scale_t(), + ideep::scale_t(), + context.attr_); + } else { + ideep::inner_product_forward::compute( + mkldnn_input, + mkldnn_weight, + mkldnn_output, + ideep::scale_t(), + ideep::scale_t(), + ideep::scale_t(), + context.attr_); + } + + if (dim != 2) { + output = output.reshape(output_size); + } + + return output; +} + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& op_context) { + return op_context->run(input); +} + +} // namespace linear +} // namespace internal +} // namespace mkldnn +} // namespace native +} // namespace at + +#endif // AT_MKLDNN_ENABLED() \ No newline at end of file diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.h b/aten/src/ATen/native/mkldnn/LinearPrepack.h new file mode 100644 index 00000000000000..aa43b6791862da --- /dev/null +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.h @@ -0,0 +1,41 @@ +#pragma once + +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at { +namespace native { +namespace mkldnn { +namespace internal { +namespace linear { + +c10::intrusive_ptr createLinearPrePackOpContext( + Tensor weight, + c10::optional bias, + std::vector input_size, + std::string attr, + std::vector> scalars, + c10::optional algorithm); + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& op_context); + +ContextLinear create( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef input_size, + const ideep::attr_t& attr); + +Tensor run(ContextLinear& context, const Tensor& input); + +} // namespace linear +} // namespace internal +} // namespace mkldnn +} // namespace native +} // namespace at + +#endif // AT_MKLDNN_ENABLED() \ No newline at end of file diff --git a/aten/src/ATen/native/mkldnn/OpContext.cpp b/aten/src/ATen/native/mkldnn/OpContext.cpp index 2716b4908eb30d..f408da5daa15e8 100644 --- a/aten/src/ATen/native/mkldnn/OpContext.cpp +++ b/aten/src/ATen/native/mkldnn/OpContext.cpp @@ -1,4 +1,5 @@ #include +#include #include #if AT_MKLDNN_ENABLED() @@ -40,6 +41,27 @@ void MkldnnConvOpContext::run(const Tensor& input, void* output) { return mkldnn::internal::convolution::run(op_context_, input, output); } +c10::intrusive_ptr MkldnnLinearOpContext::create_context( + at::Tensor&& weight, + c10::optional&& bias, + std::vector&& input_size, + const ideep::attr_t& attr) { + auto op_context = + mkldnn::internal::linear::create(weight, bias, input_size, attr); + + auto linear_op_context = c10::make_intrusive( + std::move(weight), + std::move(bias), + std::move(input_size), + std::move(op_context)); + + return linear_op_context; +} + +Tensor MkldnnLinearOpContext::run(const Tensor& input) { + return mkldnn::internal::linear::run(op_context_, input); +} + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/OpContext.h b/aten/src/ATen/native/mkldnn/OpContext.h index 6c095982884057..591e9b29973cf0 100644 --- a/aten/src/ATen/native/mkldnn/OpContext.h +++ b/aten/src/ATen/native/mkldnn/OpContext.h @@ -93,6 +93,59 @@ class MkldnnConvOpContext final : public ConvOpContext { const ideep::attr_t& attr); }; +using SerializationTypeLinearPrePack = std::tuple< + at::Tensor, + c10::optional, + std::vector, + std::string, + std::vector>, + c10::optional>; + +class LinearOpContext : public torch::jit::CustomClassHolder { + protected: + Tensor orig_weight_; + c10::optional orig_bias_; + std::vector input_size_; + std::string attr_; + std::vector> scalars_; + c10::optional algorithm_; + + public: + SerializationTypeLinearPrePack unpack() { + return std::make_tuple( + orig_weight_, orig_bias_, input_size_, attr_, scalars_, algorithm_); + } + + virtual at::Tensor run(const at::Tensor& input) = 0; + + // TODO: run with void* output +}; + +class MkldnnLinearOpContext final : public LinearOpContext { + private: + ContextLinear op_context_; + + public: + MkldnnLinearOpContext( + Tensor&& weight, + c10::optional&& bias, + std::vector&& input_size, + ContextLinear&& op_context) + : op_context_(std::move(op_context)) { + orig_weight_ = std::move(weight); + orig_bias_ = std::move(bias); + input_size_ = std::move(input_size); + } + + virtual at::Tensor run(const at::Tensor& input) override; + + static c10::intrusive_ptr create_context( + at::Tensor&& weight, + c10::optional&& bias, + std::vector&& input_size, + const ideep::attr_t& attr); +}; + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp index 534e6388f4428b..b1c36c6072d793 100644 --- a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp +++ b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -11,6 +12,7 @@ namespace native { namespace mkldnn { using namespace internal::convolution; +using namespace internal::linear; TORCH_LIBRARY(mkldnn, m) { m.class_(TORCH_SELECTIVE_CLASS("ConvOpContext")) @@ -36,14 +38,39 @@ TORCH_LIBRARY(mkldnn, m) { std::move(std::get<8>(state)), std::move(std::get<9>(state))); }); + + m.class_(TORCH_SELECTIVE_CLASS("LinearOpContext")) + .def_pickle( + [](const c10::intrusive_ptr& op_context) + -> SerializationTypeLinearPrePack { // __getstate__ + return op_context->unpack(); + }, + [](SerializationTypeLinearPrePack state) + -> c10::intrusive_ptr { // __setstate__ + return createLinearPrePackOpContext( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state))); + }); } TORCH_LIBRARY(mkldnn_prepacked, m) { + // conv m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn_prepacked::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, int[4] input_size, str attr, Scalar?[] scalars, str? algorithm) -> __torch__.torch.classes.mkldnn.ConvOpContext")); m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> Tensor Y")); + + // linear + m.def(TORCH_SELECTIVE_SCHEMA( + "mkldnn_prepacked::linear_prepack(Tensor w, Tensor? B, int[] input_sizes, str attr, Scalar?[] scalars, str? algorithm) -> __torch__.torch.classes.mkldnn.LinearOpContext")); + + m.def(TORCH_SELECTIVE_SCHEMA( + "mkldnn_prepacked::linear_run(Tensor X, __torch__.torch.classes.mkldnn.LinearOpContext W_prepack) -> Tensor Y")); } TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) { @@ -53,6 +80,14 @@ TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_run"), TORCH_FN(conv_run)); + + m.impl( + TORCH_SELECTIVE_NAME("mkldnn_prepacked::linear_prepack"), + TORCH_FN(createLinearPrePackOpContext)); + + m.impl( + TORCH_SELECTIVE_NAME("mkldnn_prepacked::linear_run"), + TORCH_FN(linear_run)); } } // namespace mkldnn From df92236f925e5050e0a385be9ff92ca4d09a74ac Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 13:58:09 +0800 Subject: [PATCH 02/15] linear: add graph rewrite for single linear --- torch/csrc/jit/passes/mkldnn_rewrite.cpp | 69 +++++++++++++++++++++++- 1 file changed, 68 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/mkldnn_rewrite.cpp b/torch/csrc/jit/passes/mkldnn_rewrite.cpp index 79cfb2646a09f0..95dade426df1d5 100644 --- a/torch/csrc/jit/passes/mkldnn_rewrite.cpp +++ b/torch/csrc/jit/passes/mkldnn_rewrite.cpp @@ -209,6 +209,51 @@ void insertPrePackedConvOpForNode(Node* n) { n->output()->replaceAllUsesWith(prepack_conv->output()); } +void insertPrePackedLinearOpForNode(Node* n) { + constexpr int POS_INPUT = 0; + constexpr int POS_WEIGHT = 1; + // TODO: check input and weight should be contiguous + + WithInsertPoint guard(n); + auto graph = n->owningGraph(); + + auto input_sizes = getSizesOf(n, POS_INPUT); + IValue input_size_value(*input_sizes.concrete_sizes()); + auto input_size = graph->insertConstant(input_size_value); + + auto prepack_node = graph->create( + Symbol::fromQualString("mkldnn_prepacked::linear_prepack"), 1); + + // skip input value + for (auto i = 1; i < n->inputs().size(); i++) { + Value* v = n->input(i); + prepack_node->addInput(v); + } + prepack_node->addInput(input_size); + auto attr = graph->insertConstant(IValue("none")); + prepack_node->addInput(attr); + + std::vector> empty_scalars; + auto scalars = graph->insertConstant(IValue(empty_scalars)); + prepack_node->addInput(scalars); + + c10::optional empty_algorithm; + auto algorithm = graph->insertConstant(IValue(empty_algorithm)); + prepack_node->addInput(algorithm); + + prepack_node->output()->setType( + getCustomClass("__torch__.torch.classes.mkldnn.LinearOpContext")); + graph->insertNode(prepack_node); + + auto prepack_linear = graph->insertNode( + graph->create(Symbol::fromQualString("mkldnn_prepacked::linear_run"), 1)); + prepack_linear->addInput(n->input(0)); + prepack_linear->addInput(prepack_node->output()); + prepack_linear->output()->setType(n->output()->type()->cast()); + + n->output()->replaceAllUsesWith(prepack_linear->output()); +} + bool isTensorTypeCPU(Node* node) { for (const auto& input : node->inputs()) { auto type = input->type()->cast(); @@ -241,6 +286,21 @@ void insertPrePackedConvOp(Block* b) { EliminateDeadCode(b); } +void insertPrePackedLinearOp(Block* b) { + for (Node* n : b->nodes()) { + for (Block* b : n->blocks()) { + insertPrePackedLinearOp(b); + } + + if (n->kind() == aten::linear) { + if (isTensorTypeCPU(n)) { + insertPrePackedLinearOpForNode(n); + } + } + } + EliminateDeadCode(b); +} + void insertMkldnnPrePackedConv2dOp(std::shared_ptr& graph) { // Replace _convolution with conv2d graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); @@ -248,8 +308,13 @@ void insertMkldnnPrePackedConv2dOp(std::shared_ptr& graph) { insertPrePackedConvOp(graph->block()); } +void insertMkldnnPrePackedLinearOp(std::shared_ptr& graph) { + insertPrePackedLinearOp(graph->block()); +} + void insertMkldnnPrePackedOps(std::shared_ptr& graph) { insertMkldnnPrePackedConv2dOp(graph); + insertMkldnnPrePackedLinearOp(graph); } void insertMkldnnPrePackedOps(script::Module& module) { @@ -357,7 +422,9 @@ void PrePackingOpsFolder(Block* b) { auto is_foldable_op = [](const Node* n) -> bool { return ( n->kind() == - Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack")); + Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack") || + n->kind() == + Symbol::fromQualString("mkldnn_prepacked::linear_prepack")); }; std::unordered_set nodes_to_delete; From ba1f5e1124b4444b7c6d0a7652281169b53f84c1 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 13:58:53 +0800 Subject: [PATCH 03/15] linear: integrate into NNC via external call --- torch/csrc/jit/passes/tensorexpr_fuser.cpp | 7 +++++ .../jit/tensorexpr/external_functions.cpp | 28 +++++++++++++++++++ torch/csrc/jit/tensorexpr/kernel.cpp | 21 ++++++++++++++ torch/csrc/jit/tensorexpr/kernel.h | 3 ++ torch/csrc/jit/tensorexpr/lowerings.cpp | 3 ++ .../csrc/jit/tensorexpr/operators/matmul.cpp | 20 +++++++++++++ torch/csrc/jit/tensorexpr/operators/matmul.h | 7 ++++- 7 files changed, 88 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index a885d7c177ba3f..3cc710d2b392c0 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -81,6 +81,7 @@ static const OperatorSet& supported_non_eltwise_set() { "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "aten::matmul(Tensor self, Tensor other) -> Tensor", + "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", }; // clang-format on return supported_non_eltwise_set; @@ -1058,6 +1059,12 @@ class TensorExprFuser { return false; } } + if (node->kind() == aten::linear) { + if (!tensorexpr::mkldnnLinearIsSupported(node)) { + GRAPH_DEBUG("Shapes of linear inputs are not supported"); + return false; + } + } return true; } diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index 3b9ad4d1f255bc..5c33d0e954ccd8 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1434,6 +1434,30 @@ void nnc_mkldnn_prepacked_conv_run( context->run(x, buf_data[0]); } +void nnc_mkldnn_prepacked_linear_run( + int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) { + using namespace at::native::mkldnn; + + auto tensors = constructTensors( + bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes); + + const at::Tensor& x = tensors[1]; + auto context = reinterpret_cast(buf_data[2]); + + at::Tensor output = context->run(x); + memcpy( + buf_data[0], output.data_ptr(), output.element_size() * output.numel()); + // TODO: remove mem copy + // context->run(x, buf_data[0]); +} + #endif // AT_MKLDNN_ENABLED() #ifdef USE_XNNPACK @@ -1618,6 +1642,10 @@ const static RegisterNNCExternalFunction nnc_embedding( const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_conv_run( "nnc_mkldnn_prepacked_conv_run", nnc_mkldnn_prepacked_conv_run); + +const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_linear_run( + "nnc_mkldnn_prepacked_linear_run", + nnc_mkldnn_prepacked_linear_run); #endif // AT_MKLDNN_ENABLED() #ifdef USE_XNNPACK diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 5c7c783e1b78a1..0cc4d47c4fdc3c 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -314,6 +314,27 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) { return false; } +bool mkldnnLinearIsSupported(const torch::jit::Node* node) { + auto const& input0 = getTensorInfoJit(node->input(0)); + auto const& input1 = getTensorInfoJit(node->input(1)); + + // Everything should be statically known. + if (!input0 || !input1) { + GRAPH_DEBUG("mkldnnLinearIsSupported: Input shapes aren't static"); + return false; + } + + // Inputs should be contiguous, or the TE will needlessly transpose them. + if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) { + GRAPH_DEBUG("mkldnnLinearIsSupported: Input shapes are not contiguous"); + return false; + } + + // TODO: only support BF16 to make sure there's no performance regression + + return true; +} + // The fuser currently only supports matmul of 2D x 2D matrices bool matmulIsSupported(const torch::jit::Node* node) { auto const& input0 = getTensorInfoJit(node->input(0)); diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index dc803ced2e29da..decf069edd4c46 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -26,6 +26,9 @@ struct SmallSizeTPairHash { bool conv2dIsSupportedJit(const Node* node); // Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv. bool mkldnnPrepackedConvIsSupportedJit(const Node* node); +// Returns true if the TE fuser supports this linear with mkldnn prepacked +// linear. +bool mkldnnLinearIsSupported(const Node* node); // Returns true if the TE fuser supports this matmul. bool matmulIsSupported(const Node* node); template diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index 5bfab4e26915cc..2c48e91636469a 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -49,6 +49,9 @@ int nnc_lowerings_lazy_registration() { RegisterNNCLoweringsFunction mkldnn_prepacked_conv2d_run( {"mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> (Tensor Y)"}, computeMkldnnPrepackedConvRun); + RegisterNNCLoweringsFunction mkldnn_prepacked_linear_run( + {"mkldnn_prepacked::linear_run(Tensor X, __torch__.torch.classes.mkldnn.LinearOpContext W_prepack) -> (Tensor Y)"}, + computeMkldnnPrepackedLinearRun); #endif // AT_MKLDNN_ENABLED() RegisterNNCLoweringsFunction aten_sub( diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.cpp b/torch/csrc/jit/tensorexpr/operators/matmul.cpp index 408ae6e8c36105..888864930080ff 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.cpp +++ b/torch/csrc/jit/tensorexpr/operators/matmul.cpp @@ -76,6 +76,26 @@ Tensor computeAddMM( inputs[4])})); // TODO: handle other dtypes of alpha and beta } +Tensor computeMkldnnPrepackedLinearRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const c10::optional& outputType, + at::Device device) { + Dtype dtype = kFloat; + if (outputType) { + dtype = Dtype(*outputType); + } + + BufHandle ResultBuf( + "mkldnn_prepacked_linear_run", outputShape, outputStrides, dtype); + const BufHandle& inp = c10::get(inputs[0]); + const BufHandle& prepacked = c10::get(inputs[1]); + StmtPtr s = ExternalCall::make( + ResultBuf, "nnc_mkldnn_prepacked_linear_run", {inp, prepacked}, {}); + return Tensor(ResultBuf.node(), s); +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.h b/torch/csrc/jit/tensorexpr/operators/matmul.h index 70f3f4bf7bf03f..de08007994d5c6 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.h +++ b/torch/csrc/jit/tensorexpr/operators/matmul.h @@ -18,7 +18,12 @@ Tensor computeAddMM( const std::vector& outputStrides, const c10::optional& outputType, at::Device device); - +Tensor computeMkldnnPrepackedLinearRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const c10::optional& outputType, + at::Device device); } // namespace tensorexpr } // namespace jit } // namespace torch From ef562f29adfc8e32548059470f3518a7bfb4cb02 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 13:59:12 +0800 Subject: [PATCH 04/15] linear: add UT --- test/test_mkldnn_fusion.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index b51cd7f47e0d2f..832a2002ca445f 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -192,5 +192,29 @@ def forward(self, x): else: self.assertGraphContains(graph, kind='aten::conv2d') + def test_single_linear(self): + class M(nn.Module): + def __init__(self, in_channels, out_channels, bias, **kwargs): + super(M, self).__init__() + self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + res = self.linear(x) + return res + iC = 2 + oC = 3 + for bias in [True, False]: + # TODO: refactor x_sghape generation + for x_shape in [ + [1, iC], + [2, iC], + [3, 2, iC] + ]: + m = M(iC, oC, bias) + x = torch.randn(x_shape) + graph = self._check_model(m, x) + self.assertFused(graph, ['aten::linear']) + self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + if __name__ == "__main__": run_tests() From dfc0b36530bda4413c2b08d681410a52d6020569 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 14:40:59 +0800 Subject: [PATCH 05/15] linear: enable fusion --- torch/csrc/jit/passes/mkldnn_rewrite.cpp | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/mkldnn_rewrite.cpp b/torch/csrc/jit/passes/mkldnn_rewrite.cpp index 95dade426df1d5..66644e48d3b83c 100644 --- a/torch/csrc/jit/passes/mkldnn_rewrite.cpp +++ b/torch/csrc/jit/passes/mkldnn_rewrite.cpp @@ -415,7 +415,16 @@ void FuseEltwiseWithPackedOps(std::shared_ptr& graph) { "%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int,"), std::string("%weight, %bias, %stride, %padding, %dilation, %groups,")); - // TODO: add linear, matmul + RewriteEltwiseGraph( + graph, + mkldnn::fusion_attr_map(), + std::string("mkldnn_prepacked::linear_prepack"), + std::string("mkldnn_prepacked::linear_run"), + std::string("mkldnn.LinearOpContext"), + std::string("%input, %weight, %bias,"), + std::string("%weight, %bias,")); + + // TODO: add matmul } void PrePackingOpsFolder(Block* b) { From 1838aa18607affd9a7b813b8828bc24b1ea6ff00 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 14:54:24 +0800 Subject: [PATCH 06/15] UT: 1. add UT for linear_eltwise and linear_clamp 2. refactor code to reuse common parts --- test/test_mkldnn_fusion.py | 178 ++++++++++++++++++++++++------------- 1 file changed, 115 insertions(+), 63 deletions(-) diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index 832a2002ca445f..60e56efd4f9a65 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -47,6 +47,72 @@ def _check_model(self, m, x): torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu) return graph + def _eltwise_list(self): + eltwise_list = [ + [torch.relu, 'aten::relu'], + [torch.sigmoid, 'aten::sigmoid'], + [torch.tanh, 'aten::tanh'], + [torch.nn.Hardswish(inplace=False), 'aten::hardswish'], + [nn.LeakyReLU(0.1, inplace=False), 'aten::leaky_relu'], + [nn.Hardtanh(inplace=False), 'aten::hardtanh'], + [nn.GELU(approximate="none"), 'aten::gelu'], + [nn.GELU(approximate="tanh"), 'aten::gelu'], + ] + return eltwise_list + + def _clamp_modules(self): + class MNoOpt(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MNoOpt, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=-0.5, max=0.9) + return x + + class MInf(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MInf, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=0, max=float('inf')) + return x + + class MNegInf(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MNegInf, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=float('-inf'), max=0) + return x + + class MOptMin(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MOptMin, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, max=2) + return x + + class MOptMax(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MOptMax, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=0) + return x + + return [MNoOpt, MInf, MNegInf, MOptMin, MOptMax] + def test_single_conv(self): class M(nn.Module): def __init__(self, in_channels, out_channels, bias, **kwargs): @@ -100,16 +166,7 @@ def forward(self, x): [torch.contiguous_format, False], [torch.channels_last, True], ]: - for eltwise_fn, op_name in [ - [torch.relu, 'aten::relu'], - [torch.sigmoid, 'aten::sigmoid'], - [torch.tanh, 'aten::tanh'], - [torch.nn.Hardswish(inplace=False), 'aten::hardswish'], - [nn.LeakyReLU(0.1, inplace=False), 'aten::leaky_relu'], - [nn.Hardtanh(inplace=False), 'aten::hardtanh'], - [nn.GELU(approximate="none"), 'aten::gelu'], - [nn.GELU(approximate="tanh"), 'aten::gelu'], - ]: + for eltwise_fn, op_name in self._eltwise_list(): for bias in [True, False]: for oC in [1, 10]: m = M(eltwise_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format) @@ -123,57 +180,7 @@ def forward(self, x): self.assertGraphContains(graph, kind='aten::conv2d') def test_conv_clamp(self): - class MNoOpt(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MNoOpt, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=-0.5, max=0.9) - return x - - class MInf(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MInf, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=0, max=float('inf')) - return x - - class MNegInf(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MNegInf, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=float('-inf'), max=0) - return x - - class MOptMin(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MOptMin, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, max=2) - return x - - class MOptMax(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MOptMax, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=0) - return x - - modules = [MNoOpt, MInf, MNegInf, MOptMin, MOptMax] + modules = self._clamp_modules() op_name = 'aten::clamp' for memory_format, enabled in [ @@ -182,7 +189,7 @@ def forward(self, x): ]: for M in modules: for bias in [True, False]: - m = M(3, 10, bias, kernel_size=(3, 3)).to(memory_format=memory_format) + m = M(nn.Conv2d, 3, 10, bias, kernel_size=(3, 3)).to(memory_format=memory_format) x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format) graph = self._check_model(m, x) @@ -200,7 +207,7 @@ def __init__(self, in_channels, out_channels, bias, **kwargs): def forward(self, x): res = self.linear(x) - return res + return res iC = 2 oC = 3 for bias in [True, False]: @@ -216,5 +223,50 @@ def forward(self, x): self.assertFused(graph, ['aten::linear']) self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + def test_linear_eltwise(self): + class M(nn.Module): + def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs): + super(M, self).__init__() + self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias, **kwargs) + self.eltwise = eltwise_fn + + def forward(self, x): + x = self.linear(x) + x = self.eltwise(x) + return x + iC = 2 + oC = 3 + for eltwise_fn, op_name in self._eltwise_list(): + for bias in [True, False]: + for x_shape in [ + [1, iC], + [2, iC], + [3, 2, iC] + ]: + m = M(eltwise_fn, iC, oC, bias) + x = torch.randn(x_shape) + + graph = self._check_model(m, x) + self.assertFused(graph, ['aten::linear', 'aten::' + op_name]) + self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + + def test_linear_clamp(self): + modules = self._clamp_modules() + op_name = 'aten::clamp' + iC = 2 + oC = 3 + for M in modules: + for bias in [True, False]: + for x_shape in [ + [1, iC], + [2, iC], + [3, 2, iC] + ]: + m = M(nn.Linear, iC, oC, bias) + x = torch.randn(x_shape) + graph = self._check_model(m, x) + self.assertFused(graph, ['aten::linear', 'aten::' + op_name]) + self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + if __name__ == "__main__": run_tests() From ccaa901a2a17941290fef02db2a2a9b6ea31e042 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Wed, 22 Jun 2022 17:57:55 +0800 Subject: [PATCH 07/15] add cpp test --- test/cpp/tensorexpr/test_kernel.cpp | 32 +++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 98e7fca452b0f9..14a42c2446b315 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -254,6 +254,38 @@ TEST_F(Kernel, Huge) { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } +TEST_F(Kernel, MkldnnConvEltwise) { + const auto graph_string = R"IR( + graph(%x.1 : Float(1, 1, 224, 224, strides=[50176, 1, 224, 1], requires_grad=0, device=cpu)): + %1 : str = prim::Constant[value="tanh"]() + %2 : int = prim::Constant[value=1]() + %3 : int[] = prim::Constant[value=[0, 0]]() + %4 : int[] = prim::Constant[value=[1, 1]]() + %self.conv.bias : NoneType = prim::Constant() + %self.conv.weight : Float(3, 1, 1, 2, strides=[2, 1, 2, 1], requires_grad=0, device=cpu) = prim::Constant[value=]() + %x.14 : Float(1, 3, 224, 223, strides=[149856, 1, 669, 3], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %self.conv.weight, %self.conv.bias, %4, %3, %4, %2) + %x.10 : Float(1, 3, 224, 223, strides=[149856, 1, 669, 3], requires_grad=0, device=cpu) = aten::gelu(%x.14, %1) + return (%x.10))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph, /*parse_tensor_constants*/ true); + TensorExprKernel k(graph); + std::ostringstream oss; + oss << *k.getCodeGenStmt(); + + torch::jit::testing::FileCheck() + .check("mkldnn_prepacked_conv_run") + ->run(oss.str()); + + torch::jit::testing::FileCheck().check_not("aten_tanh_gelu")->run(oss.str()); + + testing::FileCheck() + .check("mkldnn_prepacked::conv2d_run") + ->check_not("mkldnn:prepacked::conv2d_prepack") // this should be folded + // as constant. + ->check_not("aten::gelu") + ->run(*k.graph()); +} + TEST_F(Kernel, ParallelStrided) { const auto graph_string = R"IR( graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), From d5da32646d5e99707ca3c83c82f2231e687c8bb4 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 10:03:13 +0800 Subject: [PATCH 08/15] linear: add check on dtype (commented out for now) --- torch/csrc/jit/passes/tensorexpr_fuser.cpp | 2 +- torch/csrc/jit/tensorexpr/kernel.cpp | 21 ++++++++++++------- torch/csrc/jit/tensorexpr/kernel.h | 2 +- .../csrc/jit/tensorexpr/operators/matmul.cpp | 13 ++++++++++++ torch/csrc/jit/tensorexpr/operators/matmul.h | 4 ++++ 5 files changed, 32 insertions(+), 10 deletions(-) diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index 3cc710d2b392c0..89280788f59e2e 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -1060,7 +1060,7 @@ class TensorExprFuser { } } if (node->kind() == aten::linear) { - if (!tensorexpr::mkldnnLinearIsSupported(node)) { + if (!tensorexpr::mkldnnPrepackedLinearIsSupportedJit(node)) { GRAPH_DEBUG("Shapes of linear inputs are not supported"); return false; } diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 0cc4d47c4fdc3c..3eef032d8ee0f7 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -314,25 +314,30 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) { return false; } -bool mkldnnLinearIsSupported(const torch::jit::Node* node) { - auto const& input0 = getTensorInfoJit(node->input(0)); - auto const& input1 = getTensorInfoJit(node->input(1)); +bool mkldnnPrepackedLinearIsSupportedJit(const torch::jit::Node* node) { +#if AT_MKLDNN_ENABLED() + + auto const& input = getTensorInfoJit(node->input(0)); + auto const& weight = getTensorInfoJit(node->input(1)); // Everything should be statically known. - if (!input0 || !input1) { - GRAPH_DEBUG("mkldnnLinearIsSupported: Input shapes aren't static"); + if (!input || !weight) { + GRAPH_DEBUG( + "mkldnnPrepackedLinearIsSupportedJit: Input shapes aren't static"); return false; } // Inputs should be contiguous, or the TE will needlessly transpose them. if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) { - GRAPH_DEBUG("mkldnnLinearIsSupported: Input shapes are not contiguous"); + GRAPH_DEBUG( + "mkldnnPrepackedLinearIsSupportedJit: Input shapes are not contiguous"); return false; } - // TODO: only support BF16 to make sure there's no performance regression + return mkldnnPrepackedLinearIsSupported(*input, *weight); - return true; +#endif + return false; } // The fuser currently only supports matmul of 2D x 2D matrices diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index decf069edd4c46..88fa9f606eaf83 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -28,7 +28,7 @@ bool conv2dIsSupportedJit(const Node* node); bool mkldnnPrepackedConvIsSupportedJit(const Node* node); // Returns true if the TE fuser supports this linear with mkldnn prepacked // linear. -bool mkldnnLinearIsSupported(const Node* node); +bool mkldnnPrepackedLinearIsSupportedJit(const Node* node); // Returns true if the TE fuser supports this matmul. bool matmulIsSupported(const Node* node); template diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.cpp b/torch/csrc/jit/tensorexpr/operators/matmul.cpp index 888864930080ff..43c121343d5846 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.cpp +++ b/torch/csrc/jit/tensorexpr/operators/matmul.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -5,6 +6,18 @@ namespace torch { namespace jit { namespace tensorexpr { +bool mkldnnPrepackedLinearIsSupported( + const TensorInfo& input, + const TensorInfo& weight) { + // TODO: only support BF16 to make sure there's no performance regression + // if (input.dtype != c10::ScalarType::BFloat16 || + // weight.dtype != c10::ScalarType::BFloat16) { + // GRAPH_DEBUG("conv2dIsSupported: only bfloat16 allowed"); + // return false; + // } + return true; +} + Tensor computeMatmul( const std::vector& inputs, const std::vector& outputShape, diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.h b/torch/csrc/jit/tensorexpr/operators/matmul.h index de08007994d5c6..0c1a6f95a5520e 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.h +++ b/torch/csrc/jit/tensorexpr/operators/matmul.h @@ -1,11 +1,15 @@ #pragma once #include +#include namespace torch { namespace jit { namespace tensorexpr { +bool mkldnnPrepackedLinearIsSupported( + const TensorInfo& input, + const TensorInfo& weight); Tensor computeMatmul( const std::vector& inputs, const std::vector& outputShape, From c83e18497b509d5801566dbf36aedb6abcb59a95 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 10:10:00 +0800 Subject: [PATCH 09/15] linear: check weight and bias should be Constant --- torch/csrc/jit/tensorexpr/kernel.cpp | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 3eef032d8ee0f7..d1d24e98f18e13 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -320,17 +320,28 @@ bool mkldnnPrepackedLinearIsSupportedJit(const torch::jit::Node* node) { auto const& input = getTensorInfoJit(node->input(0)); auto const& weight = getTensorInfoJit(node->input(1)); + // TODO: skip bias check when bias is None + auto const& bias = getTensorInfoJit(node->input(2)); + // Everything should be statically known. if (!input || !weight) { GRAPH_DEBUG( - "mkldnnPrepackedLinearIsSupportedJit: Input shapes aren't static"); + "mkldnnPrepackedLinearIsSupportedJit: some params aren't static"); + return false; + } + + // Weights and bias should be Constant when using mkldnn backend + if (node->input(1)->node()->kind() != prim::Constant || + node->input(2)->node()->kind() != prim::Constant) { + GRAPH_DEBUG( + "mkldnnPrepackedLinearIsSupportedJit: weight or bias is not Constant"); return false; } // Inputs should be contiguous, or the TE will needlessly transpose them. if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) { GRAPH_DEBUG( - "mkldnnPrepackedLinearIsSupportedJit: Input shapes are not contiguous"); + "mkldnnPrepackedLinearIsSupportedJit: some inputs are not contiguous"); return false; } From 8e961421c8facbc726fe7ea6c6ab9287ae187004 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 10:12:59 +0800 Subject: [PATCH 10/15] linear: check contiguous during graph rewrite --- torch/csrc/jit/passes/mkldnn_rewrite.cpp | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/torch/csrc/jit/passes/mkldnn_rewrite.cpp b/torch/csrc/jit/passes/mkldnn_rewrite.cpp index 66644e48d3b83c..225f3b7ca2688a 100644 --- a/torch/csrc/jit/passes/mkldnn_rewrite.cpp +++ b/torch/csrc/jit/passes/mkldnn_rewrite.cpp @@ -212,7 +212,15 @@ void insertPrePackedConvOpForNode(Node* n) { void insertPrePackedLinearOpForNode(Node* n) { constexpr int POS_INPUT = 0; constexpr int POS_WEIGHT = 1; - // TODO: check input and weight should be contiguous + if (!tensorexpr::isContiguous(n->input(POS_INPUT))) { + GRAPH_DEBUG("insertPrePackedLinearOpForNode: input is not contiguous"); + return; + } + + if (!tensorexpr::isContiguous(n->input(POS_WEIGHT))) { + GRAPH_DEBUG("insertPrePackedLinearOpForNode: weight is not contiguous"); + return; + } WithInsertPoint guard(n); auto graph = n->owningGraph(); From 03218d21b8fb6b49b54dd5549da857d4a4941c17 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 11:04:12 +0800 Subject: [PATCH 11/15] linear: remove mem copy for output --- aten/src/ATen/native/mkldnn/LinearPrepack.cpp | 59 +++++++++++++++++++ aten/src/ATen/native/mkldnn/LinearPrepack.h | 2 + aten/src/ATen/native/mkldnn/OpContext.cpp | 4 ++ aten/src/ATen/native/mkldnn/OpContext.h | 6 +- .../jit/tensorexpr/external_functions.cpp | 7 ++- 5 files changed, 73 insertions(+), 5 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp index d26eab5ae83593..bd8b0f3f3a842d 100644 --- a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp @@ -125,6 +125,65 @@ Tensor run(ContextLinear& context, const Tensor& input) { return output; } +void run(ContextLinear& context, const Tensor& input, void* output) { + const ideep::tensor& mkldnn_weight = context.weight_packed_; + + auto input_size = input.sizes(); + + const int64_t dim = input.dim(); + auto input_reshaped = + dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); + + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(mkldnn_weight.get_dim(0)); + // auto output = at::empty(output_size, input.options()); + + std::vector output_size_reshaped = { + input_reshaped.size(0), mkldnn_weight.get_dim(0)}; + // output = output.reshape(output_size_reshaped); + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + const ideep::tensor mkldnn_input = itensor_view_from_dense(input_reshaped); + + ideep::tensor::desc o_desc = { + output_size_reshaped, mkldnn_input.get_data_type()}; + ideep::tensor mkldnn_output = {o_desc, output}; + + // ideep::tensor mkldnn_output = itensor_view_from_dense(output); + + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(context.at_bias_); + const Tensor& bias = *bias_maybe_owned; + + if (bias.defined()) { + const ideep::tensor mkldnn_bias = itensor_view_from_dense(bias); + ideep::inner_product_forward::compute( + mkldnn_input, + mkldnn_weight, + mkldnn_bias, + mkldnn_output, + ideep::scale_t(), + ideep::scale_t(), + ideep::scale_t(), + context.attr_); + } else { + ideep::inner_product_forward::compute( + mkldnn_input, + mkldnn_weight, + mkldnn_output, + ideep::scale_t(), + ideep::scale_t(), + ideep::scale_t(), + context.attr_); + } + + // if (dim != 2) { + // output = output.reshape(output_size); + // } + + // return output; +} + Tensor linear_run( const Tensor& input, const c10::intrusive_ptr& op_context) { diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.h b/aten/src/ATen/native/mkldnn/LinearPrepack.h index aa43b6791862da..e1696f50eee381 100644 --- a/aten/src/ATen/native/mkldnn/LinearPrepack.h +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.h @@ -32,6 +32,8 @@ ContextLinear create( Tensor run(ContextLinear& context, const Tensor& input); +void run(ContextLinear& context, const Tensor& input, void* output); + } // namespace linear } // namespace internal } // namespace mkldnn diff --git a/aten/src/ATen/native/mkldnn/OpContext.cpp b/aten/src/ATen/native/mkldnn/OpContext.cpp index f408da5daa15e8..126fdc9eba84d0 100644 --- a/aten/src/ATen/native/mkldnn/OpContext.cpp +++ b/aten/src/ATen/native/mkldnn/OpContext.cpp @@ -62,6 +62,10 @@ Tensor MkldnnLinearOpContext::run(const Tensor& input) { return mkldnn::internal::linear::run(op_context_, input); } +void MkldnnLinearOpContext::run(const Tensor& input, void* output) { + return mkldnn::internal::linear::run(op_context_, input, output); +} + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/OpContext.h b/aten/src/ATen/native/mkldnn/OpContext.h index 591e9b29973cf0..d0f193e549067b 100644 --- a/aten/src/ATen/native/mkldnn/OpContext.h +++ b/aten/src/ATen/native/mkldnn/OpContext.h @@ -118,7 +118,7 @@ class LinearOpContext : public torch::jit::CustomClassHolder { virtual at::Tensor run(const at::Tensor& input) = 0; - // TODO: run with void* output + virtual void run(const Tensor& input, void* output) = 0; }; class MkldnnLinearOpContext final : public LinearOpContext { @@ -137,7 +137,9 @@ class MkldnnLinearOpContext final : public LinearOpContext { input_size_ = std::move(input_size); } - virtual at::Tensor run(const at::Tensor& input) override; + at::Tensor run(const at::Tensor& input) override; + + void run(const Tensor& input, void* output) override; static c10::intrusive_ptr create_context( at::Tensor&& weight, diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index 5c33d0e954ccd8..842573e4cfa928 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1452,10 +1452,11 @@ void nnc_mkldnn_prepacked_linear_run( auto context = reinterpret_cast(buf_data[2]); at::Tensor output = context->run(x); - memcpy( - buf_data[0], output.data_ptr(), output.element_size() * output.numel()); + // memcpy( + // buf_data[0], output.data_ptr(), output.element_size() * + // output.numel()); // TODO: remove mem copy - // context->run(x, buf_data[0]); + context->run(x, buf_data[0]); } #endif // AT_MKLDNN_ENABLED() From 10916614e0725a2742628d423aaafbc739e5ac63 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 11:04:49 +0800 Subject: [PATCH 12/15] linear: refactor run with void* output --- aten/src/ATen/native/mkldnn/LinearPrepack.cpp | 90 +++++++++++-------- 1 file changed, 51 insertions(+), 39 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp index bd8b0f3f3a842d..55dcc364b714bd 100644 --- a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp @@ -125,6 +125,50 @@ Tensor run(ContextLinear& context, const Tensor& input) { return output; } +void _mkldnn_linear_out( + const ideep::tensor& x, + ideep::tensor& y, + const ideep::tensor& w, + const c10::optional& b, + const ideep::attr_t& attr = ideep::attr_t()) { + if (b.has_value()) { + ideep::inner_product_forward::compute( + x, + w, + b.value(), + y, + ideep::scale_t(), + ideep::scale_t(), + ideep::scale_t(), + attr); + } else { + ideep::inner_product_forward::compute( + x, w, y, ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), attr); + } +} + +void mkldnn_linear_out( + const Tensor& input, + ideep::tensor& mkldnn_output, + const ideep::tensor& mkldnn_weight, + const c10::optional& bias_opt, + const ideep::attr_t& attr = ideep::attr_t()) { + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + const ideep::tensor mkldnn_input = itensor_view_from_dense(input); + + c10::optional mkldnn_bias{c10::nullopt}; + if (bias.defined()) { + mkldnn_bias = itensor_from_tensor(bias); + } + + _mkldnn_linear_out( + mkldnn_input, mkldnn_output, mkldnn_weight, mkldnn_bias, attr); +} + void run(ContextLinear& context, const Tensor& input, void* output) { const ideep::tensor& mkldnn_weight = context.weight_packed_; @@ -136,52 +180,20 @@ void run(ContextLinear& context, const Tensor& input, void* output) { std::vector output_size(input_size.begin(), input_size.end() - 1); output_size.push_back(mkldnn_weight.get_dim(0)); - // auto output = at::empty(output_size, input.options()); std::vector output_size_reshaped = { input_reshaped.size(0), mkldnn_weight.get_dim(0)}; - // output = output.reshape(output_size_reshaped); - - c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); - const ideep::tensor mkldnn_input = itensor_view_from_dense(input_reshaped); ideep::tensor::desc o_desc = { - output_size_reshaped, mkldnn_input.get_data_type()}; + output_size_reshaped, get_mkldnn_dtype(input.scalar_type())}; ideep::tensor mkldnn_output = {o_desc, output}; - // ideep::tensor mkldnn_output = itensor_view_from_dense(output); - - c10::MaybeOwned bias_maybe_owned = - at::borrow_from_optional_tensor(context.at_bias_); - const Tensor& bias = *bias_maybe_owned; - - if (bias.defined()) { - const ideep::tensor mkldnn_bias = itensor_view_from_dense(bias); - ideep::inner_product_forward::compute( - mkldnn_input, - mkldnn_weight, - mkldnn_bias, - mkldnn_output, - ideep::scale_t(), - ideep::scale_t(), - ideep::scale_t(), - context.attr_); - } else { - ideep::inner_product_forward::compute( - mkldnn_input, - mkldnn_weight, - mkldnn_output, - ideep::scale_t(), - ideep::scale_t(), - ideep::scale_t(), - context.attr_); - } - - // if (dim != 2) { - // output = output.reshape(output_size); - // } - - // return output; + mkldnn_linear_out( + input_reshaped, + mkldnn_output, + mkldnn_weight, + context.at_bias_, + context.attr_); } Tensor linear_run( From 434191a9b853cda813a52abb77db06fcea4ad212 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 11:10:06 +0800 Subject: [PATCH 13/15] linear: refactor run without output buffer --- aten/src/ATen/native/mkldnn/LinearPrepack.cpp | 92 ++++++++----------- 1 file changed, 36 insertions(+), 56 deletions(-) diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp index 55dcc364b714bd..e3b3c976ce2fc0 100644 --- a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp @@ -69,62 +69,6 @@ ContextLinear create( std::move(attr)}; } -Tensor run(ContextLinear& context, const Tensor& input) { - const ideep::tensor& mkldnn_weight = context.weight_packed_; - - auto input_size = input.sizes(); - - const int64_t dim = input.dim(); - auto input_reshaped = - dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); - - std::vector output_size(input_size.begin(), input_size.end() - 1); - output_size.push_back(mkldnn_weight.get_dim(0)); - auto output = at::empty(output_size, input.options()); - - if (dim != 2) { - std::vector output_size_reshaped = { - input_reshaped.size(0), mkldnn_weight.get_dim(0)}; - output = output.reshape(output_size_reshaped); - } - - c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); - const ideep::tensor mkldnn_input = itensor_view_from_dense(input_reshaped); - ideep::tensor mkldnn_output = itensor_view_from_dense(output); - - c10::MaybeOwned bias_maybe_owned = - at::borrow_from_optional_tensor(context.at_bias_); - const Tensor& bias = *bias_maybe_owned; - - if (bias.defined()) { - const ideep::tensor mkldnn_bias = itensor_view_from_dense(bias); - ideep::inner_product_forward::compute( - mkldnn_input, - mkldnn_weight, - mkldnn_bias, - mkldnn_output, - ideep::scale_t(), - ideep::scale_t(), - ideep::scale_t(), - context.attr_); - } else { - ideep::inner_product_forward::compute( - mkldnn_input, - mkldnn_weight, - mkldnn_output, - ideep::scale_t(), - ideep::scale_t(), - ideep::scale_t(), - context.attr_); - } - - if (dim != 2) { - output = output.reshape(output_size); - } - - return output; -} - void _mkldnn_linear_out( const ideep::tensor& x, ideep::tensor& y, @@ -169,6 +113,42 @@ void mkldnn_linear_out( mkldnn_input, mkldnn_output, mkldnn_weight, mkldnn_bias, attr); } +Tensor run(ContextLinear& context, const Tensor& input) { + const ideep::tensor& mkldnn_weight = context.weight_packed_; + + auto input_size = input.sizes(); + + const int64_t dim = input.dim(); + auto input_reshaped = + dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); + + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(mkldnn_weight.get_dim(0)); + auto output = at::empty(output_size, input.options()); + + if (dim != 2) { + std::vector output_size_reshaped = { + input_reshaped.size(0), mkldnn_weight.get_dim(0)}; + output = output.reshape(output_size_reshaped); + } + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + ideep::tensor mkldnn_output = itensor_from_tensor(output); + + mkldnn_linear_out( + input_reshaped, + mkldnn_output, + mkldnn_weight, + context.at_bias_, + context.attr_); + + if (dim != 2) { + output = output.reshape(output_size); + } + + return output; +} + void run(ContextLinear& context, const Tensor& input, void* output) { const ideep::tensor& mkldnn_weight = context.weight_packed_; From a95b1402fcc93eacf8de4740f52e2cf2aa8f07a3 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 11:10:30 +0800 Subject: [PATCH 14/15] linear: remove mem copy in external call --- torch/csrc/jit/tensorexpr/external_functions.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index 842573e4cfa928..22082f3f8e2b5a 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1452,10 +1452,6 @@ void nnc_mkldnn_prepacked_linear_run( auto context = reinterpret_cast(buf_data[2]); at::Tensor output = context->run(x); - // memcpy( - // buf_data[0], output.data_ptr(), output.element_size() * - // output.numel()); - // TODO: remove mem copy context->run(x, buf_data[0]); } From 77dca307300ac78c00a37c594bd3151f1438b398 Mon Sep 17 00:00:00 2001 From: chunyuan-w Date: Thu, 23 Jun 2022 12:16:42 +0800 Subject: [PATCH 15/15] fix typo in the comment --- torch/csrc/jit/tensorexpr/operators/matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.cpp b/torch/csrc/jit/tensorexpr/operators/matmul.cpp index 43c121343d5846..329dacca3decc0 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.cpp +++ b/torch/csrc/jit/tensorexpr/operators/matmul.cpp @@ -12,7 +12,7 @@ bool mkldnnPrepackedLinearIsSupported( // TODO: only support BF16 to make sure there's no performance regression // if (input.dtype != c10::ScalarType::BFloat16 || // weight.dtype != c10::ScalarType::BFloat16) { - // GRAPH_DEBUG("conv2dIsSupported: only bfloat16 allowed"); + // GRAPH_DEBUG("mkldnnPrepackedLinearIsSupported: only bfloat16 allowed"); // return false; // } return true;