Skip to content

Commit

Permalink
refactor: rename noise => predicted_noise
Browse files Browse the repository at this point in the history
and in euler, `alt_noise` can now be simply `noise`
  • Loading branch information
brycedrennan authored and deltheil committed Jan 24, 2024
1 parent 695c24d commit 12a5439
Show file tree
Hide file tree
Showing 8 changed files with 29 additions and 25 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -96,5 +96,5 @@ black = true
[tool.pyright]
include = ["src/refiners", "tests", "scripts"]
strict = ["*"]
exclude = ["**/__pycache__"]
exclude = ["**/__pycache__", "tests/weights"]
reportMissingTypeStubs = "warning"
8 changes: 5 additions & 3 deletions src/refiners/foundationals/latent_diffusion/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,15 +91,17 @@ def forward(
unconditional_prediction, conditional_prediction = self.unet(latents).chunk(2)

# classifier-free guidance
noise = unconditional_prediction + condition_scale * (conditional_prediction - unconditional_prediction)
predicted_noise = unconditional_prediction + condition_scale * (
conditional_prediction - unconditional_prediction
)
x = x.narrow(dim=1, start=0, length=4) # support > 4 channels for inpainting

if self.has_self_attention_guidance():
noise += self.compute_self_attention_guidance(
predicted_noise += self.compute_self_attention_guidance(
x=x, noise=unconditional_prediction, step=step, clip_text_embedding=clip_text_embedding, **kwargs
)

return self.scheduler(x, noise=noise, step=step)
return self.scheduler(x, predicted_noise=predicted_noise, step=step)

def structural_copy(self: TLatentDiffusionModel) -> TLatentDiffusionModel:
return self.__class__(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _generate_timesteps(self) -> Tensor:
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio + 1
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
assert self.first_inference_step <= step < self.num_inference_steps, "invalid step {step}"

timestep, previous_timestep = (
Expand All @@ -55,13 +55,13 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N
else self.cumulative_scale_factors[0]
),
)
predicted_x = (x - sqrt(1 - current_scale_factor**2) * noise) / current_scale_factor
predicted_x = (x - sqrt(1 - current_scale_factor**2) * predicted_noise) / current_scale_factor
noise_factor = sqrt(1 - previous_scale_factor**2)

# Do not add noise at the last step to avoid visual artifacts.
if step == self.num_inference_steps - 1:
noise_factor = 0

denoised_x = previous_scale_factor * predicted_x + noise_factor * noise
denoised_x = previous_scale_factor * predicted_x + noise_factor * predicted_noise

return denoised_x
Original file line number Diff line number Diff line change
Expand Up @@ -35,5 +35,5 @@ def _generate_timesteps(self) -> Tensor:
timesteps = arange(start=0, end=self.num_inference_steps, step=1, device=self.device) * step_ratio
return timesteps.flip(0)

def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
raise NotImplementedError
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def multistep_dpm_solver_second_order_update(self, x: Tensor, step: int) -> Tens
)
return denoised_x

def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Represents one step of the backward diffusion process that iteratively denoises the input data `x`.
Expand All @@ -118,7 +118,7 @@ def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | N

current_timestep = self.timesteps[step]
scale_factor, noise_ratio = self.cumulative_scale_factors[current_timestep], self.noise_std[current_timestep]
estimated_denoised_data = (x - noise_ratio * noise) / scale_factor
estimated_denoised_data = (x - noise_ratio * predicted_noise) / scale_factor
self.estimated_data.append(estimated_denoised_data)

if step == self.first_inference_step or (self.last_step_first_order and step == self.num_inference_steps - 1):
Expand Down
10 changes: 6 additions & 4 deletions src/refiners/foundationals/latent_diffusion/schedulers/euler.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def scale_model_input(self, x: Tensor, step: int) -> Tensor:
def __call__(
self,
x: Tensor,
noise: Tensor,
predicted_noise: Tensor,
step: int,
generator: Generator | None = None,
s_churn: float = 0.0,
Expand All @@ -72,13 +72,15 @@ def __call__(

gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0

alt_noise = torch.randn(noise.shape, generator=generator, device=noise.device, dtype=noise.dtype)
eps = alt_noise * s_noise
noise = torch.randn(
predicted_noise.shape, generator=generator, device=predicted_noise.device, dtype=predicted_noise.dtype
)
eps = noise * s_noise
sigma_hat = sigma * (gamma + 1)
if gamma > 0:
x = x + eps * (sigma_hat**2 - sigma**2) ** 0.5

predicted_x = x - sigma_hat * noise
predicted_x = x - sigma_hat * predicted_noise

# 1st order Euler
derivative = (x - predicted_x) / sigma_hat
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,9 @@ def __init__(
self.timesteps = self._generate_timesteps()

@abstractmethod
def __call__(self, x: Tensor, noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
def __call__(self, x: Tensor, predicted_noise: Tensor, step: int, generator: Generator | None = None) -> Tensor:
"""
Applies a step of the diffusion process to the input tensor `x` using the provided `noise` and `timestep`.
Applies a step of the diffusion process to the input tensor `x` using the provided `predicted_noise` and `timestep`.
This method should be overridden by subclasses to implement the specific diffusion process.
"""
Expand Down
18 changes: 9 additions & 9 deletions tests/foundationals/latent_diffusion/test_schedulers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,11 @@ def test_dpm_solver_diffusers(n_steps: int, last_step_first_order: bool):
refiners_scheduler = DPMSolver(num_inference_steps=n_steps, last_step_first_order=last_step_first_order)

sample = randn(1, 3, 32, 32)
noise = randn(1, 3, 32, 32)
predicted_noise = randn(1, 3, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)
assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"


Expand All @@ -60,11 +60,11 @@ def test_ddim_diffusers():
refiners_scheduler = DDIM(num_inference_steps=30)

sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"

Expand All @@ -86,15 +86,15 @@ def test_euler_diffusers():
refiners_scheduler = EulerScheduler(num_inference_steps=30)

sample = randn(1, 4, 32, 32)
noise = randn(1, 4, 32, 32)
predicted_noise = randn(1, 4, 32, 32)

ref_init_noise_sigma = diffusers_scheduler.init_noise_sigma # type: ignore
assert isinstance(ref_init_noise_sigma, Tensor)
assert isclose(ref_init_noise_sigma, refiners_scheduler.init_noise_sigma), "init_noise_sigma differ"

for step, timestep in enumerate(diffusers_scheduler.timesteps):
diffusers_output = cast(Tensor, diffusers_scheduler.step(noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, noise=noise, step=step)
diffusers_output = cast(Tensor, diffusers_scheduler.step(predicted_noise, timestep, sample).prev_sample) # type: ignore
refiners_output = refiners_scheduler(x=sample, predicted_noise=predicted_noise, step=step)

assert allclose(diffusers_output, refiners_output, rtol=0.01), f"outputs differ at step {step}"

Expand Down

0 comments on commit 12a5439

Please sign in to comment.