Skip to content

Commit

Permalink
Fix gemm_float8 build failure on CUDA 11.3-11.7 (#18760)
Browse files Browse the repository at this point in the history
### Fix gemm_float8 build failure on CUDA 11.3 ~ 11.7

User env: CUDA 11.3, build option include "--disable_types float8"


```

/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(256): error: identifier "CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET" is undefined

/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(264): error: enum "cublasLtMatmulDescAttributes_t" has no member "CUBLASLT_MATMUL_DESC_FAST_ACCUM"

/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(268): error: identifier "CUBLASLT_MATMUL_DESC_A_SCALE_POINTER" is undefined

/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(271): error: identifier "CUBLASLT_MATMUL_DESC_B_SCALE_POINTER" is undefined

/tmp/onnxruntime/onnxruntime/contrib_ops/cuda/math/gemm_float8.cu(274): error: identifier "CUBLASLT_MATMUL_DESC_D_SCALE_POINTER" is undefined

5 errors detected in the compilation of "/tmp/onnxruntime/onnxruntime/contrib_ops/cu

```

Here is a versions (major version) diff on the requested attributes:

```

cuda 11.5.1

no CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET


cuda 11.6

https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf

has CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET



cuda 11.7

no CUBLASLT_MATMUL_DESC_FAST_ACCUM

no CUBLASLT_MATMUL_DESC_A_SCALE_POINTER

no CUBLASLT_MATMUL_DESC_B_SCALE_POINTER

no CUBLASLT_MATMUL_DESC_D_SCALE_POINTER



cuda 11.8

https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf

has CUBLASLT_MATMUL_DESC_FAST_ACCUM

has CUBLASLT_MATMUL_DESC_A_SCALE_POINTER

has CUBLASLT_MATMUL_DESC_A_SCALE_POINTER

has CUBLASLT_MATMUL_DESC_B_SCALE_POINTER

has CUBLASLT_MATMUL_DESC_D_SCALE_POINTER


```



### Motivation and Context
<!-- - Why is this change required? What problem does it solve?
- If it fixes an open issue, please link to the issue here. -->
  • Loading branch information
pengwa authored Dec 8, 2023
1 parent e8f33b5 commit 44b5843
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions onnxruntime/contrib_ops/cuda/math/gemm_float8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -251,15 +251,21 @@ Status GemmFloat8::ComputeGemm(
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_TRANSB, &ctransb, sizeof(ctransb)));

#if CUDA_VERSION >= 11060
// CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET exists from https://docs.nvidia.com/cuda/archive/11.6.0/pdf/CUBLAS_Library.pdf
if (sm_count_ != 0) {
int math_sm_count = static_cast<int>(sm_count_);
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_SM_COUNT_TARGET, &math_sm_count,
sizeof(math_sm_count)));
}
#endif

if (has_scales) {
// gemm float 8
#if CUDA_VERSION >= 11080
// CUBLASLT_MATMUL_DESC_FAST_ACCUM, CUBLASLT_MATMUL_DESC_A_SCALE_POINTER, CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
// CUBLASLT_MATMUL_DESC_D_SCALE_POINTER exist from https://docs.nvidia.com/cuda/archive/11.8.0/pdf/CUBLAS_Library.pdf
const int8_t ifast_accumulation_mode = 1;
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc,
Expand All @@ -274,6 +280,7 @@ Status GemmFloat8::ComputeGemm(
CUBLAS_RETURN_IF_ERROR(cublasLtMatmulDescSetAttribute(
operationDesc, CUBLASLT_MATMUL_DESC_D_SCALE_POINTER, &p_scale_y,
sizeof(p_scale_b)));
#endif

// float 8
#if !defined(DISABLE_FLOAT8_TYPES)
Expand Down

0 comments on commit 44b5843

Please sign in to comment.