Skip to content

Commit

Permalink
non-xla hipblaslt for ThenBlasGemmStridedBatched
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Nov 22, 2024
1 parent 7630bad commit 8ee6c6d
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 35 deletions.
82 changes: 75 additions & 7 deletions tensorflow/compiler/xla/stream_executor/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1477,13 +1477,6 @@ tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
ConstantType beta, DeviceMemory<InputType> *c,
int ldc, blas::ComputePrecision precision,
blas::CallContext context) {
if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.Run(*this, transa, transb, m, n, k,
alpha, a, lda, b, ldb, beta, c, ldc,
/* allocator */nullptr)); //! NOTE: allocator is not available!!
return ::tsl::OkStatus();
}
static_assert(
detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, float,
double, std::complex<float>, std::complex<double>>(),
Expand All @@ -1500,6 +1493,15 @@ tsl::Status Stream::ThenBlasGemm(blas::Transpose transa, blas::Transpose transb,
static_assert(
detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16, ConstantType>(),
"If input is not Eigen::half, constant and input types have to match");

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.Run(*this, transa, transb, m, n, k,
alpha, a, lda, b, ldb, beta, c, ldc,
/* allocator */nullptr)); //! NOTE: allocator is not available!!
return ::tsl::OkStatus();
}

blas::BlasSupport *blas = parent()->AsBlas();
if (!blas) {
return tsl::errors::Internal(
Expand Down Expand Up @@ -1540,6 +1542,72 @@ INSTANTIATE_THEN_BLAS_GEMM(Eigen::bfloat16, float)

#undef INSTANTIATE_THEN_BLAS_GEMM

template <typename InputType, typename ConstantType>
tsl::Status Stream::ThenBlasGemmStridedBatched(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, ConstantType alpha, const DeviceMemory<InputType> &a, int lda,
int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
int64_t stride_c, int batch_count, blas::ComputePrecision precision,
blas::CallContext context) {
static_assert(
detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16,
double, std::complex<float>, std::complex<double>>(),
"Unsupported input type");
static_assert(
std::is_same_v<ConstantType, InputType> ||
(detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() &&
std::is_same_v<ConstantType, float>),
"Mismatched input and alpha/beta types");

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunStridedBatched(
*this, transa, transb, m, n, k, alpha, a, lda, stride_a, b, ldb,
stride_b, beta, c, ldc, stride_c, batch_count,
/* allocator */ nullptr)); //! NOTE: allocator is not available!!
return ::tsl::OkStatus();
}

blas::BlasSupport *blas = parent()->AsBlas();
if (!blas) {
return tsl::errors::Internal(
"Attempting to perform BLAS operation using "
"StreamExecutor without BLAS support");
}

void *alpha_ptr = &alpha;
void *beta_ptr = &beta;
float alpha_storage, beta_storage;
UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
&beta_storage);

return blas->DoBlasGemmStridedBatched(
this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
stride_c, batch_count, precision, context);
}

#define INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(INPUT_TYPE, CONSTANT_TYPE) \
template tsl::Status Stream::ThenBlasGemmStridedBatched<INPUT_TYPE, CONSTANT_TYPE> ( \
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, \
uint64_t k, CONSTANT_TYPE alpha, const DeviceMemory<INPUT_TYPE> &a, int lda, \
int64_t stride_a, const DeviceMemory<INPUT_TYPE> &b, int ldb, \
int64_t stride_b, CONSTANT_TYPE beta, DeviceMemory<INPUT_TYPE> *c, int ldc, \
int64_t stride_c, int batch_count, blas::ComputePrecision precision, \
blas::CallContext context);

INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(float, float)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(double, double)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::half, Eigen::half)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::bfloat16, Eigen::bfloat16)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(std::complex<float>, std::complex<float>)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(std::complex<double>, std::complex<double>)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::half, float)
INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED(Eigen::bfloat16, float)

#undef INSTANTIATE_THEN_BLAS_GEMM_STRIDED_BATCHED

namespace {
// Like ThenBlasImpl, except this expects the last argument of blas_func to be a
// blas::ProfileResult*. This functor doesn't put the stream into an error
Expand Down
29 changes: 1 addition & 28 deletions tensorflow/compiler/xla/stream_executor/stream.h
Original file line number Diff line number Diff line change
Expand Up @@ -1093,34 +1093,7 @@ class Stream {
int64_t stride_a, const DeviceMemory<InputType> &b, int ldb,
int64_t stride_b, ConstantType beta, DeviceMemory<InputType> *c, int ldc,
int64_t stride_c, int batch_count, blas::ComputePrecision precision,
blas::CallContext context) {
static_assert(
detail::is_any_of<InputType, float, Eigen::half, Eigen::bfloat16,
double, std::complex<float>, std::complex<double>>(),
"Unsupported input type");
static_assert(
std::is_same_v<ConstantType, InputType> ||
(detail::is_any_of<InputType, Eigen::half, Eigen::bfloat16>() &&
std::is_same_v<ConstantType, float>),
"Mismatched input and alpha/beta types");
blas::BlasSupport *blas = parent()->AsBlas();
if (!blas) {
return tsl::errors::Internal(
"Attempting to perform BLAS operation using "
"StreamExecutor without BLAS support");
}

void *alpha_ptr = &alpha;
void *beta_ptr = &beta;
float alpha_storage, beta_storage;
UpcastHalfToFloat<ConstantType>(&alpha_ptr, &beta_ptr, &alpha_storage,
&beta_storage);

return blas->DoBlasGemmStridedBatched(
this, transa, transb, m, n, k, blas::ToDataType<InputType>::value,
alpha_ptr, a, lda, stride_a, b, ldb, stride_b, beta_ptr, c, ldc,
stride_c, batch_count, precision, context);
}
blas::CallContext context);

// See BlasSupport::DoBlasTrsm.
Stream &ThenBlasTrsm(blas::Side side, blas::UpperLower uplo,
Expand Down

0 comments on commit 8ee6c6d

Please sign in to comment.