Skip to content

Commit

Permalink
Fix callback_outputs is None error with callback on step end.
Browse files Browse the repository at this point in the history
  • Loading branch information
Skquark committed Mar 8, 2024
1 parent 28190fd commit 943fc6a
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion src/diffusers/pipelines/pia/pipeline_pia.py
Original file line number Diff line number Diff line change
Expand Up @@ -1009,7 +1009,7 @@ def __call__(
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
if callback_outputs is None:
if callback_outputs is not None:
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,10 @@ def __call__(
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if callback_outputs is not None:
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

if output_type not in ["pt", "np", "pil", "latent"]:
raise ValueError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -593,10 +593,10 @@ def __call__(
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)

latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
if callback_outputs is not None:
latents = callback_outputs.pop("latents", latents)
prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)

# Offload all models
self.maybe_free_model_hooks()
Expand Down

0 comments on commit 943fc6a

Please sign in to comment.