Skip to content

Commit

Permalink
feat(mnn-bridge): Added support for [email protected] for mnn-bridge
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Oct 16, 2024
1 parent d145ef8 commit 7c007a9
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 101 deletions.
17 changes: 16 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions mnn-bridge/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,15 @@ edition = "2021"
license = { workspace = true }

[dependencies]
error-stack = "0.5.0"
mnn = { workspace = true }
ndarray = { version = "0.16", optional = true }
ndarray_0_15 = { package = "ndarray", version = "0.15", optional = true }
# opencv = { version = "0.92.3", default-features = false, optional = true }

[features]
ndarray = ["dep:ndarray"]
ndarray_0_15 = ["dep:ndarray_0_15"]
# opencv = ["dep:opencv"]

default = []
7 changes: 5 additions & 2 deletions mnn-bridge/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
#[cfg(feature = "ndarray")]
pub mod ndarray;
// #[cfg(feature = "opencv")]
// pub mod opencv;
#[cfg(feature = "ndarray_0_15")]
mod ndarray_0_15 {
use ndarray_0_15 as ndarray;
include!("ndarray.rs");
}
216 changes: 118 additions & 98 deletions mnn-bridge/src/ndarray.rs
Original file line number Diff line number Diff line change
@@ -1,129 +1,149 @@
use error_stack::*;
use ndarray::*;

#[derive(Debug)]
pub struct MnnBridge;
impl Context for MnnBridge {}
impl core::fmt::Display for MnnBridge {
fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result {
write!(f, "MnnBridgeError")
}
}

pub trait MnnToNdarray {
type H: mnn::HalideType;
fn as_ndarray(&self) -> ndarray::ArrayViewD<Self::H> {
self.try_as_ndarray()
fn as_ndarray<D: Dimension>(&self) -> ndarray::ArrayView<Self::H, D> {
self.try_as_ndarray::<D>()
.expect("Failed to create ndarray::ArrayViewD from mnn::Tensor")
}
fn try_as_ndarray(&self) -> Option<ndarray::ArrayViewD<Self::H>>;
fn try_as_ndarray<D: Dimension>(&self) -> Result<ndarray::ArrayView<Self::H, D>, MnnBridge>;
}

impl<T> MnnToNdarray for mnn::Tensor<T>
where
T: mnn::TensorType + mnn::HostTensorType,
T::H: mnn::HalideType,
{
type H = T::H;
fn try_as_ndarray(&self) -> Option<ndarray::ArrayViewD<Self::H>> {
let shape = self
.shape()
.as_ref()
.into_iter()
.copied()
.map(|i| i as usize)
.collect::<Vec<_>>();
let data = self.host();
ndarray::ArrayViewD::from_shape(shape, data).ok()
pub trait MnnToNdarrayMut {
type H: mnn::HalideType;
fn as_ndarray_mut<D: Dimension>(&mut self) -> ndarray::ArrayViewMut<Self::H, D> {
self.try_as_ndarray_mut::<D>()
.expect("Failed to create ndarray::ArrayViewMutD from mnn::Tensor")
}
fn try_as_ndarray_mut<D: Dimension>(
&mut self,
) -> Result<ndarray::ArrayViewMut<Self::H, D>, MnnBridge>;
}

#[test]
pub fn test_tensor_to_ndarray_ref() {
let mut tensor: mnn::Tensor<mnn::Host<i32>> =
mnn::Tensor::new([1, 2, 3], mnn::DimensionType::Caffe);
tensor.fill(64);
let ndarr = tensor.as_ndarray();
let ndarr_2 = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec())
.unwrap()
.into_dyn();
assert_eq!(ndarr, ndarr_2);
pub trait NdarrayToMnn {
type H: mnn::HalideType;
fn as_mnn_tensor(&self) -> Option<mnn::Tensor<mnn::Ref<mnn::Host<Self::H>>>>;
}

pub trait MnnToNdarrayMut {
pub trait NdarrayToMnnMut {
type H: mnn::HalideType;
fn as_ndarray_mut(&mut self) -> ndarray::ArrayViewMutD<Self::H> {
self.try_as_ndarray_mut()
.expect("Failed to create ndarray::ArrayViewMutD from mnn::Tensor")
}
fn try_as_ndarray_mut(&mut self) -> Option<ndarray::ArrayViewMutD<Self::H>>;
fn as_mnn_tensor_mut(&mut self) -> Option<mnn::Tensor<mnn::RefMut<mnn::Host<Self::H>>>>;
}

impl<T> MnnToNdarrayMut for mnn::Tensor<T>
where
T: mnn::TensorType + mnn::MutableTensorType + mnn::HostTensorType,
T::H: mnn::HalideType,
{
type H = T::H;
fn try_as_ndarray_mut(&mut self) -> Option<ndarray::ArrayViewMutD<Self::H>> {
let shape = self
.shape()
.as_ref()
.into_iter()
.copied()
.map(|i| i as usize)
.collect::<Vec<_>>();
let data = self.host_mut();
ndarray::ArrayViewMutD::from_shape(shape, data).ok()
const _: () = {
impl<T> MnnToNdarray for mnn::Tensor<T>
where
T: mnn::TensorType + mnn::HostTensorType,
T::H: mnn::HalideType,
{
type H = T::H;
fn try_as_ndarray<D: Dimension>(
&self,
) -> Result<ndarray::ArrayView<Self::H, D>, MnnBridge> {
let shape = self
.shape()
.as_ref()
.into_iter()
.copied()
.map(|i| i as usize)
.collect::<Vec<_>>();
let data = self.host();
Ok(ndarray::ArrayViewD::from_shape(shape, data)
.change_context(MnnBridge)?
.into_dimensionality()
.change_context(MnnBridge)?)
}
}

impl<T> MnnToNdarrayMut for mnn::Tensor<T>
where
T: mnn::TensorType + mnn::MutableTensorType + mnn::HostTensorType,
T::H: mnn::HalideType,
{
type H = T::H;
fn try_as_ndarray_mut<D: Dimension>(
&mut self,
) -> Result<ndarray::ArrayViewMut<Self::H, D>, MnnBridge> {
let shape = self
.shape()
.as_ref()
.into_iter()
.copied()
.map(|i| i as usize)
.collect::<Vec<_>>();
let data = self.host_mut();
Ok(ndarray::ArrayViewMutD::from_shape(shape, data)
.change_context(MnnBridge)?
.into_dimensionality()
.change_context(MnnBridge)?)
}
}
}

impl<T, D, A> NdarrayToMnn for ndarray::ArrayBase<A, D>
where
A: ndarray::Data<Elem = T>,
D: ndarray::Dimension,
T: mnn::HalideType,
{
type H = T;
fn as_mnn_tensor(&self) -> Option<mnn::Tensor<mnn::Ref<mnn::Host<Self::H>>>> {
let shape = self.shape().iter().map(|i| *i as i32).collect::<Vec<_>>();
let data = self.as_slice()?;
Some(mnn::Tensor::borrowed(shape, data))
}
}

impl<T, D, A> NdarrayToMnnMut for ndarray::ArrayBase<A, D>
where
A: ndarray::DataMut<Elem = T>,
D: ndarray::Dimension,
T: mnn::HalideType,
{
type H = T;
fn as_mnn_tensor_mut(&mut self) -> Option<mnn::Tensor<mnn::RefMut<mnn::Host<Self::H>>>> {
let shape = self.shape().iter().map(|i| *i as i32).collect::<Vec<_>>();
let data = self.as_slice_mut()?;
Some(mnn::Tensor::borrowed_mut(shape, data))
}
}
};
#[test]
pub fn test_tensor_to_ndarray_ref() {
let mut tensor: mnn::Tensor<mnn::Host<i32>> =
mnn::Tensor::new([1, 2, 3], mnn::DimensionType::Caffe);
tensor.fill(64);
let ndarr = tensor.as_ndarray();
let ndarr_other = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap();
assert_eq!(ndarr, ndarr_other);
}
#[test]
pub fn test_tensor_to_ndarray_ref_mut() {
let mut data = vec![100; 8 * 8 * 3];
let mut tensor: mnn::Tensor<mnn::RefMut<mnn::Host<i16>>> =
mnn::Tensor::borrowed_mut([8, 8, 3], &mut data);
let mut ndarray = tensor.as_ndarray_mut();
let mut ndarray = tensor.as_ndarray_mut::<Ix3>();
ndarray.fill(600);
assert_eq!(data, [600; 8 * 8 * 3]);
}

pub trait NdarrayToMnn {
type H: mnn::HalideType;
fn as_mnn_tensor(&self) -> Option<mnn::Tensor<mnn::Ref<mnn::Host<Self::H>>>>;
}

impl<T, D, A> NdarrayToMnn for ndarray::ArrayBase<A, D>
where
A: ndarray::Data<Elem = T>,
D: ndarray::Dimension,
T: mnn::HalideType,
{
type H = T;
fn as_mnn_tensor(&self) -> Option<mnn::Tensor<mnn::Ref<mnn::Host<Self::H>>>> {
let shape = self.shape().iter().map(|i| *i as i32).collect::<Vec<_>>();
let data = self.as_slice()?;
Some(mnn::Tensor::borrowed(shape, data))
}
#[test]
pub fn test_ndarray_to_tensor_ref_mut() {
let mut arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap();
arr.as_mnn_tensor_mut().unwrap().fill(600);
assert_eq!(arr.as_slice().unwrap(), &[600; 6]);
}

#[test]
pub fn test_ndarray_to_tensor_ref() {
let arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap();
let t = arr.as_mnn_tensor().unwrap();
assert_eq!(t.host(), &[64; 6]);
}

pub trait NdarrayToMnnMut {
type H: mnn::HalideType;
fn as_mnn_tensor_mut(&mut self) -> Option<mnn::Tensor<mnn::RefMut<mnn::Host<Self::H>>>>;
}

impl<T, D, A> NdarrayToMnnMut for ndarray::ArrayBase<A, D>
where
A: ndarray::DataMut<Elem = T>,
D: ndarray::Dimension,
T: mnn::HalideType,
{
type H = T;
fn as_mnn_tensor_mut(&mut self) -> Option<mnn::Tensor<mnn::RefMut<mnn::Host<Self::H>>>> {
let shape = self.shape().iter().map(|i| *i as i32).collect::<Vec<_>>();
let data = self.as_slice_mut()?;
Some(mnn::Tensor::borrowed_mut(shape, data))
}
}

#[test]
pub fn test_ndarray_to_tensor_ref_mut() {
let mut arr = ndarray::Array3::from_shape_vec([1, 2, 3], [64; 6].to_vec()).unwrap();
arr.as_mnn_tensor_mut().unwrap().fill(600);
assert_eq!(arr.as_slice().unwrap(), &[600; 6]);
}

0 comments on commit 7c007a9

Please sign in to comment.