Skip to content

Commit

Permalink
Implement CastLike for opset 15 to 18
Browse files Browse the repository at this point in the history
  • Loading branch information
wschin committed Jan 30, 2024
1 parent ffc3431 commit a4d0848
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 0 deletions.
26 changes: 26 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1224,6 +1224,19 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain,
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization);
class ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, Shape);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, MLFloat16, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, float, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, double, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int8_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int16_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int32_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int64_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint8_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint16_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint32_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint64_t, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, bool, CastLike);
class ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, BFloat16, CastLike);

// Opset 16
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LeakyRelu);
Expand Down Expand Up @@ -2110,6 +2123,19 @@ static Status RegisterCudaKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, double, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, MLFloat16, BatchNormalization)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, Shape)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, MLFloat16, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, float, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, double, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int8_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int16_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int32_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, int64_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint8_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint16_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint32_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, uint64_t, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, bool, CastLike)>,
BuildKernelCreateInfo<ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 15, 18, BFloat16, CastLike)>,

// Opset 16
BuildKernelCreateInfo<ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kOnnxDomain, 16, float, LeakyRelu)>,
Expand Down
68 changes: 68 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/cast_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -329,5 +329,73 @@ SPECIALIZE_IMPL_19(Float8E5M2)

#endif

///////////////////////////////////////////////////////////////////
// The section below implements CastLike.
///////////////////////////////////////////////////////////////////

ONNX_OPERATOR_VERSIONED_KERNEL_CLASS_NAME
#define REGISTER_CASTLIKE_KERNEL_TYPED(T) \
ONNX_OPERATOR_VERSIONED_TYPED_KERNEL_EX( \
CastLike, \
kOnnxDomain, \
15, 18, \
T, \
kCudaExecutionProvider, \
(*KernelDefBuilder::Create()) \
.TypeConstraint("T1", DataTypeImpl::GetTensorType<T>()) \
.TypeConstraint("T2", CastOpTypeConstraints()), \
CastLike<T>);

REGISTER_CASTLIKE_KERNEL_TYPED(MLFloat16)
REGISTER_CASTLIKE_KERNEL_TYPED(float)
REGISTER_CASTLIKE_KERNEL_TYPED(double)
REGISTER_CASTLIKE_KERNEL_TYPED(int8_t)
REGISTER_CASTLIKE_KERNEL_TYPED(int16_t)
REGISTER_CASTLIKE_KERNEL_TYPED(int32_t)
REGISTER_CASTLIKE_KERNEL_TYPED(int64_t)
REGISTER_CASTLIKE_KERNEL_TYPED(uint8_t)
REGISTER_CASTLIKE_KERNEL_TYPED(uint16_t)
REGISTER_CASTLIKE_KERNEL_TYPED(uint32_t)
REGISTER_CASTLIKE_KERNEL_TYPED(uint64_t)
REGISTER_CASTLIKE_KERNEL_TYPED(bool)
REGISTER_CASTLIKE_KERNEL_TYPED(BFloat16)

template <typename SrcT>
Status CastLike<SrcT>::ComputeInternal(OpKernelContext* context) const {
typedef typename ToCudaType<SrcT>::MappedType CudaSrcT;
const Tensor* X = context->Input<Tensor>(0);
const Tensor* liked = context->Input<Tensor>(1);
const auto liked_element_type = liked->GetElementType();
const TensorShape& shape = X->Shape();
Tensor* Y = context->Output(0, shape);
const auto* x_data = reinterpret_cast<const CudaSrcT*>(X->Data<SrcT>());
size_t count = shape.Size();

switch (liked_element_type) {
CASE(TensorProto_DataType_FLOAT16, MLFloat16)
CASE(TensorProto_DataType_BFLOAT16, BFloat16)
CASE(TensorProto_DataType_FLOAT, float)
CASE(TensorProto_DataType_DOUBLE, double)
CASE(TensorProto_DataType_INT8, int8_t)
CASE(TensorProto_DataType_INT16, int16_t)
CASE(TensorProto_DataType_INT32, int32_t)
CASE(TensorProto_DataType_INT64, int64_t)
CASE(TensorProto_DataType_UINT8, uint8_t)
CASE(TensorProto_DataType_UINT16, uint16_t)
CASE(TensorProto_DataType_UINT32, uint32_t)
CASE(TensorProto_DataType_UINT64, uint64_t)
CASE(TensorProto_DataType_BOOL, bool)
// By default saturate is true. Case saturate False is only supported for float, float16 for the CUDA provider.

case TensorProto_DataType_STRING:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Casting to and from strings is not supported yet.");
case TensorProto_DataType_UNDEFINED:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cast op must have 'to' argument of type DataType");
default:
return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unexpected 'to' argument value: ", liked_element_type, ". Search for e.g., TensorProto_DataType_FLOAT for supported element type.");
}
return Status::OK();
}

} // namespace cuda
} // namespace onnxruntime
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/cuda/tensor/cast_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,5 +36,12 @@ class Cast final : public CudaKernel {
bool saturate_;
};

template <typename SrcT>
class CastLike final : public CudaKernel {
public:
CastLike(const OpKernelInfo& info) : CudaKernel(info) {}
Status ComputeInternal(OpKernelContext* context) const override;
};

} // namespace cuda
} // namespace onnxruntime

0 comments on commit a4d0848

Please sign in to comment.