Skip to content

Commit

Permalink
[ROCm] enable hipGraph (#18382)
Browse files Browse the repository at this point in the history
This ports the cudaGraph support from the CUDA EP to the ROCM EP's
hipGraph.
  • Loading branch information
jeffdaily authored Jan 23, 2024
1 parent 37d14d7 commit b2aec41
Show file tree
Hide file tree
Showing 10 changed files with 241 additions and 42 deletions.
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
endif()
if (onnxruntime_USE_ROCM)
list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
endif()
if (onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
Expand All @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu)
endif()
if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include)
target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__)
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_sources(onnxruntime_shared_lib_test PRIVATE
"${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc"
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@ typedef struct OrtROCMProviderOptions {
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
enable_hip_graph{false},
tunable_op_enable{false},
tunable_op_tuning_enable{false},
tunable_op_max_tuning_duration_ms{} {}
Expand Down Expand Up @@ -548,6 +549,8 @@ typedef struct OrtROCMProviderOptions {
*/
OrtArenaCfg* default_memory_arena_cfg;

int enable_hip_graph;

/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.
Expand Down
77 changes: 72 additions & 5 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,42 @@ ROCMExecutionProvider::PerThreadContext::PerThreadContext(OrtDevice::DeviceId de

MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_));
MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream));

hip_graph_.SetStream(stream);
}

ROCMExecutionProvider::PerThreadContext::~PerThreadContext() {
ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_)));
ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_)));
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_;
}

void ROCMExecutionProvider::PerThreadContext::CaptureBegin() {
hip_graph_.Reset();
hip_graph_.CaptureBegin();
}

void ROCMExecutionProvider::PerThreadContext::CaptureEnd() {
hip_graph_.CaptureEnd();
is_graph_captured_ = true;
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const {
return is_graph_captured_;
}

Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
return hip_graph_.Replay();
}

void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
++regular_run_count_before_graph_capture_;
}

void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) {
if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead.");
Expand Down Expand Up @@ -219,6 +248,11 @@ ROCMExecutionProvider::ROCMExecutionProvider(const ROCMExecutionProviderInfo& in
if (info.external_allocator_info.UseExternalAllocator()) {
use_ep_level_unified_stream_ = true;
stream_ = nullptr;
} else if (info.enable_hip_graph) {
// current hip graph implementation only works with single stream
// use EP level unified stream for all the reqeust
HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
use_ep_level_unified_stream_ = true;
} else {
stream_ = nullptr;
}
Expand Down Expand Up @@ -322,25 +356,58 @@ Status ROCMExecutionProvider::Sync() const {
Status ROCMExecutionProvider::OnRunStart() {
// always set ROCM device when session::Run() in case it runs in a worker thread
HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model";
GetPerThreadContext().CaptureBegin();
}
return Status::OK();
}

Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
// HIP work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
}
}

if (sync_stream) {
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream_)));
}

// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
// The reason of !IsGraphCaptureEnabled():
// If hip graph is enabled, the per thread context will not be released
// because the per thread hip graph needs to be maintained and replayed for
// the next run.
// The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end():
// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (!IsGraphCaptureEnabled() &&
PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
ReleasePerThreadContext();
}

return Status::OK();
}

bool ROCMExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_hip_graph;
}

bool ROCMExecutionProvider::IsGraphCaptured() const {
return GetPerThreadContext().IsGraphCaptured();
}

Status ROCMExecutionProvider::ReplayGraph() {
return GetPerThreadContext().ReplayGraph();
}

namespace rocm {
// opset 1 to 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/providers/rocm/rocm_graph.h"
#include "core/providers/rocm/rocm_pch.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
Expand Down Expand Up @@ -73,6 +74,9 @@ class ROCMExecutionProvider : public IExecutionProvider {

std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand All @@ -81,6 +85,7 @@ class ROCMExecutionProvider : public IExecutionProvider {
ROCMExecutionProviderInfo info_;
hipDeviceProp_t device_prop_;
bool external_stream_ = false;
// only used when set user external stream or hip graph
hipStream_t stream_ = nullptr;

bool use_ep_level_unified_stream_ = false;
Expand Down Expand Up @@ -133,6 +138,13 @@ class ROCMExecutionProvider : public IExecutionProvider {
}
}

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Status ReplayGraph();
void IncrementRegularRunCountBeforeGraphCapture();

private:
rocblas_handle rocblas_handle_ = nullptr;
miopenHandle_t miopen_handle_ = nullptr;
Expand All @@ -141,6 +153,18 @@ class ROCMExecutionProvider : public IExecutionProvider {
std::unique_ptr<rocm::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<rocm::IConstantBuffer<half>> constant_ones_half_;
std::unique_ptr<rocm::IConstantBuffer<BFloat16>> constant_ones_bfloat16_;

// Hip graph with multi threads will be supported in the future, so hip_graph_
// is put under PerThreadContext.
ROCMGraph hip_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
// to allocate enough memory in Arena before graph capturing.
const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.
};

using PerThreadContextMap = std::unordered_map<const ROCMExecutionProvider*, std::weak_ptr<PerThreadContext>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc";
constexpr const char* kGpuExternalFree = "gpu_external_free";
constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache";
constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace";
constexpr const char* kEnableHipGraph = "enable_hip_graph";
constexpr const char* kTunableOpEnable = "tunable_op_enable";
constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable";
constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms";
Expand Down Expand Up @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
info.miopen_conv_exhaustive_search)
.AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace)
.AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph)
.AddValueParser(
rocm::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
Expand Down Expand Up @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)},
{rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)},
{rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)},
{rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)},
{rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo {
// If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best.
bool miopen_conv_use_max_workspace{true};

bool enable_hip_graph{false};

rocm::TunableOpInfo tunable_op{};

static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider {
info.has_user_compute_stream = params->has_user_compute_stream != 0;
info.user_compute_stream = params->user_compute_stream;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
info.enable_hip_graph = params->enable_hip_graph;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms;
Expand Down Expand Up @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider {
rocm_options.user_compute_stream = internal_options.user_compute_stream;
}
rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
rocm_options.enable_hip_graph = internal_options.enable_hip_graph;
rocm_options.tunable_op_enable = internal_options.tunable_op.enable;
rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable;
rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms;
Expand Down
Loading

0 comments on commit b2aec41

Please sign in to comment.