diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index e38b854db4ec19..640cdfcf02f68d 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -200,6 +200,10 @@ struct LaunchFusedMatMulOp { namespace { #if GOOGLE_CUDA || TF_HIPBLASLT +/* + hipBLASLt support Epilogue: + https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#hipblasltepilogue-t +*/ StatusOr GetBlasLtEpilogOp( FusedComputationType fusion) { if (fusion == FusedComputationType::kBiasAdd) { @@ -263,7 +267,7 @@ se::blas::AlgorithmConfig AutotuneMatmul( } return algorithm_config; } -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT template StatusOr> AutotuneMatMulImpl( @@ -477,6 +481,17 @@ struct LaunchFusedMatMulOp { se::dnn::ActivationMode matmul_activation_mode; bool use_cudnn = false; + +#if !(GOOGLE_CUDA || TF_HIPBLASLT) + use_cudnn = true; +#endif + const auto& cc = stream->parent()->GetDeviceDescription(). + gpu_compute_capability(); + if (auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + use_cudnn = !procm->gfx9_mi200_or_later(); + } + + // use_cudnn is for hipblaslt doesn't support yet switch (fusion) { case FusedComputationType::kBiasAddWithGeluExact: matmul_activation_mode = se::dnn::ActivationMode::kGeluExact; @@ -511,15 +526,6 @@ struct LaunchFusedMatMulOp { default: use_cudnn = false; } -#if !(GOOGLE_CUDA || TF_HIPBLASLT) - use_cudnn = true; -#endif - -#if TF_HIPBLASLT - auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true; -#endif BlasScratchAllocator scratch_allocator(context); @@ -590,32 +596,31 @@ struct LaunchFusedMatMulOp { epilog_op}; absl::Mutex* pmu; auto plan_and_algorithms_or = - GetPlanAndAlgorithms(stream, matmul_params, &pmu); + BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); - const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value(); - const auto& algorithms = plan_and_algorithms->algorithms; - OP_REQUIRES(context, algorithms.size() > 0, + const auto& entry = *plan_and_algorithms_or.value(); + OP_REQUIRES(context, entry.algorithms.size() > 0, errors::InvalidArgument("No matmul algorithm returned!")); auto launch_func = [&](BlasScratchAllocator& scratch_allocator, size_t alg_idx, se::blas::ProfileResult* profile_result) { - return DoBlasLtMatmul(stream, *plan_and_algorithms, a_ptr, b_ptr, c_ptr, - alg_idx, scratch_allocator, bias_ptr, - profile_result); + return BlasLtMatmulPlanCache::ExecuteOnStream( + stream, entry, a_ptr, b_ptr, c_ptr, alg_idx, + scratch_allocator, bias_ptr, profile_result); }; size_t alg_idx = 0; if (use_autotune) { auto algorithm_config = - AutotuneMatmul(algorithms, matmul_params, context, launch_func); + AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func); alg_idx = algorithm_config.algorithm(); } OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr)); -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT } }; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 9c5fe075d97ff5..e78a13a4d164e9 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -601,12 +601,13 @@ struct LaunchBatchMatMul { #if GOOGLE_CUDA || TF_HIPBLASLT static const bool use_autotune = MatmulAutotuneEnable(); bool bCublasLtSupport = true; -#if TF_HIPBLASLT - if (!std::is_same_v) bCublasLtSupport = false; - auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false; -#endif + + const auto& cc = stream->parent()->GetDeviceDescription(). + gpu_compute_capability(); + if(auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + bCublasLtSupport = procm->gfx9_mi200_or_later(); + } + if (EnableCublasLtGemm() && bCublasLtSupport) { static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default @@ -636,7 +637,7 @@ struct LaunchBatchMatMul { std::optional max_algorithm_count; if (!use_autotune) max_algorithm_count = 1; absl::Mutex* pmu = nullptr; - auto plan_and_algorithms_or = GetPlanAndAlgorithms( + auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate( stream, matmul_params, &pmu, max_algorithm_count); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); @@ -659,9 +660,10 @@ struct LaunchBatchMatMul { // scratch space is deallocated between runs. BlasScratchAllocator scratch_allocator(context, max_scratch_size); Status cublas_launch_status = - DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], - *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, - /*bias = */ {}, &profile_result); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, + se::DeviceMemoryBase{}, &profile_result); VLOG(4) << " Autotune algorithm " << i << " result: " << profile_result.elapsed_time_in_ms() @@ -701,8 +703,10 @@ struct LaunchBatchMatMul { OP_REQUIRES_OK( context, - DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], *b_ptrs[0], - *c_ptrs[0], algorithm_idx, scratch_allocator)); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], + algorithm_idx, scratch_allocator, se::DeviceMemoryBase{})); } else { // requires mixed broadcasting const std::vector& a_batch_indices = bcast.x_batch_indices(); const std::vector& b_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index c4be5da2b62ece..48ed14491b310a 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include "xla/status_macros.h" @@ -24,6 +25,7 @@ limitations under the License. #include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/matmul_autotune.h" +#include "xla/stream_executor/stream.h" namespace tensorflow { @@ -44,10 +46,6 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) { return default_value_in_bytes; } -std::string BlasLtMatmulPlanParams::ToString() const { - return ""; // TODO -} - bool BlasLtMatmulPlanParams::operator==( const BlasLtMatmulPlanParams& other) const { return internal::AsTuple(*this) == internal::AsTuple(other); @@ -55,22 +53,6 @@ bool BlasLtMatmulPlanParams::operator==( namespace { -// Thread-safe map from matmul parameters to their corresponding plan and -// algorithms. -struct BlasLtMatmulPlanMap { - absl::Mutex mu; - - template - auto emplace(Args&&... args) { - absl::MutexLock lock(&mu); - return map_.emplace(std::forward(args)...); - } - - private: - absl::flat_hash_map map_ - ABSL_GUARDED_BY(mu); -}; - int MatmulMaxAutotuneAlgorithmCount() { int64_t value; Status status = @@ -110,9 +92,19 @@ StatusOr GetBlasComputationType( } // namespace -StatusOr GetPlanAndAlgorithms( +/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::deque< BlasLtMatmulPlanCache > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + return meta[dev_id]; +} + +/* static */ auto BlasLtMatmulPlanCache::GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, - absl::Mutex** ppmu, std::optional max_algorithm_count) { + absl::Mutex** ppmu, std::optional max_algorithm_count) -> StatusOr{ static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default static const int64_t max_autotune_algorithm_count = @@ -120,17 +112,17 @@ StatusOr GetPlanAndAlgorithms( if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count; - static BlasLtMatmulPlanMap plan_map; + auto& self = BlasLtMatmulPlanCache::i(stream); - auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{}); + absl::MutexLock lock(self.mutex_.get()); + auto [ptr, inserted] = self.map_.emplace(params, Entry{}); + auto& entry = ptr->second; if (inserted) { TF_ASSIGN_OR_RETURN(auto xlatype, se::gpu::AsXlaPrimitiveType(params.dtype)); TF_ASSIGN_OR_RETURN(auto computation_type, GetBlasComputationType(params.dtype)); - auto scale_type = se::gpu::GetScaleType(params.dtype, computation_type); - // row-major output is now handled automatically by blas-lt API constexpr auto kRowMajor = se::gpu::MatrixLayout::Order::kRowMajor; @@ -173,19 +165,42 @@ StatusOr GetPlanAndAlgorithms( .compute_type = computation_type, }; - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan( stream, cfg, params.epilogue)); TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); - - ptr->second = {std::move(plan), std::move(algorithms), scale_type}; + entry.algorithms, + entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); } - *ppmu = &plan_map.mu; - return &ptr->second; + *ppmu = self.mutex_.get(); + return &entry; } +/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result) { + + return entry.plan->ExecuteOnStream( + stream, a, b, c, c, + bias, // bias_buffer + se::DeviceMemoryBase{}, // aux_buffer + se::DeviceMemoryBase{}, // a_scale_buffer + se::DeviceMemoryBase{}, // b_scale_buffer + se::DeviceMemoryBase{}, // c_scale_buffer + se::DeviceMemoryBase{}, // d_scale_buffer + se::DeviceMemoryBase{}, // d_amax_buffer + entry.algorithms[algorithm_idx], + scratch_allocator, + profile_result); +} + + } // namespace tensorflow #endif \ No newline at end of file diff --git a/tensorflow/core/kernels/matmul_util.h b/tensorflow/core/kernels/matmul_util.h index 371964424eff85..dbf85eab41242c 100644 --- a/tensorflow/core/kernels/matmul_util.h +++ b/tensorflow/core/kernels/matmul_util.h @@ -21,7 +21,7 @@ limitations under the License. #if GOOGLE_CUDA || TF_HIPBLASLT -#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "tensorflow/core/framework/types.h" @@ -35,7 +35,8 @@ namespace tensorflow { int64_t GetWorkspaceLimit(int64_t default_value_in_bytes); struct BlasLtMatmulPlanParams { - std::string ToString() const; + + std::string ToString() const { return "NOP"; } bool operator==(const BlasLtMatmulPlanParams& other) const; se::blas::DataType dtype; @@ -50,12 +51,6 @@ struct BlasLtMatmulPlanParams { se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault; }; -struct PlanAndAlgorithms { - se::gpu::BlasLt::MatmulPlanPtr plan; - std::vector algorithms; - se::blas::DataType scale_type; // this is needed for half / bf16 treatment -}; - namespace internal { inline auto AsTuple(const BlasLtMatmulPlanParams& p) { @@ -71,37 +66,42 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) { return H::combine(std::move(h), internal::AsTuple(params)); } -StatusOr GetPlanAndAlgorithms( +struct BlasLtMatmulPlanCache { + struct Entry { + se::gpu::BlasLt::MatmulPlanPtr plan; + std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms; + }; + + static StatusOr GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu, - std::optional max_algorithm_count = std::nullopt); - -template -Status DoBlasLtMatmul(se::Stream* stream, const PlanAndAlgorithms& paa, - const se::DeviceMemory& a, - const se::DeviceMemory& b, se::DeviceMemory& c, - size_t alg_idx, se::ScratchAllocator& scratch_allocator, - const se::DeviceMemory& bias = {}, - se::blas::ProfileResult* profile_result = nullptr) { - se::DeviceMemory aux{}; // We don't use the auxilary buffers. - const auto& algorithm = paa.algorithms[alg_idx]; - - // The scale type may be f32 if the data type is f16 and bf16. - if constexpr (std::is_same_v || - std::is_same_v) { - if (paa.scale_type == se::blas::DataType::kFloat) { - return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar(1.0), b, - a, se::HostOrDeviceScalar(0.0), c, c, - algorithm, scratch_allocator, bias, aux, - profile_result); - } + std::optional max_algorithm_count = std::nullopt + ); + + // helper function for plan execution + static Status ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result = nullptr); + + BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) { } - return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar(T(1.0)), b, a, - se::HostOrDeviceScalar(T(0.0)), c, c, algorithm, - scratch_allocator, bias, aux, profile_result); -} + +private: + static BlasLtMatmulPlanCache& i(se::Stream *stream); + + std::unique_ptr mutex_; + absl::node_hash_map map_ + ABSL_GUARDED_BY(mutex_); + +}; // BlasLtMatmulPlanCache } // namespace tensorflow -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT #endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 20ff725849aee9..50ac6eba566715 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -740,7 +740,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); @@ -831,7 +831,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax, workspace_buffer); AddThunkToThunkSequence(std::move(thunk)); diff --git a/third_party/xla/xla/service/gpu/runtime/BUILD b/third_party/xla/xla/service/gpu/runtime/BUILD index de8cc69d94cbd3..7250dcb2a1a082 100644 --- a/third_party/xla/xla/service/gpu/runtime/BUILD +++ b/third_party/xla/xla/service/gpu/runtime/BUILD @@ -688,6 +688,7 @@ cc_library( "@com_google_absl//absl/synchronization", "//xla/service:buffer_assignment", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:autotuner_util", "//xla/service/gpu/runtime:thunk", "//xla/stream_executor:device_memory", "//xla/stream_executor", diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc index 8cdcf39773278c..adcfb713945335 100644 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.cc @@ -25,6 +25,7 @@ limitations under the License. #include "xla/service/gpu/buffer_allocations.h" #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/runtime/thunk.h" +#include "xla/service/gpu/autotuner_util.h" #include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" @@ -35,8 +36,44 @@ limitations under the License. namespace xla { namespace gpu { +struct MatmulPlanCache { + + static MatmulPlanCache& i(const se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::vector< std::unique_ptr< MatmulPlanCache > > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) res.reset(new MatmulPlanCache()); + return *res; + } + + template < class Func > + StatusOr + GetOrCreate(const std::string& key, Func&& create) { + // each GPU has a different mutex => hence different GPU instances can + // create matmul plans in parallel + absl::MutexLock lock(mutex_.get()); + auto res = map_.emplace(key, se::gpu::BlasLt::MatmulPlanPtr{}); + if(res.second) { // new entry inserted + TF_ASSIGN_OR_RETURN(res.first->second, create()); + } + return res.first->second.get(); + } + +private: + MatmulPlanCache() : mutex_(std::make_unique< absl::Mutex >()) { } + +private: + std::unique_ptr< absl::Mutex > mutex_; + absl::flat_hash_map map_; +}; + + CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, + const HloInstruction *instr, GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, @@ -45,7 +82,7 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( BufferAllocation::Slice c_scale, BufferAllocation::Slice d_scale, BufferAllocation::Slice d_amax, std::optional workspace_buffer) - : Thunk(Kind::kCublasLtMatmul, thunk_info), + : Thunk(Kind::kCublasLtMatmul, Thunk::ThunkInfo::WithProfileAnnotation(instr)), gemm_config_(std::move(gemm_config)), epilogue_(epilogue), algorithm_idx_(algorithm_idx), @@ -60,18 +97,18 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( c_scale_buffer_(c_scale), d_scale_buffer_(d_scale), d_amax_buffer_(d_amax), - workspace_buffer_(workspace_buffer) {} + workspace_buffer_(workspace_buffer) { + + canonical_hlo_ = xla::gpu::AutotuneCacheKey("nope", *instr).GetHlo(); +} absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { - TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); - TF_ASSIGN_OR_RETURN( - auto algorithm, - GetMatmulAlgorithm(plan, workspace_buffer_.has_value() - ? workspace_buffer_.value().size() - : 0)); + TF_ASSIGN_OR_RETURN(auto *plan, GetCachedMatmulPlan(params)); + + VLOG(2) << params.stream->parent()->device_ordinal() << + ": cublas_lt_matmul for: " << canonical_hlo_; - VLOG(3) << "Running cublas_lt matmul thunk"; const BufferAllocations& allocs = *params.buffer_allocations; se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; @@ -103,47 +140,39 @@ absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { if (workspace_buffer_.has_value()) { workspace = allocs.GetDeviceAddress(workspace_buffer_.value()); } + return plan->ExecuteOnStream( params.stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, algorithm, workspace); + d_scale, d_amax, {}, workspace); } -absl::StatusOr CublasLtMatmulThunk::GetMatmulPlan( - const stream_executor::Stream* stream) { - { - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto it = matmul_plans_cache_.find(stream); - if (it != matmul_plans_cache_.end()) return it->second.get(); - } - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( - stream, gemm_config_, epilogue_)); - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto [it, _] = matmul_plans_cache_.emplace(stream, std::move(plan)); - return it->second.get(); -} - -absl::StatusOr -CublasLtMatmulThunk::GetMatmulAlgorithm(const se::gpu::BlasLt::MatmulPlan* plan, - int64_t max_workspace) { - { - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto it = matmul_algorithm_cache_.find(plan); - if (it != matmul_algorithm_cache_.end()) return it->second; - } - TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(/*max_algorithm_count*/ 128, - /*max_workspace_size*/ max_workspace)); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto [it, _] = - matmul_algorithm_cache_.emplace(plan, algorithms[algorithm_idx_]); - return it->second; +auto CublasLtMatmulThunk::GetCachedMatmulPlan( + const ExecuteParams& params) -> absl::StatusOr { + + auto& cache = MatmulPlanCache::i(params.stream); + + auto create = [&]() -> StatusOr { + VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream << + " instr: " << canonical_hlo_; + + TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + params.stream, gemm_config_, epilogue_)); + + int64_t max_workspace = workspace_buffer_.has_value() + ? workspace_buffer_.value().size() : 0; + int64_t num_algorithms = algorithm_idx_ == se::blas::kDefaultAlgorithm ? + 1 : 128; + TF_ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms, max_workspace)); + + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); + return std::move(plan); + }; + return cache.GetOrCreate(canonical_hlo_, create); } absl::Status CublasLtMatmulThunk::Initialize(const InitializeParams& params) { diff --git a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h index aa114bd3ee93fd..5602a22f6fd1fe 100644 --- a/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime/gpublas_lt_matmul_thunk.h @@ -38,7 +38,7 @@ namespace gpu { class CublasLtMatmulThunk : public Thunk { public: CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, + const HloInstruction *instr, GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, @@ -74,24 +74,13 @@ class CublasLtMatmulThunk : public Thunk { } private: - absl::StatusOr GetMatmulPlan( - const stream_executor::Stream* stream); - absl::StatusOr GetMatmulAlgorithm( - const se::gpu::BlasLt::MatmulPlan* plan, int64_t max_workspace); - - absl::Mutex matmul_plans_cache_mutex_; - absl::flat_hash_map - matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); - - absl::Mutex matmul_algorithm_cache_mutex_; - absl::flat_hash_map - matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_); + absl::StatusOr GetCachedMatmulPlan( + const ExecuteParams& params); GemmConfig gemm_config_; se::gpu::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; + std::string canonical_hlo_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; BufferAllocation::Slice c_buffer_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index 2365d53358d2cf..fb0107db3877cb 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -261,6 +261,9 @@ struct BlasLt { size_t max_algorithm_count = 128, size_t max_workspace_size = 1ll << 32) const = 0; + // Algorithm needs to be set before calling ExecuteOnStream function + virtual absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) const = 0; + virtual ~MatmulPlan() {} protected: diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index cb1b2e1094f439..4d715d8e9f871c 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -382,6 +382,11 @@ absl::Status BlasLt::MatmulPlan::ValidateInputs( return absl::OkStatus(); } +absl::Status BlasLt::MatmulPlan::SetAlgorithm(const MatmulAlgorithm& algorithm) const { + algorithm_ = algorithm; + return absl::OkStatus(); +} + absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, @@ -412,7 +417,7 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, - const MatmulAlgorithm& algorithm, DeviceMemoryBase bias, + const MatmulAlgorithm& Xalgorithm, DeviceMemoryBase bias, DeviceMemoryBase aux, DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, std::optional workspace, @@ -429,6 +434,8 @@ absl::Status BlasLt::MatmulPlan::DoMatmul( stream, profile_result && profile_result->warmup_run_executed(), profile_result)); + auto algorithm = algorithm_.has_value() ? *algorithm_ : Xalgorithm; + void* workspace_addr = nullptr; uint64_t workspace_size = 0; if (workspace.has_value()) { diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 767b80737e5a8e..aa5568e7e50988 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -137,6 +137,8 @@ class BlasLt : public gpu::BlasLt { absl::StatusOr> GetAlgorithms( size_t max_algorithm_count, size_t max_workspace_size) const override; + absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) const override; + protected: absl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, bool beta_on_device, blas::DataType A_type, @@ -190,6 +192,7 @@ class BlasLt : public gpu::BlasLt { xla::complex128 alpha_; double beta_; bool must_swap_operands_; + mutable std::optional< MatmulAlgorithm > algorithm_; // selected algorithm }; // class MatmulPlan explicit BlasLt(gpu::GpuExecutor* parent)