Skip to content

Commit

Permalink
euler scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
israfelsr committed Jan 9, 2024
1 parent fae45b7 commit 3f823cf
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 110 deletions.
33 changes: 14 additions & 19 deletions src/refiners/foundationals/latent_diffusion/schedulers/ddim.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@


class DDIM(Scheduler):

def __init__(
self,
num_inference_steps: int,
Expand Down Expand Up @@ -32,29 +31,25 @@ def _generate_timesteps(self) -> Tensor:
similar to diffusers settings for the DDIM scheduler in Stable Diffusion 1.5
"""
step_ratio = self.num_train_timesteps // self.num_inference_steps
timesteps = arange(
start=0, end=self.num_inference_steps, step=1,
device=self.device) * step_ratio + 1
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)

def __call__(self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
(self.timesteps[step + 1] if step < self.num_inference_steps -
1 else tensor(data=[0], device=self.device, dtype=self.dtype)),
(
self.timesteps[step + 1]
if step < self.num_inference_steps - 1
else tensor(data=[0], device=self.device, dtype=self.dtype)
),
)
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[timestep], (
self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0
else self.cumulative_scale_factors[0]
)
current_scale_factor, previous_scale_factor = self.cumulative_scale_factors[
timestep], (self.cumulative_scale_factors[previous_timestep]
if previous_timestep > 0 else
self.cumulative_scale_factors[0])
predicted_x = (x - sqrt(1 - current_scale_factor**2) *
noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(
1 - previous_scale_factor**2) * noise
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
denoised_x = previous_scale_factor * predicted_x + sqrt(1 - previous_scale_factor**2) * noise

self.previous_scale_factor = previous_scale_factor

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,11 @@ def _generate_timesteps(self) -> Tensor:
# torch.linspace(0,999,31)[15] is 499.5
# ...and we want the same result as the original codebase.
return tensor(
np.linspace(0, self.num_train_timesteps - 1,
self.num_inference_steps + 1).round().astype(int)[1:],
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps + 1).round().astype(int)[1:],
device=self.device,
).flip(0)

def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor,
step: int) -> Tensor:
def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
timestep, previous_timestep = (
self.timesteps[step],
self.timesteps[step + 1 if step < len(self.timesteps) - 1 else 0],
Expand All @@ -53,52 +51,44 @@ def dpm_solver_first_order_update(self, x: Tensor, noise: Tensor,
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[timestep],
)
previous_scale_factor = self.cumulative_scale_factors[
previous_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_noise_std, current_noise_std = (
self.noise_std[previous_timestep],
self.noise_std[timestep],
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (previous_noise_std / current_noise_std) * x - (
factor * previous_scale_factor) * noise
denoised_x = (previous_noise_std / current_noise_std) * x - (factor * previous_scale_factor) * noise
return denoised_x

def multistep_dpm_solver_second_order_update(self, x: Tensor,
step: int) -> Tensor:
def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tensor:
previous_timestep, current_timestep, next_timestep = (
self.timesteps[step + 1] if step < len(self.timesteps) -
1 else tensor([0]),
self.timesteps[step + 1] if step < len(self.timesteps) - 1 else tensor([0]),
self.timesteps[step],
self.timesteps[step - 1],
)
current_data_estimation, next_data_estimation = self.estimated_data[
-1], self.estimated_data[-2]
current_data_estimation, next_data_estimation = self.estimated_data[-1], self.estimated_data[-2]
previous_ratio, current_ratio, next_ratio = (
self.signal_to_noise_ratios[previous_timestep],
self.signal_to_noise_ratios[current_timestep],
self.signal_to_noise_ratios[next_timestep],
)
previous_scale_factor = self.cumulative_scale_factors[
previous_timestep]
previous_scale_factor = self.cumulative_scale_factors[previous_timestep]
previous_std, current_std = (
self.noise_std[previous_timestep],
self.noise_std[current_timestep],
)
estimation_delta = (current_data_estimation - next_data_estimation) / (
(current_ratio - next_ratio) / (previous_ratio - current_ratio))
(current_ratio - next_ratio) / (previous_ratio - current_ratio)
)
factor = exp(-(previous_ratio - current_ratio)) - 1.0
denoised_x = (
(previous_std / current_std) * x -
(factor * previous_scale_factor) * current_data_estimation - 0.5 *
(factor * previous_scale_factor) * estimation_delta)
(previous_std / current_std) * x
- (factor * previous_scale_factor) * current_data_estimation
- 0.5 * (factor * previous_scale_factor) * estimation_delta
)
return denoised_x

def __call__(self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
Expand All @@ -107,15 +97,14 @@ def __call__(self,
(ODEs).
"""
current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[
current_timestep], self.noise_std[current_timestep]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)
denoised_x = (self.dpm_solver_first_order_update(
x=x, noise=estimated_denoised_data, step=step) if
(self.initial_steps == 0) else
self.multistep_dpm_solver_second_order_update(x=x,
step=step))
denoised_x = (
self.dpm_solver_first_order_update(x=x, noise=estimated_denoised_data, step=step)
if (self.initial_steps == 0)
else self.multistep_dpm_solver_second_order_update(x=x, step=step)
)
if self.initial_steps < 2:
self.initial_steps += 1
return denoised_x
39 changes: 16 additions & 23 deletions src/refiners/foundationals/latent_diffusion/schedulers/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,16 +45,11 @@ def __init__(
self.scale_factors = self.sample_noise_schedule()
self.cumulative_scale_factors = sqrt(self.scale_factors.cumprod(dim=0))
self.noise_std = sqrt(1.0 - self.scale_factors.cumprod(dim=0))
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(
self.noise_std)
self.signal_to_noise_ratios = log(self.cumulative_scale_factors) - log(self.noise_std)
self.timesteps = self._generate_timesteps()

@abstractmethod
def __call__(self,
x: Tensor,
noise: Tensor,
step: int,
generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
Expand Down Expand Up @@ -82,13 +77,16 @@ def scale_model_input(self, x: Tensor, step: int) -> Tensor:
return x

def sample_power_distribution(self, power: float = 2, /) -> Tensor:
return (linspace(
start=self.initial_diffusion_rate**(1 / power),
end=self.final_diffusion_rate**(1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)**power)
return (
linspace(
start=self.initial_diffusion_rate ** (1 / power),
end=self.final_diffusion_rate ** (1 / power),
steps=self.num_train_timesteps,
device=self.device,
dtype=self.dtype,
)
** power
)

def sample_noise_schedule(self) -> Tensor:
match self.noise_schedule:
Expand All @@ -99,8 +97,7 @@ def sample_noise_schedule(self) -> Tensor:
case "karras":
return 1 - self.sample_power_distribution(7)
case _:
raise ValueError(
f"Unknown noise schedule: {self.noise_schedule}")
raise ValueError(f"Unknown noise schedule: {self.noise_schedule}")

def add_noise(
self,
Expand All @@ -123,18 +120,14 @@ def remove_noise(self, x: Tensor, noise: Tensor, step: int) -> Tensor:
denoised_x = (x - noise_stds * noise) / cumulative_scale_factors
return denoised_x

def to(self: T,
device: Device | str | None = None,
dtype: DType | None = None) -> T: # type: ignore
def to(self: T, device: Device | str | None = None, dtype: DType | None = None) -> T: # type: ignore
if device is not None:
self.device = Device(device)
self.timesteps = self.timesteps.to(device)
if dtype is not None:
self.dtype = dtype
self.scale_factors = self.scale_factors.to(device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(
device, dtype=dtype)
self.cumulative_scale_factors = self.cumulative_scale_factors.to(device, dtype=dtype)
self.noise_std = self.noise_std.to(device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(
device, dtype=dtype)
self.signal_to_noise_ratios = self.signal_to_noise_ratios.to(device, dtype=dtype)
return self
50 changes: 14 additions & 36 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,7 @@
def test_ddpm_diffusers():
from diffusers import DDPMScheduler # type: ignore

diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012)
diffusers_scheduler = DDPMScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(1000)
refiners_scheduler = DDPM(num_inference_steps=1000)

Expand All @@ -25,23 +23,17 @@ def test_dpm_solver_diffusers():

manual_seed(0)

diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear",
beta_start=0.00085,
beta_end=0.012)
diffusers_scheduler = DiffuserScheduler(beta_schedule="scaled_linear", beta_start=0.00085, beta_end=0.012)
diffusers_scheduler.set_timesteps(30)
refiners_scheduler = DPMSolver(num_inference_steps=30)

sample = randn(1, 3, 32, 32)
noise = randn(1, 3, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor,
diffusers_scheduler.step(
noise, timestep,
sample).prev_sample) # type: ignore
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
assert allclose(diffusers_output, refiners_output,
rtol=0.01), f"outputs differ at step {step}"
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_ddim_diffusers():
Expand All @@ -65,14 +57,10 @@ def test_ddim_diffusers():
noise = randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor,
diffusers_scheduler.step(
noise, timestep,
sample).prev_sample) # type: ignore
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output,
rtol=0.01), f"outputs differ at step {step}"
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_euler_diffusers():
Expand All @@ -94,19 +82,13 @@ def test_euler_diffusers():
sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)

assert isclose(
diffusers_scheduler.init_noise_sigma,
refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"
assert isclose(diffusers_scheduler.init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor,
diffusers_scheduler.step(
noise, timestep,
sample).prev_sample) # type: ignore
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output,
rtol=0.01), f"outputs differ at step {step}"
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_scheduler_remove_noise():
Expand All @@ -131,15 +113,11 @@ def test_scheduler_remove_noise():

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(
Tensor,
diffusers_scheduler.step(
noise, timestep, sample).pred_original_sample) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample,
noise=noise,
step=step)

assert allclose(diffusers_output, refiners_output,
rtol=0.01), f"outputs differ at step {step}"
Tensor, diffusers_scheduler.step(noise, timestep, sample).pred_original_sample
) # type: ignore
refiners_output = refiners_scheduler.remove_noise(x=sample, noise=noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


def test_scheduler_device(test_device: Device):
Expand Down

0 comments on commit 3f823cf

Please sign in to comment.