diff --git a/mindone/diffusers/__init__.py b/mindone/diffusers/__init__.py index 091c35ed57..eac4836f7a 100644 --- a/mindone/diffusers/__init__.py +++ b/mindone/diffusers/__init__.py @@ -162,6 +162,7 @@ "DPMSolverMultistepScheduler", "DPMSolverMultistepInverseScheduler", "DPMSolverSinglestepScheduler", + "EDMDPMSolverMultistepScheduler", "EDMEulerScheduler", "EulerAncestralDiscreteScheduler", "EulerDiscreteScheduler", diff --git a/mindone/diffusers/schedulers/scheduling_edm_euler.py b/mindone/diffusers/schedulers/scheduling_edm_euler.py index f359339bd1..439742a6de 100644 --- a/mindone/diffusers/schedulers/scheduling_edm_euler.py +++ b/mindone/diffusers/schedulers/scheduling_edm_euler.py @@ -249,18 +249,22 @@ def index_for_timestep(self, timestep, schedule_timesteps=None): if schedule_timesteps is None: schedule_timesteps = self.timesteps - if (schedule_timesteps == timestep).sum() > 1: - pos = 1 - else: - pos = 0 + index_candidates_num = (schedule_timesteps == timestep).sum() + if index_candidates_num == 0: + step_index = len(self.timesteps) - 1 # The sigma index that is taken for the **very** first `step` # is always the second index (or the last index if there is only 1) # This way we can ensure we don't accidentally skip a sigma in # case we start in the middle of the denoising schedule (e.g. for image-to-image) - indices = (schedule_timesteps == timestep).nonzero() + else: + if index_candidates_num > 1: + pos = 1 + else: + pos = 0 + step_index = int((schedule_timesteps == timestep).nonzero()[pos]) - return int(indices[pos]) + return step_index # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index def _init_step_index(self, timestep):