Skip to content

Commit

Permalink
Make OgaModel* const again (#356)
Browse files Browse the repository at this point in the history
  • Loading branch information
baijumeswani committed May 1, 2024
1 parent df51783 commit ed41b21
Show file tree
Hide file tree
Showing 12 changed files with 42 additions and 52 deletions.
2 changes: 1 addition & 1 deletion benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ void WriteE2EStats(std::string_view label,
<< "\n";
}

std::string GeneratePrompt(size_t num_prompt_tokens, OgaModel& model, const OgaTokenizer& tokenizer) {
std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer) {
const char* const base_prompt = "A";
auto base_prompt_sequences = OgaSequences::Create();

Expand Down
4 changes: 2 additions & 2 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ internal class NativeLib
IntPtr /* const OgaSequences* */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* OgaModel* */ model,
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaGenerator** */ generator);

Expand Down Expand Up @@ -129,7 +129,7 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
// This function is used to generate sequences for the given model using the given generator parameters.
// The OgaSequences object is an array of sequences, where each sequence is an array of tokens.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* OgaModel* */ model,
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaSequences** */ sequences);

Expand Down
20 changes: 19 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,25 @@ GeneratorParams::GeneratorParams(const Model& model)
eos_token_id{model.config_->model.eos_token_id},
vocab_size{model.config_->model.vocab_size},
device_type{model.device_type_},
cuda_stream{model.cuda_stream_} {
cuda_stream{model.cuda_stream_},
is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
}

void GeneratorParams::TryGraphCapture(int max_bs) {
if (!is_cuda_graph_enabled_ || device_type == DeviceType::CPU) {
// no-op
return;
}

if (DeviceType::CUDA == device_type || DeviceType::DML == device_type) {
if (max_bs == 0) {
throw std::runtime_error("Graph capture is enabled, but max_batch_size is not set.");
}
use_cuda_graph = true;
max_batch_size = max_bs;
} else {
throw std::runtime_error("CUDA graph is not supported on this device");
}
}

std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params) {
Expand Down
6 changes: 6 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {

int batch_size{1};
int max_batch_size{0};
bool use_cuda_graph{};
int sequence_length{};
int BatchBeamSize() const { return search.num_beams * batch_size; }

Expand Down Expand Up @@ -97,6 +98,11 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {
std::vector<int32_t> input_ids_owner; // Backing memory of input_ids in some cases

std::shared_ptr<GeneratorParams> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

void TryGraphCapture(int max_bs);

private:
bool is_cuda_graph_enabled_{};
};

struct Generator {
Expand Down
2 changes: 1 addition & 1 deletion src/models/captured_graph_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static std::tuple<int, int, int> MakeKey(int max_batch_size, int max_length, int
}

CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const {
if (!model.use_cuda_graph_ || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) {
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) {
return nullptr;
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra

RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
if (first_run_) {
if (model_.use_cuda_graph_) {
if (params_->use_cuda_graph) {
model_.run_options_->AddConfigEntry("gpu_graph_id", "-1");
}
first_run_ = false;
Expand All @@ -37,7 +37,7 @@ RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int3
State::Run(*model_.session_decoder_, *model_.run_options_);

// Set the graph id for the following runs.
if (model_.use_cuda_graph_) {
if (params_->use_cuda_graph) {
int new_batch_size = static_cast<int>(input_ids_.GetShape()[0]);
if (new_batch_size != current_batch_size_) {
current_batch_size_ = new_batch_size;
Expand Down
22 changes: 0 additions & 22 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,26 +537,4 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
return expanded;
}

void Model::GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params) {
bool is_cuda_graph_enabled = device_type_ == DeviceType::DML || IsCudaGraphEnabled(config_->model.decoder.session_options);
max_batch_size_ = params.max_batch_size;

if (DeviceType::CUDA == device_type_) {
if (is_cuda_graph_enabled) {
if (max_batch_size_ == 0) {
throw std::runtime_error("CUDA graph is enabled, but max_batch_size is not set.");
}
use_cuda_graph_ = true;
}
} else if (DeviceType::DML == device_type_) {
if (max_batch_size_ == 0) {
throw std::runtime_error("max_batch_size needs to be set when using DirectML.");
}

use_cuda_graph_ = true;
} else if (is_cuda_graph_enabled) {
throw std::runtime_error("CUDA graph is not supported on this device");
}
}

} // namespace Generators
5 changes: 0 additions & 5 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ struct Model : std::enable_shared_from_this<Model> {

std::unique_ptr<OrtValue> ExpandInputs(std::unique_ptr<OrtValue>& input, int num_beams) const;

void GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params);

CapturedGraphPool* GetCapturedGraphPool() const { return captured_graph_pool_.get(); }

std::unique_ptr<Config> config_;
Expand All @@ -136,9 +134,6 @@ struct Model : std::enable_shared_from_this<Model> {

std::shared_ptr<Model> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

bool use_cuda_graph_{};
int max_batch_size_{};

#if USE_DML
DmlExecutionContext* GetDmlExecutionContext() const { return dml_execution_context_.get(); }
DmlReadbackHeap* GetDmlReadbackHeap() const { return dml_readback_heap_.get(); }
Expand Down
4 changes: 2 additions & 2 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct OgaModel : OgaAbstract {
return std::unique_ptr<OgaModel>(p);
}

std::unique_ptr<OgaSequences> Generate(const OgaGeneratorParams& params) {
std::unique_ptr<OgaSequences> Generate(const OgaGeneratorParams& params) const {
OgaSequences* p;
OgaCheckResult(OgaGenerate(this, &params, &p));
return std::unique_ptr<OgaSequences>(p);
Expand Down Expand Up @@ -201,7 +201,7 @@ struct OgaGeneratorParams : OgaAbstract {
};

struct OgaGenerator : OgaAbstract {
static std::unique_ptr<OgaGenerator> Create(OgaModel& model, const OgaGeneratorParams& params) {
static std::unique_ptr<OgaGenerator> Create(const OgaModel& model, const OgaGeneratorParams& params) {
OgaGenerator* p;
OgaCheckResult(OgaCreateGenerator(&model, &params, &p));
return std::unique_ptr<OgaGenerator>(p);
Expand Down
16 changes: 5 additions & 11 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* gene
OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size) {
OGA_TRY
auto* params = reinterpret_cast<Generators::GeneratorParams*>(generator_params);
params->max_batch_size = max_batch_size;
params->TryGraphCapture(max_batch_size);
return nullptr;
OGA_CATCH
}
Expand Down Expand Up @@ -140,23 +140,17 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGenera
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) {
OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) {
OGA_TRY
auto* model_p = reinterpret_cast<Generators::Model*>(model);
auto* params = reinterpret_cast<const Generators::GeneratorParams*>(generator_params);
model_p->GetMaxBatchSizeFromGeneratorParams(*params);
auto result = Generators::Generate(*model_p, *params);
auto result = Generators::Generate(*reinterpret_cast<const Generators::Model*>(model), *reinterpret_cast<const Generators::GeneratorParams*>(generator_params));
*out = reinterpret_cast<OgaSequences*>(std::make_unique<Generators::TokenSequences>(std::move(result)).release());
return nullptr;
OGA_CATCH
}

OgaResult* OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) {
OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) {
OGA_TRY
auto* model_p = reinterpret_cast<Generators::Model*>(model);
auto* params = reinterpret_cast<const Generators::GeneratorParams*>(generator_params);
model_p->GetMaxBatchSizeFromGeneratorParams(*params);
*out = reinterpret_cast<OgaGenerator*>(CreateGenerator(*model_p, *params).release());
*out = reinterpret_cast<OgaGenerator*>(CreateGenerator(*reinterpret_cast<const Generators::Model*>(model), *reinterpret_cast<const Generators::GeneratorParams*>(generator_params)).release());
return nullptr;
OGA_CATCH
}
Expand Down
4 changes: 2 additions & 2 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model);
* after it is done using the sequences.
* \return OgaResult containing the error message if the generation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out);

/*
* \brief Creates a OgaGeneratorParams from the given model.
Expand Down Expand Up @@ -165,7 +165,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperDecoderInputIDs(O
* \param[out] out The created generator.
* \return OgaResult containing the error message if the generator creation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);

/*
* \brief Destroys the given generator.
Expand Down
5 changes: 2 additions & 3 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct PyGeneratorParams {
}

void TryUseCudaGraphWithMaxBatchSize(pybind11::int_ max_batch_size) {
params_->max_batch_size = max_batch_size.cast<int>();
params_->TryGraphCapture(max_batch_size.cast<int>());
}

pybind11::array_t<int32_t> py_input_ids_;
Expand All @@ -115,7 +115,6 @@ struct PyGeneratorParams {
struct PyGenerator {
PyGenerator(Model& model, PyGeneratorParams& params) {
params.Prepare();
model.GetMaxBatchSizeFromGeneratorParams(params);
generator_ = CreateGenerator(model, params);
}

Expand Down Expand Up @@ -229,7 +228,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def(pybind11::init([](const std::string& config_path) {
return CreateModel(GetOrtEnv(), config_path.c_str());
}))
.def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); model.GetMaxBatchSizeFromGeneratorParams(params); return Generate(model, params); })
.def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); })
.def_property_readonly("device_type", [](const Model& s) { return s.device_type_; });

pybind11::class_<PyGenerator>(m, "Generator")
Expand Down

0 comments on commit ed41b21

Please sign in to comment.