Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda graph enhancement #19636

Merged
merged 31 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bcb90a9
cuda graph enhancement
wangyems Feb 24, 2024
03a3a80
update
wangyems Feb 24, 2024
772471c
add python test
wangyems Feb 24, 2024
5dbda60
accept RunOptions as an argument to ReplayGraph
wangyems Feb 28, 2024
773eb0d
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Feb 28, 2024
f2b5c1e
update
wangyems Feb 29, 2024
dfdb78c
update tests
wangyems Mar 1, 2024
d008515
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 1, 2024
5ff5586
update
wangyems Mar 1, 2024
e51c9c4
lint
wangyems Mar 1, 2024
0281e57
review comments
wangyems Mar 1, 2024
e655e35
review comments
wangyems Mar 4, 2024
b180c87
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 4, 2024
3078d2a
update
wangyems Mar 4, 2024
29fd4ff
update
wangyems Mar 4, 2024
74d9e18
update
wangyems Mar 5, 2024
ae1024f
fix trt/rocm build
wangyems Mar 5, 2024
ce68222
fix build
wangyems Mar 5, 2024
5d67e12
review comments
wangyems Mar 5, 2024
7a16be6
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 5, 2024
8993bdf
review comments
wangyems Mar 6, 2024
8884b55
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 6, 2024
1e40998
review comments
wangyems Mar 6, 2024
5d411a3
sync io
wangyems Mar 6, 2024
9725c7e
del session after test()
wangyems Mar 6, 2024
b65e244
update
wangyems Mar 7, 2024
1bbd87f
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 7, 2024
5bcce25
lint
wangyems Mar 7, 2024
e83a37f
review comments & add c++ test
wangyems Mar 7, 2024
2b1219e
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 7, 2024
7d0fca8
fix warnings
wangyems Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,12 @@ class IExecutionProvider {
*/
virtual common::Status Sync() const { return Status::OK(); }

/**
Set graph annotation for saving/retriving executable graphs (e.g., cuda graph).
Currently only CUDA execution provider supports it.
wangyems marked this conversation as resolved.
Show resolved Hide resolved
*/
virtual void SetGraphAnnotation(int) {}
wangyems marked this conversation as resolved.
Show resolved Hide resolved

/**
Called when InferenceSession::Run started
NOTE that due to async execution in provider, the actual work of previous
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,8 @@ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_

// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";

// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
// The value should be an integer. If the value is not set, ORT session only captures one cuda graph.
wangyems marked this conversation as resolved.
Show resolved Hide resolved
wangyems marked this conversation as resolved.
Show resolved Hide resolved
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "ep.cuda.cuda_graph_annotation";
wangyems marked this conversation as resolved.
Show resolved Hide resolved
28 changes: 23 additions & 5 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,19 @@
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureSkippedOnRun() const {
return !cuda_graph_.IsGraphCaptureAllowedOnRun();
}

void CUDAExecutionProvider::PerThreadContext::SetCudaGraphAnnotationId(GraphAnnotationOptional_t cuda_graph_annotation_id) {

Check warning on line 201 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:201: Lines should be <= 120 characters long [whitespace/line_length] [2]
cuda_graph_annotation_id_ = cuda_graph_annotation_id;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
cuda_graph_.SetGraphAnnotation(cuda_graph_annotation_id_);
}

void CUDAExecutionProvider::PerThreadContext::CaptureBegin() {
cuda_graph_.Reset();
if (!cuda_graph_annotation_id_.has_value()) {
cuda_graph_.Reset();
}
cuda_graph_.CaptureBegin();
}

Expand All @@ -205,12 +216,15 @@
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const {
if (cuda_graph_annotation_id_.has_value()) {
return cuda_graph_.IsAdditionalGraphCaptured(*cuda_graph_annotation_id_);
}
return is_graph_captured_;
}

Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
return cuda_graph_.Replay();
return cuda_graph_.Replay(cuda_graph_annotation_id_);
}

void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
Expand Down Expand Up @@ -389,15 +403,15 @@
Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptureSkippedOnRun() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {

Check warning on line 406 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:406: Lines should be <= 120 characters long [whitespace/line_length] [2]
wangyems marked this conversation as resolved.
Show resolved Hide resolved
LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model";
GetPerThreadContext().CaptureBegin();
}
return Status::OK();
}

Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptureSkippedOnRun() && !GetPerThreadContext().IsGraphCaptured()) {

Check warning on line 414 in onnxruntime/core/providers/cuda/cuda_execution_provider.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_execution_provider.cc:414: Lines should be <= 120 characters long [whitespace/line_length] [2]
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
// CUDA work issued to a capturing stream doesn’t actually run on the GPU,
Expand Down Expand Up @@ -430,7 +444,11 @@
}

bool CUDAExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_cuda_graph;
return info_.enable_cuda_graph == 1;
}

void CUDAExecutionProvider::SetGraphAnnotation(GraphAnnotation_t cuda_graph_annotation_id) {
GetPerThreadContext().SetCudaGraphAnnotationId(make_optional<GraphAnnotation_t>(cuda_graph_annotation_id));
}

bool CUDAExecutionProvider::IsGraphCaptured() const {
Expand Down
4 changes: 4 additions & 0 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
void SetGraphAnnotation(GraphAnnotation_t graph_annotation_id) override;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand Down Expand Up @@ -177,6 +178,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
}

bool IsGraphCaptureAllowed() const;
bool IsGraphCaptureSkippedOnRun() const;
void SetCudaGraphAnnotationId(GraphAnnotationOptional_t cuda_graph_annotation_id);
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Expand All @@ -202,6 +205,7 @@ class CUDAExecutionProvider : public IExecutionProvider {
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;
GraphAnnotationOptional_t cuda_graph_annotation_id_;

// There is chance that the second regular run allocates GPU memory for causes like:
// (1) memory pattern is enabled. (2) arena allocation for stream.
Expand Down
86 changes: 80 additions & 6 deletions onnxruntime/core/providers/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,20 @@
stream_ = stream;
}

void CUDAGraph::SetGraphAnnotation(GraphAnnotationOptional_t cuda_graph_annotation_id) {
cuda_graph_annotation_id_ = cuda_graph_annotation_id;
}

void CUDAGraph::CaptureBegin() {
ORT_ENFORCE(!has_graph_exec_,
"This cuda graph has already captured a graph. "
"Create a new instance to capture a new graph.");
if (!cuda_graph_annotation_id_.has_value()) {
ORT_ENFORCE(!has_graph_exec_,
"This cuda graph has already captured a graph. "
"Create a new instance to capture a new graph.");
} else {
if (!IsGraphCaptureAllowedOnRun()) {
return;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
}

CUDA_CALL_THROW(cudaStreamSynchronize(stream_));
// For now cuda graph can only work with a single thread. In the future, we
Expand All @@ -30,6 +40,29 @@
}

void CUDAGraph::CaptureEnd() {
if (!IsGraphCaptureAllowedOnRun()) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
return;
}

if (cuda_graph_annotation_id_.has_value()) {
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &additional_graph_));
if (additional_graph_ == NULL) {
ORT_THROW("CUDAGraph::CaptureEnd: additional_graph_ is NULL");
}

cudaGraphExec_t graph_exec = NULL;

has_additional_graph_ = true;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, additional_graph_, NULL, NULL, 0));
CUDA_CALL_THROW(cudaGraphDestroy(additional_graph_));
has_additional_graph_ = false;

GraphAnnotation_t cuda_graph_id = cuda_graph_annotation_id_.value();
graph_exec_map_.emplace(cuda_graph_id, graph_exec);
hariharans29 marked this conversation as resolved.
Show resolved Hide resolved

return;
}

CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_));
if (graph_ == NULL) {
ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL");
Expand All @@ -42,15 +75,42 @@
has_graph_ = false;
}

Status CUDAGraph::Replay() {
Status CUDAGraph::Replay(GraphAnnotationOptional_t cuda_graph_annotation_id) {
if (!IsGraphCaptureAllowedOnRun()) {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
return Status::OK();
}
// Although this function is not thread safe, the lock is not needed here because
// CUDA EP maintains a separate cuda graph per thread
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_;
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_));
if (cuda_graph_annotation_id_.has_value()) {
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " << *cuda_graph_annotation_id;

Check warning on line 85 in onnxruntime/core/providers/cuda/cuda_graph.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/core/providers/cuda/cuda_graph.cc:85: Lines should be <= 120 characters long [whitespace/line_length] [2]
auto it = graph_exec_map_.find(*cuda_graph_annotation_id);
if (it == graph_exec_map_.end()) {
return ORT_MAKE_STATUS(ONNXRUNTIME,
FAIL,
"CUDAGraph::Replay: graph_exec_map_ does not contain the cuda_graph_annotation_id");
}
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(it->second, stream_));
} else {
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_;
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_));
}

CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
return Status::OK();
}

bool CUDAGraph::IsAdditionalGraphCaptured(GraphAnnotation_t cuda_graph_annotation_id) const {
return graph_exec_map_.find(cuda_graph_annotation_id) != graph_exec_map_.end();
}

bool CUDAGraph::IsGraphCaptureAllowedOnRun() const {
if (!cuda_graph_annotation_id_.has_value()) {
// std::cout << "IsGraphCaptureAllowedOnRun()::cuda_graph_annotation_id is empty" << std::endl;
return true;
}
return *cuda_graph_annotation_id_ != kDefaultSkipGraphCapture;
}

void CUDAGraph::Reset() {
if (has_graph_) {
CUDA_CALL_THROW(cudaGraphDestroy(graph_));
Expand All @@ -62,8 +122,22 @@
}
}

void CUDAGraph::ResetAdditional() {
wangyems marked this conversation as resolved.
Show resolved Hide resolved
if (has_additional_graph_) {
CUDA_CALL_THROW(cudaGraphDestroy(additional_graph_));
has_additional_graph_ = false;
}
if (!graph_exec_map_.empty()) {
for (auto& it : graph_exec_map_) {
CUDA_CALL_THROW(cudaGraphExecDestroy(it.second));
}
graph_exec_map_.clear();
}
}

CUDAGraph::~CUDAGraph() {
Reset();
ResetAdditional();
}

} // namespace onnxruntime
20 changes: 18 additions & 2 deletions onnxruntime/core/providers/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,23 +4,34 @@
#pragma once

#include "core/common/common.h"
#include "core/common/optional.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/cuda/cuda_pch.h"

namespace onnxruntime {

using CaptureId_t = unsigned long long;
using GraphAnnotation_t = int;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
using GraphAnnotationOptional_t = optional<GraphAnnotation_t>;

constexpr GraphAnnotation_t kDefaultSkipGraphCapture = -1;
wangyems marked this conversation as resolved.
Show resolved Hide resolved

struct CUDAGraph {
CUDAGraph(){};
CUDAGraph(cudaStream_t stream);
~CUDAGraph();

void SetStream(cudaStream_t stream);
void SetGraphAnnotation(GraphAnnotationOptional_t cuda_graph_annotation_id);
wangyems marked this conversation as resolved.
Show resolved Hide resolved

void CaptureBegin();
void CaptureEnd();
Status Replay();
Status Replay(GraphAnnotationOptional_t cuda_graph_annotation_id);

void Reset();
void ResetAdditional();

bool IsAdditionalGraphCaptured(GraphAnnotation_t cuda_graph_annotation_id) const;
bool IsGraphCaptureAllowedOnRun() const;

private:
cudaGraph_t graph_ = NULL;
Expand All @@ -29,6 +40,11 @@
bool has_graph_ = false;
bool has_graph_exec_ = false;

cudaGraph_t additional_graph_ = NULL;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
std::unordered_map<GraphAnnotation_t, cudaGraphExec_t> graph_exec_map_;

Check warning on line 44 in onnxruntime/core/providers/cuda/cuda_graph.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4] Raw Output: onnxruntime/core/providers/cuda/cuda_graph.h:44: Add #include <unordered_map> for unordered_map<> [build/include_what_you_use] [4]
GraphAnnotationOptional_t cuda_graph_annotation_id_;
bool has_additional_graph_ = false;

cudaStream_t stream_ = nullptr; // Does not own the stream
};

Expand Down
11 changes: 11 additions & 0 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2385,6 +2385,16 @@ Status InferenceSession::Run(const RunOptions& run_options,
auto* inter_tp = (control_spinning) ? inter_op_thread_pool_.get() : nullptr;
ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_);

if (cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled()) {
// Cuda graph annotation is only considered when enable_cuda_graph is set to true in session options
const std::string& graph_annotation_str =
run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigCudaGraphAnnotation, "");
// If graph annotation is not provided, fall back to the one cuda graph per session behavior
if (!graph_annotation_str.empty()) {
cached_execution_provider_for_graph_replay_.SetGraphAnnotation(std::stoi(graph_annotation_str));
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
}

// Check if this Run() is simply going to be a CUDA Graph replay.
if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Replaying the captured "
Expand Down Expand Up @@ -2552,6 +2562,7 @@ Status InferenceSession::Run(const RunOptions& run_options,
// N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP,
// and the value could be different for other EP.
if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptureSkippedOnRun() &&
!cached_execution_provider_for_graph_replay_.IsGraphCaptured()) {
LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture.";
ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info));
Expand Down
13 changes: 13 additions & 0 deletions onnxruntime/core/session/inference_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,14 @@ class InferenceSession {
cached_execution_provider_for_graph_replay_ = execution_provider;
}

void SetGraphAnnotation(int graph_annotation_id) {
graph_annotation_id_ = graph_annotation_id;
if (!cached_execution_provider_for_graph_replay_) {
ORT_THROW("Cached EP instance for graph replay is not set yet before calling SetGraphAnnotation()");
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
cached_execution_provider_for_graph_replay_->SetGraphAnnotation(graph_annotation_id_);
}

bool IsGraphCaptureEnabled() const {
return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptureEnabled();
}
Expand All @@ -871,6 +879,10 @@ class InferenceSession {
return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured();
}

bool IsGraphCaptureSkippedOnRun() const {
return cached_execution_provider_for_graph_replay_ != nullptr && graph_annotation_id_ == -1;
}

Status ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
if (cached_execution_provider_for_graph_replay_) {
Expand All @@ -884,6 +896,7 @@ class InferenceSession {
}

IExecutionProvider* cached_execution_provider_for_graph_replay_ = nullptr;
int graph_annotation_id_ = 0;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
};

CachedExecutionProviderForGraphReplay cached_execution_provider_for_graph_replay_;
Expand Down
Loading
Loading