Skip to content

Commit

Permalink
pass max batch from gen params
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Apr 12, 2024
1 parent c3df56f commit 559c2d3
Show file tree
Hide file tree
Showing 16 changed files with 68 additions and 50 deletions.
4 changes: 2 additions & 2 deletions src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ bool IsCudaGraphEnabled(Config::SessionOptions& session_options) {
for (const auto& provider_options : session_options.provider_options) {
if (provider_options.name == "cuda") {
for (const auto& value : provider_options.options) {
if (value.first == "enable_cuda_graph") {
if (value.first == "can_use_cuda_graph") {
// Is it the correct value string?
return value.second == "1";
}
Expand Down Expand Up @@ -399,7 +399,7 @@ Config::Config(const std::filesystem::path& path) : config_path{path} {
if (search.max_length == 0)
search.max_length = model.context_length;

use_cuda_graphs = IsCudaGraphEnabled(model.decoder.session_options);
can_use_cuda_graphs = IsCudaGraphEnabled(model.decoder.session_options);
}

} // namespace Generators
5 changes: 1 addition & 4 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,7 @@ struct Config {

std::filesystem::path config_path; // Path of the config directory

bool use_cuda_graphs{false};
// TODO: pass it from config file/generator ctor when serving stack is supported
// Hardcoded for now
size_t max_batch_size{16};
bool can_use_cuda_graphs{false};

using ProviderOption = std::pair<std::string, std::string>;
struct ProviderOptions {
Expand Down
3 changes: 2 additions & 1 deletion src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {
int context_length{};

int batch_size{1};
int max_batch_size{0};
int sequence_length{};
int BatchBeamSize() const { return search.num_beams * batch_size; }

Expand All @@ -74,7 +75,7 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {

struct T5 {
std::span<const int32_t> encoder_input_ids; // Array of [batchsize][sequence_length]
std::span<const int32_t> decoder_input_ids; // Array of [batchsize][sequence_length]
std::span<const int32_t> decoder_input_ids; // Array of [batchsize][sequence_length]
};
using Bart=T5;

Expand Down
4 changes: 2 additions & 2 deletions src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ DecoderOnly_State::DecoderOnly_State(const DecoderOnly_Model& model, RoamingArra

RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int32_t> next_tokens, RoamingArray<int32_t> next_indices) {
if (first_run_) {
if (model_.config_->use_cuda_graphs) {
if (model_.use_cuda_graphs_) {
model_.run_options_->AddConfigEntry("gpu_graph_id", "-1");
}
first_run_ = false;
Expand All @@ -37,7 +37,7 @@ RoamingArray<float> DecoderOnly_State::Run(int current_length, RoamingArray<int3
State::Run(*model_.session_decoder_, *model_.run_options_);

// Set the graph id for the following runs.
if (model_.config_->use_cuda_graphs) {
if (model_.use_cuda_graphs_) {
int new_graph_annotation_id = GetGraphAnnotationId();
if (new_graph_annotation_id != graph_annotation_id_) {
graph_annotation_id_ = new_graph_annotation_id;
Expand Down
6 changes: 3 additions & 3 deletions src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ InputIDs::InputIDs(const Model& model, State& state)
value_ = model_.ExpandInputs(value_, state_.params_->search.num_beams);
shape_[0] *= state_.params_->search.num_beams;

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_;
sb_input_ids_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
}
Expand All @@ -48,7 +48,7 @@ void InputIDs::Update(RoamingArray<int32_t> next_tokens_unk) {
if (!sb_input_ids_) {
value_ = OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_);
} else {
value_ = sb_input_ids_->GetOrCreateTensor(shape_, type_);
value_ = sb_input_ids_->CreateTensorOnStaticBuffer(shape_, type_);
}

state_.inputs_[input_index_] = value_.get();
Expand Down
6 changes: 3 additions & 3 deletions src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,9 +146,9 @@ KV_Cache::KV_Cache(const Model& model, State& state)
else
shape_[2] = state_.params_->sequence_length;

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) {
assert(past_present_share_buffer_);
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_;
sb_kv_caches_.reserve(layer_count_ * 2);
for (int i = 0; i < layer_count_ * 2; ++i) {
sb_kv_caches_.push_back(std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size));
Expand All @@ -158,7 +158,7 @@ KV_Cache::KV_Cache(const Model& model, State& state)
for (int i = 0; i < layer_count_ * 2; ++i) {
presents_.push_back(
sb_kv_caches_.empty() ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
: sb_kv_caches_[i]->GetOrCreateTensor(shape_, type_));
: sb_kv_caches_[i]->CreateTensorOnStaticBuffer(shape_, type_));
}
}

Expand Down
8 changes: 4 additions & 4 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ Logits::Logits(const Model& model, State& state)
else
value16_ = std::move(logits_tensor);

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_;
if (type_ == Ort::TypeToTensorType<float>::type) {
sb_logits32_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
Expand Down Expand Up @@ -49,7 +49,7 @@ RoamingArray<float> Logits::Get() {
shape_[1] = 1;
// bugbug: not done yet
auto value_next = !sb_logits32_ ? OrtValue::CreateTensor<float>(*model_.allocator_device_, shape_)
: sb_logits32_->GetOrCreateTensor(shape_, type_);
: sb_logits32_->CreateTensorOnStaticBuffer(shape_, type_);
auto logits_next = cpu_span<float>{value_next->GetTensorMutableData<float>(), element_count};

size_t vocab_index = 0; // Simpler math to have this index go up by vocab_size for every logit chunk we process
Expand Down Expand Up @@ -83,7 +83,7 @@ RoamingArray<float> Logits::Get() {
value32_ = std::move(value_next);
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>::type)
value16_ = !sb_logits16_ ? OrtValue::CreateTensor(*model_.allocator_device_, shape_, type_)
: sb_logits16_->GetOrCreateTensor(shape_, type_);
: sb_logits16_->CreateTensorOnStaticBuffer(shape_, type_);

state_.outputs_[output_index_] = type_ == Ort::TypeToTensorType<float>::type ? value32_.get() : value16_.get();
}
Expand Down
10 changes: 10 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -473,4 +473,14 @@ std::unique_ptr<OrtValue> Model::ExpandInputs(std::unique_ptr<OrtValue>& input,
return expanded;
}

void Model::GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params) {
max_batch_size_ = params.max_batch_size;
if (max_batch_size_ > 0 && DeviceType::CUDA == device_type_) {
if (!config_->can_use_cuda_graphs) {
throw std::runtime_error("CUDA graphs are not enabled in this model");
}
use_cuda_graphs_ = true;
}
}

} // namespace Generators
5 changes: 5 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ struct Model : std::enable_shared_from_this<Model> {

std::unique_ptr<OrtValue> ExpandInputs(std::unique_ptr<OrtValue>& input, int num_beams) const;

void GetMaxBatchSizeFromGeneratorParams(const GeneratorParams& params);

std::unique_ptr<Config> config_;
std::unique_ptr<OrtSessionOptions> session_options_;
std::unique_ptr<OrtRunOptions> run_options_;
Expand All @@ -121,6 +123,9 @@ struct Model : std::enable_shared_from_this<Model> {

std::shared_ptr<Model> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

bool use_cuda_graphs_{};
int max_batch_size_{};

protected:
void InitDeviceAllocator(OrtSession& session);
void CreateSessionOptions();
Expand Down
26 changes: 13 additions & 13 deletions src/models/position_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ PositionMetadata::PositionMetadata(const Model& model, State& state, RoamingArra
position_ids_shape_ = shape;
attention_mask_shape_ = shape;

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
if (model_.device_type_ == DeviceType::CUDA && model_.use_cuda_graphs_) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.max_batch_size_;
sb_position_ids_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
sb_seqlens_k_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
Expand Down Expand Up @@ -105,7 +105,7 @@ void PositionMetadata::AddSeqlensK() {
void PositionMetadata::AddTotalSequenceLength() {
total_sequence_length_input_index_ = state_.inputs_.size();
total_sequence_length_ = OrtValue::CreateTensor(model_.allocator_cpu_,
total_sequence_length_shape_,
{},
Ort::TypeToTensorType<int32_t>::type);

total_sequence_length_->GetTensorMutableData<int32_t>()[0] = state_.params_->sequence_length;
Expand All @@ -121,7 +121,7 @@ void PositionMetadata::UpdatePositionIDs(int current_length) {
position_ids_ = std::move(position_ids_next_);
} else {
#if USE_CUDA
position_ids_ = sb_position_ids_->GetOrCreateTensor(position_ids_shape_, type_);
position_ids_ = sb_position_ids_->CreateTensorOnStaticBuffer(position_ids_shape_, type_);
assert(model_.device_type_ == DeviceType::CUDA);
if (type_ == Ort::TypeToTensorType<int32_t>::type) {
cudaMemcpyAsync(position_ids_->GetTensorMutableRawData(),
Expand Down Expand Up @@ -150,15 +150,15 @@ void PositionMetadata::UpdatePositionIDs(int current_length) {
break;
}
#if USE_CUDA
case DeviceType::CUDA:
if (type_ == Ort::TypeToTensorType<int32_t>::type)
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int32_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
else
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int64_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
break;
case DeviceType::CUDA:
if (type_ == Ort::TypeToTensorType<int32_t>::type)
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int32_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
else
cuda::Launch_UpdatePositionIds(position_ids_->GetTensorMutableData<int64_t>(), static_cast<int>(position_ids_shape_[0]), model_.cuda_stream_);
break;
#endif
default:
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
default:
throw std::runtime_error("PositionIDs::Update - Unsupported device type");
}
}
}
Expand All @@ -172,7 +172,7 @@ void PositionMetadata::UpdateSeqlensK(int current_length) {
if (!sb_seqlens_k_) {
seqlens_k_ = OrtValue::CreateTensor(*model_.allocator_device_, senlens_k_shape_, Ort::TypeToTensorType<int32_t>::type);
} else {
seqlens_k_ = sb_seqlens_k_->GetOrCreateTensor(senlens_k_shape_, Ort::TypeToTensorType<int32_t>::type);
seqlens_k_ = sb_seqlens_k_->CreateTensorOnStaticBuffer(senlens_k_shape_, Ort::TypeToTensorType<int32_t>::type);
}
state_.inputs_[seqlens_k_input_index_] = seqlens_k_.get();
cudaMemcpyAsync(seqlens_k_->GetTensorMutableRawData(), initial_sequence_lengths_.data(), sizeof(int32_t) * initial_sequence_lengths_.size(), cudaMemcpyHostToDevice, model_.cuda_stream_);
Expand Down
3 changes: 1 addition & 2 deletions src/models/position_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,7 @@ struct PositionMetadata {
std::unique_ptr<OrtValue> attention_mask_;
std::array<int64_t, 1> senlens_k_shape_{}; // {params.batch_size*params.beam_size}
std::unique_ptr<OrtValue> seqlens_k_;
std::array<int64_t, 0> total_sequence_length_shape_{}; // scalar
std::unique_ptr<OrtValue> total_sequence_length_;
std::unique_ptr<OrtValue> total_sequence_length_; // Scalar

std::unique_ptr<OrtValue> position_ids_next_; // Replaces position_ids_ after the first Run() call
std::vector<int32_t> initial_sequence_lengths_;
Expand Down
4 changes: 2 additions & 2 deletions src/models/static_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ namespace Generators {
StaticBuffer::StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size) : allocator_{allocator}, info_{allocator_->GetInfo()}, max_beam_batch_size_{max_beam_batch_size} {
}

std::unique_ptr<OrtValue> StaticBuffer::GetOrCreateTensor(std::span<const int64_t> shape,
ONNXTensorElementDataType type) {
std::unique_ptr<OrtValue> StaticBuffer::CreateTensorOnStaticBuffer(std::span<const int64_t> shape,
ONNXTensorElementDataType type) {
size_t new_bytes = GetElementSize(type) * GetNumElements(shape);
if (buffer_ == nullptr) {
// Assuming the first dimension is the batch size
Expand Down
4 changes: 2 additions & 2 deletions src/models/static_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ struct StaticBuffer {
StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size);
~StaticBuffer();

std::unique_ptr<OrtValue> GetOrCreateTensor(std::span<const int64_t> shape,
ONNXTensorElementDataType type);
std::unique_ptr<OrtValue> CreateTensorOnStaticBuffer(std::span<const int64_t> shape,
ONNXTensorElementDataType type);

private:
size_t GetElementSize(ONNXTensorElementDataType type);
Expand Down
16 changes: 8 additions & 8 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
self.io_dtype = io_dtype # {'fp16', 'fp32'}
self.onnx_dtype = onnx_dtype # {"int4", "fp16", "fp32"}
self.ep = ep
self.enable_cuda_graph = "enable_cuda_graph" in extra_options and extra_options["enable_cuda_graph"] == "1"
self.can_use_cuda_graph = "can_use_cuda_graph" in extra_options and extra_options["can_use_cuda_graph"] == "1"

self.cache_dir = cache_dir
self.filename = extra_options["filename"] if "filename" in extra_options else "model.onnx"
Expand Down Expand Up @@ -215,8 +215,8 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):

if self.ep == "cuda":
cuda_options = { "cuda" : { } }
if self.enable_cuda_graph:
cuda_options["cuda"]["enable_cuda_graph"] = "1"
if self.can_use_cuda_graph:
cuda_options["cuda"]["can_use_cuda_graph"] = "1"
genai_config["model"]["decoder"]["session_options"]["provider_options"].append(cuda_options)

print(f"Saving GenAI config in {out_dir}")
Expand Down Expand Up @@ -822,8 +822,8 @@ def make_attention_op(self, name, **kwargs):
if op_type == "MultiHeadAttention":
self.make_multi_head_attention(name, add_qk=f"{self.mask_attrs['mask_name']}/output_0", **kwargs)
elif op_type == "GroupQueryAttention":
seqlens_k_name = f"{self.mask_attrs['seqlens_k']}/output_0" if not self.enable_cuda_graph else "seqlens_k"
total_seq_len_name = f"{self.mask_attrs['total_seq_len']}/output_0" if not self.enable_cuda_graph else "total_seq_len"
seqlens_k_name = f"{self.mask_attrs['seqlens_k']}/output_0" if not self.can_use_cuda_graph else "seqlens_k"
total_seq_len_name = f"{self.mask_attrs['total_seq_len']}/output_0" if not self.can_use_cuda_graph else "total_seq_len"
self.make_group_query_attention(name, seqlens_k=seqlens_k_name, total_seq_len=total_seq_len_name, **kwargs)
else:
raise NotImplementedError(f"The {op_type} op is not currently supported.")
Expand Down Expand Up @@ -1565,7 +1565,7 @@ def make_position_ids_reformatting(self):
class LlamaModel(Model):
def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
super().__init__(config, io_dtype, onnx_dtype, ep, cache_dir, extra_options)
if self.enable_cuda_graph:
if self.can_use_cuda_graph:
self.model_inputs = ["input_ids", "position_ids", "seqlens_k", "total_seq_len"]
else:
self.model_inputs = ["input_ids", "attention_mask", "position_ids"]
Expand All @@ -1575,7 +1575,7 @@ def make_attention_mask_reformatting(self):
super().make_attention_mask_reformatting()
return

if self.enable_cuda_graph:
if self.can_use_cuda_graph:
# ORT does not allow nodes to be placed on mulitple execution providers
# with cuda graph enabled. Thus the attention mask is deprecated and the
# subgraph is replaced with seqlens_k and total_seq_len as the raw
Expand Down Expand Up @@ -1761,7 +1761,7 @@ def get_args():
The filename for each component will be '<filename>_<component-name>.onnx' (ex: '<filename>_encoder.onnx', '<filename>_decoder.onnx').
config_only = Generate config and pre/post processing files only.
Use this option when you already have your optimized and/or quantized ONNX model.
enable_cuda_graph = 1 : Enable CUDA graph capture for CUDA execution provider. Limitations may apply.
can_use_cuda_graph = 1 : The model can use CUDA graph capture for CUDA execution provider. Limitations may apply.
"""),
)

Expand Down
8 changes: 7 additions & 1 deletion src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ struct PyGeneratorParams {
}
}

void TryUseCudaGraphWithMaxBatchSize(pybind11::int_ max_batch_size) {
params_->max_batch_size = max_batch_size.cast<int>();
}

pybind11::array_t<int32_t> py_input_ids_;
pybind11::array_t<float> py_whisper_input_features_;
pybind11::array_t<int32_t> py_whisper_decoder_input_ids_;
Expand All @@ -120,6 +124,7 @@ struct PyGeneratorParams {
struct PyGenerator {
PyGenerator(Model& model, PyGeneratorParams& params) {
params.Prepare();
model.GetMaxBatchSizeFromGeneratorParams(params);
generator_ = CreateGenerator(model, params);
}

Expand Down Expand Up @@ -179,7 +184,8 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_)
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
.def_readwrite("whisper_decoder_input_ids", &PyGeneratorParams::py_whisper_decoder_input_ids_)
.def("set_search_options", &PyGeneratorParams::SetSearchOptions);
.def("set_search_options", &PyGeneratorParams::SetSearchOptions)
.def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize);

// We need to init the OrtApi before we can use it
Ort::InitApi();
Expand Down
6 changes: 3 additions & 3 deletions test/python/test_onnxruntime_genai_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def download_model(
):
# python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cpu -o download_path
# Or with cuda graph enabled:
# python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cuda --extra_options enable_cuda_graph=1 -o download_path
# python -m onnxruntime_genai.models.builder -m microsoft/phi-2 -p int4 -e cuda --extra_options can_use_cuda_graph=1 -o download_path
command = [
sys.executable,
"-m",
Expand All @@ -30,7 +30,7 @@ def download_model(
]
if device == "cuda":
command.append("--extra_options")
command.append("enable_cuda_graph=1")
command.append("can_use_cuda_graph=1")
run_subprocess(command).check_returncode()


Expand All @@ -47,6 +47,7 @@ def run_model(model_path: str | bytes | os.PathLike):
sequences = tokenizer.encode_batch(prompts)
params = og.GeneratorParams(model)
params.set_search_options({"max_length": 200})
params.try_use_cuda_graph_with_max_batch_size(16)
params.input_ids = sequences

output_sequences = model.generate(params)
Expand All @@ -61,4 +62,3 @@ def run_model(model_path: str | bytes | os.PathLike):
device = "cuda" if og.is_cuda_available() else "cpu"
download_model(temp_dir, device, model_name, precision)
run_model(temp_dir)

0 comments on commit 559c2d3

Please sign in to comment.