-
Notifications
You must be signed in to change notification settings - Fork 957
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* define structs * construct ResidualConvUnit * forward() for ResidualConvUnit * implement FeatureFusionBlock * implement Scratch * implement DPTHead * add identity module * implement forward for DTPHead * add get_intermediate_layers to DinoVisionTransformer * implement DepthAnythingV2 * some minor tweaks * fix compile errors * fix var builder prefixes * setup initial example * use fixed patch size of 37 (518 / 14) * debugged until output * print min and max values * add some dynamism to the output location * scale input image * extract prep function * extract output path function * normalize image with magic mean and std * add spectral coloring * squeeze in the right place * make enterpolation optional * use bail instead of panic * omit unnecessary Shape call * remove empty curly braces * use bail instead of assert * use vb and pp * remove closures * extract config object * Apply rustfmt. * Fix some clippy lints. * More lints. * Use the array methods. --------- Co-authored-by: laurent <[email protected]>
- Loading branch information
1 parent
6baa1d4
commit 242e006
Showing
8 changed files
with
911 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# candle-dinov2 | ||
|
||
[Depth Anything V2] is a model for Monocular Depth Estimation (MDE, i.e. just using a single image) which | ||
builds on the [DINOv2](https://github.com/facebookresearch/dinov2) vision transformer. | ||
|
||
This example first instantiates the DINOv2 model and then proceeds to create DepthAnythingV2 and run it. | ||
|
||
## Running an example with color map and CUDA | ||
|
||
```bash | ||
cargo run --features cuda,depth_anything_v2 --package candle-examples --example depth_anything_v2 -- --color-map --image candle-examples/examples/yolo-v8/assets/bike.jpg | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
use enterpolation::linear::ConstEquidistantLinear; | ||
use enterpolation::Generator; | ||
use palette::LinSrgb; | ||
|
||
use candle::Tensor; | ||
|
||
pub struct SpectralRColormap { | ||
gradient: ConstEquidistantLinear<f32, LinSrgb, 9>, | ||
} | ||
|
||
impl SpectralRColormap { | ||
pub(crate) fn new() -> Self { | ||
// Define a colormap similar to 'Spectral_r' by specifying key colors. | ||
// got the colors from ChatGPT-4o | ||
let gradient = ConstEquidistantLinear::<f32, _, 9>::equidistant_unchecked([ | ||
LinSrgb::new(0.3686, 0.3098, 0.6353), // Dark blue | ||
LinSrgb::new(0.1961, 0.5333, 0.7412), // Blue | ||
LinSrgb::new(0.4000, 0.7608, 0.6471), // Cyan | ||
LinSrgb::new(0.6706, 0.8667, 0.6431), // Green | ||
LinSrgb::new(0.9020, 0.9608, 0.5961), // Yellow | ||
LinSrgb::new(0.9961, 0.8784, 0.5451), // Orange | ||
LinSrgb::new(0.9922, 0.6824, 0.3804), // Red | ||
LinSrgb::new(0.9569, 0.4275, 0.2627), // Dark red | ||
LinSrgb::new(0.8353, 0.2431, 0.3098), // Dark purple | ||
]); | ||
Self { gradient } | ||
} | ||
|
||
fn get_color(&self, value: f32) -> LinSrgb { | ||
self.gradient.gen(value) | ||
} | ||
|
||
pub fn gray2color(&self, gray: &Tensor) -> candle::Result<Tensor> { | ||
println!("Gray: {:?}", gray.dims()); | ||
let gray_values: Vec<f32> = gray.flatten_all()?.to_vec1()?; | ||
let rgb_values: Vec<f32> = gray_values | ||
.iter() | ||
.map(|g| self.get_color(*g)) | ||
.flat_map(|rgb| [rgb.red, rgb.green, rgb.blue]) | ||
.collect(); | ||
|
||
let [.., height, width] = gray.dims() else { | ||
candle::bail!("Not enough dims!") | ||
}; | ||
|
||
let color = Tensor::from_vec(rgb_values, (*height, *width, 3), gray.device())?; | ||
|
||
color.permute((2, 0, 1)) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,187 @@ | ||
//! Depth Anything V2 | ||
//! https://huggingface.co/spaces/depth-anything/Depth-Anything-V2 | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
use std::ffi::OsString; | ||
use std::path::PathBuf; | ||
|
||
use clap::Parser; | ||
|
||
use candle::DType::{F32, U8}; | ||
use candle::{DType, Device, Module, Result, Tensor}; | ||
use candle_examples::{load_image, load_image_and_resize, save_image}; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::models::depth_anything_v2::{DepthAnythingV2, DepthAnythingV2Config}; | ||
use candle_transformers::models::dinov2; | ||
|
||
use crate::color_map::SpectralRColormap; | ||
|
||
mod color_map; | ||
|
||
// taken these from: https://huggingface.co/spaces/depth-anything/Depth-Anything-V2/blob/main/depth_anything_v2/dpt.py#L207 | ||
const MAGIC_MEAN: [f32; 3] = [0.485, 0.456, 0.406]; | ||
const MAGIC_STD: [f32; 3] = [0.229, 0.224, 0.225]; | ||
|
||
const DINO_IMG_SIZE: usize = 518; | ||
|
||
#[derive(Parser)] | ||
struct Args { | ||
#[arg(long)] | ||
dinov2_model: Option<PathBuf>, | ||
|
||
#[arg(long)] | ||
depth_anything_v2_model: Option<PathBuf>, | ||
|
||
#[arg(long)] | ||
image: PathBuf, | ||
|
||
#[arg(long)] | ||
output_dir: Option<PathBuf>, | ||
|
||
#[arg(long)] | ||
cpu: bool, | ||
|
||
#[arg(long)] | ||
color_map: bool, | ||
} | ||
|
||
pub fn main() -> anyhow::Result<()> { | ||
let args = Args::parse(); | ||
let device = candle_examples::device(args.cpu)?; | ||
|
||
let dinov2_model_file = match args.dinov2_model { | ||
None => { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let api = api.model("lmz/candle-dino-v2".into()); | ||
api.get("dinov2_vits14.safetensors")? | ||
} | ||
Some(dinov2_model) => dinov2_model, | ||
}; | ||
println!("Using file {:?}", dinov2_model_file); | ||
|
||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&[dinov2_model_file], F32, &device)? }; | ||
let dinov2 = dinov2::vit_small(vb)?; | ||
println!("DinoV2 model built"); | ||
|
||
let depth_anything_model_file = match args.depth_anything_v2_model { | ||
None => { | ||
let api = hf_hub::api::sync::Api::new()?; | ||
let api = api.model("jeroenvlek/depth-anything-v2-safetensors".into()); | ||
api.get("depth_anything_v2_vits.safetensors")? | ||
} | ||
Some(depth_anything_model) => depth_anything_model, | ||
}; | ||
println!("Using file {:?}", depth_anything_model_file); | ||
|
||
let vb = unsafe { | ||
VarBuilder::from_mmaped_safetensors(&[depth_anything_model_file], DType::F32, &device)? | ||
}; | ||
|
||
let config = DepthAnythingV2Config::vit_small(); | ||
let depth_anything = DepthAnythingV2::new(&dinov2, &config, vb)?; | ||
|
||
let (original_height, original_width, image) = load_and_prep_image(&args.image, &device)?; | ||
|
||
println!("Loaded image {image:?}"); | ||
|
||
let depth = depth_anything.forward(&image)?; | ||
|
||
println!("Got predictions {:?}", depth.shape()); | ||
|
||
let output_image = post_process_image(&depth, original_height, original_width, args.color_map)?; | ||
|
||
let output_path = full_output_path(&args.image, &args.output_dir); | ||
println!("Saving image to {}", output_path.to_string_lossy()); | ||
save_image(&output_image, output_path)?; | ||
|
||
Ok(()) | ||
} | ||
|
||
fn full_output_path(image_path: &PathBuf, output_dir: &Option<PathBuf>) -> PathBuf { | ||
let input_file_name = image_path.file_name().unwrap(); | ||
let mut output_file_name = OsString::from("depth_"); | ||
output_file_name.push(input_file_name); | ||
let mut output_path = match output_dir { | ||
None => image_path.parent().unwrap().to_path_buf(), | ||
Some(output_path) => output_path.clone(), | ||
}; | ||
output_path.push(output_file_name); | ||
|
||
output_path | ||
} | ||
|
||
fn load_and_prep_image( | ||
image_path: &PathBuf, | ||
device: &Device, | ||
) -> anyhow::Result<(usize, usize, Tensor)> { | ||
let (_original_image, original_height, original_width) = load_image(&image_path, None)?; | ||
|
||
let image = load_image_and_resize(&image_path, DINO_IMG_SIZE, DINO_IMG_SIZE)? | ||
.unsqueeze(0)? | ||
.to_dtype(F32)? | ||
.to_device(&device)?; | ||
|
||
let max_pixel_val = Tensor::try_from(255.0f32)? | ||
.to_device(&device)? | ||
.broadcast_as(image.shape())?; | ||
let image = (image / max_pixel_val)?; | ||
let image = normalize_image(&image, &MAGIC_MEAN, &MAGIC_STD)?; | ||
|
||
Ok((original_height, original_width, image)) | ||
} | ||
|
||
fn normalize_image(image: &Tensor, mean: &[f32; 3], std: &[f32; 3]) -> Result<Tensor> { | ||
let mean_tensor = | ||
Tensor::from_vec(mean.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; | ||
let std_tensor = | ||
Tensor::from_vec(std.to_vec(), (3, 1, 1), &image.device())?.broadcast_as(image.shape())?; | ||
image.sub(&mean_tensor)?.div(&std_tensor) | ||
} | ||
|
||
fn post_process_image( | ||
image: &Tensor, | ||
original_height: usize, | ||
original_width: usize, | ||
color_map: bool, | ||
) -> Result<Tensor> { | ||
let out = image.interpolate2d(original_height, original_width)?; | ||
let out = scale_image(&out)?; | ||
|
||
let out = if color_map { | ||
let spectral_r = SpectralRColormap::new(); | ||
spectral_r.gray2color(&out)? | ||
} else { | ||
let rgb_slice = [&out, &out, &out]; | ||
Tensor::cat(&rgb_slice, 0)?.squeeze(1)? | ||
}; | ||
|
||
let max_pixel_val = Tensor::try_from(255.0f32)? | ||
.to_device(out.device())? | ||
.broadcast_as(out.shape())?; | ||
let out = (out * max_pixel_val)?; | ||
|
||
out.to_dtype(U8) | ||
} | ||
|
||
fn scale_image(depth: &Tensor) -> Result<Tensor> { | ||
let flat_values: Vec<f32> = depth.flatten_all()?.to_vec1()?; | ||
|
||
let min_val = flat_values.iter().min_by(|a, b| a.total_cmp(b)).unwrap(); | ||
let max_val = flat_values.iter().max_by(|a, b| a.total_cmp(b)).unwrap(); | ||
|
||
let min_val_tensor = Tensor::try_from(*min_val)? | ||
.to_device(depth.device())? | ||
.broadcast_as(depth.shape())?; | ||
let depth = (depth - min_val_tensor)?; | ||
|
||
let range = max_val - min_val; | ||
let range_tensor = Tensor::try_from(range)? | ||
.to_device(depth.device())? | ||
.broadcast_as(depth.shape())?; | ||
|
||
depth / range_tensor | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.