Skip to content

Commit

Permalink
change the location of gemm runner for Batched GEMM
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Nov 21, 2024
1 parent e538a1c commit b7d31bd
Showing 1 changed file with 36 additions and 36 deletions.
72 changes: 36 additions & 36 deletions tensorflow/compiler/xla/stream_executor/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1703,12 +1703,6 @@ Stream &Stream::ThenBlasGemmBatched(
uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count, blas::CallContext context) {
if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, /* allocator */nullptr));
return *this;
}
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
Expand All @@ -1724,7 +1718,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, scratch_allocator));
return *this;
}
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
float, DeviceMemorySlice<Eigen::half>, int,
DeviceMemorySlice<Eigen::half>, int, float,
Expand All @@ -1744,7 +1743,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, scratch_allocator));
return *this;
}
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
float, DeviceMemorySlice<Eigen::bfloat16>, int,
DeviceMemorySlice<Eigen::bfloat16>, int, float,
Expand All @@ -1762,12 +1766,6 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
DeviceMemorySlice<float> b, int ldb,
float beta, DeviceMemorySlice<float> c,
int ldc, int batch_count, blas::CallContext context) {
if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, /* allocator */nullptr));
return *this;
}
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
Expand All @@ -1782,7 +1780,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, scratch_allocator));
return *this;
}
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
float, DeviceMemorySlice<float>, int, DeviceMemorySlice<float>,
int, float, DeviceMemorySlice<float>, int, int,
Expand All @@ -1800,12 +1803,6 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
DeviceMemorySlice<double> b, int ldb,
double beta, DeviceMemorySlice<double> c,
int ldc, int batch_count, blas::CallContext context) {
if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, /* allocator */nullptr));
return *this;
}
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
Expand All @@ -1820,7 +1817,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, scratch_allocator));
return *this;
}
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
double, DeviceMemorySlice<double>, int,
DeviceMemorySlice<double>, int, double,
Expand All @@ -1837,12 +1839,6 @@ Stream &Stream::ThenBlasGemmBatched(
DeviceMemorySlice<std::complex<float>> a, int lda,
DeviceMemorySlice<std::complex<float>> b, int ldb, std::complex<float> beta,
DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count, blas::CallContext context) {
if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, /* allocator */nullptr));
return *this;
}
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
Expand All @@ -1858,7 +1854,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, scratch_allocator));
return *this;
}
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
std::complex<float>, DeviceMemorySlice<std::complex<float>>, int,
DeviceMemorySlice<std::complex<float>>, int, std::complex<float>,
Expand All @@ -1877,12 +1878,6 @@ Stream &Stream::ThenBlasGemmBatched(
DeviceMemorySlice<std::complex<double>> b, int ldb,
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
int ldc, int batch_count, blas::CallContext context) {
if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, /* allocator */nullptr));
return *this;
}
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
Expand All @@ -1899,7 +1894,12 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if (gpu::GpuBlasLtEnabled()) {
auto &r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc, batch_count, scratch_allocator));
return *this;
}
ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
std::complex<double>, DeviceMemorySlice<std::complex<double>>,
int, DeviceMemorySlice<std::complex<double>>, int,
Expand Down

0 comments on commit b7d31bd

Please sign in to comment.