diff --git a/include/onnxruntime/core/framework/execution_provider.h b/include/onnxruntime/core/framework/execution_provider.h index c1cc69edc17d8..40ca96a19aef1 100644 --- a/include/onnxruntime/core/framework/execution_provider.h +++ b/include/onnxruntime/core/framework/execution_provider.h @@ -202,21 +202,21 @@ class IExecutionProvider { /** Indicate whether the graph capturing mode (e.g., cuda graph) is enabled for - the provider. Currently only CUDA execution provider supports it. + the provider. */ virtual bool IsGraphCaptureEnabled() const { return false; } /** - Indicate whether the graph has been captured and instantiated. Currently - only CUDA execution provider supports it. + Indicate whether the graph has been captured and instantiated. */ - virtual bool IsGraphCaptured() const { return false; } + virtual bool IsGraphCaptured(int /*graph_annotation_id*/) const { return false; } /** - Run the instantiated graph. Currently only CUDA execution provider supports - it. + Run the instantiated graph. */ - virtual common::Status ReplayGraph() { return Status::OK(); } + virtual common::Status ReplayGraph(int /*graph_annotation_id*/) { + return Status::OK(); + } /** Called when session creation is complete diff --git a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h index b0a17e175fef3..c80b8c0c164b6 100644 --- a/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h +++ b/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h @@ -42,3 +42,10 @@ static const char* const kOrtRunOptionsConfigQnnPerfModePostRun = "qnn.htp_perf_ // Set RPC control latency for QNN HTP backend static const char* const kOrtRunOptionsConfigQnnRpcControlLatency = "qnn.rpc_control_latency"; + +// Set graph annotation id for CUDA EP. Use with enable_cuda_graph=true. +// The value should be an integer. If the value is not set, the default value is 0 and +// ORT session only captures one cuda graph before another capture is requested. +// If the value is set to -1, cuda graph capture/replay is disabled in that run. +// User are not expected to set the value to 0 as it is reserved for internal use. +static const char* const kOrtRunOptionsConfigCudaGraphAnnotation = "gpu_graph_id"; diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc index 3c0930638a205..bade2faf8f2e2 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.cc +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.cc @@ -3,6 +3,7 @@ // Licensed under the MIT License. #include "core/common/inlined_containers.h" +#include "core/common/parse_string.h" #include "core/providers/shared_library/provider_api.h" #include "core/platform/env_var_utils.h" #include "core/providers/cuda/cuda_execution_provider.h" @@ -11,6 +12,7 @@ #include "core/providers/cuda/cuda_fwd.h" #include "core/providers/cuda/gpu_data_transfer.h" #include "core/providers/cuda/cuda_profiler.h" +#include "core/session/onnxruntime_run_options_config_keys.h" #ifndef USE_CUDA_MINIMAL #ifndef DISABLE_CONTRIB_OPS @@ -190,27 +192,46 @@ CUDAExecutionProvider::PerThreadContext::~PerThreadContext() { #endif } -bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { - return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowed( + CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_ && + IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); } -void CUDAExecutionProvider::PerThreadContext::CaptureBegin() { - cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(); +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptureAllowedOnRun( + CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_.IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id); } -void CUDAExecutionProvider::PerThreadContext::CaptureEnd() { - cuda_graph_.CaptureEnd(); - is_graph_captured_ = true; +CudaGraphAnnotation_t CUDAExecutionProvider::PerThreadContext::GetCudaGraphAnnotationId( + const onnxruntime::RunOptions& run_options) const { + auto graph_annotation_str = + run_options.GetConfigOptions().GetConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation); + // If graph annotation is not provided, fall back to the one cuda graph per session behavior + CudaGraphAnnotation_t cuda_graph_annotation_id = 0; + if (graph_annotation_str.has_value()) { + ORT_ENFORCE(TryParseStringWithClassicLocale(*graph_annotation_str, cuda_graph_annotation_id), + "Failed to parse the cuda graph annotation id: ", + *graph_annotation_str); + } + + return cuda_graph_annotation_id; +} + +void CUDAExecutionProvider::PerThreadContext::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureBegin(cuda_graph_annotation_id); +} + +void CUDAExecutionProvider::PerThreadContext::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cuda_graph_.CaptureEnd(cuda_graph_annotation_id); } -bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured() const { - return is_graph_captured_; +bool CUDAExecutionProvider::PerThreadContext::IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const { + return cuda_graph_.IsGraphCaptured(graph_annotation_id); } -Status CUDAExecutionProvider::PerThreadContext::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); - return cuda_graph_.Replay(); +Status CUDAExecutionProvider::PerThreadContext::ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) { + return cuda_graph_.Replay(graph_annotation_id); } void CUDAExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { @@ -386,23 +407,26 @@ Status CUDAExecutionProvider::Sync() const { return Status::OK(); } -Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { +Status CUDAExecutionProvider::OnRunStart(const onnxruntime::RunOptions& run_options) { // always set CUDA device when session::Run() in case it runs in a worker thread CUDA_RETURN_IF_ERROR(cudaSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id) && + GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { LOGS(*GetLogger(), INFO) << "Capturing the cuda graph for this model"; - GetPerThreadContext().CaptureBegin(); + GetPerThreadContext().CaptureBegin(cuda_graph_annotation_id); } return Status::OK(); } -Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { - if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(); +Status CUDAExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) { + CudaGraphAnnotation_t cuda_graph_annotation_id = GetPerThreadContext().GetCudaGraphAnnotationId(run_options); + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(cuda_graph_annotation_id)) { + if (GetPerThreadContext().IsGraphCaptureAllowed(cuda_graph_annotation_id)) { + GetPerThreadContext().CaptureEnd(cuda_graph_annotation_id); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(cuda_graph_annotation_id)); } else { GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); } @@ -433,12 +457,12 @@ bool CUDAExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_cuda_graph; } -bool CUDAExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); +bool CUDAExecutionProvider::IsGraphCaptured(int graph_annotation_id) const { + return GetPerThreadContext().IsGraphCaptured(graph_annotation_id); } -Status CUDAExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); +Status CUDAExecutionProvider::ReplayGraph(int graph_annotation_id) { + return GetPerThreadContext().ReplayGraph(graph_annotation_id); } namespace cuda { diff --git a/onnxruntime/core/providers/cuda/cuda_execution_provider.h b/onnxruntime/core/providers/cuda/cuda_execution_provider.h index 75fe1dff7c4a4..6c70e6abc4fdf 100644 --- a/onnxruntime/core/providers/cuda/cuda_execution_provider.h +++ b/onnxruntime/core/providers/cuda/cuda_execution_provider.h @@ -92,8 +92,8 @@ class CUDAExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(CudaGraphAnnotation_t graph_annotation_id) const override; + Status ReplayGraph(CudaGraphAnnotation_t graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -168,11 +168,13 @@ class CUDAExecutionProvider : public IExecutionProvider { } } - bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); + bool IsGraphCaptureAllowed(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + CudaGraphAnnotation_t GetCudaGraphAnnotationId(const onnxruntime::RunOptions& run_options) const; + Status ReplayGraph(CudaGraphAnnotation_t cuda_graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); private: @@ -192,7 +194,6 @@ class CUDAExecutionProvider : public IExecutionProvider { // Cuda graph with multi threads will be supported in the future, so cuda_graph_ // is put under PerThreadContext. CUDAGraph cuda_graph_; - bool is_graph_captured_ = false; int regular_run_count_before_graph_capture_ = 0; // There is chance that the second regular run allocates GPU memory for causes like: diff --git a/onnxruntime/core/providers/cuda/cuda_graph.cc b/onnxruntime/core/providers/cuda/cuda_graph.cc index 230d664391611..8353c654681fc 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.cc +++ b/onnxruntime/core/providers/cuda/cuda_graph.cc @@ -9,17 +9,44 @@ namespace onnxruntime { -CUDAGraph::CUDAGraph(cudaStream_t stream) : stream_(stream) { +CudaGraphSet::~CudaGraphSet() { + Clear(); } -void CUDAGraph::SetStream(cudaStream_t stream) { +void CudaGraphSet::Clear() { + for (auto& it : cuda_graphs_) { + CUDA_CALL_THROW(cudaGraphExecDestroy(it.second)); + } + cuda_graphs_.clear(); +} + +bool CudaGraphSet::Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graphs_.find(cuda_graph_annotation_id) != cuda_graphs_.end(); +} + +void CudaGraphSet::Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec) { + ORT_ENFORCE(!Contains(cuda_graph_annotation_id)); + cuda_graphs_.emplace(cuda_graph_annotation_id, graph_exec); +} + +cudaGraphExec_t CudaGraphSet::Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + ORT_ENFORCE(Contains(cuda_graph_annotation_id)); + return cuda_graphs_.at(cuda_graph_annotation_id); +} + +CUDAGraphManager::CUDAGraphManager(cudaStream_t stream) : stream_(stream) { +} + +void CUDAGraphManager::SetStream(cudaStream_t stream) { stream_ = stream; } -void CUDAGraph::CaptureBegin() { - ORT_ENFORCE(!has_graph_exec_, - "This cuda graph has already captured a graph. " - "Create a new instance to capture a new graph."); +void CUDAGraphManager::CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptureAllowedOnRun(cuda_graph_annotation_id)); + + ORT_ENFORCE(!cuda_graph_set_.Contains(cuda_graph_annotation_id), + "Trying to capture a graph with annotation id ", cuda_graph_annotation_id, + " that already used. Please use a different annotation id."); CUDA_CALL_THROW(cudaStreamSynchronize(stream_)); // For now cuda graph can only work with a single thread. In the future, we @@ -29,40 +56,48 @@ void CUDAGraph::CaptureBegin() { CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal)); } -void CUDAGraph::CaptureEnd() { - CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_)); - if (graph_ == NULL) { +void CUDAGraphManager::CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id) { + cudaGraph_t graph = NULL; + CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph)); + if (graph == NULL) { ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL"); } - has_graph_ = true; - CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0)); - has_graph_exec_ = true; - CUDA_CALL_THROW(cudaGraphDestroy(graph_)); - has_graph_ = false; + cudaGraphExec_t graph_exec = NULL; + CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, graph, NULL, NULL, 0)); + CUDA_CALL_THROW(cudaGraphDestroy(graph)); + + // Currently all the captured graphs will be tied to the session's lifecycle + // TODO(wy): Addd an interface to free captured graphs + cuda_graph_set_.Put(cuda_graph_annotation_id, graph_exec); } -Status CUDAGraph::Replay() { +Status CUDAGraphManager::Replay(CudaGraphAnnotation_t cuda_graph_annotation_id) { // Although this function is not thread safe, the lock is not needed here because // CUDA EP maintains a separate cuda graph per thread - LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_; - CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_)); + LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " + << cuda_graph_annotation_id; + + cudaGraphExec_t graph_exec = cuda_graph_set_.Get(cuda_graph_annotation_id); + CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec, stream_)); + CUDA_RETURN_IF_ERROR(cudaStreamSynchronize(stream_)); return Status::OK(); } -void CUDAGraph::Reset() { - if (has_graph_) { - CUDA_CALL_THROW(cudaGraphDestroy(graph_)); - has_graph_ = false; - } - if (has_graph_exec_) { - CUDA_CALL_THROW(cudaGraphExecDestroy(graph_exec_)); - has_graph_exec_ = false; - } +bool CUDAGraphManager::IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_annotation_id != kCudaGraphAnnotationSkip; +} + +bool CUDAGraphManager::IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const { + return cuda_graph_set_.Contains(cuda_graph_annotation_id); +} + +void CUDAGraphManager::Reset() { + cuda_graph_set_.Clear(); } -CUDAGraph::~CUDAGraph() { +CUDAGraphManager::~CUDAGraphManager() { Reset(); } diff --git a/onnxruntime/core/providers/cuda/cuda_graph.h b/onnxruntime/core/providers/cuda/cuda_graph.h index 9bcefcc64ea77..064994c1f14ae 100644 --- a/onnxruntime/core/providers/cuda/cuda_graph.h +++ b/onnxruntime/core/providers/cuda/cuda_graph.h @@ -3,33 +3,55 @@ #pragma once +#include + #include "core/common/common.h" #include "core/platform/ort_mutex.h" #include "core/providers/cuda/cuda_pch.h" namespace onnxruntime { -using CaptureId_t = unsigned long long; +using CudaGraphAnnotation_t = int; +using CudaGraphSet_t = std::unordered_map; + +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationSkip = -1; +constexpr CudaGraphAnnotation_t kCudaGraphAnnotationDefault = 0; + +struct CudaGraphSet { + CudaGraphSet(){}; + ~CudaGraphSet(); -struct CUDAGraph { - CUDAGraph(){}; - CUDAGraph(cudaStream_t stream); - ~CUDAGraph(); + void Clear(); + bool Contains(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + void Put(CudaGraphAnnotation_t cuda_graph_annotation_id, cudaGraphExec_t graph_exec); + cudaGraphExec_t Get(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + + private: + CudaGraphSet_t cuda_graphs_; +}; + +struct CUDAGraphManager { + CUDAGraphManager(){}; + CUDAGraphManager(cudaStream_t stream); + ~CUDAGraphManager(); void SetStream(cudaStream_t stream); - void CaptureBegin(); - void CaptureEnd(); - Status Replay(); + void CaptureBegin(CudaGraphAnnotation_t cuda_graph_annotation_id); + void CaptureEnd(CudaGraphAnnotation_t cuda_graph_annotation_id); + Status Replay(CudaGraphAnnotation_t cuda_graph_annotation_id); + void Reset(); - private: - cudaGraph_t graph_ = NULL; - cudaGraphExec_t graph_exec_ = NULL; + bool IsGraphCaptureAllowedOnRun(CudaGraphAnnotation_t cuda_graph_annotation_id) const; + bool IsGraphCaptured(CudaGraphAnnotation_t cuda_graph_annotation_id) const; - bool has_graph_ = false; - bool has_graph_exec_ = false; + private: + CudaGraphSet cuda_graph_set_; + CudaGraphAnnotation_t cuda_graph_annotation_id_ = kCudaGraphAnnotationDefault; cudaStream_t stream_ = nullptr; // Does not own the stream }; +using CUDAGraph = CUDAGraphManager; + } // namespace onnxruntime diff --git a/onnxruntime/core/providers/js/js_execution_provider.cc b/onnxruntime/core/providers/js/js_execution_provider.cc index 62c3981682cfc..2d2c89f36f1a7 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.cc +++ b/onnxruntime/core/providers/js/js_execution_provider.cc @@ -757,7 +757,7 @@ JsExecutionProvider::~JsExecutionProvider() { } Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { LOGS(*GetLogger(), INFO) << "Capturing the webgpu graph for this model"; EM_ASM({ Module.jsepCaptureBegin(); }); } @@ -765,7 +765,7 @@ Status JsExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_opti } Status JsExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { EM_ASM({ Module.jsepCaptureEnd(); }); is_graph_captured_ = true; @@ -781,12 +781,12 @@ bool JsExecutionProvider::IsGraphCaptureEnabled() const { return enable_graph_capture_; } -bool JsExecutionProvider::IsGraphCaptured() const { +bool JsExecutionProvider::IsGraphCaptured(int) const { return is_graph_captured_; } -Status JsExecutionProvider::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); +Status JsExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); EM_ASM({ Module.jsepReplay(); }); return Status::OK(); } diff --git a/onnxruntime/core/providers/js/js_execution_provider.h b/onnxruntime/core/providers/js/js_execution_provider.h index b4518c67d1e60..efacf510e75df 100644 --- a/onnxruntime/core/providers/js/js_execution_provider.h +++ b/onnxruntime/core/providers/js/js_execution_provider.h @@ -63,8 +63,8 @@ class JsExecutionProvider : public IExecutionProvider { Status OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& run_options) override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; private: bool IsGraphCaptureAllowed() const; diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc index 4a679b790ee40..32be74550951e 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.cc +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.cc @@ -183,23 +183,24 @@ bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_hip_graph_capture_; } -void ROCMExecutionProvider::PerThreadContext::CaptureBegin() { +void ROCMExecutionProvider::PerThreadContext::CaptureBegin(int) { hip_graph_.Reset(); - hip_graph_.CaptureBegin(); + hip_graph_.CaptureBegin(0); } -void ROCMExecutionProvider::PerThreadContext::CaptureEnd() { - hip_graph_.CaptureEnd(); +void ROCMExecutionProvider::PerThreadContext::CaptureEnd(int) { + hip_graph_.CaptureEnd(0); is_graph_captured_ = true; } -bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured() const { +bool ROCMExecutionProvider::PerThreadContext::IsGraphCaptured(int) const { return is_graph_captured_; } -Status ROCMExecutionProvider::PerThreadContext::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); - return hip_graph_.Replay(); +Status ROCMExecutionProvider::PerThreadContext::ReplayGraph(int graph_annotation_id) { + ORT_ENFORCE(IsGraphCaptured(graph_annotation_id)); + + return hip_graph_.Replay(graph_annotation_id); } void ROCMExecutionProvider::PerThreadContext::IncrementRegularRunCountBeforeGraphCapture() { @@ -356,20 +357,20 @@ Status ROCMExecutionProvider::Sync() const { Status ROCMExecutionProvider::OnRunStart(const onnxruntime::RunOptions& /*run_options*/) { // always set ROCM device when session::Run() in case it runs in a worker thread HIP_RETURN_IF_ERROR(hipSetDevice(GetDeviceId())); - if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && GetPerThreadContext().IsGraphCaptureAllowed() && !GetPerThreadContext().IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the hip graph for this model"; - GetPerThreadContext().CaptureBegin(); + GetPerThreadContext().CaptureBegin(0); } return Status::OK(); } Status ROCMExecutionProvider::OnRunEnd(bool sync_stream, const onnxruntime::RunOptions& /*run_options*/) { - if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured()) { + if (IsGraphCaptureEnabled() && !GetPerThreadContext().IsGraphCaptured(0)) { if (GetPerThreadContext().IsGraphCaptureAllowed()) { - GetPerThreadContext().CaptureEnd(); + GetPerThreadContext().CaptureEnd(0); // HIP work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph()); + ORT_RETURN_IF_ERROR(GetPerThreadContext().ReplayGraph(0)); } else { GetPerThreadContext().IncrementRegularRunCountBeforeGraphCapture(); } @@ -400,12 +401,12 @@ bool ROCMExecutionProvider::IsGraphCaptureEnabled() const { return info_.enable_hip_graph; } -bool ROCMExecutionProvider::IsGraphCaptured() const { - return GetPerThreadContext().IsGraphCaptured(); +bool ROCMExecutionProvider::IsGraphCaptured(int) const { + return GetPerThreadContext().IsGraphCaptured(0); } -Status ROCMExecutionProvider::ReplayGraph() { - return GetPerThreadContext().ReplayGraph(); +Status ROCMExecutionProvider::ReplayGraph(int /*graph_annotation_id*/) { + return GetPerThreadContext().ReplayGraph(0); } namespace rocm { diff --git a/onnxruntime/core/providers/rocm/rocm_execution_provider.h b/onnxruntime/core/providers/rocm/rocm_execution_provider.h index da671d9e863bb..6d6c05027e7bd 100644 --- a/onnxruntime/core/providers/rocm/rocm_execution_provider.h +++ b/onnxruntime/core/providers/rocm/rocm_execution_provider.h @@ -75,8 +75,8 @@ class ROCMExecutionProvider : public IExecutionProvider { std::unique_ptr GetProfiler() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; void RegisterStreamHandlers(IStreamCommandHandleRegistry& stream_handle_registry, AllocatorMap& allocators) const override; OrtDevice GetOrtDeviceByMemType(OrtMemType mem_type) const override; std::vector CreatePreferredAllocators() override; @@ -139,10 +139,10 @@ class ROCMExecutionProvider : public IExecutionProvider { } bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); + void CaptureBegin(int graph_annotation_id); + void CaptureEnd(int graph_annotation_id); + bool IsGraphCaptured(int graph_annotation_id) const; + Status ReplayGraph(int graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); private: diff --git a/onnxruntime/core/providers/shared_library/provider_api.h b/onnxruntime/core/providers/shared_library/provider_api.h index b78279040acb6..1cebe4a256fd4 100644 --- a/onnxruntime/core/providers/shared_library/provider_api.h +++ b/onnxruntime/core/providers/shared_library/provider_api.h @@ -159,6 +159,7 @@ class OpKernel; struct OpKernelContext; struct OpKernelInfo; struct PrimitiveDataTypeBase; +struct OrtRunOptions; struct Tensor; struct SparseTensor; class TensorSeq; diff --git a/onnxruntime/core/providers/shared_library/provider_interfaces.h b/onnxruntime/core/providers/shared_library/provider_interfaces.h index f5a8327443864..0b8551e0c5a66 100644 --- a/onnxruntime/core/providers/shared_library/provider_interfaces.h +++ b/onnxruntime/core/providers/shared_library/provider_interfaces.h @@ -481,6 +481,9 @@ struct ProviderHost { // ConfigOptions virtual std::optional ConfigOptions__GetConfigEntry(const ConfigOptions* p, const std::string& config_key) = 0; + // OrtRunOptions + virtual const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) = 0; + // ComputeCapability virtual std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) = 0; virtual void ComputeCapability__operator_delete(ComputeCapability* p) = 0; diff --git a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h index dde4005c80b9d..dc2b79015d95e 100644 --- a/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h +++ b/onnxruntime/core/providers/shared_library/provider_wrappedtypes.h @@ -393,6 +393,14 @@ struct ConfigOptions final { PROVIDER_DISALLOW_ALL(ConfigOptions) }; +struct OrtRunOptions final { + const ConfigOptions& GetConfigOptions() const { + return g_host->RunOptions__GetConfigOptions(this); + } + + PROVIDER_DISALLOW_ALL(OrtRunOptions) +}; + struct ComputeCapability final { static std::unique_ptr Create(std::unique_ptr t_sub_graph) { return g_host->ComputeCapability__construct(std::move(t_sub_graph)); } static void operator delete(void* p) { g_host->ComputeCapability__operator_delete(reinterpret_cast(p)); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc index e521640681a77..632d521dc21a8 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.cc @@ -1633,26 +1633,26 @@ bool TensorrtExecutionProvider::IsGraphCaptureAllowed() const { return regular_run_count_before_graph_capture_ >= min_num_runs_before_cuda_graph_capture_; } -void TensorrtExecutionProvider::CaptureBegin() { +void TensorrtExecutionProvider::CaptureBegin(int) { cuda_graph_.Reset(); - cuda_graph_.CaptureBegin(); + cuda_graph_.CaptureBegin(0); } -void TensorrtExecutionProvider::CaptureEnd() { - cuda_graph_.CaptureEnd(); +void TensorrtExecutionProvider::CaptureEnd(int) { + cuda_graph_.CaptureEnd(0); is_graph_captured_ = true; } -bool TensorrtExecutionProvider::IsGraphCaptured() const { +bool TensorrtExecutionProvider::IsGraphCaptured(int) const { return is_graph_captured_; } -Status TensorrtExecutionProvider::ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); +Status TensorrtExecutionProvider::ReplayGraph(int) { + ORT_ENFORCE(IsGraphCaptured(0)); // Please note that CUDAGraph::Replay() is not thread safe. - // ORT TRT calls ReplayGraph() in compute_func() where synchromization is enforced due to lock_guard(), + // ORT TRT calls ReplayGraph() in compute_func() where synchronization is enforced due to lock_guard(), // therefore calling CUDAGraph::Replay() here is guaranteed to be thread safe. - return cuda_graph_.Replay(); + return cuda_graph_.Replay(0); } void TensorrtExecutionProvider::IncrementRegularRunCountBeforeGraphCapture() { @@ -3412,10 +3412,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; cuda_graph_.SetStream(stream); - CaptureBegin(); + CaptureBegin(0); } // Run TRT inference @@ -3483,12 +3483,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromGraph(const GraphView // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured()) { + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { - CaptureEnd(); + CaptureEnd(0); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph()); + ORT_RETURN_IF_ERROR(ReplayGraph(0)); } else { IncrementRegularRunCountBeforeGraphCapture(); } @@ -3705,10 +3705,10 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Start CUDA graph capture. // Note: The reason we don't put graph capture in OnRunStart() like CUDA EP does is because // current ORT TRT doesn't get cuda stream until compute time and graph capture requires cuda stream. - if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured()) { + if (cuda_graph_enable_ && IsGraphCaptureAllowed() && !IsGraphCaptured(0)) { LOGS_DEFAULT(INFO) << "Capturing the cuda graph for this model"; cuda_graph_.SetStream(stream); - CaptureBegin(); + CaptureBegin(0); } // Run TRT inference @@ -3776,12 +3776,12 @@ Status TensorrtExecutionProvider::CreateNodeComputeInfoFromPrecompiledEngine(con // Note: One reason we don't put end of graph capture in OnRunEnd() like CUDA EP does is because of cuda stream mentioned in graph capture // above, another reason is because OnRunEnd() is not synchronized with OnRunStart() and ExecuteGraph() per inference_session.cc. // It's safe to start/end CUDA graph capture in compute_func() here since cuda graph object is maintained by a per thread basis. - if (cuda_graph_enable_ && !IsGraphCaptured()) { + if (cuda_graph_enable_ && !IsGraphCaptured(0)) { if (IsGraphCaptureAllowed()) { - CaptureEnd(); + CaptureEnd(0); // CUDA work issued to a capturing stream doesn’t actually run on the GPU, // so run the captured graph here to actually execute the work. - ORT_RETURN_IF_ERROR(ReplayGraph()); + ORT_RETURN_IF_ERROR(ReplayGraph(0)); } else { IncrementRegularRunCountBeforeGraphCapture(); } diff --git a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h index 339c45a8742d2..f73031eaefceb 100644 --- a/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h +++ b/onnxruntime/core/providers/tensorrt/tensorrt_execution_provider.h @@ -250,8 +250,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector CreatePreferredAllocators() override; bool IsGraphCaptureEnabled() const override; - bool IsGraphCaptured() const override; - Status ReplayGraph() override; + bool IsGraphCaptured(int graph_annotation_id) const override; + Status ReplayGraph(int graph_annotation_id) override; private: mutable TensorrtExecutionProviderInfo info_; @@ -373,10 +373,10 @@ class TensorrtExecutionProvider : public IExecutionProvider { void InitCUDAGraph(); void SetGraphStream(cudaStream_t stream); bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); - bool IsGraphCaptured() const; - Status ReplayGraph(); + void CaptureBegin(int graph_annotation_id); + void CaptureEnd(int graph_annotation_id); + bool IsGraphCaptured(int graph_annotation_id) const; + Status ReplayGraph(int graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); private: @@ -540,8 +540,8 @@ class TensorrtExecutionProvider : public IExecutionProvider { std::vector& node_compute_funcs); bool IsGraphCaptureAllowed() const; - void CaptureBegin(); - void CaptureEnd(); + void CaptureBegin(int graph_annotation_id); + void CaptureEnd(int graph_annotation_id); void IncrementRegularRunCountBeforeGraphCapture(); /** diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 5fd66c459d382..684f390857d0b 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -2383,21 +2383,32 @@ Status InferenceSession::Run(const RunOptions& run_options, Status retval = Status::OK(); const Env& env = Env::Default(); + int graph_annotation_id = 0; + const std::string& graph_annotation_str = + run_options.config_options.GetConfigOrDefault(kOrtRunOptionsConfigCudaGraphAnnotation, ""); + if (!graph_annotation_str.empty()) { + if (!TryParseStringWithClassicLocale(graph_annotation_str, graph_annotation_id)) { + return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "Failed to parse the cuda graph annotation id: ", + graph_annotation_str); + } + } + // Increment/decrement concurrent_num_runs_ and control // session threads spinning as configured. Do nothing for graph replay except the counter. const bool control_spinning = use_per_session_threads_ && force_spinning_stop_between_runs_ && - !cached_execution_provider_for_graph_replay_.IsGraphCaptured(); + !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id); auto* intra_tp = (control_spinning) ? thread_pool_.get() : nullptr; auto* inter_tp = (control_spinning) ? inter_op_thread_pool_.get() : nullptr; ThreadPoolSpinningSwitch runs_refcounter_and_tp_spin_control(intra_tp, inter_tp, current_num_runs_); // Check if this Run() is simply going to be a CUDA Graph replay. - if (cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { + if (cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) { LOGS(*session_logger_, INFO) << "Replaying the captured " << cached_execution_provider_for_graph_replay_.Type() - << " CUDA Graph for this model with tag: " << run_options.run_tag; - ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph()); + << " CUDA Graph for this model with tag: " << run_options.run_tag + << " with graph annotation id: " << graph_annotation_id; + ORT_RETURN_IF_ERROR_SESSIONID_(cached_execution_provider_for_graph_replay_.ReplayGraph(graph_annotation_id)); } else { InlinedVector exec_providers_to_stop; exec_providers_to_stop.reserve(execution_providers_.NumProviders()); @@ -2559,7 +2570,8 @@ Status InferenceSession::Run(const RunOptions& run_options, // N is defined in min_num_runs_before_hip_graph_capture_ for ROCM EP, // and the value could be different for other EP. if (retval.IsOK() && cached_execution_provider_for_graph_replay_.IsGraphCaptureEnabled() && - !cached_execution_provider_for_graph_replay_.IsGraphCaptured()) { + cached_execution_provider_for_graph_replay_.AllowGraphCaptureOnRun(graph_annotation_id) && + !cached_execution_provider_for_graph_replay_.IsGraphCaptured(graph_annotation_id)) { LOGS(*session_logger_, INFO) << "Start another run for necessary memory allocation or graph capture."; ORT_RETURN_IF_ERROR(Run(run_options, feed_names, feeds, output_names, p_fetches, p_fetches_device_info)); } diff --git a/onnxruntime/core/session/inference_session.h b/onnxruntime/core/session/inference_session.h index f8211bfd2dd4e..3038c8d22ec80 100644 --- a/onnxruntime/core/session/inference_session.h +++ b/onnxruntime/core/session/inference_session.h @@ -675,7 +675,6 @@ class InferenceSession { * If we encounter an invalid request, we return an error * back to the user. */ - [[nodiscard]] common::Status ValidateAndParseShrinkArenaString(const std::string& ort_device_list, /*out*/ InlinedVector& arenas_to_shrink) const; @@ -867,14 +866,17 @@ class InferenceSession { return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptureEnabled(); } - bool IsGraphCaptured() const { - return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured(); + bool IsGraphCaptured(int graph_annotation_id) const { + return cached_execution_provider_for_graph_replay_ != nullptr && cached_execution_provider_for_graph_replay_->IsGraphCaptured(graph_annotation_id); + } + + bool AllowGraphCaptureOnRun(int graph_annotation_id) const { + return cached_execution_provider_for_graph_replay_ != nullptr && graph_annotation_id != kGraphAnnotationSkip; } - Status ReplayGraph() { - ORT_ENFORCE(IsGraphCaptured()); + Status ReplayGraph(int graph_annotation_id) { if (cached_execution_provider_for_graph_replay_) { - return cached_execution_provider_for_graph_replay_->ReplayGraph(); + return cached_execution_provider_for_graph_replay_->ReplayGraph(graph_annotation_id); } return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Cached EP instance for graph replay is not set yet before calling ReplayGraph()"); } @@ -884,6 +886,8 @@ class InferenceSession { } IExecutionProvider* cached_execution_provider_for_graph_replay_ = nullptr; + // TODO(wy): Same as kCudaGraphAnnotationSkip in cuda_graph.h. Move to a common place. + constexpr static int kGraphAnnotationSkip = -1; }; CachedExecutionProviderForGraphReplay cached_execution_provider_for_graph_replay_; diff --git a/onnxruntime/core/session/provider_bridge_ort.cc b/onnxruntime/core/session/provider_bridge_ort.cc index 3bec9aa146f76..d6797512d9e47 100644 --- a/onnxruntime/core/session/provider_bridge_ort.cc +++ b/onnxruntime/core/session/provider_bridge_ort.cc @@ -14,6 +14,7 @@ #include "core/framework/execution_provider.h" #include "core/framework/kernel_registry.h" #include "core/framework/provider_shutdown.h" +#include "core/framework/run_options.h" #include "core/framework/tensorprotoutils.h" #include "core/framework/TensorSeq.h" #include "core/framework/provider_options.h" @@ -676,6 +677,9 @@ struct ProviderHostImpl : ProviderHost { return p->GetConfigEntry(config_key); } + // OrtRunOptions (wrapped) + const ConfigOptions& RunOptions__GetConfigOptions(const RunOptions* p) override { return p->config_options; } + // ComputeCapability (wrapped) std::unique_ptr ComputeCapability__construct(std::unique_ptr t_sub_graph) override { return std::make_unique(std::move(t_sub_graph)); } void ComputeCapability__operator_delete(ComputeCapability* p) override { delete p; } diff --git a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py index 796d6ec55ef80..8083778423241 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py +++ b/onnxruntime/python/tools/transformers/models/phi2/convert_to_onnx.py @@ -13,6 +13,7 @@ import torch from benchmark_helper import Precision from fusion_options import AttentionOpType +from onnx_model import OnnxModel from transformers import AutoConfig, AutoModelForCausalLM from onnxruntime.quantization.matmul_4bits_quantizer import MatMul4BitsQuantizer @@ -168,6 +169,58 @@ def optimize_phi2_onnx(self, onnx_path: str, onnx_path_opt: str): quant.process() quant.model.save_model_to_file(onnx_path_opt, use_external_data_format=True) + # This function currently only works for phi2 model + def convert_to_use_cuda_graph(self, in_onnx_path: str, out_onnx_path: str): + onnx_model = OnnxModel(onnx.load(in_onnx_path, load_external_data=True)) + + from onnx import TensorProto, helper + + graph = onnx_model.graph() + new_inputs = [] + for vi in graph.input: + if "attention_mask" in vi.name: + vi_seqlen_k = helper.make_tensor_value_info( + "seqlens_k", + elem_type=TensorProto.INT32, + shape=["batch_size"], + ) + vi_total_seq_len = helper.make_tensor_value_info( + "total_sequence_length", + elem_type=TensorProto.INT32, + shape=[1], + ) + new_inputs.extend([vi_seqlen_k, vi_total_seq_len]) + else: + new_inputs.append(vi) + + graph.ClearField("input") + graph.input.extend(new_inputs) + + gqas = onnx_model.get_nodes_by_op_type("GroupQueryAttention") + gqa = gqas[0] + seqlens_path = onnx_model.match_parent_path( + gqa, + ["Cast", "Sub", "ReduceSum", "Cast"], + [5, 0, 0, 0], + ) + if seqlens_path is None: + raise RuntimeError("Failed to find seqlens path for GroupQueryAttention node.") + total_seq_len_path = onnx_model.match_parent_path( + gqa, + ["Cast", "Gather", "Shape"], + [6, 0, 0], + ) + if total_seq_len_path is None: + raise RuntimeError("Failed to find total_seq_len path for GroupQueryAttention node.") + onnx_model.remove_nodes(seqlens_path) + onnx_model.remove_nodes(total_seq_len_path) + + for gqa in gqas: + gqa.input[5] = "seqlens_k" + gqa.input[6] = "total_sequence_length" + + onnx_model.save(onnx_model.model, out_onnx_path, save_as_external_data=True) + def parse_arguments(): parser = argparse.ArgumentParser() @@ -235,6 +288,13 @@ def parse_arguments(): help="Generate int4 ONNX model for ORT VLLM", ) + parser.add_argument( + "--use_cuda_graph", + required=False, + action="store_true", + help="Use CUDA Graph in decoding process", + ) + parser.add_argument( "--overwrite", required=False, @@ -265,6 +325,13 @@ def parse_arguments(): help="Run ORT inference example", ) + parser.add_argument( + "--run_benchmark", + required=False, + action="store_true", + help="Run ORT benchmark", + ) + parser.add_argument( "--skip_export", required=False, @@ -375,6 +442,9 @@ def run_optimize_phi2_onnx( ): converter.init_attn_type_and_precision(attention_type, precision) converter.optimize_phi2_onnx(original_onnx_path, optimized_onnx_path) + if args.use_cuda_graph: + assert args.fp16_gpu_sm8x or args.int4_gpu_sm8x + converter.convert_to_use_cuda_graph(optimized_onnx_path, optimized_onnx_path) processes = [] if args.fp32_cpu: @@ -447,7 +517,7 @@ def run_optimize_phi2_onnx( [p.start() for p in processes] [p.join() for p in processes] - if args.run_example: + if args.run_example or args.run_benchmark: from inference_example import run_phi2 if args.fp16_gpu_sm8x: @@ -457,6 +527,8 @@ def run_optimize_phi2_onnx( use_buffer_share=True, device_id=args.device_id, use_step=True, + use_cuda_graph=args.use_cuda_graph, + run_benchmark=args.run_benchmark, ) if args.int4_gpu_sm8x: logging.info("Running int4_gpu_sm8x example...") @@ -465,6 +537,8 @@ def run_optimize_phi2_onnx( use_buffer_share=True, device_id=args.device_id, use_step=True, + use_cuda_graph=args.use_cuda_graph, + run_benchmark=args.run_benchmark, ) if args.fp32_gpu: logging.info("Running fp32_gpu example...") @@ -474,6 +548,7 @@ def run_optimize_phi2_onnx( device_id=args.device_id, packed_kv=True, use_fp16=False, + run_benchmark=args.run_benchmark, ) if args.fp16_gpu: logging.info("Running fp16_gpu example...") @@ -482,6 +557,7 @@ def run_optimize_phi2_onnx( use_buffer_share=False, device_id=args.device_id, packed_kv=True, + run_benchmark=args.run_benchmark, ) if args.int4_gpu: logging.info("Running int4_gpu example...") @@ -490,6 +566,7 @@ def run_optimize_phi2_onnx( use_buffer_share=False, device_id=args.device_id, packed_kv=True, + run_benchmark=args.run_benchmark, ) if args.fp32_cpu or args.int4_cpu or args.fp16_vllm or args.int4_vllm: raise NotImplementedError("CPU/vllm inference example is not implemented yet.") diff --git a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py index 28828ffb853cb..829334b46b469 100644 --- a/onnxruntime/python/tools/transformers/models/phi2/inference_example.py +++ b/onnxruntime/python/tools/transformers/models/phi2/inference_example.py @@ -17,6 +17,17 @@ } +def cuda_memcpy(dst, src): + from cuda import cudart + + cudart.cudaMemcpy( + dst.data_ptr(), + src.data_ptr(), + src.element_size() * src.nelement(), + cudart.cudaMemcpyKind.cudaMemcpyDeviceToDevice, + ) + + class ORTGenerator: def __init__(self, decoder_path): self.onnx_decoder_path = decoder_path @@ -24,13 +35,68 @@ def __init__(self, decoder_path): self.head_size = 80 self.num_layers = 32 self.max_sequence_length = 2048 + self.device_id = 0 + self.use_cuda_graph = False + self.use_traced_inputs = False + self.static_inputs_map = {} + + def append_static_inputs(self, batch_size): + # Only use this function with GQA and with use_cuda_graph=True + if batch_size in self.static_inputs_map: + return + + cpu_device = torch.device("cpu") + cuda_device = torch.device("cuda", self.device_id) + + static_io = {} + static_io["input_ids"] = torch.zeros((batch_size, 1), dtype=torch.int32, device=cuda_device) + static_io["step"] = torch.tensor([0], dtype=torch.int64, device=cuda_device) + static_io["seqlens_k"] = torch.tensor(batch_size * [0], dtype=torch.int32, device=cuda_device) + static_io["total_sequence_length"] = torch.tensor([0], dtype=torch.int32, device=cpu_device) + + cache_shape = (batch_size, self.num_heads, self.max_sequence_length, self.head_size) + for i in range(self.num_layers): + cache = torch.zeros(cache_shape, device=cuda_device, dtype=torch.float16) + static_io.update({f"past_key_{i}": cache.contiguous(), f"past_value_{i}": cache.clone().contiguous()}) + + static_io["logits"] = torch.zeros((batch_size, 1, 51200), dtype=torch.float16, device=cuda_device) + + self.static_inputs_map[batch_size] = static_io def get_initial_inputs_and_outputs(self, encodings_dict): self.torch_dtype = torch.float16 if self.use_fp16 else torch.float32 input_ids = torch.tensor(encodings_dict["input_ids"], device=self.device, dtype=torch.int32) attention_mask = torch.tensor(encodings_dict["attention_mask"], device=self.device, dtype=torch.int32) - step = torch.tensor([0], device=self.device, dtype=torch.int64) + + batch_size, sequence_length = input_ids.shape + + self.use_traced_inputs = ( + self.use_cuda_graph + and (batch_size in self.static_inputs_map) + and self.use_buffer_share + and not self.packed_kv + ) + + step = ( + torch.tensor([0], device=self.device, dtype=torch.int64) + if not self.use_traced_inputs + else self.static_inputs_map[batch_size]["step"] + ) + + seqlens_k = ( + torch.tensor(batch_size * [0], device=self.device, dtype=torch.int32) + if not self.use_traced_inputs + else self.static_inputs_map[batch_size]["seqlens_k"] + ) + cuda_memcpy(seqlens_k, attention_mask.sum(1).sub(1).to(torch.int32)) + + total_seq_length = ( + torch.tensor([0], device=torch.device("cpu"), dtype=torch.int32) + if not self.use_traced_inputs + else self.static_inputs_map[batch_size]["total_sequence_length"] + ) + total_seq_length[0] = sequence_length inputs = { "input_ids": input_ids.contiguous(), @@ -40,7 +106,10 @@ def get_initial_inputs_and_outputs(self, encodings_dict): if self.use_step: inputs["step"] = step.contiguous() - batch_size, sequence_length = input_ids.shape + if self.use_cuda_graph: + inputs["seqlens_k"] = seqlens_k.contiguous() + inputs["total_sequence_length"] = total_seq_length.contiguous() + del inputs["attention_mask"] past_seq_length = self.max_sequence_length if self.use_buffer_share else 0 past_shape = ( @@ -48,11 +117,21 @@ def get_initial_inputs_and_outputs(self, encodings_dict): if self.packed_kv else (batch_size, self.num_heads, past_seq_length, self.head_size) ) - for i in range(self.num_layers): - past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) - inputs.update( - {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} - ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + + if not self.use_traced_inputs: + for i in range(self.num_layers): + past = torch.zeros(past_shape, device=self.device, dtype=self.torch_dtype) + inputs.update( + {f"past_key_{i}": past.contiguous(), f"past_value_{i}": past.clone().contiguous()} + ) if not self.packed_kv else inputs.update({f"past_{i}": past.contiguous()}) + else: + for i in range(self.num_layers): + inputs.update( + { + f"past_key_{i}": self.static_inputs_map[batch_size][f"past_key_{i}"].contiguous(), + f"past_value_{i}": self.static_inputs_map[batch_size][f"past_value_{i}"].contiguous(), + } + ) logits = torch.zeros(batch_size, sequence_length, 51200, device=self.device, dtype=self.torch_dtype) outputs = {"logits": logits.contiguous()} @@ -111,12 +190,23 @@ def apply_io_binding(self, model: ort.InferenceSession, inputs: dict, outputs: d return io_binding - def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False): + def create_session( + self, device_id, use_fp16=True, use_buffer_share=True, packed_kv=False, use_step=False, use_cuda_graph=False + ): + self.device_id = device_id sess_options = ort.SessionOptions() - ep = ("CUDAExecutionProvider", {"device_id": device_id}) if device_id >= 0 else "CPUExecutionProvider" + sess_options.log_verbosity_level = 4 + sess_options.log_severity_level = 4 + self.use_cuda_graph = use_cuda_graph + ep = ( + ("CUDAExecutionProvider", {"device_id": self.device_id, "enable_cuda_graph": self.use_cuda_graph}) + if self.device_id >= 0 + else "CPUExecutionProvider" + ) self.sess = ort.InferenceSession(self.onnx_decoder_path, sess_options=sess_options, providers=[ep]) + self.ro = ort.RunOptions() - self.device = torch.device("cuda", device_id) if torch.cuda.is_available() else torch.device("cpu") + self.device = torch.device("cuda", self.device_id) if torch.cuda.is_available() else torch.device("cpu") self.use_fp16 = use_fp16 self.use_buffer_share = use_buffer_share self.packed_kv = packed_kv @@ -125,9 +215,7 @@ def create_session(self, device_id, use_fp16=True, use_buffer_share=True, packed self.tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-2", trust_remote_code=True) self.tokenizer.pad_token = "[PAD]" - def generate(self, prompt, max_length): - encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) - + def generate_impl(self, encodings_dict, max_length, cuda_graph_annotation, benchmark=False): inputs, outputs = self.get_initial_inputs_and_outputs(encodings_dict) all_token_ids = inputs["input_ids"].clone() @@ -136,13 +224,38 @@ def generate(self, prompt, max_length): current_length = sequence_length has_eos = torch.zeros(batch_size, device=self.device, dtype=torch.bool) + if benchmark: + import time + + latency = [] + + prompt_run = True while current_length < max_length: io_binding = self.apply_io_binding(self.sess, inputs, outputs) + if benchmark: + start = time.time() + io_binding.synchronize_inputs() - self.sess.run_with_iobinding(io_binding) + if prompt_run: + if self.use_cuda_graph: + # Disable CUDA graph for the prompt run + self.ro.add_run_config_entry("gpu_graph_id", "-1") + self.sess.run_with_iobinding(io_binding, self.ro) + if self.use_cuda_graph: + # Enable CUDA graph for the decoding run + self.ro.add_run_config_entry( + "gpu_graph_id", str(cuda_graph_annotation) if self.use_traced_inputs else "-1" + ) + prompt_run = False + else: + self.sess.run_with_iobinding(io_binding, self.ro) io_binding.synchronize_outputs() + if benchmark: + end = time.time() + latency.append(end - start) + # Sample with argmax (greedy search) next_token_logits = outputs["logits"][:, -1, :] next_tokens = torch.argmax(next_token_logits, dim=-1) @@ -161,16 +274,37 @@ def generate(self, prompt, max_length): # Update inputs for next inference run current_length += 1 + inputs["input_ids"] = tokens_to_add.to(torch.int32) + if self.use_traced_inputs: + cuda_memcpy(self.static_inputs_map[batch_size]["input_ids"], inputs["input_ids"]) + inputs["input_ids"] = self.static_inputs_map[batch_size]["input_ids"] + if self.use_step: inputs["step"] = torch.tensor([current_length - 1], device=self.device, dtype=torch.int64) - inputs["attention_mask"] = torch.cat([inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1).to( - torch.int32 - ) + if self.use_traced_inputs: + cuda_memcpy(self.static_inputs_map[batch_size]["step"], inputs["step"]) + inputs["step"] = self.static_inputs_map[batch_size]["step"] + + if self.use_cuda_graph: + previous_seqlens_k = inputs["seqlens_k"] + inputs["seqlens_k"] = (previous_seqlens_k + (~has_eos).reshape(batch_size, 1)).to(torch.int32) + inputs["total_sequence_length"][0] = current_length + if self.use_traced_inputs: + cuda_memcpy(self.static_inputs_map[batch_size]["seqlens_k"], inputs["seqlens_k"]) + inputs["seqlens_k"] = self.static_inputs_map[batch_size]["seqlens_k"] + self.static_inputs_map[batch_size]["total_sequence_length"][0] = inputs["total_sequence_length"][0] + inputs["total_sequence_length"] = self.static_inputs_map[batch_size]["total_sequence_length"] + else: + inputs["attention_mask"] = torch.cat( + [inputs["attention_mask"], (~has_eos).reshape(batch_size, 1)], 1 + ).to(torch.int32) # Set logits to zeros for next inference run and re-use memory buffer if outputs["logits"].shape[1] != 1: outputs["logits"] = outputs["logits"][:, :1, :].contiguous() + if self.use_traced_inputs: + outputs["logits"] = self.static_inputs_map[batch_size]["logits"] outputs["logits"].zero_() if not self.use_buffer_share: @@ -193,11 +327,59 @@ def generate(self, prompt, max_length): {f"present_key_{i}": present.contiguous(), f"present_value_{i}": present.clone().contiguous()} ) if not self.packed_kv else outputs.update({f"present_{i}": present.contiguous()}) + if benchmark: + print( + f"Batch size: {batch_size}, Sequence length: {sequence_length}, Token num: {max_length - sequence_length}" + ) + print(f"Prompt letency: {1000 * latency[0]}ms, Token latency: {1000 * np.mean(latency[1:])}ms") + return + texts = self.tokenizer.batch_decode(all_token_ids, skip_special_tokens=True) return texts + def generate(self, prompt, max_length, cuda_graph_annotation): + encodings_dict = self.tokenizer.batch_encode_plus(prompt, padding=True) + + return self.generate_impl(encodings_dict, max_length, cuda_graph_annotation) + + def generate_benchmark(self, prompt_shape, token_num, cuda_graph_annotation): + batch_size, sequence_length = prompt_shape + max_length = sequence_length + token_num + + encodings_dict = {} + encodings_dict["input_ids"] = torch.randint(0, 50264, (batch_size, sequence_length), dtype=torch.int32).tolist() + encodings_dict["attention_mask"] = torch.ones((batch_size, sequence_length), dtype=torch.int32).tolist() + + # Warm up run + self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=False) + + # Benchmark run + self.generate_impl(encodings_dict, max_length, cuda_graph_annotation, benchmark=True) + + +def run_phi2( + onnx_model_path, + use_buffer_share, + device_id, + packed_kv=False, + use_fp16=True, + use_step=False, + use_cuda_graph=False, + run_benchmark=False, +): + generator = ORTGenerator(onnx_model_path) + generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step, use_cuda_graph) + + def simple_run(prompt): + example_batch_size = len(prompt) + if use_cuda_graph: + generator.append_static_inputs(batch_size=example_batch_size) + texts = generator.generate(prompt, max_length=210, cuda_graph_annotation=example_batch_size) + + for i in range(len(texts)): + print("Prompt: ", prompt[i]) + print("Texts: ", texts[i]) -def run_phi2(onnx_model_path, use_buffer_share, device_id, packed_kv=False, use_fp16=True, use_step=False): prompt = [ '''```python def print_prime(n): @@ -206,10 +388,14 @@ def print_prime(n): """''' ] - generator = ORTGenerator(onnx_model_path) - generator.create_session(device_id, use_fp16, use_buffer_share, packed_kv, use_step) - texts = generator.generate(prompt, max_length=200) - - for i in range(len(texts)): - print("Prompt: ", prompt[i]) - print("Texts: ", texts[i]) + if not run_benchmark: + simple_run(prompt) + + # Run simple benchmark. Time the decoder only. + if run_benchmark: + token_num = 32 + for batch_size in [1, 2, 4, 8]: + generator.append_static_inputs(batch_size) + for sequence_length in [16, 512]: + prompt_shape = (batch_size, sequence_length) + generator.generate_benchmark(prompt_shape, token_num, cuda_graph_annotation=batch_size) diff --git a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py index c4e13e773535d..ce04dff2aecb0 100644 --- a/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py +++ b/onnxruntime/test/python/onnxruntime_test_python_cudagraph.py @@ -84,6 +84,7 @@ def test_select_ep_to_run_cuda_graph(self): elif "CUDAExecutionProvider" in onnxrt.get_available_providers(): providers = [("CUDAExecutionProvider", {"enable_cuda_graph": True})] self.run_model_with_cuda_graph(providers) + self.run_model_with_cuda_graph_annotation(providers) def run_model_with_cuda_graph(self, providers): INPUT_SIZE = 1280 # noqa: N806 @@ -100,13 +101,15 @@ def run_model_with_cuda_graph(self, providers): io_binding.bind_ortvalue_input("X", x_ortvalue) io_binding.bind_ortvalue_output("Y", y_ortvalue) + ro = onnxrt.RunOptions() + # One regular run for the necessary memory allocation and cuda graph capturing - session.run_with_iobinding(io_binding) + session.run_with_iobinding(io_binding, ro) expected_y = np.array([[5.0], [11.0], [17.0]] * INPUT_SIZE, dtype=np.float32) np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) # After capturing, CUDA graph replay happens from this Run onwards - session.run_with_iobinding(io_binding) + session.run_with_iobinding(io_binding, ro) np.testing.assert_allclose(expected_y, y_ortvalue.numpy(), rtol=1e-05, atol=1e-05) # Update input and then replay CUDA graph @@ -116,7 +119,7 @@ def run_model_with_cuda_graph(self, providers): dtype=np.float32, ) ) - session.run_with_iobinding(io_binding) + session.run_with_iobinding(io_binding, ro) np.testing.assert_allclose( np.array([[50.0], [110.0], [170.0]] * INPUT_SIZE, dtype=np.float32), y_ortvalue.numpy(), @@ -124,6 +127,58 @@ def run_model_with_cuda_graph(self, providers): atol=1e-05, ) + def run_model_with_cuda_graph_annotation(self, providers): + INPUT_SIZE = 1280 # noqa: N806 + + x_base = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]] + y_base = [[0.0], [0.0], [0.0], [0.0]] + expected_y_base = [[5.0], [11.0], [17.0], [23.0]] + + x_base_mul_10 = [[10.0, 20.0], [30.0, 40.0], [50.0, 60.0], [70.0, 80.0]] + expected_y_base_mul_10 = [[50.0], [110.0], [170.0], [230.0]] + + test_num = 4 + + x_ortvalues = [] + y_ortvalues = [] + for i in range(test_num): + x = np.array(x_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + y = np.array(y_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + x_ortvalues.append(onnxrt.OrtValue.ortvalue_from_numpy(x, "cuda", 0)) + y_ortvalues.append(onnxrt.OrtValue.ortvalue_from_numpy(y, "cuda", 0)) + + onnxrt.set_default_logger_severity(0) + session = onnxrt.InferenceSession(get_name("matmul_2.onnx"), providers=providers) + io_bindings = [session.io_binding()] * test_num + ro = onnxrt.RunOptions() + + # Regular run to capture CUDA graph + for i in range(test_num): + io_bindings[i].bind_ortvalue_input("X", x_ortvalues[i]) + io_bindings[i].bind_ortvalue_output("Y", y_ortvalues[i]) + # TODO: Temporarily remove the default cuda graph capture test for the first regular run + # because it fails on a training CI. Need to investigate the root cause. + ro.add_run_config_entry("gpu_graph_id", str(i + 1)) + io_bindings[i].synchronize_inputs() + session.run_with_iobinding(io_bindings[i], ro) + io_bindings[i].synchronize_outputs() + expected_y = np.array(expected_y_base[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalues[i].numpy(), rtol=1e-05, atol=1e-05) + + del ro + ro = onnxrt.RunOptions() + + # After capturing, CUDA graph replay happens from this Run onwards + for i in range(test_num): + # Update input and then replay CUDA graph + x_ortvalues[i].update_inplace(np.array(x_base_mul_10[: i + 1][:] * INPUT_SIZE, dtype=np.float32)) + ro.add_run_config_entry("gpu_graph_id", str(i + 1)) + io_bindings[i].synchronize_inputs() + session.run_with_iobinding(io_bindings[i], ro) + io_bindings[i].synchronize_outputs() + expected_y = np.array(expected_y_base_mul_10[: i + 1][:] * INPUT_SIZE, dtype=np.float32) + np.testing.assert_allclose(expected_y, y_ortvalues[i].numpy(), rtol=1e-05, atol=1e-05) + def test_arena_with_cuda_graph(self): if "CUDAExecutionProvider" in onnxrt.get_available_providers(): # To test cuda graph catpure, we set Arena extend strategy to be SameAsRequested so as to detect any diff --git a/onnxruntime/test/shared_lib/test_inference.cc b/onnxruntime/test/shared_lib/test_inference.cc index 8dad2c8e2d10d..453b5fdd360bf 100644 --- a/onnxruntime/test/shared_lib/test_inference.cc +++ b/onnxruntime/test/shared_lib/test_inference.cc @@ -180,6 +180,9 @@ static void TestInference(Ort::Env& env, const std::basic_string& mod } static constexpr PATH_TYPE MODEL_URI = TSTR("testdata/mul_1.onnx"); +#if defined(USE_CUDA) +static constexpr PATH_TYPE CUDA_GRAPH_ANNOTATION_MODEL_URI = TSTR("testdata/mul_1_dynamic.onnx"); +#endif static constexpr PATH_TYPE MATMUL_MODEL_URI = TSTR("testdata/matmul_1.onnx"); #ifndef ORT_NO_RTTI static constexpr PATH_TYPE SEQUENCE_MODEL_URI = TSTR("testdata/sequence_length.onnx"); @@ -2082,6 +2085,152 @@ TEST(CApiTest, basic_cuda_graph) { #endif } +#if defined(USE_CUDA) +struct CudaGraphInputOutputData_0 { + const std::array x_shape = {3, 2}; + std::array x_values = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f}; + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 4.0f, 9.0f, 16.0f, 25.0f, 36.0f}; + + std::array y_values; + std::array new_x_values = {10.0f, 20.0f, 30.0f, 40.0f, 50.0f, 60.0f}; + std::array new_expected_y = {10.0f, 40.0f, 90.0f, 160.0f, 250.0f, 360.0f}; +} cg_data_0; + +struct CudaGraphInputOutputData_1 { + const std::array x_shape = {3, 1}; + std::array x_values = {1.0f, 3.0f, 5.0f}; + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 2.0f, 9.0f, 12.0f, 25.0f, 30.0f}; + + std::array y_values; + std::array new_x_values = {10.0f, 30.0f, 50.0f}; + std::array new_expected_y = {10.0f, 20.0f, 90.0f, 120.0f, 250.0f, 300.0f}; +} cg_data_1; + +struct CudaGraphInputOutputData_2 { + const std::array x_shape = {1, 2}; + std::array x_values = {1.0f, 2.0f}; + const std::array expected_y_shape = {3, 2}; + std::array expected_y = {1.0f, 4.0f, 3.0f, 8.0f, 5.0f, 12.0f}; + + std::array y_values; + std::array new_x_values = {10.0f, 20.0f}; + std::array new_expected_y = {10.0f, 40.0f, 30.0f, 80.0f, 50.0f, 120.0f}; +} cg_data_2; + +template +static void RunWithCudaGraphAnnotation(T& cg_data, + Ort::Session& session, + Ort::MemoryInfo& info_mem, + Ort::MemoryAllocation& input_data, + Ort::MemoryAllocation& output_data, + const char* cuda_graph_annotation) { + (void)cudaMemcpy(input_data.get(), + cg_data.x_values.data(), + sizeof(float) * cg_data.x_values.size(), + cudaMemcpyHostToDevice); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_x = Ort::Value::CreateTensor(info_mem, + reinterpret_cast(input_data.get()), + cg_data.x_values.size(), + cg_data.x_shape.data(), + cg_data.x_shape.size()); + + // Create an OrtValue tensor backed by data on CUDA memory + Ort::Value bound_y = Ort::Value::CreateTensor(info_mem, + reinterpret_cast(output_data.get()), + cg_data.expected_y.size(), + cg_data.expected_y_shape.data(), + cg_data.expected_y_shape.size()); + + // Create IoBinding for inputs and outputs. + Ort::IoBinding binding(session); + binding.BindInput("X", bound_x); + binding.BindOutput("Y", bound_y); + + Ort::RunOptions run_option; + if (cuda_graph_annotation != nullptr) { + run_option.AddConfigEntry(kOrtRunOptionsConfigCudaGraphAnnotation, cuda_graph_annotation); + } + + // One regular run for necessary memory allocation and graph capturing + session.Run(run_option, binding); + + // Check the values against the bound raw memory (needs copying from device to host first) + (void)cudaMemcpy(cg_data.y_values.data(), + output_data.get(), + sizeof(float) * cg_data.y_values.size(), + cudaMemcpyDeviceToHost); + ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.expected_y)); + + // Replay the captured CUDA graph + session.Run(run_option, binding); + (void)cudaMemcpy(cg_data.y_values.data(), + output_data.get(), + sizeof(float) * cg_data.y_values.size(), + cudaMemcpyDeviceToHost); + ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.expected_y)); + + // Change the input and replay the CUDA graph again. + (void)cudaMemcpy(input_data.get(), + cg_data.new_x_values.data(), + sizeof(float) * cg_data.new_x_values.size(), + cudaMemcpyHostToDevice); + binding.SynchronizeInputs(); + + session.Run(run_option, binding); + (void)cudaMemcpy(cg_data.y_values.data(), + output_data.get(), + sizeof(float) * cg_data.y_values.size(), + cudaMemcpyDeviceToHost); + ASSERT_THAT(cg_data.y_values, ::testing::ContainerEq(cg_data.new_expected_y)); + + // Clean up + binding.ClearBoundInputs(); + binding.ClearBoundOutputs(); +} + +TEST(CApiTest, basic_cuda_graph_with_annotation) { + const auto& api = Ort::GetApi(); + Ort::SessionOptions session_options; + + // Enable cuda graph in cuda provider option. + OrtCUDAProviderOptionsV2* cuda_options = nullptr; + ASSERT_TRUE(api.CreateCUDAProviderOptions(&cuda_options) == nullptr); + std::unique_ptr + rel_cuda_options(cuda_options, api.ReleaseCUDAProviderOptions); + std::vector keys{"enable_cuda_graph"}; + std::vector values{"1"}; + ASSERT_TRUE(api.UpdateCUDAProviderOptions(rel_cuda_options.get(), keys.data(), values.data(), 1) == nullptr); + + ASSERT_TRUE(api.SessionOptionsAppendExecutionProvider_CUDA_V2( + static_cast(session_options), + rel_cuda_options.get()) == nullptr); + + Ort::Session session(*ort_env, CUDA_GRAPH_ANNOTATION_MODEL_URI, session_options); + Ort::MemoryInfo info_mem("Cuda", OrtAllocatorType::OrtArenaAllocator, 0, OrtMemTypeDefault); + + Ort::Allocator allocator(session, info_mem); + auto allocator_info = allocator.GetInfo(); + ASSERT_TRUE(info_mem == allocator_info); + + size_t max_input_size = 6; + size_t max_output_size = 6; + + auto input_data = allocator.GetAllocation(max_input_size * sizeof(float)); + auto output_data = allocator.GetAllocation(max_output_size * sizeof(float)); + + ASSERT_NE(input_data.get(), nullptr); + ASSERT_NE(output_data.get(), nullptr); + + RunWithCudaGraphAnnotation(cg_data_0, session, info_mem, input_data, output_data, nullptr); + RunWithCudaGraphAnnotation(cg_data_1, session, info_mem, input_data, output_data, "1"); + RunWithCudaGraphAnnotation(cg_data_2, session, info_mem, input_data, output_data, "2"); +} +#endif + // The following test uses some ops not supported in the reduced ops build #ifndef REDUCED_OPS_BUILD #if defined(USE_CUDA) || defined(USE_TENSORRT) diff --git a/onnxruntime/test/testdata/mul_1_dynamic.onnx b/onnxruntime/test/testdata/mul_1_dynamic.onnx new file mode 100644 index 0000000000000..fb7822498b004 Binary files /dev/null and b/onnxruntime/test/testdata/mul_1_dynamic.onnx differ