Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda graph enhancement #19636

Merged
merged 31 commits into from
Mar 7, 2024
Merged
Show file tree
Hide file tree
Changes from 28 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
bcb90a9
cuda graph enhancement
wangyems Feb 24, 2024
03a3a80
update
wangyems Feb 24, 2024
772471c
add python test
wangyems Feb 24, 2024
5dbda60
accept RunOptions as an argument to ReplayGraph
wangyems Feb 28, 2024
773eb0d
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Feb 28, 2024
f2b5c1e
update
wangyems Feb 29, 2024
dfdb78c
update tests
wangyems Mar 1, 2024
d008515
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 1, 2024
5ff5586
update
wangyems Mar 1, 2024
e51c9c4
lint
wangyems Mar 1, 2024
0281e57
review comments
wangyems Mar 1, 2024
e655e35
review comments
wangyems Mar 4, 2024
b180c87
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 4, 2024
3078d2a
update
wangyems Mar 4, 2024
29fd4ff
update
wangyems Mar 4, 2024
74d9e18
update
wangyems Mar 5, 2024
ae1024f
fix trt/rocm build
wangyems Mar 5, 2024
ce68222
fix build
wangyems Mar 5, 2024
5d67e12
review comments
wangyems Mar 5, 2024
7a16be6
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 5, 2024
8993bdf
review comments
wangyems Mar 6, 2024
8884b55
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 6, 2024
1e40998
review comments
wangyems Mar 6, 2024
5d411a3
sync io
wangyems Mar 6, 2024
9725c7e
del session after test()
wangyems Mar 6, 2024
b65e244
update
wangyems Mar 7, 2024
1bbd87f
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 7, 2024
5bcce25
lint
wangyems Mar 7, 2024
e83a37f
review comments & add c++ test
wangyems Mar 7, 2024
2b1219e
Merge branch 'main' of github.com:microsoft/onnxruntime into wangye/c…
wangyems Mar 7, 2024
7d0fca8
fix warnings
wangyems Mar 7, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions include/onnxruntime/core/framework/execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,21 +202,21 @@ class IExecutionProvider {

/**
Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for
the provider. Currently only CUDA execution provider supports it.
the provider.
*/
virtual bool IsGraphCaptureEnabled() const { return false; }

/**
Indicate whether the graph has been captured and instantiated. Currently
only CUDA execution provider supports it.
Indicate whether the graph has been captured and instantiated.
*/
virtual bool IsGraphCaptured() const { return false; }
virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; }

/**
Run the instantiated graph. Currently only CUDA execution provider supports
it.
Run the instantiated graph.
*/
virtual common::Status ReplayGraph() { return Status::OK(); }
virtual common::Status ReplayGraph(int /*graph_annotation_id*/) {
return Status::OK();
}

/**
Called when session creation is complete
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,10 @@ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_

// Set RPC control latency for QNN HTP backend
static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency";

// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true.
// The value should be an integer. If the value is not set, the default value is 0 and
// ORT session only captures one cuda graph before another capture is requested.
// If the value is set to -1, cuda graph capture/replay is disabled in that run.
// User are not expected to set the value to 0 as it is reserved for internal use.
static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id";
74 changes: 49 additions & 25 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// Licensed under the MIT License.

#include "core/common/inlined_containers.h"
#include "core/common/parse_string.h"
#include "core/providers/shared_library/provider_api.h"
#include "core/platform/env_var_utils.h"
#include "core/providers/cuda/cuda_execution_provider.h"
Expand All @@ -11,6 +12,7 @@
#include "core/providers/cuda/cuda_fwd.h"
#include "core/providers/cuda/gpu_data_transfer.h"
#include "core/providers/cuda/cuda_profiler.h"
#include "core/session/onnxruntime_run_options_config_keys.h"

#ifndef USE_CUDA_MINIMAL
#ifndef DISABLE_CONTRIB_OPS
Expand Down Expand Up @@ -190,27 +192,46 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() {
#endif
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_;
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed(
CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_ &&
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need regular run counter for each cuda_graph_annotation_id.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#19856 for bug fix

IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::CaptureBegin() {
cuda_graph_.Reset();
cuda_graph_.CaptureBegin();
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun(
CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::CaptureEnd() {
cuda_graph_.CaptureEnd();
is_graph_captured_ = true;
CudaGraphAnnotation_t CUDAExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId(
const onnxruntime::RunOptions& run_options) const {
auto graph_annotation_str =
run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation);
// If graph annotation is not provided, fall back to the one cuda graph per session behavior
CudaGraphAnnotation_t cuda_graph_annotation_id = 0;
if (graph_annotation_str.has_value()) {
ORT_ENFORCE(TryParseStringWithClassicLocale<int>(*graph_annotation_str, cuda_graph_annotation_id),
"Failed to parse the cuda graph annotation id: ",
*graph_annotation_str);
}

return cuda_graph_annotation_id;
}

void CUDAExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) {
cuda_graph_.CaptureBegin(cuda_graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) {
cuda_graph_.CaptureEnd(cuda_graph_annotation_id);
}

bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const {
return is_graph_captured_;
bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const {
return cuda_graph_.IsGraphCaptured(graph_annotation_id);
}

Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() {
ORT_ENFORCE(IsGraphCaptured());
return cuda_graph_.Replay();
Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) {
return cuda_graph_.Replay(graph_annotation_id);
}

void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() {
Expand Down Expand Up @@ -386,23 +407,26 @@ Status CUDAExecutionProvider::Sync() const {
return Status::OK();
}

Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) {
Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) {
// always set CUDA device when session::Run() in case it runs in a worker thread
CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId()));
if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) {
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) &&
GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model";
GetPerThreadContext().CaptureBegin();
GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id);
}
return Status::OK();
}

Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) {
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) {
if (GetPerThreadContext().IsGraphCaptureAllowed()) {
GetPerThreadContext().CaptureEnd();
Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) {
CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options);
if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) {
if (GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) {
GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id);
// CUDA work issued to a capturing stream doesn’t actually run on the GPU,
// so run the captured graph here to actually execute the work.
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph());
ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id));
} else {
GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture();
}
Expand Down Expand Up @@ -433,12 +457,12 @@ bool CUDAExecutionProvider::IsGraphCaptureEnabled() const {
return info_.enable_cuda_graph;
}

bool CUDAExecutionProvider::IsGraphCaptured() const {
return GetPerThreadContext().IsGraphCaptured();
bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const {
return GetPerThreadContext().IsGraphCaptured(graph_annotation_id);
}

Status CUDAExecutionProvider::ReplayGraph() {
return GetPerThreadContext().ReplayGraph();
Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) {
return GetPerThreadContext().ReplayGraph(graph_annotation_id);
}

namespace cuda {
Expand Down
17 changes: 9 additions & 8 deletions onnxruntime/core/providers/cuda/cuda_execution_provider.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,8 @@ class CUDAExecutionProvider : public IExecutionProvider {
std::unique_ptr<profiling::EpProfiler> GetProfiler() override;

bool IsGraphCaptureEnabled() const override;
bool IsGraphCaptured() const override;
Status ReplayGraph() override;
bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override;
Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override;
void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override;
OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override;
std::vector<AllocatorPtr> CreatePreferredAllocators() override;
Expand Down Expand Up @@ -168,11 +168,13 @@ class CUDAExecutionProvider : public IExecutionProvider {
}
}

bool IsGraphCaptureAllowed() const;
void CaptureBegin();
void CaptureEnd();
bool IsGraphCaptured() const;
Status ReplayGraph();
bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id);
void IncrementRegularRunCountBeforeGraphCapture();

private:
Expand All @@ -192,7 +194,6 @@ class CUDAExecutionProvider : public IExecutionProvider {
// Cuda graph with multi threads will be supported in the future, so cuda_graph_
// is put under PerThreadContext.
CUDAGraph cuda_graph_;
bool is_graph_captured_ = false;
int regular_run_count_before_graph_capture_ = 0;

// There is chance that the second regular run allocates GPU memory for causes like:
Expand Down
91 changes: 64 additions & 27 deletions onnxruntime/core/providers/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,44 @@

namespace onnxruntime {

CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) {
CudaGraphSet::~CudaGraphSet() {
Clear();
}

void CUDAGraph::SetStream(cudaStream_t stream) {
void CudaGraphSet::Clear() {
for (auto& it : cuda_graphs_) {
CUDA_CALL_THROW(cudaGraphExecDestroy(it.second));
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}
cuda_graphs_.clear();
}

bool CudaGraphSet::Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graphs_.find(cuda_graph_annotation_id) != cuda_graphs_.end();
}

void CudaGraphSet::Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec) {
ORT_ENFORCE(!Contains(cuda_graph_annotation_id));
cuda_graphs_.emplace(cuda_graph_annotation_id, graph_exec);
}

cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
ORT_ENFORCE(Contains(cuda_graph_annotation_id));
return cuda_graphs_.at(cuda_graph_annotation_id);
}

CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) {
}

void CUDAGraphManager::SetStream(cudaStream_t stream) {
stream_ = stream;
}

void CUDAGraph::CaptureBegin() {
ORT_ENFORCE(!has_graph_exec_,
"This cuda graph has already captured a graph. "
"Create a new instance to capture a new graph.");
void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id));

ORT_ENFORCE(!cuda_graph_set_.Contains(cuda_graph_annotation_id),
wangyems marked this conversation as resolved.
Show resolved Hide resolved
"Trying to capture a graph with annotation id ", cuda_graph_annotation_id,
" that already used. Please use a different annotation id.");

CUDA_CALL_THROW(cudaStreamSynchronize(stream_));
// For now cuda graph can only work with a single thread. In the future, we
Expand All @@ -29,40 +56,50 @@ void CUDAGraph::CaptureBegin() {
CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal));
}

void CUDAGraph::CaptureEnd() {
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_));
if (graph_ == NULL) {
void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) {
cudaGraph_t graph = NULL;
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph));
if (graph == NULL) {
ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL");
}

has_graph_ = true;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
has_graph_exec_ = true;
CUDA_CALL_THROW(cudaGraphDestroy(graph_));
has_graph_ = false;
cudaGraphExec_t graph_exec = NULL;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0));
CUDA_CALL_THROW(cudaGraphDestroy(graph));

// Currently all the captured graphs will be tied to the session's lifecycle
// TODO(wy): Addd an interface to free captured graphs
cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec);
wangyems marked this conversation as resolved.
Show resolved Hide resolved
}

Status CUDAGraph::Replay() {
Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) {
ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id));
wangyems marked this conversation as resolved.
Show resolved Hide resolved

// Although this function is not thread safe, the lock is not needed here because
// CUDA EP maintains a separate cuda graph per thread
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_;
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_));
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id "
<< cuda_graph_annotation_id;

cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id);
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_));

CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_));
return Status::OK();
}

void CUDAGraph::Reset() {
if (has_graph_) {
CUDA_CALL_THROW(cudaGraphDestroy(graph_));
has_graph_ = false;
}
if (has_graph_exec_) {
CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_));
has_graph_exec_ = false;
}
bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graph_annotation_id != kCudaGraphAnnotationSkip;
}

bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const {
return cuda_graph_set_.Contains(cuda_graph_annotation_id);
}

void CUDAGraphManager::Reset() {
cuda_graph_set_.Clear();
}

CUDAGraph::~CUDAGraph() {
CUDAGraphManager::~CUDAGraphManager() {
Reset();
}

Expand Down
48 changes: 35 additions & 13 deletions onnxruntime/core/providers/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,55 @@

#pragma once

#include <unordered_map>

#include "core/common/common.h"
#include "core/platform/ort_mutex.h"
#include "core/providers/cuda/cuda_pch.h"

namespace onnxruntime {

using CaptureId_t = unsigned long long;
using CudaGraphAnnotation_t = int;
using CudaGraphSet_t = std::unordered_map<CudaGraphAnnotation_t, cudaGraphExec_t>;
wangyems marked this conversation as resolved.
Show resolved Hide resolved

constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1;
constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0;

struct CudaGraphSet {
CudaGraphSet(){};

Check warning on line 21 in onnxruntime/core/providers/cuda/cuda_graph.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/cuda/cuda_graph.h:21: You don't need a ; after a } [readability/braces] [4]
~CudaGraphSet();

struct CUDAGraph {
CUDAGraph(){};
CUDAGraph(cudaStream_t stream);
~CUDAGraph();
void Clear();
bool Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
void Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec);
cudaGraphExec_t Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const;

private:
CudaGraphSet_t cuda_graphs_;
};

struct CUDAGraphManager {
CUDAGraphManager(){};

Check warning on line 34 in onnxruntime/core/providers/cuda/cuda_graph.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 You don't need a ; after a } [readability/braces] [4] Raw Output: onnxruntime/core/providers/cuda/cuda_graph.h:34: You don't need a ; after a } [readability/braces] [4]
CUDAGraphManager(cudaStream_t stream);

Check warning on line 35 in onnxruntime/core/providers/cuda/cuda_graph.h

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Single-parameter constructors should be marked explicit. [runtime/explicit] [5] Raw Output: onnxruntime/core/providers/cuda/cuda_graph.h:35: Single-parameter constructors should be marked explicit. [runtime/explicit] [5]
~CUDAGraphManager();

void SetStream(cudaStream_t stream);
void CaptureBegin();
void CaptureEnd();
Status Replay();
void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id);
void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id);
Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id);

void Reset();

private:
cudaGraph_t graph_ = NULL;
cudaGraphExec_t graph_exec_ = NULL;
wangyems marked this conversation as resolved.
Show resolved Hide resolved
bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const;
wangyems marked this conversation as resolved.
Show resolved Hide resolved

bool has_graph_ = false;
bool has_graph_exec_ = false;
private:
CudaGraphSet cuda_graph_set_;
CudaGraphAnnotation_t cuda_graph_annotation_id_ = kCudaGraphAnnotationDefault;

cudaStream_t stream_ = nullptr; // Does not own the stream
};

using CUDAGraph = CUDAGraphManager;

} // namespace onnxruntime
Loading
Loading