Skip to content

Commit

Permalink
Merge pull request #4 from aftershootco/run_with_callbacks
Browse files Browse the repository at this point in the history
Added support for run_session with callbacks using rust closures
  • Loading branch information
uttarayan21 authored Sep 13, 2024
2 parents 080a4df + 2276a44 commit 9c61a7a
Show file tree
Hide file tree
Showing 21 changed files with 292 additions and 268 deletions.
153 changes: 12 additions & 141 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ thiserror = "1.0"
error-stack = { version = "0.5" }
oneshot = "0.1"
tracing = { version = "0.1.40", optional = true }
semver = "1.0.23"

[features]
metal = ["mnn-sys/metal"]
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@

Rust wrapper over [alibaba/MNN](https://github.com/alibaba/MNN) c++ library with handwritten C wrapper over mnn

If you have nix you can just build the inspect binary with

```
nix build .#inspect
```

NOTES:
On windows it will only compile with --release mode
There's a few issues with rustc linking to msvcrt by default and anything compiled with /MTd will not link properly
Expand Down
35 changes: 11 additions & 24 deletions examples/inspect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub struct Cli {
forward: ForwardType,
#[clap(short, long, default_value = "high")]
power: PowerMode,
#[clap(short, long, default_value = "high")]
#[clap(short = 'P', long, default_value = "high")]
precision: PrecisionMode,
#[clap(short, long, default_value = "high")]
memory: MemoryMode,
Expand Down Expand Up @@ -39,37 +39,26 @@ pub fn main() -> anyhow::Result<()> {

let mut config = ScheduleConfig::new();
config.set_type(cli.forward);
// let mut backend_config = BackendConfig::new();
// backend_config.set_precision_mode(PrecisionMode::High);
// backend_config.set_power_mode(PowerMode::High);
// config.set_backend_config(backend_config);
// let handle = mnn::sync::SessionHandle::new(interpreter, config)?;
let mut session = time!(interpreter.create_session(config)?; "create session");
interpreter.update_cache_file(&mut session)?;
let mut input = interpreter.input::<f32>(&session, "image")?;
let mut shape = input.shape();
shape[0] = 512;
shape[1] = 512;
shape[2] = 3;
interpreter.resize_tensor(&mut input, shape);
drop(input);
interpreter.resize_session(&mut session);
// let session = time!(interpreter.create_session(config)?; "create session");
// handle.run(|sr| {
// let interpreter = sr.interpreter();
// let session = sr.session();

let mut current = 0;
time!(loop {
let inputs = interpreter.inputs(&session);
inputs.iter().for_each(|x| {
interpreter.inputs(&session).iter().for_each(|x| {
let mut tensor = x.tensor::<f32>().expect("No tensor");
println!("{}: {:?}", x.name(), tensor.shape());
tensor.fill(1.0f32);
});
time!(interpreter.run_session(&session)?;"run session");
time!(interpreter.run_session_with_callback(&session, |_, name| {
println!("Before Callback: {:?}", name);
1
},|_ , name| {
println!("After Callback: {:?}", name);
1
} , true)?;"run session");
let outputs = interpreter.outputs(&session);
outputs.iter().for_each(|x| {
let tensor = x.tensor::<u8>().expect("No tensor");
let tensor = x.tensor::<f32>().expect("No tensor");
time!(tensor.wait(ffi::MapType::MAP_TENSOR_READ, true); format!("Waiting for tensor: {}", x.name()));
println!("{}: {:?}", x.name(), tensor.shape());
let _ = tensor.create_host_tensor_from_device(true);
Expand All @@ -80,7 +69,5 @@ pub fn main() -> anyhow::Result<()> {
break;
}
}; "run loop");
// Ok(())
// })?;
Ok(())
}
17 changes: 17 additions & 0 deletions flake.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 9c61a7a

Please sign in to comment.