From 227c4419fcb3e3bbeb3fbc3c4d52922e9cfa2be7 Mon Sep 17 00:00:00 2001
From: Frank Dong <123416088+frank-dong-ms@users.noreply.github.com>
Date: Thu, 25 Apr 2024 11:28:34 -0700
Subject: [PATCH] add bf16 support for few ops (#20385)
### Description
Add bf16 support for below ops:
ConstantOfShape
Exp
Erf
convolution
PythonOp
### Motivation and Context
phimm model works on bf16, ORT need support bf16 on previous ops to work
with phimm on bf16
---
docs/OperatorKernels.md | 6 +-
.../cpu/generator/constant_of_shape_base.h | 6 +-
.../core/providers/cuda/cu_inc/common.cuh | 3 +
.../providers/cuda/cuda_execution_provider.cc | 6 ++
.../cuda/math/unary_elementwise_ops.cc | 4 +-
.../cuda/math/unary_elementwise_ops_impl.cu | 2 +-
.../core/providers/rocm/cu_inc/common.cuh | 3 +
.../providers/rocm/rocm_execution_provider.cc | 6 ++
.../provider_bridge_provider.cc | 2 +
.../shared_library/provider_interfaces.h | 1 +
.../core/session/provider_bridge_ort.cc | 1 +
.../cpu/generator/constant_of_shape_test.cc | 7 +-
.../core/graph/training_op_defs.cc | 4 +-
.../ortmodule/_custom_op_symbolic_registry.py | 97 +++++++++++++++++++
.../python/orttraining_test_ortmodule_api.py | 36 +++++++
15 files changed, 174 insertions(+), 10 deletions(-)
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index 635f65696eae2..823125ef4ef4c 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -71,7 +71,7 @@ Do not modify directly.*
|ConcatFromSequence|*in* input_sequence:**S**
*out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|ConstantOfShape|*in* input:**T1**
*out* output:**T2**|21+|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||20|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
-|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
+|||[9, 19]|**T1** = tensor(int64)
**T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Conv|*in* X:**T**
*in* W:**T**
*in* B:**T**
*out* Y:**T**|11+|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
|ConvInteger|*in* x:**T1**
*in* w:**T2**
*in* x_zero_point:**T1**
*in* w_zero_point:**T2**
*out* y:**T3**|10+|**T1** = tensor(uint8)
**T2** = tensor(uint8)
**T3** = tensor(int32)|
@@ -601,9 +601,9 @@ Do not modify directly.*
|Equal|*in* A:**T**
*in* B:**T**
*out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)
**T1** = tensor(bool)|
|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)|
-|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Erf|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
-|Exp|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
+|Exp|*in* input:**T**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Expand|*in* input:**T**
*in* shape:**tensor(int64)**
*out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
diff --git a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h
index 5ce7ab8553276..2e4e1730d5e6a 100644
--- a/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h
+++ b/onnxruntime/core/providers/cpu/generator/constant_of_shape_base.h
@@ -15,13 +15,16 @@
namespace onnxruntime {
+// Add bf16 support for ConstantOfShape operator for phimm model.
+// Although ONNX don't have bf16 support in opset-9 for ConstantOfShape we add support here:
+// https://github.com/onnx/onnx/blob/main/docs/Changelog.md#constantofshape-9
using ConstantOfShapeDefaultOutputTypes =
TypeList<
MLFloat16,
float, double,
int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t,
- bool>;
+ bool, BFloat16>;
using ConstantOfShapeDefaultOutputTypesOpset20 =
TypeList<
@@ -158,6 +161,7 @@ void ConstantOfShapeBase::SetValueFromTensorProto(const O
CASE_FETCH_VALUE_DATA(uint16_t)
CASE_FETCH_VALUE_DATA(uint32_t)
CASE_FETCH_VALUE_DATA(uint64_t)
+ CASE_FETCH_VALUE_DATA(BFloat16)
default:
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
}
diff --git a/onnxruntime/core/providers/cuda/cu_inc/common.cuh b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
index 052dd05574ab1..db36754319309 100644
--- a/onnxruntime/core/providers/cuda/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/cuda/cu_inc/common.cuh
@@ -231,6 +231,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
template <>
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
+template <>
+__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }
+
template
__device__ __host__ __inline__ T _Round(T a);
diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
index 652626ce9e241..4b8a481c02ee9 100644
--- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
+++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc
@@ -1031,9 +1031,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
+// Add bf16 support for Exp in opset 13+ for phimm model
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
+// Add bf16 support for Erf in opset 13+ for phimm model
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max);
@@ -1947,9 +1951,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
index 24593b255371c..fb03b4326c4e8 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
@@ -244,8 +244,8 @@ UNARY_OP_HFD(Ceil, 13)
UNARY_OP_HFD(Reciprocal, 13)
UNARY_OP_HFDX(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
-UNARY_OP_HFD(Exp, 13)
-UNARY_OP_HFD(Erf, 13)
+UNARY_OP_HFDX(Exp, 13)
+UNARY_OP_HFDX(Erf, 13)
UNARY_OP_BWUZCSILHFD(Sign, 13)
UNARY_LOGICALOP_NOT_TYPED(1, bool)
diff --git a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
index 2cdfcda5be26a..295969af07973 100644
--- a/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
+++ b/onnxruntime/core/providers/cuda/math/unary_elementwise_ops_impl.cu
@@ -87,7 +87,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
-SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
+SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
diff --git a/onnxruntime/core/providers/rocm/cu_inc/common.cuh b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
index 1698e5ca8478c..cdb4d1f7edac6 100644
--- a/onnxruntime/core/providers/rocm/cu_inc/common.cuh
+++ b/onnxruntime/core/providers/rocm/cu_inc/common.cuh
@@ -138,6 +138,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
template <>
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }
+template <>
+__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }
+
template
__device__ __inline__ T _Round(T a);
diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
index 4b0fd783deeac..76964e1aed93c 100644
--- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
+++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc
@@ -1021,9 +1021,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
+// Add bf16 support for Exp in opset 13+ for phimm model
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
+// Add bf16 support for Erf in opset 13+ for phimm model
+class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max);
@@ -1973,9 +1977,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
diff --git a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
index 7b73ab36b3742..575434d19bf35 100644
--- a/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
+++ b/onnxruntime/core/providers/shared_library/provider_bridge_provider.cc
@@ -478,6 +478,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
template <>
+Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
+template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h
index 8c8d5b1fd460a..1824a82995bce 100644
--- a/onnxruntime/core/providers/shared_library/provider_interfaces.h
+++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h
@@ -198,6 +198,7 @@ struct ProviderHost {
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) = 0;
+ virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) = 0;
diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc
index 507c094422509..9ec6bb0181004 100644
--- a/onnxruntime/core/session/provider_bridge_ort.cc
+++ b/onnxruntime/core/session/provider_bridge_ort.cc
@@ -271,6 +271,7 @@ struct ProviderHostImpl : ProviderHost {
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
+ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
diff --git a/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc b/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc
index d6efa01c9ae8b..3d7805d237935 100644
--- a/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc
+++ b/onnxruntime/test/providers/cpu/generator/constant_of_shape_test.cc
@@ -85,6 +85,10 @@ inline void SetValue(TensorProto& t_proto, MLFloat16 value) {
t_proto.mutable_int32_data()->Add(value.val);
}
+inline void SetValue(TensorProto& t_proto, BFloat16 value) {
+ t_proto.mutable_int32_data()->Add(value.val);
+}
+
// This works for int64_t
template
inline void SetValue(TensorProto& t_proto, T value,
@@ -100,7 +104,7 @@ inline void SetValue(TensorProto& t_proto, T value,
t_proto.mutable_uint64_data()->Add(value);
}
-// For everything else except float, double and MLFloat16
+// For everything else except float, double, MLFloat16 and BFloat16
template
inline void SetValue(TensorProto& t_proto, T value,
typename std::enable_if::value &&
@@ -153,6 +157,7 @@ TEST(ConstantOfShape, TypeTests) {
RunTypedTest(TensorProto::UINT16, uint16_t(6U), opset);
RunTypedTest(TensorProto::UINT32, uint32_t(32U), opset);
RunTypedTest(TensorProto::UINT64, uint64_t(64U), opset);
+ RunTypedTest(TensorProto::BFLOAT16, BFloat16::FromBits(static_cast(7)), opset);
}
}
diff --git a/orttraining/orttraining/core/graph/training_op_defs.cc b/orttraining/orttraining/core/graph/training_op_defs.cc
index 21207c8e3ce40..677f383264c75 100644
--- a/orttraining/orttraining/core/graph/training_op_defs.cc
+++ b/orttraining/orttraining/core/graph/training_op_defs.cc
@@ -3949,7 +3949,7 @@ Return true if all elements are true and false otherwise.
static_cast(1))
.TypeConstraint(
"T",
- OpSchema::all_tensor_types(),
+ OpSchema::all_tensor_types_ir4(),
"Allow inputs and outputs to be any kind of tensor.")
.TypeConstraint(
"TInt64",
@@ -4116,7 +4116,7 @@ Return true if all elements are true and false otherwise.
static_cast(1))
.TypeConstraint(
"T",
- OpSchema::all_tensor_types(),
+ OpSchema::all_tensor_types_ir4(),
"Allow inputs and outputs to be any kind of tensor.")
.TypeConstraint(
"TInt64",
diff --git a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
index dd7fea3ceda10..0bd29b8d155c4 100644
--- a/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
+++ b/orttraining/orttraining/python/training/ortmodule/_custom_op_symbolic_registry.py
@@ -847,6 +847,103 @@ def layer_norm(g, input, normalized_shape, weight, bias, eps, cudnn_enable):
return res
+# Adapted from torch.onnx.symbolic_opset9._convolution -
+# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset9.py#L2334
+# We override aten::_convolution here to support bf16 for phimm model from GenAI team.
+# For bf16 inputs, we will convert input to float32, do convolution then convert output back to bf16.
+# TODO: This might have negative impact on performance.
+@register_symbolic("_convolution")
+@parse_args("v", "v", "v", "is", "is", "is", "i", "is", "i", "i", "i", "i", "i")
+def convolution(
+ g,
+ input,
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ benchmark,
+ deterministic,
+ cudnn_enabled,
+ allow_tf32=None,
+):
+ from torch.onnx.symbolic_opset9 import _convolution
+
+ input_casted = (
+ g.op("Cast", input, to_i=torch.onnx.TensorProtoDataType.FLOAT)
+ if input.type().scalarType() == "BFloat16"
+ else input
+ )
+ weight_casted = (
+ g.op("Cast", weight, to_i=torch.onnx.TensorProtoDataType.FLOAT)
+ if weight.type().scalarType() == "BFloat16"
+ else weight
+ )
+
+ n = _convolution(
+ g,
+ input_casted,
+ weight_casted,
+ bias,
+ stride,
+ padding,
+ dilation,
+ transposed,
+ output_padding,
+ groups,
+ benchmark,
+ deterministic,
+ cudnn_enabled,
+ allow_tf32,
+ )
+
+ n_casted = (
+ g.op("Cast", n, to_i=torch.onnx.TensorProtoDataType.BFLOAT16) if input.type().scalarType() == "BFloat16" else n
+ )
+ return n_casted
+
+
+# Adapted from torch.onnx.symbolic_opset9._convolution_mode -
+# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset9.py#L2406
+# We override aten::_convolution_mode here to support bf16 for phimm model from GenAI team.
+# For bf16 inputs, we will convert input to float32, do convolution then convert output back to bf16.
+# TODO: This might have negative impact on performance.
+@register_symbolic("_convolution_mode")
+@parse_args("v", "v", "v", "is", "s", "is", "i")
+def convolution_mode(
+ g,
+ input,
+ weight,
+ bias,
+ stride,
+ padding,
+ dilation,
+ groups,
+):
+ from torch.onnx.symbolic_opset9 import _convolution_mode
+
+ input_casted = (
+ g.op("Cast", input, to_i=torch.onnx.TensorProtoDataType.FLOAT)
+ if input.type().scalarType() == "BFloat16"
+ else input
+ )
+ weight_casted = (
+ g.op("Cast", weight, to_i=torch.onnx.TensorProtoDataType.FLOAT)
+ if weight.type().scalarType() == "BFloat16"
+ else weight
+ )
+
+ n = _convolution_mode(g, input_casted, weight_casted, bias, stride, padding, dilation, groups)
+
+ n_casted = (
+ g.op("Cast", n, to_i=torch.onnx.TensorProtoDataType.BFLOAT16) if input.type().scalarType() == "BFloat16" else n
+ )
+ return n_casted
+
+
# Adapted from torch.onnx.symbolic_opset13.softmax -
# https://github.com/pytorch/pytorch/blob/cf06189a2d2785ac493bcd0d55e520af5a0e3b97/torch/onnx/symbolic_opset13.py#L27
# We don't need overloads symbolic_opset9 because training support opsets >= 13.
diff --git a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
index 0839f957c26f3..ae319ab4ab473 100644
--- a/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
+++ b/orttraining/orttraining/test/python/orttraining_test_ortmodule_api.py
@@ -6610,6 +6610,42 @@ def run_step(model, attn_weight):
assert to_value == pytorch_type_to_onnx_dtype(softmax_compute_type), "Cast to attribute is not as expected"
+def test_aten_conv_bf16():
+ class NeuralNetConv(torch.nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.conv = torch.nn.Conv2d(
+ in_channels=3,
+ out_channels=1024,
+ kernel_size=14,
+ stride=14,
+ bias=False,
+ dtype=torch.bfloat16,
+ )
+
+ def forward(self, input):
+ return self.conv(input)
+
+ device = "cuda"
+ pt_model = NeuralNetConv().to(device)
+ ort_model = ORTModule(copy.deepcopy(pt_model))
+
+ def run_step(model, input):
+ prediction = model(input)
+ prediction.sum().backward()
+ return prediction
+
+ # reset manual seed to reset the generator
+ torch.manual_seed(2333)
+ pt_input = torch.randn([2, 3, 336, 336], dtype=torch.bfloat16, device=device, requires_grad=True)
+ ort_input = copy.deepcopy(pt_input)
+ pt_prediction = run_step(pt_model, pt_input)
+ ort_prediction = run_step(ort_model, ort_input)
+
+ _test_helpers.assert_values_are_close(ort_prediction, pt_prediction)
+ _test_helpers.assert_values_are_close(ort_input.grad, pt_input.grad)
+
+
@pytest.mark.parametrize("memory_optimization_level", [None, 0, 1, 2])
@pytest.mark.parametrize("allow_gradient_checkpoint_export", [None, 0, 1])
@pytest.mark.parametrize("fx", ["torch", "deepspeed"])