diff --git a/csrc/includes/cublas_wrappers.h b/csrc/includes/cublas_wrappers.h index b57ff79923fc..2721fb990c7e 100644 --- a/csrc/includes/cublas_wrappers.h +++ b/csrc/includes/cublas_wrappers.h @@ -17,6 +17,7 @@ #include #endif #include +#include int cublas_gemm_ex(cublasHandle_t handle, cublasOperation_t transa, @@ -29,7 +30,9 @@ int cublas_gemm_ex(cublasHandle_t handle, const float* A, const float* B, float* C, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); @@ -46,7 +49,8 @@ int cublas_gemm_ex(cublasHandle_t handle, const __half* A, const __half* B, __half* C, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -67,7 +71,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int stride_B, int stride_C, int batch, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT); @@ -88,7 +93,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, int stride_B, int stride_C, int batch, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo algo = rocblas_gemm_algo_standard); #else cublasGemmAlgo_t algo = CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/csrc/includes/feed_forward.h b/csrc/includes/feed_forward.h index 46e3ba748d52..d2056403d265 100644 --- a/csrc/includes/feed_forward.h +++ b/csrc/includes/feed_forward.h @@ -48,7 +48,9 @@ class FeedForward { weights, input_ptr, out, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[0])); #else cublasGemmAlgo_t(config_.gemm_algos[0])); @@ -77,7 +79,8 @@ class FeedForward { input_ptr, out_grad, weights_grad, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[1])); #else cublasGemmAlgo_t(config_.gemm_algos[1])); @@ -94,7 +97,8 @@ class FeedForward { weights, out_grad, inp_grad_out, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(config_.gemm_algos[2])); #else cublasGemmAlgo_t(config_.gemm_algos[2])); diff --git a/csrc/includes/gemm_test.h b/csrc/includes/gemm_test.h index 278515174523..de5b55cd3df1 100644 --- a/csrc/includes/gemm_test.h +++ b/csrc/includes/gemm_test.h @@ -67,7 +67,9 @@ class GemmTest { B, A, C, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -86,7 +88,8 @@ class GemmTest { A, C, B, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -105,7 +108,8 @@ class GemmTest { B, C, A, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -121,8 +125,11 @@ class GemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; +#elif defined(__HIP_PLATFORM_AMD__) + for (int algo = (int)HIPBLAS_GEMM_DEFAULT; algo <= (int)HIPBLAS_GEMM_DEFAULT; #else for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; @@ -211,7 +218,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -245,7 +253,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -276,7 +285,8 @@ class StridedGemmTest { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) static_cast(algo)); #else static_cast(algo)); @@ -292,11 +302,17 @@ class StridedGemmTest { float fast_latency = (std::numeric_limits::max)(); int fast_algo = 0; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) for (int algo = (int)rocblas_gemm_algo_standard; algo <= (int)rocblas_gemm_algo_standard; +#else +#ifdef __HIP_PLATFORM_AMD__ + for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; + algo <= (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; #else for (int algo = (int)CUBLAS_GEMM_DEFAULT_TENSOR_OP; algo <= (int)CUBLAS_GEMM_ALGO15_TENSOR_OP; +#endif #endif algo++) { int warm_up = 5; diff --git a/csrc/includes/strided_batch_gemm.h b/csrc/includes/strided_batch_gemm.h index 86d1e3dea11a..9767fcf589b8 100644 --- a/csrc/includes/strided_batch_gemm.h +++ b/csrc/includes/strided_batch_gemm.h @@ -77,7 +77,9 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[0])); #else cublasGemmAlgo_t(_config.gemm_algos[0])); @@ -105,7 +107,8 @@ class StridedBatchGemm { stride_b, stride_c, _config.batch_size, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[0])); #else cublasGemmAlgo_t(_config.gemm_algos[0])); @@ -149,7 +152,8 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[1])); #else cublasGemmAlgo_t(_config.gemm_algos[1])); @@ -178,7 +182,8 @@ class StridedBatchGemm { stride_b, stride_c, bsz, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo(_config.gemm_algos[2])); #else cublasGemmAlgo_t(_config.gemm_algos[2])); diff --git a/csrc/transformer/cublas_wrappers.cu b/csrc/transformer/cublas_wrappers.cu index 7821a8759ab0..d982e65b8a81 100644 --- a/csrc/transformer/cublas_wrappers.cu +++ b/csrc/transformer/cublas_wrappers.cu @@ -5,7 +5,9 @@ #include "cublas_wrappers.h" -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -33,7 +35,8 @@ int cublas_gemm_ex(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -67,20 +70,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -96,7 +118,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -124,7 +147,8 @@ int cublas_gemm_ex(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -158,20 +182,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (transb == CUBLAS_OP_N) ? k : n, (const void*)beta, (void*)C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -187,7 +230,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -223,7 +267,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -263,24 +308,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -297,7 +361,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -333,7 +398,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -373,24 +439,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_16F, +#else CUDA_R_16F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/csrc/transformer/inference/csrc/pt_binding.cpp b/csrc/transformer/inference/csrc/pt_binding.cpp index b7277d1e1678..1b9f91cd9c88 100644 --- a/csrc/transformer/inference/csrc/pt_binding.cpp +++ b/csrc/transformer/inference/csrc/pt_binding.cpp @@ -163,7 +163,9 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W) (T*)W.data_ptr(), (T*)Q.data_ptr(), (T*)O.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -216,7 +218,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * k, seq_len * soft_len, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -253,7 +256,8 @@ void attention_unfused(at::Tensor& prev_key_cont, seq_len * soft_len, seq_len * k, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -388,7 +392,8 @@ void attention_unfused(T* prev_key_cont, seq_len * k, seq_len * soft_len, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -421,7 +426,8 @@ void attention_unfused(T* prev_key_cont, seq_len * soft_len, seq_len * k, bsz * heads, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -886,7 +892,8 @@ void quantized_gemm(void* output, weight16, (T*)input, (T*)output, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -931,7 +938,8 @@ at::Tensor qkv_unfused_cublas(at::Tensor& output, (T*)weight.data_ptr(), workspace, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1003,7 +1011,8 @@ std::vector ds_rms_qkv(at::Tensor& input, (T*)weight.data_ptr(), (T*)rms_norm.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1089,7 +1098,8 @@ void quantized_gemm(at::Tensor& output, (T*)weight16.data_ptr(), (T*)input.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1135,7 +1145,8 @@ at::Tensor ds_linear_layer(at::Tensor& input, (T*)weight.data_ptr(), (T*)input_cont.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1353,7 +1364,8 @@ at::Tensor ds_vector_matmul(at::Tensor& input, (T*)weight.data_ptr(), (T*)input.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1439,7 +1451,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, (T*)weight.data_ptr(), inp_norm, intermediate, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1483,7 +1496,8 @@ at::Tensor mlp_unfused_cublas(at::Tensor& output, (T*)weight1.data_ptr(), intermediate, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1617,7 +1631,8 @@ std::vector ds_rms_mlp_gemm(at::Tensor& input, (T*)weight_interm.data_ptr(), (T*)inp_norm.data_ptr(), intermediate_ptr, -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1680,7 +1695,8 @@ std::vector ds_rms_mlp_gemm(at::Tensor& input, (T*)weight_out.data_ptr(), intermediate_ptr, (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard, #else CUBLAS_GEMM_DEFAULT_TENSOR_OP, @@ -1742,7 +1758,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)weight.data_ptr(), (T*)input.data_ptr(), (T*)intermediate.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); @@ -1776,7 +1793,8 @@ at::Tensor fused_gemm_gelu(at::Tensor& input, (T*)weight_out.data_ptr(), (T*)intermediate.data_ptr(), (T*)output.data_ptr(), -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_gemm_algo_standard); #else CUBLAS_GEMM_DEFAULT_TENSOR_OP); diff --git a/csrc/transformer/inference/includes/inference_cublas_wrappers.h b/csrc/transformer/inference/includes/inference_cublas_wrappers.h index 640751b12c8f..40c3e443941d 100644 --- a/csrc/transformer/inference/includes/inference_cublas_wrappers.h +++ b/csrc/transformer/inference/includes/inference_cublas_wrappers.h @@ -18,7 +18,9 @@ #endif #include -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -49,7 +51,8 @@ int cublas_gemm_ex(cublasHandle_t handle, #endif { const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_ex(handle, transa, transb, @@ -83,20 +86,39 @@ int cublas_gemm_ex(cublasHandle_t handle, k, (const void*)alpha, (const void*)A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (transa == CUBLAS_OP_N) ? m : k, (const void*)B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif ldb, (const void*)beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -113,7 +135,8 @@ int cublas_gemm_ex(cublasHandle_t handle, } template -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_gemm_ex(rocblas_handle handle, rocblas_operation transa, rocblas_operation transb, @@ -144,7 +167,8 @@ int cublas_gemm_ex(cublasHandle_t handle, #endif { const int ldb = (b_stride == -1) ? ((transb == CUBLAS_OP_N) ? k : n) : b_stride; -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r : rocblas_datatype_bf16_r; rocblas_status status = rocblas_gemm_ex(handle, @@ -171,8 +195,12 @@ int cublas_gemm_ex(cublasHandle_t handle, algo, 0, 0); +#else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; #else constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif cublasStatus_t status = cublasGemmEx(handle, transa, transb, @@ -190,11 +218,18 @@ int cublas_gemm_ex(cublasHandle_t handle, (void*)C, cublas_dtype_16, m, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -210,7 +245,8 @@ int cublas_gemm_ex(cublasHandle_t handle, return 0; } -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -246,7 +282,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_status status = rocblas_gemm_strided_batched_ex(handle, op_A, @@ -286,24 +323,43 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, k, alpha, A, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_A == CUBLAS_OP_N) ? m : k, stride_A, B, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif (op_B == CUBLAS_OP_N) ? k : n, stride_B, beta, C, +#ifdef __HIP_PLATFORM_AMD__ + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -321,7 +377,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, } template -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) int cublas_strided_batched_gemm(rocblas_handle handle, int m, int n, @@ -357,7 +414,8 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, cublasGemmAlgo_t algo) #endif { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) constexpr auto rocblas_dtype_16 = std::is_same::value ? rocblas_datatype_f16_r : rocblas_datatype_bf16_r; rocblas_status status = @@ -390,8 +448,12 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, algo, 0, 0); +#else +#ifdef __HIP_PLATFORM_AMD__ + constexpr auto cublas_dtype_16 = std::is_same::value ? HIPBLAS_R_16F : HIPBLAS_R_16B; #else constexpr auto cublas_dtype_16 = std::is_same::value ? CUDA_R_16F : CUDA_R_16BF; +#endif cublasStatus_t status = cublasGemmStridedBatchedEx(handle, op_A, op_B, @@ -413,11 +475,18 @@ int cublas_strided_batched_gemm(cublasHandle_t handle, m, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif algo); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h index c02cc76905e0..294db7528699 100644 --- a/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h +++ b/deepspeed/inference/v2/kernels/core_ops/blas_kernels/blas_utils.h @@ -55,7 +55,9 @@ class BlasContext { enum class BlasType { FP32, FP16, BF16 }; -#ifdef __HIP_PLATFORM_AMD__ +// TODO HIP: Remove backward compatibility for torch<=2.0 in future +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_operation get_trans_op(bool do_trans) { return (do_trans) ? rocblas_operation_transpose : rocblas_operation_none; @@ -76,9 +78,15 @@ cublasOperation_t get_trans_op(bool do_trans) { return (do_trans) ? CUBLAS_OP_T cublasDataType_t get_datatype(BlasType type) { switch (type) { +#ifdef __HIP_PLATFORM_AMD__ + case BlasType::FP32: return HIPBLAS_R_32F; + case BlasType::FP16: return HIPBLAS_R_16F; + case BlasType::BF16: return HIPBLAS_R_16B; +#else case BlasType::FP32: return CUDA_R_32F; case BlasType::FP16: return CUDA_R_16F; case BlasType::BF16: return CUDA_R_16BF; +#endif default: throw std::runtime_error("Unsupported BlasType"); } } @@ -99,7 +107,8 @@ int blas_gemm_ex(void* C, const float* beta, BlasType type) { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_operation_t transa_op = get_trans_op(transa); rocblas_operation_t transb_op = get_trans_op(transb); @@ -151,11 +160,18 @@ int blas_gemm_ex(void* C, C, abc_type, ldc, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { @@ -190,7 +206,8 @@ int blas_strided_batched_gemm(void* C, int batch, BlasType type) { -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) rocblas_operation_t transa_op = get_trans_op(transa); rocblas_operation_t transb_op = get_trans_op(transb); @@ -253,11 +270,18 @@ int blas_strided_batched_gemm(void* C, ldc, stride_C, batch, +#if defined(__HIP_PLATFORM_AMD__) && defined(HIPBLAS_V2) + HIPBLAS_COMPUTE_32F, +#elif defined(__HIP_PLATFORM_AMD__) + HIPBLAS_R_32F, +#else CUDA_R_32F, +#endif CUBLAS_GEMM_DEFAULT_TENSOR_OP); #endif -#ifdef __HIP_PLATFORM_AMD__ +#if defined(__HIP_PLATFORM_AMD__) && \ + ((TORCH_VERSION_MAJOR < 2) || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR == 0)) if (status != rocblas_status_success) { #else if (status != CUBLAS_STATUS_SUCCESS) { diff --git a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu index cfa62f94596a..fc14b1831361 100644 --- a/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu +++ b/deepspeed/inference/v2/kernels/core_ops/gated_activations/gated_activation_kernels_cuda.cu @@ -17,7 +17,7 @@ constexpr int access_size = 16; constexpr int threads = 1024; template -float gated_act_fn(float x, float y); +DS_D_INLINE float gated_act_fn(float x, float y); template <> DS_D_INLINE float gated_act_fn(float x, float y)