From 5126211f5fca9b998e46e57744c18e83cd25ec0e Mon Sep 17 00:00:00 2001 From: Bingxin Date: Fri, 17 May 2024 22:26:12 +0200 Subject: [PATCH 1/6] [UPDATE] support scale-invariant mode; pass invariance flag through model_index.json --- infer.py | 7 ++- marigold/marigold_pipeline.py | 63 +++++++++++++++++++++++--- marigold/util/ensemble.py | 85 +++++++++++++++++++++++++++++++++++ run.py | 3 ++ 4 files changed, 149 insertions(+), 9 deletions(-) diff --git a/infer.py b/infer.py index 3bc3eb5..6f95209 100644 --- a/infer.py +++ b/infer.py @@ -1,4 +1,4 @@ -# Last modified: 2024-04-15 +# Last modified: 2024-05-17 # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -213,7 +213,10 @@ def check_directory(directory): logging.debug("run without xformers") pipe = pipe.to(device) - + logging.info( + f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }" + ) + # -------------------- Inference and saving -------------------- with torch.no_grad(): for batch in tqdm( diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index ec9ea4e..970f126 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -33,13 +33,13 @@ from diffusers.utils import BaseOutput from PIL import Image from torch.utils.data import DataLoader, TensorDataset -from torchvision.transforms.functional import resize, pil_to_tensor from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import pil_to_tensor, resize from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from .util.batchsize import find_batch_size -from .util.ensemble import ensemble_depths +from .util.ensemble import ensemble_depths, ensemble_depths_up2scale from .util.image_util import ( chw2hwc, colorize_depth_maps, @@ -97,9 +97,32 @@ def __init__( scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, + prediction_type: str = None, + scale_invariant: bool = None, + shift_invariant: bool = None, ): super().__init__() + if prediction_type is None: + logging.warn( + "`prediction_type` is required but not given, filled with 'depth'" + ) + prediction_type = "depth" + if scale_invariant is None: + logging.warn( + "`scale_invariant` is required but not given, filled with `True`" + ) + scale_invariant = True + if shift_invariant is None: + logging.warn( + "`shift_invariant` is required but not given, filled with `True`" + ) + shift_invariant = True + + self.prediction_type = prediction_type + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.register_modules( unet=unet, vae=vae, @@ -107,6 +130,11 @@ def __init__( text_encoder=text_encoder, tokenizer=tokenizer, ) + self.register_to_config( + prediction_type=prediction_type, + scale_invariant=scale_invariant, + shift_invariant=shift_invariant, + ) self.empty_text_embed = None @@ -152,6 +180,10 @@ def __call__( Display a progress bar of diffusion denoising. color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): Colormap used to colorize the depth map. + scale_invariant (`str`, *optional*, defaults to `True`): + Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. + shift_invariant (`str`, *optional*, defaults to `True`): + Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m. ensemble_kwargs (`dict`, *optional*, defaults to `None`): Arguments for detailed ensembling settings. Returns: @@ -236,17 +268,34 @@ def __call__( # ----------------- Test-time ensembling ----------------- if ensemble_size > 1: - depth_pred, pred_uncert = ensemble_depths( - depth_preds, **(ensemble_kwargs or {}) - ) + if self.scale_invariant and self.shift_invariant: + depth_pred, pred_uncert = ensemble_depths( + depth_preds, **(ensemble_kwargs or {}) + ) + elif self.scale_invariant and (not self.shift_invariant): + depth_pred, pred_uncert = ensemble_depths_up2scale( + depth_preds, **(ensemble_kwargs or {}) + ) + else: + raise NotImplementedError("Metric depth is not supported.") else: depth_pred = depth_preds pred_uncert = None # ----------------- Post processing ----------------- # Scale prediction to [0, 1] - min_d = torch.min(depth_pred) - max_d = torch.max(depth_pred) + if self.shift_invariant: + min_d = torch.min(depth_pred) + else: + min_d = 0 + + if self.scale_invariant: + max_d = torch.max(depth_pred) + else: + raise NotImplementedError( + "Metric depth is not supported." + ) + depth_pred = (depth_pred - min_d) / (max_d - min_d) # Resize back to original resolution diff --git a/marigold/util/ensemble.py b/marigold/util/ensemble.py index 5a2908e..922f0c5 100644 --- a/marigold/util/ensemble.py +++ b/marigold/util/ensemble.py @@ -130,3 +130,88 @@ def closure(x): uncertainty /= _max - _min return aligned_images, uncertainty + + +def ensemble_depths_up2scale( + input_images: torch.Tensor, + regularizer_strength: float = 0.02, + max_iter: int = 2, + tol: float = 1e-3, + reduction: str = "median", + max_res: int = None, +): + """ + To ensemble multiple scale-invariant depth images (fixed near plane at 0) + """ + device = input_images.device + dtype = input_images.dtype + np_dtype = np.float32 + + original_input = input_images.clone() + n_img = input_images.shape[0] + ori_shape = input_images.shape + + if max_res is not None: + scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) + if scale_factor < 1: + downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") + input_images = downscaler(input_images) + + # init guess + _min = 0 + _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) + s_init = 1.0 / (_max - _min).reshape((-1)) + x = s_init + + input_images = input_images.to(device) + + # objective function + def closure(x): + s = torch.from_numpy(x).to(dtype=dtype).to(device) + + transformed_arrays = input_images * s.view((-1, 1, 1)) + dists = inter_distances(transformed_arrays) + sqrt_dist = torch.sqrt(torch.mean(dists**2)) + + if "mean" == reduction: + pred = torch.mean(transformed_arrays, dim=0) + elif "median" == reduction: + pred = torch.median(transformed_arrays, dim=0).values + else: + raise ValueError + + near_err = torch.sqrt((0 - torch.min(pred)) ** 2) + far_err = torch.sqrt((1 - torch.max(pred)) ** 2) + + err = sqrt_dist + (near_err + far_err) * regularizer_strength + err = err.detach().cpu().numpy().astype(np_dtype) + return err + + res = minimize( + closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} + ) + s = res.x + + # Prediction + s = torch.from_numpy(s).to(dtype=dtype).to(device) + transformed_arrays = original_input * s.view(-1, 1, 1) + if "mean" == reduction: + aligned_images = torch.mean(transformed_arrays, dim=0) + std = torch.std(transformed_arrays, dim=0) + uncertainty = std + elif "median" == reduction: + aligned_images = torch.median(transformed_arrays, dim=0).values + # MAD (median absolute deviation) as uncertainty indicator + abs_dev = torch.abs(transformed_arrays - aligned_images) + mad = torch.median(abs_dev, dim=0).values + uncertainty = mad + else: + raise ValueError(f"Unknown reduction method: {reduction}") + + # Scale and shift to [0, 1] + _min = 0 + _max = torch.max(aligned_images) + aligned_images = (aligned_images - _min) / (_max - _min) + uncertainty /= _max - _min + + return aligned_images, uncertainty diff --git a/run.py b/run.py index 5274ca9..84e7580 100644 --- a/run.py +++ b/run.py @@ -220,6 +220,9 @@ pass # run without xformers pipe = pipe.to(device) + logging.info( + f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }" + ) # -------------------- Inference and saving -------------------- with torch.no_grad(): From 9af23c9f3b941c4457e67a36960aa485d60ecafa Mon Sep 17 00:00:00 2001 From: Bingxin Date: Fri, 17 May 2024 22:47:17 +0200 Subject: [PATCH 2/6] [CLEAN] ruff format --- infer.py | 2 +- marigold/marigold_pipeline.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/infer.py b/infer.py index 6f95209..51e8ddf 100644 --- a/infer.py +++ b/infer.py @@ -216,7 +216,7 @@ def check_directory(directory): logging.info( f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }" ) - + # -------------------- Inference and saving -------------------- with torch.no_grad(): for batch in tqdm( diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index 970f126..522ae77 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -292,9 +292,7 @@ def __call__( if self.scale_invariant: max_d = torch.max(depth_pred) else: - raise NotImplementedError( - "Metric depth is not supported." - ) + raise NotImplementedError("Metric depth is not supported.") depth_pred = (depth_pred - min_d) / (max_d - min_d) From 37b3d6941c8ec744175630f5f706888afe05d2a7 Mon Sep 17 00:00:00 2001 From: Bingxin Date: Fri, 24 May 2024 13:36:06 +0200 Subject: [PATCH 3/6] [UPDATE] remove prediction_type --- infer.py | 4 ++-- marigold/marigold_pipeline.py | 10 ---------- run.py | 2 +- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/infer.py b/infer.py index 51e8ddf..5db976f 100644 --- a/infer.py +++ b/infer.py @@ -1,4 +1,4 @@ -# Last modified: 2024-05-17 +# Last modified: 2024-05-24 # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -214,7 +214,7 @@ def check_directory(directory): pipe = pipe.to(device) logging.info( - f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }" + f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}" ) # -------------------- Inference and saving -------------------- diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index 522ae77..5b0c820 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -97,17 +97,10 @@ def __init__( scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - prediction_type: str = None, scale_invariant: bool = None, shift_invariant: bool = None, ): super().__init__() - - if prediction_type is None: - logging.warn( - "`prediction_type` is required but not given, filled with 'depth'" - ) - prediction_type = "depth" if scale_invariant is None: logging.warn( "`scale_invariant` is required but not given, filled with `True`" @@ -118,8 +111,6 @@ def __init__( "`shift_invariant` is required but not given, filled with `True`" ) shift_invariant = True - - self.prediction_type = prediction_type self.scale_invariant = scale_invariant self.shift_invariant = shift_invariant @@ -131,7 +122,6 @@ def __init__( tokenizer=tokenizer, ) self.register_to_config( - prediction_type=prediction_type, scale_invariant=scale_invariant, shift_invariant=shift_invariant, ) diff --git a/run.py b/run.py index 84e7580..c714180 100644 --- a/run.py +++ b/run.py @@ -221,7 +221,7 @@ pipe = pipe.to(device) logging.info( - f"{pipe.prediction_type = }, {pipe.scale_invariant = }, {pipe.shift_invariant = }" + f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}" ) # -------------------- Inference and saving -------------------- From 528d02e970394a27ff02ccdcf10c090435edb7dc Mon Sep 17 00:00:00 2001 From: Bingxin Date: Fri, 24 May 2024 15:54:34 +0200 Subject: [PATCH 4/6] [UPDATE] unified ensemble function --- marigold/marigold_pipeline.py | 41 ++--- marigold/util/ensemble.py | 317 ++++++++++++++++------------------ marigold/util/image_util.py | 3 +- 3 files changed, 165 insertions(+), 196 deletions(-) diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index 5b0c820..cd1f0df 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -39,7 +39,7 @@ from transformers import CLIPTextModel, CLIPTokenizer from .util.batchsize import find_batch_size -from .util.ensemble import ensemble_depths, ensemble_depths_up2scale +from .util.ensemble import ensemble_depth from .util.image_util import ( chw2hwc, colorize_depth_maps, @@ -253,50 +253,35 @@ def __call__( generator=generator, ) depth_pred_ls.append(depth_pred_raw.detach()) - depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze() + depth_preds = torch.concat(depth_pred_ls, dim=0) torch.cuda.empty_cache() # clear vram cache for ensembling # ----------------- Test-time ensembling ----------------- if ensemble_size > 1: - if self.scale_invariant and self.shift_invariant: - depth_pred, pred_uncert = ensemble_depths( - depth_preds, **(ensemble_kwargs or {}) - ) - elif self.scale_invariant and (not self.shift_invariant): - depth_pred, pred_uncert = ensemble_depths_up2scale( - depth_preds, **(ensemble_kwargs or {}) - ) - else: - raise NotImplementedError("Metric depth is not supported.") + depth_pred, pred_uncert = ensemble_depth( + depth_preds, + scale_invariant=self.scale_invariant, + shift_invariant=self.shift_invariant, + **(ensemble_kwargs or {}), + ) else: depth_pred = depth_preds pred_uncert = None - # ----------------- Post processing ----------------- - # Scale prediction to [0, 1] - if self.shift_invariant: - min_d = torch.min(depth_pred) - else: - min_d = 0 - - if self.scale_invariant: - max_d = torch.max(depth_pred) - else: - raise NotImplementedError("Metric depth is not supported.") - - depth_pred = (depth_pred - min_d) / (max_d - min_d) - # Resize back to original resolution if match_input_res: depth_pred = resize( - depth_pred.unsqueeze(0), + depth_pred, input_size[1:], interpolation=resample_method, antialias=True, - ).squeeze() + ) # Convert to numpy + depth_pred = depth_pred.squeeze() depth_pred = depth_pred.cpu().numpy() + if pred_uncert is not None: + pred_uncert = pred_uncert.squeeze().cpu().numpy() # Clip output range depth_pred = depth_pred.clip(0, 1) diff --git a/marigold/util/ensemble.py b/marigold/util/ensemble.py index 922f0c5..d8c087b 100644 --- a/marigold/util/ensemble.py +++ b/marigold/util/ensemble.py @@ -18,10 +18,13 @@ # -------------------------------------------------------------------------- +from functools import partial +from typing import Optional, Tuple + import numpy as np import torch -from scipy.optimize import minimize +from .image_util import get_tv_resample_method, resize_max_res def inter_distances(tensors: torch.Tensor): @@ -37,181 +40,161 @@ def inter_distances(tensors: torch.Tensor): return dist -def ensemble_depths( - input_images: torch.Tensor, +def ensemble_depth( + depth: torch.Tensor, + scale_invariant: bool = True, + shift_invariant: bool = True, + output_uncertainty: bool = False, + reduction: str = "median", regularizer_strength: float = 0.02, max_iter: int = 2, tol: float = 1e-3, - reduction: str = "median", - max_res: int = None, -): + max_res: int = 1024, +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: """ - To ensemble multiple affine-invariant depth images (up to scale and shift), - by aligning estimating the scale and shift + Ensembles depth maps represented by the `depth` tensor with expected shape `(B, 1, H, W)`, where B is the + number of ensemble members for a given prediction of size `(H x W)`. Even though the function is designed for + depth maps, it can also be used with disparity maps as long as the input tensor values are non-negative. The + alignment happens when the predictions have one or more degrees of freedom, that is when they are either + affine-invariant (`scale_invariant=True` and `shift_invariant=True`), or just scale-invariant (only + `scale_invariant=True`). For absolute predictions (`scale_invariant=False` and `shift_invariant=False`) + alignment is skipped and only ensembling is performed. + + Args: + depth (`torch.Tensor`): + Input ensemble depth maps. + scale_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as scale-invariant. + shift_invariant (`bool`, *optional*, defaults to `True`): + Whether to treat predictions as shift-invariant. + output_uncertainty (`bool`, *optional*, defaults to `False`): + Whether to output uncertainty map. + reduction (`str`, *optional*, defaults to `"median"`): + Reduction method used to ensemble aligned predictions. The accepted values are: `"mean"` and + `"median"`. + regularizer_strength (`float`, *optional*, defaults to `0.02`): + Strength of the regularizer that pulls the aligned predictions to the unit range from 0 to 1. + max_iter (`int`, *optional*, defaults to `2`): + Maximum number of the alignment solver steps. Refer to `scipy.optimize.minimize` function, `options` + argument. + tol (`float`, *optional*, defaults to `1e-3`): + Alignment solver tolerance. The solver stops when the tolerance is reached. + max_res (`int`, *optional*, defaults to `1024`): + Resolution at which the alignment is performed; `None` matches the `processing_resolution`. + Returns: + A tensor of aligned and ensembled depth maps and optionally a tensor of uncertainties of the same shape: + `(1, 1, H, W)`. """ - device = input_images.device - dtype = input_images.dtype - np_dtype = np.float32 - - original_input = input_images.clone() - n_img = input_images.shape[0] - ori_shape = input_images.shape - - if max_res is not None: - scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) - if scale_factor < 1: - downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") - input_images = downscaler(input_images) - - # init guess - _min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) - _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) - s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) - t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) - x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) - - input_images = input_images.to(device) - - # objective function - def closure(x): - len_x = len(x) - s = x[: int(len_x / 2)] - t = x[int(len_x / 2) :] - s = torch.from_numpy(s).to(dtype=dtype).to(device) - t = torch.from_numpy(t).to(dtype=dtype).to(device) - - transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) - dists = inter_distances(transformed_arrays) - sqrt_dist = torch.sqrt(torch.mean(dists**2)) - - if "mean" == reduction: - pred = torch.mean(transformed_arrays, dim=0) - elif "median" == reduction: - pred = torch.median(transformed_arrays, dim=0).values + if depth.dim() != 4 or depth.shape[1] != 1: + raise ValueError(f"Expecting 4D tensor of shape [B,1,H,W]; got {depth.shape}.") + if reduction not in ("mean", "median"): + raise ValueError(f"Unrecognized reduction method: {reduction}.") + if not scale_invariant and shift_invariant: + raise ValueError("Pure shift-invariant ensembling is not supported.") + + def init_param(depth: torch.Tensor): + init_min = depth.reshape(ensemble_size, -1).min(dim=1).values + init_max = depth.reshape(ensemble_size, -1).max(dim=1).values + + if scale_invariant and shift_invariant: + init_s = 1.0 / (init_max - init_min).clamp(min=1e-6) + init_t = -init_s * init_min + param = torch.cat((init_s, init_t)).cpu().numpy() + elif scale_invariant: + init_s = 1.0 / init_max.clamp(min=1e-6) + param = init_s.cpu().numpy() else: - raise ValueError - - near_err = torch.sqrt((0 - torch.min(pred)) ** 2) - far_err = torch.sqrt((1 - torch.max(pred)) ** 2) - - err = sqrt_dist + (near_err + far_err) * regularizer_strength - err = err.detach().cpu().numpy().astype(np_dtype) - return err - - res = minimize( - closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} - ) - x = res.x - len_x = len(x) - s = x[: int(len_x / 2)] - t = x[int(len_x / 2) :] - - # Prediction - s = torch.from_numpy(s).to(dtype=dtype).to(device) - t = torch.from_numpy(t).to(dtype=dtype).to(device) - transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) - if "mean" == reduction: - aligned_images = torch.mean(transformed_arrays, dim=0) - std = torch.std(transformed_arrays, dim=0) - uncertainty = std - elif "median" == reduction: - aligned_images = torch.median(transformed_arrays, dim=0).values - # MAD (median absolute deviation) as uncertainty indicator - abs_dev = torch.abs(transformed_arrays - aligned_images) - mad = torch.median(abs_dev, dim=0).values - uncertainty = mad - else: - raise ValueError(f"Unknown reduction method: {reduction}") + raise ValueError("Unrecognized alignment.") + + return param + + def align(depth: torch.Tensor, param: np.ndarray) -> torch.Tensor: + if scale_invariant and shift_invariant: + s, t = np.split(param, 2) + s = torch.from_numpy(s).to(depth).view(ensemble_size, 1, 1, 1) + t = torch.from_numpy(t).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + t + elif scale_invariant: + s = torch.from_numpy(param).to(depth).view(ensemble_size, 1, 1, 1) + out = depth * s + else: + raise ValueError("Unrecognized alignment.") + return out + + def ensemble( + depth_aligned: torch.Tensor, return_uncertainty: bool = False + ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + uncertainty = None + if reduction == "mean": + prediction = torch.mean(depth_aligned, dim=0, keepdim=True) + if return_uncertainty: + uncertainty = torch.std(depth_aligned, dim=0, keepdim=True) + elif reduction == "median": + prediction = torch.median(depth_aligned, dim=0, keepdim=True).values + if return_uncertainty: + uncertainty = torch.median( + torch.abs(depth_aligned - prediction), dim=0, keepdim=True + ).values + else: + raise ValueError(f"Unrecognized reduction method: {reduction}.") + return prediction, uncertainty - # Scale and shift to [0, 1] - _min = torch.min(aligned_images) - _max = torch.max(aligned_images) - aligned_images = (aligned_images - _min) / (_max - _min) - uncertainty /= _max - _min + def cost_fn(param: np.ndarray, depth: torch.Tensor) -> float: + cost = 0.0 + depth_aligned = align(depth, param) - return aligned_images, uncertainty + for i, j in torch.combinations(torch.arange(ensemble_size)): + diff = depth_aligned[i] - depth_aligned[j] + cost += (diff**2).mean().sqrt().item() + if regularizer_strength > 0: + prediction, _ = ensemble(depth_aligned, return_uncertainty=False) + err_near = (0.0 - prediction.min()).abs().item() + err_far = (1.0 - prediction.max()).abs().item() + cost += (err_near + err_far) * regularizer_strength -def ensemble_depths_up2scale( - input_images: torch.Tensor, - regularizer_strength: float = 0.02, - max_iter: int = 2, - tol: float = 1e-3, - reduction: str = "median", - max_res: int = None, -): - """ - To ensemble multiple scale-invariant depth images (fixed near plane at 0) - """ - device = input_images.device - dtype = input_images.dtype - np_dtype = np.float32 - - original_input = input_images.clone() - n_img = input_images.shape[0] - ori_shape = input_images.shape - - if max_res is not None: - scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) - if scale_factor < 1: - downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") - input_images = downscaler(input_images) - - # init guess - _min = 0 - _max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) - s_init = 1.0 / (_max - _min).reshape((-1)) - x = s_init - - input_images = input_images.to(device) - - # objective function - def closure(x): - s = torch.from_numpy(x).to(dtype=dtype).to(device) - - transformed_arrays = input_images * s.view((-1, 1, 1)) - dists = inter_distances(transformed_arrays) - sqrt_dist = torch.sqrt(torch.mean(dists**2)) - - if "mean" == reduction: - pred = torch.mean(transformed_arrays, dim=0) - elif "median" == reduction: - pred = torch.median(transformed_arrays, dim=0).values - else: - raise ValueError - - near_err = torch.sqrt((0 - torch.min(pred)) ** 2) - far_err = torch.sqrt((1 - torch.max(pred)) ** 2) - - err = sqrt_dist + (near_err + far_err) * regularizer_strength - err = err.detach().cpu().numpy().astype(np_dtype) - return err - - res = minimize( - closure, x, method="BFGS", tol=tol, options={"maxiter": max_iter, "disp": False} - ) - s = res.x - - # Prediction - s = torch.from_numpy(s).to(dtype=dtype).to(device) - transformed_arrays = original_input * s.view(-1, 1, 1) - if "mean" == reduction: - aligned_images = torch.mean(transformed_arrays, dim=0) - std = torch.std(transformed_arrays, dim=0) - uncertainty = std - elif "median" == reduction: - aligned_images = torch.median(transformed_arrays, dim=0).values - # MAD (median absolute deviation) as uncertainty indicator - abs_dev = torch.abs(transformed_arrays - aligned_images) - mad = torch.median(abs_dev, dim=0).values - uncertainty = mad - else: - raise ValueError(f"Unknown reduction method: {reduction}") + return cost + + def compute_param(depth: torch.Tensor): + import scipy + + depth_to_align = depth.to(torch.float32) + if max_res is not None and max(depth_to_align.shape[2:]) > max_res: + depth_to_align = resize_max_res( + depth_to_align, max_res, get_tv_resample_method("nearest-exact") + ) + + param = init_param(depth_to_align) + + res = scipy.optimize.minimize( + partial(cost_fn, depth=depth_to_align), + param, + method="BFGS", + tol=tol, + options={"maxiter": max_iter, "disp": False}, + ) - # Scale and shift to [0, 1] - _min = 0 - _max = torch.max(aligned_images) - aligned_images = (aligned_images - _min) / (_max - _min) - uncertainty /= _max - _min + return res.x + + requires_aligning = scale_invariant or shift_invariant + ensemble_size = depth.shape[0] + + if requires_aligning: + param = compute_param(depth) + depth = align(depth, param) + + depth, uncertainty = ensemble(depth, return_uncertainty=output_uncertainty) + + depth_max = depth.max() + if scale_invariant and shift_invariant: + depth_min = depth.min() + elif scale_invariant: + depth_min = 0 + else: + raise ValueError("Unrecognized alignment.") + depth_range = (depth_max - depth_min).clamp(min=1e-6) + depth = (depth - depth_min) / depth_range + if output_uncertainty: + uncertainty /= depth_range - return aligned_images, uncertainty + return depth, uncertainty # [1,1,H,W], [1,1,H,W] diff --git a/marigold/util/image_util.py b/marigold/util/image_util.py index 90f0623..9924bab 100644 --- a/marigold/util/image_util.py +++ b/marigold/util/image_util.py @@ -1,5 +1,5 @@ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. -# Last modified: 2024-04-16 +# Last modified: 2024-05-24 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -113,6 +113,7 @@ def get_tv_resample_method(method_str: str) -> InterpolationMode: "bilinear": InterpolationMode.BILINEAR, "bicubic": InterpolationMode.BICUBIC, "nearest": InterpolationMode.NEAREST_EXACT, + "nearest-exact": InterpolationMode.NEAREST_EXACT, } resample_method = resample_method_dict.get(method_str, None) if resample_method is None: From d129f6c744ea69d4c7502c6a038b79c04f0fb606 Mon Sep 17 00:00:00 2001 From: Bingxin Date: Fri, 24 May 2024 15:56:20 +0200 Subject: [PATCH 5/6] [UPDATE] add default_denoising_steps, default_processing_resolution --- README.md | 8 ++-- marigold/marigold_pipeline.py | 77 +++++++++++++++++++++++------------ run.py | 24 ++++++----- 3 files changed, 67 insertions(+), 42 deletions(-) diff --git a/README.md b/README.md index 13ac069..8e45365 100644 --- a/README.md +++ b/README.md @@ -125,12 +125,10 @@ Activate the environment again after restarting the terminal session. ### 🚀 Run inference with LCM (faster) -The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 to 4: +The [LCM checkpoint](https://huggingface.co/prs-eth/marigold-lcm-v1-0) is distilled from our original checkpoint towards faster inference speed (by reducing inference steps). The inference steps can be as few as 1 (default) to 4. Run with default LCM setting: ```bash python run.py \ - --denoise_steps 4 \ - --ensemble_size 5 \ --input_rgb_dir input/in-the-wild_example \ --output_dir output/in-the-wild_example_lcm ``` @@ -156,11 +154,11 @@ The default settings are optimized for the best result. However, the behavior of - Trade-offs between the **accuracy** and **speed** (for both options, larger values result in better accuracy at the cost of slower inference.) - `--ensemble_size`: Number of inference passes in the ensemble. For LCM `ensemble_size` is more important than `denoise_steps`. Default: ~~10~~ 5 (for LCM). - - `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. Default: ~~10~~ 4 (for LCM). + - `--denoise_steps`: Number of denoising steps of each inference pass. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps. When unassigned (`None`), will read default setting from model config. Default: ~~10 4 (for LCM)~~ `None`. - By default, the inference script resizes input images to the *processing resolution*, and then resizes the prediction back to the original resolution. This gives the best quality, as Stable Diffusion, from which Marigold is derived, performs best at 768x768 resolution. - - `--processing_res`: the processing resolution; set 0 to process the input resolution directly. Default: 768. + - `--processing_res`: the processing resolution; set as 0 to process the input resolution directly. When unassigned (`None`), will read default setting from model config. Default: ~~768~~ `None`. - `--output_processing_res`: produce output at the processing resolution instead of upsampling it to the input resolution. Default: False. - `--resample_method`: resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`. Default: `bilinear`. diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index cd1f0df..15b8283 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -1,4 +1,5 @@ # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. +# Last modified: 2024-05-24 # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -19,7 +20,7 @@ import logging -from typing import Dict, Union +from typing import Dict, Optional, Union import numpy as np import torch @@ -85,6 +86,25 @@ class MarigoldPipeline(DiffusionPipeline): Text-encoder, for empty text embedding. tokenizer (`CLIPTokenizer`): CLIP tokenizer. + scale_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in + the model config. When used together with the `shift_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + shift_invariant (`bool`, *optional*): + A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in + the model config. When used together with the `scale_invariant=True` flag, the model is also called + "affine-invariant". NB: overriding this value is not supported. + default_denoising_steps (`int`, *optional*): + The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable + quality with the given model. This value must be set in the model config. When the pipeline is called + without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure + reasonable results with various model flavors compatible with the pipeline, such as those relying on very + short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). + default_processing_resolution (`int`, *optional*): + The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in + the model config. When the pipeline is called without explicitly setting `processing_resolution`, the + default value is used. This is required to ensure reasonable results with various model flavors trained + with varying optimal processing resolution values. """ rgb_latent_scale_factor = 0.18215 @@ -97,23 +117,12 @@ def __init__( scheduler: Union[DDIMScheduler, LCMScheduler], text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, - scale_invariant: bool = None, - shift_invariant: bool = None, + scale_invariant: Optional[bool] = True, + shift_invariant: Optional[bool] = True, + default_denoising_steps: Optional[int] = None, + default_processing_resolution: Optional[int] = None, ): super().__init__() - if scale_invariant is None: - logging.warn( - "`scale_invariant` is required but not given, filled with `True`" - ) - scale_invariant = True - if shift_invariant is None: - logging.warn( - "`shift_invariant` is required but not given, filled with `True`" - ) - shift_invariant = True - self.scale_invariant = scale_invariant - self.shift_invariant = shift_invariant - self.register_modules( unet=unet, vae=vae, @@ -124,17 +133,24 @@ def __init__( self.register_to_config( scale_invariant=scale_invariant, shift_invariant=shift_invariant, + default_denoising_steps=default_denoising_steps, + default_processing_resolution=default_processing_resolution, ) + self.scale_invariant = scale_invariant + self.shift_invariant = shift_invariant + self.default_denoising_steps = default_denoising_steps + self.default_processing_resolution = default_processing_resolution + self.empty_text_embed = None @torch.no_grad() def __call__( self, input_image: Union[Image.Image, torch.Tensor], - denoising_steps: int = 10, - ensemble_size: int = 10, - processing_res: int = 768, + denoising_steps: Optional[int] = None, + ensemble_size: int = 5, + processing_res: Optional[int] = None, match_input_res: bool = True, resample_method: str = "bilinear", batch_size: int = 0, @@ -149,18 +165,21 @@ def __call__( Args: input_image (`Image`): Input RGB (or gray-scale) image. - processing_res (`int`, *optional*, defaults to `768`): - Maximum resolution of processing. - If set to 0: will not resize at all. + denoising_steps (`int`, *optional*, defaults to `None`): + Number of denoising diffusion steps during inference. The default value `None` results in automatic + selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 + for Marigold-LCM models. + ensemble_size (`int`, *optional*, defaults to `10`): + Number of predictions to be ensembled. + processing_res (`int`, *optional*, defaults to `None`): + Effective processing resolution. When set to `0`, processes at the original image resolution. This + produces crisper predictions, but may also lead to the overall loss of global context. The default + value `None` resolves to the optimal value from the model config. match_input_res (`bool`, *optional*, defaults to `True`): Resize depth prediction to match input resolution. Only valid if `processing_res` > 0. resample_method: (`str`, *optional*, defaults to `bilinear`): Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. - denoising_steps (`int`, *optional*, defaults to `10`): - Number of diffusion denoising steps (DDIM) during inference. - ensemble_size (`int`, *optional*, defaults to `10`): - Number of predictions to be ensembled. batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. @@ -183,6 +202,12 @@ def __call__( - **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. None if `ensemble_size = 1` """ + # Model-specific optimal default values leading to fast and reasonable results. + if denoising_steps is None: + denoising_steps = self.default_denoising_steps + if processing_res is None: + processing_res = self.default_processing_resolution + assert processing_res >= 0 assert ensemble_size >= 1 diff --git a/run.py b/run.py index c714180..029e9fb 100644 --- a/run.py +++ b/run.py @@ -62,7 +62,7 @@ parser.add_argument( "--denoise_steps", type=int, - default=4, + default=None, help="Diffusion denoising steps, more steps results in higher accuracy but slower inference speed. For the original (DDIM) version, it's recommended to use 10-50 steps, while for LCM 1-4 steps.", ) parser.add_argument( @@ -82,7 +82,7 @@ parser.add_argument( "--processing_res", type=int, - default=768, + default=None, help="Maximum resolution of processing. 0 for using input image resolution. Default: 768.", ) parser.add_argument( @@ -153,14 +153,6 @@ batch_size = 1 # set default batchsize # -------------------- Preparation -------------------- - # Print out config - logging.info( - f"Inference settings: checkpoint = `{checkpoint_path}`, " - f"with denoise_steps = {denoise_steps}, ensemble_size = {ensemble_size}, " - f"processing resolution = {processing_res}, seed = {seed}; " - f"color_map = {color_map}." - ) - # Output directories output_dir_color = os.path.join(output_dir, "depth_colored") output_dir_tif = os.path.join(output_dir, "depth_bw") @@ -210,7 +202,7 @@ dtype = torch.float32 variant = None - pipe = MarigoldPipeline.from_pretrained( + pipe: MarigoldPipeline = MarigoldPipeline.from_pretrained( checkpoint_path, variant=variant, torch_dtype=dtype ) @@ -224,6 +216,16 @@ f"scale_invariant: {pipe.scale_invariant}, shift_invariant: {pipe.shift_invariant}" ) + # Print out config + logging.info( + f"Inference settings: checkpoint = `{checkpoint_path}`, " + f"with denoise_steps = {denoise_steps or pipe.default_denoising_steps}, " + f"ensemble_size = {ensemble_size}, " + f"processing resolution = {processing_res or pipe.default_processing_resolution}, " + f"seed = {seed}; " + f"color_map = {color_map}." + ) + # -------------------- Inference and saving -------------------- with torch.no_grad(): os.makedirs(output_dir, exist_ok=True) From c91a70aa458b70763ec7194066ff025be467977d Mon Sep 17 00:00:00 2001 From: Bingxin Date: Fri, 24 May 2024 17:00:15 +0200 Subject: [PATCH 6/6] [UPDATE] resize_max_res() takes 4-dim input --- marigold/marigold_pipeline.py | 12 +++++++----- marigold/util/image_util.py | 7 ++++--- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index 15b8283..9410b32 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -222,14 +222,15 @@ def __call__( input_image = input_image.convert("RGB") # convert to torch tensor [H, W, rgb] -> [rgb, H, W] rgb = pil_to_tensor(input_image) + rgb = rgb.unsqueeze(0) # [1, rgb, H, W] elif isinstance(input_image, torch.Tensor): - rgb = input_image.squeeze() + rgb = input_image else: raise TypeError(f"Unknown input type: {type(input_image) = }") input_size = rgb.shape assert ( - 3 == rgb.dim() and 3 == input_size[0] - ), f"Wrong input shape {input_size}, expected [rgb, H, W]" + 4 == rgb.dim() and 3 == input_size[-3] + ), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" # Resize image if processing_res > 0: @@ -246,7 +247,7 @@ def __call__( # ----------------- Predicting depth ----------------- # Batch repeated input image - duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) + duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) single_rgb_dataset = TensorDataset(duplicated_rgb) if batch_size > 0: _bs = batch_size @@ -287,6 +288,7 @@ def __call__( depth_preds, scale_invariant=self.scale_invariant, shift_invariant=self.shift_invariant, + max_res=50, **(ensemble_kwargs or {}), ) else: @@ -297,7 +299,7 @@ def __call__( if match_input_res: depth_pred = resize( depth_pred, - input_size[1:], + input_size[-2:], interpolation=resample_method, antialias=True, ) diff --git a/marigold/util/image_util.py b/marigold/util/image_util.py index 9924bab..82078fe 100644 --- a/marigold/util/image_util.py +++ b/marigold/util/image_util.py @@ -86,7 +86,7 @@ def resize_max_res( Args: img (`torch.Tensor`): - Image tensor to be resized. + Image tensor to be resized. Expected shape: [B, C, H, W] max_edge_resolution (`int`): Maximum edge length (pixel). resample_method (`PIL.Image.Resampling`): @@ -95,8 +95,9 @@ def resize_max_res( Returns: `torch.Tensor`: Resized image. """ - assert 3 == img.dim() - _, original_height, original_width = img.shape + assert 4 == img.dim(), f"Invalid input shape {img.shape}" + + original_height, original_width = img.shape[-2:] downscale_factor = min( max_edge_resolution / original_width, max_edge_resolution / original_height )