Skip to content

Commit

Permalink
feature: Euler scheduler (#138)
Browse files Browse the repository at this point in the history
  • Loading branch information
israfelsr authored Jan 10, 2024
1 parent ff5ec74 commit 8423c5e
Show file tree
Hide file tree
Showing 7 changed files with 138 additions and 33 deletions.
3 changes: 2 additions & 1 deletion src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

T = TypeVar("T", bound="fl.Module")


TLatentDiffusionModel = TypeVar("TLatentDiffusionModel", bound="LatentDiffusionModel")


Expand Down Expand Up @@ -91,6 +90,8 @@ def forward(
self.set_unet_context(timestep=timestep, clip_text_embedding=clip_text_embedding, **kwargs)

latents = torch.cat(tensors=(x, x)) # for classifier-free guidance
# scale latents for schedulers that need it
latents = self.scheduler.scale_model_input(latents, step=step)
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)

# classifier-free guidance
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,6 @@
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler

__all__ = [
"Scheduler",
"DPMSolver",
"DDPM",
"DDIM",
]
__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
15 changes: 6 additions & 9 deletions src/refiners/foundationals/latent_diffusion/schedulers/ddim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler

Expand Down Expand Up @@ -34,7 +34,7 @@ def _generate_timesteps(self) -> Tensor:
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(
Expand All @@ -43,13 +43,10 @@ def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
),
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from collections import deque

import numpy as np
from torch import Tensor, device as Device, dtype as Dtype, exp, float32, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator
from collections import deque


class DPMSolver(Scheduler):
Expand Down Expand Up @@ -90,12 +88,7 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens
)
return denoised_x

def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
Expand Down
83 changes: 83 additions & 0 deletions src/refiners/foundationals/latent_diffusion/schedulers/euler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator
import torch
import numpy as np


class EulerScheduler(Scheduler):
def __init__(
self,
num_inference_steps: int,
num_train_timesteps: int = 1_000,
initial_diffusion_rate: float = 8.5e-4,
final_diffusion_rate: float = 1.2e-2,
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC,
device: Device | str = "cpu",
dtype: Dtype = float32,
):
if noise_schedule != NoiseSchedule.QUADRATIC:
raise NotImplementedError
super().__init__(
num_inference_steps=num_inference_steps,
num_train_timesteps=num_train_timesteps,
initial_diffusion_rate=initial_diffusion_rate,
final_diffusion_rate=final_diffusion_rate,
noise_schedule=noise_schedule,
device=device,
dtype=dtype,
)
self.sigmas = self._generate_sigmas()

@property
def init_noise_sigma(self) -> Tensor:
return self.sigmas.max()

def _generate_timesteps(self) -> Tensor:
# We need to use numpy here because:
# numpy.linspace(0,999,31)[15] is 499.49999999999994
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
timesteps = torch.tensor(
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device
).flip(0)
return timesteps

def _generate_sigmas(self) -> Tensor:
sigmas = self.noise_std / self.cumulative_scale_factors
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy()))
sigmas = torch.cat([sigmas, tensor([0.0])])
return sigmas.to(device=self.device, dtype=self.dtype)

def scale_model_input(self, x: Tensor, step: int) -> Tensor:
sigma = self.sigmas[step]
return x / ((sigma**2 + 1) ** 0.5)

def __call__(
self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
s_tmin: float = 0.0,
s_tmax: float = float("inf"),
s_noise: float = 1.0,
) -> Tensor:
sigma = self.sigmas[step]

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0

alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5

predicted_x = x - sigma_hat * noise

# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
dt = self.sigmas[step + 1] - sigma_hat
denoised_x = x + derivative * dt

return denoised_x
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator
from typing import TypeVar

from torch import Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt

T = TypeVar("T", bound="Scheduler")


Expand Down Expand Up @@ -50,7 +49,7 @@ def __init__(
self.timesteps = self._generate_timesteps()

@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
Expand All @@ -71,6 +70,12 @@ def _generate_timesteps(self) -> Tensor:
def steps(self) -> list[int]:
return list(range(self.num_inference_steps))

def scale_model_input(self, x: Tensor, step: int) -> Tensor:
"""
For compatibility with schedulers that need to scale the input according to the current timestep.
"""
return x

def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (
linspace(
Expand Down
36 changes: 33 additions & 3 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from warnings import warn

import pytest
from torch import Tensor, allclose, device as Device, equal, randn
from torch import Tensor, allclose, device as Device, equal, randn, isclose

from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler


def test_ddpm_diffusers():
Expand Down Expand Up @@ -63,6 +63,34 @@ def test_ddim_diffusers():
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_euler_diffusers():
from diffusers import EulerDiscreteScheduler

manual_seed(0)
diffusers_scheduler = EulerDiscreteScheduler(
beta_end=0.012,
beta_schedule="scaled_linear",
beta_start=0.00085,
num_train_timesteps=1000,
steps_offset=1,
timestep_spacing="linspace",
use_karras_sigmas=False,
)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = EulerScheduler(num_inference_steps=30)

sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)

assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_scheduler_remove_noise():
from diffusers import DDIMScheduler # type: ignore

Expand All @@ -84,7 +112,9 @@ def test_scheduler_remove_noise():
noise = randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
diffusers_output = cast(
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
Expand Down

0 comments on commit 8423c5e

Please sign in to comment.