Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make OgaModel* const again #356

Merged
merged 3 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmark/c/main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ void WriteE2EStats(std::string_view label,
<< "\n";
}

std::string GeneratePrompt(size_t num_prompt_tokens, OgaModel& model, const OgaTokenizer& tokenizer) {
std::string GeneratePrompt(size_t num_prompt_tokens, const OgaModel& model, const OgaTokenizer& tokenizer) {
const char* const base_prompt = "A";
auto base_prompt_sequences = OgaSequences::Create();

Expand Down
4 changes: 2 additions & 2 deletions src/csharp/NativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ internal class NativeLib
IntPtr /* const OgaSequences* */ sequences);

[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* OgaModel* */ model,
public static extern IntPtr /* OgaResult* */ OgaCreateGenerator(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaGenerator** */ generator);

Expand Down Expand Up @@ -129,7 +129,7 @@ public static extern UIntPtr OgaSequencesGetSequenceCount(IntPtr /* const OgaSeq
// This function is used to generate sequences for the given model using the given generator parameters.
// The OgaSequences object is an array of sequences, where each sequence is an array of tokens.
[DllImport(NativeLib.DllName, CallingConvention = CallingConvention.Winapi)]
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* OgaModel* */ model,
public static extern IntPtr /* OgaResult* */ OgaGenerate(IntPtr /* const OgaModel* */ model,
IntPtr /* const OgaGeneratorParams* */ generatorParams,
out IntPtr /* OgaSequences** */ sequences);

Expand Down
24 changes: 23 additions & 1 deletion src/generators.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,29 @@ GeneratorParams::GeneratorParams(const Model& model)
eos_token_id{model.config_->model.eos_token_id},
vocab_size{model.config_->model.vocab_size},
device_type{model.device_type_},
cuda_stream{model.cuda_stream_} {
cuda_stream{model.cuda_stream_},
is_cuda_graph_enabled_{IsCudaGraphEnabled(model.config_->model.decoder.session_options)} {
}

void GeneratorParams::TryGraphCapture(int max_bs) {
if (DeviceType::CUDA == device_type) {
if (is_cuda_graph_enabled_) {
if (max_bs == 0) {
throw std::runtime_error("CUDA graph is enabled, but max_batch_size is not set.");
}
use_cuda_graph = true;
max_batch_size = max_bs;
}
} else if (DeviceType::DML == device_type) {
if (max_bs == 0) {
throw std::runtime_error("max_batch_size needs to be set when using DirectML.");
}

use_cuda_graph = true;
max_batch_size = max_bs;
} else if (is_cuda_graph_enabled_) {
throw std::runtime_error("CUDA graph is not supported on this device");
}
}

std::unique_ptr<Generator> CreateGenerator(const Model& model, const GeneratorParams& params) {
Expand Down
6 changes: 6 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {

int batch_size{1};
int max_batch_size{0};
bool use_cuda_graph{};
int sequence_length{};
int BatchBeamSize() const { return search.num_beams * batch_size; }

Expand Down Expand Up @@ -97,6 +98,11 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {
std::vector<int32_t> input_ids_owner; // Backing memory of input_ids in some cases

std::shared_ptr<GeneratorParams> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

void TryGraphCapture(int max_bs);

private:
bool is_cuda_graph_enabled_{};
};

struct Generator {
Expand Down
2 changes: 1 addition & 1 deletion src/models/captured_graph_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ static std::tuple<int, int, int> MakeKey(int max_batch_size, int max_length, int
}

CapturedGraphInfoPtr CapturedGraphPool::ReserveCapturedGraph(const Model& model, const GeneratorParams& params) const {
if (!model.use_cuda_graph_ || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) {
if (!params.use_cuda_graph || (model.device_type_ != DeviceType::CUDA && model.device_type_ != DeviceType::DML)) {
return nullptr;
}

Expand Down
4 changes: 2 additions & 2 deletions src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra

RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
if (first_run_) {
if (model_.use_cuda_graph_) {
if (params_->use_cuda_graph) {
model_.run_options_->AddConfigEntry("gpu_graph_id", "-1");
}
first_run_ = false;
Expand All @@ -37,7 +37,7 @@ RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int3
State::Run(*model_.session_decoder_, *model_.run_options_);

// Set the graph id for the following runs.
if (model_.use_cuda_graph_) {
if (params_->use_cuda_graph) {
int new_batch_size = static_cast<int>(input_ids_.GetShape()[0]);
if (new_batch_size != current_batch_size_) {
current_batch_size_ = new_batch_size;
Expand Down
22 changes: 0 additions & 22 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -537,26 +537,4 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
return expanded;
}

void Model::GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params) {
bool is_cuda_graph_enabled = device_type_ == DeviceType::DML || IsCudaGraphEnabled(config_->model.decoder.session_options);
max_batch_size_ = params.max_batch_size;

if (DeviceType::CUDA == device_type_) {
if (is_cuda_graph_enabled) {
if (max_batch_size_ == 0) {
throw std::runtime_error("CUDA graph is enabled, but max_batch_size is not set.");
}
use_cuda_graph_ = true;
}
} else if (DeviceType::DML == device_type_) {
if (max_batch_size_ == 0) {
throw std::runtime_error("max_batch_size needs to be set when using DirectML.");
}

use_cuda_graph_ = true;
} else if (is_cuda_graph_enabled) {
throw std::runtime_error("CUDA graph is not supported on this device");
}
}

} // namespace Generators
5 changes: 0 additions & 5 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,6 @@ struct Model : std::enable_shared_from_this<Model> {

std::unique_ptr<OrtValue> ExpandInputs(std::unique_ptr<OrtValue>& input, int num_beams) const;

void GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params);

CapturedGraphPool* GetCapturedGraphPool() const { return captured_graph_pool_.get(); }

std::unique_ptr<Config> config_;
Expand All @@ -136,9 +134,6 @@ struct Model : std::enable_shared_from_this<Model> {

std::shared_ptr<Model> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

bool use_cuda_graph_{};
int max_batch_size_{};

#if USE_DML
DmlExecutionContext* GetDmlExecutionContext() const { return dml_execution_context_.get(); }
DmlReadbackHeap* GetDmlReadbackHeap() const { return dml_readback_heap_.get(); }
Expand Down
4 changes: 2 additions & 2 deletions src/ort_genai.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ struct OgaModel : OgaAbstract {
return std::unique_ptr<OgaModel>(p);
}

std::unique_ptr<OgaSequences> Generate(const OgaGeneratorParams& params) {
std::unique_ptr<OgaSequences> Generate(const OgaGeneratorParams& params) const {
OgaSequences* p;
OgaCheckResult(OgaGenerate(this, &params, &p));
return std::unique_ptr<OgaSequences>(p);
Expand Down Expand Up @@ -201,7 +201,7 @@ struct OgaGeneratorParams : OgaAbstract {
};

struct OgaGenerator : OgaAbstract {
static std::unique_ptr<OgaGenerator> Create(OgaModel& model, const OgaGeneratorParams& params) {
static std::unique_ptr<OgaGenerator> Create(const OgaModel& model, const OgaGeneratorParams& params) {
OgaGenerator* p;
OgaCheckResult(OgaCreateGenerator(&model, &params, &p));
return std::unique_ptr<OgaGenerator>(p);
Expand Down
16 changes: 5 additions & 11 deletions src/ort_genai_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ OgaResult* OGA_API_CALL OgaGeneratorParamsSetSearchBool(OgaGeneratorParams* gene
OgaResult* OGA_API_CALL OgaGeneratorParamsTryGraphCaptureWithMaxBatchSize(OgaGeneratorParams* generator_params, int32_t max_batch_size) {
OGA_TRY
auto* params = reinterpret_cast<Generators::GeneratorParams*>(generator_params);
params->max_batch_size = max_batch_size;
params->TryGraphCapture(max_batch_size);
baijumeswani marked this conversation as resolved.
Show resolved Hide resolved
return nullptr;
OGA_CATCH
}
Expand Down Expand Up @@ -143,23 +143,17 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetInputSequences(OgaGenera
OGA_CATCH
}

OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) {
OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out) {
OGA_TRY
auto* model_p = reinterpret_cast<Generators::Model*>(model);
auto* params = reinterpret_cast<const Generators::GeneratorParams*>(generator_params);
model_p->GetMaxBatchSizeFromGeneratorParams(*params);
auto result = Generators::Generate(*model_p, *params);
auto result = Generators::Generate(*reinterpret_cast<const Generators::Model*>(model), *reinterpret_cast<const Generators::GeneratorParams*>(generator_params));
*out = reinterpret_cast<OgaSequences*>(std::make_unique<Generators::TokenSequences>(std::move(result)).release());
return nullptr;
OGA_CATCH
}

OgaResult* OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) {
OgaResult* OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaGenerator** out) {
OGA_TRY
auto* model_p = reinterpret_cast<Generators::Model*>(model);
auto* params = reinterpret_cast<const Generators::GeneratorParams*>(generator_params);
model_p->GetMaxBatchSizeFromGeneratorParams(*params);
*out = reinterpret_cast<OgaGenerator*>(CreateGenerator(*model_p, *params).release());
*out = reinterpret_cast<OgaGenerator*>(CreateGenerator(*reinterpret_cast<const Generators::Model*>(model), *reinterpret_cast<const Generators::GeneratorParams*>(generator_params)).release());
return nullptr;
OGA_CATCH
}
Expand Down
4 changes: 2 additions & 2 deletions src/ort_genai_c.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ OGA_EXPORT void OGA_API_CALL OgaDestroyModel(OgaModel* model);
* after it is done using the sequences.
* \return OgaResult containing the error message if the generation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaGenerate(const OgaModel* model, const OgaGeneratorParams* generator_params, OgaSequences** out);

/*
* \brief Creates a OgaGeneratorParams from the given model.
Expand Down Expand Up @@ -167,7 +167,7 @@ OGA_EXPORT OgaResult* OGA_API_CALL OgaGeneratorParamsSetWhisperDecoderInputIDs(O
* \param[out] out The created generator.
* \return OgaResult containing the error message if the generator creation failed.
*/
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);
OGA_EXPORT OgaResult* OGA_API_CALL OgaCreateGenerator(const OgaModel* model, const OgaGeneratorParams* params, OgaGenerator** out);

/*
* \brief Destroys the given generator.
Expand Down
5 changes: 2 additions & 3 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ struct PyGeneratorParams {
}

void TryUseCudaGraphWithMaxBatchSize(pybind11::int_ max_batch_size) {
params_->max_batch_size = max_batch_size.cast<int>();
params_->TryGraphCapture(max_batch_size.cast<int>());
}

pybind11::array_t<int32_t> py_input_ids_;
Expand All @@ -115,7 +115,6 @@ struct PyGeneratorParams {
struct PyGenerator {
PyGenerator(Model& model, PyGeneratorParams& params) {
params.Prepare();
model.GetMaxBatchSizeFromGeneratorParams(params);
generator_ = CreateGenerator(model, params);
}

Expand Down Expand Up @@ -229,7 +228,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def(pybind11::init([](const std::string& config_path) {
return CreateModel(GetOrtEnv(), config_path.c_str());
}))
.def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); model.GetMaxBatchSizeFromGeneratorParams(params); return Generate(model, params); })
.def("generate", [](Model& model, PyGeneratorParams& params) { params.Prepare(); return Generate(model, params); })
.def_property_readonly("device_type", [](const Model& s) { return s.device_type_; });

pybind11::class_<PyGenerator>(m, "Generator")
Expand Down
Loading