diff --git a/kornia/augmentation/_2d/intensity/dissolving.py b/kornia/augmentation/_2d/intensity/dissolving.py index 13e9df00b65..270f5278734 100644 --- a/kornia/augmentation/_2d/intensity/dissolving.py +++ b/kornia/augmentation/_2d/intensity/dissolving.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Tuple, Union +from typing import Any, Dict, Optional, Tuple from kornia.augmentation import random_generator as rg from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D @@ -44,7 +44,7 @@ def __init__( version: str = "2.1", p: float = 0.5, keepdim: bool = False, - **kwargs + **kwargs, ) -> None: super().__init__(p=p, same_on_batch=True, keepdim=keepdim) self.step_range = step_range diff --git a/kornia/core/__init__.py b/kornia/core/__init__.py index 0aba5a8c0c1..2e1ebc2f7c8 100644 --- a/kornia/core/__init__.py +++ b/kornia/core/__init__.py @@ -1,3 +1,4 @@ +from . import external from ._backend import ( Device, Dtype, @@ -33,8 +34,6 @@ ) from .module import ImageModule from .tensor_wrapper import TensorWrapper # type: ignore -from . import external - __all__ = [ "external", diff --git a/kornia/filters/dissolving.py b/kornia/filters/dissolving.py index 0eed8f072dd..f4fb3734dd1 100644 --- a/kornia/filters/dissolving.py +++ b/kornia/filters/dissolving.py @@ -1,7 +1,5 @@ -from typing import Dict, Optional - import torch -import torch.nn as nn + from kornia.core import Module, Tensor from kornia.core.external import diffusers @@ -18,15 +16,14 @@ def __init__(self, model: Module, num_ddim_steps: int = 50) -> None: def predict_start_from_noise(self, noise_pred: Tensor, timestep: int, latent: Tensor) -> Tensor: return ( - torch.sqrt(1. / self.scheduler.alphas_cumprod[timestep]) * latent - - torch.sqrt(1. / self.scheduler.alphas_cumprod[timestep] - 1) * noise_pred + torch.sqrt(1.0 / self.scheduler.alphas_cumprod[timestep]) * latent + - torch.sqrt(1.0 / self.scheduler.alphas_cumprod[timestep] - 1) * noise_pred ) @torch.no_grad() def init_prompt(self, prompt: str) -> None: uncond_input = self.model.tokenizer( - [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, - return_tensors="pt" + [""], padding="max_length", max_length=self.model.tokenizer.model_max_length, return_tensors="pt" ) uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0] text_input = self.model.tokenizer( @@ -45,14 +42,14 @@ def init_prompt(self, prompt: str) -> None: def encode_tensor_to_latent(self, image: Tensor) -> Tensor: with torch.no_grad(): image = (image / 0.5 - 1).to(self.model.device) - latents = self.model.vae.encode(image)['latent_dist'].sample() + latents = self.model.vae.encode(image)["latent_dist"].sample() latents = latents * 0.18215 return latents @torch.no_grad() def decode_tensor_to_latent(self, latents: Tensor) -> Tensor: latents = 1 / 0.18215 * latents.detach() - image = self.model.vae.decode(latents)['sample'] + image = self.model.vae.decode(latents)["sample"] image = (image / 2 + 0.5).clamp(0, 1) return image @@ -104,6 +101,7 @@ class StableDiffusionDissolving(Module): version: the version of the stable diffusion model. **kwargs: additional arguments for `.from_pretrained`. """ + def __init__(self, version: str = "2.1", **kwargs): super().__init__() StableDiffusionPipeline = diffusers.StableDiffusionPipeline @@ -111,19 +109,26 @@ def __init__(self, version: str = "2.1", **kwargs): # Load the scheduler and model pipeline from diffusers library scheduler = DDIMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False, + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + clip_sample=False, + set_alpha_to_one=False, steps_offset=1, ) if version == "1.4": self._sdm_model = StableDiffusionPipeline.from_pretrained( - "CompVis/stable-diffusion-v1-4", scheduler=scheduler, **kwargs) + "CompVis/stable-diffusion-v1-4", scheduler=scheduler, **kwargs + ) elif version == "1.5": self._sdm_model = StableDiffusionPipeline.from_pretrained( - "runwayml/stable-diffusion-v1-5", scheduler=scheduler, **kwargs) + "runwayml/stable-diffusion-v1-5", scheduler=scheduler, **kwargs + ) elif version == "2.1": self._sdm_model = StableDiffusionPipeline.from_pretrained( - "stabilityai/stable-diffusion-2-1", scheduler=scheduler, **kwargs) + "stabilityai/stable-diffusion-2-1", scheduler=scheduler, **kwargs + ) else: raise NotImplementedError