Skip to content

Commit

Permalink
add unimplemented types to output tensor and move path for onnxruntim…
Browse files Browse the repository at this point in the history
…e_c_api.h
  • Loading branch information
devigned committed Nov 7, 2023
1 parent 8d78074 commit 8aeac1a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 22 deletions.
2 changes: 0 additions & 2 deletions rust/onnxruntime-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,6 @@ fn generate_bindings(include_dir: &Path) {

let path = include_dir
.join("onnxruntime")
.join("core")
.join("session")
.join("onnxruntime_c_api.h");

// The bindgen::Builder is the main entry point
Expand Down
2 changes: 1 addition & 1 deletion rust/onnxruntime-sys/examples/c_api_sample.rs
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ fn main() {

let output_node_names_cstring: Vec<std::ffi::CString> = output_node_names
.iter()
.map(|n| std::ffi::CString::new(n.clone()).unwrap())
.map(|n| std::ffi::CString::new(*n).unwrap())
.collect();
let output_node_names_ptr: Vec<*const i8> = output_node_names_cstring
.iter()
Expand Down
33 changes: 14 additions & 19 deletions rust/onnxruntime/src/tensor/ort_output_tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,8 @@ impl<'a, T> std::ops::Deref for WithOutputTensor<'a, T> {
}

impl<'a, T> TryFrom<OrtOutputTensor> for WithOutputTensor<'a, T>
where
T: TypeToTensorElementDataType,
where
T: TypeToTensorElementDataType,
{
type Error = OrtError;

Expand Down Expand Up @@ -290,9 +290,6 @@ impl<'a> TryFrom<OrtOutputTensor> for OrtOutput<'a> {
.unwrap()(shape_info);

match element_type {
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED => {
unimplemented!()
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT => {
WithOutputTensor::try_from(value).map(OrtOutput::Float)
}
Expand All @@ -317,12 +314,6 @@ impl<'a> TryFrom<OrtOutputTensor> for OrtOutput<'a> {
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_STRING => {
WithOutputTensor::try_from(value).map(OrtOutput::String)
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL => {
unimplemented!()
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16 => {
unimplemented!()
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_DOUBLE => {
WithOutputTensor::try_from(value).map(OrtOutput::Double)
}
Expand All @@ -332,14 +323,18 @@ impl<'a> TryFrom<OrtOutputTensor> for OrtOutput<'a> {
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UINT64 => {
WithOutputTensor::try_from(value).map(OrtOutput::UInt64)
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64 => {
unimplemented!()
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128 => {
unimplemented!()
}
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16 => {
unimplemented!()
// Unimplemented output tensor data types
sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX64
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_UNDEFINED
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BOOL
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT16
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_COMPLEX128
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_BFLOAT16
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FN
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E4M3FNUZ
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2FNUZ
| sys::ONNXTensorElementDataType::ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT8E5M2 => {
unimplemented!("{:?}", element_type)
}
}
}
Expand Down

0 comments on commit 8aeac1a

Please sign in to comment.