Skip to content

Commit

Permalink
Add 'add_extra_input' to handle models like QLora (#370)
Browse files Browse the repository at this point in the history
Add a new python api 'add_extra_input' that will take numpy tensors and
turn them into OrtValue inputs internally.
This allows models with extra custom inputs (like QLora) to be specified
in python.

C API to follow soon.
  • Loading branch information
RyanUnderhill authored May 1, 2024
1 parent f94280f commit b3ff5ce
Show file tree
Hide file tree
Showing 6 changed files with 53 additions and 18 deletions.
8 changes: 8 additions & 0 deletions src/generators.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,14 @@ struct GeneratorParams : std::enable_shared_from_this<GeneratorParams> {

std::shared_ptr<GeneratorParams> external_owner_; // Set to 'this' when created by the C API to preserve lifetime

struct Input {
std::string name;
std::unique_ptr<OrtValue> value;
};

// A list of extra model inputs that will be matched at runtime based on name
std::vector<Input> extra_inputs;

void TryGraphCapture(int max_bs);

private:
Expand Down
5 changes: 5 additions & 0 deletions src/models/model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ static std::wstring CurrentModulePath() {
namespace Generators {

State::State(const GeneratorParams& params) : params_{params.shared_from_this()} {
// Add extra user inputs
for (auto& input : params.extra_inputs) {
input_names_.push_back(input.name.c_str());
inputs_.push_back(input.value.get());
}
}

void State::Run(OrtSession& session, OrtRunOptions& run_options) {
Expand Down
2 changes: 2 additions & 0 deletions src/models/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ struct Tokenizer;

void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr<OrtValue>& p_out, DeviceType device_type, cudaStream_t stream);

size_t GetOrtTypeSize(ONNXTensorElementDataType type);

struct State {
State(const GeneratorParams& params);
virtual ~State() = default;
Expand Down
19 changes: 2 additions & 17 deletions src/models/static_buffer.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include "../generators.h"
#include "model.h"
#include "static_buffer.h"

namespace Generators {
Expand All @@ -8,7 +9,7 @@ StaticBuffer::StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size

std::unique_ptr<OrtValue> StaticBuffer::CreateTensorOnStaticBuffer(std::span<const int64_t> shape,
ONNXTensorElementDataType type) {
size_t new_bytes = GetElementSize(type) * GetNumElements(shape);
size_t new_bytes = GetOrtTypeSize(type) * GetNumElements(shape);
if (buffer_ == nullptr) {
// Assuming the first dimension is the batch size
bytes_ = new_bytes * (max_beam_batch_size_ / shape[0]);
Expand All @@ -21,22 +22,6 @@ std::unique_ptr<OrtValue> StaticBuffer::CreateTensorOnStaticBuffer(std::span<con
return OrtValue::CreateTensor(info_, buffer_, new_bytes, shape, type);
}

// TODO: same as GetOrtTypeSize() in model.cc. Should be moved to a common place
size_t StaticBuffer::GetElementSize(ONNXTensorElementDataType type) {
switch (type) {
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16:
return sizeof(uint16_t);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT:
return sizeof(float);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32:
return sizeof(int32_t);
case ONNX_TENSOR_ELEMENT_DATA_TYPE_INT64:
return sizeof(int64_t);
default:
throw std::runtime_error("Unsupported tensor element data type");
}
}

size_t StaticBuffer::GetNumElements(std::span<const int64_t> shape) {
size_t num_elements = 1;
for (auto dim : shape) {
Expand Down
1 change: 0 additions & 1 deletion src/models/static_buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ struct StaticBuffer {
ONNXTensorElementDataType type);

private:
size_t GetElementSize(ONNXTensorElementDataType type);
size_t GetNumElements(std::span<const int64_t> shape);

Ort::Allocator* allocator_{nullptr};
Expand Down
36 changes: 36 additions & 0 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,34 @@ pybind11::array_t<T> ToPython(std::span<T> v) {
return pybind11::array_t<T>(v.size(), v.data());
}

ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) {
switch (type.num()) {
case pybind11::detail::npy_api::NPY_INT32_:
return Ort::TypeToTensorType<int32_t>::type;
case pybind11::detail::npy_api::NPY_UINT32_:
return Ort::TypeToTensorType<uint32_t>::type;
case 23 /*NPY_FLOAT16*/:
return Ort::TypeToTensorType<Ort::Float16_t>::type;
case pybind11::detail::npy_api::NPY_FLOAT_:
return Ort::TypeToTensorType<float>::type;
case pybind11::detail::npy_api::NPY_DOUBLE_:
return Ort::TypeToTensorType<double>::type;
default:
throw std::runtime_error("Unsupported numpy type");
}
}

std::unique_ptr<OrtValue> ToTensor(pybind11::array& v) {
auto type = ToTensorType(v.dtype());

std::vector<int64_t> shape(v.ndim());
for (pybind11::ssize_t i = 0; i < v.ndim(); i++)
shape[i] = v.shape()[i];

auto p_memory_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
return OrtValue::CreateTensor(*p_memory_info, v.mutable_data(), v.nbytes(), shape, type);
}

namespace Generators {

// A roaming array is one that can be in CPU or GPU memory, and will copy the memory as needed to be used from anywhere
Expand Down Expand Up @@ -85,6 +113,11 @@ struct PyGeneratorParams {
}
}

void AddExtraInput(const std::string& name, pybind11::array& value) {
params_->extra_inputs.push_back({name, ToTensor(value)});
refs_.emplace_back(value);
}

void SetSearchOptions(const pybind11::kwargs& dict) {
for (auto& entry : dict) {
auto name = entry.first.cast<std::string>();
Expand All @@ -110,6 +143,8 @@ struct PyGeneratorParams {
pybind11::array_t<int32_t> py_input_ids_;
pybind11::array_t<float> py_whisper_input_features_;
pybind11::array_t<int32_t> py_whisper_decoder_input_ids_;

std::vector<pybind11::object> refs_; // References to data we want to ensure doesn't get garbage collected
};

struct PyGenerator {
Expand Down Expand Up @@ -198,6 +233,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) {
.def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_)
.def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_)
.def_readwrite("whisper_decoder_input_ids", &PyGeneratorParams::py_whisper_decoder_input_ids_)
.def("add_extra_input", &PyGeneratorParams::AddExtraInput)
.def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options
.def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize);

Expand Down

0 comments on commit b3ff5ce

Please sign in to comment.