Skip to content

Commit

Permalink
[ROCm] enable hipGraph
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffdaily committed Nov 9, 2023
1 parent 59262df commit e3dff16
Show file tree
Hide file tree
Showing 8 changed files with 140 additions and 24 deletions.
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -495,6 +495,7 @@ typedef struct OrtROCMProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
enable_hip_graph{false},
tunable_op_enable{false},
tunable_op_tuning_enable{false},
tunable_op_max_tuning_duration_ms{} {}
Expand Down Expand Up @@ -547,6 +548,8 @@ typedef struct OrtROCMProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;

int enable_hip_graph;

/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.

Check warning on line 555 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "overriden" is a misspelling of "overridden" Raw Output: ./include/onnxruntime/core/session/onnxruntime_c_api.h:555:26: "overriden" is a misspelling of "overridden"
Expand Down
77 changes: 72 additions & 5 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,42 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de

MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_));
MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream));

hip_graph_.SetStream(stream);
}

ROCMExecutionProvider::PerThreadContext::~PerThreadContext() {
ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_)));
ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_)));
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_;
}

void ROCMExecutionProvider::PerThreadContext::CaptureBegin() {
hip_graph_.Reset();
hip_graph_.CaptureBegin();
}

void ROCMExecutionProvider::PerThreadContext::CaptureEnd() {
hip_graph_.CaptureEnd();
is_graph_captured_ = true;
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const {
return is_graph_captured_;
}

Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
return hip_graph_.Replay();
}

void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
++regular_run_count_before_graph_capture_;
}

void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) {
if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead.");
Expand Down Expand Up @@ -219,6 +248,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in
if (info.external_allocator_info.UseExternalAllocator()) {
use_ep_level_unified_stream_ = true;
stream_ = nullptr;
} else if (info.enable_hip_graph) {
// current hip graph implementation only works with single stream
// use EP level unified stream for all the reqeust
HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
use_ep_level_unified_stream_ = true;
} else {
stream_ = nullptr;
}
Expand Down Expand Up @@ -322,25 +356,58 @@ Status ROCMExecutionProvider::Sync() const {
Status ROCMExecutionProvider::OnRunStart() {
// always set ROCM device when session::Run() in case it runs in a worker thread
HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {

Check warning on line 359 in onnxruntime/core/providers/rocm/rocm_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/rocm_execution_provider.cc:359: Lines should be <= 120 characters long [whitespace/line_length] [2]
LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model";
GetPerThreadContext().CaptureBegin();
}
return Status::OK();
}

Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
// HIP work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
}
}

if (sync_stream) {
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream_)));
}

// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
// The reason of !IsGraphCaptureEnabled():
// If hip graph is enabled, the per thread context will not be released
// because the per thread hip graph needs to be maintained and replayed for
// the next run.
// The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end():
// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (!IsGraphCaptureEnabled() &&
PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
ReleasePerThreadContext();
}

return Status::OK();
}

bool ROCMExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_hip_graph;
}

bool ROCMExecutionProvider::IsGraphCaptured() const {
return GetPerThreadContext().IsGraphCaptured();
}

Status ROCMExecutionProvider::ReplayGraph() {
return GetPerThreadContext().ReplayGraph();
}

namespace rocm {
// opset 1 to 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/providers/rocm/rocm_graph.h"
#include "core/providers/rocm/rocm_pch.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
Expand Down Expand Up @@ -73,6 +74,9 @@ class ROCMExecutionProvider : public IExecutionProvider {

std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand All @@ -81,6 +85,7 @@ class ROCMExecutionProvider : public IExecutionProvider {
ROCMExecutionProviderInfo info_;
hipDeviceProp_t device_prop_;
bool external_stream_ = false;
// only used when set user external stream or hip graph
hipStream_t stream_ = nullptr;

bool use_ep_level_unified_stream_ = false;
Expand Down Expand Up @@ -133,6 +138,13 @@ class ROCMExecutionProvider : public IExecutionProvider {
}
}

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Status ReplayGraph();
void IncrementRegularRunCountBeforeGraphCapture();

private:
rocblas_handle rocblas_handle_ = nullptr;
miopenHandle_t miopen_handle_ = nullptr;
Expand All @@ -141,6 +153,18 @@ class ROCMExecutionProvider : public IExecutionProvider {
std::unique_ptr<rocm::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<rocm::IConstantBuffer<half>> constant_ones_half_;
std::unique_ptr<rocm::IConstantBuffer<BFloat16>> constant_ones_bfloat16_;

// Hip graph with multi threads will be supported in the future, so hip_graph_
// is put under PerThreadContext.
ROCMGraph hip_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
// to allocate enough memory in Arena before graph capturing.
const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.

Check warning on line 167 in onnxruntime/core/providers/rocm/rocm_execution_provider.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/rocm_execution_provider.h:167: Lines should be <= 120 characters long [whitespace/line_length] [2]
};

using PerThreadContextMap = std::unordered_map<const ROCMExecutionProvider*, std::weak_ptr<PerThreadContext>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc";
constexpr const char* kGpuExternalFree = "gpu_external_free";
constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache";
constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace";
constexpr const char* kEnableHipGraph = "enable_hip_graph";
constexpr const char* kTunableOpEnable = "tunable_op_enable";
constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable";
constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms";
Expand Down Expand Up @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
info.miopen_conv_exhaustive_search)
.AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace)
.AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph)
.AddValueParser(
rocm::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
Expand Down Expand Up @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)},
{rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)},
{rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)},
{rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)},
{rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo {
// If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best.
bool miopen_conv_use_max_workspace{true};

bool enable_hip_graph{false};

rocm::TunableOpInfo tunable_op{};

static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider {
info.has_user_compute_stream = params->has_user_compute_stream != 0;
info.user_compute_stream = params->user_compute_stream;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
info.enable_hip_graph = params->enable_hip_graph;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms;
Expand Down Expand Up @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider {
rocm_options.user_compute_stream = internal_options.user_compute_stream;
}
rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
rocm_options.enable_hip_graph = internal_options.enable_hip_graph;
rocm_options.tunable_op_enable = internal_options.tunable_op.enable;
rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable;
rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms;
Expand Down
52 changes: 33 additions & 19 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ static bool AreAllComputeNodesAssignedToCudaEp(const Graph& graph) {

// Empty node provider means CPU EP
if (!node_provider.empty() &&
node_provider != kCudaExecutionProvider &&
!(node_provider == kCudaExecutionProvider || node_provider == kRocmExecutionProvider) &&
node_provider != kCpuExecutionProvider) {
nodes_on_cpu_and_cuda_eps_only = false;
break;
Expand Down Expand Up @@ -1648,7 +1648,8 @@ common::Status InferenceSession::Initialize() {
// now that all the transforms are done, call Resolve on the main graph. this will recurse into the subgraphs.
ORT_RETURN_IF_ERROR_SESSIONID_(graph.Resolve());

// Currently CUDA graph is only considered by CUDA EP and TRT EP.
// Currently CUDA graph is only considered by CUDA EP and TRT EP, and
// HIP graph is only considered by ROCM EP.
//
// Check for CUDA EP:
// If the CUDA EP is part of the providers list for this session AND
Expand All @@ -1661,47 +1662,58 @@ common::Status InferenceSession::Initialize() {
// The TRT EP is configured to do a graph capture AND
// All the graph nodes have been assigned to the TRT EP,
// Then the TRT EP is cached for triggering a ReplayGraph() in Run().
std::vector<const char*> cuda_graph_support_ep_list = {onnxruntime::kTensorrtExecutionProvider, onnxruntime::kCudaExecutionProvider};
//
// Check for ROCM EP:
// If the ROCM EP is part of the providers list for this session AND
// The ROCM EP is configured to do a graph capture AND
// All the "compute" graph nodes have been assigned to the ROCM EP,
// Then the ROCM EP is cached for triggering a ReplayGraph() in Run().
//
std::vector<const char*> graph_support_ep_list = {
onnxruntime::kTensorrtExecutionProvider,
onnxruntime::kCudaExecutionProvider,
onnxruntime::kRocmExecutionProvider};

for (auto& it : cuda_graph_support_ep_list) {
for (auto& it : graph_support_ep_list) {
auto* target_ep = execution_providers_.Get(it);

if (target_ep && target_ep->IsGraphCaptureEnabled()) {
// CUDA Graphs can't work with control flow nodes
// CUDA/HIP Graphs can't work with control flow nodes
if (HasControlflowNodes(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< "as the model has control flow nodes which can't be supported by CUDA Graphs.";
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user "

Check warning on line 1683 in onnxruntime/core/session/inference_session.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/inference_session.cc:1683: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< "as the model has control flow nodes which can't be supported by CUDA/HIP Graphs.";

Check warning on line 1684 in onnxruntime/core/session/inference_session.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/inference_session.cc:1684: Lines should be <= 120 characters long [whitespace/line_length] [2]

ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
"as the model has control flow nodes which can't be supported by CUDA Graphs."));
"This session cannot use the CUDA/HIP Graph feature as requested by the user "
"as the model has control flow nodes which can't be supported by CUDA/HIP Graphs."));
}

if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0) {
if (strcmp(target_ep->Type().c_str(), onnxruntime::kCudaExecutionProvider) == 0 ||
strcmp(target_ep->Type().c_str(), onnxruntime::kRocmExecutionProvider) == 0) {
// Ensure that all nodes have been partitioned to CUDA or CPU EP && there are no memcpy nodes
// The reasoning behind this logic is that certain shape nodes will be forced onto CPU
// and as long as there are no memcpy nodes this is confirmation that no compute nodes have been placed on the CPU EP
// which is all we care about.
if (!AreAllComputeNodesAssignedToCudaEp(graph)) {
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA Graph feature as requested by the user "
<< " as all compute graph nodes have not been partitioned to the CUDA EP.";
LOGS(*session_logger_, ERROR) << "This session cannot use the CUDA/HIP Graph feature as requested by the user "

Check warning on line 1699 in onnxruntime/core/session/inference_session.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/inference_session.cc:1699: Lines should be <= 120 characters long [whitespace/line_length] [2]
<< " as all compute graph nodes have not been partitioned to the CUDA/HIP EP.";

Check warning on line 1700 in onnxruntime/core/session/inference_session.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/session/inference_session.cc:1700: Lines should be <= 120 characters long [whitespace/line_length] [2]

ORT_RETURN_IF_ERROR_SESSIONID_(
ORT_MAKE_STATUS(ONNXRUNTIME, FAIL,
"This session cannot use the CUDA Graph feature as requested by the user "
" as all compute graph nodes have not been partitioned to the CUDA EP."));
"This session cannot use the CUDA/HIP Graph feature as requested by the user "
" as all compute graph nodes have not been partitioned to the CUDA/HIP EP."));
}

// Log a warning for the user to know that there are shape subgraphs that will execute on CPU
if (HasShapeSubgraphNodes(graph)) {
LOGS(*session_logger_, WARNING) << "This model has shape massaging nodes that will execute on CPU. "
<< "Use the CUDA Graph feature with caution. "
<< "Use the CUDA/HIP Graph feature with caution. "
<< "As long as the intermediate shapes produced in the model "
<< "using the representative input used to capture the CUDA graph, "
<< "using the representative input used to capture the CUDA/HIP graph, "
<< "will match the shapes produced in the model for other inputs "
<< "of the same shape as the representative input (common case), "
<< "it is safe to use the CUDA Graph feature.";
<< "it is safe to use the CUDA/HIP Graph feature.";
}
} else {
// Following code path is for TRT EP currently.
Expand All @@ -1720,7 +1732,7 @@ common::Status InferenceSession::Initialize() {
}
}

LOGS(*session_logger_, INFO) << "This session will use the CUDA Graph feature as requested by the user.";
LOGS(*session_logger_, INFO) << "This session will use the CUDA/HIP Graph feature as requested by the user.";
cached_execution_provider_for_graph_replay_.SetExecutionProvider(target_ep);
break; // Make sure only one ep can run CUDA graph.
}
Expand Down Expand Up @@ -2408,7 +2420,9 @@ Status InferenceSession::Run(const RunOptions& run_options,
// As N+1 inference runs (N for memory allocation and 1 for graph capturing)
// are needed before replaying the captured graph, here run N inference runs recursively until graph captured,
// so that users just need one session run to capture the graph.
// N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP, and the value could be different for other EP.
// N is defined in min_num_runs_before_cuda_graph_capture_ for CUDA EP,
// N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP,
// and the value could be different for other EP.
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
Expand Down
1 change: 1 addition & 0 deletions onnxruntime/core/session/provider_bridge_ort.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2252,6 +2252,7 @@ ORT_API_STATUS_IMPL(OrtApis::CreateROCMProviderOptions, _Outptr_ OrtROCMProvider
options->has_user_compute_stream = 0;
options->user_compute_stream = nullptr;
options->default_memory_arena_cfg = nullptr;
options->enable_hip_graph = false;
options->tunable_op_enable = 0;
options->tunable_op_tuning_enable = 0;
options->tunable_op_max_tuning_duration_ms = 0;
Expand Down

0 comments on commit e3dff16

Please sign in to comment.