Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Feb 24, 2024
1 parent 3b4de7a commit df7b525
Showing 1 changed file with 0 additions and 11 deletions.
11 changes: 0 additions & 11 deletions onnxruntime/core/providers/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,11 @@ void CUDAGraph::SetGraphAnnotation(GraphAnnotationOptional_t cuda_graph_annotati

void CUDAGraph::CaptureBegin() {
if (!cuda_graph_annotation_id_.has_value()) {
//std::cout << "CaptureBegin: cuda_graph_annotation_id is empty" << std::endl;
ORT_ENFORCE(!has_graph_exec_,
"This cuda graph has already captured a graph. "
"Create a new instance to capture a new graph.");
} else {
//std::cout << "CaptureBegin: cuda_graph_annotation_id is " << *cuda_graph_annotation_id_ << std::endl;
if (!IsGraphCaptureAllowedOnRun()) {
//std::cout << "CaptureBegin: Graph capture is not allowed on this run" << std::endl;
return;
}
}
Expand All @@ -39,18 +36,15 @@ void CUDAGraph::CaptureBegin() {
// will support multiple threads. For multiple threads with multiple graphs
// and streams, `cudaStreamCaptureModeGlobal` needs to be changed to
// `cudaStreamCaptureModeThreadLocal`
//std::cout << "REAL cuda graph capture begins" << std::endl;
CUDA_CALL_THROW(cudaStreamBeginCapture(stream_, cudaStreamCaptureModeGlobal));
}

void CUDAGraph::CaptureEnd() {
//std::cout << "CUDAGraph::CaptureEnd()" << std::endl;
if (!IsGraphCaptureAllowedOnRun()) {
return;
}

if (cuda_graph_annotation_id_.has_value()) {
//std::cout << "CaptureEnd: cuda_graph_annotation_id is " << *cuda_graph_annotation_id_ << std::endl;
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &additional_graph_));
if (additional_graph_ == NULL) {
ORT_THROW("CUDAGraph::CaptureEnd: additional_graph_ is NULL");
Expand All @@ -59,7 +53,6 @@ void CUDAGraph::CaptureEnd() {
cudaGraphExec_t graph_exec = NULL;

has_additional_graph_ = true;
//std::cout << "REAL cuda graph capture ends" << std::endl;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec, additional_graph_, NULL, NULL, 0));
CUDA_CALL_THROW(cudaGraphDestroy(additional_graph_));
has_additional_graph_ = false;
Expand All @@ -70,14 +63,12 @@ void CUDAGraph::CaptureEnd() {
return;
}

//std::cout << "CaptureEnd: cuda_graph_annotation_id is empty" << std::endl;
CUDA_CALL_THROW(cudaStreamEndCapture(stream_, &graph_));
if (graph_ == NULL) {
ORT_THROW("CUDAGraph::CaptureEnd: graph_ is NULL");
}

has_graph_ = true;
//std::cout << "REAL cuda graph capture ends" << std::endl;
CUDA_CALL_THROW(cudaGraphInstantiate(&graph_exec_, graph_, NULL, NULL, 0));
has_graph_exec_ = true;
CUDA_CALL_THROW(cudaGraphDestroy(graph_));
Expand All @@ -91,7 +82,6 @@ Status CUDAGraph::Replay(GraphAnnotationOptional_t cuda_graph_annotation_id) {
// Although this function is not thread safe, the lock is not needed here because
// CUDA EP maintains a separate cuda graph per thread
if (cuda_graph_annotation_id_.has_value()) {
//std::cout << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " << *cuda_graph_annotation_id << std::endl;
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_ << " with cuda_graph_annotation_id " << *cuda_graph_annotation_id;
auto it = graph_exec_map_.find(*cuda_graph_annotation_id);
if (it == graph_exec_map_.end()) {
Expand All @@ -101,7 +91,6 @@ Status CUDAGraph::Replay(GraphAnnotationOptional_t cuda_graph_annotation_id) {
}
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(it->second, stream_));
} else {
//std::cout << "Replaying CUDA graph on stream " << stream_ << std::endl;
LOGS_DEFAULT(INFO) << "Replaying CUDA graph on stream " << stream_;
CUDA_RETURN_IF_ERROR(cudaGraphLaunch(graph_exec_, stream_));
}
Expand Down

0 comments on commit df7b525

Please sign in to comment.