diff --git a/aten/src/ATen/native/mkldnn/Common.h b/aten/src/ATen/native/mkldnn/Common.h index 4e048ebce7597..da6a2c3f604cb 100644 --- a/aten/src/ATen/native/mkldnn/Common.h +++ b/aten/src/ATen/native/mkldnn/Common.h @@ -39,6 +39,22 @@ struct ContextConv final { attr_(attr) {} }; +struct ContextLinear final { + ideep::tensor weight_packed_; + c10::optional at_bias_; + ideep::attr_t attr_; + + ContextLinear() = delete; + + ContextLinear( + ideep::tensor&& weight_packed, + c10::optional at_bias, + ideep::attr_t attr) + : weight_packed_(std::move(weight_packed)), + at_bias_(std::move(at_bias)), + attr_(attr) {} +}; + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.cpp b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp new file mode 100644 index 0000000000000..e3b3c976ce2fc --- /dev/null +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.cpp @@ -0,0 +1,191 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at { +namespace native { +namespace mkldnn { +namespace internal { +namespace linear { + +using namespace torch::jit::mkldnn; + +c10::intrusive_ptr createLinearPrePackOpContext( + Tensor weight, + c10::optional bias, + std::vector input_size, + std::string attr, + std::vector> scalars, + c10::optional algorithm) { + auto it = fusion_attr_map().find(attr); + TORCH_CHECK(it != fusion_attr_map().end(), "Fusion behavior undefined."); + ideep::attr_t op_attr = it->second.attr_function(scalars, algorithm); + return mkldnn::MkldnnLinearOpContext::create_context( + std::move(weight), std::move(bias), std::move(input_size), op_attr); +} + +ContextLinear create( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef input_size, + const ideep::attr_t& attr) { + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + ideep::tensor w = itensor_view_from_dense(weight); + auto dtype = w.get_data_type(); + + int64_t b_size = std::accumulate( + input_size.begin(), + input_size.end(), + (int64_t)1, + std::multiplies()) / + input_size[input_size.size() - 1]; + + auto out_features = weight.size(0); + auto in_features = weight.size(1); + ideep::dims reshaped_input_size = {b_size, in_features}; + + ideep::tensor::desc expected_weight_desc = + ideep::inner_product_forward::expected_weights_desc( + {out_features, in_features}, + reshaped_input_size, + /* w_dtype */ dtype, + /* x_dtype */ dtype); + + ideep::tensor packed_weight; + packed_weight.init(expected_weight_desc); + packed_weight.feed_from(w); + + return ContextLinear{ + std::move(packed_weight), + bias.has_value() ? c10::make_optional(*bias) : c10::nullopt, + std::move(attr)}; +} + +void _mkldnn_linear_out( + const ideep::tensor& x, + ideep::tensor& y, + const ideep::tensor& w, + const c10::optional& b, + const ideep::attr_t& attr = ideep::attr_t()) { + if (b.has_value()) { + ideep::inner_product_forward::compute( + x, + w, + b.value(), + y, + ideep::scale_t(), + ideep::scale_t(), + ideep::scale_t(), + attr); + } else { + ideep::inner_product_forward::compute( + x, w, y, ideep::scale_t(), ideep::scale_t(), ideep::scale_t(), attr); + } +} + +void mkldnn_linear_out( + const Tensor& input, + ideep::tensor& mkldnn_output, + const ideep::tensor& mkldnn_weight, + const c10::optional& bias_opt, + const ideep::attr_t& attr = ideep::attr_t()) { + c10::MaybeOwned bias_maybe_owned = + at::borrow_from_optional_tensor(bias_opt); + const Tensor& bias = *bias_maybe_owned; + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + const ideep::tensor mkldnn_input = itensor_view_from_dense(input); + + c10::optional mkldnn_bias{c10::nullopt}; + if (bias.defined()) { + mkldnn_bias = itensor_from_tensor(bias); + } + + _mkldnn_linear_out( + mkldnn_input, mkldnn_output, mkldnn_weight, mkldnn_bias, attr); +} + +Tensor run(ContextLinear& context, const Tensor& input) { + const ideep::tensor& mkldnn_weight = context.weight_packed_; + + auto input_size = input.sizes(); + + const int64_t dim = input.dim(); + auto input_reshaped = + dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); + + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(mkldnn_weight.get_dim(0)); + auto output = at::empty(output_size, input.options()); + + if (dim != 2) { + std::vector output_size_reshaped = { + input_reshaped.size(0), mkldnn_weight.get_dim(0)}; + output = output.reshape(output_size_reshaped); + } + + c10::impl::ExcludeDispatchKeyGuard edkg(c10::autograd_dispatch_keyset); + ideep::tensor mkldnn_output = itensor_from_tensor(output); + + mkldnn_linear_out( + input_reshaped, + mkldnn_output, + mkldnn_weight, + context.at_bias_, + context.attr_); + + if (dim != 2) { + output = output.reshape(output_size); + } + + return output; +} + +void run(ContextLinear& context, const Tensor& input, void* output) { + const ideep::tensor& mkldnn_weight = context.weight_packed_; + + auto input_size = input.sizes(); + + const int64_t dim = input.dim(); + auto input_reshaped = + dim == 2 ? input : input.reshape({-1, input.size(input.dim() - 1)}); + + std::vector output_size(input_size.begin(), input_size.end() - 1); + output_size.push_back(mkldnn_weight.get_dim(0)); + + std::vector output_size_reshaped = { + input_reshaped.size(0), mkldnn_weight.get_dim(0)}; + + ideep::tensor::desc o_desc = { + output_size_reshaped, get_mkldnn_dtype(input.scalar_type())}; + ideep::tensor mkldnn_output = {o_desc, output}; + + mkldnn_linear_out( + input_reshaped, + mkldnn_output, + mkldnn_weight, + context.at_bias_, + context.attr_); +} + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& op_context) { + return op_context->run(input); +} + +} // namespace linear +} // namespace internal +} // namespace mkldnn +} // namespace native +} // namespace at + +#endif // AT_MKLDNN_ENABLED() \ No newline at end of file diff --git a/aten/src/ATen/native/mkldnn/LinearPrepack.h b/aten/src/ATen/native/mkldnn/LinearPrepack.h new file mode 100644 index 0000000000000..e1696f50eee38 --- /dev/null +++ b/aten/src/ATen/native/mkldnn/LinearPrepack.h @@ -0,0 +1,43 @@ +#pragma once + +#include +#include +#include + +#if AT_MKLDNN_ENABLED() + +namespace at { +namespace native { +namespace mkldnn { +namespace internal { +namespace linear { + +c10::intrusive_ptr createLinearPrePackOpContext( + Tensor weight, + c10::optional bias, + std::vector input_size, + std::string attr, + std::vector> scalars, + c10::optional algorithm); + +Tensor linear_run( + const Tensor& input, + const c10::intrusive_ptr& op_context); + +ContextLinear create( + const Tensor& weight, + const c10::optional& bias, + const IntArrayRef input_size, + const ideep::attr_t& attr); + +Tensor run(ContextLinear& context, const Tensor& input); + +void run(ContextLinear& context, const Tensor& input, void* output); + +} // namespace linear +} // namespace internal +} // namespace mkldnn +} // namespace native +} // namespace at + +#endif // AT_MKLDNN_ENABLED() \ No newline at end of file diff --git a/aten/src/ATen/native/mkldnn/OpContext.cpp b/aten/src/ATen/native/mkldnn/OpContext.cpp index 2716b4908eb30..126fdc9eba84d 100644 --- a/aten/src/ATen/native/mkldnn/OpContext.cpp +++ b/aten/src/ATen/native/mkldnn/OpContext.cpp @@ -1,4 +1,5 @@ #include +#include #include #if AT_MKLDNN_ENABLED() @@ -40,6 +41,31 @@ void MkldnnConvOpContext::run(const Tensor& input, void* output) { return mkldnn::internal::convolution::run(op_context_, input, output); } +c10::intrusive_ptr MkldnnLinearOpContext::create_context( + at::Tensor&& weight, + c10::optional&& bias, + std::vector&& input_size, + const ideep::attr_t& attr) { + auto op_context = + mkldnn::internal::linear::create(weight, bias, input_size, attr); + + auto linear_op_context = c10::make_intrusive( + std::move(weight), + std::move(bias), + std::move(input_size), + std::move(op_context)); + + return linear_op_context; +} + +Tensor MkldnnLinearOpContext::run(const Tensor& input) { + return mkldnn::internal::linear::run(op_context_, input); +} + +void MkldnnLinearOpContext::run(const Tensor& input, void* output) { + return mkldnn::internal::linear::run(op_context_, input, output); +} + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/OpContext.h b/aten/src/ATen/native/mkldnn/OpContext.h index 6c09598288405..d0f193e549067 100644 --- a/aten/src/ATen/native/mkldnn/OpContext.h +++ b/aten/src/ATen/native/mkldnn/OpContext.h @@ -93,6 +93,61 @@ class MkldnnConvOpContext final : public ConvOpContext { const ideep::attr_t& attr); }; +using SerializationTypeLinearPrePack = std::tuple< + at::Tensor, + c10::optional, + std::vector, + std::string, + std::vector>, + c10::optional>; + +class LinearOpContext : public torch::jit::CustomClassHolder { + protected: + Tensor orig_weight_; + c10::optional orig_bias_; + std::vector input_size_; + std::string attr_; + std::vector> scalars_; + c10::optional algorithm_; + + public: + SerializationTypeLinearPrePack unpack() { + return std::make_tuple( + orig_weight_, orig_bias_, input_size_, attr_, scalars_, algorithm_); + } + + virtual at::Tensor run(const at::Tensor& input) = 0; + + virtual void run(const Tensor& input, void* output) = 0; +}; + +class MkldnnLinearOpContext final : public LinearOpContext { + private: + ContextLinear op_context_; + + public: + MkldnnLinearOpContext( + Tensor&& weight, + c10::optional&& bias, + std::vector&& input_size, + ContextLinear&& op_context) + : op_context_(std::move(op_context)) { + orig_weight_ = std::move(weight); + orig_bias_ = std::move(bias); + input_size_ = std::move(input_size); + } + + at::Tensor run(const at::Tensor& input) override; + + void run(const Tensor& input, void* output) override; + + static c10::intrusive_ptr create_context( + at::Tensor&& weight, + c10::optional&& bias, + std::vector&& input_size, + const ideep::attr_t& attr); +}; + } // namespace mkldnn } // namespace native } // namespace at diff --git a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp index 534e6388f4428..b1c36c6072d79 100644 --- a/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp +++ b/aten/src/ATen/native/mkldnn/RegisterMkldnnOpContextClass.cpp @@ -1,5 +1,6 @@ #include #include +#include #include #include #include @@ -11,6 +12,7 @@ namespace native { namespace mkldnn { using namespace internal::convolution; +using namespace internal::linear; TORCH_LIBRARY(mkldnn, m) { m.class_(TORCH_SELECTIVE_CLASS("ConvOpContext")) @@ -36,14 +38,39 @@ TORCH_LIBRARY(mkldnn, m) { std::move(std::get<8>(state)), std::move(std::get<9>(state))); }); + + m.class_(TORCH_SELECTIVE_CLASS("LinearOpContext")) + .def_pickle( + [](const c10::intrusive_ptr& op_context) + -> SerializationTypeLinearPrePack { // __getstate__ + return op_context->unpack(); + }, + [](SerializationTypeLinearPrePack state) + -> c10::intrusive_ptr { // __setstate__ + return createLinearPrePackOpContext( + std::move(std::get<0>(state)), + std::move(std::get<1>(state)), + std::move(std::get<2>(state)), + std::move(std::get<3>(state)), + std::move(std::get<4>(state)), + std::move(std::get<5>(state))); + }); } TORCH_LIBRARY(mkldnn_prepacked, m) { + // conv m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn_prepacked::conv2d_prepack(Tensor W, Tensor? B, int[2] stride, int[2] padding, int[2] dilation, int groups, int[4] input_size, str attr, Scalar?[] scalars, str? algorithm) -> __torch__.torch.classes.mkldnn.ConvOpContext")); m.def(TORCH_SELECTIVE_SCHEMA( "mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> Tensor Y")); + + // linear + m.def(TORCH_SELECTIVE_SCHEMA( + "mkldnn_prepacked::linear_prepack(Tensor w, Tensor? B, int[] input_sizes, str attr, Scalar?[] scalars, str? algorithm) -> __torch__.torch.classes.mkldnn.LinearOpContext")); + + m.def(TORCH_SELECTIVE_SCHEMA( + "mkldnn_prepacked::linear_run(Tensor X, __torch__.torch.classes.mkldnn.LinearOpContext W_prepack) -> Tensor Y")); } TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) { @@ -53,6 +80,14 @@ TORCH_LIBRARY_IMPL(mkldnn_prepacked, CPU, m) { m.impl( TORCH_SELECTIVE_NAME("mkldnn_prepacked::conv2d_run"), TORCH_FN(conv_run)); + + m.impl( + TORCH_SELECTIVE_NAME("mkldnn_prepacked::linear_prepack"), + TORCH_FN(createLinearPrePackOpContext)); + + m.impl( + TORCH_SELECTIVE_NAME("mkldnn_prepacked::linear_run"), + TORCH_FN(linear_run)); } } // namespace mkldnn diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp index 98e7fca452b0f..14a42c2446b31 100644 --- a/test/cpp/tensorexpr/test_kernel.cpp +++ b/test/cpp/tensorexpr/test_kernel.cpp @@ -254,6 +254,38 @@ TEST_F(Kernel, Huge) { torch::jit::testing::FileCheck().run(verification_pattern, oss.str()); } +TEST_F(Kernel, MkldnnConvEltwise) { + const auto graph_string = R"IR( + graph(%x.1 : Float(1, 1, 224, 224, strides=[50176, 1, 224, 1], requires_grad=0, device=cpu)): + %1 : str = prim::Constant[value="tanh"]() + %2 : int = prim::Constant[value=1]() + %3 : int[] = prim::Constant[value=[0, 0]]() + %4 : int[] = prim::Constant[value=[1, 1]]() + %self.conv.bias : NoneType = prim::Constant() + %self.conv.weight : Float(3, 1, 1, 2, strides=[2, 1, 2, 1], requires_grad=0, device=cpu) = prim::Constant[value=]() + %x.14 : Float(1, 3, 224, 223, strides=[149856, 1, 669, 3], requires_grad=0, device=cpu) = aten::conv2d(%x.1, %self.conv.weight, %self.conv.bias, %4, %3, %4, %2) + %x.10 : Float(1, 3, 224, 223, strides=[149856, 1, 669, 3], requires_grad=0, device=cpu) = aten::gelu(%x.14, %1) + return (%x.10))IR"; + auto graph = std::make_shared(); + parseIR(graph_string, &*graph, /*parse_tensor_constants*/ true); + TensorExprKernel k(graph); + std::ostringstream oss; + oss << *k.getCodeGenStmt(); + + torch::jit::testing::FileCheck() + .check("mkldnn_prepacked_conv_run") + ->run(oss.str()); + + torch::jit::testing::FileCheck().check_not("aten_tanh_gelu")->run(oss.str()); + + testing::FileCheck() + .check("mkldnn_prepacked::conv2d_run") + ->check_not("mkldnn:prepacked::conv2d_prepack") // this should be folded + // as constant. + ->check_not("aten::gelu") + ->run(*k.graph()); +} + TEST_F(Kernel, ParallelStrided) { const auto graph_string = R"IR( graph(%0 : Float(5, 3, 40005, strides=[120015, 40005, 1], device=cpu), diff --git a/test/test_mkldnn_fusion.py b/test/test_mkldnn_fusion.py index b51cd7f47e0d2..60e56efd4f9a6 100644 --- a/test/test_mkldnn_fusion.py +++ b/test/test_mkldnn_fusion.py @@ -47,6 +47,72 @@ def _check_model(self, m, x): torch._C._jit_set_te_must_use_llvm_cpu(old_te_must_use_llvm_cpu) return graph + def _eltwise_list(self): + eltwise_list = [ + [torch.relu, 'aten::relu'], + [torch.sigmoid, 'aten::sigmoid'], + [torch.tanh, 'aten::tanh'], + [torch.nn.Hardswish(inplace=False), 'aten::hardswish'], + [nn.LeakyReLU(0.1, inplace=False), 'aten::leaky_relu'], + [nn.Hardtanh(inplace=False), 'aten::hardtanh'], + [nn.GELU(approximate="none"), 'aten::gelu'], + [nn.GELU(approximate="tanh"), 'aten::gelu'], + ] + return eltwise_list + + def _clamp_modules(self): + class MNoOpt(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MNoOpt, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=-0.5, max=0.9) + return x + + class MInf(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MInf, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=0, max=float('inf')) + return x + + class MNegInf(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MNegInf, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=float('-inf'), max=0) + return x + + class MOptMin(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MOptMin, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, max=2) + return x + + class MOptMax(nn.Module): + def __init__(self, m, in_channels, out_channels, bias, **kwargs): + super(MOptMax, self).__init__() + self.conv = m(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + x = self.conv(x) + x = torch.clamp(x, min=0) + return x + + return [MNoOpt, MInf, MNegInf, MOptMin, MOptMax] + def test_single_conv(self): class M(nn.Module): def __init__(self, in_channels, out_channels, bias, **kwargs): @@ -100,16 +166,7 @@ def forward(self, x): [torch.contiguous_format, False], [torch.channels_last, True], ]: - for eltwise_fn, op_name in [ - [torch.relu, 'aten::relu'], - [torch.sigmoid, 'aten::sigmoid'], - [torch.tanh, 'aten::tanh'], - [torch.nn.Hardswish(inplace=False), 'aten::hardswish'], - [nn.LeakyReLU(0.1, inplace=False), 'aten::leaky_relu'], - [nn.Hardtanh(inplace=False), 'aten::hardtanh'], - [nn.GELU(approximate="none"), 'aten::gelu'], - [nn.GELU(approximate="tanh"), 'aten::gelu'], - ]: + for eltwise_fn, op_name in self._eltwise_list(): for bias in [True, False]: for oC in [1, 10]: m = M(eltwise_fn, 3, oC, bias, kernel_size=(3, 3)).to(memory_format=memory_format) @@ -123,57 +180,7 @@ def forward(self, x): self.assertGraphContains(graph, kind='aten::conv2d') def test_conv_clamp(self): - class MNoOpt(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MNoOpt, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=-0.5, max=0.9) - return x - - class MInf(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MInf, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=0, max=float('inf')) - return x - - class MNegInf(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MNegInf, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=float('-inf'), max=0) - return x - - class MOptMin(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MOptMin, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, max=2) - return x - - class MOptMax(nn.Module): - def __init__(self, in_channels, out_channels, bias, **kwargs): - super(MOptMax, self).__init__() - self.conv = torch.nn.Conv2d(in_channels, out_channels, bias=bias, **kwargs) - - def forward(self, x): - x = self.conv(x) - x = torch.clamp(x, min=0) - return x - - modules = [MNoOpt, MInf, MNegInf, MOptMin, MOptMax] + modules = self._clamp_modules() op_name = 'aten::clamp' for memory_format, enabled in [ @@ -182,7 +189,7 @@ def forward(self, x): ]: for M in modules: for bias in [True, False]: - m = M(3, 10, bias, kernel_size=(3, 3)).to(memory_format=memory_format) + m = M(nn.Conv2d, 3, 10, bias, kernel_size=(3, 3)).to(memory_format=memory_format) x = torch.randn(1, 3, 224, 224).to(memory_format=memory_format) graph = self._check_model(m, x) @@ -192,5 +199,74 @@ def forward(self, x): else: self.assertGraphContains(graph, kind='aten::conv2d') + def test_single_linear(self): + class M(nn.Module): + def __init__(self, in_channels, out_channels, bias, **kwargs): + super(M, self).__init__() + self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias, **kwargs) + + def forward(self, x): + res = self.linear(x) + return res + iC = 2 + oC = 3 + for bias in [True, False]: + # TODO: refactor x_sghape generation + for x_shape in [ + [1, iC], + [2, iC], + [3, 2, iC] + ]: + m = M(iC, oC, bias) + x = torch.randn(x_shape) + graph = self._check_model(m, x) + self.assertFused(graph, ['aten::linear']) + self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + + def test_linear_eltwise(self): + class M(nn.Module): + def __init__(self, eltwise_fn, in_channels, out_channels, bias, **kwargs): + super(M, self).__init__() + self.linear = torch.nn.Linear(in_channels, out_channels, bias=bias, **kwargs) + self.eltwise = eltwise_fn + + def forward(self, x): + x = self.linear(x) + x = self.eltwise(x) + return x + iC = 2 + oC = 3 + for eltwise_fn, op_name in self._eltwise_list(): + for bias in [True, False]: + for x_shape in [ + [1, iC], + [2, iC], + [3, 2, iC] + ]: + m = M(eltwise_fn, iC, oC, bias) + x = torch.randn(x_shape) + + graph = self._check_model(m, x) + self.assertFused(graph, ['aten::linear', 'aten::' + op_name]) + self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + + def test_linear_clamp(self): + modules = self._clamp_modules() + op_name = 'aten::clamp' + iC = 2 + oC = 3 + for M in modules: + for bias in [True, False]: + for x_shape in [ + [1, iC], + [2, iC], + [3, 2, iC] + ]: + m = M(nn.Linear, iC, oC, bias) + x = torch.randn(x_shape) + graph = self._check_model(m, x) + self.assertFused(graph, ['aten::linear', 'aten::' + op_name]) + self.assertGraphContainsExactly(graph, FUSION_GROUP, 1) + if __name__ == "__main__": run_tests() diff --git a/torch/csrc/jit/passes/mkldnn_rewrite.cpp b/torch/csrc/jit/passes/mkldnn_rewrite.cpp index 79cfb2646a09f..225f3b7ca2688 100644 --- a/torch/csrc/jit/passes/mkldnn_rewrite.cpp +++ b/torch/csrc/jit/passes/mkldnn_rewrite.cpp @@ -209,6 +209,59 @@ void insertPrePackedConvOpForNode(Node* n) { n->output()->replaceAllUsesWith(prepack_conv->output()); } +void insertPrePackedLinearOpForNode(Node* n) { + constexpr int POS_INPUT = 0; + constexpr int POS_WEIGHT = 1; + if (!tensorexpr::isContiguous(n->input(POS_INPUT))) { + GRAPH_DEBUG("insertPrePackedLinearOpForNode: input is not contiguous"); + return; + } + + if (!tensorexpr::isContiguous(n->input(POS_WEIGHT))) { + GRAPH_DEBUG("insertPrePackedLinearOpForNode: weight is not contiguous"); + return; + } + + WithInsertPoint guard(n); + auto graph = n->owningGraph(); + + auto input_sizes = getSizesOf(n, POS_INPUT); + IValue input_size_value(*input_sizes.concrete_sizes()); + auto input_size = graph->insertConstant(input_size_value); + + auto prepack_node = graph->create( + Symbol::fromQualString("mkldnn_prepacked::linear_prepack"), 1); + + // skip input value + for (auto i = 1; i < n->inputs().size(); i++) { + Value* v = n->input(i); + prepack_node->addInput(v); + } + prepack_node->addInput(input_size); + auto attr = graph->insertConstant(IValue("none")); + prepack_node->addInput(attr); + + std::vector> empty_scalars; + auto scalars = graph->insertConstant(IValue(empty_scalars)); + prepack_node->addInput(scalars); + + c10::optional empty_algorithm; + auto algorithm = graph->insertConstant(IValue(empty_algorithm)); + prepack_node->addInput(algorithm); + + prepack_node->output()->setType( + getCustomClass("__torch__.torch.classes.mkldnn.LinearOpContext")); + graph->insertNode(prepack_node); + + auto prepack_linear = graph->insertNode( + graph->create(Symbol::fromQualString("mkldnn_prepacked::linear_run"), 1)); + prepack_linear->addInput(n->input(0)); + prepack_linear->addInput(prepack_node->output()); + prepack_linear->output()->setType(n->output()->type()->cast()); + + n->output()->replaceAllUsesWith(prepack_linear->output()); +} + bool isTensorTypeCPU(Node* node) { for (const auto& input : node->inputs()) { auto type = input->type()->cast(); @@ -241,6 +294,21 @@ void insertPrePackedConvOp(Block* b) { EliminateDeadCode(b); } +void insertPrePackedLinearOp(Block* b) { + for (Node* n : b->nodes()) { + for (Block* b : n->blocks()) { + insertPrePackedLinearOp(b); + } + + if (n->kind() == aten::linear) { + if (isTensorTypeCPU(n)) { + insertPrePackedLinearOpForNode(n); + } + } + } + EliminateDeadCode(b); +} + void insertMkldnnPrePackedConv2dOp(std::shared_ptr& graph) { // Replace _convolution with conv2d graph_rewrite_helper::replaceConvolutionWithAtenConv(graph); @@ -248,8 +316,13 @@ void insertMkldnnPrePackedConv2dOp(std::shared_ptr& graph) { insertPrePackedConvOp(graph->block()); } +void insertMkldnnPrePackedLinearOp(std::shared_ptr& graph) { + insertPrePackedLinearOp(graph->block()); +} + void insertMkldnnPrePackedOps(std::shared_ptr& graph) { insertMkldnnPrePackedConv2dOp(graph); + insertMkldnnPrePackedLinearOp(graph); } void insertMkldnnPrePackedOps(script::Module& module) { @@ -350,14 +423,25 @@ void FuseEltwiseWithPackedOps(std::shared_ptr& graph) { "%input, %weight, %bias, %stride:int[], %padding:int[], %dilation:int[], %groups:int,"), std::string("%weight, %bias, %stride, %padding, %dilation, %groups,")); - // TODO: add linear, matmul + RewriteEltwiseGraph( + graph, + mkldnn::fusion_attr_map(), + std::string("mkldnn_prepacked::linear_prepack"), + std::string("mkldnn_prepacked::linear_run"), + std::string("mkldnn.LinearOpContext"), + std::string("%input, %weight, %bias,"), + std::string("%weight, %bias,")); + + // TODO: add matmul } void PrePackingOpsFolder(Block* b) { auto is_foldable_op = [](const Node* n) -> bool { return ( n->kind() == - Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack")); + Symbol::fromQualString("mkldnn_prepacked::conv2d_prepack") || + n->kind() == + Symbol::fromQualString("mkldnn_prepacked::linear_prepack")); }; std::unordered_set nodes_to_delete; diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp index a885d7c177ba3..89280788f59e2 100644 --- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp +++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp @@ -81,6 +81,7 @@ static const OperatorSet& supported_non_eltwise_set() { "aten::batch_norm(Tensor input, Tensor? weight, Tensor? bias, Tensor? running_mean, Tensor? running_var, bool training, float momentum, float eps, bool cudnn_enabled) -> Tensor", "aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", "aten::matmul(Tensor self, Tensor other) -> Tensor", + "aten::linear(Tensor input, Tensor weight, Tensor? bias=None) -> Tensor", }; // clang-format on return supported_non_eltwise_set; @@ -1058,6 +1059,12 @@ class TensorExprFuser { return false; } } + if (node->kind() == aten::linear) { + if (!tensorexpr::mkldnnPrepackedLinearIsSupportedJit(node)) { + GRAPH_DEBUG("Shapes of linear inputs are not supported"); + return false; + } + } return true; } diff --git a/torch/csrc/jit/tensorexpr/external_functions.cpp b/torch/csrc/jit/tensorexpr/external_functions.cpp index 3b9ad4d1f255b..22082f3f8e2b5 100644 --- a/torch/csrc/jit/tensorexpr/external_functions.cpp +++ b/torch/csrc/jit/tensorexpr/external_functions.cpp @@ -1434,6 +1434,27 @@ void nnc_mkldnn_prepacked_conv_run( context->run(x, buf_data[0]); } +void nnc_mkldnn_prepacked_linear_run( + int64_t bufs_num, + void** buf_data, + int64_t* buf_ranks, + int64_t* buf_dims, + int64_t* buf_strides, + int8_t* buf_dtypes, + int64_t args_num, + int64_t* extra_args) { + using namespace at::native::mkldnn; + + auto tensors = constructTensors( + bufs_num - 1, buf_data, buf_ranks, buf_dims, buf_strides, buf_dtypes); + + const at::Tensor& x = tensors[1]; + auto context = reinterpret_cast(buf_data[2]); + + at::Tensor output = context->run(x); + context->run(x, buf_data[0]); +} + #endif // AT_MKLDNN_ENABLED() #ifdef USE_XNNPACK @@ -1618,6 +1639,10 @@ const static RegisterNNCExternalFunction nnc_embedding( const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_conv_run( "nnc_mkldnn_prepacked_conv_run", nnc_mkldnn_prepacked_conv_run); + +const static RegisterNNCExternalFunction reg_nnc_mkldnn_prepacked_linear_run( + "nnc_mkldnn_prepacked_linear_run", + nnc_mkldnn_prepacked_linear_run); #endif // AT_MKLDNN_ENABLED() #ifdef USE_XNNPACK diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp index 5c7c783e1b78a..d1d24e98f18e1 100644 --- a/torch/csrc/jit/tensorexpr/kernel.cpp +++ b/torch/csrc/jit/tensorexpr/kernel.cpp @@ -314,6 +314,43 @@ bool mkldnnPrepackedConvIsSupportedJit(const torch::jit::Node* node) { return false; } +bool mkldnnPrepackedLinearIsSupportedJit(const torch::jit::Node* node) { +#if AT_MKLDNN_ENABLED() + + auto const& input = getTensorInfoJit(node->input(0)); + auto const& weight = getTensorInfoJit(node->input(1)); + + // TODO: skip bias check when bias is None + auto const& bias = getTensorInfoJit(node->input(2)); + + // Everything should be statically known. + if (!input || !weight) { + GRAPH_DEBUG( + "mkldnnPrepackedLinearIsSupportedJit: some params aren't static"); + return false; + } + + // Weights and bias should be Constant when using mkldnn backend + if (node->input(1)->node()->kind() != prim::Constant || + node->input(2)->node()->kind() != prim::Constant) { + GRAPH_DEBUG( + "mkldnnPrepackedLinearIsSupportedJit: weight or bias is not Constant"); + return false; + } + + // Inputs should be contiguous, or the TE will needlessly transpose them. + if (!isContiguous(node->input(0)) || !isContiguous(node->input(1))) { + GRAPH_DEBUG( + "mkldnnPrepackedLinearIsSupportedJit: some inputs are not contiguous"); + return false; + } + + return mkldnnPrepackedLinearIsSupported(*input, *weight); + +#endif + return false; +} + // The fuser currently only supports matmul of 2D x 2D matrices bool matmulIsSupported(const torch::jit::Node* node) { auto const& input0 = getTensorInfoJit(node->input(0)); diff --git a/torch/csrc/jit/tensorexpr/kernel.h b/torch/csrc/jit/tensorexpr/kernel.h index dc803ced2e29d..88fa9f606eaf8 100644 --- a/torch/csrc/jit/tensorexpr/kernel.h +++ b/torch/csrc/jit/tensorexpr/kernel.h @@ -26,6 +26,9 @@ struct SmallSizeTPairHash { bool conv2dIsSupportedJit(const Node* node); // Returns true if the TE fuser supports this conv2d with mkldnn prepacked conv. bool mkldnnPrepackedConvIsSupportedJit(const Node* node); +// Returns true if the TE fuser supports this linear with mkldnn prepacked +// linear. +bool mkldnnPrepackedLinearIsSupportedJit(const Node* node); // Returns true if the TE fuser supports this matmul. bool matmulIsSupported(const Node* node); template diff --git a/torch/csrc/jit/tensorexpr/lowerings.cpp b/torch/csrc/jit/tensorexpr/lowerings.cpp index 5bfab4e26915c..2c48e91636469 100644 --- a/torch/csrc/jit/tensorexpr/lowerings.cpp +++ b/torch/csrc/jit/tensorexpr/lowerings.cpp @@ -49,6 +49,9 @@ int nnc_lowerings_lazy_registration() { RegisterNNCLoweringsFunction mkldnn_prepacked_conv2d_run( {"mkldnn_prepacked::conv2d_run(Tensor X, __torch__.torch.classes.mkldnn.ConvOpContext W_prepack) -> (Tensor Y)"}, computeMkldnnPrepackedConvRun); + RegisterNNCLoweringsFunction mkldnn_prepacked_linear_run( + {"mkldnn_prepacked::linear_run(Tensor X, __torch__.torch.classes.mkldnn.LinearOpContext W_prepack) -> (Tensor Y)"}, + computeMkldnnPrepackedLinearRun); #endif // AT_MKLDNN_ENABLED() RegisterNNCLoweringsFunction aten_sub( diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.cpp b/torch/csrc/jit/tensorexpr/operators/matmul.cpp index 408ae6e8c3610..329dacca3decc 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.cpp +++ b/torch/csrc/jit/tensorexpr/operators/matmul.cpp @@ -1,3 +1,4 @@ +#include #include #include @@ -5,6 +6,18 @@ namespace torch { namespace jit { namespace tensorexpr { +bool mkldnnPrepackedLinearIsSupported( + const TensorInfo& input, + const TensorInfo& weight) { + // TODO: only support BF16 to make sure there's no performance regression + // if (input.dtype != c10::ScalarType::BFloat16 || + // weight.dtype != c10::ScalarType::BFloat16) { + // GRAPH_DEBUG("mkldnnPrepackedLinearIsSupported: only bfloat16 allowed"); + // return false; + // } + return true; +} + Tensor computeMatmul( const std::vector& inputs, const std::vector& outputShape, @@ -76,6 +89,26 @@ Tensor computeAddMM( inputs[4])})); // TODO: handle other dtypes of alpha and beta } +Tensor computeMkldnnPrepackedLinearRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const c10::optional& outputType, + at::Device device) { + Dtype dtype = kFloat; + if (outputType) { + dtype = Dtype(*outputType); + } + + BufHandle ResultBuf( + "mkldnn_prepacked_linear_run", outputShape, outputStrides, dtype); + const BufHandle& inp = c10::get(inputs[0]); + const BufHandle& prepacked = c10::get(inputs[1]); + StmtPtr s = ExternalCall::make( + ResultBuf, "nnc_mkldnn_prepacked_linear_run", {inp, prepacked}, {}); + return Tensor(ResultBuf.node(), s); +} + } // namespace tensorexpr } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/tensorexpr/operators/matmul.h b/torch/csrc/jit/tensorexpr/operators/matmul.h index 70f3f4bf7bf03..0c1a6f95a5520 100644 --- a/torch/csrc/jit/tensorexpr/operators/matmul.h +++ b/torch/csrc/jit/tensorexpr/operators/matmul.h @@ -1,11 +1,15 @@ #pragma once #include +#include namespace torch { namespace jit { namespace tensorexpr { +bool mkldnnPrepackedLinearIsSupported( + const TensorInfo& input, + const TensorInfo& weight); Tensor computeMatmul( const std::vector& inputs, const std::vector& outputShape, @@ -18,7 +22,12 @@ Tensor computeAddMM( const std::vector& outputStrides, const c10::optional& outputType, at::Device device); - +Tensor computeMkldnnPrepackedLinearRun( + const std::vector& inputs, + const std::vector& outputShape, + const std::vector& outputStrides, + const c10::optional& outputType, + at::Device device); } // namespace tensorexpr } // namespace jit } // namespace torch