diff --git a/api/Cargo.toml b/api/Cargo.toml index 03315d3a17..d667f12303 100644 --- a/api/Cargo.toml +++ b/api/Cargo.toml @@ -20,6 +20,7 @@ half.workspace = true ndarray.workspace = true [dev-dependencies] +lazy_static = "1.4.0" reqwest.workspace = true tempfile.workspace = true serde_json.workspace = true diff --git a/api/tests/mobilenet/mod.rs b/api/tests/mobilenet/mod.rs index 59ee9a15ab..ea62a6a55d 100644 --- a/api/tests/mobilenet/mod.rs +++ b/api/tests/mobilenet/mod.rs @@ -1,3 +1,5 @@ +use std::sync::Once; + fn grace_hopper() -> Value { let data = std::fs::read("../tests/grace_hopper_3_224_224.f32.raw").unwrap(); let data: &[f32] = unsafe { std::slice::from_raw_parts(data.as_ptr() as _, 3 * 224 * 224) }; @@ -5,29 +7,29 @@ fn grace_hopper() -> Value { } fn ensure_models() -> anyhow::Result<()> { - for (url, file) in [( - "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx", - "mobilenetv2-7.onnx"), - ( - "https://sfo2.digitaloceanspaces.com/nnef-public/mobilenet_v2_1.0.onnx.nnef.tgz", - "mobilenet_v2_1.0.onnx.nnef.tgz") - ] { - if std::fs::metadata(file).is_err() { - let client = reqwest::blocking::Client::new(); - let model = client.get(url).send()?; - std::fs::write(file, model.bytes()?)?; + static START: Once = Once::new(); + START.call_once(|| { + for (url, file) in [( + "https://github.com/onnx/models/raw/main/vision/classification/mobilenet/model/mobilenetv2-7.onnx", + "mobilenetv2-7.onnx"), + ( + "https://sfo2.digitaloceanspaces.com/nnef-public/mobilenet_v2_1.0.onnx.nnef.tgz", + "mobilenet_v2_1.0.onnx.nnef.tgz") + ] { + if std::fs::metadata(file).is_err() { + let client = reqwest::blocking::Client::new(); + let model = client.get(url).send().unwrap(); + std::fs::write(file, model.bytes().unwrap()).unwrap(); + } } - } + }); Ok(()) } #[test] fn test_onnx() -> anyhow::Result<()> { ensure_models()?; - let model = onnx()? - .model_for_path("mobilenetv2-7.onnx")? - .into_optimized()? - .into_runnable()?; + let model = onnx()?.model_for_path("mobilenetv2-7.onnx")?.into_optimized()?.into_runnable()?; let hopper = grace_hopper(); let result = model.run([hopper])?; let result = result[0].view::()?; @@ -45,10 +47,7 @@ fn test_onnx() -> anyhow::Result<()> { #[test] fn test_state() -> anyhow::Result<()> { ensure_models()?; - let model = onnx()? - .model_for_path("mobilenetv2-7.onnx")? - .into_optimized()? - .into_runnable()?; + let model = onnx()?.model_for_path("mobilenetv2-7.onnx")?.into_optimized()?.into_runnable()?; let mut state = model.spawn_state()?; let hopper = grace_hopper(); let result = state.run([hopper])?;