diff --git a/cmake/onnxruntime_rocm_hipify.cmake b/cmake/onnxruntime_rocm_hipify.cmake
index 4ef0584b0273e..ec021a1550d6c 100644
--- a/cmake/onnxruntime_rocm_hipify.cmake
+++ b/cmake/onnxruntime_rocm_hipify.cmake
@@ -48,6 +48,9 @@ set(contrib_ops_excluded_files
"diffusion/group_norm_impl.cu"
"diffusion/group_norm_impl.h"
"diffusion/nhwc_conv.cc"
+ "math/gemm_float8.cc"
+ "math/gemm_float8.cu"
+ "math/gemm_float8.h"
"quantization/attention_quantization.cc"
"quantization/attention_quantization.h"
"quantization/attention_quantization_impl.cu"
diff --git a/docs/ContribOperators.md b/docs/ContribOperators.md
index 1a76c18a6a8e0..890403556cc47 100644
--- a/docs/ContribOperators.md
+++ b/docs/ContribOperators.md
@@ -40,6 +40,7 @@ Do not modify directly.*
* com.microsoft.GatherND
* com.microsoft.Gelu
* com.microsoft.GemmFastGelu
+ * com.microsoft.GemmFloat8
* com.microsoft.GreedySearch
* com.microsoft.GridSample
* com.microsoft.GroupNorm
@@ -2137,6 +2138,71 @@ This version of the operator has been available since version 1 of the 'com.micr
+### **com.microsoft.GemmFloat8**
+
+ Generic Gemm for float and float 8.
+
+#### Version
+
+This version of the operator has been available since version 1 of the 'com.microsoft' operator set.
+
+#### Attributes
+
+
+- activation : string
+- Activation function, RELU or GELU or NONE (default).
+- alpha : float
+- Scalar multiplier for the product of input tensors A * B.
+- beta : float
+- Scalar multiplier for the product of input bias C.
+- dtype : int
+- Output Type. Same definition as attribute 'to' for operator Cast.
+- transA : int
+- Whether A should be transposed. Float 8 only supprted transA=0.
+- transB : int
+- Whether B should be transposed. Float 8 only supprted transB=1.
+
+
+#### Inputs (2 - 6)
+
+
+- A : TA
+- Input tensor A. The shape of A should be (M, K) if transA is 0, or (K, M) if transA is non-zero.
+- B : TB
+- Input tensor B. The shape of B should be (K, N) if transB is 0, or (N, K) if transB is non-zero.
+- C (optional) : TC
+- Input tensor C.
+- scaleA (optional) : TS
+- Scale of tensor A if A is float 8 tensor
+- scaleB (optional) : TS
+- Scale of tensor B if B is float 8 tensor
+- scaleY (optional) : TS
+- Scale of the output tensor if A or B is float 8.
+
+
+#### Outputs
+
+
+- Y : TR
+- Output tensor of shape (M, N).
+
+
+#### Type Constraints
+
+
+- TA : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain type to input A.
+- TB : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain type to input B.
+- TC : tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain type to input C.
+- TR : tensor(float8e4m3fn), tensor(float8e5m2), tensor(float16), tensor(bfloat16), tensor(float)
+- Constrain type to result type.
+- TS : tensor(float)
+- Constrain type for all input scales (scaleA, scaleB, scaleY).
+
+
+
### **com.microsoft.GreedySearch**
Greedy Search for text generation.
diff --git a/docs/OperatorKernels.md b/docs/OperatorKernels.md
index d047096cb8c80..bfb7716dc5cea 100644
--- a/docs/OperatorKernels.md
+++ b/docs/OperatorKernels.md
@@ -831,6 +831,7 @@ Do not modify directly.*
|FusedMatMul|*in* A:**T**
*in* B:**T**
*out* Y:**T**|1+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
|GatedRelativePositionBias|*in* query_layer:**T**
*in* query_bias:**T**
*in* rel_pos:**T**
*in* weight:**T**
*in* bias:**T**
*in* eco_a:**T**
*in* token_offset:**M**
*out* output:**T**|1+|**T** = tensor(float), tensor(float16)|
|Gelu|*in* X:**T**
*out* Y:**T**|1+|**T** = tensor(double), tensor(float), tensor(float16)|
+|GemmFloat8|*in* A:**TA**
*in* B:**TB**
*in* C:**TC**
*in* scaleA:**TS**
*in* scaleB:**TS**
*in* scaleY:**TS**
*out* Y:**TR**|1+|**TA** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TB** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TR** = tensor(bfloat16), tensor(float), tensor(float16), tensor(float8e4m3fn), tensor(float8e5m2)
**TS** = tensor(float)|
|GreedySearch|*in* input_ids:**I**
*in* max_length:**I**
*in* min_length:**I**
*in* repetition_penalty:**T**
*in* vocab_mask:**I**
*in* prefix_vocab_mask:**I**
*in* attention_mask:**I**
*out* sequences:**I**|1+|**T** = tensor(float), tensor(float16)|
|GridSample|*in* X:**T1**
*in* Grid:**T1**
*out* Y:**T2**|1+|**T1** = tensor(float)
**T2** = tensor(float)|
|GroupNorm|*in* X:**T**
*in* gamma:**M**
*in* beta:**M**
*out* Y:**T**|1+|**T** = tensor(float), tensor(float16)|
diff --git a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
index 29ca8124bfd05..e6a216795c10b 100644
--- a/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
+++ b/onnxruntime/contrib_ops/cuda/cuda_contrib_kernels.cc
@@ -144,6 +144,7 @@ class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedSelfAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, float, DecoderMaskedMultiHeadAttention);
class ONNX_OPERATOR_TYPED_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, MLFloat16, DecoderMaskedMultiHeadAttention);
+class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kMSDomain, 1, GemmFloat8);
#ifdef ENABLE_ATEN
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kCudaExecutionProvider, kPytorchAtenDomain, 1, ATen);
@@ -317,6 +318,7 @@ Status RegisterCudaContribKernels(KernelRegistry& kernel_registry) {
BuildKernelCreateInfo,
BuildKernelCreateInfo,
BuildKernelCreateInfo,
+ BuildKernelCreateInfo,
#ifdef ENABLE_ATEN
BuildKernelCreateInfo,
diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
new file mode 100644
index 0000000000000..251850f621361
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cc
@@ -0,0 +1,70 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#include
+#include "core/providers/cuda/math/gemm.h"
+#include "core/providers/cuda/cuda_common.h"
+#include "core/providers/cuda/shared_inc/fpgeneric.h"
+#include "core/providers/cpu/math/gemm_helper.h"
+#include "contrib_ops/cuda/math/gemm_float8.h"
+
+using namespace ONNX_NAMESPACE;
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+#define REGISTER_KERNEL() \
+ ONNX_OPERATOR_KERNEL_EX( \
+ GemmFloat8, \
+ kMSDomain, \
+ 1, \
+ kCudaExecutionProvider, \
+ (*KernelDefBuilder::Create()) \
+ .TypeConstraint("TA", BuildKernelDefConstraints()) \
+ .TypeConstraint("TB", BuildKernelDefConstraints()) \
+ .TypeConstraint("TR", BuildKernelDefConstraints()) \
+ .TypeConstraint("TS", BuildKernelDefConstraints()), \
+ GemmFloat8);
+
+REGISTER_KERNEL()
+
+GemmFloat8::GemmFloat8(const OpKernelInfo& info) : CudaKernel(info) {
+ transA_ = info.GetAttrOrDefault("transA", 0);
+ transB_ = info.GetAttrOrDefault("transB", 0);
+ dtype_ = info.GetAttrOrDefault("dtype", ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ auto& device_prop = GetDeviceProp();
+ sm_count_ = device_prop.multiProcessorCount;
+ alpha_ = info.GetAttrOrDefault("alpha", 1);
+ beta_ = info.GetAttrOrDefault("beta", 0);
+
+#if (CUDA_VERSION <= 12000)
+ ORT_ENFORCE(beta_ == 0, "CUDA < 12.0 does not support bias, beta must be 0.");
+#endif
+
+ std::string stemp = info.GetAttrOrDefault("activation", "NONE");
+ if (stemp == "NONE") {
+ epilogue_ = CUBLASLT_EPILOGUE_DEFAULT;
+ } else if (stemp == "RELU") {
+ epilogue_ = CUBLASLT_EPILOGUE_RELU;
+ } else if (stemp == "GELU") {
+ epilogue_ = CUBLASLT_EPILOGUE_GELU;
+ } else {
+ ORT_THROW("Unexpected value for activation: '", stemp, "'.");
+ }
+}
+
+Status GemmFloat8::SetCheck(const TensorShape& a_shape, const TensorShape& b_shape, int& M, int& N, int& K) const {
+ GemmHelper helper(a_shape, transA_, b_shape, transB_, TensorShape({}));
+ if (!helper.State().IsOK())
+ return helper.State();
+
+ M = gsl::narrow_cast(helper.M());
+ N = gsl::narrow_cast(helper.N());
+ K = gsl::narrow_cast(helper.K());
+ return helper.State();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
new file mode 100644
index 0000000000000..df25342342cd5
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
@@ -0,0 +1,402 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+//
+// The operator calls function 'cublasLtMatmul'
+// (https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul).
+// It lets the function checks what configuration is valid or not. If not, the error message
+// shows the error message 'CUBLAS_STATUS_NOT_SUPPORTED'. NVIDIA documentation provides
+// information on what attribute or type must be modified.
+// This operator requires CUDA_VERSION >= 11.8 for float 8 and CUDA_VERSION >= 12.0
+// for beta != 0.
+
+#include
+#include
+#include
+#include "contrib_ops/cuda/math/gemm_float8.h"
+#include "core/providers/cuda/cu_inc/common.cuh"
+#include "core/providers/cuda/shared_inc/cuda_utils.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+// It must exist somewhere already.
+int32_t TypeSize(int32_t element_type) {
+ switch (element_type) {
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
+ return 4;
+ case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
+ return 2;
+#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
+ return 1;
+#endif
+ default:
+ ORT_THROW("Unexpected element_type=", element_type, ".");
+ }
+}
+
+void GemmFloat8::SetParams(const TensorShape& a_shape, const TensorShape& b_shape,
+ int& M, int& N, int& K, int& lda, int& ldb, int& ldd) const {
+ int m_idx = transA_ ? 1 : 0;
+ int k_idx = 1 - m_idx;
+ int n_idx = transB_ ? 0 : 1;
+
+ M = static_cast(a_shape[m_idx]);
+ K = static_cast(a_shape[k_idx]);
+ N = static_cast(b_shape[n_idx]);
+ lda = static_cast(a_shape[1]);
+ ldb = static_cast(b_shape[1]);
+ ldd = static_cast(b_shape[n_idx]);
+}
+
+template
+int32_t GetTypeAndShape(const TValue* input,
+ TensorShape& shape,
+ bool swap = false) {
+ shape = input->Shape();
+ ORT_ENFORCE(shape.NumDimensions() == 2);
+ if (swap) {
+ std::swap(shape[0], shape[1]);
+ }
+ return input->GetElementType();
+}
+
+Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const {
+ const Tensor* input_A = nullptr;
+ const Tensor* input_B = nullptr;
+ const Tensor* input_C = nullptr;
+ const Tensor* scale_A = nullptr;
+ const Tensor* scale_B = nullptr;
+ const Tensor* scale_Y = nullptr;
+ bool has_scales = false;
+ bool has_bias = false;
+ int n_inputs = ctx->InputCount();
+
+ input_A = ctx->Input(0);
+ input_B = ctx->Input(1);
+ if (n_inputs == 3) {
+ input_C = ctx->Input(2);
+ has_bias = true;
+ } else if (n_inputs > 3) {
+ ORT_ENFORCE(n_inputs >= 5, "Unexpected number of inputs=", n_inputs, ".");
+ has_scales = true;
+ scale_A = ctx->Input(3);
+ scale_B = ctx->Input(4);
+ scale_Y = n_inputs < 6 ? nullptr : ctx->Input(5);
+ ORT_ENFORCE(scale_A->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ ORT_ENFORCE(scale_B->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ ORT_ENFORCE(scale_Y == nullptr || scale_Y->GetElementType() == ONNX_NAMESPACE::TensorProto_DataType_FLOAT);
+ if (ctx->Input(2) != nullptr) {
+ input_C = ctx->Input(2);
+ has_bias = true;
+ ORT_ENFORCE(input_C->GetElementType() == dtype_, "Bias type must be equal to dtype.");
+ }
+ }
+
+ auto first_type = input_A->GetElementType();
+ bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2;
+ if (!is_float8)
+ return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
+ input_C, scale_A, scale_B, scale_Y);
+ return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B,
+ input_C, scale_A, scale_B, scale_Y);
+}
+
+Status GemmFloat8::ComputeRowMajor(
+ OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales,
+ const Tensor* input_A, const Tensor* input_B,
+ const Tensor* input_C, const Tensor* scale_A,
+ const Tensor* scale_B, const Tensor* scale_Y) const {
+ TensorShape shape_A, shape_B, shape_C, shape_Y;
+ int32_t dtype_A, dtype_B, dtype_C, dtype_Y;
+ dtype_A = GetTypeAndShape(input_A, shape_A);
+ dtype_B = GetTypeAndShape(input_B, shape_B);
+
+ int M, N, K, lda, ldb, ldd;
+ SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd);
+
+ TensorShape dimensions{M, N};
+ Tensor* Y = ctx->Output(0, dimensions);
+ dtype_Y = GetTypeAndShape(Y, shape_Y);
+ dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C)
+ : ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
+ return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_A, dtype_B, dtype_C,
+ dtype_Y, shape_A, shape_B, shape_C, shape_Y, transA_, transB_,
+ input_A->DataRaw(), input_B->DataRaw(),
+ has_bias ? input_C->DataRaw() : nullptr,
+ has_scales ? scale_A->DataRaw() : nullptr,
+ has_scales ? scale_B->DataRaw() : nullptr,
+ has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr,
+ Y->MutableDataRaw(), M, N, K, lda, ldb, ldd, true);
+}
+
+Status GemmFloat8::ComputeColMajor(
+ OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales,
+ const Tensor* input_A, const Tensor* input_B,
+ const Tensor* input_C, const Tensor* scale_A,
+ const Tensor* scale_B, const Tensor* scale_Y) const {
+ TensorShape shape_A, shape_B, shape_C, shape_Y;
+ int32_t dtype_A, dtype_B, dtype_C, dtype_Y;
+ dtype_A = GetTypeAndShape(input_A, shape_A);
+ dtype_B = GetTypeAndShape(input_B, shape_B);
+
+ int M, N, K, lda, ldb, ldd;
+ SetParams(shape_A, shape_B, M, N, K, lda, ldb, ldd);
+
+ std::swap(shape_A[0], shape_A[1]);
+ std::swap(shape_B[0], shape_B[1]);
+
+ TensorShape dimensions{M, N};
+ Tensor* Y = ctx->Output(0, dimensions);
+ dtype_Y = GetTypeAndShape(Y, shape_Y);
+ dtype_C = has_bias ? GetTypeAndShape(input_C, shape_C, true)
+ : ONNX_NAMESPACE::TensorProto_DataType_FLOAT;
+
+ return ComputeGemm(ctx, n_inputs, has_bias, has_scales, dtype_B, dtype_A, dtype_C,
+ dtype_Y, shape_B, shape_A, shape_C, shape_Y, transB_, transA_,
+ input_B->DataRaw(), input_A->DataRaw(),
+ has_bias ? input_C->DataRaw() : nullptr,
+ has_scales ? scale_B->DataRaw() : nullptr,
+ has_scales ? scale_A->DataRaw() : nullptr,
+ has_scales && scale_Y != nullptr ? scale_Y->DataRaw() : nullptr,
+ Y->MutableDataRaw(), N, M, K, ldb, lda, ldd, false);
+}
+
+Status GemmFloat8::ComputeGemm(
+ OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales,
+ int32_t dtype_A, int32_t dtype_B,
+ int32_t dtype_C, int32_t dtype_Y,
+ const TensorShape& shape_A, const TensorShape& shape_B,
+ const TensorShape& shape_C, const TensorShape& shape_Y,
+ bool trans_A, bool trans_B, const void* p_input_a, const void* p_input_b,
+ const void* p_input_c, const void* p_scale_a, const void* p_scale_b,
+ const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda,
+ int ldb, int ldd, bool row_major_compute) const {
+ cudaStream_t stream = Stream(ctx);
+ CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream));
+
+ cublasLtHandle_t cublasLt;
+ CUBLAS_RETURN_IF_ERROR(cublasLtCreate(&cublasLt));
+
+ cublasLtMatmulDesc_t operationDesc = nullptr;
+ cublasLtMatrixLayout_t Adesc = nullptr, Bdesc = nullptr, Cdesc = nullptr,
+ Ddesc = nullptr;
+
+ // Create matrix descriptors. Not setting any extra attributes.
+ cudaDataType_t a_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_A);
+ cudaDataType_t b_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_B);
+ cudaDataType_t d_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_Y);
+ cudaDataType_t scale_cuda_type =
+ onnxruntime::cuda::ToCudaDataType(ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
+ cudaDataType_t bias_cuda_type = onnxruntime::cuda::ToCudaDataType(dtype_C);
+
+ cublasComputeType_t compute_type;
+ switch (d_cuda_type) {
+ case CUDA_R_16F:
+ switch (a_cuda_type) {
+ case CUDA_R_8F_E4M3:
+ case CUDA_R_8F_E5M2:
+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
+ break;
+ default:
+ compute_type = CUBLAS_COMPUTE_32F_FAST_16F;
+ break;
+ }
+ break;
+ case CUDA_R_16BF:
+ compute_type = CUBLAS_COMPUTE_32F_FAST_16BF;
+ break;
+ case CUDA_R_32F:
+ compute_type = CUBLAS_COMPUTE_32F_FAST_TF32;
+ break;
+ default:
+ ORT_THROW("Unable to determine computeType in operator GemmFloat8.");
+ }
+
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate(
+ &Adesc, a_cuda_type, trans_A ? K : M, trans_A ? M : K, lda));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutCreate(
+ &Bdesc, b_cuda_type, trans_B ? N : K, trans_B ? K : N, ldb));
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutCreate(&Ddesc, d_cuda_type, M, N, ldd));
+
+ if (row_major_compute) {
+ cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW;
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutSetAttribute(Adesc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &matrixOrder, sizeof(matrixOrder)));
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutSetAttribute(Bdesc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &matrixOrder, sizeof(matrixOrder)));
+ }
+
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatmulDescCreate(&operationDesc, compute_type, scale_cuda_type));
+ cublasOperation_t ctransa = trans_A ? CUBLAS_OP_T : CUBLAS_OP_N;
+ cublasOperation_t ctransb = trans_B ? CUBLAS_OP_T : CUBLAS_OP_N;
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_TRANSA, &ctransa, sizeof(ctransa)));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb)));
+
+ if (sm_count_ != 0) {
+ int math_sm_count = static_cast(sm_count_);
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count,
+ sizeof(math_sm_count)));
+ }
+
+ if (has_scales) {
+ // gemm float 8
+ const int8_t ifast_accumulation_mode = 1;
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc,
+ cublasLtMatmulDescAttributes_t::CUBLASLT_MATMUL_DESC_FAST_ACCUM,
+ &ifast_accumulation_mode, sizeof(ifast_accumulation_mode)));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, &p_scale_a,
+ sizeof(p_scale_a)));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, &p_scale_b,
+ sizeof(p_scale_b)));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y,
+ sizeof(p_scale_b)));
+
+ // float 8
+#if CUDA_VERSION >= 11080
+ if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN ||
+ dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) {
+ // For FP8 output, cuBLAS requires C_type to be same as bias_type
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutCreate(&Cdesc, bias_cuda_type, M, N, ldd));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
+ operationDesc, CUBLASLT_MATMUL_DESC_BIAS_DATA_TYPE, &bias_cuda_type,
+ sizeof(bias_cuda_type)));
+ } else {
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
+ }
+ } else {
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
+ }
+#else
+ // An output is still needed but it is not initialized.
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutCreate(&Cdesc, d_cuda_type, M, N, ldd));
+#endif
+
+ if (row_major_compute) {
+ cublasLtOrder_t matrixOrder = CUBLASLT_ORDER_ROW;
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutSetAttribute(Cdesc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &matrixOrder, sizeof(matrixOrder)));
+ CUBLAS_RETURN_IF_ERROR(
+ cublasLtMatrixLayoutSetAttribute(Ddesc, CUBLASLT_MATRIX_LAYOUT_ORDER,
+ &matrixOrder, sizeof(matrixOrder)));
+ }
+
+ cublasLtMatmulDescSetAttribute(operationDesc, CUBLASLT_MATMUL_DESC_EPILOGUE,
+ &epilogue_, sizeof(epilogue_));
+
+ // See
+ // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulPreferenceAttributes_t#cublasltmatmulpreferenceattributes-t
+ // The workspace should be allocated once from OpKernelContext assuming
+ // only one cuda function is running at a time (which is not necessarily true
+ // with H100).
+ size_t workspaceSize = static_cast(1 << 25); // suggested fixed value 32Mb
+ cublasLtMatmulPreference_t preference = nullptr;
+ cublasLtMatmulPreferenceCreate(&preference);
+ cublasLtMatmulPreferenceSetAttribute(preference,
+ CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
+ &workspaceSize, sizeof(workspaceSize));
+
+ // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmulAlgoGetHeuristic#cublasltmatmulalgogetheuristic
+ cublasLtMatmulHeuristicResult_t heuristicResult = {};
+ int returnedResults = 0;
+ cublasStatus_t cuda_status = cublasLtMatmulAlgoGetHeuristic(
+ cublasLt, operationDesc, Adesc, Bdesc, Cdesc, Ddesc, preference, 1,
+ &heuristicResult, &returnedResults);
+ ORT_ENFORCE(
+ returnedResults > 0 && cuda_status == CUBLAS_STATUS_SUCCESS,
+ " Unable to find any suitable algorithm due to ",
+ onnxruntime::cuda::cublasGetErrorEnum(cuda_status),
+ ", returnedResults=", returnedResults,
+ ", alpha=", alpha_, ", beta=", beta_, ", n_inputs=", n_inputs,
+ ", A_type=", onnxruntime::cuda::CudaDataTypeToString(a_cuda_type),
+ ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type),
+ ", C_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type),
+ ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type),
+ ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type),
+ ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type),
+ ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type),
+ ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A,
+ ", transB=", trans_B,
+ ", fastAccumulationMode=", 1,
+ ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x",
+ shape_B[1], ", shape_C=", (shape_C.NumDimensions() > 0 ? shape_C[0] : 0), "x",
+ (shape_C.NumDimensions() > 1 ? shape_C[1] : 0), ", M=", M, ", N=", N, ", K=", K,
+ ", lda=", lda, ", ldb=", ldb, ", ldd=", ldd,
+ ", workspaceSize=", workspaceSize, ", rowMajorCompute=", (row_major_compute ? 1 : 0),
+ ". Check NVIDIA documentation to see what combination is valid: ",
+ "https://docs.nvidia.com/cuda/cublas/"
+ "index.html?highlight=cublasLtMatmulAlgoGetHeuristic#"
+ "cublasltmatmulalgogetheuristic.");
+
+ void* workspace = nullptr;
+ if (workspaceSize > 0) {
+ CUDA_RETURN_IF_ERROR(cudaMalloc(reinterpret_cast(&workspace), workspaceSize));
+ }
+ // https://docs.nvidia.com/cuda/cublas/index.html?highlight=cublasLtMatmul#cublasltmatmul
+ const void* bias = has_bias ? p_input_c : p_output_y;
+ cuda_status = cublasLtMatmul(
+ cublasLt, operationDesc, static_cast(&alpha_), /* alpha */
+ p_input_a, /* A */
+ Adesc, p_input_b, /* B */
+ Bdesc, static_cast(&beta_), /* beta */
+ bias, /* C */
+ Cdesc, p_output_y, /* Y */
+ Ddesc, &heuristicResult.algo, /* algo */
+ workspace, /* workspace */
+ workspaceSize, stream); /* stream */
+ ORT_ENFORCE(
+ cuda_status == CUBLAS_STATUS_SUCCESS,
+ " Unable to run cublasLtMatmul due to ",
+ onnxruntime::cuda::cublasGetErrorEnum(cuda_status),
+ ", returnedResults=", returnedResults, ", alpha=", alpha_,
+ ", n_inputs=", n_inputs, ", A_type=",
+ onnxruntime::cuda::CudaDataTypeToString(a_cuda_type),
+ ", B_type=", onnxruntime::cuda::CudaDataTypeToString(b_cuda_type),
+ ", result_type=", onnxruntime::cuda::CudaDataTypeToString(d_cuda_type),
+ ", bias_type=", onnxruntime::cuda::CudaDataTypeToString(bias_cuda_type),
+ ", scale_type=", onnxruntime::cuda::CudaDataTypeToString(scale_cuda_type),
+ ", computeType=", onnxruntime::cuda::CublasComputeTypeToString(compute_type),
+ ", epilogue=", epilogue_, ", smCount=", sm_count_, ", transA=", trans_A,
+ ", transB=", trans_B,
+ ", fastAccumulationMode=", 1,
+ ", shape_A=", shape_A[0], "x", shape_A[1], ", shape_B=", shape_B[0], "x",
+ shape_B[1], ", M=", M, ", N=", N, ", K=", K, ", lda=", lda, ", ldb=", ldb,
+ ", ldd=", ldd, ", workspaceSize=", workspaceSize,
+ ", rowMajorCompute=", (row_major_compute ? 1 : 0), ".");
+
+ if (workspaceSize > 0) {
+ CUDA_RETURN_IF_ERROR(cudaFree(workspace));
+ }
+
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulPreferenceDestroy(preference));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Ddesc));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Cdesc));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Bdesc));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatrixLayoutDestroy(Adesc));
+ CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescDestroy(operationDesc));
+ CUBLAS_RETURN_IF_ERROR(cublasLtDestroy(cublasLt));
+ return Status::OK();
+}
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.h b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h
new file mode 100644
index 0000000000000..e84ccd55b2003
--- /dev/null
+++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.h
@@ -0,0 +1,65 @@
+// Copyright (c) Microsoft Corporation. All rights reserved.
+// Licensed under the MIT License.
+
+#pragma once
+
+#include "cublas_v2.h"
+#include "core/providers/cuda/cuda_kernel.h"
+
+namespace onnxruntime {
+namespace contrib {
+namespace cuda {
+
+// Calls https://docs.nvidia.com/cuda/cublas/index.html#cublasltmatmul.
+// D = alpha*(A*B)
+class GemmFloat8 final : public onnxruntime::cuda::CudaKernel {
+ public:
+ GemmFloat8(const OpKernelInfo& info);
+
+ Status ComputeInternal(OpKernelContext* context) const override;
+
+ private:
+ void SetParams(const TensorShape& shape_a,
+ const TensorShape& shape_b,
+ int& M, int& N, int& K,
+ int& lda, int& ldb, int& ldd) const;
+ Status SetCheck(const TensorShape& shape_a,
+ const TensorShape& shape_b,
+ int& M, int& N, int& K) const;
+
+ Status ComputeRowMajor(OpKernelContext* ctx, int n_inputs, bool has_bias,
+ bool has_scales, const Tensor* input_A,
+ const Tensor* input_B, const Tensor* input_C,
+ const Tensor* scale_A, const Tensor* scale_B,
+ const Tensor* scale_Y) const;
+ Status ComputeColMajor(OpKernelContext* ctx, int n_inputs, bool has_bias,
+ bool has_scales, const Tensor* input_A,
+ const Tensor* input_B, const Tensor* input_C,
+ const Tensor* scale_A, const Tensor* scale_B,
+ const Tensor* scale_Y) const;
+
+ Status ComputeGemm(
+ OpKernelContext* ctx, int n_inputs, bool has_bias, bool has_scales,
+ int32_t dtype_A, int32_t dtype_b,
+ int32_t dtype_c, int32_t dtype_Y,
+ const TensorShape& shape_A, const TensorShape& shape_B,
+ const TensorShape& shape_C, const TensorShape& shape_Y,
+ bool transa, bool transb, const void* p_input_a, const void* p_input_b,
+ const void* p_input_c, const void* p_scale_a, const void* p_scale_b,
+ const void* p_scale_y, void* p_output_y, int M, int N, int K, int lda,
+ int ldb, int ldd, bool row_major_compute) const;
+
+ float alpha_;
+ float beta_;
+ bool transA_;
+ bool transB_;
+ int64_t sm_count_;
+ int64_t dtype_;
+ cublasLtEpilogue_t epilogue_;
+
+ // TODO(xadupre): add epilogue (= activation function, Relu or Gelu are available).
+};
+
+} // namespace cuda
+} // namespace contrib
+} // namespace onnxruntime
diff --git a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
index 681a728f823da..e757e39130d39 100644
--- a/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
+++ b/onnxruntime/core/graph/contrib_ops/contrib_defs.cc
@@ -2573,6 +2573,124 @@ ONNX_MS_OPERATOR_SET_SCHEMA(CropAndResize, 1,
a fixed size = [crop_height, crop_width]. The result is a 4-D tensor [num_boxes, crop_height, crop_width, depth].
The resizing is corner aligned.)DOC"));
+#if !defined(DISABLE_FLOAT8_TYPES)
+#define GEMM_FLOAT8_TYPES \
+ { "tensor(float8e4m3fn)", "tensor(float8e5m2)", "tensor(float16)", "tensor(bfloat16)", "tensor(float)" }
+#else
+#define GEMM_FLOAT8_TYPES \
+ { "tensor(float16)", "tensor(bfloat16)", "tensor(float)" }
+#endif
+
+ONNX_MS_OPERATOR_SET_SCHEMA(GemmFloat8, 1,
+ OpSchema()
+ .SetDoc(R"DOC(Generic Gemm for float and float 8.)DOC")
+ .Attr(
+ "transA",
+ "Whether A should be transposed. Float 8 only supprted transA=0.",
+ AttributeProto::INT,
+ static_cast(0))
+ .Attr(
+ "transB",
+ "Whether B should be transposed. Float 8 only supprted transB=1.",
+ AttributeProto::INT,
+ static_cast(0))
+ .Attr(
+ "alpha",
+ "Scalar multiplier for the product of input tensors A * B.",
+ AttributeProto::FLOAT,
+ 1.0f)
+ .Attr(
+ "beta",
+ "Scalar multiplier for the product of input bias C.",
+ AttributeProto::FLOAT,
+ 0.0f)
+ .Attr(
+ "dtype",
+ "Output Type. Same definition as attribute 'to' for operator Cast.",
+ AttributeProto::INT,
+ static_cast(1))
+ .Attr(
+ "activation",
+ "Activation function, RELU or GELU or NONE (default).",
+ AttributeProto::STRING,
+ OPTIONAL_VALUE)
+ .Input(
+ 0,
+ "A",
+ "Input tensor A. "
+ "The shape of A should be (M, K) if transA is 0, "
+ "or (K, M) if transA is non-zero.",
+ "TA")
+ .Input(
+ 1,
+ "B",
+ "Input tensor B. "
+ "The shape of B should be (K, N) if transB is 0, "
+ "or (N, K) if transB is non-zero.",
+ "TB")
+ .Input(
+ 2,
+ "C",
+ "Input tensor C.",
+ "TC",
+ OpSchema::Optional)
+ .Input(
+ 3,
+ "scaleA",
+ "Scale of tensor A if A is float 8 tensor",
+ "TS",
+ OpSchema::Optional)
+ .Input(
+ 4,
+ "scaleB",
+ "Scale of tensor B if B is float 8 tensor",
+ "TS",
+ OpSchema::Optional)
+ .Input(
+ 5,
+ "scaleY",
+ "Scale of the output tensor if A or B is float 8.",
+ "TS",
+ OpSchema::Optional)
+ .Output(0, "Y", "Output tensor of shape (M, N).", "TR")
+ .TypeConstraint(
+ "TA",
+ GEMM_FLOAT8_TYPES,
+ "Constrain type to input A.")
+ .TypeConstraint(
+ "TB",
+ GEMM_FLOAT8_TYPES,
+ "Constrain type to input B.")
+ .TypeConstraint(
+ "TC",
+ {"tensor(float16)", "tensor(bfloat16)", "tensor(float)"},
+ "Constrain type to input C.")
+ .TypeConstraint(
+ "TR",
+ GEMM_FLOAT8_TYPES,
+ "Constrain type to result type.")
+ .TypeConstraint("TS", {"tensor(float)"},
+ "Constrain type for all input scales (scaleA, scaleB, scaleY).")
+ .TypeAndShapeInferenceFunction([](ONNX_NAMESPACE::InferenceContext& ctx) {
+ propagateElemTypeFromAttributeToOutput(ctx, "dtype", 0, TensorProto::FLOAT);
+ if (!hasNInputShapes(ctx, 2)) {
+ return;
+ }
+ auto transAAttr = ctx.getAttribute("transA");
+ bool transA = transAAttr ? static_cast(transAAttr->i()) != 0 : false;
+ auto transBAttr = ctx.getAttribute("transB");
+ bool transB = transBAttr ? static_cast(transBAttr->i()) != 0 : false;
+ auto& first_input_shape = getInputShape(ctx, 0);
+ auto& second_input_shape = getInputShape(ctx, 1);
+ if (first_input_shape.dim_size() != 2) {
+ fail_shape_inference("First input does not have rank 2");
+ }
+ if (second_input_shape.dim_size() != 2) {
+ fail_shape_inference("Second input does not have rank 2");
+ }
+ updateOutputShape(ctx, 0, {first_input_shape.dim(transA ? 1 : 0), second_input_shape.dim(transB ? 0 : 1)});
+ }));
+
static void MatmulWithQuantWeightShapeInference(ONNX_NAMESPACE::InferenceContext& ctx,
int64_t K,
int64_t N) {
diff --git a/onnxruntime/core/graph/contrib_ops/ms_opset.h b/onnxruntime/core/graph/contrib_ops/ms_opset.h
index afaa380d6ac79..aa31f3b5a7c62 100644
--- a/onnxruntime/core/graph/contrib_ops/ms_opset.h
+++ b/onnxruntime/core/graph/contrib_ops/ms_opset.h
@@ -112,6 +112,7 @@ class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, WordConvEmbedding);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFastGelu);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedSelfAttention);
class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, DecoderMaskedMultiHeadAttention);
+class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Microsoft, 1, GemmFloat8);
class OpSet_Microsoft_ver1 {
public:
@@ -218,6 +219,7 @@ class OpSet_Microsoft_ver1 {
fn(GetOpSchema());
fn(GetOpSchema());
fn(GetOpSchema());
+ fn(GetOpSchema());
}
};
} // namespace contrib
diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc
index 57477f167c555..288ca8e97e34d 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.cc
+++ b/onnxruntime/core/providers/cuda/cuda_common.cc
@@ -27,5 +27,90 @@ const HalfGemmOptions* HalfGemmOptions::GetInstance() {
return &instance;
}
+const char* cublasGetErrorEnum(cublasStatus_t error) {
+ switch (error) {
+ case CUBLAS_STATUS_SUCCESS:
+ return "CUBLAS_STATUS_SUCCESS";
+ case CUBLAS_STATUS_NOT_INITIALIZED:
+ return "CUBLAS_STATUS_NOT_INITIALIZED";
+ case CUBLAS_STATUS_ALLOC_FAILED:
+ return "CUBLAS_STATUS_ALLOC_FAILED";
+ case CUBLAS_STATUS_INVALID_VALUE:
+ return "CUBLAS_STATUS_INVALID_VALUE";
+ case CUBLAS_STATUS_ARCH_MISMATCH:
+ return "CUBLAS_STATUS_ARCH_MISMATCH";
+ case CUBLAS_STATUS_MAPPING_ERROR:
+ return "CUBLAS_STATUS_MAPPING_ERROR";
+ case CUBLAS_STATUS_EXECUTION_FAILED:
+ return "CUBLAS_STATUS_EXECUTION_FAILED";
+ case CUBLAS_STATUS_INTERNAL_ERROR:
+ return "CUBLAS_STATUS_INTERNAL_ERROR";
+ case CUBLAS_STATUS_NOT_SUPPORTED:
+ return "CUBLAS_STATUS_NOT_SUPPORTED";
+ case CUBLAS_STATUS_LICENSE_ERROR:
+ return "CUBLAS_STATUS_LICENSE_ERROR";
+ default:
+ return "";
+ }
+}
+
+const char* CudaDataTypeToString(cudaDataType_t dt) {
+ switch (dt) {
+ case CUDA_R_16F:
+ return "CUDA_R_16F";
+ case CUDA_R_16BF:
+ return "CUDA_R_16BF";
+ case CUDA_R_32F:
+ return "CUDA_R_32F";
+#if (CUDA_VERSION >= 11080)
+ case CUDA_R_8F_E4M3:
+ return "CUDA_R_8F_E4M3";
+ case CUDA_R_8F_E5M2:
+ return "CUDA_R_8F_E5M2";
+#endif
+ default:
+ return "";
+ }
+}
+
+const char* CublasComputeTypeToString(cublasComputeType_t ct) {
+ switch (ct) {
+ case CUBLAS_COMPUTE_16F:
+ return "CUBLAS_COMPUTE_16F";
+ case CUBLAS_COMPUTE_32F:
+ return "CUBLAS_COMPUTE_32F";
+ case CUBLAS_COMPUTE_32F_FAST_16F:
+ return "CUBLAS_COMPUTE_32F_FAST_16F";
+ case CUBLAS_COMPUTE_32F_FAST_16BF:
+ return "CUBLAS_COMPUTE_32F_FAST_16BF";
+ case CUBLAS_COMPUTE_32F_FAST_TF32:
+ return "CUBLAS_COMPUTE_32F_FAST_TF32";
+ case CUBLAS_COMPUTE_64F:
+ return "CUBLAS_COMPUTE_64F";
+ default:
+ return "";
+ }
+}
+
+// It must exist somewhere already.
+cudaDataType_t ToCudaDataType(int32_t element_type) {
+ switch (element_type) {
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT:
+ return CUDA_R_32F;
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16:
+ return CUDA_R_16F;
+ case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16:
+ return CUDA_R_16BF;
+#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080))
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN:
+ return CUDA_R_8F_E4M3;
+ case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2:
+ return CUDA_R_8F_E5M2;
+#endif
+ default:
+ ORT_THROW("Unexpected element_type=", element_type, ".");
+ }
+}
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/core/providers/cuda/cuda_common.h b/onnxruntime/core/providers/cuda/cuda_common.h
index fa258961f1155..9cd4e721ccab8 100644
--- a/onnxruntime/core/providers/cuda/cuda_common.h
+++ b/onnxruntime/core/providers/cuda/cuda_common.h
@@ -11,6 +11,7 @@
#include "core/providers/shared_library/provider_api.h"
#include "core/common/status.h"
+#include "core/framework/float8.h"
#include "core/framework/float16.h"
#include "core/providers/cuda/cuda_pch.h"
#include "core/providers/cuda/shared_inc/cuda_call.h"
@@ -48,6 +49,33 @@ class ToCudaType {
}
};
+template <>
+class ToCudaType {
+ public:
+ typedef BFloat16 MappedType;
+ static MappedType FromFloat(float f) {
+ return MappedType(f);
+ }
+};
+
+template <>
+class ToCudaType {
+ public:
+ typedef Float8E4M3FN MappedType;
+ static MappedType FromFloat(float f) {
+ return MappedType(f);
+ }
+};
+
+template <>
+class ToCudaType {
+ public:
+ typedef Float8E5M2 MappedType;
+ static MappedType FromFloat(float f) {
+ return MappedType(f);
+ }
+};
+
inline bool CalculateFdmStrides(gsl::span p, const std::vector& dims) {
int stride = 1;
if (dims.empty() || p.size() < dims.size())
@@ -152,5 +180,13 @@ class HalfGemmOptions {
static HalfGemmOptions instance;
};
+const char* cublasGetErrorEnum(cublasStatus_t error);
+
+const char* CudaDataTypeToString(cudaDataType_t dt);
+
+const char* CublasComputeTypeToString(cublasComputeType_t ct);
+
+cudaDataType_t ToCudaDataType(int32_t element_type);
+
} // namespace cuda
} // namespace onnxruntime
diff --git a/onnxruntime/python/tools/symbolic_shape_infer.py b/onnxruntime/python/tools/symbolic_shape_infer.py
index 272727a9f5375..ef1c46b83946a 100755
--- a/onnxruntime/python/tools/symbolic_shape_infer.py
+++ b/onnxruntime/python/tools/symbolic_shape_infer.py
@@ -198,6 +198,7 @@ def __init__(self, int_max, auto_merge, guess_output_rank, verbose, prefix=""):
"GatedRelativePositionBias": self._infer_GatedRelativePositionBias,
"Gelu": self._infer_Gelu,
"GemmFastGelu": self._infer_GemmFastGelu,
+ "GemmFloat8": self._infer_GemmFloat8,
"GroupNorm": self._infer_GroupNorm,
"LayerNormalization": self._infer_LayerNormalization,
"LongformerAttention": self._infer_LongformerAttention,
@@ -2317,6 +2318,9 @@ def _infer_QuickGelu(self, node): # noqa: N802
def _infer_GemmFastGelu(self, node): # noqa: N802
self._compute_matmul_shape(node)
+ def _infer_GemmFloat8(self, node): # noqa: N802
+ self._compute_matmul_shape(node)
+
def _infer_LayerNormalization(self, node): # noqa: N802
self._propagate_shape_and_type(node)
if len(node.output) > 1:
diff --git a/onnxruntime/test/onnx/main.cc b/onnxruntime/test/onnx/main.cc
index de5431ca4a460..0526ccca5bb4e 100644
--- a/onnxruntime/test/onnx/main.cc
+++ b/onnxruntime/test/onnx/main.cc
@@ -761,6 +761,7 @@ select from 'TF8', 'TF16', 'UINT8', 'FLOAT', 'ITENSOR'. \n)");
ORT_TSTR("sce_none_weights_expanded")};
std::unordered_set> all_disabled_tests(std::begin(immutable_broken_tests), std::end(immutable_broken_tests));
+
if (enable_cuda) {
all_disabled_tests.insert(std::begin(cuda_flaky_tests), std::end(cuda_flaky_tests));
}
diff --git a/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
new file mode 100644
index 0000000000000..784ae8ce70bd8
--- /dev/null
+++ b/onnxruntime/test/python/onnxruntime_test_float8_gemm8.py
@@ -0,0 +1,284 @@
+# Copyright (c) Microsoft Corporation. All rights reserved.
+# Licensed under the MIT License.
+# pylint: disable=C0116,W0212,R1720,C0103,C0114
+#
+# Note: the precision is different on V100, H100 even with the same code.
+# The thresholds were adjusted on H100 as the precision seems lower on this machine.
+
+import itertools
+import unittest
+import warnings
+
+import numpy as np
+import parameterized
+from numpy.testing import assert_allclose
+from onnx import TensorProto
+from onnx.checker import check_model
+from onnx.helper import make_graph, make_model, make_node, make_opsetid, make_tensor_value_info
+from onnx.numpy_helper import from_array
+
+from onnxruntime import InferenceSession
+
+
+class TestFloat8Gemm8(unittest.TestCase):
+ def get_model_gemm(
+ self,
+ float_name,
+ alpha=1.0,
+ beta=0.0,
+ transA=0,
+ transB=0,
+ domain="",
+ dtype=TensorProto.FLOAT,
+ activation="NONE",
+ ):
+ proto_type = getattr(TensorProto, float_name)
+ use_f8 = proto_type in (TensorProto.FLOAT8E4M3FN, TensorProto.FLOAT8E5M2)
+
+ a = make_tensor_value_info("A", TensorProto.FLOAT, [None, None])
+ b = make_tensor_value_info("B", TensorProto.FLOAT, [None, None])
+ d = make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])
+
+ inits = []
+ kwargs = {}
+ node_inputs = ["Af", "Bf"]
+ inputs = [a, b]
+ bias = beta != 0
+ if bias:
+ inputs.append(make_tensor_value_info("C", TensorProto.FLOAT, [None, None]))
+ node_inputs = ["Af", "Bf", "Cf"]
+ if use_f8:
+ node_inputs.extends(["one"] * 3)
+ elif use_f8:
+ node_inputs.append("")
+ node_inputs.extend(["one"] * 3)
+
+ if use_f8:
+ assert domain == "com.microsoft"
+ inits.append(from_array(np.array([1], dtype=np.float32), name="one"))
+ kwargs = dict(
+ domain=domain,
+ dtype=dtype,
+ )
+ if activation is not None:
+ kwargs["activation"] = activation
+ op_name = "GemmFloat8"
+ elif domain == "com.microsoft":
+ op_name = "GemmFloat8"
+ kwargs = dict(
+ domain=domain,
+ dtype=dtype,
+ )
+ else:
+ op_name = "Gemm"
+ nodes = [
+ make_node("Cast", ["A"], ["Af"], to=proto_type),
+ make_node("Cast", ["B"], ["Bf"], to=proto_type),
+ make_node("Cast", ["C"], ["Cf"], to=proto_type) if bias else None,
+ make_node(
+ op_name,
+ node_inputs,
+ ["Yf"],
+ transA=transA,
+ transB=transB,
+ alpha=alpha,
+ beta=beta,
+ **kwargs,
+ ),
+ make_node("Cast", ["Yf"], ["Y"], to=TensorProto.FLOAT),
+ ]
+ nodes = [n for n in nodes if n is not None]
+ graph = make_graph(nodes, "gemm", inputs, [d], inits)
+ onnx_model = make_model(graph, opset_imports=[make_opsetid("", 19)], ir_version=9)
+ if domain != "com.microsoft":
+ check_model(onnx_model)
+ return onnx_model
+
+ def common_test_model_gemm(self, float_type, mul=0.33, atol=0, rtol=0, square=True, **kwargs):
+ if square:
+ a = (np.arange(256) * 0.01).astype(np.float32).reshape((-1, 16))
+ b = (np.arange(256) * -0.01).astype(np.float32).reshape((-1, 16))
+ c = (np.arange(256) * 0.03).astype(np.float32).reshape((-1, 16))
+ b[:, 0] += 1
+ else:
+ a = (np.arange(256) / 256).astype(np.float32).reshape((32, -1))
+ b = (np.arange(512) / 512).astype(np.float32).reshape((32, -1))
+ c = (np.arange(128) / 128).astype(np.float32).reshape((8, 16))
+
+ feeds = {"A": a, "B": b}
+
+ expected = (a.T if kwargs.get("transA", 0) else a) @ (b.T if kwargs.get("transB", 0) else b)
+ expected *= kwargs.get("alpha", 1.0)
+ if kwargs.get("beta", 0) != 0:
+ expected += kwargs["beta"] * c
+ feeds["C"] = c
+
+ onnx_model = self.get_model_gemm("FLOAT", **kwargs)
+
+ ref = InferenceSession(
+ onnx_model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
+ )
+ y = ref.run(None, feeds)[0]
+ if float_type in ("FLOAT", "FLOAT16"):
+ try:
+ assert_allclose(expected, y, atol=atol, rtol=rtol)
+ except Exception as e:
+
+ def check(f):
+ try:
+ return f()[:2, :2]
+ except Exception as e:
+ return str(e)
+
+ raise AssertionError(
+ f"Gemm ERROR len(inputs)={len(feeds)}"
+ f"\na@b=\n{check(lambda:a@b)}"
+ f"\na.T@b=\n{check(lambda:a.T@b)}"
+ f"\na@b.T=\n{check(lambda:a@b.T)}"
+ f"\na.T@b.T=\n{check(lambda:a.T@b.T)}"
+ f"\n----\nb@a=\n{check(lambda:b@a)}"
+ f"\nb.T@a=\n{check(lambda:b.T@a)}"
+ f"\nb@a.T=\n{check(lambda:b@a.T)}"
+ f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}"
+ f"\n----\nexpected=\n{expected[:2,:2]}"
+ f"\n----\ngot=\n{y[:2,:2]}"
+ f"\nkwargs={kwargs}"
+ ) from e
+
+ self.assertEqual(expected.shape, y.shape)
+ self.assertEqual(expected.dtype, y.dtype)
+
+ onnx_model_f8 = self.get_model_gemm(float_type, domain="com.microsoft", **kwargs)
+ try:
+ ref8 = InferenceSession(
+ onnx_model_f8.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"]
+ )
+ except Exception as e:
+ if "CUDA < 12.0 does not support bias" in str(e):
+ return
+ raise AssertionError(f"Could not load model {onnx_model_f8}") from e
+ try:
+ y = ref8.run(None, feeds)[0]
+ except Exception as e:
+ if "CUBLAS_STATUS_NOT_SUPPORTED" in str(e):
+ # Skipping. This machine does not support float8.
+ warnings.warn("unable to test with float8 on this machine.")
+ return
+ raise AssertionError(f"Could not execute model {onnx_model_f8}") from e
+ try:
+ assert_allclose(expected, y, atol=atol, rtol=rtol)
+ except Exception as e:
+
+ def check(f):
+ try:
+ return f()[:2, :2]
+ except Exception as e:
+ return str(e)
+
+ raise AssertionError(
+ f"Gemm ERROR len(inputs)={len(feeds)}"
+ f"\na@b=\n{check(lambda:a@b)}"
+ f"\na.T@b=\n{check(lambda:a.T@b)}"
+ f"\na@b.T=\n{check(lambda:a@b.T)}"
+ f"\na.T@b.T=\n{check(lambda:a.T@b.T)}"
+ f"\n----\nb@a=\n{check(lambda:b@a)}"
+ f"\nb.T@a=\n{check(lambda:b.T@a)}"
+ f"\nb@a.T=\n{check(lambda:b@a.T)}"
+ f"\nb.T@a.T=\n{check(lambda:b.T@a.T)}"
+ f"\n----\nexpected=\n{expected[:2,:2]}"
+ f"\n----\ngot=\n{y[:2,:2]}"
+ f"\nkwargs={kwargs}"
+ ) from e
+ self.assertEqual(expected.shape, y.shape)
+ self.assertEqual(expected.dtype, y.dtype)
+
+ def test_model_gemm_float(self):
+ self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3)
+
+ def test_model_gemm_float_default_values(self):
+ self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation=None)
+
+ def test_model_gemm_float_relu(self):
+ self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="RELU")
+
+ def test_model_gemm_float_gelu(self):
+ self.common_test_model_gemm("FLOAT", transA=1, rtol=1e-3, activation="GELU")
+
+ def test_model_gemm_float_bias(self):
+ self.common_test_model_gemm("FLOAT", transA=1, beta=1.0, rtol=1e-3)
+
+ def test_model_gemm_float16(self):
+ self.common_test_model_gemm(
+ "FLOAT16",
+ rtol=1e-2,
+ dtype=TensorProto.FLOAT16,
+ transB=1,
+ )
+
+ def test_model_gemm_float8_e4m3(self):
+ self.common_test_model_gemm(
+ "FLOAT8E4M3FN",
+ rtol=0.5,
+ dtype=TensorProto.FLOAT,
+ transA=0,
+ transB=1,
+ alpha=10.0,
+ )
+
+ @parameterized.parameterized.expand(list(itertools.product([0, 1], [0, 1])))
+ def test_combinations_square_matrices(self, transA, transB):
+ self.common_test_model_gemm("FLOAT", transA=transA, transB=transB, rtol=1e-3)
+
+ @parameterized.parameterized.expand(
+ [
+ ((2, 3), (3, 5), 0, 0),
+ ((2, 3), (5, 3), 0, 1),
+ ((2, 3), (5, 2), 1, 1),
+ ((2, 3), (2, 5), 1, 0),
+ ]
+ )
+ def test_combinations(self, shapeA, shapeB, transA, transB):
+ model = make_model(
+ make_graph(
+ [
+ make_node(
+ "GemmFloat8",
+ ["A", "B"],
+ ["Y"],
+ transA=transA,
+ transB=transB,
+ domain="com.microsoft",
+ )
+ ],
+ "f8",
+ [
+ make_tensor_value_info("A", TensorProto.FLOAT, [None, None]),
+ make_tensor_value_info("B", TensorProto.FLOAT, [None, None]),
+ ],
+ [make_tensor_value_info("Y", TensorProto.FLOAT, [None, None])],
+ )
+ )
+
+ sess = InferenceSession(model.SerializeToString(), providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
+ a = np.arange(np.prod(shapeA)).reshape(shapeA).astype(np.float32)
+ b = np.arange(np.prod(shapeB)).reshape(shapeB).astype(np.float32)
+ try:
+ expected = (a.T if transA else a) @ (b.T if transB else b)
+ except Exception as e:
+ raise AssertionError(
+ f"Unable to multiply shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}"
+ ) from e
+ try:
+ got = sess.run(None, {"A": a, "B": b})
+ except Exception as e:
+ raise AssertionError(
+ f"Unable to run Gemm with shapes={shapeA}x{shapeB}, transA={transA}, transB={transB}"
+ ) from e
+ self.assertEqual(expected.shape, got[0].shape)
+ self.assertEqual(expected.dtype, got[0].dtype)
+ assert_allclose(expected, got[0])
+
+
+if __name__ == "__main__":
+ # TestFloat8Gemm8().test_model_gemm_float()
+ unittest.main(verbosity=2)