Skip to content

Commit

Permalink
Added callback_on_step_end to AuraFlow.
Browse files Browse the repository at this point in the history
  • Loading branch information
Skquark committed Jul 13, 2024
1 parent b3fefb8 commit 443a49c
Showing 1 changed file with 26 additions and 1 deletion.
27 changes: 26 additions & 1 deletion src/diffusers/pipelines/aura_flow/pipeline_aura_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import List, Optional, Tuple, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import torch
from transformers import T5Tokenizer, UMT5EncoderModel

from ...callbacks import MultiPipelineCallbacks, PipelineCallback
from ...image_processor import VaeImageProcessor
from ...models import AuraFlowTransformer2DModel, AutoencoderKL
from ...models.attention_processor import AttnProcessor2_0, FusedAttnProcessor2_0, XFormersAttnProcessor
Expand Down Expand Up @@ -154,10 +155,18 @@ def check_inputs(
negative_prompt_embeds=None,
prompt_attention_mask=None,
negative_prompt_attention_mask=None,
callback_on_step_end_tensor_inputs=None,
):
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")

if callback_on_step_end_tensor_inputs is not None and not all(
k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
):
raise ValueError(
f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
)

if prompt is not None and prompt_embeds is not None:
raise ValueError(
f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
Expand Down Expand Up @@ -402,6 +411,10 @@ def __call__(
max_sequence_length: int = 256,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback_on_step_end: Optional[
Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks]
] = None,
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
) -> Union[ImagePipelineOutput, Tuple]:
r"""
Function invoked when calling the pipeline for generation.
Expand Down Expand Up @@ -468,6 +481,9 @@ def __call__(
If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is returned
where the first element is a list with the generated images.
"""
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs

# 1. Check inputs. Raise error if not correct
height = height or self.transformer.config.sample_size * self.vae_scale_factor
width = width or self.transformer.config.sample_size * self.vae_scale_factor
Expand All @@ -481,6 +497,7 @@ def __call__(
negative_prompt_embeds,
prompt_attention_mask,
negative_prompt_attention_mask,
callback_on_step_end_tensor_inputs,
)

# 2. Determine batch size.
Expand Down Expand Up @@ -567,6 +584,14 @@ def __call__(
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]

if callback_on_step_end is not None:
callback_kwargs = {}
for k in callback_on_step_end_tensor_inputs:
callback_kwargs[k] = locals()[k]
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
if callback_outputs is not None:
latents = callback_outputs.pop("latents", latents)

# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
Expand Down

0 comments on commit 443a49c

Please sign in to comment.