Skip to content

Commit

Permalink
refine for cuda graph (#301)
Browse files Browse the repository at this point in the history
  • Loading branch information
yufenglee authored Apr 23, 2024
1 parent b1180a6 commit 0f7fb7a
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,12 +544,9 @@ void Model::GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params) {
max_batch_size_ = params.max_batch_size;

if (DeviceType::CUDA == device_type_) {
if (max_batch_size_ == 0 && is_cuda_graph_enabled) {
throw std::runtime_error("CUDA graph is enabled, but max_batch_size is not set.");
}
if (max_batch_size_ > 0) {
if (!is_cuda_graph_enabled) {
throw std::runtime_error("CUDA graph is not enabled.");
if (is_cuda_graph_enabled) {
if (max_batch_size_ == 0) {
throw std::runtime_error("CUDA graph is enabled, but max_batch_size is not set.");
}
use_cuda_graph_ = true;
}
Expand Down

0 comments on commit 0f7fb7a

Please sign in to comment.