diff --git a/marigold/marigold_pipeline.py b/marigold/marigold_pipeline.py index bad29af..ec9ea4e 100644 --- a/marigold/marigold_pipeline.py +++ b/marigold/marigold_pipeline.py @@ -120,7 +120,7 @@ def __call__( match_input_res: bool = True, resample_method: str = "bilinear", batch_size: int = 0, - seed: Union[int, None] = None, + generator: Union[torch.Generator, None] = None, color_map: str = "Spectral", show_progress_bar: bool = True, ensemble_kwargs: Dict = None, @@ -146,8 +146,8 @@ def __call__( batch_size (`int`, *optional*, defaults to `0`): Inference batch size, no bigger than `num_ensemble`. If set to 0, the script will automatically decide the proper batch size. - seed (`int`, *optional*, defaults to `None`) - Reproducibility seed. + generator (`torch.Generator`, *optional*, defaults to `None`) + Random generator for initial noise generation. show_progress_bar (`bool`, *optional*, defaults to `True`): Display a progress bar of diffusion denoising. color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): @@ -228,7 +228,7 @@ def __call__( rgb_in=batched_img, num_inference_steps=denoising_steps, show_pbar=show_progress_bar, - seed=seed, + generator=generator, ) depth_pred_ls.append(depth_pred_raw.detach()) depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze() @@ -322,7 +322,7 @@ def single_infer( self, rgb_in: torch.Tensor, num_inference_steps: int, - seed: Union[int, None], + generator: Union[torch.Generator, None], show_pbar: bool, ) -> torch.Tensor: """ @@ -335,6 +335,8 @@ def single_infer( Number of diffusion denoisign steps (DDIM) during inference. show_pbar (`bool`): Display a progress bar of diffusion denoising. + generator (`torch.Generator`) + Random generator for initial noise generation. Returns: `torch.Tensor`: Predicted depth map. """ @@ -349,16 +351,11 @@ def single_infer( rgb_latent = self.encode_rgb(rgb_in) # Initial depth map (noise) - if seed is None: - rand_num_generator = None - else: - rand_num_generator = torch.Generator(device=device) - rand_num_generator.manual_seed(seed) depth_latent = torch.randn( rgb_latent.shape, device=device, dtype=self.dtype, - generator=rand_num_generator, + generator=generator, ) # [B, 4, h, w] # Batched empty text embedding @@ -391,7 +388,7 @@ def single_infer( # compute the previous noisy sample x_t -> x_t-1 depth_latent = self.scheduler.step( - noise_pred, t, depth_latent, generator=rand_num_generator + noise_pred, t, depth_latent, generator=generator ).prev_sample depth = self.decode_depth(depth_latent) diff --git a/run.py b/run.py index aacc231..5274ca9 100644 --- a/run.py +++ b/run.py @@ -229,6 +229,13 @@ # Read input image input_image = Image.open(rgb_path) + # Random number generator + if seed is None: + generator = None + else: + generator = torch.Generator(device=device) + generator.manual_seed(seed) + # Predict depth pipe_out = pipe( input_image, @@ -240,7 +247,7 @@ color_map=color_map, show_progress_bar=True, resample_method=resample_method, - seed=seed, + generator=generator, ) depth_pred: np.ndarray = pipe_out.depth_np