Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Aug 7, 2024
1 parent cd22af9 commit 2c50e6f
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 17 deletions.
4 changes: 2 additions & 2 deletions kornia/augmentation/_2d/intensity/dissolving.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Optional, Tuple, Union
from typing import Any, Dict, Optional, Tuple

from kornia.augmentation import random_generator as rg
from kornia.augmentation._2d.intensity.base import IntensityAugmentationBase2D
Expand Down Expand Up @@ -44,7 +44,7 @@ def __init__(
version: str = "2.1",
p: float = 0.5,
keepdim: bool = False,
**kwargs
**kwargs,
) -> None:
super().__init__(p=p, same_on_batch=True, keepdim=keepdim)
self.step_range = step_range
Expand Down
3 changes: 1 addition & 2 deletions kornia/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from . import external
from ._backend import (
Device,
Dtype,
Expand Down Expand Up @@ -33,8 +34,6 @@
)
from .module import ImageModule
from .tensor_wrapper import TensorWrapper # type: ignore
from . import external


__all__ = [
"external",
Expand Down
31 changes: 18 additions & 13 deletions kornia/filters/dissolving.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
from typing import Dict, Optional

import torch
import torch.nn as nn

from kornia.core import Module, Tensor
from kornia.core.external import diffusers

Expand All @@ -18,15 +16,14 @@ def __init__(self, model: Module, num_ddim_steps: int = 50) -> None:

def predict_start_from_noise(self, noise_pred: Tensor, timestep: int, latent: Tensor) -> Tensor:
return (
torch.sqrt(1. / self.scheduler.alphas_cumprod[timestep]) * latent -
torch.sqrt(1. / self.scheduler.alphas_cumprod[timestep] - 1) * noise_pred
torch.sqrt(1.0 / self.scheduler.alphas_cumprod[timestep]) * latent
- torch.sqrt(1.0 / self.scheduler.alphas_cumprod[timestep] - 1) * noise_pred
)

@torch.no_grad()
def init_prompt(self, prompt: str) -> None:
uncond_input = self.model.tokenizer(
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length,
return_tensors="pt"
[""], padding="max_length", max_length=self.model.tokenizer.model_max_length, return_tensors="pt"
)
uncond_embeddings = self.model.text_encoder(uncond_input.input_ids.to(self.model.device))[0]
text_input = self.model.tokenizer(
Expand All @@ -45,14 +42,14 @@ def init_prompt(self, prompt: str) -> None:
def encode_tensor_to_latent(self, image: Tensor) -> Tensor:
with torch.no_grad():
image = (image / 0.5 - 1).to(self.model.device)
latents = self.model.vae.encode(image)['latent_dist'].sample()
latents = self.model.vae.encode(image)["latent_dist"].sample()
latents = latents * 0.18215
return latents

@torch.no_grad()
def decode_tensor_to_latent(self, latents: Tensor) -> Tensor:
latents = 1 / 0.18215 * latents.detach()
image = self.model.vae.decode(latents)['sample']
image = self.model.vae.decode(latents)["sample"]
image = (image / 2 + 0.5).clamp(0, 1)
return image

Expand Down Expand Up @@ -104,26 +101,34 @@ class StableDiffusionDissolving(Module):
version: the version of the stable diffusion model.
**kwargs: additional arguments for `.from_pretrained`.
"""

def __init__(self, version: str = "2.1", **kwargs):
super().__init__()
StableDiffusionPipeline = diffusers.StableDiffusionPipeline
DDIMScheduler = diffusers.DDIMScheduler

# Load the scheduler and model pipeline from diffusers library
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False,
beta_start=0.00085,
beta_end=0.012,
beta_schedule="scaled_linear",
clip_sample=False,
set_alpha_to_one=False,
steps_offset=1,
)

if version == "1.4":
self._sdm_model = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4", scheduler=scheduler, **kwargs)
"CompVis/stable-diffusion-v1-4", scheduler=scheduler, **kwargs
)
elif version == "1.5":
self._sdm_model = StableDiffusionPipeline.from_pretrained(
"runwayml/stable-diffusion-v1-5", scheduler=scheduler, **kwargs)
"runwayml/stable-diffusion-v1-5", scheduler=scheduler, **kwargs
)
elif version == "2.1":
self._sdm_model = StableDiffusionPipeline.from_pretrained(
"stabilityai/stable-diffusion-2-1", scheduler=scheduler, **kwargs)
"stabilityai/stable-diffusion-2-1", scheduler=scheduler, **kwargs
)
else:
raise NotImplementedError

Expand Down

0 comments on commit 2c50e6f

Please sign in to comment.