Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ScXfjiang committed Oct 29, 2024
1 parent 3f7e238 commit b178ee4
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 18 deletions.
2 changes: 1 addition & 1 deletion tensorflow/compiler/xla/service/computation_placer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
ComputationPlacerCreationFunction creation_function) {
absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
auto* computation_placers = GetPlatformComputationPlacers();
CHECK(computation_placers->find(platform_id) == computation_placers->end());
// CHECK(computation_placers->find(platform_id) == computation_placers->end());
(*computation_placers)[platform_id].creation_function = creation_function;
}

Expand Down
10 changes: 10 additions & 0 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,16 @@ using blas::ComputationType;
using blas::DataType;
using xla::PrimitiveType;

bool GpuBlasLtEnabled() {
static std::atomic_bool result{[] {
bool value = false;
tsl::ReadBoolFromEnvVar("TF_ENABLE_GPU_BLASLT",
/*default_value=*/false, &value);
return value;
}()};
return result;
}

namespace {

bool TF32_Enabled() {
Expand Down
2 changes: 2 additions & 0 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ limitations under the License.
namespace stream_executor {

namespace gpu {

bool GpuBlasLtEnabled();

xla::StatusOr<blas::DataType> AsBlasDataType(xla::PrimitiveType dtype);

Expand Down
78 changes: 61 additions & 17 deletions tensorflow/compiler/xla/stream_executor/stream.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ limitations under the License.
#include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h"
#include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h"
#include "tensorflow/tsl/platform/stacktrace.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h"
#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h"

namespace stream_executor {

Expand Down Expand Up @@ -1592,48 +1594,62 @@ Stream &Stream::ThenBlasGemmBatched(
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count, blas::CallContext context) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
/*allocator=*/nullptr, context);
}

Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, float alpha, DeviceMemorySlice<Eigen::half> a, int lda,
DeviceMemorySlice<Eigen::half> b, int ldb, float beta,
DeviceMemorySlice<Eigen::half> c, int ldc, int batch_count,
ScratchAllocator *scratch_allocator,
ScratchAllocator *allocator,
blas::CallContext context) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc, batch_count, allocator));
return *this;
}

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
float, DeviceMemorySlice<Eigen::half>, int,
DeviceMemorySlice<Eigen::half>, int, float,
DeviceMemorySlice<Eigen::half>, int, int, ScratchAllocator *, blas::CallContext>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
scratch_allocator, context);
allocator, context);
}

Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, float alpha, DeviceMemorySlice<Eigen::bfloat16> a, int lda,
DeviceMemorySlice<Eigen::bfloat16> b, int ldb, float beta,
DeviceMemorySlice<Eigen::bfloat16> c, int ldc, int batch_count,
ScratchAllocator *scratch_allocator, blas::CallContext context ) {
ScratchAllocator *allocator, blas::CallContext context ) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc, batch_count, allocator));
return *this;
}

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
float, DeviceMemorySlice<Eigen::bfloat16>, int,
DeviceMemorySlice<Eigen::bfloat16>, int, float,
DeviceMemorySlice<Eigen::bfloat16>, int, int, ScratchAllocator *, blas::CallContext>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
scratch_allocator, context);
allocator, context);
}

Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
Expand All @@ -1645,27 +1661,34 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
int ldc, int batch_count, blas::CallContext context) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
/*allocator=*/nullptr, context);
}

Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, float alpha, DeviceMemorySlice<float> a, int lda,
DeviceMemorySlice<float> b, int ldb, float beta, DeviceMemorySlice<float> c,
int ldc, int batch_count, ScratchAllocator *scratch_allocator,
int ldc, int batch_count, ScratchAllocator *allocator,
blas::CallContext context) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc, batch_count, allocator));
return *this;
}

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
float, DeviceMemorySlice<float>, int, DeviceMemorySlice<float>,
int, float, DeviceMemorySlice<float>, int, int,
ScratchAllocator *, blas::CallContext>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
scratch_allocator, context);
allocator, context);
}

Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
Expand All @@ -1677,27 +1700,34 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa,
int ldc, int batch_count, blas::CallContext context) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
/*allocator=*/nullptr, context);
}

Stream &Stream::ThenBlasGemmBatchedWithScratch(
blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n,
uint64_t k, double alpha, DeviceMemorySlice<double> a, int lda,
DeviceMemorySlice<double> b, int ldb, double beta,
DeviceMemorySlice<double> c, int ldc, int batch_count,
ScratchAllocator *scratch_allocator, blas::CallContext context) {
ScratchAllocator *allocator, blas::CallContext context) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc, batch_count, allocator));
return *this;
}

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
double, DeviceMemorySlice<double>, int,
DeviceMemorySlice<double>, int, double,
DeviceMemorySlice<double>, int, int, ScratchAllocator *, blas::CallContext>
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
scratch_allocator, context);
allocator, context);
}

Stream &Stream::ThenBlasGemmBatched(
Expand All @@ -1708,7 +1738,7 @@ Stream &Stream::ThenBlasGemmBatched(
DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count, blas::CallContext context) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
/*allocator=*/nullptr, context);
}

Stream &Stream::ThenBlasGemmBatchedWithScratch(
Expand All @@ -1717,11 +1747,18 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
DeviceMemorySlice<std::complex<float>> a, int lda,
DeviceMemorySlice<std::complex<float>> b, int ldb, std::complex<float> beta,
DeviceMemorySlice<std::complex<float>> c, int ldc, int batch_count,
ScratchAllocator *scratch_allocator, blas::CallContext context) {
ScratchAllocator *allocator, blas::CallContext context) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc, batch_count, allocator));
return *this;
}

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
std::complex<float>, DeviceMemorySlice<std::complex<float>>, int,
DeviceMemorySlice<std::complex<float>>, int, std::complex<float>,
Expand All @@ -1730,7 +1767,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
scratch_allocator, context);
allocator, context);
}

Stream &Stream::ThenBlasGemmBatched(
Expand All @@ -1742,7 +1779,7 @@ Stream &Stream::ThenBlasGemmBatched(
int ldc, int batch_count, blas::CallContext context) {
return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda,
b, ldb, beta, c, ldc, batch_count,
/*scratch_allocator=*/nullptr, context);
/*allocator=*/nullptr, context);
}

Stream &Stream::ThenBlasGemmBatchedWithScratch(
Expand All @@ -1751,12 +1788,19 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
DeviceMemorySlice<std::complex<double>> a, int lda,
DeviceMemorySlice<std::complex<double>> b, int ldb,
std::complex<double> beta, DeviceMemorySlice<std::complex<double>> c,
int ldc, int batch_count, ScratchAllocator *scratch_allocator,
int ldc, int batch_count, ScratchAllocator *allocator,
blas::CallContext context) {
VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k),
PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb),
PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count));

if(gpu::GpuBlasLtEnabled()) {
auto& r = gpu::BlasLtGemmRunner::i(this);
CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha,
a, lda, b, ldb, beta, c, ldc, batch_count, allocator));
return *this;
}

ThenBlasImpl<blas::Transpose, blas::Transpose, uint64_t, uint64_t, uint64,
std::complex<double>, DeviceMemorySlice<std::complex<double>>,
int, DeviceMemorySlice<std::complex<double>>, int,
Expand All @@ -1765,7 +1809,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch(
impl;
return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n,
k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count,
scratch_allocator, context);
allocator, context);
}

Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes) {
Expand Down

0 comments on commit b178ee4

Please sign in to comment.