Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Euler scheduler: follow up #172

Merged
merged 6 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ ______________________________________________________________________

## Latest News 🔥

- Added [Euler's method](https://arxiv.org/abs/2206.00364) to schedulers (contributed by [@israfelsr](https://github.com/israfelsr))
- Added [DINOv2](https://github.com/facebookresearch/dinov2) for high-performance visual features (contributed by [@Laurent2916](https://github.com/Laurent2916))
- Added [FreeU](https://github.com/ChenyangSi/FreeU) for improved quality at no cost (contributed by [@isamu-isozaki](https://github.com/isamu-isozaki))
- Added [Restart Sampling](https://github.com/Newbeeer/diffusion_restart_sampling) for improved image generation ([example](https://github.com/Newbeeer/diffusion_restart_sampling/issues/4))
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from refiners.foundationals.latent_diffusion.schedulers.ddim import DDIM
from refiners.foundationals.latent_diffusion.schedulers.ddpm import DDPM
from refiners.foundationals.latent_diffusion.schedulers.dpm_solver import DPMSolver
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler
from refiners.foundationals.latent_diffusion.schedulers.euler import EulerScheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler

__all__ = ["Scheduler", "DPMSolver", "DDPM", "DDIM", "EulerScheduler"]
13 changes: 8 additions & 5 deletions src/refiners/foundationals/latent_diffusion/schedulers/ddim.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor, Generator
from torch import Generator, Tensor, arange, device as Device, dtype as Dtype, float32, sqrt, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler

Expand Down Expand Up @@ -43,10 +43,13 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
current_scale_factor, previous_scale_factor = (
self.cumulative_scale_factors[timestep],
(
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from torch import Tensor, arange, device as Device
from torch import Generator, Tensor, arange, device as Device

from refiners.foundationals.latent_diffusion.schedulers.scheduler import Scheduler

Expand Down Expand Up @@ -30,5 +30,5 @@ def _generate_timesteps(self) -> Tensor:
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
import numpy as np
from torch import Tensor, device as Device, tensor, exp, float32, dtype as Dtype, Generator
from collections import deque

import numpy as np
from torch import Generator, Tensor, device as Device, dtype as Dtype, exp, float32, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler


class DPMSolver(Scheduler):
"""Implements DPM-Solver++ from https://arxiv.org/abs/2211.01095
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator
import torch
import numpy as np
import torch
from torch import Generator, Tensor, device as Device, dtype as Dtype, float32, tensor

from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler


class EulerScheduler(Scheduler):
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
from abc import ABC, abstractmethod
from enum import Enum
from torch import Tensor, device as Device, dtype as DType, linspace, float32, sqrt, log, Generator
from typing import TypeVar

from torch import Generator, Tensor, device as Device, dtype as DType, float32, linspace, log, sqrt

T = TypeVar("T", bound="Scheduler")


Expand Down
56 changes: 55 additions & 1 deletion tests/e2e/test_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from refiners.foundationals.latent_diffusion.multi_diffusion import DiffusionTarget
from refiners.foundationals.latent_diffusion.reference_only_control import ReferenceOnlyControlAdapter
from refiners.foundationals.latent_diffusion.restart import Restart
from refiners.foundationals.latent_diffusion.schedulers import DDIM
from refiners.foundationals.latent_diffusion.schedulers import DDIM, EulerScheduler
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule
from refiners.foundationals.latent_diffusion.stable_diffusion_1.multi_diffusion import SD1MultiDiffusion
from refiners.foundationals.latent_diffusion.stable_diffusion_xl.model import StableDiffusion_XL
Expand Down Expand Up @@ -65,6 +65,11 @@ def expected_image_std_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init.png").convert("RGB")


@pytest.fixture
def expected_image_std_random_init_euler(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_std_random_init_euler.png").convert("RGB")


@pytest.fixture
def expected_karras_random_init(ref_path: Path) -> Image.Image:
return Image.open(ref_path / "expected_karras_random_init.png").convert("RGB")
Expand Down Expand Up @@ -438,6 +443,24 @@ def sd15_ddim_karras(
return sd15


@pytest.fixture
def sd15_euler(
text_encoder_weights: Path, lda_weights: Path, unet_weights_std: Path, test_device: torch.device
) -> StableDiffusion_1:
if test_device.type == "cpu":
warn("not running on CPU, skipping")
pytest.skip()

euler_scheduler = EulerScheduler(num_inference_steps=30)
sd15 = StableDiffusion_1(scheduler=euler_scheduler, device=test_device)

sd15.clip_text_encoder.load_from_safetensors(text_encoder_weights)
sd15.lda.load_from_safetensors(lda_weights)
sd15.unet.load_from_safetensors(unet_weights_std)

return sd15


@pytest.fixture
def sd15_ddim_lda_ft_mse(
text_encoder_weights: Path, lda_ft_mse_weights: Path, unet_weights_std: Path, test_device: torch.device
Expand Down Expand Up @@ -529,6 +552,37 @@ def test_diffusion_std_random_init(
ensure_similar_images(predicted_image, expected_image_std_random_init)


@no_grad()
def test_diffusion_std_random_init_euler(
sd15_euler: StableDiffusion_1, expected_image_std_random_init_euler: Image.Image, test_device: torch.device
):
sd15 = sd15_euler
euler_scheduler = sd15_euler.scheduler
assert isinstance(euler_scheduler, EulerScheduler)
n_steps = 30

prompt = "a cute cat, detailed high-quality professional image"
negative_prompt = "lowres, bad anatomy, bad hands, cropped, worst quality"
clip_text_embedding = sd15.compute_clip_text_embedding(text=prompt, negative_text=negative_prompt)

sd15.set_num_inference_steps(n_steps)

manual_seed(2)
x = torch.randn(1, 4, 64, 64, device=test_device)
x = x * euler_scheduler.init_noise_sigma

for step in sd15.steps:
x = sd15(
x,
step=step,
clip_text_embedding=clip_text_embedding,
condition_scale=7.5,
)
predicted_image = sd15.lda.decode_latents(x)

ensure_similar_images(predicted_image, expected_image_std_random_init_euler)


@no_grad()
def test_diffusion_karras_random_init(
sd15_ddim_karras: StableDiffusion_1, expected_karras_random_init: Image.Image, test_device: torch.device
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
12 changes: 6 additions & 6 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from warnings import warn

import pytest
from torch import Tensor, allclose, device as Device, equal, randn, isclose
from torch import Tensor, allclose, device as Device, equal, isclose, randn

from refiners.fluxion import manual_seed
from refiners.foundationals.latent_diffusion.schedulers import DDIM, DDPM, DPMSolver, EulerScheduler
Expand Down Expand Up @@ -64,7 +64,7 @@ def test_ddim_diffusers():


def test_euler_diffusers():
from diffusers import EulerDiscreteScheduler
from diffusers import EulerDiscreteScheduler # type: ignore

manual_seed(0)
diffusers_scheduler = EulerDiscreteScheduler(
Expand All @@ -82,7 +82,9 @@ def test_euler_diffusers():
sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)

assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
assert isclose(ref_init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
Expand Down Expand Up @@ -112,9 +114,7 @@ def test_scheduler_remove_noise():
noise = randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
) # type: ignore
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"
Expand Down