diff --git a/benchmark/c/main.cpp b/benchmark/c/main.cpp index 2d4b62b1f..3a4c9b43b 100644 --- a/benchmark/c/main.cpp +++ b/benchmark/c/main.cpp @@ -112,7 +112,7 @@ void WriteE2EStats(std::string_view label, << "\n"; } -std::string GeneratePrompt(size_t num_prompt_tokens, OgaModel& model, const OgaTokenizer& tokenizer) { +std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer) { const char* const base_prompt = "A"; auto base_prompt_sequences = OgaSequences::Create(); diff --git a/src/csharp/NativeMethods.cs b/src/csharp/NativeMethods.cs index f2906f3df..a56e7dd7e 100644 --- a/src/csharp/NativeMethods.cs +++ b/src/csharp/NativeMethods.cs @@ -71,7 +71,7 @@ internal class NativeLib IntPtr /* const OgaSequences* */ sequences); [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* OgaModel* */ model, + public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model, IntPtr /* const OgaGeneratorParams* */ generatorParams, out IntPtr /* OgaGenerator** */ generator); @@ -129,7 +129,7 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq // This function is used to generate sequences for the given model using the given generator parameters. // The OgaSequences object is an array of sequences, where each sequence is an array of tokens. [DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)] - public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* OgaModel* */ model, + public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* const OgaModel* */ model, IntPtr /* const OgaGeneratorParams* */ generatorParams, out IntPtr /* OgaSequences** */ sequences); diff --git a/src/generators.cpp b/src/generators.cpp index 0c664f341..bc00f8d3e 100644 --- a/src/generators.cpp +++ b/src/generators.cpp @@ -61,7 +61,25 @@ GeneratorParams::GeneratorParams(const Model& model) eos_token_id{model.config_->model.eos_token_id}, vocab_size{model.config_->model.vocab_size}, device_type{model.device_type_}, - cuda_stream{model.cuda_stream_} { + cuda_stream{model.cuda_stream_}, + is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} { +} + +void GeneratorParams::TryGraphCapture(int max_bs) { + if (!is_cuda_graph_enabled_ || device_type == DeviceType::CPU) { + // no-op + return; + } + + if (DeviceType::CUDA == device_type || DeviceType::DML == device_type) { + if (max_bs == 0) { + throw std::runtime_error("Graph capture is enabled, but max_batch_size is not set."); + } + use_cuda_graph = true; + max_batch_size = max_bs; + } else { + throw std::runtime_error("CUDA graph is not supported on this device"); + } } std::unique_ptr CreateGenerator(const Model& model, const GeneratorParams& params) { diff --git a/src/generators.h b/src/generators.h index c10868570..c6a510739 100644 --- a/src/generators.h +++ b/src/generators.h @@ -61,6 +61,7 @@ struct GeneratorParams : std::enable_shared_from_this { int batch_size{1}; int max_batch_size{0}; + bool use_cuda_graph{}; int sequence_length{}; int BatchBeamSize() const { return search.num_beams * batch_size; } @@ -97,6 +98,11 @@ struct GeneratorParams : std::enable_shared_from_this { std::vector input_ids_owner; // Backing memory of input_ids in some cases std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime + + void TryGraphCapture(int max_bs); + + private: + bool is_cuda_graph_enabled_{}; }; struct Generator { diff --git a/src/models/captured_graph_pool.cpp b/src/models/captured_graph_pool.cpp index 140f2a8cd..96cc029b8 100644 --- a/src/models/captured_graph_pool.cpp +++ b/src/models/captured_graph_pool.cpp @@ -24,7 +24,7 @@ static std::tuple MakeKey(int max_batch_size, int max_length, int } CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const { - if (!model.use_cuda_graph_ || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) { + if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) { return nullptr; } diff --git a/src/models/decoder_only.cpp b/src/models/decoder_only.cpp index 83d1f03d3..53f4f6697 100644 --- a/src/models/decoder_only.cpp +++ b/src/models/decoder_only.cpp @@ -26,7 +26,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra RoamingArray DecoderOnly_State::Run(int current_length, RoamingArray next_tokens, RoamingArray next_indices) { if (first_run_) { - if (model_.use_cuda_graph_) { + if (params_->use_cuda_graph) { model_.run_options_->AddConfigEntry("gpu_graph_id", "-1"); } first_run_ = false; @@ -37,7 +37,7 @@ RoamingArray DecoderOnly_State::Run(int current_length, RoamingArrayuse_cuda_graph) { int new_batch_size = static_cast(input_ids_.GetShape()[0]); if (new_batch_size != current_batch_size_) { current_batch_size_ = new_batch_size; diff --git a/src/models/model.cpp b/src/models/model.cpp index 4e7aa7343..e0a138ae2 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -537,26 +537,4 @@ std::unique_ptr Model::ExpandInputs(std::unique_ptr& input, return expanded; } -void Model::GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params) { - bool is_cuda_graph_enabled = device_type_ == DeviceType::DML || IsCudaGraphEnabled(config_->model.decoder.session_options); - max_batch_size_ = params.max_batch_size; - - if (DeviceType::CUDA == device_type_) { - if (is_cuda_graph_enabled) { - if (max_batch_size_ == 0) { - throw std::runtime_error("CUDA graph is enabled, but max_batch_size is not set."); - } - use_cuda_graph_ = true; - } - } else if (DeviceType::DML == device_type_) { - if (max_batch_size_ == 0) { - throw std::runtime_error("max_batch_size needs to be set when using DirectML."); - } - - use_cuda_graph_ = true; - } else if (is_cuda_graph_enabled) { - throw std::runtime_error("CUDA graph is not supported on this device"); - } -} - } // namespace Generators diff --git a/src/models/model.h b/src/models/model.h index b569373f8..d1b0a1ec0 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -119,8 +119,6 @@ struct Model : std::enable_shared_from_this { std::unique_ptr ExpandInputs(std::unique_ptr& input, int num_beams) const; - void GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params); - CapturedGraphPool* GetCapturedGraphPool() const { return captured_graph_pool_.get(); } std::unique_ptr config_; @@ -136,9 +134,6 @@ struct Model : std::enable_shared_from_this { std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime - bool use_cuda_graph_{}; - int max_batch_size_{}; - #if USE_DML DmlExecutionContext* GetDmlExecutionContext() const { return dml_execution_context_.get(); } DmlReadbackHeap* GetDmlReadbackHeap() const { return dml_readback_heap_.get(); } diff --git a/src/ort_genai.h b/src/ort_genai.h index fb863dae2..b8e55bf19 100644 --- a/src/ort_genai.h +++ b/src/ort_genai.h @@ -75,7 +75,7 @@ struct OgaModel : OgaAbstract { return std::unique_ptr(p); } - std::unique_ptr Generate(const OgaGeneratorParams& params) { + std::unique_ptr Generate(const OgaGeneratorParams& params) const { OgaSequences* p; OgaCheckResult(OgaGenerate(this, ¶ms, &p)); return std::unique_ptr(p); @@ -201,7 +201,7 @@ struct OgaGeneratorParams : OgaAbstract { }; struct OgaGenerator : OgaAbstract { - static std::unique_ptr Create(OgaModel& model, const OgaGeneratorParams& params) { + static std::unique_ptr Create(const OgaModel& model, const OgaGeneratorParams& params) { OgaGenerator* p; OgaCheckResult(OgaCreateGenerator(&model, ¶ms, &p)); return std::unique_ptr(p); diff --git a/src/ort_genai_c.cpp b/src/ort_genai_c.cpp index 13cae5235..d5ab67040 100644 --- a/src/ort_genai_c.cpp +++ b/src/ort_genai_c.cpp @@ -108,7 +108,7 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* gene OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size) { OGA_TRY auto* params = reinterpret_cast(generator_params); - params->max_batch_size = max_batch_size; + params->TryGraphCapture(max_batch_size); return nullptr; OGA_CATCH } @@ -143,23 +143,17 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGenera OGA_CATCH } -OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) { +OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) { OGA_TRY - auto* model_p = reinterpret_cast(model); - auto* params = reinterpret_cast(generator_params); - model_p->GetMaxBatchSizeFromGeneratorParams(*params); - auto result = Generators::Generate(*model_p, *params); + auto result = Generators::Generate(*reinterpret_cast(model), *reinterpret_cast(generator_params)); *out = reinterpret_cast(std::make_unique(std::move(result)).release()); return nullptr; OGA_CATCH } -OgaResult* OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) { +OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) { OGA_TRY - auto* model_p = reinterpret_cast(model); - auto* params = reinterpret_cast(generator_params); - model_p->GetMaxBatchSizeFromGeneratorParams(*params); - *out = reinterpret_cast(CreateGenerator(*model_p, *params).release()); + *out = reinterpret_cast(CreateGenerator(*reinterpret_cast(model), *reinterpret_cast(generator_params)).release()); return nullptr; OGA_CATCH } diff --git a/src/ort_genai_c.h b/src/ort_genai_c.h index 0939d2c36..3e44c29e4 100644 --- a/src/ort_genai_c.h +++ b/src/ort_genai_c.h @@ -117,7 +117,7 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model); * after it is done using the sequences. * \return OgaResult containing the error message if the generation failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out); /* * \brief Creates a OgaGeneratorParams from the given model. @@ -167,7 +167,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperDecoderInputIDs(O * \param[out] out The created generator. * \return OgaResult containing the error message if the generator creation failed. */ -OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); +OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out); /* * \brief Destroys the given generator. diff --git a/src/python/python.cpp b/src/python/python.cpp index cd974d916..1d8a4e567 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -104,7 +104,7 @@ struct PyGeneratorParams { } void TryUseCudaGraphWithMaxBatchSize(pybind11::int_ max_batch_size) { - params_->max_batch_size = max_batch_size.cast(); + params_->TryGraphCapture(max_batch_size.cast()); } pybind11::array_t py_input_ids_; @@ -115,7 +115,6 @@ struct PyGeneratorParams { struct PyGenerator { PyGenerator(Model& model, PyGeneratorParams& params) { params.Prepare(); - model.GetMaxBatchSizeFromGeneratorParams(params); generator_ = CreateGenerator(model, params); } @@ -229,7 +228,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def(pybind11::init([](const std::string& config_path) { return CreateModel(GetOrtEnv(), config_path.c_str()); })) - .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); model.GetMaxBatchSizeFromGeneratorParams(params); return Generate(model, params); }) + .def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); }) .def_property_readonly("device_type", [](const Model& s) { return s.device_type_; }); pybind11::class_(m, "Generator")