Skip to content

Commit

Permalink
Added comments
Browse files Browse the repository at this point in the history
  • Loading branch information
rraminen committed May 16, 2024
1 parent d288d36 commit dfc8c6c
Show file tree
Hide file tree
Showing 8 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions csrc/includes/cublas_wrappers.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ int cublas_gemm_ex(cublasHandle_t handle,
const float* A,
const float* B,
float* C,
//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_gemm_algo algo = rocblas_gemm_algo_standard);
#else
Expand Down
1 change: 1 addition & 0 deletions csrc/includes/feed_forward.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ class FeedForward {
weights,
input_ptr,
out,
//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_gemm_algo(config_.gemm_algos[0]));
#else
Expand Down
1 change: 1 addition & 0 deletions csrc/includes/gemm_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ class GemmTest {
B,
A,
C,
//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
static_cast<rocblas_gemm_algo>(algo));
#else
Expand Down
1 change: 1 addition & 0 deletions csrc/includes/strided_batch_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class StridedBatchGemm {
stride_b,
stride_c,
bsz,
//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_gemm_algo(_config.gemm_algos[0]));
#else
Expand Down
1 change: 1 addition & 0 deletions csrc/transformer/cublas_wrappers.cu
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "cublas_wrappers.h"

//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
Expand Down
1 change: 1 addition & 0 deletions csrc/transformer/inference/csrc/pt_binding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ at::Tensor einsum_sec_sm_ecm(at::Tensor& Q, at::Tensor& W)
(T*)W.data_ptr(),
(T*)Q.data_ptr(),
(T*)O.data_ptr(),
//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_gemm_algo_standard);
#else
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#endif
#include <stdio.h>

//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
int cublas_gemm_ex(rocblas_handle handle,
rocblas_operation transa,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class BlasContext {

enum class BlasType { FP32, FP16, BF16 };

//TODO HIP: Remove backward compatibility for torch<=2.0 in future
#if defined(__HIP_PLATFORM_AMD__) && TORCH_VERSION_MAJOR <= 2 && TORCH_VERSION_MINOR <=0
rocblas_operation get_trans_op(bool do_trans)
{
Expand Down

0 comments on commit dfc8c6c

Please sign in to comment.