Skip to content

Commit

Permalink
feat(example): Allow inspect to set data types for input and output t…
Browse files Browse the repository at this point in the history
…ensors
  • Loading branch information
uttarayan21 committed Oct 8, 2024
1 parent 1ae6a2d commit 7507380
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 30 deletions.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ mnn-threadpool = ["mnn-sys/mnn-threadpool"]
tracing = ["dep:tracing"]
profile = ["tracing"]

default = ["mnn-threadpool", "opencl"]
default = ["mnn-threadpool"]


[dev-dependencies]
Expand Down
70 changes: 50 additions & 20 deletions examples/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,21 @@ pub struct Cli {
precision: PrecisionMode,
#[clap(short, long, default_value = "high")]
memory: MemoryMode,
#[clap(short, long, default_value = "f32")]
output_data_type: DataType,
#[clap(short, long, default_value = "f32")]
input_data_type: DataType,
#[clap(short, long, default_value = "1")]
loops: usize,
}

#[derive(Debug, Clone, clap::ValueEnum)]
pub enum DataType {
F32,
U8,
I8,
}

macro_rules! time {
($($x:expr),+ ; $text: expr) => {
{
Expand All @@ -41,33 +52,52 @@ pub fn main() -> anyhow::Result<()> {
config.set_type(cli.forward);
let mut session = time!(interpreter.create_session(config)?; "create session");
interpreter.update_cache_file(&mut session)?;
let inputs = interpreter.inputs(&session);
let mut first = inputs
.iter()
.next()
.expect("No input")
.tensor::<f32>()
.unwrap();
let shape = first.shape();
interpreter.resize_tensor(&mut first, shape);
interpreter.resize_session(&mut session);
drop(first);
drop(inputs);

let mut current = 0;
time!(loop {
println!("--------------------------------Inputs--------------------------------");
interpreter.inputs(&session).iter().for_each(|x| {
let mut tensor = x.tensor::<f32>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1.0f32);
match cli.input_data_type {
DataType::F32 => {
let mut tensor = x.tensor::<f32>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1.0f32);
},
DataType::U8 => {
let mut tensor = x.tensor::<u8>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1u8);
},
DataType::I8 => {
let mut tensor = x.tensor::<i8>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1i8);
},
};
});
println!("Running session");
interpreter.run_session(&session)?;
println!("--------------------------------Outputs--------------------------------");
let outputs = interpreter.outputs(&session);
outputs.iter().for_each(|x| {
let tensor = x.tensor::<f32>().expect("No tensor");
time!(tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
println!("{}: {:?}", x.name(), tensor.shape());
let _ = tensor.create_host_tensor_from_device(true);
// std::fs::write(format!("{}.bin", x.name()), bytemuck::cast_slice(cpu_tensor.host())).expect("Unable to write");
match cli.output_data_type {
DataType::F32 => {
let tensor = x.tensor::<f32>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
time!(tensor.wait(MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
},
DataType::U8 => {
let tensor = x.tensor::<u8>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
time!(tensor.wait(MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
},
DataType::I8 => {
let tensor = x.tensor::<i8>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
time!(tensor.wait(MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
},
};

});
current += 1;
if current >= cli.loops {
Expand Down
14 changes: 5 additions & 9 deletions src/interpreter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ impl Interpreter {
Ok(RawTensor::from_ptr(input))
}

/// * Safety
/// # Safety
/// We Still don't know the safety guarantees of this function so it's marked unsafe
pub unsafe fn input_unresized<'s, H: HalideType>(
&self,
Expand All @@ -309,7 +309,7 @@ impl Interpreter {
Ok(tensor)
}

/// * Safety
/// # Safety
/// Very unsafe since it doesn't check the type of the tensor
/// as well as the shape of the tensor
pub unsafe fn input_unchecked<'s, H: HalideType>(
Expand Down Expand Up @@ -468,13 +468,9 @@ impl<'t, 'tl> TensorInfo<'t, 'tl> {
Ok(tensor)
}

/// * Safety
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be
/// unsafe with this
pub unsafe fn tensor_unresized<H: HalideType>(&self) -> Result<Tensor<RefMut<'t, Device<H>>>>
where
H: HalideType,
{
/// # Safety
/// The shape is not checked so it's marked unsafe since futher calls to interpreter might be unsafe with this
pub unsafe fn tensor_unresized<H: HalideType>(&self) -> Result<Tensor<RefMut<'t, Device<H>>>> {
debug_assert!(!self.tensor_info.is_null());
unsafe { debug_assert!(!(*self.tensor_info).tensor.is_null()) };
let tensor = unsafe { Tensor::from_ptr((*self.tensor_info).tensor.cast()) };
Expand Down

0 comments on commit 7507380

Please sign in to comment.