Skip to content

Commit

Permalink
fix model dl
Browse files Browse the repository at this point in the history
  • Loading branch information
kali committed Sep 12, 2023
1 parent 86e68e4 commit ff4d0a6
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 20 deletions.
1 change: 1 addition & 0 deletions api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ half.workspace = true
ndarray.workspace = true

[dev-dependencies]
lazy_static = "1.4.0"
reqwest.workspace = true
tempfile.workspace = true
serde_json.workspace = true
39 changes: 19 additions & 20 deletions api/tests/mobilenet/mod.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
use std::sync::Once;

fn grace_hopper() -> Value {
let data = std::fs::read("../tests/grace_hopper_3_224_224.f32.raw").unwrap();
let data: &[f32] = unsafe { std::slice::from_raw_parts(data.as_ptr() as _, 3 * 224 * 224) };
Value::from_slice(&[1, 3, 224, 224], data).unwrap()
}

fn ensure_models() -> anyhow::Result<()> {
for (url, file) in [(
"https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx",
"mobilenetv2-7.onnx"),
(
"https://sfo2.digitaloceanspaces.com/nnef-public/mobilenet_v2_1.0.onnx.nnef.tgz",
"mobilenet_v2_1.0.onnx.nnef.tgz")
] {
if std::fs::metadata(file).is_err() {
let client = reqwest::blocking::Client::new();
let model = client.get(url).send()?;
std::fs::write(file, model.bytes()?)?;
static START: Once = Once::new();
START.call_once(|| {
for (url, file) in [(
"https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx",
"mobilenetv2-7.onnx"),
(
"https://sfo2.digitaloceanspaces.com/nnef-public/mobilenet_v2_1.0.onnx.nnef.tgz",
"mobilenet_v2_1.0.onnx.nnef.tgz")
] {
if std::fs::metadata(file).is_err() {
let client = reqwest::blocking::Client::new();
let model = client.get(url).send().unwrap();
std::fs::write(file, model.bytes().unwrap()).unwrap();
}
}
}
});
Ok(())
}

#[test]
fn test_onnx() -> anyhow::Result<()> {
ensure_models()?;
let model = onnx()?
.model_for_path("mobilenetv2-7.onnx")?
.into_optimized()?
.into_runnable()?;
let model = onnx()?.model_for_path("mobilenetv2-7.onnx")?.into_optimized()?.into_runnable()?;
let hopper = grace_hopper();
let result = model.run([hopper])?;
let result = result[0].view::<f32>()?;
Expand All @@ -45,10 +47,7 @@ fn test_onnx() -> anyhow::Result<()> {
#[test]
fn test_state() -> anyhow::Result<()> {
ensure_models()?;
let model = onnx()?
.model_for_path("mobilenetv2-7.onnx")?
.into_optimized()?
.into_runnable()?;
let model = onnx()?.model_for_path("mobilenetv2-7.onnx")?.into_optimized()?.into_runnable()?;
let mut state = model.spawn_state()?;
let hopper = grace_hopper();
let result = state.run([hopper])?;
Expand Down

0 comments on commit ff4d0a6

Please sign in to comment.