Skip to content

Commit

Permalink
Add all numpy sized integer types to ToTensorType (#398)
Browse files Browse the repository at this point in the history
  • Loading branch information
RyanUnderhill authored May 3, 2024
1 parent 57a0a8c commit 31afa61
Showing 1 changed file with 14 additions and 2 deletions.
16 changes: 14 additions & 2 deletions src/python/python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,22 @@ pybind11::array_t<T> ToPython(std::span<T> v) {

ONNXTensorElementDataType ToTensorType(const pybind11::dtype& type) {
switch (type.num()) {
case pybind11::detail::npy_api::NPY_INT32_:
return Ort::TypeToTensorType<int32_t>::type;
case pybind11::detail::npy_api::NPY_UINT8_:
return Ort::TypeToTensorType<uint8_t>::type;
case pybind11::detail::npy_api::NPY_INT8_:
return Ort::TypeToTensorType<int8_t>::type;
case pybind11::detail::npy_api::NPY_UINT16_:
return Ort::TypeToTensorType<uint16_t>::type;
case pybind11::detail::npy_api::NPY_INT16_:
return Ort::TypeToTensorType<int16_t>::type;
case pybind11::detail::npy_api::NPY_UINT32_:
return Ort::TypeToTensorType<uint32_t>::type;
case pybind11::detail::npy_api::NPY_INT32_:
return Ort::TypeToTensorType<int32_t>::type;
case pybind11::detail::npy_api::NPY_UINT64_:
return Ort::TypeToTensorType<uint64_t>::type;
case pybind11::detail::npy_api::NPY_INT64_:
return Ort::TypeToTensorType<int64_t>::type;
case 23 /*NPY_FLOAT16*/:
return Ort::TypeToTensorType<Ort::Float16_t>::type;
case pybind11::detail::npy_api::NPY_FLOAT_:
Expand Down

0 comments on commit 31afa61

Please sign in to comment.