Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ROCm] enable hipGraph #18382

Merged
merged 3 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions cmake/onnxruntime_unittests.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -1277,6 +1277,9 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
if (onnxruntime_USE_CUDA)
list(APPEND onnxruntime_shared_lib_test_LIBS cudart)
endif()
if (onnxruntime_USE_ROCM)
list(APPEND onnxruntime_shared_lib_test_LIBS hip::host)
endif()
if (onnxruntime_USE_TENSORRT)
list(APPEND onnxruntime_shared_lib_test_LIBS ${TENSORRT_LIBRARY_INFER})
endif()
Expand All @@ -1294,6 +1297,10 @@ if (NOT onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES})
target_sources(onnxruntime_shared_lib_test PRIVATE ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/cuda_ops.cu)
endif()
if (onnxruntime_USE_ROCM)
target_include_directories(onnxruntime_shared_lib_test PRIVATE ${onnxruntime_ROCM_HOME}/include)
target_compile_definitions(onnxruntime_shared_lib_test PRIVATE __HIP_PLATFORM_AMD__)
endif()
if (CMAKE_SYSTEM_NAME STREQUAL "Android")
target_sources(onnxruntime_shared_lib_test PRIVATE
"${ONNXRUNTIME_ROOT}/core/platform/android/cxa_demangle.cc"
Expand Down
3 changes: 3 additions & 0 deletions include/onnxruntime/core/session/onnxruntime_c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,7 @@
has_user_compute_stream{},
user_compute_stream{},
default_memory_arena_cfg{},
enable_hip_graph{false},
tunable_op_enable{false},
tunable_op_tuning_enable{false},
tunable_op_max_tuning_duration_ms{} {}
Expand Down Expand Up @@ -548,9 +549,11 @@
*/
OrtArenaCfg* default_memory_arena_cfg;

int enable_hip_graph;

/** \brief Enable TunableOp for using.
* Set it to 1/0 to enable/disable TunableOp. Otherwise, it is disabled by default.
* This option can be overriden by environment variable ORT_ROCM_TUNABLE_OP_ENABLE.

Check warning on line 556 in include/onnxruntime/core/session/onnxruntime_c_api.h

View workflow job for this annotation

GitHub Actions / Optional Lint

[misspell] reported by reviewdog 🐶 "overriden" is a misspelling of "overridden" Raw Output: ./include/onnxruntime/core/session/onnxruntime_c_api.h:556:26: "overriden" is a misspelling of "overridden"
*/
int tunable_op_enable;

Expand Down
77 changes: 72 additions & 5 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,42 @@

MIOPEN_CALL_THROW(miopenCreate(&miopen_handle_));
MIOPEN_CALL_THROW(miopenSetStream(miopen_handle_, stream));

hip_graph_.SetStream(stream);
}

ROCMExecutionProvider::PerThreadContext::~PerThreadContext() {
ORT_IGNORE_RETURN_VALUE(ROCBLAS_CALL(rocblas_destroy_handle(rocblas_handle_)));
ORT_IGNORE_RETURN_VALUE(MIOPEN_CALL(miopenDestroy(miopen_handle_)));
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_;
}

void ROCMExecutionProvider::PerThreadContext::CaptureBegin() {
hip_graph_.Reset();
hip_graph_.CaptureBegin();
}

void ROCMExecutionProvider::PerThreadContext::CaptureEnd() {
hip_graph_.CaptureEnd();
is_graph_captured_ = true;
}

bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const {
return is_graph_captured_;
}

Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
tianleiwu marked this conversation as resolved.
Show resolved Hide resolved
return hip_graph_.Replay();
}

void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
++regular_run_count_before_graph_capture_;
}

void OverrideTunableOpInfoByEnv(ROCMExecutionProviderInfo& info) {
if (auto env_tunable_op_enable = onnxruntime::ParseTestOnlyEnvironmentVariable<bool>(
"ORT_ROCM_TUNABLE_OP_ENABLE", {"0", "1"}, "Use provider_options \"tunable_op_enable\" instead.");
Expand Down Expand Up @@ -219,6 +248,11 @@
if (info.external_allocator_info.UseExternalAllocator()) {
use_ep_level_unified_stream_ = true;
stream_ = nullptr;
} else if (info.enable_hip_graph) {
// current hip graph implementation only works with single stream
// use EP level unified stream for all the reqeust
HIP_CALL_THROW(hipStreamCreateWithFlags(&stream_, hipStreamNonBlocking));
use_ep_level_unified_stream_ = true;
} else {
stream_ = nullptr;
}
Expand Down Expand Up @@ -322,25 +356,58 @@
Status ROCMExecutionProvider::OnRunStart() {
// always set ROCM device when session::Run() in case it runs in a worker thread
HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {

Check warning on line 359 in onnxruntime/core/providers/rocm/rocm_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/rocm_execution_provider.cc:359: Lines should be <= 120 characters long [whitespace/line_length] [2]
LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model";
GetPerThreadContext().CaptureBegin();
}
return Status::OK();
}

Status ROCMExecutionProvider::OnRunEnd(bool sync_stream) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
// HIP work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
}
}

if (sync_stream) {
HIP_RETURN_IF_ERROR(hipStreamSynchronize(static_cast<hipStream_t>(stream_)));
}

// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
// The reason of !IsGraphCaptureEnabled():
// If hip graph is enabled, the per thread context will not be released
// because the per thread hip graph needs to be maintained and replayed for
// the next run.
// The reason of PerThreadContextCache()->find(this) != PerThreadContextCache()->end():
// In extreme cases (e.g., 1-op graph and that op fallbacks to CPU),
// PerThreadContext won't be created and there is nothing to
// release. This didn't happen before because we always call
// GetPerThreadContext in OnRunStart.
if (!IsGraphCaptureEnabled() &&
PerThreadContextCache()->find(this) != PerThreadContextCache()->end()) {
ReleasePerThreadContext();
}

return Status::OK();
}

bool ROCMExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_hip_graph;
}

bool ROCMExecutionProvider::IsGraphCaptured() const {
return GetPerThreadContext().IsGraphCaptured();
}

Status ROCMExecutionProvider::ReplayGraph() {
return GetPerThreadContext().ReplayGraph();
}

namespace rocm {
// opset 1 to 9
class ONNX_OPERATOR_KERNEL_CLASS_NAME(kRocmExecutionProvider, kOnnxDomain, 1, MemcpyFromHost);
Expand Down
24 changes: 24 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include "core/framework/execution_provider.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/rocm/rocm_execution_provider_info.h"
#include "core/providers/rocm/rocm_graph.h"
#include "core/providers/rocm/rocm_pch.h"
#include "core/providers/rocm/shared_inc/rocm_utils.h"
#include "core/providers/rocm/shared_inc/rocm_call.h"
Expand Down Expand Up @@ -73,6 +74,9 @@

std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand All @@ -81,6 +85,7 @@
ROCMExecutionProviderInfo info_;
hipDeviceProp_t device_prop_;
bool external_stream_ = false;
// only used when set user external stream or hip graph
hipStream_t stream_ = nullptr;

bool use_ep_level_unified_stream_ = false;
Expand Down Expand Up @@ -133,6 +138,13 @@
}
}

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Status ReplayGraph();
void IncrementRegularRunCountBeforeGraphCapture();

private:
rocblas_handle rocblas_handle_ = nullptr;
miopenHandle_t miopen_handle_ = nullptr;
Expand All @@ -141,6 +153,18 @@
std::unique_ptr<rocm::IConstantBuffer<double>> constant_ones_double_;
std::unique_ptr<rocm::IConstantBuffer<half>> constant_ones_half_;
std::unique_ptr<rocm::IConstantBuffer<BFloat16>> constant_ones_bfloat16_;

// Hip graph with multi threads will be supported in the future, so hip_graph_
// is put under PerThreadContext.
ROCMGraph hip_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
// Since no GPU memory allocation is allowed during graph capturing, we need at least two regular runs
// to allocate enough memory in Arena before graph capturing.
const int min_num_runs_before_hip_graph_capture_ = 2; // required min regular runs before graph capture for the necessary memory allocations.

Check warning on line 167 in onnxruntime/core/providers/rocm/rocm_execution_provider.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/rocm/rocm_execution_provider.h:167: Lines should be <= 120 characters long [whitespace/line_length] [2]
};

using PerThreadContextMap = std::unordered_map<const ROCMExecutionProvider*, std::weak_ptr<PerThreadContext>>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ constexpr const char* kGpuExternalAlloc = "gpu_external_alloc";
constexpr const char* kGpuExternalFree = "gpu_external_free";
constexpr const char* kGpuExternalEmptyCache = "gpu_external_empty_cache";
constexpr const char* kMiopenConvUseMaxWorkspace = "miopen_conv_use_max_workspace";
constexpr const char* kEnableHipGraph = "enable_hip_graph";
constexpr const char* kTunableOpEnable = "tunable_op_enable";
constexpr const char* kTunableOpTuningEnable = "tunable_op_tuning_enable";
constexpr const char* kTunableOpMaxTuningDurationMs = "tunable_op_max_tuning_duration_ms";
Expand Down Expand Up @@ -84,6 +85,7 @@ ROCMExecutionProviderInfo ROCMExecutionProviderInfo::FromProviderOptions(const P
info.miopen_conv_exhaustive_search)
.AddAssignmentToReference(rocm::provider_option_names::kDoCopyInDefaultStream, info.do_copy_in_default_stream)
.AddAssignmentToReference(rocm::provider_option_names::kMiopenConvUseMaxWorkspace, info.miopen_conv_use_max_workspace)
.AddAssignmentToReference(rocm::provider_option_names::kEnableHipGraph, info.enable_hip_graph)
.AddValueParser(
rocm::provider_option_names::kTunableOpEnable,
[&info](const std::string& value_str) -> Status {
Expand Down Expand Up @@ -121,6 +123,7 @@ ProviderOptions ROCMExecutionProviderInfo::ToProviderOptions(const ROCMExecution
{rocm::provider_option_names::kMiopenConvExhaustiveSearch, MakeStringWithClassicLocale(info.miopen_conv_exhaustive_search)},
{rocm::provider_option_names::kDoCopyInDefaultStream, MakeStringWithClassicLocale(info.do_copy_in_default_stream)},
{rocm::provider_option_names::kMiopenConvUseMaxWorkspace, MakeStringWithClassicLocale(info.miopen_conv_use_max_workspace)},
{rocm::provider_option_names::kEnableHipGraph, MakeStringWithClassicLocale(info.enable_hip_graph)},
{rocm::provider_option_names::kTunableOpEnable, MakeStringWithClassicLocale(info.tunable_op.enable)},
{rocm::provider_option_names::kTunableOpTuningEnable, MakeStringWithClassicLocale(info.tunable_op.tuning_enable)},
{rocm::provider_option_names::kTunableOpMaxTuningDurationMs, MakeStringWithClassicLocale(info.tunable_op.max_tuning_duration_ms)},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ struct ROCMExecutionProviderInfo {
// If set to false, use fix workspace size (32M) for Conv algo search, the final algo might not be the best.
bool miopen_conv_use_max_workspace{true};

bool enable_hip_graph{false};

rocm::TunableOpInfo tunable_op{};

static ROCMExecutionProviderInfo FromProviderOptions(const ProviderOptions& options);
Expand Down
2 changes: 2 additions & 0 deletions onnxruntime/core/providers/rocm/rocm_provider_factory.cc
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ struct ROCM_Provider : Provider {
info.has_user_compute_stream = params->has_user_compute_stream != 0;
info.user_compute_stream = params->user_compute_stream;
info.default_memory_arena_cfg = params->default_memory_arena_cfg;
info.enable_hip_graph = params->enable_hip_graph;
info.tunable_op.enable = params->tunable_op_enable;
info.tunable_op.tuning_enable = params->tunable_op_tuning_enable;
info.tunable_op.max_tuning_duration_ms = params->tunable_op_max_tuning_duration_ms;
Expand Down Expand Up @@ -215,6 +216,7 @@ struct ROCM_Provider : Provider {
rocm_options.user_compute_stream = internal_options.user_compute_stream;
}
rocm_options.default_memory_arena_cfg = internal_options.default_memory_arena_cfg;
rocm_options.enable_hip_graph = internal_options.enable_hip_graph;
rocm_options.tunable_op_enable = internal_options.tunable_op.enable;
rocm_options.tunable_op_tuning_enable = internal_options.tunable_op.tuning_enable;
rocm_options.tunable_op_max_tuning_duration_ms = internal_options.tunable_op.max_tuning_duration_ms;
Expand Down
Loading
Loading