Skip to content

Commit

Permalink
[ROCm] Some fixes in tunable (#18575)
Browse files Browse the repository at this point in the history
### Description

* Fix workspace size for hipBLASLt algos at 32M
* Update according to API changes
  • Loading branch information
mindest authored Nov 24, 2023
1 parent 62f00ad commit b9c935f
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ auto GetTritonGroupNormNHWCTypeStringAndOps() {
params->epsilon};

// Grid dim is (batch_count, groups, 1)
return LaunchTritonKernel(params->stream, i, params->n, params->groups, 1, &args, sizeof(args));
return LaunchTritonKernel(params->StreamHandle(), i, params->n, params->groups, 1, &args, sizeof(args));
};
ret.emplace_back(std::make_pair(metadata->name, std::move(impl)));
}
Expand Down
2 changes: 1 addition & 1 deletion onnxruntime/core/providers/rocm/math/softmax_triton.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ auto GetSoftmaxTritonOps() {
} args = {(void*)params->output, (const void*)params->input, params->input_stride, params->output_stride, params->softmax_elements};

// grid dim is (batch_count, 1, 1)
return LaunchTritonKernel(params->stream, i, params->batch_count, 1, 1, &args, sizeof(args));
return LaunchTritonKernel(params->StreamHandle(), i, params->batch_count, 1, 1, &args, sizeof(args));
};
ret.emplace_back(std::make_pair(metadata->name, std::move(impl)));
}
Expand Down
7 changes: 7 additions & 0 deletions onnxruntime/core/providers/rocm/tunable/gemm_hipblaslt.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ using onnxruntime::contrib::rocm::blas::GemmFastGeluParams;

#ifdef USE_HIPBLASLT

// For large K and small M/N, K dim will be split to multiple workgroups and buffers,
// which will require additional workspace. Here we set the max workspace size to 32MB.
constexpr const size_t kHipBlasLtMaxWorkSpaceSizeInBytes = 32 * 1024 * 1024;

enum ActivationType {
NONE = 0,
RELU = 1,
Expand Down Expand Up @@ -225,6 +229,9 @@ auto GetHipBlasLtTypeStringAndOps(ActivationType activation_type = ActivationTyp

IAllocatorUniquePtr<void> workspace_buffer;
if (workspace_size > 0) {
TUNABLE_OP_RETURN_UNSUPPORTED_ARGUMENT_IF(workspace_size > kHipBlasLtMaxWorkSpaceSizeInBytes,
"Workspace size exceeds limit (32M): ", workspace_size);
workspace_size = kHipBlasLtMaxWorkSpaceSizeInBytes;
workspace_buffer = params->tuning_ctx->GetScratchBuffer(workspace_size, params->stream);
}

Expand Down

0 comments on commit b9c935f

Please sign in to comment.