-
Notifications
You must be signed in to change notification settings - Fork 58
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
138 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
83 changes: 83 additions & 0 deletions
83
src/refiners/foundationals/latent_diffusion/schedulers/euler.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
from refiners.foundationals.latent_diffusion.schedulers.scheduler import NoiseSchedule, Scheduler | ||
from torch import Tensor, device as Device, dtype as Dtype, float32, tensor, Generator | ||
import torch | ||
import numpy as np | ||
|
||
|
||
class EulerScheduler(Scheduler): | ||
def __init__( | ||
self, | ||
num_inference_steps: int, | ||
num_train_timesteps: int = 1_000, | ||
initial_diffusion_rate: float = 8.5e-4, | ||
final_diffusion_rate: float = 1.2e-2, | ||
noise_schedule: NoiseSchedule = NoiseSchedule.QUADRATIC, | ||
device: Device | str = "cpu", | ||
dtype: Dtype = float32, | ||
): | ||
if noise_schedule != NoiseSchedule.QUADRATIC: | ||
raise NotImplementedError | ||
super().__init__( | ||
num_inference_steps=num_inference_steps, | ||
num_train_timesteps=num_train_timesteps, | ||
initial_diffusion_rate=initial_diffusion_rate, | ||
final_diffusion_rate=final_diffusion_rate, | ||
noise_schedule=noise_schedule, | ||
device=device, | ||
dtype=dtype, | ||
) | ||
self.sigmas = self._generate_sigmas() | ||
|
||
@property | ||
def init_noise_sigma(self) -> Tensor: | ||
return self.sigmas.max() | ||
|
||
def _generate_timesteps(self) -> Tensor: | ||
# We need to use numpy here because: | ||
# numpy.linspace(0,999,31)[15] is 499.49999999999994 | ||
# torch.linspace(0,999,31)[15] is 499.5 | ||
# ...and we want the same result as the original codebase. | ||
timesteps = torch.tensor( | ||
np.linspace(0, self.num_train_timesteps - 1, self.num_inference_steps), dtype=self.dtype, device=self.device | ||
).flip(0) | ||
return timesteps | ||
|
||
def _generate_sigmas(self) -> Tensor: | ||
sigmas = self.noise_std / self.cumulative_scale_factors | ||
sigmas = torch.tensor(np.interp(self.timesteps.cpu().numpy(), np.arange(0, len(sigmas)), sigmas.cpu().numpy())) | ||
sigmas = torch.cat([sigmas, tensor([0.0])]) | ||
return sigmas.to(device=self.device, dtype=self.dtype) | ||
|
||
def scale_model_input(self, x: Tensor, step: int) -> Tensor: | ||
sigma = self.sigmas[step] | ||
return x / ((sigma**2 + 1) ** 0.5) | ||
|
||
def __call__( | ||
self, | ||
x: Tensor, | ||
noise: Tensor, | ||
step: int, | ||
generator: Generator | None = None, | ||
s_churn: float = 0.0, | ||
s_tmin: float = 0.0, | ||
s_tmax: float = float("inf"), | ||
s_noise: float = 1.0, | ||
) -> Tensor: | ||
sigma = self.sigmas[step] | ||
|
||
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0 | ||
|
||
alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype) | ||
eps = alt_noise * s_noise | ||
sigma_hat = sigma * (gamma + 1) | ||
if gamma > 0: | ||
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5 | ||
|
||
predicted_x = x - sigma_hat * noise | ||
|
||
# 1st order Euler | ||
derivative = (x - predicted_x) / sigma_hat | ||
dt = self.sigmas[step + 1] - sigma_hat | ||
denoised_x = x + derivative * dt | ||
|
||
return denoised_x |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters