diff --git a/src/config.cpp b/src/config.cpp index a57ec4b99..dff68f07d 100644 --- a/src/config.cpp +++ b/src/config.cpp @@ -322,7 +322,7 @@ bool IsCudaGraphEnabled(Config::SessionOptions& session_options) { for (const auto& provider_options : session_options.provider_options) { if (provider_options.name == "cuda") { for (const auto& value : provider_options.options) { - if (value.first == "enable_cuda_graph") { + if (value.first == "can_use_cuda_graph") { // Is it the correct value string? return value.second == "1"; } @@ -399,7 +399,7 @@ Config::Config(const std::filesystem::path& path) : config_path{path} { if (search.max_length == 0) search.max_length = model.context_length; - use_cuda_graphs = IsCudaGraphEnabled(model.decoder.session_options); + can_use_cuda_graphs = IsCudaGraphEnabled(model.decoder.session_options); } } // namespace Generators diff --git a/src/config.h b/src/config.h index 1b14e97f6..38d5c69fe 100644 --- a/src/config.h +++ b/src/config.h @@ -8,10 +8,7 @@ struct Config { std::filesystem::path config_path; // Path of the config directory - bool use_cuda_graphs{false}; - // TODO: pass it from config file/generator ctor when serving stack is supported - // Hardcoded for now - size_t max_batch_size{16}; + bool can_use_cuda_graphs{false}; using ProviderOption = std::pair; struct ProviderOptions { diff --git a/src/generators.h b/src/generators.h index 3fb9f5201..55fea3a3d 100644 --- a/src/generators.h +++ b/src/generators.h @@ -57,6 +57,7 @@ struct GeneratorParams : std::enable_shared_from_this { int context_length{}; int batch_size{1}; + int max_batch_size{0}; int sequence_length{}; int BatchBeamSize() const { return search.num_beams * batch_size; } @@ -74,7 +75,7 @@ struct GeneratorParams : std::enable_shared_from_this { struct T5 { std::span encoder_input_ids; // Array of [batchsize][sequence_length] - std::span decoder_input_ids; // Array of [batchsize][sequence_length] + std::span decoder_input_ids; // Array of [batchsize][sequence_length] }; using Bart=T5; diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 4de1e5672..213caad09 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -26,7 +26,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra RoamingArray DecoderOnly_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { if (first_run_) { - if (model_.config_->use_cuda_graphs) { + if (model_.use_cuda_graphs_) { model_.run_options_->AddConfigEntry("gpu_graph_id", "-1"); } first_run_ = false; @@ -37,7 +37,7 @@ RoamingArray DecoderOnly_State::Run(int current_length, RoamingArrayuse_cuda_graphs) { + if (model_.use_cuda_graphs_) { int new_graph_annotation_id = GetGraphAnnotationId(); if (new_graph_annotation_id != graph_annotation_id_) { graph_annotation_id_ = new_graph_annotation_id; diff --git a/src/models/input_ids.cpp b/src/models/input_ids.cpp index 6bb4b01ea..7781b8b63 100644 --- a/src/models/input_ids.cpp +++ b/src/models/input_ids.cpp @@ -28,8 +28,8 @@ InputIDs::InputIDs(const Model& model, State& state) value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams); shape_[0] *= state_.params_->search.num_beams; - if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) { - size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size; + if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) { + size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_; sb_input_ids_ = std::make_unique(model_.allocator_device_, max_beam_batch_size); } } @@ -48,7 +48,7 @@ void InputIDs::Update(RoamingArray next_tokens_unk) { if (!sb_input_ids_) { value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_); } else { - value_ = sb_input_ids_->GetOrCreateTensor(shape_, type_); + value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_); } state_.inputs_[input_index_] = value_.get(); diff --git a/src/models/kv_cache.cpp b/src/models/kv_cache.cpp index efb031ce4..f76ae0557 100644 --- a/src/models/kv_cache.cpp +++ b/src/models/kv_cache.cpp @@ -146,9 +146,9 @@ KV_Cache::KV_Cache(const Model& model, State& state) else shape_[2] = state_.params_->sequence_length; - if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) { + if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) { assert(past_present_share_buffer_); - size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size; + size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_; sb_kv_caches_.reserve(layer_count_ * 2); for (int i = 0; i < layer_count_ * 2; ++i) { sb_kv_caches_.push_back(std::make_unique(model_.allocator_device_, max_beam_batch_size)); @@ -158,7 +158,7 @@ KV_Cache::KV_Cache(const Model& model, State& state) for (int i = 0; i < layer_count_ * 2; ++i) { presents_.push_back( sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) - : sb_kv_caches_[i]->GetOrCreateTensor(shape_, type_)); + : sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_)); } } diff --git a/src/models/logits.cpp b/src/models/logits.cpp index 9daacec94..a964fe727 100644 --- a/src/models/logits.cpp +++ b/src/models/logits.cpp @@ -18,8 +18,8 @@ Logits::Logits(const Model& model, State& state) else value16_ = std::move(logits_tensor); - if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) { - size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size; + if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) { + size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_; if (type_ == Ort::TypeToTensorType::type) { sb_logits32_ = std::make_unique(model_.allocator_device_, max_beam_batch_size); } @@ -49,7 +49,7 @@ RoamingArray Logits::Get() { shape_[1] = 1; // bugbug: not done yet auto value_next = !sb_logits32_ ? OrtValue::CreateTensor(*model_.allocator_device_, shape_) - : sb_logits32_->GetOrCreateTensor(shape_, type_); + : sb_logits32_->CreateTensorOnStaticBuffer(shape_, type_); auto logits_next = cpu_span{value_next->GetTensorMutableData(), element_count}; size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process @@ -83,7 +83,7 @@ RoamingArray Logits::Get() { value32_ = std::move(value_next); if (type_ == Ort::TypeToTensorType::type) value16_ = !sb_logits16_ ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_) - : sb_logits16_->GetOrCreateTensor(shape_, type_); + : sb_logits16_->CreateTensorOnStaticBuffer(shape_, type_); state_.outputs_[output_index_] = type_ == Ort::TypeToTensorType::type ? value32_.get() : value16_.get(); } diff --git a/src/models/model.cpp b/src/models/model.cpp index 75292143f..7681ff6b3 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -473,4 +473,14 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, return expanded; } +void Model::GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params) { + max_batch_size_ = params.max_batch_size; + if (max_batch_size_ > 0 && DeviceType::CUDA == device_type_) { + if (!config_->can_use_cuda_graphs) { + throw std::runtime_error("CUDA graphs are not enabled in this model"); + } + use_cuda_graphs_ = true; + } +} + } // namespace Generators diff --git a/src/models/model.h b/src/models/model.h index cab82a3e3..d60c353fe 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -108,6 +108,8 @@ struct Model : std::enable_shared_from_this { std::unique_ptr ExpandInputs(std::unique_ptr& input, int num_beams) const; + void GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params); + std::unique_ptr config_; std::unique_ptr session_options_; std::unique_ptr run_options_; @@ -121,6 +123,9 @@ struct Model : std::enable_shared_from_this { std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime + bool use_cuda_graphs_{}; + int max_batch_size_{}; + protected: void InitDeviceAllocator(OrtSession& session); void CreateSessionOptions(); diff --git a/src/models/position_metadata.cpp b/src/models/position_metadata.cpp index fbf848216..ec6d6b495 100644 --- a/src/models/position_metadata.cpp +++ b/src/models/position_metadata.cpp @@ -37,8 +37,8 @@ PositionMetadata::PositionMetadata(const Model& model, State& state, RoamingArra position_ids_shape_ = shape; attention_mask_shape_ = shape; - if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) { - size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size; + if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) { + size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_; sb_position_ids_ = std::make_unique(model_.allocator_device_, max_beam_batch_size); sb_seqlens_k_ = std::make_unique(model_.allocator_device_, max_beam_batch_size); } @@ -105,7 +105,7 @@ void PositionMetadata::AddSeqlensK() { void PositionMetadata::AddTotalSequenceLength() { total_sequence_length_input_index_ = state_.inputs_.size(); total_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_, - total_sequence_length_shape_, + {}, Ort::TypeToTensorType::type); total_sequence_length_->GetTensorMutableData()[0] = state_.params_->sequence_length; @@ -121,7 +121,7 @@ void PositionMetadata::UpdatePositionIDs(int current_length) { position_ids_ = std::move(position_ids_next_); } else { #if USE_CUDA - position_ids_ = sb_position_ids_->GetOrCreateTensor(position_ids_shape_, type_); + position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_); assert(model_.device_type_ == DeviceType::CUDA); if (type_ == Ort::TypeToTensorType::type) { cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(), @@ -150,15 +150,15 @@ void PositionMetadata::UpdatePositionIDs(int current_length) { break; } #if USE_CUDA - case DeviceType::CUDA: - if (type_ == Ort::TypeToTensorType::type) - cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); - else - cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); - break; + case DeviceType::CUDA: + if (type_ == Ort::TypeToTensorType::type) + cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); + else + cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData(), static_cast(position_ids_shape_[0]), model_.cuda_stream_); + break; #endif - default: - throw std::runtime_error("PositionIDs::Update - Unsupported device type"); + default: + throw std::runtime_error("PositionIDs::Update - Unsupported device type"); } } } @@ -172,7 +172,7 @@ void PositionMetadata::UpdateSeqlensK(int current_length) { if (!sb_seqlens_k_) { seqlens_k_ = OrtValue::CreateTensor(*model_.allocator_device_, senlens_k_shape_, Ort::TypeToTensorType::type); } else { - seqlens_k_ = sb_seqlens_k_->GetOrCreateTensor(senlens_k_shape_, Ort::TypeToTensorType::type); + seqlens_k_ = sb_seqlens_k_->CreateTensorOnStaticBuffer(senlens_k_shape_, Ort::TypeToTensorType::type); } state_.inputs_[seqlens_k_input_index_] = seqlens_k_.get(); cudaMemcpyAsync(seqlens_k_->GetTensorMutableRawData(), initial_sequence_lengths_.data(), sizeof(int32_t) * initial_sequence_lengths_.size(), cudaMemcpyHostToDevice, model_.cuda_stream_); diff --git a/src/models/position_metadata.h b/src/models/position_metadata.h index 765243e25..7e0c267a0 100644 --- a/src/models/position_metadata.h +++ b/src/models/position_metadata.h @@ -50,8 +50,7 @@ struct PositionMetadata { std::unique_ptr attention_mask_; std::array senlens_k_shape_{}; // {params.batch_size*params.beam_size} std::unique_ptr seqlens_k_; - std::array total_sequence_length_shape_{}; // scalar - std::unique_ptr total_sequence_length_; + std::unique_ptr total_sequence_length_; // Scalar std::unique_ptr position_ids_next_; // Replaces position_ids_ after the first Run() call std::vector initial_sequence_lengths_; diff --git a/src/models/static_buffer.cpp b/src/models/static_buffer.cpp index 0ceb56033..c7f2e5f90 100644 --- a/src/models/static_buffer.cpp +++ b/src/models/static_buffer.cpp @@ -7,8 +7,8 @@ namespace Generators { StaticBuffer::StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size) : allocator_{allocator}, info_{allocator_->GetInfo()}, max_beam_batch_size_{max_beam_batch_size} { } -std::unique_ptr StaticBuffer::GetOrCreateTensor(std::span shape, - ONNXTensorElementDataType type) { +std::unique_ptr StaticBuffer::CreateTensorOnStaticBuffer(std::span shape, + ONNXTensorElementDataType type) { size_t new_bytes = GetElementSize(type) * GetNumElements(shape); if (buffer_ == nullptr) { // Assuming the first dimension is the batch size diff --git a/src/models/static_buffer.h b/src/models/static_buffer.h index 3f77f0321..8b3b43315 100644 --- a/src/models/static_buffer.h +++ b/src/models/static_buffer.h @@ -7,8 +7,8 @@ struct StaticBuffer { StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size); ~StaticBuffer(); - std::unique_ptr GetOrCreateTensor(std::span shape, - ONNXTensorElementDataType type); + std::unique_ptr CreateTensorOnStaticBuffer(std::span shape, + ONNXTensorElementDataType type); private: size_t GetElementSize(ONNXTensorElementDataType type); diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 2e14dc804..d806e626f 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -38,7 +38,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): self.io_dtype = io_dtype # {'fp16', 'fp32'} self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"} self.ep = ep - self.enable_cuda_graph = "enable_cuda_graph" in extra_options and extra_options["enable_cuda_graph"] == "1" + self.can_use_cuda_graph = "can_use_cuda_graph" in extra_options and extra_options["can_use_cuda_graph"] == "1" self.cache_dir = cache_dir self.filename = extra_options["filename"] if "filename" in extra_options else "model.onnx" @@ -215,8 +215,8 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): if self.ep == "cuda": cuda_options = { "cuda" : { } } - if self.enable_cuda_graph: - cuda_options["cuda"]["enable_cuda_graph"] = "1" + if self.can_use_cuda_graph: + cuda_options["cuda"]["can_use_cuda_graph"] = "1" genai_config["model"]["decoder"]["session_options"]["provider_options"].append(cuda_options) print(f"Saving GenAI config in {out_dir}") @@ -822,8 +822,8 @@ def make_attention_op(self, name, **kwargs): if op_type == "MultiHeadAttention": self.make_multi_head_attention(name, add_qk=f"{self.mask_attrs['mask_name']}/output_0", **kwargs) elif op_type == "GroupQueryAttention": - seqlens_k_name = f"{self.mask_attrs['seqlens_k']}/output_0" if not self.enable_cuda_graph else "seqlens_k" - total_seq_len_name = f"{self.mask_attrs['total_seq_len']}/output_0" if not self.enable_cuda_graph else "total_seq_len" + seqlens_k_name = f"{self.mask_attrs['seqlens_k']}/output_0" if not self.can_use_cuda_graph else "seqlens_k" + total_seq_len_name = f"{self.mask_attrs['total_seq_len']}/output_0" if not self.can_use_cuda_graph else "total_seq_len" self.make_group_query_attention(name, seqlens_k=seqlens_k_name, total_seq_len=total_seq_len_name, **kwargs) else: raise NotImplementedError(f"The {op_type} op is not currently supported.") @@ -1565,7 +1565,7 @@ def make_position_ids_reformatting(self): class LlamaModel(Model): def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options) - if self.enable_cuda_graph: + if self.can_use_cuda_graph: self.model_inputs = ["input_ids", "position_ids", "seqlens_k", "total_seq_len"] else: self.model_inputs = ["input_ids", "attention_mask", "position_ids"] @@ -1575,7 +1575,7 @@ def make_attention_mask_reformatting(self): super().make_attention_mask_reformatting() return - if self.enable_cuda_graph: + if self.can_use_cuda_graph: # ORT does not allow nodes to be placed on mulitple execution providers # with cuda graph enabled. Thus the attention mask is deprecated and the # subgraph is replaced with seqlens_k and total_seq_len as the raw @@ -1761,7 +1761,7 @@ def get_args(): The filename for each component will be '_.onnx' (ex: '_encoder.onnx', '_decoder.onnx'). config_only = Generate config and pre/post processing files only. Use this option when you already have your optimized and/or quantized ONNX model. - enable_cuda_graph = 1 : Enable CUDA graph capture for CUDA execution provider. Limitations may apply. + can_use_cuda_graph = 1 : The model can use CUDA graph capture for CUDA execution provider. Limitations may apply. """), ) diff --git a/src/python/python.cpp b/src/python/python.cpp index 584beb97c..2ded7cff9 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -112,6 +112,10 @@ struct PyGeneratorParams { } } + void TryUseCudaGraphWithMaxBatchSize(pybind11::int_ max_batch_size) { + params_->max_batch_size = max_batch_size.cast(); + } + pybind11::array_t py_input_ids_; pybind11::array_t py_whisper_input_features_; pybind11::array_t py_whisper_decoder_input_ids_; @@ -120,6 +124,7 @@ struct PyGeneratorParams { struct PyGenerator { PyGenerator(Model& model, PyGeneratorParams& params) { params.Prepare(); + model.GetMaxBatchSizeFromGeneratorParams(params); generator_ = CreateGenerator(model, params); } @@ -179,7 +184,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) .def_readwrite("whisper_decoder_input_ids", &PyGeneratorParams::py_whisper_decoder_input_ids_) - .def("set_search_options", &PyGeneratorParams::SetSearchOptions); + .def("set_search_options", &PyGeneratorParams::SetSearchOptions) + .def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize); // We need to init the OrtApi before we can use it Ort::InitApi(); diff --git a/test/python/test_onnxruntime_genai_e2e.py b/test/python/test_onnxruntime_genai_e2e.py index 077d48504..e4c792cfd 100644 --- a/test/python/test_onnxruntime_genai_e2e.py +++ b/test/python/test_onnxruntime_genai_e2e.py @@ -14,7 +14,7 @@ def download_model( ): # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cpu -o download_path # Or with cuda graph enabled: - # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cuda --extra_options enable_cuda_graph=1 -o download_path + # python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cuda --extra_options can_use_cuda_graph=1 -o download_path command = [ sys.executable, "-m", @@ -30,7 +30,7 @@ def download_model( ] if device == "cuda": command.append("--extra_options") - command.append("enable_cuda_graph=1") + command.append("can_use_cuda_graph=1") run_subprocess(command).check_returncode() @@ -47,6 +47,7 @@ def run_model(model_path: str | bytes | os.PathLike): sequences = tokenizer.encode_batch(prompts) params = og.GeneratorParams(model) params.set_search_options({"max_length": 200}) + params.try_use_cuda_graph_with_max_batch_size(16) params.input_ids = sequences output_sequences = model.generate(params) @@ -61,4 +62,3 @@ def run_model(model_path: str | bytes | os.PathLike): device = "cuda" if og.is_cuda_available() else "cpu" download_model(temp_dir, device, model_name, precision) run_model(temp_dir) -