Skip to content

Commit

Permalink
Merge pull request #2780 from ROCm/r2.18-rocm-enhanced-hipblaslt-and-…
Browse files Browse the repository at this point in the history
…fp8-fixes

numerous hipblaslt related fixes & fp8 buffer_comparator fix
  • Loading branch information
i-chaochen authored Dec 2, 2024
2 parents 8f61c78 + 75c9894 commit 432b95b
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 76 deletions.
17 changes: 8 additions & 9 deletions tensorflow/core/kernels/matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -597,32 +597,31 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
epilog_op};
absl::Mutex* pmu;
auto plan_and_algorithms_or =
PlanAndAlgorithms::GetOrCreate(stream, matmul_params, &pmu);
BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu);
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
absl::MutexLock lock(pmu);
const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value();
const auto& algorithms = plan_and_algorithms->algorithms;
OP_REQUIRES(context, algorithms.size() > 0,
const auto& entry = *plan_and_algorithms_or.value();
OP_REQUIRES(context, entry.algorithms.size() > 0,
errors::InvalidArgument("No matmul algorithm returned!"));

auto launch_func = [&](BlasScratchAllocator& scratch_allocator,
size_t alg_idx,
se::blas::ProfileResult* profile_result) {
return plan_and_algorithms->ExecuteOnStream(stream, a_ptr, b_ptr, c_ptr,
alg_idx, scratch_allocator, bias_ptr,
profile_result);
return BlasLtMatmulPlanCache::ExecuteOnStream(
stream, entry, a_ptr, b_ptr, c_ptr, alg_idx,
scratch_allocator, bias_ptr, profile_result);
};

size_t alg_idx = 0;
if (use_autotune) {
auto algorithm_config =
AutotuneMatmul(algorithms, matmul_params, context, launch_func);
AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func);

alg_idx = algorithm_config.algorithm();
}

OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr));
#endif
#endif // GOOGLE_CUDA || TF_HIPBLASLT
}
};

Expand Down
13 changes: 8 additions & 5 deletions tensorflow/core/kernels/matmul_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -637,7 +637,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
std::optional<int> max_algorithm_count;
if (!use_autotune) max_algorithm_count = 1;
absl::Mutex* pmu = nullptr;
auto plan_and_algorithms_or = PlanAndAlgorithms::GetOrCreate(
auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate(
stream, matmul_params, &pmu, max_algorithm_count);
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
absl::MutexLock lock(pmu);
Expand All @@ -660,8 +660,9 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
// scratch space is deallocated between runs.
BlasScratchAllocator scratch_allocator(context, max_scratch_size);
Status cublas_launch_status =
plan_and_algorithms->ExecuteOnStream(stream, *a_ptrs[0],
*b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
*plan_and_algorithms,
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
se::DeviceMemoryBase{}, &profile_result);

VLOG(4) << " Autotune algorithm " << i
Expand Down Expand Up @@ -702,8 +703,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {

OP_REQUIRES_OK(
context,
plan_and_algorithms->ExecuteOnStream(stream, *a_ptrs[0], *b_ptrs[0],
*c_ptrs[0], algorithm_idx, scratch_allocator));
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
*plan_and_algorithms,
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
algorithm_idx, scratch_allocator, se::DeviceMemoryBase{}));
} else { // requires mixed broadcasting
const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
Expand Down
70 changes: 30 additions & 40 deletions tensorflow/core/kernels/matmul_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <optional>
#include <string>
#include <deque>
#include <utility>

#include "xla/status_macros.h"
Expand All @@ -24,6 +25,8 @@ limitations under the License.
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/matmul_autotune.h"
#include "xla/stream_executor/stream.h"
#include "xla/stream_executor/stream_executor.h"

namespace tensorflow {

Expand All @@ -44,33 +47,13 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
return default_value_in_bytes;
}

std::string BlasLtMatmulPlanParams::ToString() const {
return ""; // TODO
}

bool BlasLtMatmulPlanParams::operator==(
const BlasLtMatmulPlanParams& other) const {
return internal::AsTuple(*this) == internal::AsTuple(other);
}

namespace {

// Thread-safe map from matmul parameters to their corresponding plan and
// algorithms.
struct BlasLtMatmulPlanMap {
absl::Mutex mu;

template <class... Args>
auto emplace(Args&&... args) {
absl::MutexLock lock(&mu);
return map_.emplace(std::forward<Args>(args)...);
}

private:
absl::node_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms> map_
ABSL_GUARDED_BY(mu);
};

int MatmulMaxAutotuneAlgorithmCount() {
int64_t value;
Status status =
Expand Down Expand Up @@ -110,19 +93,31 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(

} // namespace

/* static */ StatusOr<const PlanAndAlgorithms*> PlanAndAlgorithms::GetOrCreate(
/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) {
static absl::Mutex m(absl::kConstInit);
// Each GPU gets different cache instance
static std::deque< BlasLtMatmulPlanCache > meta(8);
absl::MutexLock lock(&m);
size_t dev_id = stream->parent()->device_ordinal();
if (dev_id >= meta.size()) meta.resize(dev_id + 1);
return meta[dev_id];
}

/* static */ auto BlasLtMatmulPlanCache::GetOrCreate(
se::Stream* stream, const BlasLtMatmulPlanParams& params,
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) {
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) -> StatusOr<const Entry *>{
static const int64_t max_scratch_size =
GetWorkspaceLimit(1LL << 32); // 4GB by default
static const int64_t max_autotune_algorithm_count =
MatmulMaxAutotuneAlgorithmCount();

if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;

static BlasLtMatmulPlanMap plan_map;
auto& self = BlasLtMatmulPlanCache::i(stream);

auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{});
absl::MutexLock lock(self.mutex_.get());
auto [ptr, inserted] = self.map_.emplace(params, Entry{});
auto& entry = ptr->second;
if (inserted) {
TF_ASSIGN_OR_RETURN(auto xlatype,
se::gpu::AsXlaPrimitiveType(params.dtype));
Expand Down Expand Up @@ -171,32 +166,28 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(
.compute_type = computation_type,
};

TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan(
TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan(
stream, cfg, params.epilogue));

TF_ASSIGN_OR_RETURN(
auto algorithms,
plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));

ptr->second = {std::move(plan), std::move(algorithms)};
entry.algorithms,
entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));
}
*ppmu = &plan_map.mu;
return &ptr->second;
*ppmu = self.mutex_.get();
return &entry;
}

Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream,
/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream,
const Entry& entry,
const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b,
se::DeviceMemoryBase& c,
size_t algorithm_idx,
se::ScratchAllocator& scratch_allocator,
const se::DeviceMemoryBase& bias,
se::blas::ProfileResult* profile_result) const {
se::blas::ProfileResult* profile_result) {

if(!plan || algorithm_idx >= algorithms.size()) {
return errors::Internal("MatmulPlan or algorithms are not initialized!");
}
return plan->ExecuteOnStream(
return entry.plan->ExecuteOnStream(
stream, a, b, c, c,
bias, // bias_buffer
se::DeviceMemoryBase{}, // aux_buffer
Expand All @@ -205,9 +196,8 @@ Status PlanAndAlgorithms::ExecuteOnStream(se::Stream* stream,
se::DeviceMemoryBase{}, // c_scale_buffer
se::DeviceMemoryBase{}, // d_scale_buffer
se::DeviceMemoryBase{}, // d_amax_buffer
algorithms[algorithm_idx],
std::nullopt, // workspace
&scratch_allocator,
entry.algorithms[algorithm_idx],
scratch_allocator,
profile_result);
}

Expand Down
57 changes: 36 additions & 21 deletions tensorflow/core/kernels/matmul_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ namespace tensorflow {
int64_t GetWorkspaceLimit(int64_t default_value_in_bytes);

struct BlasLtMatmulPlanParams {
std::string ToString() const;

std::string ToString() const { return "NOP"; }
bool operator==(const BlasLtMatmulPlanParams& other) const;

se::blas::DataType dtype;
Expand All @@ -50,26 +51,6 @@ struct BlasLtMatmulPlanParams {
se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault;
};

struct PlanAndAlgorithms {

static StatusOr<const PlanAndAlgorithms*> GetOrCreate(
se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu,
std::optional<int> max_algorithm_count = std::nullopt
);

Status ExecuteOnStream(se::Stream* stream,
const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b,
se::DeviceMemoryBase& c,
size_t algorithm_idx,
se::ScratchAllocator& scratch_allocator,
const se::DeviceMemoryBase& bias = se::DeviceMemoryBase{},
se::blas::ProfileResult* profile_result = nullptr) const;

se::gpu::BlasLt::MatmulPlanPtr plan;
std::vector<se::gpu::BlasLt::MatmulAlgorithm> algorithms;
};

namespace internal {

inline auto AsTuple(const BlasLtMatmulPlanParams& p) {
Expand All @@ -85,6 +66,40 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) {
return H::combine(std::move(h), internal::AsTuple(params));
}

struct BlasLtMatmulPlanCache {
struct Entry {
se::gpu::BlasLt::MatmulPlanPtr plan;
std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms;
};

static StatusOr<const Entry *> GetOrCreate(
se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu,
std::optional<int> max_algorithm_count = std::nullopt
);

// helper function for plan execution
static Status ExecuteOnStream(se::Stream* stream,
const Entry& entry,
const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b,
se::DeviceMemoryBase& c,
size_t algorithm_idx,
se::ScratchAllocator& scratch_allocator,
const se::DeviceMemoryBase& bias,
se::blas::ProfileResult* profile_result = nullptr);

BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) {
}

private:
static BlasLtMatmulPlanCache& i(se::Stream *stream);

std::unique_ptr<absl::Mutex> mutex_;
absl::node_hash_map<BlasLtMatmulPlanParams, Entry> map_
ABSL_GUARDED_BY(mutex_);

}; // BlasLtMatmulPlanCache

} // namespace tensorflow

#endif // GOOGLE_CUDA || TF_HIPBLASLT
Expand Down
15 changes: 15 additions & 0 deletions third_party/xla/xla/service/gpu/buffer_comparator.cu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,16 @@ __global__ void xla_fp8_e5m2_comparison(__nv_fp8_storage_t* buffer_a,
#endif // GOOGLE_CUDA

#if TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

__global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
// NOTE: according to amd_hip_fp8.h, GFX1200 and GFX1201 support ocp __hip_fp8_e4m3
// but not __hip_fp8_e4m3_fnuz

int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e4m3_fnuz elem_a_fp8, elem_b_fp8;
Expand All @@ -123,13 +128,18 @@ __global__ void xla_fp8_e4m3fnuz_comparison(__hip_fp8_storage_t* buffer_a,

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
#else
// on unsupported architectures, this should not / cannot be used!
atomicAdd(mismatch_count, 1);
#endif
}

__global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,
__hip_fp8_storage_t* buffer_b,
float rel_error_threshold,
uint64_t buffer_length,
int* mismatch_count) {
#if defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__)
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx >= buffer_length) return;
__hip_fp8_e5m2_fnuz elem_a_fp8, elem_b_fp8;
Expand All @@ -145,7 +155,12 @@ __global__ void xla_fp8_e5m2fnuz_comparison(__hip_fp8_storage_t* buffer_a,

if (rel_error > rel_error_threshold || isnan(rel_error))
atomicAdd(mismatch_count, 1);
#else
// on unsupported architectures, this should not / cannot be used!
atomicAdd(mismatch_count, 1);
#endif
}

#endif // TENSORFLOW_USE_ROCM && TF_ROCM_VERSION >= 60200

__global__ void xla_fp16_comparison(__half* buffer_a, __half* buffer_b,
Expand Down
2 changes: 1 addition & 1 deletion third_party/xla/xla/stream_executor/rocm/hip_blas_lt.cc
Original file line number Diff line number Diff line change
Expand Up @@ -618,4 +618,4 @@ absl::Status BlasLt::MatmulPlan::ExecuteOnStream(

} // namespace stream_executor

#endif // TF_HIPBLASLT
#endif // TF_HIPBLASLT
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#define XLA_STREAM_EXECUTOR_ROCM_HIPBLASLT_WRAPPER_H_

#define __HIP_DISABLE_CPP_FUNCTIONS__
#define LEGACY_HIPBLAS_DIRECT

#include "rocm/rocm_config.h"

Expand Down

0 comments on commit 432b95b

Please sign in to comment.