From 90dc79bc2381328994b5d175d212802af97e2ff0 Mon Sep 17 00:00:00 2001 From: Xavier Dupre Date: Fri, 3 Nov 2023 09:25:39 +0100 Subject: [PATCH] improve flags --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 11 +++++++++-- onnxruntime/core/providers/cuda/cuda_common.cc | 5 +++-- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index df25342342cd5..d0c52aa570749 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -28,7 +28,7 @@ int32_t TypeSize(int32_t element_type) { case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT16: return 2; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: return 1; @@ -97,12 +97,16 @@ Status GemmFloat8::ComputeInternal(OpKernelContext* ctx) const { } auto first_type = input_A->GetElementType(); +#if !defined(DISABLE_FLOAT8_TYPES) bool is_float8 = first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN || first_type == ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2; if (!is_float8) +#endif return ComputeRowMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#if !defined(DISABLE_FLOAT8_TYPES) return ComputeColMajor(ctx, n_inputs, has_bias, has_scales, input_A, input_B, input_C, scale_A, scale_B, scale_Y); +#endif } Status GemmFloat8::ComputeRowMajor( @@ -197,10 +201,13 @@ Status GemmFloat8::ComputeGemm( switch (d_cuda_type) { case CUDA_R_16F: switch (a_cuda_type) { +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 case CUDA_R_8F_E4M3: case CUDA_R_8F_E5M2: compute_type = CUBLAS_COMPUTE_32F_FAST_TF32; break; +#endif default: compute_type = CUBLAS_COMPUTE_32F_FAST_16F; break; @@ -267,7 +274,7 @@ Status GemmFloat8::ComputeGemm( sizeof(p_scale_b))); // float 8 -#if CUDA_VERSION >= 11080 +#if !defined(DISABLE_FLOAT8_TYPES) if (dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN || dtype_Y == ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2) { // For FP8 output, cuBLAS requires C_type to be same as bias_type diff --git a/onnxruntime/core/providers/cuda/cuda_common.cc b/onnxruntime/core/providers/cuda/cuda_common.cc index 288ca8e97e34d..33f2938940e4d 100644 --- a/onnxruntime/core/providers/cuda/cuda_common.cc +++ b/onnxruntime/core/providers/cuda/cuda_common.cc @@ -62,7 +62,8 @@ const char* CudaDataTypeToString(cudaDataType_t dt) { return "CUDA_R_16BF"; case CUDA_R_32F: return "CUDA_R_32F"; -#if (CUDA_VERSION >= 11080) +#if !defined(DISABLE_FLOAT8_TYPES) + // Note: CUDA_R_8F_E4M3 is defined with CUDA>=11.8 case CUDA_R_8F_E4M3: return "CUDA_R_8F_E4M3"; case CUDA_R_8F_E5M2: @@ -101,7 +102,7 @@ cudaDataType_t ToCudaDataType(int32_t element_type) { return CUDA_R_16F; case ONNX_NAMESPACE::TensorProto_DataType_BFLOAT16: return CUDA_R_16BF; -#if (!defined(DISABLE_FLOAT8_TYPES) && (CUDA_VERSION >= 11080)) +#if !defined(DISABLE_FLOAT8_TYPES) case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E4M3FN: return CUDA_R_8F_E4M3; case ONNX_NAMESPACE::TensorProto_DataType_FLOAT8E5M2: