Skip to content

Commit

Permalink
max_batch
Browse files Browse the repository at this point in the history
  • Loading branch information
wangyems committed Mar 14, 2024
1 parent 905c2ba commit 0ea99db
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 13 deletions.
2 changes: 1 addition & 1 deletion src/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ bool IsCudaGraphEnabled(Config::SessionOptions& session_options) {
for (const auto& value : provider_options.options) {
if (value.first == "enable_cuda_graph") {
// Is it the correct value string?
return value.second == "true" || value.second == "1";
return value.second == "1";
}
}
}
Expand Down
2 changes: 2 additions & 0 deletions src/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ struct Config {
std::filesystem::path config_path; // Path of the config directory

bool use_cuda_graphs{false};
//TODO: pass it from config file
size_t max_batch_size{16};

using ProviderOption = std::pair<std::string, std::string>;
struct ProviderOptions {
Expand Down
2 changes: 1 addition & 1 deletion src/models/decoder_only.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void DecoderOnly_State::UpdateInputs(const RoamingArray<int32_t>& next_tokens_un

int DecoderOnly_State::GetGraphAnnotationId() const {
// Here we use the batch size as the graph annotation id.
return input_ids_.GetShape()[0];
return static_cast<int>(input_ids_.GetShape()[0]);
}

} // namespace Generators
3 changes: 2 additions & 1 deletion src/models/input_ids.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ InputIDs::InputIDs(const Model& model, State& state)
shape_[0] *= state_.params_.search.num_beams;

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
sb_input_ids_ = std::make_unique<StaticBuffer>(model_.allocator_device_);
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
sb_input_ids_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
}

Expand Down
3 changes: 2 additions & 1 deletion src/models/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,9 +148,10 @@ KV_Cache::KV_Cache(const Model& model, State& state)

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
assert(past_present_share_buffer_);
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
sb_kv_caches_.reserve(layer_count_ * 2);
for (int i = 0; i < layer_count_ * 2; ++i) {
sb_kv_caches_.push_back(std::make_unique<StaticBuffer>(model_.allocator_device_));
sb_kv_caches_.push_back(std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size));
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/models/logits.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,12 @@ Logits::Logits(const Model& model, State& state)
value16_ = std::move(logits_tensor);

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
if (type_ == Ort::TypeToTensorType<float>::type) {
sb_logits32_ = std::make_unique<StaticBuffer>(model_.allocator_device_);
sb_logits32_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
if (type_ == Ort::TypeToTensorType<Ort::Float16_t>::type) {
sb_logits16_ = std::make_unique<StaticBuffer>(model_.allocator_device_);
sb_logits16_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
}
}
Expand Down
5 changes: 3 additions & 2 deletions src/models/position_metadata.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ PositionMetadata::PositionMetadata(const Model& model, State& state, RoamingArra
attention_mask_shape_ = shape;

if (model_.device_type_ == DeviceType::CUDA && model_.config_->use_cuda_graphs) {
sb_position_ids_ = std::make_unique<StaticBuffer>(model_.allocator_device_);
sb_seqlens_k_ = std::make_unique<StaticBuffer>(model_.allocator_device_);
size_t max_beam_batch_size = model_.config_->search.num_beams * model_.config_->max_batch_size;
sb_position_ids_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
sb_seqlens_k_ = std::make_unique<StaticBuffer>(model_.allocator_device_, max_beam_batch_size);
}
}

Expand Down
10 changes: 6 additions & 4 deletions src/models/static_buffer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@

namespace Generators {

StaticBuffer::StaticBuffer(Ort::Allocator* allocator) : allocator_{allocator}, info_{allocator_->GetInfo()} {
StaticBuffer::StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size) :
allocator_{allocator}, info_{allocator_->GetInfo()}, max_beam_batch_size_{max_beam_batch_size} {
}

std::unique_ptr<OrtValue> StaticBuffer::GetOrCreateTensor(std::span<const int64_t> shape,
ONNXTensorElementDataType type) {
size_t new_bytes = GetElementSize(type) * GetNumElements(shape);
if (buffer_ == nullptr) {
buffer_ = allocator_->Alloc(new_bytes);
bytes_ = new_bytes;
return OrtValue::CreateTensor(info_, buffer_, bytes_, shape, type);
// Assuming the first dimension is the batch size
bytes_ = new_bytes * (max_beam_batch_size_ / shape[0]);
buffer_ = allocator_->Alloc(bytes_);
return OrtValue::CreateTensor(info_, buffer_, new_bytes, shape, type);
}
if (new_bytes > bytes_) {
std::runtime_error("StaticBuffer: new_bytes > bytes_");
Expand Down
3 changes: 2 additions & 1 deletion src/models/static_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ namespace Generators {

struct StaticBuffer {
// Add max_beam_batch_size to the constructor
StaticBuffer(Ort::Allocator* allocator);
StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size);
~StaticBuffer();

std::unique_ptr<OrtValue> GetOrCreateTensor(std::span<const int64_t> shape,
Expand All @@ -18,6 +18,7 @@ struct StaticBuffer {
const OrtMemoryInfo& info_;
void* buffer_{nullptr};
size_t bytes_{0};
size_t max_beam_batch_size_{0};
};

} // namespace Generators

0 comments on commit 0ea99db

Please sign in to comment.