From 7c007a937710f80340a468fa1e99d4fe6def021e Mon Sep 17 00:00:00 2001 From: uttarayan21 Date: Wed, 16 Oct 2024 14:21:47 +0530 Subject: [PATCH] feat(mnn-bridge): Added support for ndarray@0.15 for mnn-bridge --- Cargo.lock | 17 ++- mnn-bridge/Cargo.toml | 3 + mnn-bridge/src/lib.rs | 7 +- mnn-bridge/src/ndarray.rs | 216 +++++++++++++++++++++----------------- 4 files changed, 142 insertions(+), 101 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index fbbebaf..40d7701 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -532,8 +532,10 @@ dependencies = [ name = "mnn-bridge" version = "0.1.0" dependencies = [ + "error-stack", "mnn", - "ndarray", + "ndarray 0.15.6", + "ndarray 0.16.1", ] [[package]] @@ -571,6 +573,19 @@ dependencies = [ "getrandom", ] +[[package]] +name = "ndarray" +version = "0.15.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb12d4e967ec485a5f71c6311fe28158e9d6f4bc4a447b474184d0f91a8fa32" +dependencies = [ + "matrixmultiply", + "num-complex", + "num-integer", + "num-traits", + "rawpointer", +] + [[package]] name = "ndarray" version = "0.16.1" diff --git a/mnn-bridge/Cargo.toml b/mnn-bridge/Cargo.toml index deda01d..ed08658 100644 --- a/mnn-bridge/Cargo.toml +++ b/mnn-bridge/Cargo.toml @@ -5,12 +5,15 @@ edition = "2021" license = { workspace = true } [dependencies] +error-stack = "0.5.0" mnn = { workspace = true } ndarray = { version = "0.16", optional = true } +ndarray_0_15 = { package = "ndarray", version = "0.15", optional = true } # opencv = { version = "0.92.3", default-features = false, optional = true } [features] ndarray = ["dep:ndarray"] +ndarray_0_15 = ["dep:ndarray_0_15"] # opencv = ["dep:opencv"] default = [] diff --git a/mnn-bridge/src/lib.rs b/mnn-bridge/src/lib.rs index 2aa4b2a..b7796ce 100644 --- a/mnn-bridge/src/lib.rs +++ b/mnn-bridge/src/lib.rs @@ -1,4 +1,7 @@ #[cfg(feature = "ndarray")] pub mod ndarray; -// #[cfg(feature = "opencv")] -// pub mod opencv; +#[cfg(feature = "ndarray_0_15")] +mod ndarray_0_15 { + use ndarray_0_15 as ndarray; + include!("ndarray.rs"); +} diff --git a/mnn-bridge/src/ndarray.rs b/mnn-bridge/src/ndarray.rs index 7a3cf7f..6ba5df5 100644 --- a/mnn-bridge/src/ndarray.rs +++ b/mnn-bridge/src/ndarray.rs @@ -1,129 +1,149 @@ +use error_stack::*; +use ndarray::*; + +#[derive(Debug)] +pub struct MnnBridge; +impl Context for MnnBridge {} +impl core::fmt::Display for MnnBridge { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + write!(f, "MnnBridgeError") + } +} + pub trait MnnToNdarray { type H: mnn::HalideType; - fn as_ndarray(&self) -> ndarray::ArrayViewD { - self.try_as_ndarray() + fn as_ndarray(&self) -> ndarray::ArrayView { + self.try_as_ndarray::() .expect("Failed to create ndarray::ArrayViewD from mnn::Tensor") } - fn try_as_ndarray(&self) -> Option>; + fn try_as_ndarray(&self) -> Result, MnnBridge>; } -impl MnnToNdarray for mnn::Tensor -where - T: mnn::TensorType + mnn::HostTensorType, - T::H: mnn::HalideType, -{ - type H = T::H; - fn try_as_ndarray(&self) -> Option> { - let shape = self - .shape() - .as_ref() - .into_iter() - .copied() - .map(|i| i as usize) - .collect::>(); - let data = self.host(); - ndarray::ArrayViewD::from_shape(shape, data).ok() +pub trait MnnToNdarrayMut { + type H: mnn::HalideType; + fn as_ndarray_mut(&mut self) -> ndarray::ArrayViewMut { + self.try_as_ndarray_mut::() + .expect("Failed to create ndarray::ArrayViewMutD from mnn::Tensor") } + fn try_as_ndarray_mut( + &mut self, + ) -> Result, MnnBridge>; } -#[test] -pub fn test_tensor_to_ndarray_ref() { - let mut tensor: mnn::Tensor> = - mnn::Tensor::new([1, 2, 3], mnn::DimensionType::Caffe); - tensor.fill(64); - let ndarr = tensor.as_ndarray(); - let ndarr_2 = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()) - .unwrap() - .into_dyn(); - assert_eq!(ndarr, ndarr_2); +pub trait NdarrayToMnn { + type H: mnn::HalideType; + fn as_mnn_tensor(&self) -> Option>>>; } -pub trait MnnToNdarrayMut { +pub trait NdarrayToMnnMut { type H: mnn::HalideType; - fn as_ndarray_mut(&mut self) -> ndarray::ArrayViewMutD { - self.try_as_ndarray_mut() - .expect("Failed to create ndarray::ArrayViewMutD from mnn::Tensor") - } - fn try_as_ndarray_mut(&mut self) -> Option>; + fn as_mnn_tensor_mut(&mut self) -> Option>>>; } -impl MnnToNdarrayMut for mnn::Tensor -where - T: mnn::TensorType + mnn::MutableTensorType + mnn::HostTensorType, - T::H: mnn::HalideType, -{ - type H = T::H; - fn try_as_ndarray_mut(&mut self) -> Option> { - let shape = self - .shape() - .as_ref() - .into_iter() - .copied() - .map(|i| i as usize) - .collect::>(); - let data = self.host_mut(); - ndarray::ArrayViewMutD::from_shape(shape, data).ok() +const _: () = { + impl MnnToNdarray for mnn::Tensor + where + T: mnn::TensorType + mnn::HostTensorType, + T::H: mnn::HalideType, + { + type H = T::H; + fn try_as_ndarray( + &self, + ) -> Result, MnnBridge> { + let shape = self + .shape() + .as_ref() + .into_iter() + .copied() + .map(|i| i as usize) + .collect::>(); + let data = self.host(); + Ok(ndarray::ArrayViewD::from_shape(shape, data) + .change_context(MnnBridge)? + .into_dimensionality() + .change_context(MnnBridge)?) + } + } + + impl MnnToNdarrayMut for mnn::Tensor + where + T: mnn::TensorType + mnn::MutableTensorType + mnn::HostTensorType, + T::H: mnn::HalideType, + { + type H = T::H; + fn try_as_ndarray_mut( + &mut self, + ) -> Result, MnnBridge> { + let shape = self + .shape() + .as_ref() + .into_iter() + .copied() + .map(|i| i as usize) + .collect::>(); + let data = self.host_mut(); + Ok(ndarray::ArrayViewMutD::from_shape(shape, data) + .change_context(MnnBridge)? + .into_dimensionality() + .change_context(MnnBridge)?) + } } -} + impl NdarrayToMnn for ndarray::ArrayBase + where + A: ndarray::Data, + D: ndarray::Dimension, + T: mnn::HalideType, + { + type H = T; + fn as_mnn_tensor(&self) -> Option>>> { + let shape = self.shape().iter().map(|i| *i as i32).collect::>(); + let data = self.as_slice()?; + Some(mnn::Tensor::borrowed(shape, data)) + } + } + + impl NdarrayToMnnMut for ndarray::ArrayBase + where + A: ndarray::DataMut, + D: ndarray::Dimension, + T: mnn::HalideType, + { + type H = T; + fn as_mnn_tensor_mut(&mut self) -> Option>>> { + let shape = self.shape().iter().map(|i| *i as i32).collect::>(); + let data = self.as_slice_mut()?; + Some(mnn::Tensor::borrowed_mut(shape, data)) + } + } +}; +#[test] +pub fn test_tensor_to_ndarray_ref() { + let mut tensor: mnn::Tensor> = + mnn::Tensor::new([1, 2, 3], mnn::DimensionType::Caffe); + tensor.fill(64); + let ndarr = tensor.as_ndarray(); + let ndarr_other = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); + assert_eq!(ndarr, ndarr_other); +} #[test] pub fn test_tensor_to_ndarray_ref_mut() { let mut data = vec![100; 8 * 8 * 3]; let mut tensor: mnn::Tensor>> = mnn::Tensor::borrowed_mut([8, 8, 3], &mut data); - let mut ndarray = tensor.as_ndarray_mut(); + let mut ndarray = tensor.as_ndarray_mut::(); ndarray.fill(600); assert_eq!(data, [600; 8 * 8 * 3]); } - -pub trait NdarrayToMnn { - type H: mnn::HalideType; - fn as_mnn_tensor(&self) -> Option>>>; -} - -impl NdarrayToMnn for ndarray::ArrayBase -where - A: ndarray::Data, - D: ndarray::Dimension, - T: mnn::HalideType, -{ - type H = T; - fn as_mnn_tensor(&self) -> Option>>> { - let shape = self.shape().iter().map(|i| *i as i32).collect::>(); - let data = self.as_slice()?; - Some(mnn::Tensor::borrowed(shape, data)) - } +#[test] +pub fn test_ndarray_to_tensor_ref_mut() { + let mut arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); + arr.as_mnn_tensor_mut().unwrap().fill(600); + assert_eq!(arr.as_slice().unwrap(), &[600; 6]); } - #[test] pub fn test_ndarray_to_tensor_ref() { let arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); let t = arr.as_mnn_tensor().unwrap(); assert_eq!(t.host(), &[64; 6]); } - -pub trait NdarrayToMnnMut { - type H: mnn::HalideType; - fn as_mnn_tensor_mut(&mut self) -> Option>>>; -} - -impl NdarrayToMnnMut for ndarray::ArrayBase -where - A: ndarray::DataMut, - D: ndarray::Dimension, - T: mnn::HalideType, -{ - type H = T; - fn as_mnn_tensor_mut(&mut self) -> Option>>> { - let shape = self.shape().iter().map(|i| *i as i32).collect::>(); - let data = self.as_slice_mut()?; - Some(mnn::Tensor::borrowed_mut(shape, data)) - } -} - -#[test] -pub fn test_ndarray_to_tensor_ref_mut() { - let mut arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap(); - arr.as_mnn_tensor_mut().unwrap().fill(600); - assert_eq!(arr.as_slice().unwrap(), &[600; 6]); -}