From b178ee40a7fe383f0bcb231a7a54217e848409f0 Mon Sep 17 00:00:00 2001 From: scxfjiang Date: Tue, 29 Oct 2024 18:02:32 +0000 Subject: [PATCH] update --- .../xla/service/computation_placer.cc | 2 +- .../xla/stream_executor/gpu/gpu_blas_lt.cc | 10 +++ .../xla/stream_executor/gpu/gpu_blas_lt.h | 2 + .../compiler/xla/stream_executor/stream.cc | 78 +++++++++++++++---- 4 files changed, 74 insertions(+), 18 deletions(-) diff --git a/tensorflow/compiler/xla/service/computation_placer.cc b/tensorflow/compiler/xla/service/computation_placer.cc index f00a1399aefec3..29972f2764af8d 100644 --- a/tensorflow/compiler/xla/service/computation_placer.cc +++ b/tensorflow/compiler/xla/service/computation_placer.cc @@ -163,7 +163,7 @@ StatusOr ComputationPlacer::AssignDevices( ComputationPlacerCreationFunction creation_function) { absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_); auto* computation_placers = GetPlatformComputationPlacers(); - CHECK(computation_placers->find(platform_id) == computation_placers->end()); + // CHECK(computation_placers->find(platform_id) == computation_placers->end()); (*computation_placers)[platform_id].creation_function = creation_function; } diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc index 195e1161a3aa3a..c9fe6109ef47fd 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.cc @@ -31,6 +31,16 @@ using blas::ComputationType; using blas::DataType; using xla::PrimitiveType; +bool GpuBlasLtEnabled() { + static std::atomic_bool result{[] { + bool value = false; + tsl::ReadBoolFromEnvVar("TF_ENABLE_GPU_BLASLT", + /*default_value=*/false, &value); + return value; + }()}; + return result; +} + namespace { bool TF32_Enabled() { diff --git a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h index 9eb8d121bb44fc..a050d120caf2e8 100644 --- a/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h @@ -33,6 +33,8 @@ limitations under the License. namespace stream_executor { namespace gpu { + +bool GpuBlasLtEnabled(); xla::StatusOr AsBlasDataType(xla::PrimitiveType dtype); diff --git a/tensorflow/compiler/xla/stream_executor/stream.cc b/tensorflow/compiler/xla/stream_executor/stream.cc index f9d3f37a73b63c..f8db97ba430f74 100644 --- a/tensorflow/compiler/xla/stream_executor/stream.cc +++ b/tensorflow/compiler/xla/stream_executor/stream.cc @@ -33,6 +33,8 @@ limitations under the License. #include "tensorflow/compiler/xla/stream_executor/stream_executor_internal.h" #include "tensorflow/compiler/xla/stream_executor/stream_executor_pimpl.h" #include "tensorflow/tsl/platform/stacktrace.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt.h" +#include "tensorflow/compiler/xla/stream_executor/gpu/gpu_blas_lt_gemm_runner.h" namespace stream_executor { @@ -1592,7 +1594,7 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice c, int ldc, int batch_count, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - /*scratch_allocator=*/nullptr, context); + /*allocator=*/nullptr, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1600,12 +1602,19 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, - ScratchAllocator *scratch_allocator, + ScratchAllocator *allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, batch_count, allocator)); + return *this; + } + ThenBlasImpl, int, DeviceMemorySlice, int, float, @@ -1613,7 +1622,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator, context); + allocator, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1621,11 +1630,18 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, int ldc, int batch_count, - ScratchAllocator *scratch_allocator, blas::CallContext context ) { + ScratchAllocator *allocator, blas::CallContext context ) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, batch_count, allocator)); + return *this; + } + ThenBlasImpl, int, DeviceMemorySlice, int, float, @@ -1633,7 +1649,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator, context); + allocator, context); } Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, @@ -1645,19 +1661,26 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, int ldc, int batch_count, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - /*scratch_allocator=*/nullptr, context); + /*allocator=*/nullptr, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( blas::Transpose transa, blas::Transpose transb, uint64_t m, uint64 n, uint64_t k, float alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, float beta, DeviceMemorySlice c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator, + int ldc, int batch_count, ScratchAllocator *allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, batch_count, allocator)); + return *this; + } + ThenBlasImpl, int, DeviceMemorySlice, int, float, DeviceMemorySlice, int, int, @@ -1665,7 +1688,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator, context); + allocator, context); } Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, @@ -1677,7 +1700,7 @@ Stream &Stream::ThenBlasGemmBatched(blas::Transpose transa, int ldc, int batch_count, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - /*scratch_allocator=*/nullptr, context); + /*allocator=*/nullptr, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1685,11 +1708,18 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( uint64_t k, double alpha, DeviceMemorySlice a, int lda, DeviceMemorySlice b, int ldb, double beta, DeviceMemorySlice c, int ldc, int batch_count, - ScratchAllocator *scratch_allocator, blas::CallContext context) { + ScratchAllocator *allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, batch_count, allocator)); + return *this; + } + ThenBlasImpl, int, DeviceMemorySlice, int, double, @@ -1697,7 +1727,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator, context); + allocator, context); } Stream &Stream::ThenBlasGemmBatched( @@ -1708,7 +1738,7 @@ Stream &Stream::ThenBlasGemmBatched( DeviceMemorySlice> c, int ldc, int batch_count, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - /*scratch_allocator=*/nullptr, context); + /*allocator=*/nullptr, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1717,11 +1747,18 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, int ldc, int batch_count, - ScratchAllocator *scratch_allocator, blas::CallContext context) { + ScratchAllocator *allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, batch_count, allocator)); + return *this; + } + ThenBlasImpl, DeviceMemorySlice>, int, DeviceMemorySlice>, int, std::complex, @@ -1730,7 +1767,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator, context); + allocator, context); } Stream &Stream::ThenBlasGemmBatched( @@ -1742,7 +1779,7 @@ Stream &Stream::ThenBlasGemmBatched( int ldc, int batch_count, blas::CallContext context) { return ThenBlasGemmBatchedWithScratch(transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - /*scratch_allocator=*/nullptr, context); + /*allocator=*/nullptr, context); } Stream &Stream::ThenBlasGemmBatchedWithScratch( @@ -1751,12 +1788,19 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( DeviceMemorySlice> a, int lda, DeviceMemorySlice> b, int ldb, std::complex beta, DeviceMemorySlice> c, - int ldc, int batch_count, ScratchAllocator *scratch_allocator, + int ldc, int batch_count, ScratchAllocator *allocator, blas::CallContext context) { VLOG_CALL(PARAM(transa), PARAM(transb), PARAM(m), PARAM(n), PARAM(k), PARAM(alpha), PARAM(a), PARAM(lda), PARAM(b), PARAM(ldb), PARAM(beta), PARAM(c), PARAM(ldc), PARAM(batch_count)); + if(gpu::GpuBlasLtEnabled()) { + auto& r = gpu::BlasLtGemmRunner::i(this); + CheckStatus(r.RunBatched(*this, transa, transb, m, n, k, alpha, + a, lda, b, ldb, beta, c, ldc, batch_count, allocator)); + return *this; + } + ThenBlasImpl, DeviceMemorySlice>, int, DeviceMemorySlice>, int, @@ -1765,7 +1809,7 @@ Stream &Stream::ThenBlasGemmBatchedWithScratch( impl; return impl(this, &blas::BlasSupport::DoBlasGemmBatched, transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc, batch_count, - scratch_allocator, context); + allocator, context); } Stream &Stream::ThenSetRngSeed(const uint8 *seed, uint64_t seed_bytes) {