Skip to content

Commit

Permalink
[ROCm] Add hipBLASLt workspace support (microsoft#17096)
Browse files Browse the repository at this point in the history
### Description
* hipBLASLt extra workspace for split-k
* type update (due to extra support for fp8 in hipBLASLt)
* minor changes
  • Loading branch information
mindest authored Aug 25, 2023
1 parent 7c98f45 commit 93ae17d
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 28 deletions.
13 changes: 5 additions & 8 deletions onnxruntime/core/providers/rocm/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,11 @@ Status MatMul<T>::ComputeInternal(OpKernelContext* ctx) const {
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (MatMulImpl<T>(this, helper, reinterpret_cast<const T*>(left_X->Data<T>()),
reinterpret_cast<const T*>(right_X->Data<T>()),
reinterpret_cast<T*>(Y->MutableData<T>()),
left_X->Shape(), right_X->Shape(),
transa, transb, trans_batch_a_, trans_batch_b_, alpha_, ctx->GetComputeStream()) != Status::OK()) {
return Status(common::ONNXRUNTIME, common::FAIL, "MatMulImpl failed");
}
return Status::OK();
return MatMulImpl<T>(this, helper, reinterpret_cast<const T*>(left_X->Data<T>()),
reinterpret_cast<const T*>(right_X->Data<T>()),
reinterpret_cast<T*>(Y->MutableData<T>()),
left_X->Shape(), right_X->Shape(),
transa, transb, trans_batch_a_, trans_batch_b_, alpha_, ctx->GetComputeStream());
}

} // namespace rocm
Expand Down
36 changes: 19 additions & 17 deletions onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,26 +33,26 @@ enum ActivationType {
};

template <typename T>
constexpr hipblasDatatype_t HipBlasDataTypeFor();
constexpr hipblasltDatatype_t HipBlasDataTypeFor();

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLAS_R_32F;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<float>() {
return HIPBLASLT_R_32F;
}

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLAS_R_16F;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<half>() {
return HIPBLASLT_R_16F;
}

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLAS_R_16B;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<BFloat16>() {
return HIPBLASLT_R_16B;
}

template <>
constexpr hipblasDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLAS_R_64F;
constexpr hipblasltDatatype_t HipBlasDataTypeFor<double>() {
return HIPBLASLT_R_64F;
}

template <typename Layout>
Expand Down Expand Up @@ -104,7 +104,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

hipblasOperation_t trans_a = MapCKLayoutToHipBlasLt<BLayout>();
hipblasOperation_t trans_b = MapCKLayoutToHipBlasLt<ALayout>();
hipblasDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
hipblasltDatatype_t in_out_datatype = HipBlasDataTypeFor<T>();
std::vector<hipblasLtMatmulHeuristicResult_t> heuristic_result;

HIPBLASLT_CALL_THROW(hipblaslt_ext::getAllAlgos(handle,
Expand Down Expand Up @@ -149,7 +149,7 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_a, in_out_datatype, row_a, col_a, lda));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_b, in_out_datatype, row_b, col_b, ldb));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutCreate(&mat_c, in_out_datatype, row_c, col_c, ldc));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLAS_R_32F));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescCreate(&matmul, HIPBLASLT_COMPUTE_F32, HIPBLASLT_R_32F));

int batch = GetBatchCountFromParams<T>(params);
if (batch > 1) {
Expand Down Expand Up @@ -213,9 +213,11 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != HIPBLAS_STATUS_SUCCESS, "hipBLASLt find_all: algo not supported, index ", std::to_string(i));
// TODO: support workspace in next PR
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
workspace_size > 0, "hipBLASLt find_all: extra workspace not supported for now.");

IAllocatorUniquePtr<void> workspace_buffer;
if (workspace_size > 0) {
workspace_buffer = params->tuning_ctx->GetScratchBuffer(workspace_size, params->stream);
}

HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmul(op_handle,
matmul,
Expand All @@ -230,9 +232,9 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp
params->c,
mat_c,
&algo_i,
nullptr,
0,
params->stream));
workspace_buffer.get(),
workspace_size,
params->StreamHandle()));

HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatmulDescDestroy(matmul));
HIPBLASLT_RETURN_IF_ERROR(hipblasLtMatrixLayoutDestroy(mat_a));
Expand Down
6 changes: 3 additions & 3 deletions onnxruntime/core/providers/rocm/tunable/gemm_rocblas.h
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ auto GetRocBlasGemmTypeStringAndOps() {
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
status != rocblas_status_success, "Solution ", solution, " failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down Expand Up @@ -232,7 +232,7 @@ auto GetRocBlasBatchedGemmTypeStringAndOps() {
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
status != rocblas_status_success, "Solution ", solution, " failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down Expand Up @@ -299,7 +299,7 @@ auto GetRocBlasStridedBatchedGemmTypeStringAndOps() {
status == rocblas_status_invalid_size, "Solution ", solution, " not supported: INVALID VALUE.");

TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(
status != rocblas_status_success, "Solution ", solution, " failed.");
status != rocblas_status_success, "Solution ", solution, " failed: ", rocblas_status_to_string(status));

return Status::OK();
};
Expand Down

0 comments on commit 93ae17d

Please sign in to comment.