diff --git a/docs/source/en/using-diffusers/write_own_pipeline.md b/docs/source/en/using-diffusers/write_own_pipeline.md index e4de0d2f0fbe..8bd478e4ec10 100644 --- a/docs/source/en/using-diffusers/write_own_pipeline.md +++ b/docs/source/en/using-diffusers/write_own_pipeline.md @@ -273,7 +273,7 @@ Lastly, convert the image to a `PIL.Image` to see your generated image! ```py >>> image = (image / 2 + 0.5).clamp(0, 1).squeeze() >>> image = (image.permute(1, 2, 0) * 255).to(torch.uint8).cpu().numpy() ->>> images = (image * 255).round().astype("uint8") +>>> image = (image * 255).round().astype("uint8") >>> image = Image.fromarray(image) >>> image ``` diff --git a/examples/community/imagic_stable_diffusion.py b/examples/community/imagic_stable_diffusion.py index 6a612c74aa08..2fd36285b4b7 100644 --- a/examples/community/imagic_stable_diffusion.py +++ b/examples/community/imagic_stable_diffusion.py @@ -370,8 +370,9 @@ def __call__( r""" Function invoked when calling the pipeline for generation. Args: - prompt (`str` or `List[str]`): - The prompt or prompts to guide the image generation. + alpha (`float`, *optional*, defaults to 1.2): + The interpolation factor between the original and optimized text embeddings. A value closer to 0 + will resemble the original input image. height (`int`, *optional*, defaults to 512): The height in pixels of the generated image. width (`int`, *optional*, defaults to 512): @@ -385,16 +386,9 @@ def __call__( Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, usually at the expense of lower image quality. - eta (`float`, *optional*, defaults to 0.0): - Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to - [`schedulers.DDIMScheduler`], will be ignored for others. generator (`torch.Generator`, *optional*): A [torch generator](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make generation deterministic. - latents (`torch.FloatTensor`, *optional*): - Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image - generation. Can be used to tweak the same generation with different prompts. If not provided, a latents - tensor will ge generated by sampling using the supplied random `generator`. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generate image. Choose between [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `nd.array`. @@ -407,6 +401,9 @@ def __call__( callback_steps (`int`, *optional*, defaults to 1): The frequency at which the `callback` function will be called. If not specified, the callback will be called at every step. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. Returns: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. diff --git a/examples/community/stable_diffusion_reference.py b/examples/community/stable_diffusion_reference.py index 924548b35ca3..1b02831ad19c 100644 --- a/examples/community/stable_diffusion_reference.py +++ b/examples/community/stable_diffusion_reference.py @@ -31,7 +31,7 @@ torch_dtype=torch.float16 ).to('cuda:0') - >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe_controlnet.scheduler.config) + >>> pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config) >>> result_img = pipe(ref_image=input_image, prompt="1girl", diff --git a/examples/community/stable_diffusion_tensorrt_inpaint.py b/examples/community/stable_diffusion_tensorrt_inpaint.py index 0c97b23fc5fd..9ace697cc7a6 100755 --- a/examples/community/stable_diffusion_tensorrt_inpaint.py +++ b/examples/community/stable_diffusion_tensorrt_inpaint.py @@ -1011,7 +1011,7 @@ def __call__( """ self.generator = generator self.denoising_steps = num_inference_steps - self.guidance_scale = guidance_scale + self._guidance_scale = guidance_scale # Pre-compute latent input scales and linear multistep coefficients self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) diff --git a/examples/community/stable_diffusion_tensorrt_txt2img.py b/examples/community/stable_diffusion_tensorrt_txt2img.py index 548dad984b6d..54661d66a2eb 100755 --- a/examples/community/stable_diffusion_tensorrt_txt2img.py +++ b/examples/community/stable_diffusion_tensorrt_txt2img.py @@ -882,7 +882,7 @@ def __call__( """ self.generator = generator self.denoising_steps = num_inference_steps - self.guidance_scale = guidance_scale + self._guidance_scale = guidance_scale # Pre-compute latent input scales and linear multistep coefficients self.scheduler.set_timesteps(self.denoising_steps, device=self.torch_device) diff --git a/examples/text_to_image/train_text_to_image_lora.py b/examples/text_to_image/train_text_to_image_lora.py index 39590fa8666b..34838500748e 100644 --- a/examples/text_to_image/train_text_to_image_lora.py +++ b/examples/text_to_image/train_text_to_image_lora.py @@ -58,12 +58,17 @@ def save_model_card( - repo_id: str, images: list = None, base_model: str = None, dataset_name: str = None, repo_folder: str = None + repo_id: str, + images: list = None, + base_model: str = None, + dataset_name: str = None, + repo_folder: str = None, ): img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" model_description = f""" # LoRA text2image fine-tuning - {repo_id} diff --git a/examples/text_to_image/train_text_to_image_sdxl.py b/examples/text_to_image/train_text_to_image_sdxl.py index 04f8c3dba417..2d77e9c8bfa3 100644 --- a/examples/text_to_image/train_text_to_image_sdxl.py +++ b/examples/text_to_image/train_text_to_image_sdxl.py @@ -74,9 +74,10 @@ def save_model_card( vae_path: str = None, ): img_str = "" - for i, image in enumerate(images): - image.save(os.path.join(repo_folder, f"image_{i}.png")) - img_str += f"![img_{i}](./image_{i}.png)\n" + if images is not None: + for i, image in enumerate(images): + image.save(os.path.join(repo_folder, f"image_{i}.png")) + img_str += f"![img_{i}](./image_{i}.png)\n" model_description = f""" # Text-to-image finetuning - {repo_id} diff --git a/src/diffusers/loaders/lora.py b/src/diffusers/loaders/lora.py index 55ccdda94973..0870f3a67a3d 100644 --- a/src/diffusers/loaders/lora.py +++ b/src/diffusers/loaders/lora.py @@ -1192,7 +1192,7 @@ def set_lora_device(self, adapter_names: List[str], device: Union[torch.device, class StableDiffusionXLLoraLoaderMixin(LoraLoaderMixin): """This class overrides `LoraLoaderMixin` with LoRA loading/saving code that's specific to SDXL""" - # Overrride to properly handle the loading and unloading of the additional text encoder. + # Override to properly handle the loading and unloading of the additional text encoder. def load_lora_weights( self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], diff --git a/src/diffusers/loaders/textual_inversion.py b/src/diffusers/loaders/textual_inversion.py index e5aeea488407..aaaf4b68bb5f 100644 --- a/src/diffusers/loaders/textual_inversion.py +++ b/src/diffusers/loaders/textual_inversion.py @@ -215,7 +215,7 @@ def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer): embedding = state_dict["string_to_param"]["*"] else: raise ValueError( - f"Loaded state dictonary is incorrect: {state_dict}. \n\n" + f"Loaded state dictionary is incorrect: {state_dict}. \n\n" "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`" " input key." ) diff --git a/src/diffusers/models/unets/unet_2d_blocks.py b/src/diffusers/models/unets/unet_2d_blocks.py index 73099eafc7b4..9ebf6982ca82 100644 --- a/src/diffusers/models/unets/unet_2d_blocks.py +++ b/src/diffusers/models/unets/unet_2d_blocks.py @@ -249,6 +249,81 @@ def get_down_block( raise ValueError(f"{down_block_type} does not exist.") +def get_mid_block( + mid_block_type: str, + temb_channels: int, + in_channels: int, + resnet_eps: float, + resnet_act_fn: str, + resnet_groups: int, + output_scale_factor: float = 1.0, + transformer_layers_per_block: int = 1, + num_attention_heads: Optional[int] = None, + cross_attention_dim: Optional[int] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + mid_block_only_cross_attention: bool = False, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + attention_type: str = "default", + resnet_skip_time_act: bool = False, + cross_attention_norm: Optional[str] = None, + attention_head_dim: Optional[int] = 1, + dropout: float = 0.0, +): + if mid_block_type == "UNetMidBlock2DCrossAttn": + return UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block, + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + num_attention_heads=num_attention_heads, + resnet_groups=resnet_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + return UNetMidBlock2DSimpleCrossAttn( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + cross_attention_dim=cross_attention_dim, + attention_head_dim=attention_head_dim, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + return UNetMidBlock2D( + in_channels=in_channels, + temb_channels=temb_channels, + dropout=dropout, + num_layers=0, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + output_scale_factor=output_scale_factor, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + return None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + def get_up_block( up_block_type: str, num_layers: int, diff --git a/src/diffusers/models/unets/unet_2d_condition.py b/src/diffusers/models/unets/unet_2d_condition.py index 05addc2481f7..fee7a34fb216 100644 --- a/src/diffusers/models/unets/unet_2d_condition.py +++ b/src/diffusers/models/unets/unet_2d_condition.py @@ -44,10 +44,8 @@ ) from ..modeling_utils import ModelMixin from .unet_2d_blocks import ( - UNetMidBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, get_down_block, + get_mid_block, get_up_block, ) @@ -239,44 +237,18 @@ def __init__( num_attention_heads = num_attention_heads or attention_head_dim # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: - for layer_number_per_block in transformer_layers_per_block: - if isinstance(layer_number_per_block, list): - raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + self._check_config( + down_block_types=down_block_types, + up_block_types=up_block_types, + only_cross_attention=only_cross_attention, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + cross_attention_dim=cross_attention_dim, + transformer_layers_per_block=transformer_layers_per_block, + reverse_transformer_layers_per_block=reverse_transformer_layers_per_block, + attention_head_dim=attention_head_dim, + num_attention_heads=num_attention_heads, + ) # input conv_in_padding = (conv_in_kernel - 1) // 2 @@ -285,23 +257,13 @@ def __init__( ) # time - if time_embedding_type == "fourier": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." - ) + time_embed_dim, timestep_input_dim = self._set_time_proj( + time_embedding_type, + block_out_channels=block_out_channels, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + time_embedding_dim=time_embedding_dim, + ) self.time_embedding = TimestepEmbedding( timestep_input_dim, @@ -311,96 +273,33 @@ def __init__( cond_proj_dim=time_cond_proj_dim, ) - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - self.encoder_hid_proj = ImageProjection( - image_embed_dim=encoder_hid_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None + self._set_encoder_hid_proj( + encoder_hid_dim_type, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + ) # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim + self._set_class_embedding( + class_embed_type, + act_fn=act_fn, + num_class_embeds=num_class_embeds, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + timestep_input_dim=timestep_input_dim, + ) - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif addition_embed_type == "image": - # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + self._set_add_embedding( + addition_embed_type, + addition_embed_type_num_heads=addition_embed_type_num_heads, + addition_time_embed_dim=addition_time_embed_dim, + cross_attention_dim=cross_attention_dim, + encoder_hid_dim=encoder_hid_dim, + flip_sin_to_cos=flip_sin_to_cos, + freq_shift=freq_shift, + projection_class_embeddings_input_dim=projection_class_embeddings_input_dim, + time_embed_dim=time_embed_dim, + ) if time_embedding_act_fn is None: self.time_embed_act = None @@ -478,57 +377,28 @@ def __init__( self.down_blocks.append(down_block) # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim[-1], - attention_head_dim=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - only_cross_attention=mid_block_only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif mid_block_type == "UNetMidBlock2D": - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - num_layers=0, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - add_attention=False, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") + self.mid_block = get_mid_block( + mid_block_type, + temb_channels=blocks_time_embed_dim, + in_channels=block_out_channels[-1], + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + output_scale_factor=mid_block_scale_factor, + transformer_layers_per_block=transformer_layers_per_block[-1], + num_attention_heads=num_attention_heads[-1], + cross_attention_dim=cross_attention_dim[-1], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + mid_block_only_cross_attention=mid_block_only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[-1], + dropout=dropout, + ) # count how many layers upsample the images self.num_upsamplers = 0 @@ -599,14 +469,214 @@ def __init__( self.conv_act = get_activation(act_fn) else: - self.conv_norm_out = None - self.conv_act = None + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + self._set_pos_net_if_use_gligen(attention_type=attention_type, cross_attention_dim=cross_attention_dim) + + def _check_config( + self, + down_block_types: Tuple[str], + up_block_types: Tuple[str], + only_cross_attention: Union[bool, Tuple[bool]], + block_out_channels: Tuple[int], + layers_per_block: [int, Tuple[int]], + cross_attention_dim: Union[int, Tuple[int]], + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]], + reverse_transformer_layers_per_block: bool, + attention_head_dim: int, + num_attention_heads: Optional[Union[int, Tuple[int]]], + ): + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + def _set_time_proj( + self, + time_embedding_type: str, + block_out_channels: int, + flip_sin_to_cos: bool, + freq_shift: float, + time_embedding_dim: int, + ) -> Tuple[int, int]: + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + return time_embed_dim, timestep_input_dim + + def _set_encoder_hid_proj( + self, + encoder_hid_dim_type: Optional[str], + cross_attention_dim: Union[int, Tuple[int]], + encoder_hid_dim: Optional[int], + ): + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + def _set_class_embedding( + self, + class_embed_type: Optional[str], + act_fn: str, + num_class_embeds: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + timestep_input_dim: int, + ): + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) + def _set_add_embedding( + self, + addition_embed_type: str, + addition_embed_type_num_heads: int, + addition_time_embed_dim: Optional[int], + flip_sin_to_cos: bool, + freq_shift: float, + cross_attention_dim: Optional[int], + encoder_hid_dim: Optional[int], + projection_class_embeddings_input_dim: Optional[int], + time_embed_dim: int, + ): + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int): if attention_type in ["gated", "gated-text-image"]: positive_len = 768 if isinstance(cross_attention_dim, int): @@ -840,6 +910,130 @@ def unload_lora(self): if hasattr(module, "set_lora_layer"): module.set_lora_layer(None) + def get_time_embed( + self, sample: torch.Tensor, timestep: Union[torch.Tensor, float, int] + ) -> Optional[torch.Tensor]: + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + return t_emb + + def get_class_embed(self, sample: torch.Tensor, class_labels: Optional[torch.Tensor]) -> Optional[torch.Tensor]: + class_emb = None + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + return class_emb + + def get_aug_embed( + self, emb: torch.Tensor, encoder_hidden_states: torch.Tensor, added_cond_kwargs: Dict + ) -> Optional[torch.Tensor]: + aug_emb = None + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb = self.add_embedding(image_embs, hint) + return aug_emb + + def process_encoder_hidden_states(self, encoder_hidden_states: torch.Tensor, added_cond_kwargs) -> torch.Tensor: + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + return encoder_hidden_states + def forward( self, sample: torch.FloatTensor, @@ -952,96 +1146,22 @@ def forward( sample = 2 * sample - 1.0 # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - + t_emb = self.get_time_embed(sample=sample, timestep=timestep) emb = self.time_embedding(t_emb, timestep_cond) aug_emb = None - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) - - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - + class_emb = self.get_class_embed(sample=sample, class_labels=class_labels) + if class_emb is not None: if self.config.class_embeddings_concat: emb = torch.cat([emb, class_emb], dim=-1) else: emb = emb + class_emb - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "text_image": - # Kandinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - - image_embs = added_cond_kwargs.get("image_embeds") - text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) - aug_emb = self.add_embedding(text_embs, image_embs) - elif self.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - elif self.config.addition_embed_type == "image": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - aug_emb = self.add_embedding(image_embs) - elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - hint = added_cond_kwargs.get("hint") - aug_emb, hint = self.add_embedding(image_embs, hint) + aug_emb = self.get_aug_embed( + emb=emb, encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) + if self.config.addition_embed_type == "image_hint": + aug_emb, hint = aug_emb sample = torch.cat([sample, hint], dim=1) emb = emb + aug_emb if aug_emb is not None else emb @@ -1049,33 +1169,9 @@ def forward( if self.time_embed_act is not None: emb = self.time_embed_act(emb) - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": - # Kadinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds) - encoder_hidden_states = (encoder_hidden_states, image_embeds) + encoder_hidden_states = self.process_encoder_hidden_states( + encoder_hidden_states=encoder_hidden_states, added_cond_kwargs=added_cond_kwargs + ) # 2. pre-process sample = self.conv_in(sample) diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py index c46dadb53e6a..c794bd00ce85 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff.py @@ -797,7 +797,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): diff --git a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py index f5ada63dfdfc..4b5cc12b1265 100644 --- a/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +++ b/src/diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py @@ -441,6 +441,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.text_to_video_synthesis/pipeline_text_to_video_synth.TextToVideoSDPipeline.decode_latents def decode_latents(self, latents): latents = 1 / self.vae.config.scaling_factor * latents @@ -735,6 +770,7 @@ def __call__( prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -784,6 +820,9 @@ def __call__( not provided, `negative_prompt_embeds` are generated from the `negative_prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated video. Choose between `torch.FloatTensor`, `PIL.Image` or `np.array`. @@ -870,13 +909,10 @@ def __call__( if self.do_classifier_free_guidance: prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_videos_per_prompt, output_hidden_state + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_videos_per_prompt ) - if self.do_classifier_free_guidance: - image_embeds = torch.cat([negative_image_embeds, image_embeds]) # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps) @@ -902,7 +938,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 for free_init_iter in range(num_free_init_iters): diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py index b186ec5cab2f..c8af65c78505 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet.py @@ -1206,7 +1206,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py index 10fc4384de29..377af876aaeb 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py @@ -1206,7 +1206,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] diff --git a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py index 35a4ae67c9be..b23f78a8b3fd 100644 --- a/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +++ b/src/diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py @@ -1495,7 +1495,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7.2 Create tensor stating which controlnets to keep controlnet_keep = [] diff --git a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py index 60707cc1e2f7..97c533be5864 100644 --- a/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +++ b/src/diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py @@ -268,7 +268,6 @@ def forward( return objs -# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat class UNetFlatConditionModel(ModelMixin, ConfigMixin): r""" A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py index acaeab1c6f50..f914020dd505 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py @@ -477,8 +477,9 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds def prepare_ip_adapter_image_embeds( - self, ip_adapter_image, ip_adapter_image_embeds, do_classifier_free_guidance, device, num_images_per_prompt + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt ): if ip_adapter_image_embeds is None: if not isinstance(ip_adapter_image, list): @@ -502,7 +503,7 @@ def prepare_ip_adapter_image_embeds( [single_negative_image_embeds] * num_images_per_prompt, dim=0 ) - if do_classifier_free_guidance: + if self.do_classifier_free_guidance: single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) single_image_embeds = single_image_embeds.to(device) @@ -699,6 +700,10 @@ def cross_attention_kwargs(self): def clip_skip(self): return self._clip_skip + @property + def do_classifier_free_guidance(self): + return False + @property def num_timesteps(self): return self._num_timesteps @@ -845,7 +850,7 @@ def __call__( if ip_adapter_image is not None or ip_adapter_image_embeds is not None: image_embeds = self.prepare_ip_adapter_image_embeds( - ip_adapter_image, ip_adapter_image_embeds, False, device, batch_size * num_images_per_prompt + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ) # 3. Encode input prompt @@ -860,7 +865,7 @@ def __call__( prompt, device, num_images_per_prompt, - False, + self.do_classifier_free_guidance, negative_prompt=None, prompt_embeds=prompt_embeds, negative_prompt_embeds=None, @@ -906,7 +911,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 8. LCM Multistep Sampling Loop num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order diff --git a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py index 469305f248e7..967d845367d4 100644 --- a/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +++ b/src/diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py @@ -461,6 +461,41 @@ def encode_image(self, image, device, num_images_per_prompt, output_hidden_state return image_embeds, uncond_image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_ip_adapter_image_embeds + def prepare_ip_adapter_image_embeds( + self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt + ): + if ip_adapter_image_embeds is None: + if not isinstance(ip_adapter_image, list): + ip_adapter_image = [ip_adapter_image] + + if len(ip_adapter_image) != len(self.unet.encoder_hid_proj.image_projection_layers): + raise ValueError( + f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters." + ) + + image_embeds = [] + for single_ip_adapter_image, image_proj_layer in zip( + ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers + ): + output_hidden_state = not isinstance(image_proj_layer, ImageProjection) + single_image_embeds, single_negative_image_embeds = self.encode_image( + single_ip_adapter_image, device, 1, output_hidden_state + ) + single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0) + single_negative_image_embeds = torch.stack( + [single_negative_image_embeds] * num_images_per_prompt, dim=0 + ) + + if self.do_classifier_free_guidance: + single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds]) + single_image_embeds = single_image_embeds.to(device) + + image_embeds.append(single_image_embeds) + else: + image_embeds = ip_adapter_image_embeds + return image_embeds + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.run_safety_checker def run_safety_checker(self, image, device, dtype): if self.safety_checker is None: @@ -590,6 +625,10 @@ def cross_attention_kwargs(self): def clip_skip(self): return self._clip_skip + @property + def do_classifier_free_guidance(self): + return False + @property def num_timesteps(self): return self._num_timesteps @@ -610,6 +649,7 @@ def __call__( latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, ip_adapter_image: Optional[PipelineImageInput] = None, + ip_adapter_image_embeds: Optional[List[torch.FloatTensor]] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, @@ -660,6 +700,9 @@ def __call__( provided, text embeddings are generated from the `prompt` input argument. ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters. + ip_adapter_image_embeds (`List[torch.FloatTensor]`, *optional*): + Pre-generated image embeddings for IP-Adapter. If not + provided, embeddings are computed from the `ip_adapter_image` input argument. output_type (`str`, *optional*, defaults to `"pil"`): The output format of the generated image. Choose between `PIL.Image` or `np.array`. return_dict (`bool`, *optional*, defaults to `True`): @@ -726,12 +769,10 @@ def __call__( batch_size = prompt_embeds.shape[0] device = self._execution_device - # do_classifier_free_guidance = guidance_scale > 1.0 - if ip_adapter_image is not None: - output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True - image_embeds, negative_image_embeds = self.encode_image( - ip_adapter_image, device, num_images_per_prompt, output_hidden_state + if ip_adapter_image is not None or ip_adapter_image_embeds is not None: + image_embeds = self.prepare_ip_adapter_image_embeds( + ip_adapter_image, ip_adapter_image_embeds, device, batch_size * num_images_per_prompt ) # 3. Encode input prompt @@ -746,7 +787,7 @@ def __call__( prompt, device, num_images_per_prompt, - False, + self.do_classifier_free_guidance, negative_prompt=None, prompt_embeds=prompt_embeds, negative_prompt_embeds=None, @@ -786,7 +827,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, None) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 8. LCM MultiStep Sampling Loop: num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order diff --git a/src/diffusers/pipelines/pia/pipeline_pia.py b/src/diffusers/pipelines/pia/pipeline_pia.py index a32844762b55..59e21678b62f 100644 --- a/src/diffusers/pipelines/pia/pipeline_pia.py +++ b/src/diffusers/pipelines/pia/pipeline_pia.py @@ -987,7 +987,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 8. Denoising loop num_free_init_iters = self._free_init_num_iters if self.free_init_enabled else 1 diff --git a/src/diffusers/pipelines/pipeline_utils.py b/src/diffusers/pipelines/pipeline_utils.py index 18a4b5cb346b..c92df24251f4 100644 --- a/src/diffusers/pipelines/pipeline_utils.py +++ b/src/diffusers/pipelines/pipeline_utils.py @@ -170,7 +170,7 @@ def is_safetensors_compatible(filenames, variant=None, passed_components=None) - sf_filenames.add(os.path.normpath(filename)) for filename in pt_filenames: - # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extention = '.bam' + # filename = 'foo/bar/baz.bam' -> path = 'foo/bar', filename = 'baz', extension = '.bam' path, filename = os.path.split(filename) filename, extension = os.path.splitext(filename) @@ -375,7 +375,7 @@ def _get_pipeline_class( if repo_id is not None and hub_revision is not None: # if we load the pipeline code from the Hub - # make sure to overwrite the `revison` + # make sure to overwrite the `revision` revision = hub_revision return get_class_from_dynamic_module( @@ -451,7 +451,7 @@ def load_sub_model( ) load_method_name = None - # retrive load method name + # retrieve load method name for class_name, class_candidate in class_candidates.items(): if class_candidate is not None and issubclass(class_obj, class_candidate): load_method_name = importable_classes[class_name][1] @@ -1897,7 +1897,7 @@ def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]: else: # 2. we forced `local_files_only=True` when `model_info` failed raise EnvironmentError( - f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occured" + f"Cannot load model {pretrained_model_name}: model is not cached locally and an error occurred" " while trying to fetch metadata from the Hub. Please check out the root cause in the stacktrace" " above." ) from model_info_call_error diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py index 2746c6ad43ea..5c6e67d7282b 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py @@ -1111,7 +1111,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 7.2 Optionally get Guidance Scale Embedding timestep_cond = None diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py index 2be701f1601a..bef7b256b0cb 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py @@ -1398,7 +1398,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 9.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 9.2 Optionally get Guidance Scale Embedding timestep_cond = None diff --git a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py index 51e6f47b83b6..3773ea6e9728 100644 --- a/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +++ b/src/diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py @@ -777,7 +777,11 @@ def __call__( extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7.1 Add image embeds for IP-Adapter - added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None + added_cond_kwargs = ( + {"image_embeds": image_embeds} + if ip_adapter_image is not None or ip_adapter_image_embeds is not None + else None + ) # 8. Denoising loop # Each denoising step also includes refinement of the latents with respect to the diff --git a/tests/models/unets/test_models_unet_2d_condition.py b/tests/models/unets/test_models_unet_2d_condition.py index 2d3dce7a31c2..db07b126e44c 100644 --- a/tests/models/unets/test_models_unet_2d_condition.py +++ b/tests/models/unets/test_models_unet_2d_condition.py @@ -62,7 +62,10 @@ def create_ip_adapter_state_dict(model): key_id = 1 for name in model.attn_processors.keys(): - cross_attention_dim = None if name.endswith("attn1.processor") else model.config.cross_attention_dim + cross_attention_dim = ( + None if name.endswith("attn1.processor") or "motion_module" in name else model.config.cross_attention_dim + ) + if name.startswith("mid_block"): hidden_size = model.config.block_out_channels[-1] elif name.startswith("up_blocks"): @@ -71,6 +74,7 @@ def create_ip_adapter_state_dict(model): elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = model.config.block_out_channels[block_id] + if cross_attention_dim is not None: sd = IPAdapterAttnProcessor( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0 diff --git a/tests/pipelines/animatediff/test_animatediff.py b/tests/pipelines/animatediff/test_animatediff.py index 412d536c6e14..3b789e4ff0f3 100644 --- a/tests/pipelines/animatediff/test_animatediff.py +++ b/tests/pipelines/animatediff/test_animatediff.py @@ -18,7 +18,7 @@ from diffusers.utils.testing_utils import numpy_cosine_similarity_distance, require_torch_gpu, slow, torch_device from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin def to_np(tensor): @@ -28,7 +28,7 @@ def to_np(tensor): return tensor -class AnimateDiffPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AnimateDiffPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = AnimateDiffPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = TEXT_TO_IMAGE_BATCH_PARAMS diff --git a/tests/pipelines/animatediff/test_animatediff_video2video.py b/tests/pipelines/animatediff/test_animatediff_video2video.py index bfb607ea507d..6cc54d97d8c6 100644 --- a/tests/pipelines/animatediff/test_animatediff_video2video.py +++ b/tests/pipelines/animatediff/test_animatediff_video2video.py @@ -18,7 +18,7 @@ from diffusers.utils.testing_utils import torch_device from ..pipeline_params import TEXT_TO_IMAGE_PARAMS, VIDEO_TO_VIDEO_BATCH_PARAMS -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin def to_np(tensor): @@ -28,7 +28,7 @@ def to_np(tensor): return tensor -class AnimateDiffVideoToVideoPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class AnimateDiffVideoToVideoPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = AnimateDiffVideoToVideoPipeline params = TEXT_TO_IMAGE_PARAMS batch_params = VIDEO_TO_VIDEO_BATCH_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet.py b/tests/pipelines/controlnet/test_controlnet.py index 96373a1a11f4..b7839eb99638 100644 --- a/tests/pipelines/controlnet/test_controlnet.py +++ b/tests/pipelines/controlnet/test_controlnet.py @@ -54,6 +54,7 @@ TEXT_TO_IMAGE_PARAMS, ) from ..test_pipelines_common import ( + IPAdapterTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, @@ -110,7 +111,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): class ControlNetPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -273,7 +278,7 @@ def test_controlnet_lcm_custom_timesteps(self): class StableDiffusionMultiControlNetPipelineFastTests( - PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase + IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS @@ -490,7 +495,7 @@ def test_inference_multiple_prompt_input(self): class StableDiffusionMultiControlNetOneModelPipelineFastTests( - PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase + IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionControlNetPipeline params = TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/controlnet/test_controlnet_img2img.py b/tests/pipelines/controlnet/test_controlnet_img2img.py index 5e54384d14c2..89e2b3803dee 100644 --- a/tests/pipelines/controlnet/test_controlnet_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_img2img.py @@ -52,6 +52,7 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) from ..test_pipelines_common import ( + IPAdapterTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, @@ -62,7 +63,11 @@ class ControlNetImg2ImgPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} @@ -181,7 +186,7 @@ def test_inference_batch_single_identical(self): class StableDiffusionMultiControlNetPipelineFastTests( - PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase + IPAdapterTesterMixin, PipelineTesterMixin, PipelineKarrasSchedulerTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} diff --git a/tests/pipelines/controlnet/test_controlnet_inpaint.py b/tests/pipelines/controlnet/test_controlnet_inpaint.py index 661fa1107af6..67e0da4de9cd 100644 --- a/tests/pipelines/controlnet/test_controlnet_inpaint.py +++ b/tests/pipelines/controlnet/test_controlnet_inpaint.py @@ -51,11 +51,7 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, ) -from ..test_pipelines_common import ( - PipelineKarrasSchedulerTesterMixin, - PipelineLatentTesterMixin, - PipelineTesterMixin, -) +from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl.py b/tests/pipelines/controlnet/test_controlnet_sdxl.py index 5e9a6f997b20..dd566403157e 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl.py @@ -48,6 +48,7 @@ TEXT_TO_IMAGE_PARAMS, ) from ..test_pipelines_common import ( + IPAdapterTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, @@ -59,6 +60,7 @@ class StableDiffusionXLControlNetPipelineFastTests( + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, diff --git a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py index e6ec616eefb5..7d2ba8cc28fd 100644 --- a/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py +++ b/tests/pipelines/controlnet/test_controlnet_sdxl_img2img.py @@ -36,6 +36,7 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) from ..test_pipelines_common import ( + IPAdapterTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, @@ -46,7 +47,11 @@ class ControlNetPipelineSDXLImg2ImgFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLControlNetImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py index 5d33b45c0973..eaf8fa2cdd59 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models.py @@ -20,13 +20,15 @@ ) from ..pipeline_params import TEXT_TO_IMAGE_BATCH_PARAMS, TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class LatentConsistencyModelPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class LatentConsistencyModelPipelineFastTests( + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = LatentConsistencyModelPipeline params = TEXT_TO_IMAGE_PARAMS - {"negative_prompt", "negative_prompt_embeds"} batch_params = TEXT_TO_IMAGE_BATCH_PARAMS - {"negative_prompt"} diff --git a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py index 5b4e2b191f53..cfd596dcd0ed 100644 --- a/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py +++ b/tests/pipelines/latent_consistency_models/test_latent_consistency_models_img2img.py @@ -27,14 +27,14 @@ TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() class LatentConsistencyModelImg2ImgPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase ): pipeline_class = LatentConsistencyModelImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width", "negative_prompt", "negative_prompt_embeds"} diff --git a/tests/pipelines/pia/test_pia.py b/tests/pipelines/pia/test_pia.py index 214f085e057e..2813dc70a71d 100644 --- a/tests/pipelines/pia/test_pia.py +++ b/tests/pipelines/pia/test_pia.py @@ -17,7 +17,7 @@ from diffusers.utils import is_xformers_available, logging from diffusers.utils.testing_utils import floats_tensor, torch_device -from ..test_pipelines_common import PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineTesterMixin def to_np(tensor): @@ -27,7 +27,7 @@ def to_np(tensor): return tensor -class PIAPipelineFastTests(PipelineTesterMixin, unittest.TestCase): +class PIAPipelineFastTests(IPAdapterTesterMixin, PipelineTesterMixin, unittest.TestCase): pipeline_class = PIAPipeline params = frozenset( [ diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion.py b/tests/pipelines/stable_diffusion/test_stable_diffusion.py index d8c3710310ce..57671bbdcc9a 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion.py @@ -23,7 +23,11 @@ import numpy as np import torch from huggingface_hub import hf_hub_download -from transformers import CLIPTextConfig, CLIPTextModel, CLIPTokenizer +from transformers import ( + CLIPTextConfig, + CLIPTextModel, + CLIPTokenizer, +) from diffusers import ( AutoencoderKL, @@ -60,7 +64,12 @@ TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() @@ -100,7 +109,11 @@ def _test_stable_diffusion_compile(in_queue, out_queue, timeout): class StableDiffusionPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionPipeline params = TEXT_TO_IMAGE_PARAMS @@ -177,7 +190,7 @@ def get_dummy_inputs(self, device, seed=0): "generator": generator, "num_inference_steps": 2, "guidance_scale": 6.0, - "output_type": "numpy", + "output_type": "np", } return inputs diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py index 2259143a59d2..4483fd8e0b8c 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_img2img.py @@ -55,7 +55,12 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() @@ -94,7 +99,11 @@ def _test_img2img_compile(in_queue, out_queue, timeout): class StableDiffusionImg2ImgPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py index 522b62892fc9..d8cdcc867dfe 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_inpaint.py @@ -57,7 +57,12 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() @@ -98,7 +103,11 @@ def _test_inpaint_compile(in_queue, out_queue, timeout): class StableDiffusionInpaintPipelineFastTests( - PipelineLatentTesterMixin, PipelineKarrasSchedulerTesterMixin, PipelineTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineKarrasSchedulerTesterMixin, + PipelineTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS diff --git a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py index 4262133bbe34..0986f02deeaa 100644 --- a/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py +++ b/tests/pipelines/stable_diffusion/test_stable_diffusion_instruction_pix2pix.py @@ -47,7 +47,11 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) -from ..test_pipelines_common import PipelineKarrasSchedulerTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import ( + PipelineKarrasSchedulerTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, +) enable_full_determinism() diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py index 16ef7e3009bd..a27614a2c717 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl.py @@ -49,14 +49,23 @@ TEXT_TO_IMAGE_IMAGE_PARAMS, TEXT_TO_IMAGE_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) enable_full_determinism() class StableDiffusionXLPipelineFastTests( - PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, + unittest.TestCase, ): pipeline_class = StableDiffusionXLPipeline params = TEXT_TO_IMAGE_PARAMS diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py index b9827df2f98e..0bcffeb078b8 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_adapter.py @@ -44,6 +44,7 @@ from ..pipeline_params import TEXT_GUIDED_IMAGE_VARIATION_BATCH_PARAMS, TEXT_GUIDED_IMAGE_VARIATION_PARAMS from ..test_pipelines_common import ( + IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, assert_mean_pixel_difference, @@ -54,7 +55,7 @@ class StableDiffusionXLAdapterPipelineFastTests( - PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase + IPAdapterTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin, unittest.TestCase ): pipeline_class = StableDiffusionXLAdapterPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py index 6474d02c194d..3a0229ac23ca 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_img2img.py @@ -54,13 +54,20 @@ TEXT_GUIDED_IMAGE_VARIATION_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin, SDXLOptionalComponentsTesterMixin +from ..test_pipelines_common import ( + IPAdapterTesterMixin, + PipelineLatentTesterMixin, + PipelineTesterMixin, + SDXLOptionalComponentsTesterMixin, +) enable_full_determinism() -class StableDiffusionXLImg2ImgPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLImg2ImgPipelineFastTests( + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLImg2ImgPipeline params = TEXT_GUIDED_IMAGE_VARIATION_PARAMS - {"height", "width"} required_optional_params = PipelineTesterMixin.required_optional_params - {"latents"} diff --git a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py index c0a20df5020f..11c711e82e8b 100644 --- a/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py +++ b/tests/pipelines/stable_diffusion_xl/test_stable_diffusion_xl_inpaint.py @@ -48,13 +48,15 @@ TEXT_GUIDED_IMAGE_INPAINTING_PARAMS, TEXT_TO_IMAGE_CALLBACK_CFG_PARAMS, ) -from ..test_pipelines_common import PipelineLatentTesterMixin, PipelineTesterMixin +from ..test_pipelines_common import IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin enable_full_determinism() -class StableDiffusionXLInpaintPipelineFastTests(PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase): +class StableDiffusionXLInpaintPipelineFastTests( + IPAdapterTesterMixin, PipelineLatentTesterMixin, PipelineTesterMixin, unittest.TestCase +): pipeline_class = StableDiffusionXLInpaintPipeline params = TEXT_GUIDED_IMAGE_INPAINTING_PARAMS batch_params = TEXT_GUIDED_IMAGE_INPAINTING_BATCH_PARAMS diff --git a/tests/pipelines/test_pipelines_common.py b/tests/pipelines/test_pipelines_common.py index 7f51847caf07..3c439d9c7042 100644 --- a/tests/pipelines/test_pipelines_common.py +++ b/tests/pipelines/test_pipelines_common.py @@ -8,7 +8,7 @@ import tempfile import unittest import uuid -from typing import Callable, Union +from typing import Any, Callable, Dict, Union import numpy as np import PIL.Image @@ -29,6 +29,7 @@ UNet2DConditionModel, ) from diffusers.image_processor import VaeImageProcessor +from diffusers.loaders import IPAdapterMixin from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.utils import logging from diffusers.utils.import_utils import is_accelerate_available, is_accelerate_version, is_xformers_available @@ -44,6 +45,7 @@ get_autoencoder_tiny_config, get_consistency_vae_config, ) +from ..models.unets.test_models_unet_2d_condition import create_ip_adapter_state_dict from ..others.test_utils import TOKEN, USER, is_staging_test @@ -59,6 +61,118 @@ def check_same_shape(tensor_list): return all(shape == shapes[0] for shape in shapes[1:]) +class IPAdapterTesterMixin: + """ + This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes. + It provides a set of common tests for pipelines that support IP Adapters. + """ + + def test_pipeline_signature(self): + parameters = inspect.signature(self.pipeline_class.__call__).parameters + + assert issubclass(self.pipeline_class, IPAdapterMixin) + self.assertIn( + "ip_adapter_image", + parameters, + "`ip_adapter_image` argument must be supported by the `__call__` method", + ) + self.assertIn( + "ip_adapter_image_embeds", + parameters, + "`ip_adapter_image_embeds` argument must be supported by the `__call__` method", + ) + + def _get_dummy_image_embeds(self, cross_attention_dim: int = 32): + return torch.randn((2, 1, cross_attention_dim), device=torch_device) + + def _modify_inputs_for_ip_adapter_test(self, inputs: Dict[str, Any]): + parameters = inspect.signature(self.pipeline_class.__call__).parameters + if "image" in parameters.keys() and "strength" in parameters.keys(): + inputs["num_inference_steps"] = 4 + + inputs["output_type"] = "np" + inputs["return_dict"] = False + return inputs + + def test_ip_adapter_single(self, expected_max_diff: float = 1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs)[0] + + adapter_state_dict = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights(adapter_state_dict) + + # forward pass with single ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(0.0) + output_without_adapter_scale = pipe(**inputs)[0] + + # forward pass with single ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] + pipe.set_ip_adapter_scale(42.0) + output_with_adapter_scale = pipe(**inputs)[0] + + max_diff_without_adapter_scale = np.abs(output_without_adapter_scale - output_without_adapter).max() + max_diff_with_adapter_scale = np.abs(output_with_adapter_scale - output_without_adapter).max() + + self.assertLess( + max_diff_without_adapter_scale, + expected_max_diff, + "Output without ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_adapter_scale, 1e-2, "Output with ip-adapter must be different from normal inference" + ) + + def test_ip_adapter_multi(self, expected_max_diff: float = 1e-4): + components = self.get_dummy_components() + pipe = self.pipeline_class(**components).to(torch_device) + pipe.set_progress_bar_config(disable=None) + cross_attention_dim = pipe.unet.config.get("cross_attention_dim", 32) + + # forward pass without ip adapter + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + output_without_adapter = pipe(**inputs)[0] + + adapter_state_dict_1 = create_ip_adapter_state_dict(pipe.unet) + adapter_state_dict_2 = create_ip_adapter_state_dict(pipe.unet) + pipe.unet._load_ip_adapter_weights([adapter_state_dict_1, adapter_state_dict_2]) + + # forward pass with multi ip adapter, but scale=0 which should have no effect + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([0.0, 0.0]) + output_without_multi_adapter_scale = pipe(**inputs)[0] + + # forward pass with multi ip adapter, but with scale of adapter weights + inputs = self._modify_inputs_for_ip_adapter_test(self.get_dummy_inputs(torch_device)) + inputs["ip_adapter_image_embeds"] = [self._get_dummy_image_embeds(cross_attention_dim)] * 2 + pipe.set_ip_adapter_scale([42.0, 42.0]) + output_with_multi_adapter_scale = pipe(**inputs)[0] + + max_diff_without_multi_adapter_scale = np.abs( + output_without_multi_adapter_scale - output_without_adapter + ).max() + max_diff_with_multi_adapter_scale = np.abs(output_with_multi_adapter_scale - output_without_adapter).max() + self.assertLess( + max_diff_without_multi_adapter_scale, + expected_max_diff, + "Output without multi-ip-adapter must be same as normal inference", + ) + self.assertGreater( + max_diff_with_multi_adapter_scale, + 1e-2, + "Output with multi-ip-adapter scale must be different from normal inference", + ) + + class PipelineLatentTesterMixin: """ This mixin is designed to be used with PipelineTesterMixin and unittest.TestCase classes.