diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index e38b854db4ec19..640cdfcf02f68d 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -200,6 +200,10 @@ struct LaunchFusedMatMulOp { namespace { #if GOOGLE_CUDA || TF_HIPBLASLT +/* + hipBLASLt support Epilogue: + https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#hipblasltepilogue-t +*/ StatusOr GetBlasLtEpilogOp( FusedComputationType fusion) { if (fusion == FusedComputationType::kBiasAdd) { @@ -263,7 +267,7 @@ se::blas::AlgorithmConfig AutotuneMatmul( } return algorithm_config; } -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT template StatusOr> AutotuneMatMulImpl( @@ -477,6 +481,17 @@ struct LaunchFusedMatMulOp { se::dnn::ActivationMode matmul_activation_mode; bool use_cudnn = false; + +#if !(GOOGLE_CUDA || TF_HIPBLASLT) + use_cudnn = true; +#endif + const auto& cc = stream->parent()->GetDeviceDescription(). + gpu_compute_capability(); + if (auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + use_cudnn = !procm->gfx9_mi200_or_later(); + } + + // use_cudnn is for hipblaslt doesn't support yet switch (fusion) { case FusedComputationType::kBiasAddWithGeluExact: matmul_activation_mode = se::dnn::ActivationMode::kGeluExact; @@ -511,15 +526,6 @@ struct LaunchFusedMatMulOp { default: use_cudnn = false; } -#if !(GOOGLE_CUDA || TF_HIPBLASLT) - use_cudnn = true; -#endif - -#if TF_HIPBLASLT - auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true; -#endif BlasScratchAllocator scratch_allocator(context); @@ -590,32 +596,31 @@ struct LaunchFusedMatMulOp { epilog_op}; absl::Mutex* pmu; auto plan_and_algorithms_or = - GetPlanAndAlgorithms(stream, matmul_params, &pmu); + BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); - const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value(); - const auto& algorithms = plan_and_algorithms->algorithms; - OP_REQUIRES(context, algorithms.size() > 0, + const auto& entry = *plan_and_algorithms_or.value(); + OP_REQUIRES(context, entry.algorithms.size() > 0, errors::InvalidArgument("No matmul algorithm returned!")); auto launch_func = [&](BlasScratchAllocator& scratch_allocator, size_t alg_idx, se::blas::ProfileResult* profile_result) { - return DoBlasLtMatmul(stream, *plan_and_algorithms, a_ptr, b_ptr, c_ptr, - alg_idx, scratch_allocator, bias_ptr, - profile_result); + return BlasLtMatmulPlanCache::ExecuteOnStream( + stream, entry, a_ptr, b_ptr, c_ptr, alg_idx, + scratch_allocator, bias_ptr, profile_result); }; size_t alg_idx = 0; if (use_autotune) { auto algorithm_config = - AutotuneMatmul(algorithms, matmul_params, context, launch_func); + AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func); alg_idx = algorithm_config.algorithm(); } OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr)); -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT } }; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 2c547d00d106f8..3f097a7fc4ed15 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -601,12 +601,13 @@ struct LaunchBatchMatMul { #if GOOGLE_CUDA || TF_HIPBLASLT static const bool use_autotune = MatmulAutotuneEnable(); bool bCublasLtSupport = true; -#if TF_HIPBLASLT - if (!std::is_same_v) bCublasLtSupport = false; - auto cap = stream->GetRocmComputeCapability(); - // as of ROCm 5.5, hipblaslt only supports MI200. - if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false; -#endif + + const auto& cc = stream->parent()->GetDeviceDescription(). + gpu_compute_capability(); + if(auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) { + bCublasLtSupport = procm->gfx9_mi200_or_later(); + } + if (EnableCublasLtGemm() && bCublasLtSupport) { static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default @@ -636,7 +637,7 @@ struct LaunchBatchMatMul { std::optional max_algorithm_count; if (!use_autotune) max_algorithm_count = 1; absl::Mutex* pmu = nullptr; - auto plan_and_algorithms_or = GetPlanAndAlgorithms( + auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate( stream, matmul_params, &pmu, max_algorithm_count); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); @@ -659,9 +660,10 @@ struct LaunchBatchMatMul { // scratch space is deallocated between runs. BlasScratchAllocator scratch_allocator(context, max_scratch_size); Status cublas_launch_status = - DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], - *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, - /*bias = */ {}, &profile_result); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, + se::DeviceMemoryBase{}, &profile_result); VLOG(4) << " Autotune algorithm " << i << " result: " << profile_result.elapsed_time_in_ms() @@ -701,8 +703,10 @@ struct LaunchBatchMatMul { OP_REQUIRES_OK( context, - DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], *b_ptrs[0], - *c_ptrs[0], algorithm_idx, scratch_allocator)); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], + algorithm_idx, scratch_allocator, se::DeviceMemoryBase{})); } else { // requires mixed broadcasting const std::vector& a_batch_indices = bcast.x_batch_indices(); const std::vector& b_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index 930de6e25ed604..9291eb211dca05 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include "xla/status_macros.h" @@ -23,6 +24,7 @@ limitations under the License. #include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/matmul_autotune.h" +#include "xla/stream_executor/stream.h" namespace tensorflow { @@ -43,10 +45,6 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) { return default_value_in_bytes; } -std::string BlasLtMatmulPlanParams::ToString() const { - return ""; // TODO -} - bool BlasLtMatmulPlanParams::operator==( const BlasLtMatmulPlanParams& other) const { return internal::AsTuple(*this) == internal::AsTuple(other); @@ -54,22 +52,6 @@ bool BlasLtMatmulPlanParams::operator==( namespace { -// Thread-safe map from matmul parameters to their corresponding plan and -// algorithms. -struct BlasLtMatmulPlanMap { - absl::Mutex mu; - - template - auto emplace(Args&&... args) { - absl::MutexLock lock(&mu); - return map_.emplace(std::forward(args)...); - } - - private: - absl::flat_hash_map map_ - ABSL_GUARDED_BY(mu); -}; - int MatmulMaxAutotuneAlgorithmCount() { int64_t value; Status status = @@ -109,9 +91,19 @@ StatusOr GetBlasComputationType( } // namespace -StatusOr GetPlanAndAlgorithms( +/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::deque< BlasLtMatmulPlanCache > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + return meta[dev_id]; +} + +/* static */ auto BlasLtMatmulPlanCache::GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, - absl::Mutex** ppmu, std::optional max_algorithm_count) { + absl::Mutex** ppmu, std::optional max_algorithm_count) -> StatusOr{ static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default static const int64_t max_autotune_algorithm_count = @@ -119,17 +111,17 @@ StatusOr GetPlanAndAlgorithms( if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count; - static BlasLtMatmulPlanMap plan_map; + auto& self = BlasLtMatmulPlanCache::i(stream); - auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{}); + absl::MutexLock lock(self.mutex_.get()); + auto [ptr, inserted] = self.map_.emplace(params, Entry{}); + auto& entry = ptr->second; if (inserted) { TF_ASSIGN_OR_RETURN(auto xlatype, se::gpu::AsXlaPrimitiveType(params.dtype)); TF_ASSIGN_OR_RETURN(auto computation_type, GetBlasComputationType(params.dtype)); - auto scale_type = se::gpu::GetScaleType(params.dtype, computation_type); - // row-major output is now handled automatically by blas-lt API constexpr auto kRowMajor = se::gpu::MatrixLayout::Order::kRowMajor; @@ -171,19 +163,42 @@ StatusOr GetPlanAndAlgorithms( .compute_type = computation_type, }; - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan( stream, cfg, params.epilogue)); TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); - - ptr->second = {std::move(plan), std::move(algorithms), scale_type}; + entry.algorithms, + entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); } - *ppmu = &plan_map.mu; - return &ptr->second; + *ppmu = self.mutex_.get(); + return &entry; } +/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result) { + + return entry.plan->ExecuteOnStream( + stream, a, b, c, c, + bias, // bias_buffer + se::DeviceMemoryBase{}, // aux_buffer + se::DeviceMemoryBase{}, // a_scale_buffer + se::DeviceMemoryBase{}, // b_scale_buffer + se::DeviceMemoryBase{}, // c_scale_buffer + se::DeviceMemoryBase{}, // d_scale_buffer + se::DeviceMemoryBase{}, // d_amax_buffer + entry.algorithms[algorithm_idx], + scratch_allocator, + profile_result); +} + + } // namespace tensorflow #endif \ No newline at end of file diff --git a/tensorflow/core/kernels/matmul_util.h b/tensorflow/core/kernels/matmul_util.h index 371964424eff85..dbf85eab41242c 100644 --- a/tensorflow/core/kernels/matmul_util.h +++ b/tensorflow/core/kernels/matmul_util.h @@ -21,7 +21,7 @@ limitations under the License. #if GOOGLE_CUDA || TF_HIPBLASLT -#include "absl/container/flat_hash_map.h" +#include "absl/container/node_hash_map.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/gpu/gpu_blas_lt.h" #include "tensorflow/core/framework/types.h" @@ -35,7 +35,8 @@ namespace tensorflow { int64_t GetWorkspaceLimit(int64_t default_value_in_bytes); struct BlasLtMatmulPlanParams { - std::string ToString() const; + + std::string ToString() const { return "NOP"; } bool operator==(const BlasLtMatmulPlanParams& other) const; se::blas::DataType dtype; @@ -50,12 +51,6 @@ struct BlasLtMatmulPlanParams { se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault; }; -struct PlanAndAlgorithms { - se::gpu::BlasLt::MatmulPlanPtr plan; - std::vector algorithms; - se::blas::DataType scale_type; // this is needed for half / bf16 treatment -}; - namespace internal { inline auto AsTuple(const BlasLtMatmulPlanParams& p) { @@ -71,37 +66,42 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) { return H::combine(std::move(h), internal::AsTuple(params)); } -StatusOr GetPlanAndAlgorithms( +struct BlasLtMatmulPlanCache { + struct Entry { + se::gpu::BlasLt::MatmulPlanPtr plan; + std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms; + }; + + static StatusOr GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu, - std::optional max_algorithm_count = std::nullopt); - -template -Status DoBlasLtMatmul(se::Stream* stream, const PlanAndAlgorithms& paa, - const se::DeviceMemory& a, - const se::DeviceMemory& b, se::DeviceMemory& c, - size_t alg_idx, se::ScratchAllocator& scratch_allocator, - const se::DeviceMemory& bias = {}, - se::blas::ProfileResult* profile_result = nullptr) { - se::DeviceMemory aux{}; // We don't use the auxilary buffers. - const auto& algorithm = paa.algorithms[alg_idx]; - - // The scale type may be f32 if the data type is f16 and bf16. - if constexpr (std::is_same_v || - std::is_same_v) { - if (paa.scale_type == se::blas::DataType::kFloat) { - return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar(1.0), b, - a, se::HostOrDeviceScalar(0.0), c, c, - algorithm, scratch_allocator, bias, aux, - profile_result); - } + std::optional max_algorithm_count = std::nullopt + ); + + // helper function for plan execution + static Status ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result = nullptr); + + BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) { } - return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar(T(1.0)), b, a, - se::HostOrDeviceScalar(T(0.0)), c, c, algorithm, - scratch_allocator, bias, aux, profile_result); -} + +private: + static BlasLtMatmulPlanCache& i(se::Stream *stream); + + std::unique_ptr mutex_; + absl::node_hash_map map_ + ABSL_GUARDED_BY(mutex_); + +}; // BlasLtMatmulPlanCache } // namespace tensorflow -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT #endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_ diff --git a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc index 41dc54fb3deae7..69a2d772e44ca5 100644 --- a/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc +++ b/third_party/xla/xla/service/gpu/ir_emitter_unnested.cc @@ -1078,14 +1078,16 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + instr, Thunk::ThunkInfo::WithProfileAnnotation(instr), + std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); AddThunkToThunkSequence(std::move(thunk)); return absl::OkStatus(); } -absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { +absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk( + mlir::Operation* op) { auto matmul = mlir::dyn_cast(op); TF_RET_CHECK(matmul != nullptr); @@ -1107,7 +1109,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunk(mlir::Operation* op) { TF_ASSIGN_OR_RETURN(GemmConfig gemm_config, GemmConfig::For(matmul)); TF_ASSIGN_OR_RETURN(auto epilogue, gpublas_lt::AsBlasLtEpilogue(matmul.getEpilogue())); - auto thunk = std::make_unique( + auto thunk = std::make_unique(nullptr, // instruction is not given Thunk::ThunkInfo::WithProfileAnnotation(op), std::move(gemm_config), epilogue, matmul.getAlgorithm(), a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); @@ -1187,7 +1189,7 @@ absl::Status IrEmitterUnnested::EmitCublasLtMatmulThunkF8( TF_ASSIGN_OR_RETURN(se::gpu::BlasLt::Epilogue blas_lt_epilogue, gpublas_lt::AsBlasLtEpilogue(epilogue)); auto thunk = std::make_unique( - Thunk::ThunkInfo::WithProfileAnnotation(instr), std::move(gemm_config), + instr, std::move(gemm_config), blas_lt_epilogue, algorithm, a, b, c, d, bias, aux, a_scale, b_scale, c_scale, d_scale, d_amax); AddThunkToThunkSequence(std::move(thunk)); diff --git a/third_party/xla/xla/service/gpu/runtime3/BUILD b/third_party/xla/xla/service/gpu/runtime3/BUILD index aee0672f68402f..5acf58fb280664 100644 --- a/third_party/xla/xla/service/gpu/runtime3/BUILD +++ b/third_party/xla/xla/service/gpu/runtime3/BUILD @@ -444,6 +444,7 @@ cc_library( deps = if_gpu_is_configured([ "//xla/service:buffer_assignment", "//xla/service/gpu:matmul_utils", + "//xla/service/gpu:autotuner_util", "//xla/service/gpu:thunk", "//xla:status", "//xla/stream_executor:device_memory", diff --git a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc b/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc index ae0c7eb908cb8b..eaf8c462912149 100644 --- a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc +++ b/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.cc @@ -19,7 +19,8 @@ limitations under the License. #include "xla/service/gpu/matmul_utils.h" #include "xla/service/gpu/thunk.h" -#include "xla/status.h" +#include "xla/service/gpu/autotuner_util.h" +#include "xla/status_macros.h" #include "xla/stream_executor/device_memory.h" #include "xla/stream_executor/scratch_allocator.h" #include "tsl/platform/logging.h" @@ -27,8 +28,46 @@ limitations under the License. namespace xla { namespace gpu { +struct MatmulPlanCache { + + static MatmulPlanCache& i(const se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::vector< std::unique_ptr< MatmulPlanCache > > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + auto& res = meta[dev_id]; + if (!res) res.reset(new MatmulPlanCache()); + return *res; + } + + template < class Func > + StatusOr + GetOrCreate(const std::string& key, Func&& create) { + // each GPU has a different mutex => hence different GPU instances can + // create matmul plans in parallel + absl::MutexLock lock(mutex_.get()); + auto res = map_.emplace(key, se::gpu::BlasLt::MatmulPlanPtr{}); + if(res.second) { // new entry inserted + TF_ASSIGN_OR_RETURN(res.first->second, create()); + } + return res.first->second.get(); + } + +private: + MatmulPlanCache() : mutex_(std::make_unique< absl::Mutex >()) { } + +private: + std::unique_ptr< absl::Mutex > mutex_; + absl::flat_hash_map map_; +}; + + CublasLtMatmulThunk::CublasLtMatmulThunk( - ThunkInfo thunk_info, GemmConfig gemm_config, + const HloInstruction *instr, + ThunkInfo thunk_info, + GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, BufferAllocation::Slice c_buffer, BufferAllocation::Slice d_buffer, @@ -50,13 +89,20 @@ CublasLtMatmulThunk::CublasLtMatmulThunk( b_scale_buffer_(b_scale), c_scale_buffer_(c_scale), d_scale_buffer_(d_scale), - d_amax_buffer_(d_amax) {} + d_amax_buffer_(d_amax) { + // if instruction is not available, cache based on profile annotation + // anyway, this execution path is not used with a new XLA runtime + canonical_hlo_ = instr ? xla::gpu::AutotuneCacheKey("nope", *instr).GetHlo() : + thunk_info.profile_annotation; +} absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { - TF_ASSIGN_OR_RETURN(auto plan, GetMatmulPlan(params.stream)); - TF_ASSIGN_OR_RETURN(auto algorithm, GetMatmulAlgorithm(plan)); - VLOG(3) << "Running cublas_lt matmul thunk"; + TF_ASSIGN_OR_RETURN(auto *plan, GetCachedMatmulPlan(params)); + + VLOG(2) << params.stream->parent()->device_ordinal() << + ": cublas_lt_matmul for: " << canonical_hlo_; + const BufferAllocations& allocs = *params.buffer_allocations; se::DeviceMemoryBase bias, a_scale, b_scale, c_scale, d_scale, d_amax; @@ -90,33 +136,29 @@ absl::Status CublasLtMatmulThunk::ExecuteOnStream(const ExecuteParams& params) { params.stream, allocs.GetDeviceAddress(a_buffer_), allocs.GetDeviceAddress(b_buffer_), allocs.GetDeviceAddress(c_buffer_), allocs.GetDeviceAddress(d_buffer_), bias, aux, a_scale, b_scale, c_scale, - d_scale, d_amax, *algorithm, scratch_allocator); + d_scale, d_amax, {}, scratch_allocator); } -absl::StatusOr CublasLtMatmulThunk::GetMatmulPlan( - const stream_executor::Stream* stream) { - absl::MutexLock lock(&matmul_plans_cache_mutex_); - auto it = matmul_plans_cache_.find(stream); - if (it == matmul_plans_cache_.end()) { - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( - stream, gemm_config_, epilogue_)); - it = matmul_plans_cache_.emplace(stream, std::move(plan)).first; - } - return it->second.get(); -} +auto CublasLtMatmulThunk::GetCachedMatmulPlan( + const ExecuteParams& params) -> absl::StatusOr { -absl::StatusOr > -CublasLtMatmulThunk::GetMatmulAlgorithm( - const se::gpu::BlasLt::MatmulPlan* plan) { - absl::MutexLock lock(&matmul_algorithm_cache_mutex_); - auto it = matmul_algorithm_cache_.find(plan); - if (it == matmul_algorithm_cache_.end()) { - TF_ASSIGN_OR_RETURN(auto algorithms, plan->GetAlgorithms()); - TF_RET_CHECK(algorithm_idx_ >= 0 && algorithm_idx_ < algorithms.size()); - auto algorithm = algorithms[algorithm_idx_]; - it = matmul_algorithm_cache_.emplace(plan, algorithm).first; - } - return it->second; + auto& cache = MatmulPlanCache::i(params.stream); + + auto create = [&]() -> StatusOr { + VLOG(2) << this << ": Adding new MatmulPlan for stream: " << params.stream << + " instr: " << canonical_hlo_; + + TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + params.stream, gemm_config_, epilogue_)); + + int64_t num_algorithms = algorithm_idx_ == 0 ? 1 : 128; + TF_ASSIGN_OR_RETURN(auto algorithms, + plan->GetAlgorithms(num_algorithms)); + + TF_RETURN_IF_ERROR(plan->SetAlgorithm(algorithms[algorithm_idx_])); + return std::move(plan); + }; + return cache.GetOrCreate(canonical_hlo_, create); } } // namespace gpu diff --git a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h b/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h index 9558890ea828f1..0ecc2bbafba692 100644 --- a/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h +++ b/third_party/xla/xla/service/gpu/runtime3/gpublas_lt_matmul_thunk.h @@ -31,7 +31,9 @@ namespace gpu { class CublasLtMatmulThunk : public Thunk { public: - CublasLtMatmulThunk(ThunkInfo thunk_info, GemmConfig gemm_config, + CublasLtMatmulThunk(const HloInstruction *instr, + ThunkInfo thunk_info, + GemmConfig gemm_config, se::gpu::BlasLt::Epilogue epilogue, int64_t algorithm_idx, BufferAllocation::Slice a_buffer, BufferAllocation::Slice b_buffer, @@ -48,24 +50,13 @@ class CublasLtMatmulThunk : public Thunk { absl::Status ExecuteOnStream(const ExecuteParams& params) override; private: - absl::StatusOr GetMatmulPlan( - const stream_executor::Stream* stream); - absl::StatusOr > - GetMatmulAlgorithm(const se::gpu::BlasLt::MatmulPlan* plan); - - absl::Mutex matmul_plans_cache_mutex_; - absl::flat_hash_map - matmul_plans_cache_ ABSL_GUARDED_BY(matmul_plans_cache_mutex_); - - absl::Mutex matmul_algorithm_cache_mutex_; - absl::flat_hash_map - matmul_algorithm_cache_ ABSL_GUARDED_BY(matmul_algorithm_cache_mutex_); + absl::StatusOr GetCachedMatmulPlan( + const ExecuteParams& params); GemmConfig gemm_config_; se::gpu::BlasLt::Epilogue epilogue_; int64_t algorithm_idx_; + std::string canonical_hlo_; BufferAllocation::Slice a_buffer_; BufferAllocation::Slice b_buffer_; BufferAllocation::Slice c_buffer_; diff --git a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h index 6bac2ea13676e5..05e5872985243b 100644 --- a/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h +++ b/third_party/xla/xla/stream_executor/gpu/gpu_blas_lt.h @@ -200,6 +200,9 @@ struct BlasLt { size_t max_algorithm_count = 128, size_t max_workspace_size = 1ll << 32) const = 0; + // Algorithm needs to be set before calling ExecuteOnStream function + virtual absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) const = 0; + virtual ~MatmulPlan() {} protected: diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index bbc7f834fccfcb..f1e757d9d448f8 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -349,18 +349,26 @@ absl::Status BlasLt::MatmulPlan::ValidateInputs( return absl::OkStatus(); } +absl::Status BlasLt::MatmulPlan::SetAlgorithm(const MatmulAlgorithm& algorithm) const { + algorithm_ = algorithm; + return absl::OkStatus(); +} + absl::Status BlasLt::MatmulPlan::DoMatmul( Stream* stream, const void* alpha, DeviceMemoryBase a, DeviceMemoryBase b, const void* beta, DeviceMemoryBase c, DeviceMemoryBase d, - const MatmulAlgorithm& algorithm, ScratchAllocator& scratch_allocator, + const MatmulAlgorithm& Xalgorithm, ScratchAllocator& scratch_allocator, DeviceMemoryBase bias, DeviceMemoryBase aux, DeviceMemoryBase a_scale, DeviceMemoryBase b_scale, DeviceMemoryBase c_scale, DeviceMemoryBase d_scale, DeviceMemoryBase d_amax, blas::ProfileResult* profile_result) const { + TF_ASSIGN_OR_RETURN( std::optional timer, gpu::GpuTimer::CreateIfNeeded(gpu::AsGpuStream(stream), profile_result)); + auto algorithm = algorithm_.has_value() ? *algorithm_ : Xalgorithm; + void* workspace = nullptr; if (algorithm.workspace_size > 0) { TF_ASSIGN_OR_RETURN( diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h index 3f0daaca1915fa..720574a6a55381 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.h @@ -114,6 +114,8 @@ class BlasLt : public gpu::BlasLt { absl::StatusOr> GetAlgorithms( size_t max_algorithm_count, size_t max_workspace_size) const override; + absl::Status SetAlgorithm(const MatmulAlgorithm& algorithm) const override; + protected: absl::Status ValidateInputs(blas::DataType scale_type, bool alpha_on_device, bool beta_on_device, blas::DataType A_type, @@ -142,6 +144,7 @@ class BlasLt : public gpu::BlasLt { xla::complex128 alpha_; double beta_; bool must_swap_operands_; + mutable std::optional< MatmulAlgorithm > algorithm_; // selected algorithm }; // class MatmulPlan explicit BlasLt(gpu::GpuExecutor* parent)