From 44b58437402b207c8216f3be8c75accb7409be1c Mon Sep 17 00:00:00 2001 From: pengwa Date: Fri, 8 Dec 2023 21:01:34 +0800 Subject: [PATCH] Fix gemm_float8 build failure on CUDA 11.3-11.7 (#18760) ### Fix gemm_float8 build failure on CUDA 11.3 ~ 11.7 User env: CUDA 11.3, build option include "--disable_types float8" ``` /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(256): error: identifier "CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET" is undefined /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(264): error: enum "cublasLtMatmulDescAttributes_t" has no member "CUBLASLT_MATMUL_DESC_FAST_ACCUM" /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(268): error: identifier "CUBLASLT_MATMUL_DESC_A_SCALE_POINTER" is undefined /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(271): error: identifier "CUBLASLT_MATMUL_DESC_B_SCALE_POINTER" is undefined /tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(274): error: identifier "CUBLASLT_MATMUL_DESC_D_SCALE_POINTER" is undefined 5 errors detected in the compilation of "/tmp/onnxruntime/onnxruntime/contrib_ops/cu ``` Here is a versions (major version) diff on the requested attributes: ``` cuda 11.5.1 no CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET cuda 11.6 https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf has CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET cuda 11.7 no CUBLASLT_MATMUL_DESC_FAST_ACCUM no CUBLASLT_MATMUL_DESC_A_SCALE_POINTER no CUBLASLT_MATMUL_DESC_B_SCALE_POINTER no CUBLASLT_MATMUL_DESC_D_SCALE_POINTER cuda 11.8 https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf has CUBLASLT_MATMUL_DESC_FAST_ACCUM has CUBLASLT_MATMUL_DESC_A_SCALE_POINTER has CUBLASLT_MATMUL_DESC_A_SCALE_POINTER has CUBLASLT_MATMUL_DESC_B_SCALE_POINTER has CUBLASLT_MATMUL_DESC_D_SCALE_POINTER ``` ### Motivation and Context --- onnxruntime/contrib_ops/cuda/math/gemm_float8.cu | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu index 56b541f5256bf..064b6dd392437 100644 --- a/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu +++ b/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu @@ -251,15 +251,21 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb))); +#if CUDA_VERSION >= 11060 + // CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf if (sm_count_ != 0) { int math_sm_count = static_cast(sm_count_); CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count, sizeof(math_sm_count))); } +#endif if (has_scales) { // gemm float 8 +#if CUDA_VERSION >= 11080 + // CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER, + // CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf const int8_t ifast_accumulation_mode = 1; CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, @@ -274,6 +280,7 @@ Status GemmFloat8::ComputeGemm( CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute( operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y, sizeof(p_scale_b))); +#endif // float 8 #if !defined(DISABLE_FLOAT8_TYPES)