Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add 'add_extra_input' to handle models like QLora #370

Merged
merged 5 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {

std::shared_ptr<GeneratorParams> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

struct Input {
std::string name;
std::unique_ptr<OrtValue> value;
};

// A list of extra model inputs that will be matched at runtime based on name
std::vector<Input> extra_inputs;

void TryGraphCapture(int max_bs);

private:
Expand Down
5 changes: 5 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ static std::wstring CurrentModulePath() {
namespace Generators {

State::State(const GeneratorParams& params) : params_{params.shared_from_this()} {
// Add extra user inputs
for (auto& input : params.extra_inputs) {
input_names_.push_back(input.name.c_str());
inputs_.push_back(input.value.get());
}
}

void State::Run(OrtSession& session, OrtRunOptions& run_options) {
Expand Down
2 changes: 2 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct Tokenizer;

void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);

size_t GetOrtTypeSize(ONNXTensorElementDataType type);

struct State {
State(const GeneratorParams& params);
virtual ~State() = default;
Expand Down
19 changes: 2 additions & 17 deletions src/models/static_buffer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "../generators.h"
#include "model.h"
#include "static_buffer.h"

namespace Generators {
Expand All @@ -8,7 +9,7 @@ StaticBuffer::StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size

std::unique_ptr<OrtValue> StaticBuffer::CreateTensorOnStaticBuffer(std::span<const int64_t> shape,
ONNXTensorElementDataType type) {
size_t new_bytes = GetElementSize(type) * GetNumElements(shape);
size_t new_bytes = GetOrtTypeSize(type) * GetNumElements(shape);
if (buffer_ == nullptr) {
// Assuming the first dimension is the batch size
bytes_ = new_bytes * (max_beam_batch_size_ / shape[0]);
Expand All @@ -21,22 +22,6 @@ std::unique_ptr<OrtValue> StaticBuffer::CreateTensorOnStaticBuffer(std::span<con
return OrtValue::CreateTensor(info_, buffer_, new_bytes, shape, type);
}

// TODO: same as GetOrtTypeSize() in model.cc. Should be moved to a common place
size_t StaticBuffer::GetElementSize(ONNXTensorElementDataType type) {
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return sizeof(uint16_t);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return sizeof(float);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return sizeof(int32_t);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return sizeof(int64_t);
default:
throw std::runtime_error("Unsupported tensor element data type");
}
}

size_t StaticBuffer::GetNumElements(std::span<const int64_t> shape) {
size_t num_elements = 1;
for (auto dim : shape) {
Expand Down
1 change: 0 additions & 1 deletion src/models/static_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct StaticBuffer {
ONNXTensorElementDataType type);

private:
size_t GetElementSize(ONNXTensorElementDataType type);
size_t GetNumElements(std::span<const int64_t> shape);

Ort::Allocator* allocator_{nullptr};
Expand Down
36 changes: 36 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ pybind11::array_t<T> ToPython(std::span<T> v) {
return pybind11::array_t<T>(v.size(), v.data());
}

ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) {
switch (type.num()) {
case pybind11::detail::npy_api::NPY_INT32_:
return Ort::TypeToTensorType<int32_t>::type;
case pybind11::detail::npy_api::NPY_UINT32_:
return Ort::TypeToTensorType<uint32_t>::type;
case 23 /*NPY_FLOAT16*/:
return Ort::TypeToTensorType<Ort::Float16_t>::type;
case pybind11::detail::npy_api::NPY_FLOAT_:
return Ort::TypeToTensorType<float>::type;
case pybind11::detail::npy_api::NPY_DOUBLE_:
return Ort::TypeToTensorType<double>::type;
default:
throw std::runtime_error("Unsupported numpy type");
}
}

std::unique_ptr<OrtValue> ToTensor(pybind11::array& v) {
auto type = ToTensorType(v.dtype());

std::vector<int64_t> shape(v.ndim());
for (pybind11::ssize_t i = 0; i < v.ndim(); i++)
shape[i] = v.shape()[i];

auto p_memory_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
return OrtValue::CreateTensor(*p_memory_info, v.mutable_data(), v.nbytes(), shape, type);
}

namespace Generators {

// A roaming array is one that can be in CPU or GPU memory, and will copy the memory as needed to be used from anywhere
Expand Down Expand Up @@ -85,6 +113,11 @@ struct PyGeneratorParams {
}
}

void AddExtraInput(const std::string& name, pybind11::array& value) {
params_->extra_inputs.push_back({name, ToTensor(value)});
refs_.emplace_back(value);
}

void SetSearchOptions(const pybind11::kwargs& dict) {
for (auto& entry : dict) {
auto name = entry.first.cast<std::string>();
Expand All @@ -110,6 +143,8 @@ struct PyGeneratorParams {
pybind11::array_t<int32_t> py_input_ids_;
pybind11::array_t<float> py_whisper_input_features_;
pybind11::array_t<int32_t> py_whisper_decoder_input_ids_;

std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
};

struct PyGenerator {
Expand Down Expand Up @@ -198,6 +233,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_)
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
.def_readwrite("whisper_decoder_input_ids", &PyGeneratorParams::py_whisper_decoder_input_ids_)
.def("add_extra_input", &PyGeneratorParams::AddExtraInput)
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
.def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize);

Expand Down
Loading