diff --git a/src/generators.h b/src/generators.h index c6a510739..e6ad6f0e1 100644 --- a/src/generators.h +++ b/src/generators.h @@ -99,6 +99,14 @@ struct GeneratorParams : std::enable_shared_from_this { std::shared_ptr external_owner_; // Set to 'this' when created by the C API to preserve lifetime + struct Input { + std::string name; + std::unique_ptr value; + }; + + // A list of extra model inputs that will be matched at runtime based on name + std::vector extra_inputs; + void TryGraphCapture(int max_bs); private: diff --git a/src/models/model.cpp b/src/models/model.cpp index 6f0cc294a..35a9b4ad4 100644 --- a/src/models/model.cpp +++ b/src/models/model.cpp @@ -35,6 +35,11 @@ static std::wstring CurrentModulePath() { namespace Generators { State::State(const GeneratorParams& params) : params_{params.shared_from_this()} { + // Add extra user inputs + for (auto& input : params.extra_inputs) { + input_names_.push_back(input.name.c_str()); + inputs_.push_back(input.value.get()); + } } void State::Run(OrtSession& session, OrtRunOptions& run_options) { diff --git a/src/models/model.h b/src/models/model.h index 5b9ec12d9..165e7c345 100644 --- a/src/models/model.h +++ b/src/models/model.h @@ -16,6 +16,8 @@ struct Tokenizer; void ConvertFp16ToFp32(OrtAllocator& allocator, OrtValue& in, std::unique_ptr& p_out, DeviceType device_type, cudaStream_t stream); +size_t GetOrtTypeSize(ONNXTensorElementDataType type); + struct State { State(const GeneratorParams& params); virtual ~State() = default; diff --git a/src/models/static_buffer.cpp b/src/models/static_buffer.cpp index 9bc5f50ea..eab776e65 100644 --- a/src/models/static_buffer.cpp +++ b/src/models/static_buffer.cpp @@ -1,4 +1,5 @@ #include "../generators.h" +#include "model.h" #include "static_buffer.h" namespace Generators { @@ -8,7 +9,7 @@ StaticBuffer::StaticBuffer(Ort::Allocator* allocator, size_t max_beam_batch_size std::unique_ptr StaticBuffer::CreateTensorOnStaticBuffer(std::span shape, ONNXTensorElementDataType type) { - size_t new_bytes = GetElementSize(type) * GetNumElements(shape); + size_t new_bytes = GetOrtTypeSize(type) * GetNumElements(shape); if (buffer_ == nullptr) { // Assuming the first dimension is the batch size bytes_ = new_bytes * (max_beam_batch_size_ / shape[0]); @@ -21,22 +22,6 @@ std::unique_ptr StaticBuffer::CreateTensorOnStaticBuffer(std::span shape) { size_t num_elements = 1; for (auto dim : shape) { diff --git a/src/models/static_buffer.h b/src/models/static_buffer.h index ce9e14686..8c133fdae 100644 --- a/src/models/static_buffer.h +++ b/src/models/static_buffer.h @@ -18,7 +18,6 @@ struct StaticBuffer { ONNXTensorElementDataType type); private: - size_t GetElementSize(ONNXTensorElementDataType type); size_t GetNumElements(std::span shape); Ort::Allocator* allocator_{nullptr}; diff --git a/src/python/python.cpp b/src/python/python.cpp index 1d8a4e567..8bd25a9d3 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -22,6 +22,34 @@ pybind11::array_t ToPython(std::span v) { return pybind11::array_t(v.size(), v.data()); } +ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) { + switch (type.num()) { + case pybind11::detail::npy_api::NPY_INT32_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_UINT32_: + return Ort::TypeToTensorType::type; + case 23 /*NPY_FLOAT16*/: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_FLOAT_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_DOUBLE_: + return Ort::TypeToTensorType::type; + default: + throw std::runtime_error("Unsupported numpy type"); + } +} + +std::unique_ptr ToTensor(pybind11::array& v) { + auto type = ToTensorType(v.dtype()); + + std::vector shape(v.ndim()); + for (pybind11::ssize_t i = 0; i < v.ndim(); i++) + shape[i] = v.shape()[i]; + + auto p_memory_info = OrtMemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); + return OrtValue::CreateTensor(*p_memory_info, v.mutable_data(), v.nbytes(), shape, type); +} + namespace Generators { // A roaming array is one that can be in CPU or GPU memory, and will copy the memory as needed to be used from anywhere @@ -85,6 +113,11 @@ struct PyGeneratorParams { } } + void AddExtraInput(const std::string& name, pybind11::array& value) { + params_->extra_inputs.push_back({name, ToTensor(value)}); + refs_.emplace_back(value); + } + void SetSearchOptions(const pybind11::kwargs& dict) { for (auto& entry : dict) { auto name = entry.first.cast(); @@ -110,6 +143,8 @@ struct PyGeneratorParams { pybind11::array_t py_input_ids_; pybind11::array_t py_whisper_input_features_; pybind11::array_t py_whisper_decoder_input_ids_; + + std::vector refs_; // References to data we want to ensure doesn't get garbage collected }; struct PyGenerator { @@ -198,6 +233,7 @@ PYBIND11_MODULE(onnxruntime_genai, m) { .def_readwrite("input_ids", &PyGeneratorParams::py_input_ids_) .def_readwrite("whisper_input_features", &PyGeneratorParams::py_whisper_input_features_) .def_readwrite("whisper_decoder_input_ids", &PyGeneratorParams::py_whisper_decoder_input_ids_) + .def("add_extra_input", &PyGeneratorParams::AddExtraInput) .def("set_search_options", &PyGeneratorParams::SetSearchOptions) // See config.h 'struct Search' for the options .def("try_use_cuda_graph_with_max_batch_size", &PyGeneratorParams::TryUseCudaGraphWithMaxBatchSize);