Skip to content

Commit

Permalink
merged feature/euler-scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
israfelsr committed Jan 9, 2024
2 parents b40a3d2 + 02a1f16 commit fae45b7
Show file tree
Hide file tree
Showing 7 changed files with 238 additions and 86 deletions.
3 changes: 2 additions & 1 deletion src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

T = TypeVar("T", bound="fl.Module")


TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")


Expand Down Expand Up @@ -91,6 +90,8 @@ def forward(
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)

latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)

# classifier-free guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler

__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
40 changes: 22 additions & 18 deletions src/refiners/foundationals/latent_diffusion/schedulers/ddim.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler


class DDIM(Scheduler):

def __init__(
self,
num_inference_steps: int,
Expand Down Expand Up @@ -31,27 +32,30 @@ def _generate_timesteps(self) -> Tensor:
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
timesteps = arange(
start=0, end=self.num_inference_steps, step=1,
device=self.device) * step_ratio + 1
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(
self.timesteps[step + 1]
if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
),
(self.timesteps[step + 1] if step < self.num_inference_steps -
1 else tensor(data=[0], device=self.device, dtype=self.dtype)),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[
timestep], (self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0 else
self.cumulative_scale_factors[0])
predicted_x = (x - sqrt(1 - current_scale_factor**2) *
noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(
1 - previous_scale_factor**2) * noise

self.previous_scale_factor = previous_scale_factor

return denoised_x
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections import deque

import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator
from collections import deque


class DPMSolver(Scheduler):
Expand Down Expand Up @@ -40,11 +38,13 @@ def _generate_timesteps(self) -> Tensor:
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
return tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
np.linspace(0, self.num_train_timesteps - 1,
self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0)

def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor,
step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
Expand All @@ -53,49 +53,52 @@ def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) ->
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_scale_factor = self.cumulative_scale_factors[
previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
denoised_x = (previous_noise_std / current_noise_std) * x - (
factor * previous_scale_factor) * noise
return denoised_x

def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
def multistep_dpm_solver_second_order_update(self, x: Tensor,
step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = (
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
self.timesteps[step + 1] if step < len(self.timesteps) -
1 else tensor([0]),
self.timesteps[step],
self.timesteps[step - 1],
)
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
current_data_estimation, next_data_estimation = self.estimated_data[
-1], self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[current_timestep],
self.signal_to_noise_ratios[next_timestep],
)
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_scale_factor = self.cumulative_scale_factors[
previous_timestep]
previous_std, current_std = (
self.noise_std[previous_timestep],
self.noise_std[current_timestep],
)
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
(current_ratio - next_ratio) / (previous_ratio - current_ratio))
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_std / current_std) * x
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
)
(previous_std / current_std) * x -
(factor * previous_scale_factor) * current_data_estimation - 0.5 *
(factor * previous_scale_factor) * estimation_delta)
return denoised_x

def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
def __call__(self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
Expand All @@ -104,14 +107,15 @@ def __call__(
(ODEs).
"""
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
scale_factor, noise_ratio = self.cumulative_scale_factors[
current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
denoised_x = (self.dpm_solver_first_order_update(
x=x, noise=estimated_denoised_data, step=step) if
(self.initial_steps == 0) else
self.multistep_dpm_solver_second_order_update(x=x,
step=step))
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x
83 changes: 83 additions & 0 deletions src/refiners/foundationals/latent_diffusion/schedulers/euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator
import torch
import numpy as np


class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()

@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()

def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps

def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)

def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)

def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
sigma = self.sigmas[step]

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0

alt_noise = torch.randn(noise.shape, generator=generator)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5

predicted_x = x - sigma_hat * noise

# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt

return denoised_x
48 changes: 30 additions & 18 deletions src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator
from typing import TypeVar

from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt

T = TypeVar("T", bound="Scheduler")


Expand Down Expand Up @@ -46,11 +45,16 @@ def __init__(
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(
self.noise_std)
self.timesteps = self._generate_timesteps()

@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
Expand All @@ -71,17 +75,20 @@ def _generate_timesteps(self) -> Tensor:
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))

def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x

def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)
return (linspace(
start=self.initial_diffusion_rate**(1 / power),
end=self.final_diffusion_rate**(1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)**power)

def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
Expand All @@ -92,7 +99,8 @@ def sample_noise_schedule(self) -> Tensor:
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")
raise ValueError(
f"Unknown noise schedule: {self.noise_schedule}")

def add_noise(
self,
Expand All @@ -115,14 +123,18 @@ def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x

def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
def to(self: T,
device: Device | str | None = None,
dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
self.timesteps = self.timesteps.to(device)
if dtype is not None:
self.dtype = dtype
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(
device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(
device, dtype=dtype)
return self
Loading

0 comments on commit fae45b7

Please sign in to comment.