Skip to content

Commit

Permalink
add bf16 support for few ops (microsoft#20385)
Browse files Browse the repository at this point in the history
### Description
Add bf16 support for below ops:
ConstantOfShape
Exp
Erf
convolution
PythonOp



### Motivation and Context
phimm model works on bf16, ORT need support bf16 on previous ops to work
with phimm on bf16
  • Loading branch information
frank-dong-ms authored Apr 25, 2024
1 parent 464f199 commit 227c441
Show file tree
Hide file tree
Showing 15 changed files with 174 additions and 10 deletions.
6 changes: 3 additions & 3 deletions docs/OperatorKernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Do not modify directly.*
|ConcatFromSequence|*in* input_sequence:**S**<br> *out* concat_result:**T**|11+|**S** = seq(tensor(bfloat16)), seq(tensor(bool)), seq(tensor(double)), seq(tensor(float)), seq(tensor(float16)), seq(tensor(int16)), seq(tensor(int32)), seq(tensor(int64)), seq(tensor(int8)), seq(tensor(string)), seq(tensor(uint16)), seq(tensor(uint32)), seq(tensor(uint64)), seq(tensor(uint8))|
|ConstantOfShape|*in* input:**T1**<br> *out* output:**T2**|21+|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||20|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e4m3fnuz), tensor(float8e5m2), tensor(float8e5m2fnuz), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[9, 19]|**T1** = tensor(int64)<br/> **T2** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|Conv|*in* X:**T**<br> *in* W:**T**<br> *in* B:**T**<br> *out* Y:**T**|11+|**T** = tensor(float)|
|||[1, 10]|**T** = tensor(float)|
|ConvInteger|*in* x:**T1**<br> *in* w:**T2**<br> *in* x_zero_point:**T1**<br> *in* w_zero_point:**T2**<br> *out* y:**T3**|10+|**T1** = tensor(uint8)<br/> **T2** = tensor(uint8)<br/> **T3** = tensor(int32)|
Expand Down Expand Up @@ -601,9 +601,9 @@ Do not modify directly.*
|Equal|*in* A:**T**<br> *in* B:**T**<br> *out* C:**T1**|13+|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)<br/> **T1** = tensor(bool)|
|||[11, 12]|**T** = tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int32), tensor(int64), tensor(uint32), tensor(uint64)|
|||[7, 10]|**T** = tensor(bool), tensor(int32), tensor(int64)|
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|Erf|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[9, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Exp|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
|Exp|*in* input:**T**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|||[6, 12]|**T** = tensor(double), tensor(float), tensor(float16)|
|Expand|*in* input:**T**<br> *in* shape:**tensor(int64)**<br> *out* output:**T**|13+|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
|||[8, 12]|**T** = tensor(bfloat16), tensor(bool), tensor(double), tensor(float), tensor(float16), tensor(int16), tensor(int32), tensor(int64), tensor(int8), tensor(uint16), tensor(uint32), tensor(uint64), tensor(uint8)|
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@

namespace onnxruntime {

// Add bf16 support for ConstantOfShape operator for phimm model.
// Although ONNX don't have bf16 support in opset-9 for ConstantOfShape we add support here:
// https://github.com/onnx/onnx/blob/main/docs/Changelog.md#constantofshape-9
using ConstantOfShapeDefaultOutputTypes =
TypeList<
MLFloat16,
float, double,
int8_t, int16_t, int32_t, int64_t,
uint8_t, uint16_t, uint32_t, uint64_t,
bool>;
bool, BFloat16>;

using ConstantOfShapeDefaultOutputTypesOpset20 =
TypeList<
Expand Down Expand Up @@ -158,6 +161,7 @@ void ConstantOfShapeBase<EnabledOutputTypeList>::SetValueFromTensorProto(const O
CASE_FETCH_VALUE_DATA(uint16_t)
CASE_FETCH_VALUE_DATA(uint32_t)
CASE_FETCH_VALUE_DATA(uint64_t)
CASE_FETCH_VALUE_DATA(BFloat16)
default:
ORT_THROW("Unsupported value attribute datatype: ", tensor_type);
}
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/cuda/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
template <>
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }

template <>
__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }

template <typename T>
__device__ __host__ __inline__ T _Round(T a);

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1031,9 +1031,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
// Add bf16 support for Exp in opset 13+ for phimm model
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
// Add bf16 support for Erf in opset 13+ for phimm model
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max);
Expand Down Expand Up @@ -1947,9 +1951,11 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, float, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, double, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, BFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 13, Max)>,
Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/providers/cuda/math/unary_elementwise_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,8 @@ UNARY_OP_HFD(Ceil, 13)
UNARY_OP_HFD(Reciprocal, 13)
UNARY_OP_HFDX(Sqrt, 13)
UNARY_OP_HFD(Log, 13)
UNARY_OP_HFD(Exp, 13)
UNARY_OP_HFD(Erf, 13)
UNARY_OP_HFDX(Exp, 13)
UNARY_OP_HFDX(Erf, 13)
UNARY_OP_BWUZCSILHFD(Sign, 13)

UNARY_LOGICALOP_NOT_TYPED(1, bool)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Reciprocal)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Sqrt)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Log)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Exp)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFDX(Erf)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Round)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Sin)
SPECIALIZED_UNARY_ELEMENTWISE_IMPL_HFD(Cos)
Expand Down
3 changes: 3 additions & 0 deletions onnxruntime/core/providers/rocm/cu_inc/common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ __device__ __inline__ double _Erf(double a) { return erf(a); }
template <>
__device__ __inline__ half _Erf(half a) { return half(erff((float)a)); }

template <>
__device__ __inline__ BFloat16 _Erf(BFloat16 a) { return BFloat16(erff((float)a)); }

template <typename T>
__device__ __inline__ T _Round(T a);

Expand Down
6 changes: 6 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1021,9 +1021,13 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp);
// Add bf16 support for Exp in opset 13+ for phimm model
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf);
// Add bf16 support for Erf in opset 13+ for phimm model
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum);
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max);
Expand Down Expand Up @@ -1973,9 +1977,11 @@ static Status RegisterRocmKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Exp)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, float, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, double, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, MLFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, BFloat16, Erf)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Expand)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Sum)>,
BuildKernelCreateInfo<ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 13, Max)>,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -478,6 +478,8 @@ Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_d
template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
template <>
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) { return g_host->UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct ProviderHost {
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) = 0;
virtual Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) = 0;
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,7 @@ struct ProviderHostImpl : ProviderHost {
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ float* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ double* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ MLFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ BFloat16* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ uint8_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Status UnpackTensor(const ONNX_NAMESPACE::TensorProto& tensor, const void* raw_data, size_t raw_data_len, /*out*/ int16_t* p_data, size_t expected_size) override { return utils::UnpackTensor(tensor, raw_data, raw_data_len, p_data, expected_size); }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ inline void SetValue(TensorProto& t_proto, MLFloat16 value) {
t_proto.mutable_int32_data()->Add(value.val);
}

inline void SetValue(TensorProto& t_proto, BFloat16 value) {
t_proto.mutable_int32_data()->Add(value.val);
}

// This works for int64_t
template <class T>
inline void SetValue(TensorProto& t_proto, T value,
Expand All @@ -100,7 +104,7 @@ inline void SetValue(TensorProto& t_proto, T value,
t_proto.mutable_uint64_data()->Add(value);
}

// For everything else except float, double and MLFloat16
// For everything else except float, double, MLFloat16 and BFloat16
template <class T>
inline void SetValue(TensorProto& t_proto, T value,
typename std::enable_if<!std::is_same<T, int64_t>::value &&
Expand Down Expand Up @@ -153,6 +157,7 @@ TEST(ConstantOfShape, TypeTests) {
RunTypedTest(TensorProto::UINT16, uint16_t(6U), opset);
RunTypedTest(TensorProto::UINT32, uint32_t(32U), opset);
RunTypedTest(TensorProto::UINT64, uint64_t(64U), opset);
RunTypedTest(TensorProto::BFLOAT16, BFloat16::FromBits(static_cast<uint16_t>(7)), opset);
}
}

Expand Down
4 changes: 2 additions & 2 deletions orttraining/orttraining/core/graph/training_op_defs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3949,7 +3949,7 @@ Return true if all elements are true and false otherwise.
static_cast<int64_t>(1))
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Allow inputs and outputs to be any kind of tensor.")
.TypeConstraint(
"TInt64",
Expand Down Expand Up @@ -4116,7 +4116,7 @@ Return true if all elements are true and false otherwise.
static_cast<int64_t>(1))
.TypeConstraint(
"T",
OpSchema::all_tensor_types(),
OpSchema::all_tensor_types_ir4(),
"Allow inputs and outputs to be any kind of tensor.")
.TypeConstraint(
"TInt64",
Expand Down
Loading

0 comments on commit 227c441

Please sign in to comment.