diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake index 4ef0584b0273e..ec021a1550d6c 100644 --- a/cmake/onnxruntime_rocm_hipify.cmake +++ b/cmake/onnxruntime_rocm_hipify.cmake @@ -48,6 +48,9 @@ set(contrib_ops_excluded_files "diffusion/group_norm_impl.cu" "diffusion/group_norm_impl.h" "diffusion/nhwc_conv.cc" + "math/gemm_float8.cc" + "math/gemm_float8.cu" + "math/gemm_float8.h" "quantization/attention_quantization.cc" "quantization/attention_quantization.h" "quantization/attention_quantization_impl.cu" diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md index 1a76c18a6a8e0..890403556cc47 100644 --- a/docs/ContribOperators.md +++ b/docs/ContribOperators.md @@ -40,6 +40,7 @@ Do not modify directly.* * com.microsoft.GatherND * com.microsoft.Gelu * com.microsoft.GemmFastGelu + * com.microsoft.GemmFloat8 * com.microsoft.GreedySearch * com.microsoft.GridSample * com.microsoft.GroupNorm @@ -2137,6 +2138,71 @@ This version of the operator has been available since version 1 of the 'com.micr +### **com.microsoft.GemmFloat8** + + Generic Gemm for float and float 8. + +#### Version + +This version of the operator has been available since version 1 of the 'com.microsoft' operator set. + +#### Attributes + +
+
activation : string
+
Activation function, RELU or GELU or NONE (default).
+
alpha : float
+
Scalar multiplier for the product of input tensors A * B.
+
beta : float
+
Scalar multiplier for the product of input bias C.
+
dtype : int
+
Output Type. Same definition as attribute 'to' for operator Cast.
+
transA : int
+
Whether A should be transposed. Float 8 only supprted transA=0.
+
transB : int
+
Whether B should be transposed. Float 8 only supprted transB=1.
+
+ +#### Inputs (2 - 6) + +
+
A : TA
+
Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+
B : TB
+
Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+
C (optional) : TC
+
Input tensor C.
+
scaleA (optional) : TS
+
Scale of tensor A if A is float 8 tensor
+
scaleB (optional) : TS
+
Scale of tensor B if B is float 8 tensor
+
scaleY (optional) : TS
+
Scale of the output tensor if A or B is float 8.
+
+ +#### Outputs + +
+
Y : TR
+
Output tensor of shape (M, N).
+
+ +#### Type Constraints + +
+
TA : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input A.
+
TB : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input B.
+
TC : tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to input C.
+
TR : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+
Constrain type to result type.
+
TS : tensor(float)
+
Constrain type for all input scales (scaleA, scaleB, scaleY).
+
+ + ### **com.microsoft.GreedySearch** Greedy Search for text generation. diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md index d047096cb8c80..bfb7716dc5cea 100644 --- a/docs/OperatorKernels.md +++ b/docs/OperatorKernels.md @@ -831,6 +831,7 @@ Do not modify directly.* |FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)| |GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)| |Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)| +|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)| |GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)| |GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)| |GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)| diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc index 29ca8124bfd05..e6a216795c10b 100644 --- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc +++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc @@ -144,6 +144,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1 class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention); class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8); #ifdef ENABLE_ATEN class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen); @@ -317,6 +318,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) { BuildKernelCreateInfo, BuildKernelCreateInfo, BuildKernelCreateInfo, + BuildKernelCreateInfo, #ifdef ENABLE_ATEN BuildKernelCreateInfo, diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc new file mode 100644 index 0000000000000..251850f621361 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc @@ -0,0 +1,70 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#include +#include "core/providers/cuda/math/gemm.h" +#include "core/providers/cuda/cuda_common.h" +#include "core/providers/cuda/shared_inc/fpgeneric.h" +#include "core/providers/cpu/math/gemm_helper.h" +#include "contrib_ops/cuda/math/gemm_float8.h" + +using namespace ONNX_NAMESPACE; + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +#define REGISTER_KERNEL() \ + ONNX_OPERATOR_KERNEL_EX( \ + GemmFloat8, \ + kMSDomain, \ + 1, \ + kCudaExecutionProvider, \ + (*KernelDefBuilder::Create()) \ + .TypeConstraint("TA", BuildKernelDefConstraints()) \ + .TypeConstraint("TB", BuildKernelDefConstraints()) \ + .TypeConstraint("TR", BuildKernelDefConstraints()) \ + .TypeConstraint("TS", BuildKernelDefConstraints()), \ + GemmFloat8); + +REGISTER_KERNEL() + +GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) { + transA_ = info.GetAttrOrDefault("transA", 0); + transB_ = info.GetAttrOrDefault("transB", 0); + dtype_ = info.GetAttrOrDefault("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + auto& device_prop = GetDeviceProp(); + sm_count_ = device_prop.multiProcessorCount; + alpha_ = info.GetAttrOrDefault("alpha", 1); + beta_ = info.GetAttrOrDefault("beta", 0); + +#if (CUDA_VERSION <= 12000) + ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0."); +#endif + + std::string stemp = info.GetAttrOrDefault("activation", "NONE"); + if (stemp == "NONE") { + epilogue_ = CUBLASLT_EPILOGUE_DEFAULT; + } else if (stemp == "RELU") { + epilogue_ = CUBLASLT_EPILOGUE_RELU; + } else if (stemp == "GELU") { + epilogue_ = CUBLASLT_EPILOGUE_GELU; + } else { + ORT_THROW("Unexpected value for activation: '", stemp, "'."); + } +} + +Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const { + GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({})); + if (!helper.State().IsOK()) + return helper.State(); + + M = gsl::narrow_cast(helper.M()); + N = gsl::narrow_cast(helper.N()); + K = gsl::narrow_cast(helper.K()); + return helper.State(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu new file mode 100644 index 0000000000000..df25342342cd5 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -0,0 +1,402 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. +// +// The operator calls function 'cublasLtMatmul' +// (https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul). +// It lets the function checks what configuration is valid or not. If not, the error message +// shows the error message 'CUBLAS_STATUS_NOT_SUPPORTED'. NVIDIA documentation provides +// information on what attribute or type must be modified. +// This operator requires CUDA_VERSION >= 11.8 for float 8 and CUDA_VERSION >= 12.0 +// for beta != 0. + +#include +#include +#include +#include "contrib_ops/cuda/math/gemm_float8.h" +#include "core/providers/cuda/cu_inc/common.cuh" +#include "core/providers/cuda/shared_inc/cuda_utils.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// It must exist somewhere already. +int32_t TypeSize(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return 4; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return 2; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return 1; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + +void GemmFloat8::SetParams(const TensorShape& a_shape, const TensorShape& b_shape, + int& M, int& N, int& K, int& lda, int& ldb, int& ldd) const { + int m_idx = transA_ ? 1 : 0; + int k_idx = 1 - m_idx; + int n_idx = transB_ ? 0 : 1; + + M = static_cast(a_shape[m_idx]); + K = static_cast(a_shape[k_idx]); + N = static_cast(b_shape[n_idx]); + lda = static_cast(a_shape[1]); + ldb = static_cast(b_shape[1]); + ldd = static_cast(b_shape[n_idx]); +} + +template +int32_t GetTypeAndShape(const TValue* input, + TensorShape& shape, + bool swap = false) { + shape = input->Shape(); + ORT_ENFORCE(shape.NumDimensions() == 2); + if (swap) { + std::swap(shape[0], shape[1]); + } + return input->GetElementType(); +} + +Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { + const Tensor* input_A = nullptr; + const Tensor* input_B = nullptr; + const Tensor* input_C = nullptr; + const Tensor* scale_A = nullptr; + const Tensor* scale_B = nullptr; + const Tensor* scale_Y = nullptr; + bool has_scales = false; + bool has_bias = false; + int n_inputs = ctx->InputCount(); + + input_A = ctx->Input(0); + input_B = ctx->Input(1); + if (n_inputs == 3) { + input_C = ctx->Input(2); + has_bias = true; + } else if (n_inputs > 3) { + ORT_ENFORCE(n_inputs >= 5, "Unexpected number of inputs=", n_inputs, "."); + has_scales = true; + scale_A = ctx->Input(3); + scale_B = ctx->Input(4); + scale_Y = n_inputs < 6 ? nullptr : ctx->Input(5); + ORT_ENFORCE(scale_A->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_B->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + ORT_ENFORCE(scale_Y == nullptr || scale_Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT); + if (ctx->Input(2) != nullptr) { + input_C = ctx->Input(2); + has_bias = true; + ORT_ENFORCE(input_C->GetElementType() == dtype_, "Bias type must be equal to dtype."); + } + } + + auto first_type = input_A->GetElementType(); + bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; + if (!is_float8) + return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); + return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, + input_C, scale_A, scale_B, scale_Y); +} + +Status GemmFloat8::ComputeRowMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_A, dtype_B, dtype_C, + dtype_Y, shape_A, shape_B, shape_C, shape_Y, transA_, transB_, + input_A->DataRaw(), input_B->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), M, N, K, lda, ldb, ldd, true); +} + +Status GemmFloat8::ComputeColMajor( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + const Tensor* input_A, const Tensor* input_B, + const Tensor* input_C, const Tensor* scale_A, + const Tensor* scale_B, const Tensor* scale_Y) const { + TensorShape shape_A, shape_B, shape_C, shape_Y; + int32_t dtype_A, dtype_B, dtype_C, dtype_Y; + dtype_A = GetTypeAndShape(input_A, shape_A); + dtype_B = GetTypeAndShape(input_B, shape_B); + + int M, N, K, lda, ldb, ldd; + SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd); + + std::swap(shape_A[0], shape_A[1]); + std::swap(shape_B[0], shape_B[1]); + + TensorShape dimensions{M, N}; + Tensor* Y = ctx->Output(0, dimensions); + dtype_Y = GetTypeAndShape(Y, shape_Y); + dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C, true) + : ONNX_NAMESPACE::TensorProto_DataType_FLOAT; + + return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_B, dtype_A, dtype_C, + dtype_Y, shape_B, shape_A, shape_C, shape_Y, transB_, transA_, + input_B->DataRaw(), input_A->DataRaw(), + has_bias ? input_C->DataRaw() : nullptr, + has_scales ? scale_B->DataRaw() : nullptr, + has_scales ? scale_A->DataRaw() : nullptr, + has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr, + Y->MutableDataRaw(), N, M, K, ldb, lda, ldd, false); +} + +Status GemmFloat8::ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_B, + int32_t dtype_C, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const { + cudaStream_t stream = Stream(ctx); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream)); + + cublasLtHandle_t cublasLt; + CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt)); + + cublasLtMatmulDesc_t operationDesc = nullptr; + cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr, + Ddesc = nullptr; + + // Create matrix descriptors. Not setting any extra attributes. + cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A); + cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B); + cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y); + cudaDataType_t scale_cuda_type = + onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT); + cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C); + + cublasComputeType_t compute_type; + switch (d_cuda_type) { + case CUDA_R_16F: + switch (a_cuda_type) { + case CUDA_R_8F_E4M3: + case CUDA_R_8F_E5M2: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + compute_type = CUBLAS_COMPUTE_32F_FAST_16F; + break; + } + break; + case CUDA_R_16BF: + compute_type = CUBLAS_COMPUTE_32F_FAST_16BF; + break; + case CUDA_R_32F: + compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; + break; + default: + ORT_THROW("Unable to determine computeType in operator GemmFloat8."); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate( + &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb)); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd)); + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + CUBLAS_RETURN_IF_ERROR( + cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type)); + cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N; + cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); + + if (sm_count_ != 0) { + int math_sm_count = static_cast(sm_count_); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, + sizeof(math_sm_count))); + } + + if (has_scales) { + // gemm float 8 + const int8_t ifast_accumulation_mode = 1; + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, + cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &ifast_accumulation_mode, sizeof(ifast_accumulation_mode))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a, + sizeof(p_scale_a))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b, + sizeof(p_scale_b))); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, + sizeof(p_scale_b))); + + // float 8 +#if CUDA_VERSION >= 11080 + if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || + dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { + // For FP8 output, cuBLAS requires C_type to be same as bias_type + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( + operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type, + sizeof(bias_cuda_type))); + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } + } else { + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); + } +#else + // An output is still needed but it is not initialized. + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd)); +#endif + + if (row_major_compute) { + cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW; + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + CUBLAS_RETURN_IF_ERROR( + cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER, + &matrixOrder, sizeof(matrixOrder))); + } + + cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE, + &epilogue_, sizeof(epilogue_)); + + // See + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t + // The workspace should be allocated once from OpKernelContext assuming + // only one cuda function is running at a time (which is not necessarily true + // with H100). + size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb + cublasLtMatmulPreference_t preference = nullptr; + cublasLtMatmulPreferenceCreate(&preference); + cublasLtMatmulPreferenceSetAttribute(preference, + CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, + &workspaceSize, sizeof(workspaceSize)); + + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic + cublasLtMatmulHeuristicResult_t heuristicResult = {}; + int returnedResults = 0; + cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic( + cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1, + &heuristicResult, &returnedResults); + ORT_ENFORCE( + returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to find any suitable algorithm due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, + ", alpha=", alpha_, ", beta=", beta_, ", n_inputs=", n_inputs, + ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", shape_C=", (shape_C.NumDimensions() > 0 ? shape_C[0] : 0), "x", + (shape_C.NumDimensions() > 1 ? shape_C[1] : 0), ", M=", M, ", N=", N, ", K=", K, + ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd, + ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0), + ". Check NVIDIA documentation to see what combination is valid: ", + "https://docs.nvidia.com/cuda/cublas/" + "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#" + "cublasltmatmulalgogetheuristic."); + + void* workspace = nullptr; + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize)); + } + // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul + const void* bias = has_bias ? p_input_c : p_output_y; + cuda_status = cublasLtMatmul( + cublasLt, operationDesc, static_cast(&alpha_), /* alpha */ + p_input_a, /* A */ + Adesc, p_input_b, /* B */ + Bdesc, static_cast(&beta_), /* beta */ + bias, /* C */ + Cdesc, p_output_y, /* Y */ + Ddesc, &heuristicResult.algo, /* algo */ + workspace, /* workspace */ + workspaceSize, stream); /* stream */ + ORT_ENFORCE( + cuda_status == CUBLAS_STATUS_SUCCESS, + " Unable to run cublasLtMatmul due to ", + onnxruntime::cuda::cublasGetErrorEnum(cuda_status), + ", returnedResults=", returnedResults, ", alpha=", alpha_, + ", n_inputs=", n_inputs, ", A_type=", + onnxruntime::cuda::CudaDataTypeToString(a_cuda_type), + ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type), + ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type), + ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type), + ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type), + ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type), + ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A, + ", transB=", trans_B, + ", fastAccumulationMode=", 1, + ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x", + shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb, + ", ldd=", ldd, ", workspaceSize=", workspaceSize, + ", rowMajorCompute=", (row_major_compute ? 1 : 0), "."); + + if (workspaceSize > 0) { + CUDA_RETURN_IF_ERROR(cudaFree(workspace)); + } + + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc)); + CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt)); + return Status::OK(); +} + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.h b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h new file mode 100644 index 0000000000000..e84ccd55b2003 --- /dev/null +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h @@ -0,0 +1,65 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT License. + +#pragma once + +#include "cublas_v2.h" +#include "core/providers/cuda/cuda_kernel.h" + +namespace onnxruntime { +namespace contrib { +namespace cuda { + +// Calls https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul. +// D = alpha*(A*B) +class GemmFloat8 final : public onnxruntime::cuda::CudaKernel { + public: + GemmFloat8(const OpKernelInfo& info); + + Status ComputeInternal(OpKernelContext* context) const override; + + private: + void SetParams(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K, + int& lda, int& ldb, int& ldd) const; + Status SetCheck(const TensorShape& shape_a, + const TensorShape& shape_b, + int& M, int& N, int& K) const; + + Status ComputeRowMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + Status ComputeColMajor(OpKernelContext* ctx, int n_inputs, bool has_bias, + bool has_scales, const Tensor* input_A, + const Tensor* input_B, const Tensor* input_C, + const Tensor* scale_A, const Tensor* scale_B, + const Tensor* scale_Y) const; + + Status ComputeGemm( + OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales, + int32_t dtype_A, int32_t dtype_b, + int32_t dtype_c, int32_t dtype_Y, + const TensorShape& shape_A, const TensorShape& shape_B, + const TensorShape& shape_C, const TensorShape& shape_Y, + bool transa, bool transb, const void* p_input_a, const void* p_input_b, + const void* p_input_c, const void* p_scale_a, const void* p_scale_b, + const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda, + int ldb, int ldd, bool row_major_compute) const; + + float alpha_; + float beta_; + bool transA_; + bool transB_; + int64_t sm_count_; + int64_t dtype_; + cublasLtEpilogue_t epilogue_; + + // TODO(xadupre): add epilogue (= activation function, Relu or Gelu are available). +}; + +} // namespace cuda +} // namespace contrib +} // namespace onnxruntime diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc index 681a728f823da..e757e39130d39 100644 --- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc +++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc @@ -2573,6 +2573,124 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1, a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth]. The resizing is corner aligned.)DOC")); +#if !defined(DISABLE_FLOAT8_TYPES) +#define GEMM_FLOAT8_TYPES \ + { "tensor(float8e4m3fn)", "tensor(float8e5m2)", "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#else +#define GEMM_FLOAT8_TYPES \ + { "tensor(float16)", "tensor(bfloat16)", "tensor(float)" } +#endif + +ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1, + OpSchema() + .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC") + .Attr( + "transA", + "Whether A should be transposed. Float 8 only supprted transA=0.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "transB", + "Whether B should be transposed. Float 8 only supprted transB=1.", + AttributeProto::INT, + static_cast(0)) + .Attr( + "alpha", + "Scalar multiplier for the product of input tensors A * B.", + AttributeProto::FLOAT, + 1.0f) + .Attr( + "beta", + "Scalar multiplier for the product of input bias C.", + AttributeProto::FLOAT, + 0.0f) + .Attr( + "dtype", + "Output Type. Same definition as attribute 'to' for operator Cast.", + AttributeProto::INT, + static_cast(1)) + .Attr( + "activation", + "Activation function, RELU or GELU or NONE (default).", + AttributeProto::STRING, + OPTIONAL_VALUE) + .Input( + 0, + "A", + "Input tensor A. " + "The shape of A should be (M, K) if transA is 0, " + "or (K, M) if transA is non-zero.", + "TA") + .Input( + 1, + "B", + "Input tensor B. " + "The shape of B should be (K, N) if transB is 0, " + "or (N, K) if transB is non-zero.", + "TB") + .Input( + 2, + "C", + "Input tensor C.", + "TC", + OpSchema::Optional) + .Input( + 3, + "scaleA", + "Scale of tensor A if A is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 4, + "scaleB", + "Scale of tensor B if B is float 8 tensor", + "TS", + OpSchema::Optional) + .Input( + 5, + "scaleY", + "Scale of the output tensor if A or B is float 8.", + "TS", + OpSchema::Optional) + .Output(0, "Y", "Output tensor of shape (M, N).", "TR") + .TypeConstraint( + "TA", + GEMM_FLOAT8_TYPES, + "Constrain type to input A.") + .TypeConstraint( + "TB", + GEMM_FLOAT8_TYPES, + "Constrain type to input B.") + .TypeConstraint( + "TC", + {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"}, + "Constrain type to input C.") + .TypeConstraint( + "TR", + GEMM_FLOAT8_TYPES, + "Constrain type to result type.") + .TypeConstraint("TS", {"tensor(float)"}, + "Constrain type for all input scales (scaleA, scaleB, scaleY).") + .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) { + propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT); + if (!hasNInputShapes(ctx, 2)) { + return; + } + auto transAAttr = ctx.getAttribute("transA"); + bool transA = transAAttr ? static_cast(transAAttr->i()) != 0 : false; + auto transBAttr = ctx.getAttribute("transB"); + bool transB = transBAttr ? static_cast(transBAttr->i()) != 0 : false; + auto& first_input_shape = getInputShape(ctx, 0); + auto& second_input_shape = getInputShape(ctx, 1); + if (first_input_shape.dim_size() != 2) { + fail_shape_inference("First input does not have rank 2"); + } + if (second_input_shape.dim_size() != 2) { + fail_shape_inference("Second input does not have rank 2"); + } + updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)}); + })); + static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx, int64_t K, int64_t N) { diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h index afaa380d6ac79..aa31f3b5a7c62 100644 --- a/onnxruntime/core/graph/contrib_ops/ms_opset.h +++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h @@ -112,6 +112,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedMultiHeadAttention); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFloat8); class OpSet_Microsoft_ver1 { public: @@ -218,6 +219,7 @@ class OpSet_Microsoft_ver1 { fn(GetOpSchema()); fn(GetOpSchema()); fn(GetOpSchema()); + fn(GetOpSchema()); } }; } // namespace contrib diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 57477f167c555..288ca8e97e34d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -27,5 +27,90 @@ const HalfGemmOptions* HalfGemmOptions::GetInstance() { return &instance; } +const char* cublasGetErrorEnum(cublasStatus_t error) { + switch (error) { + case CUBLAS_STATUS_SUCCESS: + return "CUBLAS_STATUS_SUCCESS"; + case CUBLAS_STATUS_NOT_INITIALIZED: + return "CUBLAS_STATUS_NOT_INITIALIZED"; + case CUBLAS_STATUS_ALLOC_FAILED: + return "CUBLAS_STATUS_ALLOC_FAILED"; + case CUBLAS_STATUS_INVALID_VALUE: + return "CUBLAS_STATUS_INVALID_VALUE"; + case CUBLAS_STATUS_ARCH_MISMATCH: + return "CUBLAS_STATUS_ARCH_MISMATCH"; + case CUBLAS_STATUS_MAPPING_ERROR: + return "CUBLAS_STATUS_MAPPING_ERROR"; + case CUBLAS_STATUS_EXECUTION_FAILED: + return "CUBLAS_STATUS_EXECUTION_FAILED"; + case CUBLAS_STATUS_INTERNAL_ERROR: + return "CUBLAS_STATUS_INTERNAL_ERROR"; + case CUBLAS_STATUS_NOT_SUPPORTED: + return "CUBLAS_STATUS_NOT_SUPPORTED"; + case CUBLAS_STATUS_LICENSE_ERROR: + return "CUBLAS_STATUS_LICENSE_ERROR"; + default: + return ""; + } +} + +const char* CudaDataTypeToString(cudaDataType_t dt) { + switch (dt) { + case CUDA_R_16F: + return "CUDA_R_16F"; + case CUDA_R_16BF: + return "CUDA_R_16BF"; + case CUDA_R_32F: + return "CUDA_R_32F"; +#if (CUDA_VERSION >= 11080) + case CUDA_R_8F_E4M3: + return "CUDA_R_8F_E4M3"; + case CUDA_R_8F_E5M2: + return "CUDA_R_8F_E5M2"; +#endif + default: + return ""; + } +} + +const char* CublasComputeTypeToString(cublasComputeType_t ct) { + switch (ct) { + case CUBLAS_COMPUTE_16F: + return "CUBLAS_COMPUTE_16F"; + case CUBLAS_COMPUTE_32F: + return "CUBLAS_COMPUTE_32F"; + case CUBLAS_COMPUTE_32F_FAST_16F: + return "CUBLAS_COMPUTE_32F_FAST_16F"; + case CUBLAS_COMPUTE_32F_FAST_16BF: + return "CUBLAS_COMPUTE_32F_FAST_16BF"; + case CUBLAS_COMPUTE_32F_FAST_TF32: + return "CUBLAS_COMPUTE_32F_FAST_TF32"; + case CUBLAS_COMPUTE_64F: + return "CUBLAS_COMPUTE_64F"; + default: + return ""; + } +} + +// It must exist somewhere already. +cudaDataType_t ToCudaDataType(int32_t element_type) { + switch (element_type) { + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT: + return CUDA_R_32F; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: + return CUDA_R_16F; + case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: + return CUDA_R_16BF; +#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: + return CUDA_R_8F_E4M3; + case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: + return CUDA_R_8F_E5M2; +#endif + default: + ORT_THROW("Unexpected element_type=", element_type, "."); + } +} + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h index fa258961f1155..9cd4e721ccab8 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.h +++ b/onnxruntime/core/providers/cuda/cuda_common.h @@ -11,6 +11,7 @@ #include "core/providers/shared_library/provider_api.h" #include "core/common/status.h" +#include "core/framework/float8.h" #include "core/framework/float16.h" #include "core/providers/cuda/cuda_pch.h" #include "core/providers/cuda/shared_inc/cuda_call.h" @@ -48,6 +49,33 @@ class ToCudaType { } }; +template <> +class ToCudaType { + public: + typedef BFloat16 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E4M3FN MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + +template <> +class ToCudaType { + public: + typedef Float8E5M2 MappedType; + static MappedType FromFloat(float f) { + return MappedType(f); + } +}; + inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) { int stride = 1; if (dims.empty() || p.size() < dims.size()) @@ -152,5 +180,13 @@ class HalfGemmOptions { static HalfGemmOptions instance; }; +const char* cublasGetErrorEnum(cublasStatus_t error); + +const char* CudaDataTypeToString(cudaDataType_t dt); + +const char* CublasComputeTypeToString(cublasComputeType_t ct); + +cudaDataType_t ToCudaDataType(int32_t element_type); + } // namespace cuda } // namespace onnxruntime diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py index 272727a9f5375..ef1c46b83946a 100755 --- a/onnxruntime/python/tools/symbolic_shape_infer.py +++ b/onnxruntime/python/tools/symbolic_shape_infer.py @@ -198,6 +198,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""): "GatedRelativePositionBias": self._infer_GatedRelativePositionBias, "Gelu": self._infer_Gelu, "GemmFastGelu": self._infer_GemmFastGelu, + "GemmFloat8": self._infer_GemmFloat8, "GroupNorm": self._infer_GroupNorm, "LayerNormalization": self._infer_LayerNormalization, "LongformerAttention": self._infer_LongformerAttention, @@ -2317,6 +2318,9 @@ def _infer_QuickGelu(self, node): # noqa: N802 def _infer_GemmFastGelu(self, node): # noqa: N802 self._compute_matmul_shape(node) + def _infer_GemmFloat8(self, node): # noqa: N802 + self._compute_matmul_shape(node) + def _infer_LayerNormalization(self, node): # noqa: N802 self._propagate_shape_and_type(node) if len(node.output) > 1: diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc index de5431ca4a460..0526ccca5bb4e 100644 --- a/onnxruntime/test/onnx/main.cc +++ b/onnxruntime/test/onnx/main.cc @@ -761,6 +761,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)"); ORT_TSTR("sce_none_weights_expanded")}; std::unordered_set> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests)); + if (enable_cuda) { all_disabled_tests.insert(std::begin(cuda_flaky_tests), std::end(cuda_flaky_tests)); } diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py new file mode 100644 index 0000000000000..784ae8ce70bd8 --- /dev/null +++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py @@ -0,0 +1,284 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. +# pylint: disable=C0116,W0212,R1720,C0103,C0114 +# +# Note: the precision is different on V100, H100 even with the same code. +# The thresholds were adjusted on H100 as the precision seems lower on this machine. + +import itertools +import unittest +import warnings + +import numpy as np +import parameterized +from numpy.testing import assert_allclose +from onnx import TensorProto +from onnx.checker import check_model +from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info +from onnx.numpy_helper import from_array + +from onnxruntime import InferenceSession + + +class TestFloat8Gemm8(unittest.TestCase): + def get_model_gemm( + self, + float_name, + alpha=1.0, + beta=0.0, + transA=0, + transB=0, + domain="", + dtype=TensorProto.FLOAT, + activation="NONE", + ): + proto_type = getattr(TensorProto, float_name) + use_f8 = proto_type in (TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2) + + a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) + b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) + d = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None]) + + inits = [] + kwargs = {} + node_inputs = ["Af", "Bf"] + inputs = [a, b] + bias = beta != 0 + if bias: + inputs.append(make_tensor_value_info("C", TensorProto.FLOAT, [None, None])) + node_inputs = ["Af", "Bf", "Cf"] + if use_f8: + node_inputs.extends(["one"] * 3) + elif use_f8: + node_inputs.append("") + node_inputs.extend(["one"] * 3) + + if use_f8: + assert domain == "com.microsoft" + inits.append(from_array(np.array([1], dtype=np.float32), name="one")) + kwargs = dict( + domain=domain, + dtype=dtype, + ) + if activation is not None: + kwargs["activation"] = activation + op_name = "GemmFloat8" + elif domain == "com.microsoft": + op_name = "GemmFloat8" + kwargs = dict( + domain=domain, + dtype=dtype, + ) + else: + op_name = "Gemm" + nodes = [ + make_node("Cast", ["A"], ["Af"], to=proto_type), + make_node("Cast", ["B"], ["Bf"], to=proto_type), + make_node("Cast", ["C"], ["Cf"], to=proto_type) if bias else None, + make_node( + op_name, + node_inputs, + ["Yf"], + transA=transA, + transB=transB, + alpha=alpha, + beta=beta, + **kwargs, + ), + make_node("Cast", ["Yf"], ["Y"], to=TensorProto.FLOAT), + ] + nodes = [n for n in nodes if n is not None] + graph = make_graph(nodes, "gemm", inputs, [d], inits) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9) + if domain != "com.microsoft": + check_model(onnx_model) + return onnx_model + + def common_test_model_gemm(self, float_type, mul=0.33, atol=0, rtol=0, square=True, **kwargs): + if square: + a = (np.arange(256) * 0.01).astype(np.float32).reshape((-1, 16)) + b = (np.arange(256) * -0.01).astype(np.float32).reshape((-1, 16)) + c = (np.arange(256) * 0.03).astype(np.float32).reshape((-1, 16)) + b[:, 0] += 1 + else: + a = (np.arange(256) / 256).astype(np.float32).reshape((32, -1)) + b = (np.arange(512) / 512).astype(np.float32).reshape((32, -1)) + c = (np.arange(128) / 128).astype(np.float32).reshape((8, 16)) + + feeds = {"A": a, "B": b} + + expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b) + expected *= kwargs.get("alpha", 1.0) + if kwargs.get("beta", 0) != 0: + expected += kwargs["beta"] * c + feeds["C"] = c + + onnx_model = self.get_model_gemm("FLOAT", **kwargs) + + ref = InferenceSession( + onnx_model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + y = ref.run(None, feeds)[0] + if float_type in ("FLOAT", "FLOAT16"): + try: + assert_allclose(expected, y, atol=atol, rtol=rtol) + except Exception as e: + + def check(f): + try: + return f()[:2, :2] + except Exception as e: + return str(e) + + raise AssertionError( + f"Gemm ERROR len(inputs)={len(feeds)}" + f"\na@b=\n{check(lambda:a@b)}" + f"\na.T@b=\n{check(lambda:a.T@b)}" + f"\na@b.T=\n{check(lambda:a@b.T)}" + f"\na.T@b.T=\n{check(lambda:a.T@b.T)}" + f"\n----\nb@a=\n{check(lambda:b@a)}" + f"\nb.T@a=\n{check(lambda:b.T@a)}" + f"\nb@a.T=\n{check(lambda:b@a.T)}" + f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}" + f"\n----\nexpected=\n{expected[:2,:2]}" + f"\n----\ngot=\n{y[:2,:2]}" + f"\nkwargs={kwargs}" + ) from e + + self.assertEqual(expected.shape, y.shape) + self.assertEqual(expected.dtype, y.dtype) + + onnx_model_f8 = self.get_model_gemm(float_type, domain="com.microsoft", **kwargs) + try: + ref8 = InferenceSession( + onnx_model_f8.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"] + ) + except Exception as e: + if "CUDA < 12.0 does not support bias" in str(e): + return + raise AssertionError(f"Could not load model {onnx_model_f8}") from e + try: + y = ref8.run(None, feeds)[0] + except Exception as e: + if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e): + # Skipping. This machine does not support float8. + warnings.warn("unable to test with float8 on this machine.") + return + raise AssertionError(f"Could not execute model {onnx_model_f8}") from e + try: + assert_allclose(expected, y, atol=atol, rtol=rtol) + except Exception as e: + + def check(f): + try: + return f()[:2, :2] + except Exception as e: + return str(e) + + raise AssertionError( + f"Gemm ERROR len(inputs)={len(feeds)}" + f"\na@b=\n{check(lambda:a@b)}" + f"\na.T@b=\n{check(lambda:a.T@b)}" + f"\na@b.T=\n{check(lambda:a@b.T)}" + f"\na.T@b.T=\n{check(lambda:a.T@b.T)}" + f"\n----\nb@a=\n{check(lambda:b@a)}" + f"\nb.T@a=\n{check(lambda:b.T@a)}" + f"\nb@a.T=\n{check(lambda:b@a.T)}" + f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}" + f"\n----\nexpected=\n{expected[:2,:2]}" + f"\n----\ngot=\n{y[:2,:2]}" + f"\nkwargs={kwargs}" + ) from e + self.assertEqual(expected.shape, y.shape) + self.assertEqual(expected.dtype, y.dtype) + + def test_model_gemm_float(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3) + + def test_model_gemm_float_default_values(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None) + + def test_model_gemm_float_relu(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU") + + def test_model_gemm_float_gelu(self): + self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU") + + def test_model_gemm_float_bias(self): + self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3) + + def test_model_gemm_float16(self): + self.common_test_model_gemm( + "FLOAT16", + rtol=1e-2, + dtype=TensorProto.FLOAT16, + transB=1, + ) + + def test_model_gemm_float8_e4m3(self): + self.common_test_model_gemm( + "FLOAT8E4M3FN", + rtol=0.5, + dtype=TensorProto.FLOAT, + transA=0, + transB=1, + alpha=10.0, + ) + + @parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1]))) + def test_combinations_square_matrices(self, transA, transB): + self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3) + + @parameterized.parameterized.expand( + [ + ((2, 3), (3, 5), 0, 0), + ((2, 3), (5, 3), 0, 1), + ((2, 3), (5, 2), 1, 1), + ((2, 3), (2, 5), 1, 0), + ] + ) + def test_combinations(self, shapeA, shapeB, transA, transB): + model = make_model( + make_graph( + [ + make_node( + "GemmFloat8", + ["A", "B"], + ["Y"], + transA=transA, + transB=transB, + domain="com.microsoft", + ) + ], + "f8", + [ + make_tensor_value_info("A", TensorProto.FLOAT, [None, None]), + make_tensor_value_info("B", TensorProto.FLOAT, [None, None]), + ], + [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])], + ) + ) + + sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]) + a = np.arange(np.prod(shapeA)).reshape(shapeA).astype(np.float32) + b = np.arange(np.prod(shapeB)).reshape(shapeB).astype(np.float32) + try: + expected = (a.T if transA else a) @ (b.T if transB else b) + except Exception as e: + raise AssertionError( + f"Unable to multiply shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}" + ) from e + try: + got = sess.run(None, {"A": a, "B": b}) + except Exception as e: + raise AssertionError( + f"Unable to run Gemm with shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}" + ) from e + self.assertEqual(expected.shape, got[0].shape) + self.assertEqual(expected.dtype, got[0].dtype) + assert_allclose(expected, got[0]) + + +if __name__ == "__main__": + # TestFloat8Gemm8().test_model_gemm_float() + unittest.main(verbosity=2)