Skip to content

Commit

Permalink
make Scheduler a fl.Module + Change name Scheduler -> Solver
Browse files Browse the repository at this point in the history
  • Loading branch information
limiteinductive committed Jan 31, 2024
1 parent 07cb2ff commit 73f6ccf
Show file tree
Hide file tree
Showing 19 changed files with 157 additions and 146 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ______________________________________________________________________

## Latest News 🔥

- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr))
- Added [Euler's method](https://arxiv.org/abs/2206.00364) to solvers (contributed by [@israfelsr](https://github.com/israfelsr))
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))
Expand Down
4 changes: 2 additions & 2 deletions scripts/conversion/convert_diffusers_controlnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ def convert(args: Args) -> dict[str, torch.Tensor]:
clip_text_embedding = torch.rand(1, 77, 768)
unet.set_clip_text_embedding(clip_text_embedding=clip_text_embedding)

scheduler = DPMSolver(num_inference_steps=10)
timestep = scheduler.timesteps[0].unsqueeze(dim=0)
solver = DPMSolver(num_inference_steps=10)
timestep = solver.timesteps[0].unsqueeze(dim=0)
unet.set_timestep(timestep=timestep.unsqueeze(dim=0))

x = torch.randn(1, 4, 64, 64)
Expand Down
4 changes: 2 additions & 2 deletions src/refiners/foundationals/latent_diffusion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
LatentDiffusionAutoencoder,
)
from refiners.foundationals.latent_diffusion.freeu import SDFreeUAdapter
from refiners.foundationals.latent_diffusion.schedulers import DPMSolver, Scheduler
from refiners.foundationals.latent_diffusion.solvers import DPMSolver, Solver
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import (
SD1ControlnetAdapter,
SD1IPAdapter,
Expand Down Expand Up @@ -33,7 +33,7 @@
"SDXLIPAdapter",
"SDXLT2IAdapter",
"DPMSolver",
"Scheduler",
"Solver",
"CLIPTextEncoderL",
"LatentDiffusionAutoencoder",
"SDFreeUAdapter",
Expand Down
24 changes: 12 additions & 12 deletions src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import refiners.fluxion.layers as fl
from refiners.foundationals.latent_diffusion.auto_encoder import LatentDiffusionAutoencoder
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver

T = TypeVar("T", bound="fl.Module")

Expand All @@ -20,7 +20,7 @@ def __init__(
unet: fl.Module,
lda: LatentDiffusionAutoencoder,
clip_text_encoder: fl.Module,
scheduler: Scheduler,
solver: Solver,
device: Device | str = "cpu",
dtype: DType = torch.float32,
) -> None:
Expand All @@ -30,10 +30,10 @@ def __init__(
self.unet = unet.to(device=self.device, dtype=self.dtype)
self.lda = lda.to(device=self.device, dtype=self.dtype)
self.clip_text_encoder = clip_text_encoder.to(device=self.device, dtype=self.dtype)
self.scheduler = scheduler.to(device=self.device, dtype=self.dtype)
self.solver = solver.to(device=self.device, dtype=self.dtype)

def set_inference_steps(self, num_steps: int, first_step: int = 0) -> None:
self.scheduler = self.scheduler.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)
self.solver = self.solver.rebuild(num_inference_steps=num_steps, first_inference_step=first_step)

def init_latents(
self,
Expand All @@ -51,15 +51,15 @@ def init_latents(
if init_image is None:
return noise
encoded_image = self.lda.encode_image(image=init_image.resize(size=(width, height)))
return self.scheduler.add_noise(
return self.solver.add_noise(
x=encoded_image,
noise=noise,
step=self.scheduler.first_inference_step,
step=self.solver.first_inference_step,
)

@property
def steps(self) -> list[int]:
return self.scheduler.inference_steps
return self.solver.inference_steps

@abstractmethod
def set_unet_context(self, *, timestep: Tensor, clip_text_embedding: Tensor, **_: Tensor) -> None:
Expand All @@ -82,12 +82,12 @@ def compute_self_attention_guidance(
def forward(
self, x: Tensor, step: int, *, clip_text_embedding: Tensor, condition_scale: float = 7.5, **kwargs: Tensor
) -> Tensor:
timestep = self.scheduler.timesteps[step].unsqueeze(dim=0)
timestep = self.solver.timesteps[step].unsqueeze(dim=0)
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)

latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
# scale latents for solvers that need it
latents = self.solver.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)

# classifier-free guidance
Expand All @@ -101,14 +101,14 @@ def forward(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)

return self.scheduler(x, predicted_noise=predicted_noise, step=step)
return self.solver(x, predicted_noise=predicted_noise, step=step)

def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
unet=self.unet.structural_copy(),
lda=self.lda.structural_copy(),
clip_text_encoder=self.clip_text_encoder.structural_copy(),
scheduler=self.scheduler,
solver=self.solver,
device=self.device,
dtype=self.dtype,
)
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def __call__(self, x: Tensor, /, noise: Tensor, step: int, targets: list[D]) ->
match step:
case step if step == target.start_step and target.init_latents is not None:
noise_view = target.crop(noise)
view = self.ldm.scheduler.add_noise(
view = self.ldm.solver.add_noise(
x=target.init_latents,
noise=noise_view,
step=step,
Expand Down
34 changes: 17 additions & 17 deletions src/refiners/foundationals/latent_diffusion/restart.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,22 @@
import torch

from refiners.foundationals.latent_diffusion.model import LatentDiffusionModel
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.solver import Solver

T = TypeVar("T", bound=LatentDiffusionModel)


def add_noise_interval(
scheduler: Scheduler,
solver: Solver,
/,
x: torch.Tensor,
noise: torch.Tensor,
initial_timestep: torch.Tensor,
target_timestep: torch.Tensor,
) -> torch.Tensor:
initial_cumulative_scale_factors = scheduler.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = scheduler.cumulative_scale_factors[target_timestep]
initial_cumulative_scale_factors = solver.cumulative_scale_factors[initial_timestep]
target_cumulative_scale_factors = solver.cumulative_scale_factors[target_timestep]

factor = target_cumulative_scale_factors / initial_cumulative_scale_factors
noised_x = factor * x + torch.sqrt(1 - factor**2) * noise
Expand All @@ -33,7 +33,7 @@ class Restart(Generic[T]):
Implements the restart sampling strategy from the paper "Restart Sampling for Improving Generative Processes"
(https://arxiv.org/pdf/2306.14878.pdf)
Works only with the DDIM scheduler for now.
Works only with the DDIM solver for now.
"""

ldm: T
Expand All @@ -43,7 +43,7 @@ class Restart(Generic[T]):
end_time: float = 2

def __post_init__(self) -> None:
assert isinstance(self.ldm.scheduler, DDIM), "Restart sampling only works with DDIM scheduler"
assert isinstance(self.ldm.solver, DDIM), "Restart sampling only works with DDIM solver"

def __call__(
self,
Expand All @@ -53,15 +53,15 @@ def __call__(
condition_scale: float = 7.5,
**kwargs: torch.Tensor,
) -> torch.Tensor:
original_scheduler = self.ldm.scheduler
new_scheduler = DDIM(self.ldm.scheduler.num_inference_steps, device=self.device, dtype=self.dtype)
new_scheduler.timesteps = self.timesteps
self.ldm.scheduler = new_scheduler
original_solver = self.ldm.solver
new_solver = DDIM(self.ldm.solver.num_inference_steps, device=self.device, dtype=self.dtype)
new_solver.timesteps = self.timesteps
self.ldm.solver = new_solver

for _ in range(self.num_iterations):
noise = torch.randn_like(input=x, device=self.device, dtype=self.dtype)
x = add_noise_interval(
new_scheduler,
new_solver,
x=x,
noise=noise,
initial_timestep=self.timesteps[-1],
Expand All @@ -73,26 +73,26 @@ def __call__(
x, step=step, clip_text_embedding=clip_text_embedding, condition_scale=condition_scale, **kwargs
)

self.ldm.scheduler = original_scheduler
self.ldm.solver = original_solver

return x

@cached_property
def start_step(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.scheduler.timesteps] - self.start_time)))
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas[self.ldm.solver.timesteps] - self.start_time)))

@cached_property
def end_timestep(self) -> int:
sigmas = self.ldm.scheduler.noise_std / self.ldm.scheduler.cumulative_scale_factors
sigmas = self.ldm.solver.noise_std / self.ldm.solver.cumulative_scale_factors
return int(torch.argmin(input=torch.abs(input=sigmas - self.end_time)))

@cached_property
def timesteps(self) -> torch.Tensor:
return (
torch.round(
torch.linspace(
start=int(self.ldm.scheduler.timesteps[self.start_step]),
start=int(self.ldm.solver.timesteps[self.start_step]),
end=self.end_timestep,
steps=self.num_steps,
)
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from refiners.fluxion.adapters.adapter import Adapter
from refiners.fluxion.context import Contexts
from refiners.fluxion.utils import gaussian_blur, interpolate
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver

if TYPE_CHECKING:
from refiners.foundationals.latent_diffusion.stable_diffusion_1.unet import SD1UNet
Expand Down Expand Up @@ -89,13 +89,13 @@ def compute_sag_mask(
return interpolate(attn_mask, Size((h, w)))

def compute_degraded_latents(
self, scheduler: Scheduler, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
self, solver: Solver, latents: Tensor, noise: Tensor, step: int, classifier_free_guidance: bool = True
) -> Tensor:
sag_mask = self.compute_sag_mask(latents=latents, classifier_free_guidance=classifier_free_guidance)
original_latents = scheduler.remove_noise(x=latents, noise=noise, step=step)
original_latents = solver.remove_noise(x=latents, noise=noise, step=step)
degraded_latents = gaussian_blur(original_latents, kernel_size=self.kernel_size, sigma=self.sigma)
degraded_latents = degraded_latents * sag_mask + original_latents * (1 - sag_mask)
return scheduler.add_noise(degraded_latents, noise=noise, step=step)
return solver.add_noise(degraded_latents, noise=noise, step=step)

def init_context(self) -> Contexts:
return {"self_attention_map": {"middle_block_attn_map": None, "middle_block_attn_shape": []}}
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from refiners.foundationals.latent_diffusion.solvers.ddim import DDIM
from refiners.foundationals.latent_diffusion.solvers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.solvers.dpm import DPMSolver
from refiners.foundationals.latent_diffusion.solvers.euler import Euler
from refiners.foundationals.latent_diffusion.solvers.solver import Solver

__all__ = ["Solver", "DPMSolver", "DDPM", "DDIM", "Euler"]
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver


class DDIM(Scheduler):
class DDIM(Solver):
def __init__(
self,
num_inference_steps: int,
Expand All @@ -25,15 +25,14 @@ def __init__(
device=device,
dtype=dtype,
)
self.timesteps = self._generate_timesteps()

def _generate_timesteps(self) -> Tensor:
"""
Generates decreasing timesteps with 'leading' spacing and offset of 1
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
similar to diffusers settings for the DDIM solver in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
timesteps = arange(start=0, end=self.num_inference_steps, step=1) * step_ratio + 1
return timesteps.flip(0)

def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from torch import Generator, Tensor, arange, device as Device, dtype as DType

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver


class DDPM(Scheduler):
class DDPM(Solver):
"""
Denoising Diffusion Probabilistic Model
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver


class DPMSolver(Scheduler):
class DPMSolver(Solver):
"""
Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
Expand Down Expand Up @@ -48,7 +48,6 @@ def _generate_timesteps(self) -> Tensor:
# ...and we want the same result as the original codebase.
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0)

def rebuild(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from refiners.foundationals.latent_diffusion.solvers.solver import NoiseSchedule, Solver


class EulerScheduler(Scheduler):
class Euler(Solver):
def __init__(
self,
num_inference_steps: int,
Expand Down Expand Up @@ -40,9 +40,7 @@ def _generate_timesteps(self) -> Tensor:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
timesteps = torch.tensor(np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps)).flip(0)
return timesteps

def _generate_sigmas(self) -> Tensor:
Expand Down
Loading

0 comments on commit 73f6ccf

Please sign in to comment.