diff --git a/src/python/python.cpp b/src/python/python.cpp index 8bd25a9d3..ba685c10d 100644 --- a/src/python/python.cpp +++ b/src/python/python.cpp @@ -24,10 +24,22 @@ pybind11::array_t ToPython(std::span v) { ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) { switch (type.num()) { - case pybind11::detail::npy_api::NPY_INT32_: - return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_UINT8_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_INT8_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_UINT16_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_INT16_: + return Ort::TypeToTensorType::type; case pybind11::detail::npy_api::NPY_UINT32_: return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_INT32_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_UINT64_: + return Ort::TypeToTensorType::type; + case pybind11::detail::npy_api::NPY_INT64_: + return Ort::TypeToTensorType::type; case 23 /*NPY_FLOAT16*/: return Ort::TypeToTensorType::type; case pybind11::detail::npy_api::NPY_FLOAT_: