Skip to content

Commit

Permalink
Merge pull request #2772 from ROCm/r2.17-rocm-enhanced-hipblaslt-fixes
Browse files Browse the repository at this point in the history
Numerous hipblaslt related fixes
  • Loading branch information
pemeliya authored Dec 4, 2024
2 parents e27a3c5 + 394b32e commit e49be7f
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 160 deletions.
43 changes: 24 additions & 19 deletions tensorflow/core/kernels/matmul_op_fused.cc
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ struct LaunchFusedMatMulOp<CPUDevice, T> {
namespace {

#if GOOGLE_CUDA || TF_HIPBLASLT
/*
hipBLASLt support Epilogue:
https://rocm.docs.amd.com/projects/hipBLASLt/en/latest/datatypes.html#hipblasltepilogue-t
*/
StatusOr<se::gpu::BlasLt::Epilogue> GetBlasLtEpilogOp(
FusedComputationType fusion) {
if (fusion == FusedComputationType::kBiasAdd) {
Expand Down Expand Up @@ -263,7 +267,7 @@ se::blas::AlgorithmConfig AutotuneMatmul(
}
return algorithm_config;
}
#endif
#endif // GOOGLE_CUDA || TF_HIPBLASLT

template <typename LaunchFunc, typename Sig>
StatusOr<std::vector<xla::AutotuneResult>> AutotuneMatMulImpl(
Expand Down Expand Up @@ -477,6 +481,17 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {

se::dnn::ActivationMode matmul_activation_mode;
bool use_cudnn = false;

#if !(GOOGLE_CUDA || TF_HIPBLASLT)
use_cudnn = true;
#endif
const auto& cc = stream->parent()->GetDeviceDescription().
gpu_compute_capability();
if (auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) {
use_cudnn = !procm->gfx9_mi200_or_later();
}

// use_cudnn is for hipblaslt doesn't support yet
switch (fusion) {
case FusedComputationType::kBiasAddWithGeluExact:
matmul_activation_mode = se::dnn::ActivationMode::kGeluExact;
Expand Down Expand Up @@ -511,15 +526,6 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
default:
use_cudnn = false;
}
#if !(GOOGLE_CUDA || TF_HIPBLASLT)
use_cudnn = true;
#endif

#if TF_HIPBLASLT
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") use_cudnn = true;
#endif

BlasScratchAllocator scratch_allocator(context);

Expand Down Expand Up @@ -590,32 +596,31 @@ struct LaunchFusedMatMulOp<GPUDevice, T> {
epilog_op};
absl::Mutex* pmu;
auto plan_and_algorithms_or =
GetPlanAndAlgorithms(stream, matmul_params, &pmu);
BlasLtMatmulPlanCache::GetOrCreate(stream, matmul_params, &pmu);
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
absl::MutexLock lock(pmu);
const auto* plan_and_algorithms = std::move(plan_and_algorithms_or).value();
const auto& algorithms = plan_and_algorithms->algorithms;
OP_REQUIRES(context, algorithms.size() > 0,
const auto& entry = *plan_and_algorithms_or.value();
OP_REQUIRES(context, entry.algorithms.size() > 0,
errors::InvalidArgument("No matmul algorithm returned!"));

auto launch_func = [&](BlasScratchAllocator& scratch_allocator,
size_t alg_idx,
se::blas::ProfileResult* profile_result) {
return DoBlasLtMatmul(stream, *plan_and_algorithms, a_ptr, b_ptr, c_ptr,
alg_idx, scratch_allocator, bias_ptr,
profile_result);
return BlasLtMatmulPlanCache::ExecuteOnStream(
stream, entry, a_ptr, b_ptr, c_ptr, alg_idx,
scratch_allocator, bias_ptr, profile_result);
};

size_t alg_idx = 0;
if (use_autotune) {
auto algorithm_config =
AutotuneMatmul(algorithms, matmul_params, context, launch_func);
AutotuneMatmul(entry.algorithms, matmul_params, context, launch_func);

alg_idx = algorithm_config.algorithm();
}

OP_REQUIRES_OK(context, launch_func(scratch_allocator, alg_idx, nullptr));
#endif
#endif // GOOGLE_CUDA || TF_HIPBLASLT
}
};

Expand Down
28 changes: 16 additions & 12 deletions tensorflow/core/kernels/matmul_op_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -601,12 +601,13 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
#if GOOGLE_CUDA || TF_HIPBLASLT
static const bool use_autotune = MatmulAutotuneEnable();
bool bCublasLtSupport = true;
#if TF_HIPBLASLT
if (!std::is_same_v<Scalar, float>) bCublasLtSupport = false;
auto cap = stream->GetRocmComputeCapability();
// as of ROCm 5.5, hipblaslt only supports MI200.
if (cap.gcn_arch_name().substr(0, 6) != "gfx90a") bCublasLtSupport = false;
#endif

const auto& cc = stream->parent()->GetDeviceDescription().
gpu_compute_capability();
if(auto *procm = std::get_if< se::RocmComputeCapability >(&cc)) {
bCublasLtSupport = procm->gfx9_mi200_or_later();
}

if (EnableCublasLtGemm() && bCublasLtSupport) {
static const int64_t max_scratch_size =
GetWorkspaceLimit(1LL << 32); // 4GB by default
Expand Down Expand Up @@ -636,7 +637,7 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
std::optional<int> max_algorithm_count;
if (!use_autotune) max_algorithm_count = 1;
absl::Mutex* pmu = nullptr;
auto plan_and_algorithms_or = GetPlanAndAlgorithms(
auto plan_and_algorithms_or = BlasLtMatmulPlanCache::GetOrCreate(
stream, matmul_params, &pmu, max_algorithm_count);
OP_REQUIRES_OK(context, plan_and_algorithms_or.status());
absl::MutexLock lock(pmu);
Expand All @@ -659,9 +660,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {
// scratch space is deallocated between runs.
BlasScratchAllocator scratch_allocator(context, max_scratch_size);
Status cublas_launch_status =
DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0],
*b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
/*bias = */ {}, &profile_result);
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
*plan_and_algorithms,
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0], i, scratch_allocator,
se::DeviceMemoryBase{}, &profile_result);

VLOG(4) << " Autotune algorithm " << i
<< " result: " << profile_result.elapsed_time_in_ms()
Expand Down Expand Up @@ -701,8 +703,10 @@ struct LaunchBatchMatMul<GPUDevice, Scalar> {

OP_REQUIRES_OK(
context,
DoBlasLtMatmul(stream, *plan_and_algorithms, *a_ptrs[0], *b_ptrs[0],
*c_ptrs[0], algorithm_idx, scratch_allocator));
BlasLtMatmulPlanCache::ExecuteOnStream(stream,
*plan_and_algorithms,
*a_ptrs[0], *b_ptrs[0], *c_ptrs[0],
algorithm_idx, scratch_allocator, se::DeviceMemoryBase{}));
} else { // requires mixed broadcasting
const std::vector<int64_t>& a_batch_indices = bcast.x_batch_indices();
const std::vector<int64_t>& b_batch_indices = bcast.y_batch_indices();
Expand Down
81 changes: 48 additions & 33 deletions tensorflow/core/kernels/matmul_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ limitations under the License.

#include <optional>
#include <string>
#include <deque>
#include <utility>

#include "xla/status_macros.h"
Expand All @@ -24,6 +25,7 @@ limitations under the License.
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/util/env_var.h"
#include "tensorflow/core/util/matmul_autotune.h"
#include "xla/stream_executor/stream.h"

namespace tensorflow {

Expand All @@ -44,33 +46,13 @@ int64_t GetWorkspaceLimit(int64_t default_value_in_bytes) {
return default_value_in_bytes;
}

std::string BlasLtMatmulPlanParams::ToString() const {
return ""; // TODO
}

bool BlasLtMatmulPlanParams::operator==(
const BlasLtMatmulPlanParams& other) const {
return internal::AsTuple(*this) == internal::AsTuple(other);
}

namespace {

// Thread-safe map from matmul parameters to their corresponding plan and
// algorithms.
struct BlasLtMatmulPlanMap {
absl::Mutex mu;

template <class... Args>
auto emplace(Args&&... args) {
absl::MutexLock lock(&mu);
return map_.emplace(std::forward<Args>(args)...);
}

private:
absl::flat_hash_map<BlasLtMatmulPlanParams, PlanAndAlgorithms> map_
ABSL_GUARDED_BY(mu);
};

int MatmulMaxAutotuneAlgorithmCount() {
int64_t value;
Status status =
Expand Down Expand Up @@ -110,27 +92,37 @@ StatusOr<se::blas::ComputationType> GetBlasComputationType(

} // namespace

StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
/* static */ BlasLtMatmulPlanCache& BlasLtMatmulPlanCache::i(se::Stream *stream) {
static absl::Mutex m(absl::kConstInit);
// Each GPU gets different cache instance
static std::deque< BlasLtMatmulPlanCache > meta(8);
absl::MutexLock lock(&m);
size_t dev_id = stream->parent()->device_ordinal();
if (dev_id >= meta.size()) meta.resize(dev_id + 1);
return meta[dev_id];
}

/* static */ auto BlasLtMatmulPlanCache::GetOrCreate(
se::Stream* stream, const BlasLtMatmulPlanParams& params,
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) {
absl::Mutex** ppmu, std::optional<int> max_algorithm_count) -> StatusOr<const Entry *>{
static const int64_t max_scratch_size =
GetWorkspaceLimit(1LL << 32); // 4GB by default
static const int64_t max_autotune_algorithm_count =
MatmulMaxAutotuneAlgorithmCount();

if (!max_algorithm_count) max_algorithm_count = max_autotune_algorithm_count;

static BlasLtMatmulPlanMap plan_map;
auto& self = BlasLtMatmulPlanCache::i(stream);

auto [ptr, inserted] = plan_map.emplace(params, PlanAndAlgorithms{});
absl::MutexLock lock(self.mutex_.get());
auto [ptr, inserted] = self.map_.emplace(params, Entry{});
auto& entry = ptr->second;
if (inserted) {
TF_ASSIGN_OR_RETURN(auto xlatype,
se::gpu::AsXlaPrimitiveType(params.dtype));
TF_ASSIGN_OR_RETURN(auto computation_type,
GetBlasComputationType(params.dtype));

auto scale_type = se::gpu::GetScaleType(params.dtype, computation_type);

// row-major output is now handled automatically by blas-lt API
constexpr auto kRowMajor = se::gpu::MatrixLayout::Order::kRowMajor;

Expand Down Expand Up @@ -173,19 +165,42 @@ StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
.compute_type = computation_type,
};

TF_ASSIGN_OR_RETURN(auto plan, se::gpu::BlasLt::GetMatmulPlan(
TF_ASSIGN_OR_RETURN(entry.plan, se::gpu::BlasLt::GetMatmulPlan(
stream, cfg, params.epilogue));

TF_ASSIGN_OR_RETURN(
auto algorithms,
plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));

ptr->second = {std::move(plan), std::move(algorithms), scale_type};
entry.algorithms,
entry.plan->GetAlgorithms(*max_algorithm_count, max_scratch_size));
}
*ppmu = &plan_map.mu;
return &ptr->second;
*ppmu = self.mutex_.get();
return &entry;
}

/*static */ Status BlasLtMatmulPlanCache::ExecuteOnStream(se::Stream* stream,
const Entry& entry,
const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b,
se::DeviceMemoryBase& c,
size_t algorithm_idx,
se::ScratchAllocator& scratch_allocator,
const se::DeviceMemoryBase& bias,
se::blas::ProfileResult* profile_result) {

return entry.plan->ExecuteOnStream(
stream, a, b, c, c,
bias, // bias_buffer
se::DeviceMemoryBase{}, // aux_buffer
se::DeviceMemoryBase{}, // a_scale_buffer
se::DeviceMemoryBase{}, // b_scale_buffer
se::DeviceMemoryBase{}, // c_scale_buffer
se::DeviceMemoryBase{}, // d_scale_buffer
se::DeviceMemoryBase{}, // d_amax_buffer
entry.algorithms[algorithm_idx],
scratch_allocator,
profile_result);
}


} // namespace tensorflow

#endif
70 changes: 35 additions & 35 deletions tensorflow/core/kernels/matmul_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License.

#if GOOGLE_CUDA || TF_HIPBLASLT

#include "absl/container/flat_hash_map.h"
#include "absl/container/node_hash_map.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/gpu/gpu_blas_lt.h"
#include "tensorflow/core/framework/types.h"
Expand All @@ -35,7 +35,8 @@ namespace tensorflow {
int64_t GetWorkspaceLimit(int64_t default_value_in_bytes);

struct BlasLtMatmulPlanParams {
std::string ToString() const;

std::string ToString() const { return "NOP"; }
bool operator==(const BlasLtMatmulPlanParams& other) const;

se::blas::DataType dtype;
Expand All @@ -50,12 +51,6 @@ struct BlasLtMatmulPlanParams {
se::gpu::BlasLt::Epilogue epilogue = se::gpu::BlasLt::Epilogue::kDefault;
};

struct PlanAndAlgorithms {
se::gpu::BlasLt::MatmulPlanPtr plan;
std::vector<se::gpu::BlasLt::MatmulAlgorithm> algorithms;
se::blas::DataType scale_type; // this is needed for half / bf16 treatment
};

namespace internal {

inline auto AsTuple(const BlasLtMatmulPlanParams& p) {
Expand All @@ -71,37 +66,42 @@ H AbslHashValue(H h, const BlasLtMatmulPlanParams& params) {
return H::combine(std::move(h), internal::AsTuple(params));
}

StatusOr<const PlanAndAlgorithms*> GetPlanAndAlgorithms(
struct BlasLtMatmulPlanCache {
struct Entry {
se::gpu::BlasLt::MatmulPlanPtr plan;
std::vector< se::gpu::BlasLt::MatmulAlgorithm > algorithms;
};

static StatusOr<const Entry *> GetOrCreate(
se::Stream* stream, const BlasLtMatmulPlanParams& params, absl::Mutex** pmu,
std::optional<int> max_algorithm_count = std::nullopt);

template <typename T>
Status DoBlasLtMatmul(se::Stream* stream, const PlanAndAlgorithms& paa,
const se::DeviceMemory<T>& a,
const se::DeviceMemory<T>& b, se::DeviceMemory<T>& c,
size_t alg_idx, se::ScratchAllocator& scratch_allocator,
const se::DeviceMemory<T>& bias = {},
se::blas::ProfileResult* profile_result = nullptr) {
se::DeviceMemory<T> aux{}; // We don't use the auxilary buffers.
const auto& algorithm = paa.algorithms[alg_idx];

// The scale type may be f32 if the data type is f16 and bf16.
if constexpr (std::is_same_v<T, Eigen::half> ||
std::is_same_v<T, Eigen::bfloat16>) {
if (paa.scale_type == se::blas::DataType::kFloat) {
return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar<float>(1.0), b,
a, se::HostOrDeviceScalar<float>(0.0), c, c,
algorithm, scratch_allocator, bias, aux,
profile_result);
}
std::optional<int> max_algorithm_count = std::nullopt
);

// helper function for plan execution
static Status ExecuteOnStream(se::Stream* stream,
const Entry& entry,
const se::DeviceMemoryBase& a,
const se::DeviceMemoryBase& b,
se::DeviceMemoryBase& c,
size_t algorithm_idx,
se::ScratchAllocator& scratch_allocator,
const se::DeviceMemoryBase& bias,
se::blas::ProfileResult* profile_result = nullptr);

BlasLtMatmulPlanCache() : mutex_(new absl::Mutex) {
}
return paa.plan->DoMatmul(stream, se::HostOrDeviceScalar<T>(T(1.0)), b, a,
se::HostOrDeviceScalar<T>(T(0.0)), c, c, algorithm,
scratch_allocator, bias, aux, profile_result);
}

private:
static BlasLtMatmulPlanCache& i(se::Stream *stream);

std::unique_ptr<absl::Mutex> mutex_;
absl::node_hash_map<BlasLtMatmulPlanParams, Entry> map_
ABSL_GUARDED_BY(mutex_);

}; // BlasLtMatmulPlanCache

} // namespace tensorflow

#endif
#endif // GOOGLE_CUDA || TF_HIPBLASLT

#endif // TENSORFLOW_CORE_KERNELS_MATMUL_UTIL_H_
Loading

0 comments on commit e49be7f

Please sign in to comment.