Skip to content

Commit

Permalink
Depth Anything v2 (#2279)
Browse files Browse the repository at this point in the history
* define structs

* construct ResidualConvUnit

* forward() for ResidualConvUnit

* implement FeatureFusionBlock

* implement Scratch

* implement DPTHead

* add identity module

* implement forward for DTPHead

* add get_intermediate_layers to DinoVisionTransformer

* implement DepthAnythingV2

* some minor tweaks

* fix compile errors

* fix var builder prefixes

* setup initial example

* use fixed patch size of 37 (518 / 14)

* debugged until output

* print min and max values

* add some dynamism to the output location

* scale input image

* extract prep function

* extract output path function

* normalize image with magic mean and std

* add spectral coloring

* squeeze in the right place

* make enterpolation optional

* use bail instead of panic

* omit unnecessary Shape call

* remove empty curly braces

* use bail instead of assert

* use vb and pp

* remove closures

* extract config object

* Apply rustfmt.

* Fix some clippy lints.

* More lints.

* Use the array methods.

---------

Co-authored-by: laurent <[email protected]>
  • Loading branch information
jeroenvlek and LaurentMazare authored Jun 24, 2024
1 parent 6baa1d4 commit 242e006
Show file tree
Hide file tree
Showing 8 changed files with 911 additions and 1 deletion.
7 changes: 7 additions & 0 deletions candle-examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@ hf-hub = { workspace = true, features = ["tokio"] }
image = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
num-traits = { workspace = true }
palette = { version = "0.7.6", optional = true }
enterpolation = { version = "0.2.1", optional = true}
pyo3 = { version = "0.21.0", features = ["auto-initialize"], optional = true }
rayon = { workspace = true }
rubato = { version = "0.15.0", optional = true }
Expand Down Expand Up @@ -65,6 +67,7 @@ onnx = ["candle-onnx"]
metal = ["candle/metal", "candle-nn/metal"]
microphone = ["cpal"]
encodec = ["cpal", "symphonia", "rubato"]
depth_anything_v2 = ["palette", "enterpolation"]

[[example]]
name = "llama_multiprocess"
Expand Down Expand Up @@ -101,3 +104,7 @@ required-features = ["candle-datasets"]
[[example]]
name = "encodec"
required-features = ["encodec"]

[[example]]
name = "depth_anything_v2"
required-features = ["depth_anything_v2"]
13 changes: 13 additions & 0 deletions candle-examples/examples/depth_anything_v2/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# candle-dinov2

[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer.

This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it.

## Running an example with color map and CUDA

```bash
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg
```

50 changes: 50 additions & 0 deletions candle-examples/examples/depth_anything_v2/color_map.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
use enterpolation::linear::ConstEquidistantLinear;
use enterpolation::Generator;
use palette::LinSrgb;

use candle::Tensor;

pub struct SpectralRColormap {
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>,
}

impl SpectralRColormap {
pub(crate) fn new() -> Self {
// Define a colormap similar to 'Spectral_r' by specifying key colors.
// got the colors from ChatGPT-4o
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple
]);
Self { gradient }
}

fn get_color(&self, value: f32) -> LinSrgb {
self.gradient.gen(value)
}

pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> {
println!("Gray: {:?}", gray.dims());
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?;
let rgb_values: Vec<f32> = gray_values
.iter()
.map(|g| self.get_color(*g))
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue])
.collect();

let [.., height, width] = gray.dims() else {
candle::bail!("Not enough dims!")
};

let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?;

color.permute((2, 0, 1))
}
}
187 changes: 187 additions & 0 deletions candle-examples/examples/depth_anything_v2/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
//! Depth Anything V2
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2

#[cfg(feature = "accelerate")]
extern crate accelerate_src;
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use std::ffi::OsString;
use std::path::PathBuf;

use clap::Parser;

use candle::DType::{F32, U8};
use candle::{DType, Device, Module, Result, Tensor};
use candle_examples::{load_image, load_image_and_resize, save_image};
use candle_nn::VarBuilder;
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config};
use candle_transformers::models::dinov2;

use crate::color_map::SpectralRColormap;

mod color_map;

// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406];
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225];

const DINO_IMG_SIZE: usize = 518;

#[derive(Parser)]
struct Args {
#[arg(long)]
dinov2_model: Option<PathBuf>,

#[arg(long)]
depth_anything_v2_model: Option<PathBuf>,

#[arg(long)]
image: PathBuf,

#[arg(long)]
output_dir: Option<PathBuf>,

#[arg(long)]
cpu: bool,

#[arg(long)]
color_map: bool,
}

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
let device = candle_examples::device(args.cpu)?;

let dinov2_model_file = match args.dinov2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("lmz/candle-dino-v2".into());
api.get("dinov2_vits14.safetensors")?
}
Some(dinov2_model) => dinov2_model,
};
println!("Using file {:?}", dinov2_model_file);

let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? };
let dinov2 = dinov2::vit_small(vb)?;
println!("DinoV2 model built");

let depth_anything_model_file = match args.depth_anything_v2_model {
None => {
let api = hf_hub::api::sync::Api::new()?;
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into());
api.get("depth_anything_v2_vits.safetensors")?
}
Some(depth_anything_model) => depth_anything_model,
};
println!("Using file {:?}", depth_anything_model_file);

let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)?
};

let config = DepthAnythingV2Config::vit_small();
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?;

let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?;

println!("Loaded image {image:?}");

let depth = depth_anything.forward(&image)?;

println!("Got predictions {:?}", depth.shape());

let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?;

let output_path = full_output_path(&args.image, &args.output_dir);
println!("Saving image to {}", output_path.to_string_lossy());
save_image(&output_image, output_path)?;

Ok(())
}

fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf {
let input_file_name = image_path.file_name().unwrap();
let mut output_file_name = OsString::from("depth_");
output_file_name.push(input_file_name);
let mut output_path = match output_dir {
None => image_path.parent().unwrap().to_path_buf(),
Some(output_path) => output_path.clone(),
};
output_path.push(output_file_name);

output_path
}

fn load_and_prep_image(
image_path: &PathBuf,
device: &Device,
) -> anyhow::Result<(usize, usize, Tensor)> {
let (_original_image, original_height, original_width) = load_image(&image_path, None)?;

let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)?
.unsqueeze(0)?
.to_dtype(F32)?
.to_device(&device)?;

let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(&device)?
.broadcast_as(image.shape())?;
let image = (image / max_pixel_val)?;
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?;

Ok((original_height, original_width, image))
}

fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> {
let mean_tensor =
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
let std_tensor =
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?;
image.sub(&mean_tensor)?.div(&std_tensor)
}

fn post_process_image(
image: &Tensor,
original_height: usize,
original_width: usize,
color_map: bool,
) -> Result<Tensor> {
let out = image.interpolate2d(original_height, original_width)?;
let out = scale_image(&out)?;

let out = if color_map {
let spectral_r = SpectralRColormap::new();
spectral_r.gray2color(&out)?
} else {
let rgb_slice = [&out, &out, &out];
Tensor::cat(&rgb_slice, 0)?.squeeze(1)?
};

let max_pixel_val = Tensor::try_from(255.0f32)?
.to_device(out.device())?
.broadcast_as(out.shape())?;
let out = (out * max_pixel_val)?;

out.to_dtype(U8)
}

fn scale_image(depth: &Tensor) -> Result<Tensor> {
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?;

let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap();
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap();

let min_val_tensor = Tensor::try_from(*min_val)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;
let depth = (depth - min_val_tensor)?;

let range = max_val - min_val;
let range_tensor = Tensor::try_from(range)?
.to_device(depth.device())?
.broadcast_as(depth.shape())?;

depth / range_tensor
}
23 changes: 22 additions & 1 deletion candle-nn/src/ops.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use candle::{CpuStorage, DType, Layout, Result, Shape, Tensor, D};
use candle::{CpuStorage, DType, Layout, Module, Result, Shape, Tensor, D};
use rayon::prelude::*;

/// Applies the softmax function to the input tensor, rescaling the element so that elements on
Expand Down Expand Up @@ -926,3 +926,24 @@ pub fn replication_pad2d(xs: &Tensor, pad: usize) -> Result<Tensor> {
n => candle::bail!("replication-pad with a size of {n} is not supported"),
}
}

#[derive(Clone, Debug)]
pub struct Identity;

impl Identity {
pub fn new() -> Identity {
Self
}
}

impl Default for Identity {
fn default() -> Self {
Self
}
}

impl Module for Identity {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
Ok(xs.clone())
}
}
Loading

0 comments on commit 242e006

Please sign in to comment.