diff --git a/tensorflow/core/kernels/matmul_op_fused.cc b/tensorflow/core/kernels/matmul_op_fused.cc index cdf4358fb6f1f4..f784cbf38c3eee 100644 --- a/tensorflow/core/kernels/matmul_op_fused.cc +++ b/tensorflow/core/kernels/matmul_op_fused.cc @@ -597,32 +597,31 @@ struct LaunchFusedMatMulOp { epilog_op}; absl::Mutex* pmu; auto plan_and_algorithms_or = - PlanAndAlgorithms::GetOrCreate(stream, matmul_params, &pmu); + BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); - const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value(); - const auto& algorithms = plan_and_algorithms->algorithms; - OP_REQUIRES(context, algorithms.size() > 0, + const auto& entry = *plan_and_algorithms_or.value(); + OP_REQUIRES(context, entry.algorithms.size() > 0, errors::InvalidArgument("No matmul algorithm returned!")); auto launch_func = [&](BlasScratchAllocator& scratch_allocator, size_t alg_idx, se::blas::ProfileResult* profile_result) { - return plan_and_algorithms->ExecuteOnStream(stream, a_ptr, b_ptr, c_ptr, - alg_idx, scratch_allocator, bias_ptr, - profile_result); + return BlasLtMatmulPlanCache::ExecuteOnStream( + stream, entry, a_ptr, b_ptr, c_ptr, alg_idx, + scratch_allocator, bias_ptr, profile_result); }; size_t alg_idx = 0; if (use_autotune) { auto algorithm_config = - AutotuneMatmul(algorithms, matmul_params, context, launch_func); + AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func); alg_idx = algorithm_config.algorithm(); } OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr)); -#endif +#endif // GOOGLE_CUDA || TF_HIPBLASLT } }; diff --git a/tensorflow/core/kernels/matmul_op_impl.h b/tensorflow/core/kernels/matmul_op_impl.h index 71230312869092..b55677c4a15f0a 100644 --- a/tensorflow/core/kernels/matmul_op_impl.h +++ b/tensorflow/core/kernels/matmul_op_impl.h @@ -637,7 +637,7 @@ struct LaunchBatchMatMul { std::optional max_algorithm_count; if (!use_autotune) max_algorithm_count = 1; absl::Mutex* pmu = nullptr; - auto plan_and_algorithms_or = PlanAndAlgorithms::GetOrCreate( + auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate( stream, matmul_params, &pmu, max_algorithm_count); OP_REQUIRES_OK(context, plan_and_algorithms_or.status()); absl::MutexLock lock(pmu); @@ -660,8 +660,9 @@ struct LaunchBatchMatMul { // scratch space is deallocated between runs. BlasScratchAllocator scratch_allocator(context, max_scratch_size); Status cublas_launch_status = - plan_and_algorithms->ExecuteOnStream(stream, *a_ptrs[0], - *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator, se::DeviceMemoryBase{}, &profile_result); VLOG(4) << " Autotune algorithm " << i @@ -702,8 +703,10 @@ struct LaunchBatchMatMul { OP_REQUIRES_OK( context, - plan_and_algorithms->ExecuteOnStream(stream, *a_ptrs[0], *b_ptrs[0], - *c_ptrs[0], algorithm_idx, scratch_allocator)); + BlasLtMatmulPlanCache::ExecuteOnStream(stream, + *plan_and_algorithms, + *a_ptrs[0], *b_ptrs[0], *c_ptrs[0], + algorithm_idx, scratch_allocator, se::DeviceMemoryBase{})); } else { // requires mixed broadcasting const std::vector& a_batch_indices = bcast.x_batch_indices(); const std::vector& b_batch_indices = bcast.y_batch_indices(); diff --git a/tensorflow/core/kernels/matmul_util.cc b/tensorflow/core/kernels/matmul_util.cc index 50764c44159897..8f95e9a9336fe2 100644 --- a/tensorflow/core/kernels/matmul_util.cc +++ b/tensorflow/core/kernels/matmul_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include #include +#include #include #include "xla/status_macros.h" @@ -24,6 +25,8 @@ limitations under the License. #include "tensorflow/core/platform/tensor_float_32_utils.h" #include "tensorflow/core/util/env_var.h" #include "tensorflow/core/util/matmul_autotune.h" +#include "xla/stream_executor/stream.h" +#include "xla/stream_executor/stream_executor.h" namespace tensorflow { @@ -44,10 +47,6 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) { return default_value_in_bytes; } -std::string BlasLtMatmulPlanParams::ToString() const { - return ""; // TODO -} - bool BlasLtMatmulPlanParams::operator==( const BlasLtMatmulPlanParams& other) const { return internal::AsTuple(*this) == internal::AsTuple(other); @@ -55,22 +54,6 @@ bool BlasLtMatmulPlanParams::operator==( namespace { -// Thread-safe map from matmul parameters to their corresponding plan and -// algorithms. -struct BlasLtMatmulPlanMap { - absl::Mutex mu; - - template - auto emplace(Args&&... args) { - absl::MutexLock lock(&mu); - return map_.emplace(std::forward(args)...); - } - - private: - absl::node_hash_map map_ - ABSL_GUARDED_BY(mu); -}; - int MatmulMaxAutotuneAlgorithmCount() { int64_t value; Status status = @@ -110,9 +93,19 @@ StatusOr GetBlasComputationType( } // namespace -/* static */ StatusOr PlanAndAlgorithms::GetOrCreate( +/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) { + static absl::Mutex m(absl::kConstInit); + // Each GPU gets different cache instance + static std::deque< BlasLtMatmulPlanCache > meta(8); + absl::MutexLock lock(&m); + size_t dev_id = stream->parent()->device_ordinal(); + if (dev_id >= meta.size()) meta.resize(dev_id + 1); + return meta[dev_id]; +} + +/* static */ auto BlasLtMatmulPlanCache::GetOrCreate( se::Stream* stream, const BlasLtMatmulPlanParams& params, - absl::Mutex** ppmu, std::optional max_algorithm_count) { + absl::Mutex** ppmu, std::optional max_algorithm_count) -> StatusOr{ static const int64_t max_scratch_size = GetWorkspaceLimit(1LL << 32); // 4GB by default static const int64_t max_autotune_algorithm_count = @@ -120,9 +113,11 @@ StatusOr GetBlasComputationType( if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count; - static BlasLtMatmulPlanMap plan_map; + auto& self = BlasLtMatmulPlanCache::i(stream); - auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{}); + absl::MutexLock lock(self.mutex_.get()); + auto [ptr, inserted] = self.map_.emplace(params, Entry{}); + auto& entry = ptr->second; if (inserted) { TF_ASSIGN_OR_RETURN(auto xlatype, se::gpu::AsXlaPrimitiveType(params.dtype)); @@ -171,32 +166,28 @@ StatusOr GetBlasComputationType( .compute_type = computation_type, }; - TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan( + TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan( stream, cfg, params.epilogue)); TF_ASSIGN_OR_RETURN( - auto algorithms, - plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); - - ptr->second = {std::move(plan), std::move(algorithms)}; + entry.algorithms, + entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size)); } - *ppmu = &plan_map.mu; - return &ptr->second; + *ppmu = self.mutex_.get(); + return &entry; } -Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream, +/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream, + const Entry& entry, const se::DeviceMemoryBase& a, const se::DeviceMemoryBase& b, se::DeviceMemoryBase& c, size_t algorithm_idx, se::ScratchAllocator& scratch_allocator, const se::DeviceMemoryBase& bias, - se::blas::ProfileResult* profile_result) const { + se::blas::ProfileResult* profile_result) { - if(!plan || algorithm_idx >= algorithms.size()) { - return errors::Internal("MatmulPlan or algorithms are not initialized!"); - } - return plan->ExecuteOnStream( + return entry.plan->ExecuteOnStream( stream, a, b, c, c, bias, // bias_buffer se::DeviceMemoryBase{}, // aux_buffer @@ -205,9 +196,8 @@ Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream, se::DeviceMemoryBase{}, // c_scale_buffer se::DeviceMemoryBase{}, // d_scale_buffer se::DeviceMemoryBase{}, // d_amax_buffer - algorithms[algorithm_idx], - std::nullopt, // workspace - &scratch_allocator, + entry.algorithms[algorithm_idx], + scratch_allocator, profile_result); } diff --git a/tensorflow/core/kernels/matmul_util.h b/tensorflow/core/kernels/matmul_util.h index acf734a92d0d56..dbf85eab41242c 100644 --- a/tensorflow/core/kernels/matmul_util.h +++ b/tensorflow/core/kernels/matmul_util.h @@ -35,7 +35,8 @@ namespace tensorflow { int64_t GetWorkspaceLimit(int64_t default_value_in_bytes); struct BlasLtMatmulPlanParams { - std::string ToString() const; + + std::string ToString() const { return "NOP"; } bool operator==(const BlasLtMatmulPlanParams& other) const; se::blas::DataType dtype; @@ -50,26 +51,6 @@ struct BlasLtMatmulPlanParams { se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault; }; -struct PlanAndAlgorithms { - - static StatusOr GetOrCreate( - se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu, - std::optional max_algorithm_count = std::nullopt - ); - - Status ExecuteOnStream(se::Stream* stream, - const se::DeviceMemoryBase& a, - const se::DeviceMemoryBase& b, - se::DeviceMemoryBase& c, - size_t algorithm_idx, - se::ScratchAllocator& scratch_allocator, - const se::DeviceMemoryBase& bias = se::DeviceMemoryBase{}, - se::blas::ProfileResult* profile_result = nullptr) const; - - se::gpu::BlasLt::MatmulPlanPtr plan; - std::vector algorithms; -}; - namespace internal { inline auto AsTuple(const BlasLtMatmulPlanParams& p) { @@ -85,6 +66,40 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) { return H::combine(std::move(h), internal::AsTuple(params)); } +struct BlasLtMatmulPlanCache { + struct Entry { + se::gpu::BlasLt::MatmulPlanPtr plan; + std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms; + }; + + static StatusOr GetOrCreate( + se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu, + std::optional max_algorithm_count = std::nullopt + ); + + // helper function for plan execution + static Status ExecuteOnStream(se::Stream* stream, + const Entry& entry, + const se::DeviceMemoryBase& a, + const se::DeviceMemoryBase& b, + se::DeviceMemoryBase& c, + size_t algorithm_idx, + se::ScratchAllocator& scratch_allocator, + const se::DeviceMemoryBase& bias, + se::blas::ProfileResult* profile_result = nullptr); + + BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) { + } + +private: + static BlasLtMatmulPlanCache& i(se::Stream *stream); + + std::unique_ptr mutex_; + absl::node_hash_map map_ + ABSL_GUARDED_BY(mutex_); + +}; // BlasLtMatmulPlanCache + } // namespace tensorflow #endif // GOOGLE_CUDA || TF_HIPBLASLT diff --git a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc index b8e5a8e8d1e662..531ad8d2f605fb 100644 --- a/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc +++ b/third_party/xla/xla/service/gpu/buffer_comparator.cu.cc @@ -103,11 +103,16 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a, #endif // GOOGLE_CUDA #if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 + __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, __hip_fp8_storage_t* buffer_b, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) +// NOTE: according to amd_hip_fp8.h, GFX1200 and GFX1201 support ocp __hip_fp8_e4m3 +// but not __hip_fp8_e4m3_fnuz + int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; __hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8; @@ -123,6 +128,10 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); +#else + // on unsupported architectures, this should not / cannot be used! + atomicAdd(mismatch_count, 1); +#endif } __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, @@ -130,6 +139,7 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, float rel_error_threshold, uint64_t buffer_length, int* mismatch_count) { +#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__) int idx = threadIdx.x + blockIdx.x * blockDim.x; if (idx >= buffer_length) return; __hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8; @@ -145,7 +155,12 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a, if (rel_error > rel_error_threshold || isnan(rel_error)) atomicAdd(mismatch_count, 1); +#else + // on unsupported architectures, this should not / cannot be used! + atomicAdd(mismatch_count, 1); +#endif } + #endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200 __global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b, diff --git a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc index 18c84744491fbe..f3be29c92fea20 100644 --- a/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc +++ b/third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc @@ -618,4 +618,4 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream( } // namespace stream_executor -#endif // TF_HIPBLASLT +#endif // TF_HIPBLASLT \ No newline at end of file diff --git a/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h b/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h index c53cff6a933913..9574c7b7ac28db 100644 --- a/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h +++ b/third_party/xla/xla/stream_executor/rocm/hipblaslt_wrapper.h @@ -18,6 +18,7 @@ limitations under the License. #define XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_ #define __HIP_DISABLE_CPP_FUNCTIONS__ +#define LEGACY_HIPBLAS_DIRECT #include "rocm/rocm_config.h"