Skip to content

Commit

Permalink
Add Apple ml-depth-pro model
Browse files Browse the repository at this point in the history
  • Loading branch information
jamjamjon committed Oct 9, 2024
1 parent 1d59638 commit 10cee26
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 1 deletion.
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "usls"
version = "0.0.18"
version = "0.0.19"
edition = "2021"
description = "A Rust library integrated with ONNXRuntime, providing a collection of ML models."
repository = "https://github.com/jamjamjon/usls"
Expand Down
26 changes: 26 additions & 0 deletions examples/depth-pro/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
use usls::{models::DepthPro, Annotator, DataLoader, Options};

fn main() -> Result<(), Box<dyn std::error::Error>> {
// options
let options = Options::default()
.with_model("depth-pro/q4f16.onnx")? // bnb4, f16
.with_ixx(0, 0, 1.into()) // batch. Note: now only support batch_size = 1
.with_ixx(0, 1, 3.into()) // channel
.with_ixx(0, 2, 1536.into()) // height
.with_ixx(0, 3, 1536.into()); // width
let mut model = DepthPro::new(options)?;

// load
let x = [DataLoader::try_read("images/street.jpg")?];

// run
let y = model.run(&x)?;

// annotate
let annotator = Annotator::default()
.with_colormap("Turbo")
.with_saveout("Depth-Pro");
annotator.annotate(&x, &y);

Ok(())
}
86 changes: 86 additions & 0 deletions src/models/depth_pro.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::{Mask, MinOptMax, Ops, Options, OrtEngine, Xs, X, Y};
use anyhow::Result;
use image::DynamicImage;
use ndarray::Axis;

#[derive(Debug)]
pub struct DepthPro {
engine: OrtEngine,
height: MinOptMax,
width: MinOptMax,
batch: MinOptMax,
}

impl DepthPro {
pub fn new(options: Options) -> Result<Self> {
let mut engine = OrtEngine::new(&options)?;
let (batch, height, width) = (
engine.batch().clone(),
engine.height().clone(),
engine.width().clone(),
);
engine.dry_run()?;

Ok(Self {
engine,
height,
width,
batch,
})
}

pub fn run(&mut self, xs: &[DynamicImage]) -> Result<Vec<Y>> {
let xs_ = X::apply(&[
Ops::Resize(
xs,
self.height.opt() as u32,
self.width.opt() as u32,
"Bilinear",
),
Ops::Normalize(0., 255.),
Ops::Standardize(&[0.5, 0.5, 0.5], &[0.5, 0.5, 0.5], 3),
Ops::Nhwc2nchw,
])?;
let ys = self.engine.run(Xs::from(xs_))?;

self.postprocess(ys, xs)
}

pub fn postprocess(&self, xs: Xs, xs0: &[DynamicImage]) -> Result<Vec<Y>> {
let (predicted_depth, _focallength_px) = (&xs["predicted_depth"], &xs["focallength_px"]);
let predicted_depth = predicted_depth.mapv(|x| 1. / x);

let mut ys: Vec<Y> = Vec::new();
for (idx, luma) in predicted_depth.axis_iter(Axis(0)).enumerate() {
let (w1, h1) = (xs0[idx].width(), xs0[idx].height());
let v = luma.into_owned().into_raw_vec_and_offset().0;
let max_ = v.iter().max_by(|x, y| x.total_cmp(y)).unwrap();
let min_ = v.iter().min_by(|x, y| x.total_cmp(y)).unwrap();
let v = v
.iter()
.map(|x| (((*x - min_) / (max_ - min_)) * 255.).clamp(0., 255.) as u8)
.collect::<Vec<_>>();

let luma = Ops::resize_luma8_u8(
&v,
self.width.opt() as _,
self.height.opt() as _,
w1 as _,
h1 as _,
false,
"Bilinear",
)?;
let luma: image::ImageBuffer<image::Luma<_>, Vec<_>> =
match image::ImageBuffer::from_raw(w1 as _, h1 as _, luma) {
None => continue,
Some(x) => x,
};
ys.push(Y::default().with_masks(&[Mask::default().with_mask(luma)]));
}
Ok(ys)
}

pub fn batch(&self) -> isize {
self.batch.opt() as _
}
}
2 changes: 2 additions & 0 deletions src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod blip;
mod clip;
mod db;
mod depth_anything;
mod depth_pro;
mod dinov2;
mod florence2;
mod grounding_dino;
Expand All @@ -20,6 +21,7 @@ pub use blip::Blip;
pub use clip::Clip;
pub use db::DB;
pub use depth_anything::DepthAnything;
pub use depth_pro::DepthPro;
pub use dinov2::Dinov2;
pub use florence2::Florence2;
pub use grounding_dino::GroundingDINO;
Expand Down

0 comments on commit 10cee26

Please sign in to comment.