From 96aa40b3e9068ef40063b834b2410c6352277afe Mon Sep 17 00:00:00 2001 From: Cui-yshoho Date: Thu, 7 Nov 2024 14:54:27 +0800 Subject: [PATCH] fix(diffusers/schedulers): fix some bug in schedulers --- mindone/diffusers/__init__.py | 1 + .../diffusers/schedulers/scheduling_edm_euler.py | 16 ++++++++++------ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index 091c35ed57..eac4836f7a 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -162,6 +162,7 @@ "DPMSolverMultistepScheduler", "DPMSolverMultistepInverseScheduler", "DPMSolverSinglestepScheduler", + "EDMDPMSolverMultistepScheduler", "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", diff --git a/mindone/diffusers/schedulers/scheduling_edm_euler.py b/mindone/diffusers/schedulers/scheduling_edm_euler.py index f359339bd1..439742a6de 100644 --- a/mindone/diffusers/schedulers/scheduling_edm_euler.py +++ b/mindone/diffusers/schedulers/scheduling_edm_euler.py @@ -249,18 +249,22 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps - if (schedule_timesteps == timestep).sum() > 1: - pos = 1 - else: - pos = 0 + index_candidates_num = (schedule_timesteps == timestep).sum() + if index_candidates_num == 0: + step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - indices = (schedule_timesteps == timestep).nonzero() + else: + if index_candidates_num > 1: + pos = 1 + else: + pos = 0 + step_index = int((schedule_timesteps == timestep).nonzero()[pos]) - return int(indices[pos]) + return step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep):