From 447889028e0c998c31a743eaa179fc27798f79b3 Mon Sep 17 00:00:00 2001 From: Jamjamjon <51357717+jamjamjon@users.noreply.github.com> Date: Thu, 10 Oct 2024 00:26:26 +0800 Subject: [PATCH] Add Apple ml-depth-pro model --- Cargo.toml | 2 +- examples/depth-pro/main.rs | 26 ++++++++++++ src/models/depth_pro.rs | 86 ++++++++++++++++++++++++++++++++++++++ src/models/mod.rs | 2 + 4 files changed, 115 insertions(+), 1 deletion(-) create mode 100644 examples/depth-pro/main.rs create mode 100644 src/models/depth_pro.rs diff --git a/Cargo.toml b/Cargo.toml index a8d0afb..db6fc67 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "usls" -version = "0.0.18" +version = "0.0.19" edition = "2021" description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models." repository = "https://github.com/jamjamjon/usls" diff --git a/examples/depth-pro/main.rs b/examples/depth-pro/main.rs new file mode 100644 index 0000000..eb72a9a --- /dev/null +++ b/examples/depth-pro/main.rs @@ -0,0 +1,26 @@ +use usls::{models::DepthPro, Annotator, DataLoader, Options}; + +fn main() -> Result<(), Box> { + // options + let options = Options::default() + .with_model("depth-pro/q4f16.onnx")? // bnb4, f16 + .with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1 + .with_ixx(0, 1, 3.into()) // channel + .with_ixx(0, 2, 1536.into()) // height + .with_ixx(0, 3, 1536.into()); // width + let mut model = DepthPro::new(options)?; + + // load + let x = [DataLoader::try_read("images/street.jpg")?]; + + // run + let y = model.run(&x)?; + + // annotate + let annotator = Annotator::default() + .with_colormap("Turbo") + .with_saveout("Depth-Pro"); + annotator.annotate(&x, &y); + + Ok(()) +} diff --git a/src/models/depth_pro.rs b/src/models/depth_pro.rs new file mode 100644 index 0000000..26938f7 --- /dev/null +++ b/src/models/depth_pro.rs @@ -0,0 +1,86 @@ +use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y}; +use anyhow::Result; +use image::DynamicImage; +use ndarray::Axis; + +#[derive(Debug)] +pub struct DepthPro { + engine: OrtEngine, + height: MinOptMax, + width: MinOptMax, + batch: MinOptMax, +} + +impl DepthPro { + pub fn new(options: Options) -> Result { + let mut engine = OrtEngine::new(&options)?; + let (batch, height, width) = ( + engine.batch().clone(), + engine.height().clone(), + engine.width().clone(), + ); + engine.dry_run()?; + + Ok(Self { + engine, + height, + width, + batch, + }) + } + + pub fn run(&mut self, xs: &[DynamicImage]) -> Result> { + let xs_ = X::apply(&[ + Ops::Resize( + xs, + self.height.opt() as u32, + self.width.opt() as u32, + "Bilinear", + ), + Ops::Normalize(0., 255.), + Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3), + Ops::Nhwc2nchw, + ])?; + let ys = self.engine.run(Xs::from(xs_))?; + + self.postprocess(ys, xs) + } + + pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result> { + let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]); + let predicted_depth = predicted_depth.mapv(|x| 1. / x); + + let mut ys: Vec = Vec::new(); + for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() { + let (w1, h1) = (xs0[idx].width(), xs0[idx].height()); + let v = luma.into_owned().into_raw_vec_and_offset().0; + let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap(); + let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap(); + let v = v + .iter() + .map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8) + .collect::>(); + + let luma = Ops::resize_luma8_u8( + &v, + self.width.opt() as _, + self.height.opt() as _, + w1 as _, + h1 as _, + false, + "Bilinear", + )?; + let luma: image::ImageBuffer, Vec<_>> = + match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) { + None => continue, + Some(x) => x, + }; + ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)])); + } + Ok(ys) + } + + pub fn batch(&self) -> isize { + self.batch.opt() as _ + } +} diff --git a/src/models/mod.rs b/src/models/mod.rs index 28ecb53..6df4c8c 100644 --- a/src/models/mod.rs +++ b/src/models/mod.rs @@ -4,6 +4,7 @@ mod blip; mod clip; mod db; mod depth_anything; +mod depth_pro; mod dinov2; mod florence2; mod grounding_dino; @@ -20,6 +21,7 @@ pub use blip::Blip; pub use clip::Clip; pub use db::DB; pub use depth_anything::DepthAnything; +pub use depth_pro::DepthPro; pub use dinov2::Dinov2; pub use florence2::Florence2; pub use grounding_dino::GroundingDINO;