Skip to content

Commit

Permalink
hipblaslt 6.0 adaptions for datatypes and enums
Browse files Browse the repository at this point in the history
  • Loading branch information
jayfurmanek committed Dec 21, 2023
1 parent d287248 commit 6b05526
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 58 deletions.
2 changes: 1 addition & 1 deletion third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -1545,7 +1545,7 @@ cc_library(
"//xla/stream_executor:host_or_device_scalar",
]) + if_rocm_is_configured([
"//xla/stream_executor/rocm:hipblas_lt_header",
"//xla/stream_executor/rocm:hipblaslt_plugin",
"//xla/stream_executor/rocm:amdhipblaslt_plugin",
"//xla/stream_executor:host_or_device_scalar",
"//xla/stream_executor/platform:dso_loader",
]) + if_static([
Expand Down
10 changes: 0 additions & 10 deletions third_party/xla/xla/service/gpu/matmul_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -857,15 +857,6 @@ StatusOr<se::gpu::BlasLt::MatrixLayout> AsBlasLtMatrixLayout(
}

#if TF_HIPBLASLT
#if TF_ROCM_VERSION < 60000
using cudaDataType_t = hipblasDatatype_t;
#define CUDA_R_16BF HIPBLAS_R_16B
#define CUDA_R_16F HIPBLAS_R_16F
#define CUDA_R_32F HIPBLAS_R_32F
#define CUDA_R_64F HIPBLAS_R_64F
#define CUDA_C_32F HIPBLAS_C_32F
#define CUDA_C_64F HIPBLAS_C_64F
#else
using cudaDataType_t = hipblasltDatatype_t;
#define CUDA_R_16BF HIPBLASLT_R_16B
#define CUDA_R_16F HIPBLASLT_R_16F
Expand All @@ -874,7 +865,6 @@ using cudaDataType_t = hipblasltDatatype_t;
#define CUDA_C_32F HIPBLASLT_C_32F
#define CUDA_C_64F HIPBLASLT_C_64F
#endif
#endif

template <cudaDataType_t CudaT>
struct CudaToNativeT;
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/stream_executor/rocm/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,7 @@ cc_library(
)

cc_library(
name = "hipblaslt_plugin",
name = "amdhipblaslt_plugin",
srcs = if_rocm_is_configured(["hip_blas_lt.cc"]),
hdrs = if_rocm_is_configured([
"hip_blas_lt.h",
Expand Down Expand Up @@ -556,7 +556,7 @@ cc_library(
":rocm_driver",
":rocm_platform",
":rocm_helpers",
":hipblaslt_plugin",
":amdhipblaslt_plugin",
]),
alwayslink = 1,
)
Expand Down
10 changes: 5 additions & 5 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ tsl::Status BlasLt::Init() {
return std::move(layout);
}

hipblasltDatatype_t BlasLt::MatrixLayout::type() const { return HIPBLASLT_R_32F; }
hipDatatype BlasLt::MatrixLayout::type() const { return HIP_R_32F; }

/*static*/ tsl::StatusOr<BlasLt::MatmulDesc> BlasLt::MatmulDesc::Create(
blas::ComputationType compute_type, blas::DataType scale_type,
Expand All @@ -175,12 +175,12 @@ hipblasltDatatype_t BlasLt::MatrixLayout::type() const { return HIPBLASLT_R_32F;
return std::move(desc);
}

hipblasLtComputeType_t BlasLt::MatmulDesc::compute_type() const {
return HIPBLASLT_COMPUTE_F32;
hipblasComputeType_t BlasLt::MatmulDesc::compute_type() const {
return HIPBLAS_COMPUTE_32F;
}

hipblasltDatatype_t BlasLt::MatmulDesc::scale_type() const {
return HIPBLASLT_R_32F;
hipblasDatatype_t BlasLt::MatmulDesc::scale_type() const {
return HIP_R_32F;
}

hipblasPointerMode_t BlasLt::MatmulDesc::pointer_mode() const {
Expand Down
18 changes: 3 additions & 15 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,6 @@ limitations under the License.
#include "rocm/rocm_config.h"
#if TF_HIPBLASLT

#if TF_ROCM_VERSION < 60000
#define hipblasltDatatype_t hipblasDatatype_t
#define HIPBLASLT_R_16F HIPBLAS_R_16F
#define HIPBLASLT_R_16B HIPBLAS_R_16B
#define HIPBLASLT_R_32F HIPBLAS_R_32F
#define HIPBLASLT_R_64F HIPBLAS_R_64F
#define HIPBLASLT_R_8I HIPBLAS_R_8I
#define HIPBLASLT_R_32I HIPBLAS_R_32I
#define HIPBLASLT_C_32F HIPBLAS_R_32F
#define HIPBLASLT_C_64F HIPBLAS_R_64F
#endif

#include "xla/status.h"
#include "xla/stream_executor/rocm/hip_blas_utils.h"
#include "xla/stream_executor/rocm/hipblaslt_wrapper.h"
Expand Down Expand Up @@ -66,7 +54,7 @@ class BlasLt {
std::optional<int64_t> leading_dim_stride = std::nullopt,
std::optional<int64_t> batch_stride = std::nullopt);

hipblasltDatatype_t type() const;
hipDatatype type() const;

hipblasLtMatrixLayout_t get() const { return handle_.get(); }

Expand Down Expand Up @@ -103,8 +91,8 @@ class BlasLt {
Epilogue epilogue = Epilogue::kDefault,
PointerMode pointer_mode = PointerMode::kHost);

hipblasLtComputeType_t compute_type() const;
hipblasltDatatype_t scale_type() const;
hipblasComputeType_t compute_type() const;
hipDatatype scale_type() const;
hipblasPointerMode_t pointer_mode() const;

hipblasLtMatmulDesc_t get() const { return handle_.get(); }
Expand Down
33 changes: 11 additions & 22 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,6 @@ limitations under the License.
#include "absl/strings/str_cat.h"
#include "xla/stream_executor/blas.h"
#include "rocm/rocm_config.h"
#if TF_ROCM_VERSION < 60000
#define hipblasltDatatype_t hipblasDatatype_t
#define HIPBLASLT_R_16F HIPBLAS_R_16F
#define HIPBLASLT_R_16B HIPBLAS_R_16B
#define HIPBLASLT_R_32F HIPBLAS_R_32F
#define HIPBLASLT_R_64F HIPBLAS_R_64F
#define HIPBLASLT_R_8I HIPBLAS_R_8I
#define HIPBLASLT_R_32I HIPBLAS_R_32I
#define HIPBLASLT_C_32F HIPBLAS_R_32F
#define HIPBLASLT_C_64F HIPBLAS_R_64F
#endif

namespace stream_executor {
namespace rocm {
Expand All @@ -42,36 +31,36 @@ tsl::Status ToStatus(hipblasStatus_t status, const char* prefix) {
return tsl::OkStatus();
}

hipblasltDatatype_t AsHipblasDataType(blas::DataType type) {
hiptDatatype_t AsHipblasDataType(blas::DataType type) {
switch (type) {
case blas::DataType::kF8E5M2:
case blas::DataType::kF8E4M3FN:
LOG(FATAL) << "hipblaslt does not support F8 yet";
case blas::DataType::kHalf:
return HIPBLASLT_R_16F;
return HIP_R_16F;
case blas::DataType::kBF16:
return HIPBLASLT_R_16B;
return HIP_R_16B;
case blas::DataType::kFloat:
return HIPBLASLT_R_32F;
return HIP_R_32F;
case blas::DataType::kDouble:
return HIPBLASLT_R_64F;
return HIP_R_64F;
case blas::DataType::kInt8:
return HIPBLASLT_R_8I;
return HIP_R_8I;
case blas::DataType::kInt32:
return HIPBLASLT_R_32I;
return HIP_R_32I;
case blas::DataType::kComplexFloat:
return HIPBLASLT_C_32F;
return HIP_C_32F;
case blas::DataType::kComplexDouble:
return HIPBLASLT_C_64F;
return HIP_C_64F;
default:
LOG(FATAL) << "unknown data type";
}
}

hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type) {
hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type) {
if (type == blas::ComputationType::kF32 ||
type == blas::ComputationType::kTF32AsF32)
return HIPBLASLT_COMPUTE_F32;
return HIPBLAS_COMPUTE_32F;
else
LOG(FATAL) << "unsupported hipblaslt computation type";
}
Expand Down
19 changes: 16 additions & 3 deletions third_party/xla/xla/stream_executor/rocm/hip_blas_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,20 @@ limitations under the License.

#include "rocm/rocm_config.h"
#if TF_ROCM_VERSION < 60000
#define hipblasltDatatype_t hipblasDatatype_t
#define hipDataType hipblasDatatype_t
#define HIP_R_16F HIPBLAS_R_16F
#define HIP_R_16BF HIPBLAS_R_16B
#define HIP_R_32F HIPBLAS_R_32F
#define HIP_R_64F HIPBLAS_R_64F
#define HIP_R_8I HIPBLAS_R_8I
#define HIP_R_32I HIPBLAS_R_32I
#define HIP_C_32F HIPBLAS_C_32F
#define HIP_C_64F HIPBLAS_C_64F

#define hipblasComputeType_t hipblasLtComputeType_t
#define HIPBLAS_COMPUTE_32F HIPBLASLT_COMPUTE_F32
#define HIPBLAS_COMPUTE_64F HIPBLASLT_COMPUTE_F64
#define HIPBLAS_COMPUTE_32I HIPBLASLT_COMPUTE_I32
#endif


Expand All @@ -36,8 +49,8 @@ namespace rocm {
TF_RETURN_IF_ERROR(::stream_executor::rocm::ToStatus(expr, #expr))

tsl::Status ToStatus(hipblasStatus_t status, const char* prefix);
hipblasltDatatype_t AsHipblasDataType(blas::DataType type);
hipblasLtComputeType_t AsHipblasComputeType(blas::ComputationType type);
hipDatatype AsHipblasDataType(blas::DataType type);
hipblasComputeType_t AsHipblasComputeType(blas::ComputationType type);
hipblasOperation_t AsHipblasOperation(blas::Transpose trans);

} // namespace rocm
Expand Down

0 comments on commit 6b05526

Please sign in to comment.