From 227c4419fcb3e3bbeb3fbc3c4d52922e9cfa2be7 Mon Sep 17 00:00:00 2001 From: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com> Date: Thu, 25 Apr 2024 11:28:34 -0700 Subject: [PATCH] add bf16 support for few ops (#20385) ### Description Add bf16 support for below ops: ConstantOfShape Exp Erf convolution PythonOp ### Motivation and Context phimm model works on bf16, ORT need support bf16 on previous ops to work with phimm on bf16 --- docs/OperatorKernels.md | 6 +- .../cpu/generator/constant_of_shape_base.h | 6 +- .../core/providers/cuda/cu_inc/common.cuh | 3 + .../providers/cuda/cuda_execution_provider.cc | 6 ++ .../cuda/math/unary_elementwise_ops.cc | 4 +- .../cuda/math/unary_elementwise_ops_impl.cu | 2 +- .../core/providers/rocm/cu_inc/common.cuh | 3 + .../providers/rocm/rocm_execution_provider.cc | 6 ++ .../provider_bridge_provider.cc | 2 + .../shared_library/provider_interfaces.h | 1 + .../core/session/provider_bridge_ort.cc | 1 + .../cpu/generator/constant_of_shape_test.cc | 7 +- .../core/graph/training_op_defs.cc | 4 +- .../ortmodule/_custom_op_symbolic_registry.py | 97 +++++++++++++++++++ .../python/orttraining_test_ortmodule_api.py | 36 +++++++ 15 files changed, 174 insertions(+), 10 deletions(-) diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index 635f65696eae2..823125ef4ef4c 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -71,7 +71,7 @@ Do not modify directly.* |ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))| |ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||20|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| -|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| +|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float)| |||[1, 10]|**T** = tensor(float)| |ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)| @@ -601,9 +601,9 @@ Do not modify directly.* |Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)| |||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)| |||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)| -|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)| -|Exp|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)| +|Exp|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)| |Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| |||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)| diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h index 5ce7ab8553276..2e4e1730d5e6a 100644 --- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h +++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h @@ -15,13 +15,16 @@ namespace onnxruntime { +// Add bf16 support for ConstantOfShape operator for phimm model. +// Although ONNX don't have bf16 support in opset-9 for ConstantOfShape we add support here: +// https://github.com/onnx/onnx/blob/main/docs/Changelog.md#constantofshape-9 using ConstantOfShapeDefaultOutputTypes = TypeList< MLFloat16, float, double, int8_t, int16_t, int32_t, int64_t, uint8_t, uint16_t, uint32_t, uint64_t, - bool>; + bool, BFloat16>; using ConstantOfShapeDefaultOutputTypesOpset20 = TypeList< @@ -158,6 +161,7 @@ void ConstantOfShapeBase::SetValueFromTensorProto(const O CASE_FETCH_VALUE_DATA(uint16_t) CASE_FETCH_VALUE_DATA(uint32_t) CASE_FETCH_VALUE_DATA(uint64_t) + CASE_FETCH_VALUE_DATA(BFloat16) default: ORT_THROW("Unsupported value attribute datatype: ", tensor_type); } diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh index 052dd05574ab1..db36754319309 100644 --- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh +++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh @@ -231,6 +231,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); } template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } +template <> +__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); } + template __device__ __host__ __inline__ T _Round(T a); diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 652626ce9e241..4b8a481c02ee9 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -1031,9 +1031,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp); +// Add bf16 support for Exp in opset 13+ for phimm model +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf); +// Add bf16 support for Erf in opset 13+ for phimm model +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max); @@ -1947,9 +1951,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc index 24593b255371c..fb03b4326c4e8 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc @@ -244,8 +244,8 @@ UNARY_OP_HFD(Ceil, 13) UNARY_OP_HFD(Reciprocal, 13) UNARY_OP_HFDX(Sqrt, 13) UNARY_OP_HFD(Log, 13) -UNARY_OP_HFD(Exp, 13) -UNARY_OP_HFD(Erf, 13) +UNARY_OP_HFDX(Exp, 13) +UNARY_OP_HFDX(Erf, 13) UNARY_OP_BWUZCSILHFD(Sign, 13) UNARY_LOGICALOP_NOT_TYPED(1, bool) diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu index 2cdfcda5be26a..295969af07973 100644 --- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu +++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu @@ -87,7 +87,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp) -SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf) +SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Erf) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin) SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos) diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh index 1698e5ca8478c..cdb4d1f7edac6 100644 --- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh +++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh @@ -138,6 +138,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); } template <> __device__ __inline__ half _Erf(half a) { return half(erff((float)a)); } +template <> +__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); } + template __device__ __inline__ T _Round(T a); diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 4b0fd783deeac..76964e1aed93c 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -1021,9 +1021,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp); +// Add bf16 support for Exp in opset 13+ for phimm model +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf); +// Add bf16 support for Erf in opset 13+ for phimm model +class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum); class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max); @@ -1973,9 +1977,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc index 7b73ab36b3742..575434d19bf35 100644 --- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc +++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc @@ -478,6 +478,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } template <> +Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } +template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } template <> Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index 8c8d5b1fd460a..1824a82995bce 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -198,6 +198,7 @@ struct ProviderHost { virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) = 0; + virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) = 0; virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) = 0; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 507c094422509..9ec6bb0181004 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -271,6 +271,7 @@ struct ProviderHostImpl : ProviderHost { Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } + Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); } diff --git a/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc b/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc index d6efa01c9ae8b..3d7805d237935 100644 --- a/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc +++ b/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc @@ -85,6 +85,10 @@ inline void SetValue(TensorProto& t_proto, MLFloat16 value) { t_proto.mutable_int32_data()->Add(value.val); } +inline void SetValue(TensorProto& t_proto, BFloat16 value) { + t_proto.mutable_int32_data()->Add(value.val); +} + // This works for int64_t template inline void SetValue(TensorProto& t_proto, T value, @@ -100,7 +104,7 @@ inline void SetValue(TensorProto& t_proto, T value, t_proto.mutable_uint64_data()->Add(value); } -// For everything else except float, double and MLFloat16 +// For everything else except float, double, MLFloat16 and BFloat16 template inline void SetValue(TensorProto& t_proto, T value, typename std::enable_if::value && @@ -153,6 +157,7 @@ TEST(ConstantOfShape, TypeTests) { RunTypedTest(TensorProto::UINT16, uint16_t(6U), opset); RunTypedTest(TensorProto::UINT32, uint32_t(32U), opset); RunTypedTest(TensorProto::UINT64, uint64_t(64U), opset); + RunTypedTest(TensorProto::BFLOAT16, BFloat16::FromBits(static_cast(7)), opset); } } diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc index 21207c8e3ce40..677f383264c75 100644 --- a/orttraining/orttraining/core/graph/training_op_defs.cc +++ b/orttraining/orttraining/core/graph/training_op_defs.cc @@ -3949,7 +3949,7 @@ Return true if all elements are true and false otherwise. static_cast(1)) .TypeConstraint( "T", - OpSchema::all_tensor_types(), + OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") .TypeConstraint( "TInt64", @@ -4116,7 +4116,7 @@ Return true if all elements are true and false otherwise. static_cast(1)) .TypeConstraint( "T", - OpSchema::all_tensor_types(), + OpSchema::all_tensor_types_ir4(), "Allow inputs and outputs to be any kind of tensor.") .TypeConstraint( "TInt64", diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py index dd7fea3ceda10..0bd29b8d155c4 100644 --- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py +++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py @@ -847,6 +847,103 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable): return res +# Adapted from torch.onnx.symbolic_opset9._convolution - +# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset9.py#L2334 +# We override aten::_convolution here to support bf16 for phimm model from GenAI team. +# For bf16 inputs, we will convert input to float32, do convolution then convert output back to bf16. +# TODO: This might have negative impact on performance. +@register_symbolic("_convolution") +@parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i") +def convolution( + g, + input, + weight, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32=None, +): + from torch.onnx.symbolic_opset9 import _convolution + + input_casted = ( + g.op("Cast", input, to_i=torch.onnx.TensorProtoDataType.FLOAT) + if input.type().scalarType() == "BFloat16" + else input + ) + weight_casted = ( + g.op("Cast", weight, to_i=torch.onnx.TensorProtoDataType.FLOAT) + if weight.type().scalarType() == "BFloat16" + else weight + ) + + n = _convolution( + g, + input_casted, + weight_casted, + bias, + stride, + padding, + dilation, + transposed, + output_padding, + groups, + benchmark, + deterministic, + cudnn_enabled, + allow_tf32, + ) + + n_casted = ( + g.op("Cast", n, to_i=torch.onnx.TensorProtoDataType.BFLOAT16) if input.type().scalarType() == "BFloat16" else n + ) + return n_casted + + +# Adapted from torch.onnx.symbolic_opset9._convolution_mode - +# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset9.py#L2406 +# We override aten::_convolution_mode here to support bf16 for phimm model from GenAI team. +# For bf16 inputs, we will convert input to float32, do convolution then convert output back to bf16. +# TODO: This might have negative impact on performance. +@register_symbolic("_convolution_mode") +@parse_args("v", "v", "v", "is", "s", "is", "i") +def convolution_mode( + g, + input, + weight, + bias, + stride, + padding, + dilation, + groups, +): + from torch.onnx.symbolic_opset9 import _convolution_mode + + input_casted = ( + g.op("Cast", input, to_i=torch.onnx.TensorProtoDataType.FLOAT) + if input.type().scalarType() == "BFloat16" + else input + ) + weight_casted = ( + g.op("Cast", weight, to_i=torch.onnx.TensorProtoDataType.FLOAT) + if weight.type().scalarType() == "BFloat16" + else weight + ) + + n = _convolution_mode(g, input_casted, weight_casted, bias, stride, padding, dilation, groups) + + n_casted = ( + g.op("Cast", n, to_i=torch.onnx.TensorProtoDataType.BFLOAT16) if input.type().scalarType() == "BFloat16" else n + ) + return n_casted + + # Adapted from torch.onnx.symbolic_opset13.softmax - # https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27 # We don't need overloads symbolic_opset9 because training support opsets >= 13. diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py index 0839f957c26f3..ae319ab4ab473 100644 --- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py +++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py @@ -6610,6 +6610,42 @@ def run_step(model, attn_weight): assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected" +def test_aten_conv_bf16(): + class NeuralNetConv(torch.nn.Module): + def __init__(self): + super().__init__() + self.conv = torch.nn.Conv2d( + in_channels=3, + out_channels=1024, + kernel_size=14, + stride=14, + bias=False, + dtype=torch.bfloat16, + ) + + def forward(self, input): + return self.conv(input) + + device = "cuda" + pt_model = NeuralNetConv().to(device) + ort_model = ORTModule(copy.deepcopy(pt_model)) + + def run_step(model, input): + prediction = model(input) + prediction.sum().backward() + return prediction + + # reset manual seed to reset the generator + torch.manual_seed(2333) + pt_input = torch.randn([2, 3, 336, 336], dtype=torch.bfloat16, device=device, requires_grad=True) + ort_input = copy.deepcopy(pt_input) + pt_prediction = run_step(pt_model, pt_input) + ort_prediction = run_step(ort_model, ort_input) + + _test_helpers.assert_values_are_close(ort_prediction, pt_prediction) + _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad) + + @pytest.mark.parametrize("memory_optimization_level", [None, 0, 1, 2]) @pytest.mark.parametrize("allow_gradient_checkpoint_export", [None, 0, 1]) @pytest.mark.parametrize("fx", ["torch", "deepspeed"])